[Update] OpenAI model update, bigcodebench update (#1879)

* [Update] Openai model update, bigcodebench update

* update
This commit is contained in:
Linchen Xiao 2025-02-20 19:33:25 +08:00 committed by GitHub
parent 27c916661d
commit d7daee6e25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 58 deletions

View File

@ -121,8 +121,40 @@ class BigCodeBenchEvaluator(BaseEvaluator):
logger.info('Start to extract code from predictions')
sanitized_predictions = []
for prediction, entrypoint in zip(predictions, entrypoints):
sanitized_prediction = extract_code_generation(
prediction, entrypoint=entrypoint)
try:
import signal
from contextlib import contextmanager
@contextmanager
def timeout_handler(seconds):
def _handle_timeout(signum, frame):
raise TimeoutError(f'Code extraction timed out'
f'after {seconds} seconds')
original_handler = signal.signal(signal.SIGALRM,
_handle_timeout)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, original_handler)
with timeout_handler(10):
sanitized_prediction = extract_code_generation(
prediction, entrypoint=entrypoint)
except TimeoutError as e:
logger.warning(
f'Code extraction timeout for entrypoint {entrypoint}: '
f'{str(e)}')
sanitized_prediction = ''
except Exception as e:
logger.warning(
f'Code extraction failed for entrypoint {entrypoint}: '
f'{str(e)}')
sanitized_prediction = ''
sanitized_predictions.append(sanitized_prediction)
# Prepare for submission

View File

