diff --git a/opencompass/configs/datasets/livemathbench/livemathbench_gen.py b/opencompass/configs/datasets/livemathbench/livemathbench_gen.py index c0bd6477..4977f1e2 100644 --- a/opencompass/configs/datasets/livemathbench/livemathbench_gen.py +++ b/opencompass/configs/datasets/livemathbench/livemathbench_gen.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .livemathbench_gen_caed8f import livemathbench_datasets # noqa: F401, F403 \ No newline at end of file + from .livemathbench_gen_9befbf import livemathbench_datasets # noqa: F401, F403 \ No newline at end of file diff --git a/opencompass/configs/datasets/livemathbench/livemathbench_gen_9befbf.py b/opencompass/configs/datasets/livemathbench/livemathbench_gen_9befbf.py new file mode 100644 index 00000000..3748c022 --- /dev/null +++ b/opencompass/configs/datasets/livemathbench/livemathbench_gen_9befbf.py @@ -0,0 +1,51 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer + +from opencompass.datasets.livemathbench import LiveMathBenchDataset, LiveMathBenchEvaluator + + +livemathbench_dataset = dict( + type=LiveMathBenchDataset, + path='', + k=16, + replication=3, + dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], + dataset_languages=['cn', 'en'], + cot=True, + version='202412', + abbr='LiveMathBench-v202412', + reader_cfg=dict( + input_columns=['prompt'], + output_column='answer' + ), + infer_cfg=dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='{prompt}'), + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict( + type=GenInferencer, + max_out_len=8192 + ), + ), + eval_cfg=dict( + evaluator=dict( + type=LiveMathBenchEvaluator, + model_name='', + url=[], + use_extract_model=False, + extract_url=[], + extract_model_name='', + k=[4, 8, 16], + replication=3, + thresholds=[0.0, 0.25, 0.5, 0.75, 1.0] + ) + ) +) +livemathbench_datasets = [livemathbench_dataset] \ No newline at end of file diff --git a/opencompass/datasets/livemathbench/livemathbench.py b/opencompass/datasets/livemathbench/livemathbench.py index 56a22ae2..9d6ac63b 100644 --- a/opencompass/datasets/livemathbench/livemathbench.py +++ b/opencompass/datasets/livemathbench/livemathbench.py @@ -1,45 +1,55 @@ -import concurrent.futures import os -import re +import warnings from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy +from functools import partial from itertools import product -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, Union import jsonlines +import mmengine import numpy as np -from datasets import Dataset +from datasets import Dataset, load_dataset +from opencompass.datasets.math import MATHAgentEvaluator, math_postprocess_v2 from opencompass.models import OpenAISDK -from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.openicl.icl_evaluator import GPassKEvaluator +from opencompass.openicl.icl_inferencer.icl_base_inferencer import \ + dump_results_dict from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS from opencompass.utils import get_data_path from ..base import BaseDataset from .prompts import (EXTRACT_PROMPT_CN, EXTRACT_PROMPT_EN, JUDGE_PROMPT_CN, JUDGE_PROMPT_EN, PROMPT_CN, PROMPT_EN) +from .utils import extract_judge_label @LOAD_DATASET.register_module() class LiveMathBenchDataset(BaseDataset): @staticmethod - def load( - path: str, - k: int, - n: int, - dataset_splits: List[str] = [ - 'AIMC', 'CEE', 'CMO', 'MATH500', 'AIME2024' - ], - dataset_languages: List[str] = ['cn', 'en'], - ) -> List[Dict[str, Any]]: + def load(path: str, + k: Union[int, List[int]], + replication: int, + dataset_splits: List[str] = [ + 'CNMO', + 'CCEE', + 'AMC', + 'WLPMC', + ], + dataset_languages: List[str] = ['cn', 'en'], + cot: bool = True, + version: str = '202412') -> List[Dict[str, Any]]: dataset = [] dataset_info = {} - path = get_data_path(path) + + if path != '': + path = get_data_path(path) + head, tail = os.path.split(path) + path = os.path.join(head, f'{tail}-{version}') for split, language in product(dataset_splits, dataset_languages): - file_path = os.path.join(path, f'{split}_{language}.jsonl') - if not os.path.exists(file_path): - continue dataset_info[f'{split}_{language}'] = { 'single-choice': 0, 'multiple-choice': 0, @@ -52,36 +62,57 @@ class LiveMathBenchDataset(BaseDataset): '填空': 'fill-in-the-blank', '问答': 'problem-solving' } - with jsonlines.open(file_path, 'r') as file: - for example_idx, example in enumerate(file): - dataset_info[f'{split}_{language}'][ - example['question_type'] if language == 'en' else - question_type_mapping[example['question_type']]] += 1 - prompt = PROMPT_EN if language == 'en' else PROMPT_CN - example.update({ - 'dataset_key': - f'{split}_{language}_{example_idx}', - 'prompt': - prompt.format(question_type=example['question_type'], - question=example['question'] + - ('' if 'options' not in example else - ' '.join(example['options']))), - 'k': - k, - 'n': - n - }) - for idx in range(k * n): - duplicated_example = deepcopy(example) - duplicated_example.update({'duplicated_idx': idx}) - dataset.append(duplicated_example) + if path != '': + file_path = os.path.join(path, f'{split}_{language}.jsonl') + if not os.path.exists(file_path): + continue + examples = [] + with jsonlines.open(file_path, 'r') as file: + for example in file: + examples.append(example) + else: + hf_dataset = load_dataset( + 'opencompass/LiveMathBench', + f'v{version}_{split}_{language}')['test'] + examples = [] + for example in hf_dataset: + examples.append(example) + + for example_idx, example in enumerate(examples): + dataset_info[f'{split}_{language}'][ + example['question_type'] if language == 'en' else + question_type_mapping[example['question_type']]] += 1 + + prompt = PROMPT_EN if language == 'en' else PROMPT_CN + if not cot: + if language == 'cn': + prompt = prompt.replace(',请逐步推理', '') + else: + prompt = prompt.replace( + ', please reasoning step by step', '') + example.update({ + 'subdivision': + f'{split}_{language}', + 'idx': + str(example_idx), + 'prompt': + prompt.format(question_type=example['question_type'], + question=example['question'] + + ('' if 'options' not in example else + ' '.join(example['options']))), + }) + max_k = k if isinstance(k, int) else max(k) + for idx in range(max_k * replication): + duplicated_example = deepcopy(example) + duplicated_example.update({'replication_idx': idx}) + dataset.append(duplicated_example) return Dataset.from_list(dataset) @ICL_EVALUATORS.register_module() -class LiveMathBenchEvaluator(BaseEvaluator): +class LiveMathBenchEvaluator(GPassKEvaluator): api_meta_template = dict(round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), @@ -90,72 +121,125 @@ class LiveMathBenchEvaluator(BaseEvaluator): def __init__(self, model_name, url, - with_postprocess=True, use_extract_model=False, - post_url=[], - post_model_name='', - **kwargs): + extract_url=[], + extract_model_name='', + k: Union[int, List[int]] = 16, + replication: int = 3, + thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]): + super().__init__(k, replication, thresholds) + if isinstance(url, str): url = [url] - self.model = [ - MODELS.build( - dict( - type=OpenAISDK, - path=model_name, - openai_api_base=url, - key='EMPTY', - query_per_second=128, - meta_template=self.api_meta_template, - temperature=kwargs.get('temperature', 0.001), - max_seq_len=kwargs.get('max_tokens', 16384), - )) for url in url - ] - self.with_postprocess = with_postprocess - self.use_extract_model = use_extract_model - self.post_url = post_url - self.post_model_name = post_model_name - - def batch_response(self, models: List[OpenAISDK], - inputs: List[str]) -> List[str]: - batch_num = len(models) - batch_size = (len(inputs) + batch_num - 1) // batch_num - result_responses = [] - - with concurrent.futures.ThreadPoolExecutor( - max_workers=batch_num) as executor: - futures = [ - executor.submit(models[i].generate, - inputs[i * batch_size:(i + 1) * batch_size]) - for i in range(batch_num) - ] - for response in executor.map(lambda f: f.result(), futures): - result_responses.extend(response) - - return result_responses - - def postprocess(self, questions: List[str], predictions: List[str], - question_types: List[str], - languages: List[str]) -> List[str]: - if self.use_extract_model: - assert len(self.post_url) > 0 and self.post_model_name != '' - post_model = [ + if model_name == '' or len(url) == 0: + warnings.warn('Unable to leverage LLM-as-judge abd backup to ' + 'rule-based judge due to incomplete parameters, ' + 'this may cause performance degradation, check ' + '`model_name` or `url` of evaluator if you do ' + 'not want to do this.') + self.judge_models = [] + else: + self.judge_models = [ MODELS.build( dict( type=OpenAISDK, - path=self.post_model_name, + path=model_name, + openai_api_base=_url, + key='EMPTY', + query_per_second=2, + retry=5, + meta_template=self.api_meta_template, + temperature=0.0, + max_seq_len=16384, + )) for _url in url + ] + self.use_extract_model = use_extract_model + self.extract_url = extract_url + self.extract_model_name = extract_model_name + + self.extract_output_handler = LiveMathBenchOutputHandler() + self.judge_output_handler = LiveMathBenchOutputHandler() + + def batch_infer(self, models: List[OpenAISDK], inputs: List[str], + completed_indexes: set, + output_handler: 'LiveMathBenchOutputHandler', + postprocess: Callable) -> List[str]: + batch_size = 16 + batch_num = (len(inputs) + batch_size - 1) // batch_size + all_indexes = [i for i in range(len(inputs))] + indexes = [i for i in all_indexes if i not in completed_indexes] + inputs = [inputs[i] for i in indexes] + result_responses = [] + result_indexes = [] + + def thread_worker(inputs, max_out_len, temperature, indexes, model): + return model.generate(inputs, max_out_len, + temperature), inputs, indexes + + if len(indexes) > 0: + with ThreadPoolExecutor(max_workers=len(models)) as pool: + tasks = [ + pool.submit( + partial(thread_worker, model=models[i % len(models)]), + inputs[i * batch_size:(i + 1) * batch_size], 8192, 0.0, + indexes[i * batch_size:(i + 1) * batch_size]) + for i in range(batch_num) + ] + for completed_task in as_completed(tasks): + responses, current_inputs, indexes = completed_task.result( + ) + for input, response, index in zip(current_inputs, + responses, indexes): + output_handler.save( + index, + prompt=input, + response=response, + postprocess_response=postprocess(response)) + result_responses.append(postprocess(response)) + result_indexes.append(index) + output_handler.write_to_json() + + return [ + output_handler.output_dict[str(i)]['postprocess_response'] + for i in all_indexes + ] + + def extract(self, questions: List[str], predictions: List[str], + question_types: List[str], languages: List[str]) -> List[str]: + + # extract answer by model + if self.use_extract_model: + assert len(self.extract_url) > 0 and self.extract_model_name != '' + extract_models = [ + MODELS.build( + dict( + type=OpenAISDK, + path=self.extract_model_name, openai_api_base=url, key='EMPTY', query_per_second=2, + retry=5, meta_template=self.api_meta_template, - temperature=0.01, + temperature=0.0, max_seq_len=1024, - )) for url in self.post_url + )) for url in self.extract_url ] + completed_indexes = [] + mmengine.mkdir_or_exist(self.output_dir) + tmp_json_file_path = os.path.join(self.output_dir, + 'tmp_extract.json') + self.extract_output_handler.save_file_path = tmp_json_file_path + if os.path.exists(tmp_json_file_path): + tmp_dict = mmengine.load(tmp_json_file_path) + self.extract_output_handler.output_dict = tmp_dict + for index in tmp_dict: + completed_indexes.add(int(index)) + input_prompts = [] - for question, prediction, question_type, language in zip( - questions, predictions, question_types, languages): + for question, prediction, question_type, language in enumerate( + zip(questions, predictions, question_types, languages)): prompt = (EXTRACT_PROMPT_EN if language == 'en' else EXTRACT_PROMPT_CN) input_prompts.append( @@ -163,245 +247,125 @@ class LiveMathBenchEvaluator(BaseEvaluator): response=prediction, question_type=question_type)) - result_responses = self.batch_response(post_model, input_prompts) + results = self.batch_infer(extract_models, + input_prompts, + completed_indexes, + self.extract_output_handler, + postprocess=lambda x: x) - return result_responses + return results - def last_boxed_only_string(string): - idx = string.rfind('\\boxed') - if idx < 0: - idx = string.rfind('\\fbox') - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == '{': - num_left_braces_open += 1 - if string[i] == '}': - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - - if right_brace_idx is None: - retval = None - else: - retval = string[idx:right_brace_idx + 1] - - return retval - - def remove_boxed(s): - left = '\\boxed{' - try: - assert s[:len(left)] == left - assert s[-1] == '}' - return s[len(left):-1] - except Exception: - return None - - def extract_boxed_answer(pred_str, strip_double_curly_brace=False): - boxed_str = last_boxed_only_string(pred_str) - if boxed_str is None: - return None - answer = remove_boxed(boxed_str) - if answer is None: - return None - if strip_double_curly_brace: - match = re.match('^\{(.*)\}$', answer) # noqa: W605 - if match: - answer = match.group(1) - return answer - - predictions = [ - extract_boxed_answer(prediction) for prediction in predictions + # extract answer in \\boxed{} + results = [ + math_postprocess_v2(prediction) for prediction in predictions ] - return predictions + return results - def extract_boxed_answer(self, text): - match = re.findall(r'\\boxed{(.+?)}', text) - if match: - return match[-1] - - return None - - def score(self, predictions, references, origin_prompt, test_set): + def judge(self, predictions, references, test_set): if len(predictions) != len(references): - return {'error': 'preds and refrs have different length'} + raise ValueError('preds and refrs have different length') + + completed_indexes = set() + mmengine.mkdir_or_exist(self.output_dir) + tmp_json_file_path = os.path.join(self.output_dir, 'tmp_judge.json') + self.judge_output_handler.save_file_path = tmp_json_file_path + if os.path.exists(tmp_json_file_path): + tmp_dict = mmengine.load(tmp_json_file_path) + self.judge_output_handler.output_dict = tmp_dict + for index in tmp_dict: + completed_indexes.add(int(index)) questions = test_set['question'] question_types = test_set['question_type'] - languages = [key.split('_')[1] for key in test_set['dataset_key']] + languages = [key.split('_')[1] for key in test_set['subdivision']] - if self.with_postprocess: - predictions = self.postprocess(questions, predictions, - question_types, languages) + predictions = self.extract(questions, predictions, question_types, + languages) - inputs = [] - for prediction, reference, question, language in zip( - predictions, references, questions, languages): - prompt = JUDGE_PROMPT_EN if language == 'en' else JUDGE_PROMPT_CN - inputs.append( - prompt.format(answer=prediction, - gold_answer=reference, - question=question)) - result_responses = self.batch_response(self.model, inputs) - results = [ - self.extract_boxed_answer(result) == 'yes' - for result in result_responses - ] + if len(self.judge_models) > 0: + inputs = [] + for prediction, reference, question, language in zip( + predictions, references, questions, languages): + prompt = (JUDGE_PROMPT_EN + if language == 'en' else JUDGE_PROMPT_CN) + inputs.append( + prompt.format(answer=prediction, + gold_answer=reference, + question=question)) - K = test_set['k'][0] - N = test_set['n'][0] - key2example = {} - - for example, result_response, result, prediction in zip( - test_set, result_responses, results, predictions): - if example['dataset_key'] not in key2example: - key2example[example['dataset_key']] = [] - example.update({ - 'eval_response': result_response, - 'prediction': prediction, - 'correct': result - }) - key2example[example['dataset_key']].append(example) - for key in key2example: - key2example[key] = [ - key2example[key][i * K:(i + 1) * K] for i in range(N) + labels = self.batch_infer( + self.judge_models, inputs, completed_indexes, + self.judge_output_handler, lambda x: + (1 if extract_judge_label(x) == 'yes' else 0)) + else: + is_equiv = MATHAgentEvaluator(version='v2').is_equiv + labels = [ + 1 if is_equiv(prediction, reference) else 0 + for prediction, reference in zip(predictions, references) ] + return labels - count = [] - total_pass_num = [] - details = [] - all_dataset = set() - for key, examples in key2example.items(): - detail = OrderedDict() - detail['question'] = examples[0][0]['question'] - detail['answer'] = examples[0][0]['answer'] - detail['responses'] = [] - detail['dataset'] = '_'.join(key.split('_')[:-1]) - all_dataset.add('_'.join(key.split('_')[:-1])) - if_pass_list = [] - for single_run_examples in examples: - detail['responses'].append([]) - if_pass_list.append([]) - for example in single_run_examples: - detail['responses'][-1].append({ - 'prediction': - example['prediction'], - 'eval_response': - example['eval_response'] - }) - if_pass_list[-1].append(1.0 if example['correct'] else 0.0) + def preprocess(self, predictions, references, test_set): + return self.judge(predictions, references, test_set) - if_pass_list = [ - sorted(if_pass, reverse=True) for if_pass in if_pass_list - ] - if_pass_list = np.array(if_pass_list) - i = 1 - while i <= K: - detail.update({ - f'pass-rate@{i}': - if_pass_list[:, :i].mean(axis=1).mean(axis=0).item(), - f'pass-rate@{i}/std': - if_pass_list[:, :i].mean(axis=1).std(axis=0).item(), - f'pass@{i}': - np.ceil( - if_pass_list[:, :i].mean(axis=1)).mean(axis=0).item(), - f'pass@{i}/std': - np.ceil( - if_pass_list[:, :i].mean(axis=1)).std(axis=0).item(), - }) - i = i * 2 + def group(self, predictions, labels, test_set): + example2replications = {} + for example, label, prediction in zip(test_set, labels, predictions): + example_abbr = f"{example['subdivision']}_{example['idx']}" + if example_abbr not in example2replications: + example2replications[example_abbr] = [] + example.update({'prediction': prediction, 'label': label}) + example2replications[example_abbr].append(example) + for _, replications in example2replications.items(): + assert len(replications) == self.n, print(len(replications), + self.n) + return example2replications - for threshold in [0.5, 0.75, 1.0]: - detail.update({ - f'{K}-pass@{threshold}': - np.floor( - np.where( - if_pass_list.mean(axis=1) >= threshold, 1.0, - 0.0).mean(axis=0)) - }) + def reduce(self, details) -> Dict[str, Any]: + """Aggregate the overall metrics. - count.append(np.ones_like(if_pass_list).sum(axis=1)) - total_pass_num.append(if_pass_list.sum(axis=1)) + Return: + A dict contains overall metrics, like: + {'details': details for each example, 'G-Pass@16': xxx} + """ + g_passk_details = OrderedDict() + g_passk_details['details'] = details - details.append(detail) + all_dataset = set([detail['subdivision'] for detail in details]) - detailed_result = OrderedDict() - detailed_result['details'] = details - - i = 1 - while i <= K: - detailed_result.update({ - f'pass-rate@{i}': - 100. * - np.mean([detail[f'pass-rate@{i}'] for detail in details]), - f'pass-rate@{i}/std': - 100. * - np.mean([detail[f'pass-rate@{i}/std'] for detail in details]), - f'pass@{i}': - 100. * np.mean([detail[f'pass@{i}'] for detail in details]), - f'pass@{i}/std': - 100. * np.mean([detail[f'pass@{i}/std'] for detail in details]) - }) - for d in sorted(list(all_dataset)): - detailed_result.update({ - f'{d}/pass-rate@{i}': - 100. * np.mean([ - detail[f'pass-rate@{i}'] - for detail in details if detail['dataset'] == d - ]), - f'{d}/pass-rate@{i}/std': - 100. * np.mean([ - detail[f'pass-rate@{i}/std'] - for detail in details if detail['dataset'] == d - ]), - f'{d}/pass@{i}': - 100. * np.mean([ - detail[f'pass@{i}'] - for detail in details if detail['dataset'] == d - ]), - f'{d}/pass@{i}/std': - 100. * np.mean([ - detail[f'pass@{i}/std'] - for detail in details if detail['dataset'] == d + for k in self.k: + for subdivision in sorted(list(all_dataset)): + for threshold in self.thresholds: + g_passk_details[ + f'{subdivision}/G-Pass@{k}_{threshold}'] = \ + 100. * np.mean( + [ + detail[f'G-Pass@{k}_{threshold}'] + for detail in details + if detail['subdivision'] == subdivision + ]) + g_passk_details[f'{subdivision}/mG-Pass@{k}'] = 100. * np.mean( + [ + detail[f'mG-Pass@{k}'] for detail in details + if detail['subdivision'] == subdivision ]) - }) - i = i * 2 - for threshold in [0.5, 0.75, 1.0]: - detailed_result.update({ - f'{K}-pass@{threshold}': - 100. * np.mean([ - detail[f'{K}-pass@{threshold}'] for detail in details - ]) - }) - detailed_result.update({ - f'{K}-pass@{threshold}/std': - 100. * np.mean([ - detail[f'{K}-pass@{threshold}'] for detail in details - ]) - }) - for d in sorted(list(all_dataset)): + for threshold in self.thresholds: + g_passk_details[f'G-Pass@{k}_{threshold}'] = 100. * np.mean( + [detail[f'G-Pass@{k}_{threshold}'] for detail in details]) + g_passk_details[f'mG-Pass@{k}'] = 100. * np.mean( + [detail[f'mG-Pass@{k}'] for detail in details]) - for threshold in [0.5, 0.75, 1.0]: - detailed_result.update({ - f'{d}/{K}-pass@{threshold}': - 100. * np.mean([ - detail[f'{K}-pass@{threshold}'] - for detail in details if detail['dataset'] == d - ]) - }) - detailed_result.update({ - f'{d}/{K}-pass@{threshold}/std': - 100. * np.mean([ - detail[f'{K}-pass@{threshold}'] - for detail in details if detail['dataset'] == d - ]) - }) + return g_passk_details - return detailed_result + +class LiveMathBenchOutputHandler: + output_dict = {} + save_file_path = '' + + def write_to_json(self): + """Dump the result to a json file.""" + dump_results_dict(self.output_dict, self.save_file_path) + + def save(self, idx, **kwargs): + self.output_dict[str(idx)] = kwargs diff --git a/opencompass/datasets/livemathbench/utils.py b/opencompass/datasets/livemathbench/utils.py new file mode 100644 index 00000000..411e098f --- /dev/null +++ b/opencompass/datasets/livemathbench/utils.py @@ -0,0 +1,10 @@ +import re + + +def extract_judge_label(text): + if isinstance(text, str): + match = re.findall(r'\\boxed{(.+?)}', text) + if match: + return match[-1] + + return None diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 261d3926..5cd38b4a 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -106,7 +106,7 @@ def _format_with_fast_chat_template(inputs: List[str], name: str='vicuna'): elif item['role'] == 'system': continue else: - raise ValueError(f'Unknown role {item["role"]}') + raise ValueError(f"Unknown role {item['role']}") template.append_message(template.roles[1], None) outputs.append(template.get_prompt()) return outputs @@ -474,6 +474,8 @@ class HuggingFacewithChatTemplate(BaseModel): if min_out_len is not None: generation_kwargs['min_new_tokens'] = min_out_len generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id + self.logger.info('Generation Args of Huggingface: ') + self.logger.info(generation_kwargs) # step-2: conduct model forward to generate output outputs = self.model.generate(**tokens, **generation_kwargs) diff --git a/opencompass/models/openai_api.py b/opencompass/models/openai_api.py index 51faf7be..1ac544c1 100644 --- a/opencompass/models/openai_api.py +++ b/opencompass/models/openai_api.py @@ -516,10 +516,13 @@ class OpenAISDK(OpenAI): # support multiple api_base for acceleration if isinstance(openai_api_base, List): - openai_api_base = random.choice(openai_api_base) + self.openai_api_base = random.choice(openai_api_base) + else: + self.openai_api_base = openai_api_base if self.proxy_url is None: - self.openai_client = OpenAI(base_url=openai_api_base, api_key=key) + self.openai_client = OpenAI(base_url=self.openai_api_base, + api_key=key) else: proxies = { 'http://': self.proxy_url, @@ -527,7 +530,7 @@ class OpenAISDK(OpenAI): } self.openai_client = OpenAI( - base_url=openai_api_base, + base_url=self.openai_api_base, api_key=key, http_client=httpx.Client(proxies=proxies)) if self.verbose: @@ -617,8 +620,8 @@ class OpenAISDK(OpenAI): 'Successfully get response from OpenAI API') try: self.logger.info(responses) - except Exception as e: # noqa F841 - pass + except Exception: + pass # noqa F841 if not responses.choices: self.logger.error( 'Response is empty, it is an internal server error \ @@ -635,13 +638,18 @@ class OpenAISDK(OpenAI): if (status_code is not None and status_code in self.status_code_mappings): error_message = self.status_code_mappings[status_code] + self.logger.error( + f'error occurs at {self.openai_api_base}') self.logger.info(f'Status Code: {status_code}, \n' f'Original Error Message: {e}, \n' f'Return Message: {error_message} ') return error_message else: + self.logger.error( + f'error occurs at {self.openai_api_base}') self.logger.error(e) except Exception as e: + self.logger.error(f'error occurs at {self.openai_api_base}') self.logger.error(e) num_retries += 1 raise RuntimeError('Calling OpenAI API failed after retrying for ' diff --git a/opencompass/models/turbomind_with_tf_above_v4_33.py b/opencompass/models/turbomind_with_tf_above_v4_33.py index 79e6e556..47bbc84b 100644 --- a/opencompass/models/turbomind_with_tf_above_v4_33.py +++ b/opencompass/models/turbomind_with_tf_above_v4_33.py @@ -128,10 +128,7 @@ class TurboMindModelwithChatTemplate(BaseModel): gen_config['max_new_tokens'] = max_out_len if min_out_len is not None: gen_config['min_new_tokens'] = min_out_len - if do_sample or ('do_sample' in self.gen_config and self.gen_config['do_sample']): - gen_config['top_k'] = 40 - gen_config['temperature'] = temperature - else: + if not(do_sample or ('do_sample' in self.gen_config and self.gen_config['do_sample'])): if self.version_info >= (0, 6, 0): gen_config['do_sample'] = False else: @@ -140,6 +137,8 @@ class TurboMindModelwithChatTemplate(BaseModel): from lmdeploy import GenerationConfig gen_config = {k: v for k, v in gen_config.items() if hasattr(GenerationConfig, k)} gen_config = GenerationConfig(**gen_config) + self.logger.info('Generation Config of LMdeploy: ') + self.logger.info(gen_config) results = [] outputs = self.pipe(messages, gen_config=gen_config, do_preprocess=False) diff --git a/opencompass/models/vllm_with_tf_above_v4_33.py b/opencompass/models/vllm_with_tf_above_v4_33.py index cf79ea6f..27226fa5 100644 --- a/opencompass/models/vllm_with_tf_above_v4_33.py +++ b/opencompass/models/vllm_with_tf_above_v4_33.py @@ -108,6 +108,8 @@ class VLLMwithChatTemplate(BaseModel): sampling_kwargs.update(self.generation_kwargs) sampling_kwargs.update(kwargs) sampling_kwargs = SamplingParams(**sampling_kwargs) + self.logger.info('Sampling Params of vLLM: ') + self.logger.info(sampling_kwargs) outputs = self.model.generate(messages, sampling_kwargs) diff --git a/opencompass/openicl/icl_evaluator/__init__.py b/opencompass/openicl/icl_evaluator/__init__.py index 1fd1683b..5103c00d 100644 --- a/opencompass/openicl/icl_evaluator/__init__.py +++ b/opencompass/openicl/icl_evaluator/__init__.py @@ -4,6 +4,7 @@ from .icl_base_evaluator import BaseEvaluator # noqa from .icl_bpc_evaluator import BPCEvaluator # noqa from .icl_circular_evaluator import CircularEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa +from .icl_gpassk_evaluator import GPassKEvaluator # noqa from .icl_hf_evaluator import * # noqa from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa diff --git a/opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py b/opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py new file mode 100644 index 00000000..80a59073 --- /dev/null +++ b/opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py @@ -0,0 +1,163 @@ +from abc import abstractmethod +from typing import Any, Dict, List, Union + +import numpy as np +from scipy.stats import hypergeom + +from opencompass.registry import ICL_EVALUATORS + +from .icl_base_evaluator import BaseEvaluator + + +def compute_pass_at_k(n, c, k): + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + +def _compute_g_pass_at_k(n, c, k, m): + if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0: + return 0.0 + return hypergeom.sf(m - 1, n, c, k) + + +def compute_g_pass_at_k(n, c, k, t): + m = max(int(np.ceil(k * t)), 1) + return _compute_g_pass_at_k(n, c, k, m) + + +def compute_mg_pass_at_k(n, c, k): + l, r = int(np.ceil(k * 0.5)), k + + mg_pass_at_k = 0.0 + for i in range(l + 1, r + 1): + mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i) + mg_pass_at_k = 2 * mg_pass_at_k / k + + return mg_pass_at_k + + +@ICL_EVALUATORS.register_module() +class GPassKEvaluator(BaseEvaluator): + """Evaluator for computing the G-Pass@k Metric. + + This evaluator performs the following steps: + 1. Invokes task-specific `preprocess` on predictions to + assign a consistency label to each prediction and its + corresponding reference. + 2. Calculates metrics for each input example based on + these labels. + 3. Aggregates the overall metrics through a task-specific + `postprocess`. + + Args: + k (int or list of int): Number of predictions to be + considered in G-Pass@k. It can be a single integer + (e.g., `k=16` computes G-Pass@16) or a list of + integers (e.g., `[4, 8, 16]` computes G-Pass@4, + G-Pass@8, and G-Pass@16). + + replication (int): Controls the number of generations + used to estimate G-Pass@k. The total number of + generations is determined by multiplying the + maximum of `k` with `replication`. This parameter + should be a single integer. + + thresholds (list of float): A list of floating-point + numbers that define the thresholds for the G-Pass@k + metric. + """ + + def __init__( + self, + k: Union[int, List[int]] = 16, + replication: int = 3, + thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None: + super().__init__() + + if isinstance(k, int): + k = [k] + + self.k = k + self.replication = replication + self.n = max(k) * replication + self.thresholds = thresholds + + @property + def output_dir(self): + # please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200 + return self._out_dir + + @abstractmethod + def preprocess(self, predictions, references, test_set) -> None: + """Perform operations on predictions before computing metrics, for + example, do answer_extraction and model_judge in mathematical reasoning + task. + + Return: + labels: A list contains the label which indicates whether + prediction is consistency with reference at each position. + """ + raise NotImplementedError + + @abstractmethod + def group(self, predictions, labels, test_set) -> Dict[str, Any]: + """Group the predictions and references. + + Return: + A dict contains the grouped predictions and references. + """ + raise NotImplementedError + + @abstractmethod + def reduce(self, details) -> Dict[str, Any]: + """Aggregate the overall metrics. + + Return: + A dict contains overall metrics, like: + {'details': details for each example, 'G-Pass@16': xxx} + """ + raise NotImplementedError + + def score(self, predictions, references, test_set) -> Dict[str, Any]: + """Compute G-Pass@k metrics. + + Return: + A dict contains metrics for each dataset sample and + overall metrics reduced by `self.reduce`, like: + {'details': details for each example, 'G-Pass@16': xxx} + """ + labels = self.preprocess(predictions, references, test_set) + grouped_examples = self.group(predictions, labels, test_set) + + details = [] + total_pass_num, count = 0, 0 + for example_abbr, examples in grouped_examples.items(): + detail = { + k: v + for k, v in examples[0].items() + if k not in ['prediction', 'label'] + } + detail.update({ + 'predictions': [{ + 'prediction': example['prediction'], + 'label': example['label'] + } for example in examples], + }) + + current_example_labels = [e['label'] for e in examples] + c = int(np.sum(current_example_labels)) + + for k in self.k: + for threshold in self.thresholds: + detail[f'G-Pass@{k}_{threshold}'] = compute_g_pass_at_k( + n=self.n, c=c, k=k, t=threshold) + detail[f'mG-Pass@{k}'] = compute_mg_pass_at_k(n=self.n, + c=c, + k=k) + count += self.n + total_pass_num += c + + details.append(detail) + + return self.reduce(details)