From 704853e5e7fb5cb933f00071a774a6d9c7eb9f5b Mon Sep 17 00:00:00 2001 From: Songyang Zhang Date: Mon, 29 Jul 2024 18:32:50 +0800 Subject: [PATCH] [Feature] Update pip install (#1324) * [Feature] Update pip install * Update Configuration * Update * Update * Update * Update Internal Config * Update collect env --- MANIFEST.in | 2 + .../gpqa_openai_simple_evals_gen_5aeece.py | 4 +- opencompass/__init__.py | 2 +- opencompass/datasets/LCBench.py | 329 ++++++++++++++++++ opencompass/datasets/__init__.py | 1 + opencompass/datasets/bbh.py | 10 +- opencompass/models/__init__.py | 2 + opencompass/models/claude_allesapin.py | 150 ++++++++ opencompass/models/claude_sdk_api.py | 121 +++++++ opencompass/models/turbomind.py | 2 + opencompass/runners/__init__.py | 1 + opencompass/runners/volc.py | 260 ++++++++++++++ opencompass/summarizers/default.py | 2 +- opencompass/tasks/openicl_eval.py | 24 ++ opencompass/utils/collect_env.py | 14 + setup.py | 10 +- tools/update_dataset_suffix.py | 4 +- 17 files changed, 923 insertions(+), 15 deletions(-) create mode 100644 MANIFEST.in create mode 100644 opencompass/datasets/LCBench.py create mode 100644 opencompass/models/claude_allesapin.py create mode 100644 opencompass/models/claude_sdk_api.py create mode 100644 opencompass/runners/volc.py diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..63d03d04 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +recursive-include opencompass/configs *.py *.yml *.json *.txt *.md +recursive-include opencompass/openicl/icl_evaluator/hf_metrics *.py diff --git a/configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py b/configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py index 1dbcc1cc..7f77116e 100644 --- a/configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py +++ b/configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py @@ -1,7 +1,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import GPQASimpleEvalDataset, GPQA_Simple_Eval_postprocess, GPQAEvaluator +from opencompass.datasets import GPQADataset, GPQA_Simple_Eval_postprocess, GPQAEvaluator # openai_simple_eval prompt align_prompt = """ @@ -43,7 +43,7 @@ for split in list(gpqa_subsets.keys()): gpqa_datasets.append( dict( abbr='GPQA_' + split, - type=GPQASimpleEvalDataset, + type=GPQADataset, path='./data/gpqa/', name=gpqa_subsets[split], reader_cfg=gpqa_reader_cfg, diff --git a/opencompass/__init__.py b/opencompass/__init__.py index 44b18069..9a243b92 100644 --- a/opencompass/__init__.py +++ b/opencompass/__init__.py @@ -1 +1 @@ -__version__ = '0.2.6' +__version__ = '0.2.7rc1' diff --git a/opencompass/datasets/LCBench.py b/opencompass/datasets/LCBench.py new file mode 100644 index 00000000..76cc0fae --- /dev/null +++ b/opencompass/datasets/LCBench.py @@ -0,0 +1,329 @@ +import contextlib +import io +import itertools +import multiprocessing +import re +import signal +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import List, Sequence, Union + +import numpy as np +from datasets import DatasetDict, concatenate_datasets, load_dataset + +from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class LCDataset(BaseDataset): + + @staticmethod + def load(path: str, num_repeats: int = 1, difficulty='ALL'): + """Load LC dataset for pass k mode. + + Note that you can use num_repeats > 1 when your model does not support + `num_return_sequence` in generation, otherwise use the raw + LC dataset and set `num_return_sequence` in model config to + generate multiple responses for testing pass@k>1. + + It better to change your dataset abbr correspondingly if you want to + change num_repeats>1, otherwise the number in + `.cache/dataset_size.json` might be inconsistent. + + Args: + num_repeats(int): Number of repetition for this dataset to get + multiple responses in special cases. + """ + + def processing_test(example): + example['test_case'] = example['test_list'] + example['test_list'] = '\n'.join(example['test_list']) + example['test_column'] = dict(test_list_2=example['test_list'], + task_id=example['Contest id']) + return example + + train = load_dataset('json', data_files=path, + split='train[:5]').map(processing_test) + test = load_dataset('json', data_files=path, + split='train[5:]').map(processing_test) + if not difficulty == 'ALL': + train = train.filter( + lambda example: example['Difficulty'] == difficulty) + test = test.filter( + lambda example: example['Difficulty'] == difficulty) + test = concatenate_datasets([test] * num_repeats) + return DatasetDict({'train': train, 'test': test}) + + +class TimeOutException(Exception): + pass + + +@contextlib.contextmanager +def swallow_io(): + stream = WriteOnlyStringIO() + with contextlib.redirect_stdout(stream): + with contextlib.redirect_stderr(stream): + with redirect_stdin(stream): + yield + + +@contextlib.contextmanager +def time_limit(seconds: float): + + def signal_handler(signum, frame): + raise TimeOutException('Time out!') + + signal.setitimer(signal.ITIMER_REAL, seconds) + signal.signal(signal.SIGALRM, signal_handler) + try: + yield + finally: + signal.setitimer(signal.ITIMER_REAL, 0) + + +class WriteOnlyStringIO(io.StringIO): + """StringIO that throws an exception when it's read from.""" + + def read(self, *args, **kwargs): + raise IOError + + def readline(self, *args, **kwargs): + raise IOError + + def readlines(self, *args, **kwargs): + raise IOError + + def readable(self, *args, **kwargs): + """Returns True if the IO object can be read.""" + return False + + +class redirect_stdin(contextlib._RedirectStream): # type: ignore + _stream = 'stdin' + + +@ICL_EVALUATORS.register_module() +class LCEvaluator(BaseEvaluator): + + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0} + details = {} + with ProcessPoolExecutor() as executor: + futures = [] + for i, (refer, pred) in enumerate(zip(references, predictions)): + pred = self._process_answer(pred) + programs = self._process_test(refer, pred) + future = executor.submit(execution, programs, i, 3) + futures.append(future) + + from tqdm import tqdm + for future in tqdm(as_completed(futures), total=len(futures)): + index, ret = future.result() + result[ret] += 1 + details[str(index)] = { + 'programs': predictions[index], + 'result': ret, + 'is_correct': ret == 'pass', + } + + result['score'] = result['pass'] / len(predictions) * 100 + result['details'] = details + return result + + def _process_answer(self, text): + try: + # for chatGLM related text + eval_text = eval(text) + except Exception: + pass + else: + if isinstance(eval_text, str): + text = eval_text + # deal with code block + if '```' in text: + blocks = re.findall(r'```(.*?)```', text, re.DOTALL) + if len(blocks) == 0: + text = text.split('```')[1] # fall back to default strategy + else: + text = blocks[0] # fetch the first code block + if not text.startswith('\n'): # in case starting with ```xxx + text = text[max(text.find('\n') + 1, 0):] + text = text.strip() + match = re.search(r"('\s*|)(\[DONE\]|DONE)", text) + if match: + text = text[:match.start()] + match = re.search(r"(\[BEGIN\]|BEGIN)('\s*|)", text) + if match: + text = text[match.end():] + text = text.strip() + if text.startswith("'"): + text = text[1:] + if text.endswith("'"): + text = text[:-1] + text = text.replace('\\', '') + match = re.search(r'```python(.*)```', text, re.DOTALL) + if match: + text = match.group(1).strip().split('```')[0].strip() + return text + + def _process_test(self, test_case, pred): + formatted = pred + '\n' + formatted += test_case + return formatted + + +def execution(programs, task_id, timeout): + """Execution function for running generation code. + + Args: + programs(str): Python code to be executed. + task_id(int): Task id of the current example. + timeout(int): Time limit for execution, avoid unnecessary + blocking. + + In pass@k scenario, a lot of programs should be executed. + Some internal error cannot be handled properly, such as + `RecursionError` might cause system break. It is better to + separate the execution in thread or multiprocess to better + control the process. + """ + + def _execution(programs, timeout): + try: + # Add exec globals to prevent the exec to raise + # unnecessary NameError for correct answer + exec_globals = {} + with swallow_io(): + with time_limit(timeout): + exec(programs, exec_globals) + key.append('pass') + except TimeOutException: + key.append('timeout') + except AssertionError: + key.append('wrong_answer') + except BaseException as e: + print(e) + key.append('failed') + + manager = multiprocessing.Manager() + key = manager.list() + # `signal` cannot be used in child thread, therefore, we + # need to create a process in the thread. + p = multiprocessing.Process(target=_execution, + args=(programs, timeout - 1)) + p.start() + p.join(timeout=timeout) + if p.is_alive(): + p.kill() + # key might not have value if killed + return task_id, 'timeout' + return task_id, key[0] + + +class LCPassKEvaluator(LCEvaluator): + """Better use for pass k evaluation. + + Args: + k(Tuple[int]): Choices of Pass@k. Defaults to (1, 10, 100) + """ + + def __init__(self, k=(1, 10, 100)) -> None: + if not isinstance(k, Sequence): + k = (k, ) + self.k = k + + @staticmethod + def estimate_pass_at_k( + num_samples: Union[int, List[int], np.ndarray], + num_correct: Union[List[int], np.ndarray], + k: int, + ) -> np.ndarray: + """Estimates pass@k of each problem and returns them in an array.""" + + def estimator(n: int, c: int, k: int) -> float: + """ + Calculates 1 - comb(n - c, k) / comb(n, k). + """ + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + if isinstance(num_samples, int): + num_samples_it = itertools.repeat(num_samples, len(num_correct)) + else: + assert len(num_samples) == len(num_correct) + num_samples_it = iter(num_samples) + + return np.array([ + estimator(int(n), int(c), k) + for n, c in zip(num_samples_it, num_correct) + ]) + + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + + task_pass = defaultdict(int) + task_total = defaultdict(int) + + result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0} + details = {} + with ProcessPoolExecutor() as executor: + futures = [] + index, programs = 0, [] + for refer, preds in zip(references, predictions): + # suits for two case + # 1. use repeated dataset + # 2. use `num_return_sequences` to generate multiple responses + if not isinstance(preds, list): + preds = [preds] + test_case = refer['test_list_2'] + task_id = refer['task_id'] + # create empty task_pass in case all example failed + if task_id not in task_pass: + task_pass[task_id] = 0 + for pred in preds: + pred = self._process_answer(pred) + program = self._process_test(test_case, pred) + future = executor.submit(execution, program, + (index, task_id), 3) + futures.append(future) + programs.append(program) + index += 1 + + from tqdm import tqdm + for future in tqdm(as_completed(futures), total=len(futures)): + (index, task_id), ret = future.result() + result[ret] += 1 + task_total[task_id] += 1 + is_correct = ret == 'pass' + task_pass[task_id] += is_correct + details[str(index)] = { + 'program': programs[index], + 'task_id': task_id, + 'result': ret, + 'is_correct': is_correct, + } + + result['details'] = details + + def get_number(tasks): + return np.array([ + task[1] for task in sorted(tasks.items(), key=lambda x: x[0]) + ]) + + task_pass = get_number(task_pass) + task_total = get_number(task_total) + pass_at_k = { + f'pass@{k}': + self.estimate_pass_at_k(task_total, task_pass, k).mean() * 100 + for k in self.k if (task_total >= k).all() + } + result.update(pass_at_k) + return result diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index e8bb6a2e..c5bc80ea 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -62,6 +62,7 @@ from .jsonl import JsonlDataset # noqa: F401, F403 from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403 from .lambada import * # noqa: F401, F403 from .lawbench import * # noqa: F401, F403 +from .LCBench import * # noqa: F401, F403 from .lcsts import * # noqa: F401, F403 from .leval import * # noqa: F401, F403 from .llm_compression import LLMCompressionDataset # noqa: F401, F403 diff --git a/opencompass/datasets/bbh.py b/opencompass/datasets/bbh.py index 7950056f..71c3e174 100644 --- a/opencompass/datasets/bbh.py +++ b/opencompass/datasets/bbh.py @@ -50,9 +50,15 @@ def bbh_freeform_postprocess(text: str) -> str: ans_line = ans.split('answer is ') if len(ans_line) != 1: ans = ans_line[1].strip() - ans = ans.split('\n')[0] + ans = ans.split('\n')[0].strip() + if ans.endswith('.'): - ans = ans[:-1] + ans = ans[:-1].strip() + + match = re.search(r'\*\*(.*?)\*\*', ans) + if match: + return match.group(1) + return ans diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 21ce7df6..d2ba27d3 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -6,7 +6,9 @@ from .baidu_api import ERNIEBot # noqa: F401 from .base import BaseModel, LMTemplateParser # noqa: F401 from .base_api import APITemplateParser, BaseAPIModel # noqa: F401 from .bytedance_api import ByteDance # noqa: F401 +from .claude_allesapin import ClaudeAllesAPIN # noqa: F401 from .claude_api import Claude # noqa: F401 +from .claude_sdk_api import ClaudeSDK # noqa: F401 from .deepseek_api import DeepseekAPI # noqa: F401 from .doubao_api import Doubao # noqa: F401 from .gemini_api import Gemini # noqa: F401 diff --git a/opencompass/models/claude_allesapin.py b/opencompass/models/claude_allesapin.py new file mode 100644 index 00000000..96ad6080 --- /dev/null +++ b/opencompass/models/claude_allesapin.py @@ -0,0 +1,150 @@ +import json +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.registry import MODELS +from opencompass.utils import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +@MODELS.register_module() +class ClaudeAllesAPIN(BaseAPIModel): + """Model wrapper around Claude-AllesAPIN. + + Args: + path (str): The name of Claude's model. + url (str): URL to AllesAPIN. + key (str): AllesAPIN key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + is_api: bool = True + + def __init__(self, + path: str, + url: str, + key: str, + query_per_second: int = 1, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + self.url = url + self.headers = { + 'alles-apin-token': key, + 'content-type': 'application/json', + } + + def generate(self, + inputs: List[PromptType], + max_out_len: int = 512, + **kwargs) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[PromptType]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenAGIEval's + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + return results + + def _generate(self, input: PromptType, max_out_len: int) -> str: + """Generate results given an input. + + Args: + inputs (PromptType): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + temperature (float): What sampling temperature to use, + between 0 and 2. Higher values like 0.8 will make the output + more random, while lower values like 0.2 will make it more + focused and deterministic. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + msg_buffer, last_role = [], None + for item in input: + item['role'] = 'assistant' if item['role'] == 'BOT' else 'user' + if item['role'] != last_role and last_role is not None: + messages.append({ + 'content': '\n'.join(msg_buffer), + 'role': last_role + }) + msg_buffer = [] + msg_buffer.append(item['prompt']) + last_role = item['role'] + messages.append({ + 'content': '\n'.join(msg_buffer), + 'role': last_role + }) + + data = { + 'model': self.path, + 'messages': messages, + 'max_tokens': max_out_len, + } + + err_data = [] + for _ in range(self.retry + 1): + self.wait() + try: + raw_response = requests.post(self.url, + headers=self.headers, + data=json.dumps(data)) + except requests.ConnectionError: + time.sleep(5) + continue + except requests.ReadTimeout: + time.sleep(5) + continue + try: + response = raw_response.json() + except requests.JSONDecodeError: + if 'https://errors.aliyun.com/images' in \ + raw_response.content.decode(): + return 'request blocked by allesapin' + self.logger.error('JsonDecode error, got', + raw_response.content) + continue + if raw_response.status_code == 200 and response[ + 'msgCode'] == '10000': + data = response['data'] + generated = data['content'][0]['text'].strip() + self.logger.debug(f'Generated: {generated}') + return generated + self.logger.error(response['data']) + err_data.append(response['data']) + + raise RuntimeError(err_data) diff --git a/opencompass/models/claude_sdk_api.py b/opencompass/models/claude_sdk_api.py new file mode 100644 index 00000000..8cbf98ef --- /dev/null +++ b/opencompass/models/claude_sdk_api.py @@ -0,0 +1,121 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.registry import MODELS +from opencompass.utils import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +@MODELS.register_module() +class ClaudeSDK(BaseAPIModel): + """Model wrapper around Claude SDK API. + + Args: + key (str): Authorization key. + path (str): The model to be used. Defaults to claude-2. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + key: str, + path: str = 'claude-2', + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + temperature: Optional[float] = 0.0, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + try: + from anthropic import Anthropic + except ImportError: + raise ImportError('Import anthropic failed. Please install it ' + 'with "pip install anthropic" and try again.') + + self.anthropic = Anthropic(api_key=key) + self.model = path + self.temperature = temperature + + def generate( + self, + inputs: List[PromptType], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[PromptType]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + return results + + def _generate( + self, + input: PromptType, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (PromptType): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + elif item['role'] == 'SYSTEM': + msg['role'] = 'system' + messages.append(msg) + + num_retries = 0 + while num_retries < self.retry: + self.wait() + try: + responses = self.anthropic.messages.create( + model=self.model, + max_tokens=max_out_len, + temperature=self.temperature, + messages=messages) + return responses.content[0].text + except Exception as e: + self.logger.error(e) + num_retries += 1 + raise RuntimeError('Calling Claude API failed after retrying for ' + f'{self.retry} times. Check the logs for details.') diff --git a/opencompass/models/turbomind.py b/opencompass/models/turbomind.py index 70e82e1b..33e1b984 100644 --- a/opencompass/models/turbomind.py +++ b/opencompass/models/turbomind.py @@ -1,4 +1,5 @@ import copy +import os from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union @@ -60,6 +61,7 @@ class TurboMindModel(BaseModel): from lmdeploy.messages import TurbomindEngineConfig engine_config = TurbomindEngineConfig(**engine_config) self.logger = get_logger() + assert os.path.exists(path), '{} is not existist'.format(path) tm_model = TurboMind.from_pretrained(path, engine_config=engine_config) self.tokenizer = tm_model.tokenizer self.generators = [ diff --git a/opencompass/runners/__init__.py b/opencompass/runners/__init__.py index f4a3207e..129297b3 100644 --- a/opencompass/runners/__init__.py +++ b/opencompass/runners/__init__.py @@ -2,3 +2,4 @@ from .dlc import * # noqa: F401, F403 from .local import * # noqa: F401, F403 from .slurm import * # noqa: F401, F403 from .slurm_sequential import * # noqa: F401, F403 +from .volc import * # noqa: F401, F403 diff --git a/opencompass/runners/volc.py b/opencompass/runners/volc.py new file mode 100644 index 00000000..b055faf6 --- /dev/null +++ b/opencompass/runners/volc.py @@ -0,0 +1,260 @@ +import os +import os.path as osp +import random +import re +import subprocess +import time +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import mmengine +import yaml +from mmengine.config import ConfigDict +from mmengine.utils import track_parallel_progress + +from opencompass.registry import RUNNERS, TASKS +from opencompass.utils import get_logger + +from .base import BaseRunner + + +@RUNNERS.register_module() +class VOLCRunner(BaseRunner): + """Distributed runner based on Volcano Cloud Cluster (VCC). It will launch + multiple tasks in parallel with the 'vcc' command. Please install and + configure VCC first before using this runner. + + Args: + task (ConfigDict): Task type config. + volcano_cfg (ConfigDict): Volcano Cloud config. + queue_name (str): Name of resource queue. + preemptible (bool): Whether to launch task in preemptible way. + Default: False + priority (bool): Priority of tasks, ranging from 1 to 9. + 9 means the highest priority. Default: None + max_num_workers (int): Max number of workers. Default: 32. + retry (int): Number of retries when job failed. Default: 2. + debug (bool): Whether to run in debug mode. Default: False. + lark_bot_url (str): Lark bot url. Default: None. + """ + + def __init__(self, + task: ConfigDict, + volcano_cfg: ConfigDict, + queue_name: str, + preemptible: bool = False, + priority: Optional[int] = None, + max_num_workers: int = 32, + retry: int = 2, + debug: bool = False, + lark_bot_url: str = None): + super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) + self.volcano_cfg = volcano_cfg + self.max_num_workers = max_num_workers + self.retry = retry + self.queue_name = queue_name + self.preemptible = preemptible + self.priority = priority + + def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: + """Launch multiple tasks. + + Args: + tasks (list[dict]): A list of task configs, usually generated by + Partitioner. + + Returns: + list[tuple[str, int]]: A list of (task name, exit code). + """ + + if not self.debug: + status = track_parallel_progress(self._launch, + tasks, + nproc=self.max_num_workers, + keep_order=False) + else: + status = [self._launch(task, random_sleep=False) for task in tasks] + return status + + def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True): + """Launch a single task. + + Args: + task_cfg (ConfigDict): Task config. + random_sleep (bool): Whether to sleep for a random time before + running the command. This avoids cluster error when launching + multiple tasks at the same time. Default: True. + + Returns: + tuple[str, int]: Task name and exit code. + """ + + task_type = self.task_cfg.type + if isinstance(self.task_cfg.type, str): + task_type = TASKS.get(task_type) + task = task_type(task_cfg) + num_gpus = task.num_gpus + task_name = task.name + + # Build up VCC command + pwd = os.getcwd() + # Dump task config to file + mmengine.mkdir_or_exist('tmp/') + param_file = f'{pwd}/tmp/{os.getpid()}_params.py' + + volc_cfg_file = f'{pwd}/tmp/{os.getpid()}_cfg.yaml' + volc_cfg = self._choose_flavor(num_gpus) + with open(volc_cfg_file, 'w') as fp: + yaml.dump(volc_cfg, fp, sort_keys=False) + try: + task_cfg.dump(param_file) + if self.volcano_cfg.get('bashrc_path') is not None: + # using user's conda env + bashrc_path = self.volcano_cfg['bashrc_path'] + assert osp.exists(bashrc_path) + assert self.volcano_cfg.get('conda_env_name') is not None + + conda_env_name = self.volcano_cfg['conda_env_name'] + + shell_cmd = (f'source {self.volcano_cfg["bashrc_path"]}; ' + f'source activate {conda_env_name}; ') + shell_cmd += f'export PYTHONPATH={pwd}:$PYTHONPATH; ' + else: + assert self.volcano_cfg.get('python_env_path') is not None + shell_cmd = ( + f'export PATH={self.volcano_cfg["python_env_path"]}/bin:$PATH; ' # noqa: E501 + f'export PYTHONPATH={pwd}:$PYTHONPATH; ') + + huggingface_cache = self.volcano_cfg.get('huggingface_cache') + if huggingface_cache is not None: + # HUGGINGFACE_HUB_CACHE is a Legacy env variable, here we set + # `HF_HUB_CACHE` and `HUGGINGFACE_HUB_CACHE` for bc + shell_cmd += f'export HF_HUB_CACHE={huggingface_cache}; ' + shell_cmd += f'export HUGGINGFACE_HUB_CACHE={huggingface_cache}; ' # noqa: E501 + + torch_cache = self.volcano_cfg.get('torch_cache') + if torch_cache is not None: + shell_cmd += f'export TORCH_HOME={torch_cache}; ' + + hf_offline = self.volcano_cfg.get('hf_offline', True) + + if hf_offline: + shell_cmd += 'export HF_DATASETS_OFFLINE=1; export TRANSFORMERS_OFFLINE=1; export HF_EVALUATE_OFFLINE=1; export HF_HUB_OFFLINE=1; ' # noqa: E501 + + hf_endpoint = self.volcano_cfg.get('hf_endpoint') + if hf_endpoint is not None: + shell_cmd += f'export HF_ENDPOINT={hf_endpoint}; ' + + extra_envs = self.volcano_cfg.get('extra_envs') + if extra_envs is not None: + for extra_env in extra_envs: + shell_cmd += f'export {extra_env}; ' + + shell_cmd += f'cd {pwd}; ' + shell_cmd += '{task_cmd}' + + task_name = task_name[:128].replace('[', '-').replace( + ']', '').replace('/', '-').replace(',', + '--').replace('.', '_') + tmpl = ('volc ml_task submit' + f" --conf '{volc_cfg_file}'" + f" --entrypoint '{shell_cmd}'" + f' --task_name {task_name}' + f' --resource_queue_name {self.queue_name}') + if self.preemptible: + tmpl += ' --preemptible' + if self.priority is not None: + tmpl += f' --priority {self.priority}' + get_cmd = partial(task.get_command, + cfg_path=param_file, + template=tmpl) + cmd = get_cmd() + + logger = get_logger() + logger.debug(f'Running command: {cmd}') + + out_path = task.get_log_path(file_extension='txt') + mmengine.mkdir_or_exist(osp.split(out_path)[0]) + + retry = self.retry + while True: + if random_sleep: + time.sleep(random.randint(0, 10)) + task_status, returncode = self._run_task(cmd, + out_path, + poll_interval=20) + output_paths = task.get_output_paths() + if not (self._job_failed(task_status, output_paths)) \ + or retry <= 0: + break + retry -= 1 + + finally: + # Clean up + os.remove(param_file) + os.remove(volc_cfg_file) + return task_name, returncode + + def _run_task(self, cmd, log_path, poll_interval): + result = subprocess.run(cmd, + shell=True, + text=True, + capture_output=True) + pattern = r'(?<=task_id=).*(?=\n\n)' + match = re.search(pattern, result.stdout) + if match: + task_id = match.group() + ask_cmd = f'volc ml_task get --id {task_id} --output json ' + \ + '--format Status' + log_cmd = f'volc ml_task logs --task {task_id} --instance worker_0' + while True: + task_status = os.popen(ask_cmd).read() + pattern = r'(?<=\[{"Status":").*(?="}\])' + match = re.search(pattern, task_status) + if match: + task_status = match.group() + else: + task_status = 'Exception' + if self.debug: + print(task_status) + logs = os.popen(log_cmd).read() + with open(log_path, 'w', encoding='utf-8') as f: + f.write(logs) + if task_status in [ + 'Success', 'Failed', 'Cancelled', 'Exception', + 'Killing' + ]: + break + time.sleep(poll_interval) + else: + task_status = 'Exception' + + return task_status, result.returncode + + def _job_failed(self, task_status: str, output_paths: List[str]) -> bool: + return task_status != 'Success' or not all( + osp.exists(output_path) for output_path in output_paths) + + def _choose_flavor(self, num_gpus): + config_path = self.volcano_cfg.volcano_config_path + with open(config_path) as fp: + volc_cfg = yaml.safe_load(fp) + if num_gpus <= 0: + flavor = 'ml.c1ie.2xlarge' + elif num_gpus == 1: + flavor = 'ml.pni2l.3xlarge' + elif num_gpus == 2: + flavor = 'ml.pni2l.7xlarge' + elif num_gpus <= 4: + flavor = 'ml.pni2l.14xlarge' + elif num_gpus <= 8: + flavor = 'ml.pni2l.28xlarge' + else: + raise NotImplementedError + + role_specs = volc_cfg['TaskRoleSpecs'] + for i in range(len(role_specs)): + if role_specs[i]['RoleName'] == 'worker': + role_specs[i]['Flavor'] = flavor + + return volc_cfg diff --git a/opencompass/summarizers/default.py b/opencompass/summarizers/default.py index 7eef249d..93dab27a 100644 --- a/opencompass/summarizers/default.py +++ b/opencompass/summarizers/default.py @@ -16,7 +16,7 @@ from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg, model_abbr_from_cfg) from opencompass.utils.prompt import get_prompt_hash -METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth', 'f1', 'exact_match'] +METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth', 'f1', 'exact_match', 'extract_rate'] METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len', 'type'] def model_abbr_from_cfg_used_in_summarizer(model): diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 1f9aff72..de536f4b 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -75,6 +75,8 @@ class OpenICLEvalTask(BaseTask): for c in sum(self.dataset_cfgs, [])) self.dump_details = cfg.get('eval', {}).get('runner', {}).get( 'task', {}).get('dump_details', False) + self.cal_extrat_rate = cfg.get('eval', {}).get('runner', {}).get( + 'task', {}).get('cal_extrat_rate', False) def get_command(self, cfg_path, template): sys.path.append(os.getcwd()) @@ -234,6 +236,9 @@ class OpenICLEvalTask(BaseTask): pred_strs, test_set[self.output_column], details, pred_dicts) result['type'] = result['details'].pop('type', None) + if self.cal_extrat_rate: + # Calculate the extraction success rate for prediction + result['extract_rate'] = self.extract_rate(result) if 'PPL' in str( self.dataset_cfg.infer_cfg.inferencer.type): @@ -262,6 +267,25 @@ class OpenICLEvalTask(BaseTask): mkdir_or_exist(osp.split(out_path)[0]) mmengine.dump(result, out_path, ensure_ascii=False, indent=4) + def extract_rate(self, results): + """This function is designed for calculating the extraction rate. + + Args: + results (dict): The result dict, include the information + """ + details = results['details'] + details_list = list(details.values()) + invalid_extractions = [] + for item in details_list: + try: + invalid_extractions.extend( + [item] if not item['predictions'] else []) + except KeyError as e: + self.logger.warning(f'Skip {e} due to: {item}') + raise KeyError + success_rate = 100 - len(invalid_extractions) / len(details) * 100 + return success_rate + def format_details(self, predictions, references, details, pred_dicts): """This function is responsible for formatting prediction details. diff --git a/opencompass/utils/collect_env.py b/opencompass/utils/collect_env.py index c8950491..771d44fb 100644 --- a/opencompass/utils/collect_env.py +++ b/opencompass/utils/collect_env.py @@ -9,4 +9,18 @@ def collect_env(): env_info = collect_base_env() env_info['opencompass'] = opencompass.__version__ + '+' + get_git_hash( )[:7] + + # LMDeploy + try: + import lmdeploy + env_info['lmdeploy'] = lmdeploy.__version__ + except ModuleNotFoundError as e: + env_info['lmdeploy'] = f'not installed:{e}' + # Transformers + try: + import transformers + env_info['transformers'] = transformers.__version__ + except ModuleNotFoundError as e: + env_info['transformers'] = f'not installed:{e}' + return env_info diff --git a/setup.py b/setup.py index 46d4dc23..1ac5f84b 100644 --- a/setup.py +++ b/setup.py @@ -117,14 +117,8 @@ def do_setup(): python_requires='>=3.8.0', install_requires=parse_requirements('requirements/runtime.txt'), license='Apache License 2.0', - packages=find_packages(exclude=[ - 'test*', - 'configs', - 'data', - 'docs', - 'tools', - 'tmp', - ]), + include_package_data=True, + packages=find_packages(), keywords=[ 'AI', 'NLP', 'in-context learning', 'large language model', 'evaluation', 'benchmark', 'llm' diff --git a/tools/update_dataset_suffix.py b/tools/update_dataset_suffix.py index a49f37eb..a256c7dd 100755 --- a/tools/update_dataset_suffix.py +++ b/tools/update_dataset_suffix.py @@ -109,9 +109,11 @@ def update_imports(data): def main(): parser = argparse.ArgumentParser() parser.add_argument('python_files', nargs='*') + # Could be opencompass/configs/datasets and configs/datasets + parser.add_argument('--root_folder', default='configs/datasets') args = parser.parse_args() - root_folder = 'configs/datasets' + root_folder = args.root_folder if args.python_files: python_files = [ i for i in args.python_files if i.startswith(root_folder)