@ -25,12 +25,7 @@ OPENAI_API_BASE = os.path.join(
OPENAISDK_API_BASE = os.environ.get('OPENAI_BASE_URL',
'https://api.openai.com/v1/')
O1_MODEL_LIST = [
'o1-preview-2024-09-12',
'o1-mini-2024-09-12',
'o1-preview',
'o1-mini',
]
O1_MODEL_LIST = ['o1', 'o3']
@MODELS.register_module()
@ -96,7 +91,6 @@ class OpenAI(BaseAPIModel):
temperature: Optional[float] = None,
tokenizer_path: Optional[str] = None,
extra_body: Optional[Dict] = None,
max_completion_tokens: int = 16384,
verbose: bool = False,
):
@ -151,9 +145,6 @@ class OpenAI(BaseAPIModel):
self.proxy_url = openai_proxy_url
self.path = path
self.max_completion_tokens = max_completion_tokens
self.logger.warning(
f'Max Completion tokens for {path} is {max_completion_tokens}')
def generate(
self,
@ -250,16 +241,15 @@ class OpenAI(BaseAPIModel):
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
try:
if self.path in O1_MODEL_LIST:
if any(model in self.path for model in O1_MODEL_LIST):
self.logger.warning(
f"'max_token' is unsupported for model {self.path}")
self.logger.warning(
f'We use max_completion_tokens: '
f'{self.max_completion_tokens}for this query')
f'We use max_out_len: {max_out_len} for this query')
data = dict(
model=self.path,
messages=messages,
max_completion_tokens=self.max_completion_tokens,
max_completion_tokens=max_out_len,
n=1,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,
@ -440,7 +430,7 @@ class OpenAI(BaseAPIModel):
if mode == 'front':
cur_prompt = sep.join(words[-mid:])
elif mode == 'mid':
cur_prompt = (sep.join(words[:mid]) + sep.join(words[-mid:]))
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
elif mode == 'rear':
cur_prompt = sep.join(words[:mid])
@ -480,7 +470,9 @@ class OpenAI(BaseAPIModel):
"""
# Check input length when mode is 'none'
if mode == 'none':
input_len = get_token_len_func(str(input))
input_len = (get_token_len_func(input) if isinstance(
input, str) else sum(
get_token_len_func(item['prompt']) for item in input))
if input_len > max_seq_len:
raise ValueError(
f'Input length ({input_len}) exceeds max_seq_len '
@ -499,12 +491,15 @@ class OpenAI(BaseAPIModel):
# Convert input to messages format
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
input_len = get_token_len_func(input)
else:
messages = []
processed_prompts = []
for item in input:
input_content = item['prompt']
if mode != 'none':
input_content = bin_trim_wrapper(input_content)
processed_prompts.append(input_content)
msg = {'content': input_content}
if item['role'] == 'HUMAN':
msg['role'] = 'user'
@ -513,19 +508,18 @@ class OpenAI(BaseAPIModel):
elif item['role'] == 'SYSTEM':
msg['role'] = 'system'
messages.append(msg)
input_len = sum(
get_token_len_func(prompt) for prompt in processed_prompts)
# Adjust max_out_len
if max_out_len is not None:
original_max_out_len = max_out_len
max_out_len = min(
max_out_len,
max_seq_len - get_token_len_func(str(input)) - 100)
max_out_len = min(max_out_len, max_seq_len - input_len - 100)
if max_out_len <= 0:
raise ValueError(
f'max_out_len ({max_out_len}) is less than or equal to 0. '
f'This may be due to input length '
f'({get_token_len_func(str(input))}) being too close to '
f'max_seq_len ({max_seq_len}). Please either increase '
f'This may be due to input length ({input_len}) being too '
f'close to max_seq_len ({max_seq_len}). Please increase '
f'max_seq_len or use a truncation mode other than "none".')
if max_out_len < original_max_out_len:
self.logger.warning(
@ -555,7 +549,6 @@ class OpenAISDK(OpenAI):
temperature: float | None = None,
tokenizer_path: str | None = None,
extra_body: Dict | None = None,
max_completion_tokens: int = 16384,
verbose: bool = False,
status_code_mappings: dict = {},
):
@ -577,7 +570,6 @@ class OpenAISDK(OpenAI):
tokenizer_path,
extra_body,
verbose=verbose,
max_completion_tokens=max_completion_tokens,
)
from openai import OpenAI
@ -605,8 +597,23 @@ class OpenAISDK(OpenAI):
self.logger.info(f'Used openai_client: {self.openai_client}')
self.status_code_mappings = status_code_mappings
def _generate(self, input: PromptList | str, max_out_len: int,
temperature: float) -> str:
def _generate(self,
input: PromptList | str,
max_out_len: int,
temperature: float,
timeout: int = 3600) -> str:
"""Generate results given a list of inputs.
Args:
input (PromptType): A string or PromptDict.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use.
timeout (int, optional): Timeout in seconds for the API call.
Defaults to 3600 (60 minutes).
Returns:
str: The generated string.
"""
from openai import APIStatusError, BadRequestError
assert isinstance(input, (str, PromptList))
@ -618,16 +625,14 @@ class OpenAISDK(OpenAI):
num_retries = 0
while num_retries < self.retry:
self.wait()
if self.path in O1_MODEL_LIST:
if any(model in self.path for model in O1_MODEL_LIST):
self.logger.warning(
f"'max_token' is unsupported for model {self.path}")
self.logger.warning(
f'We use max_completion_tokens: '
f'{self.max_completion_tokens}for this query')
f'We use max_out_len: {max_out_len} for this query')
query_data = dict(
model=self.path,
max_completion_tokens=self.max_completion_tokens,
max_completion_tokens=max_out_len,
n=1,
messages=messages,
extra_body=self.extra_body,
@ -646,7 +651,8 @@ class OpenAISDK(OpenAI):
if self.verbose:
self.logger.info('Start calling OpenAI API')
responses = self.openai_client.chat.completions.create(
**query_data)
**query_data, timeout=timeout) # timeout in seconds
if self.verbose:
self.logger.info(
'Successfully get response from OpenAI API')

View File

@ -34,39 +34,29 @@ MAP = {
'总分',
'中文总分',
'英文总分',
'instruct/compassbenchv1_4_IF_en_fofo_sub',
'instruct/compassbenchv1_4_IF_zh_fofo_sub',
'instruct/compassbench_2501_IF_en_chatIF_sub',
'instruct/compassbench_2501_IF_en_functionalIF_sub',
'instruct/compassbench_2501_IF_cn_chatIF_sub',
'instruct/compassbench_2501_IF_cn_functionalIF_sub',
],
'language': [
'总分',
'中文总分',
'英文总分',
'language/compassbenchv1_4_language_zh_chat_sub',
'language/compassbenchv1_4_language_zh_creation_sub',
'language/compassbenchv1_4_language_zh_NLP_sub',
'language/compassbenchv1_4_language_en_chat_sub',
'language/compassbenchv1_4_language_en_creation_sub',
'language/compassbenchv1_4_language_en_NLP_sub',
'language/compassbench_v2501_language_zh_chat_sub',
'language/compassbench_v2501_language_zh_nlp_sub',
'language/compassbench_v2501_language_zh_creation_sub',
'language/compassbench_v2501_language_en_chat_sub',
'language/compassbench_v2501_language_en_nlp_sub',
'language/compassbench_v2501_language_en_creation_sub',
],
'reasoning': [
'code': [
'总分',
'中文总分',
'英文总分',
'reasoning/compassbenchv1_4_reasoning_en_CommonSenseSense_sub',
'reasoning/compassbenchv1_4_reasoning_en_Humanities_sub',
'reasoning/compassbenchv1_4_reasoning_en_ScienceEngineering_sub',
'reasoning/compassbenchv1_4_reasoning_en_Social_sub',
'reasoning/compassbenchv1_4_reasoning_zh_CommonSenseSense_sub',
'reasoning/compassbenchv1_4_reasoning_zh_Humanities_sub',
'reasoning/compassbenchv1_4_reasoning_zh_ScienceEngineering_sub',
'reasoning/compassbenchv1_4_reasoning_zh_Social_sub',
],
'coding': [
'总分',
'中文总分',
'英文总分',
'coding/compassbenchv1_4_coding_en_sub',
'coding/compassbenchv1_4_coding_zh_sub',
'code/compassbench_2501_code_arena_en_sub',
'code/compassbench_2501_code_arena_zh_sub',
],
}