mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Update] OpenAI model update, bigcodebench update (#1879)
* [Update] Openai model update, bigcodebench update * update
This commit is contained in:
parent
27c916661d
commit
d7daee6e25
@ -121,8 +121,40 @@ class BigCodeBenchEvaluator(BaseEvaluator):
|
||||
logger.info('Start to extract code from predictions')
|
||||
sanitized_predictions = []
|
||||
for prediction, entrypoint in zip(predictions, entrypoints):
|
||||
sanitized_prediction = extract_code_generation(
|
||||
prediction, entrypoint=entrypoint)
|
||||
try:
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def timeout_handler(seconds):
|
||||
|
||||
def _handle_timeout(signum, frame):
|
||||
raise TimeoutError(f'Code extraction timed out'
|
||||
f'after {seconds} seconds')
|
||||
|
||||
original_handler = signal.signal(signal.SIGALRM,
|
||||
_handle_timeout)
|
||||
signal.alarm(seconds)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, original_handler)
|
||||
|
||||
with timeout_handler(10):
|
||||
sanitized_prediction = extract_code_generation(
|
||||
prediction, entrypoint=entrypoint)
|
||||
|
||||
except TimeoutError as e:
|
||||
logger.warning(
|
||||
f'Code extraction timeout for entrypoint {entrypoint}: '
|
||||
f'{str(e)}')
|
||||
sanitized_prediction = ''
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Code extraction failed for entrypoint {entrypoint}: '
|
||||
f'{str(e)}')
|
||||
sanitized_prediction = ''
|
||||
sanitized_predictions.append(sanitized_prediction)
|
||||
|
||||
# Prepare for submission
|
||||
|
@ -25,12 +25,7 @@ OPENAI_API_BASE = os.path.join(
|
||||
OPENAISDK_API_BASE = os.environ.get('OPENAI_BASE_URL',
|
||||
'https://api.openai.com/v1/')
|
||||
|
||||
O1_MODEL_LIST = [
|
||||
'o1-preview-2024-09-12',
|
||||
'o1-mini-2024-09-12',
|
||||
'o1-preview',
|
||||
'o1-mini',
|
||||
]
|
||||
O1_MODEL_LIST = ['o1', 'o3']
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
@ -96,7 +91,6 @@ class OpenAI(BaseAPIModel):
|
||||
temperature: Optional[float] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
extra_body: Optional[Dict] = None,
|
||||
max_completion_tokens: int = 16384,
|
||||
verbose: bool = False,
|
||||
):
|
||||
|
||||
@ -151,9 +145,6 @@ class OpenAI(BaseAPIModel):
|
||||
self.proxy_url = openai_proxy_url
|
||||
|
||||
self.path = path
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.logger.warning(
|
||||
f'Max Completion tokens for {path} is {max_completion_tokens}')
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -250,16 +241,15 @@ class OpenAI(BaseAPIModel):
|
||||
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
|
||||
|
||||
try:
|
||||
if self.path in O1_MODEL_LIST:
|
||||
if any(model in self.path for model in O1_MODEL_LIST):
|
||||
self.logger.warning(
|
||||
f"'max_token' is unsupported for model {self.path}")
|
||||
self.logger.warning(
|
||||
f'We use max_completion_tokens: '
|
||||
f'{self.max_completion_tokens}for this query')
|
||||
f'We use max_out_len: {max_out_len} for this query')
|
||||
data = dict(
|
||||
model=self.path,
|
||||
messages=messages,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
max_completion_tokens=max_out_len,
|
||||
n=1,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
@ -440,7 +430,7 @@ class OpenAI(BaseAPIModel):
|
||||
if mode == 'front':
|
||||
cur_prompt = sep.join(words[-mid:])
|
||||
elif mode == 'mid':
|
||||
cur_prompt = (sep.join(words[:mid]) + sep.join(words[-mid:]))
|
||||
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
|
||||
elif mode == 'rear':
|
||||
cur_prompt = sep.join(words[:mid])
|
||||
|
||||
@ -480,7 +470,9 @@ class OpenAI(BaseAPIModel):
|
||||
"""
|
||||
# Check input length when mode is 'none'
|
||||
if mode == 'none':
|
||||
input_len = get_token_len_func(str(input))
|
||||
input_len = (get_token_len_func(input) if isinstance(
|
||||
input, str) else sum(
|
||||
get_token_len_func(item['prompt']) for item in input))
|
||||
if input_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f'Input length ({input_len}) exceeds max_seq_len '
|
||||
@ -499,12 +491,15 @@ class OpenAI(BaseAPIModel):
|
||||
# Convert input to messages format
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
input_len = get_token_len_func(input)
|
||||
else:
|
||||
messages = []
|
||||
processed_prompts = []
|
||||
for item in input:
|
||||
input_content = item['prompt']
|
||||
if mode != 'none':
|
||||
input_content = bin_trim_wrapper(input_content)
|
||||
processed_prompts.append(input_content)
|
||||
msg = {'content': input_content}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
@ -513,19 +508,18 @@ class OpenAI(BaseAPIModel):
|
||||
elif item['role'] == 'SYSTEM':
|
||||
msg['role'] = 'system'
|
||||
messages.append(msg)
|
||||
input_len = sum(
|
||||
get_token_len_func(prompt) for prompt in processed_prompts)
|
||||
|
||||
# Adjust max_out_len
|
||||
if max_out_len is not None:
|
||||
original_max_out_len = max_out_len
|
||||
max_out_len = min(
|
||||
max_out_len,
|
||||
max_seq_len - get_token_len_func(str(input)) - 100)
|
||||
max_out_len = min(max_out_len, max_seq_len - input_len - 100)
|
||||
if max_out_len <= 0:
|
||||
raise ValueError(
|
||||
f'max_out_len ({max_out_len}) is less than or equal to 0. '
|
||||
f'This may be due to input length '
|
||||
f'({get_token_len_func(str(input))}) being too close to '
|
||||
f'max_seq_len ({max_seq_len}). Please either increase '
|
||||
f'This may be due to input length ({input_len}) being too '
|
||||
f'close to max_seq_len ({max_seq_len}). Please increase '
|
||||
f'max_seq_len or use a truncation mode other than "none".')
|
||||
if max_out_len < original_max_out_len:
|
||||
self.logger.warning(
|
||||
@ -555,7 +549,6 @@ class OpenAISDK(OpenAI):
|
||||
temperature: float | None = None,
|
||||
tokenizer_path: str | None = None,
|
||||
extra_body: Dict | None = None,
|
||||
max_completion_tokens: int = 16384,
|
||||
verbose: bool = False,
|
||||
status_code_mappings: dict = {},
|
||||
):
|
||||
@ -577,7 +570,6 @@ class OpenAISDK(OpenAI):
|
||||
tokenizer_path,
|
||||
extra_body,
|
||||
verbose=verbose,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
from openai import OpenAI
|
||||
|
||||
@ -605,8 +597,23 @@ class OpenAISDK(OpenAI):
|
||||
self.logger.info(f'Used openai_client: {self.openai_client}')
|
||||
self.status_code_mappings = status_code_mappings
|
||||
|
||||
def _generate(self, input: PromptList | str, max_out_len: int,
|
||||
temperature: float) -> str:
|
||||
def _generate(self,
|
||||
input: PromptList | str,
|
||||
max_out_len: int,
|
||||
temperature: float,
|
||||
timeout: int = 3600) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
input (PromptType): A string or PromptDict.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use.
|
||||
timeout (int, optional): Timeout in seconds for the API call.
|
||||
Defaults to 3600 (60 minutes).
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
from openai import APIStatusError, BadRequestError
|
||||
|
||||
assert isinstance(input, (str, PromptList))
|
||||
@ -618,16 +625,14 @@ class OpenAISDK(OpenAI):
|
||||
num_retries = 0
|
||||
while num_retries < self.retry:
|
||||
self.wait()
|
||||
|
||||
if self.path in O1_MODEL_LIST:
|
||||
if any(model in self.path for model in O1_MODEL_LIST):
|
||||
self.logger.warning(
|
||||
f"'max_token' is unsupported for model {self.path}")
|
||||
self.logger.warning(
|
||||
f'We use max_completion_tokens: '
|
||||
f'{self.max_completion_tokens}for this query')
|
||||
f'We use max_out_len: {max_out_len} for this query')
|
||||
query_data = dict(
|
||||
model=self.path,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
max_completion_tokens=max_out_len,
|
||||
n=1,
|
||||
messages=messages,
|
||||
extra_body=self.extra_body,
|
||||
@ -646,7 +651,8 @@ class OpenAISDK(OpenAI):
|
||||
if self.verbose:
|
||||
self.logger.info('Start calling OpenAI API')
|
||||
responses = self.openai_client.chat.completions.create(
|
||||
**query_data)
|
||||
**query_data, timeout=timeout) # timeout in seconds
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
'Successfully get response from OpenAI API')
|
||||
|
@ -34,39 +34,29 @@ MAP = {
|
||||
'总分',
|
||||
'中文总分',
|
||||
'英文总分',
|
||||
'instruct/compassbenchv1_4_IF_en_fofo_sub',
|
||||
'instruct/compassbenchv1_4_IF_zh_fofo_sub',
|
||||
'instruct/compassbench_2501_IF_en_chatIF_sub',
|
||||
'instruct/compassbench_2501_IF_en_functionalIF_sub',
|
||||
'instruct/compassbench_2501_IF_cn_chatIF_sub',
|
||||
'instruct/compassbench_2501_IF_cn_functionalIF_sub',
|
||||
],
|
||||
'language': [
|
||||
'总分',
|
||||
'中文总分',
|
||||
'英文总分',
|
||||
'language/compassbenchv1_4_language_zh_chat_sub',
|
||||
'language/compassbenchv1_4_language_zh_creation_sub',
|
||||
'language/compassbenchv1_4_language_zh_NLP_sub',
|
||||
'language/compassbenchv1_4_language_en_chat_sub',
|
||||
'language/compassbenchv1_4_language_en_creation_sub',
|
||||
'language/compassbenchv1_4_language_en_NLP_sub',
|
||||
'language/compassbench_v2501_language_zh_chat_sub',
|
||||
'language/compassbench_v2501_language_zh_nlp_sub',
|
||||
'language/compassbench_v2501_language_zh_creation_sub',
|
||||
'language/compassbench_v2501_language_en_chat_sub',
|
||||
'language/compassbench_v2501_language_en_nlp_sub',
|
||||
'language/compassbench_v2501_language_en_creation_sub',
|
||||
],
|
||||
'reasoning': [
|
||||
|
||||
'code': [
|
||||
'总分',
|
||||
'中文总分',
|
||||
'英文总分',
|
||||
'reasoning/compassbenchv1_4_reasoning_en_CommonSenseSense_sub',
|
||||
'reasoning/compassbenchv1_4_reasoning_en_Humanities_sub',
|
||||
'reasoning/compassbenchv1_4_reasoning_en_ScienceEngineering_sub',
|
||||
'reasoning/compassbenchv1_4_reasoning_en_Social_sub',
|
||||
'reasoning/compassbenchv1_4_reasoning_zh_CommonSenseSense_sub',
|
||||
'reasoning/compassbenchv1_4_reasoning_zh_Humanities_sub',
|
||||
'reasoning/compassbenchv1_4_reasoning_zh_ScienceEngineering_sub',
|
||||
'reasoning/compassbenchv1_4_reasoning_zh_Social_sub',
|
||||
],
|
||||
'coding': [
|
||||
'总分',
|
||||
'中文总分',
|
||||
'英文总分',
|
||||
'coding/compassbenchv1_4_coding_en_sub',
|
||||
'coding/compassbenchv1_4_coding_zh_sub',
|
||||
'code/compassbench_2501_code_arena_en_sub',
|
||||
'code/compassbench_2501_code_arena_zh_sub',
|
||||
],
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user