From 5de85406ce0b9990015e4b73878e68b62808c90c Mon Sep 17 00:00:00 2001 From: Fengzhe Zhou Date: Fri, 17 May 2024 16:50:58 +0800 Subject: [PATCH] [Sync] add OC16 entry (#1171) --- .../mbpp/sanitized_mbpp_mdblock_gen_a447ff.py | 41 ++++ .../datasets/taco/taco_staged_gen_411572.py | 36 ++++ configs/models/deepseek/hf_deepseek_v2.py | 2 +- .../models/deepseek/hf_deepseek_v2_chat.py | 2 +- configs/models/yi/hf_yi_1_5_34b_chat.py | 12 ++ configs/models/yi/hf_yi_1_5_6b_chat.py | 12 ++ configs/models/yi/hf_yi_1_5_9b_chat.py | 12 ++ opencompass/datasets/mbpp.py | 1 + opencompass/datasets/taco.py | 7 +- opencompass/models/__init__.py | 1 + opencompass/models/huggingface_above_v4_33.py | 5 +- opencompass/models/lightllm_api.py | 2 + opencompass/models/yi_api.py | 178 ++++++++++++++++++ opencompass/runners/dlc.py | 35 ++-- opencompass/utils/run.py | 2 +- 15 files changed, 321 insertions(+), 27 deletions(-) create mode 100644 configs/datasets/mbpp/sanitized_mbpp_mdblock_gen_a447ff.py create mode 100644 configs/datasets/taco/taco_staged_gen_411572.py create mode 100644 configs/models/yi/hf_yi_1_5_34b_chat.py create mode 100644 configs/models/yi/hf_yi_1_5_6b_chat.py create mode 100644 configs/models/yi/hf_yi_1_5_9b_chat.py create mode 100644 opencompass/models/yi_api.py diff --git a/configs/datasets/mbpp/sanitized_mbpp_mdblock_gen_a447ff.py b/configs/datasets/mbpp/sanitized_mbpp_mdblock_gen_a447ff.py new file mode 100644 index 00000000..957f793e --- /dev/null +++ b/configs/datasets/mbpp/sanitized_mbpp_mdblock_gen_a447ff.py @@ -0,0 +1,41 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import SanitizedMBPPDataset, MBPPEvaluator + +sanitized_mbpp_reader_cfg = dict(input_columns=['text', 'test_list'], output_column='test_list_2') + +sanitized_mbpp_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='You are an expert Python programmer, and here is your task:\nWrite a function to find the similar elements from the given two tuple lists.\nYour code should pass these tests:\n\nassert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)\nassert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)\nassert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)\n',), + dict(role='BOT', prompt='```python\ndef similar_elements(test_tup1, test_tup2):\n res = tuple(set(test_tup1) & set(test_tup2))\n return (res)```',), + + dict(role='HUMAN', prompt='You are an expert Python programmer, and here is your task:\nWrite a python function to identify non-prime numbers.\nYour code should pass these tests:\n\nassert is_not_prime(2) == False\nassert is_not_prime(10) == True\nassert is_not_prime(35) == True\n',), + dict(role='BOT', prompt='```python\nimport math\ndef is_not_prime(n):\n result = False\n for i in range(2,int(math.sqrt(n)) + 1):\n if n %% i == 0:\n result = True\n return result```',), + + dict(role='HUMAN', prompt='You are an expert Python programmer, and here is your task:\nWrite a function to find the largest integers from a given list of numbers using heap queue algorithm.\nYour code should pass these tests:\n\nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65]\nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75]\nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]\n',), + dict(role='BOT', prompt='```python\nimport heapq as hq\ndef heap_queue_largest(nums,n):\n largest_nums = hq.nlargest(n, nums)\n return largest_nums```',), + + dict(role='HUMAN', prompt='You are an expert Python programmer, and here is your task:\n{text}\nYour code should pass these tests:\n\n{test_list}\n',), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512), +) + +sanitized_mbpp_eval_cfg = dict(evaluator=dict(type=MBPPEvaluator), pred_role='BOT') + +sanitized_mbpp_datasets = [ + dict( + type=SanitizedMBPPDataset, + abbr='sanitized_mbpp', + path='./data/mbpp/sanitized-mbpp.jsonl', + reader_cfg=sanitized_mbpp_reader_cfg, + infer_cfg=sanitized_mbpp_infer_cfg, + eval_cfg=sanitized_mbpp_eval_cfg, + ) +] diff --git a/configs/datasets/taco/taco_staged_gen_411572.py b/configs/datasets/taco/taco_staged_gen_411572.py new file mode 100644 index 00000000..2bead3d8 --- /dev/null +++ b/configs/datasets/taco/taco_staged_gen_411572.py @@ -0,0 +1,36 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import TACODataset, TACOEvaluator + +TACO_difficulties_list = ['EASY', 'MEDIUM', 'MEDIUM_HARD', 'HARD', 'VERY_HARD'] +TACO_reader_cfg = dict(input_columns=['question', 'starter'], output_column='problem_id', train_split='test') + +TACO_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='Please write a python program to address the following QUESTION. Your ANSWER should be in a code block format like this: ```python # Write your code here ```. \nQUESTION:\n{question} {starter}\nANSWER:\n'), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=1024), +) + +TACO_eval_cfg = dict(evaluator=dict(type=TACOEvaluator), pred_role='BOT') + +TACO_datasets = [] +for difficulty in TACO_difficulties_list: + TACO_datasets.append( + dict( + type=TACODataset, + abbr='TACO-' + difficulty, + path='./data/BAAI-TACO', + difficulty=difficulty, + reader_cfg=TACO_reader_cfg, + infer_cfg=TACO_infer_cfg, + eval_cfg=TACO_eval_cfg, + ) + ) diff --git a/configs/models/deepseek/hf_deepseek_v2.py b/configs/models/deepseek/hf_deepseek_v2.py index e05be313..1342c0d7 100644 --- a/configs/models/deepseek/hf_deepseek_v2.py +++ b/configs/models/deepseek/hf_deepseek_v2.py @@ -13,6 +13,6 @@ models = [ max_memory={i: '75GB' for i in range(8)}, attn_implementation='eager' ), - run_cfg=dict(num_gpus=4), + run_cfg=dict(num_gpus=8), ) ] diff --git a/configs/models/deepseek/hf_deepseek_v2_chat.py b/configs/models/deepseek/hf_deepseek_v2_chat.py index 67dfd0bd..e8f28d78 100644 --- a/configs/models/deepseek/hf_deepseek_v2_chat.py +++ b/configs/models/deepseek/hf_deepseek_v2_chat.py @@ -13,6 +13,6 @@ models = [ max_memory={i: '75GB' for i in range(8)}, attn_implementation='eager' ), - run_cfg=dict(num_gpus=4), + run_cfg=dict(num_gpus=8), ) ] diff --git a/configs/models/yi/hf_yi_1_5_34b_chat.py b/configs/models/yi/hf_yi_1_5_34b_chat.py new file mode 100644 index 00000000..b1742ee8 --- /dev/null +++ b/configs/models/yi/hf_yi_1_5_34b_chat.py @@ -0,0 +1,12 @@ +from opencompass.models import HuggingFacewithChatTemplate + +models = [ + dict( + type=HuggingFacewithChatTemplate, + abbr='yi-1.5-34b-chat-hf', + path='01-ai/Yi-1.5-34B-Chat', + max_out_len=1024, + batch_size=8, + run_cfg=dict(num_gpus=2), + ) +] diff --git a/configs/models/yi/hf_yi_1_5_6b_chat.py b/configs/models/yi/hf_yi_1_5_6b_chat.py new file mode 100644 index 00000000..015df536 --- /dev/null +++ b/configs/models/yi/hf_yi_1_5_6b_chat.py @@ -0,0 +1,12 @@ +from opencompass.models import HuggingFacewithChatTemplate + +models = [ + dict( + type=HuggingFacewithChatTemplate, + abbr='yi-1.5-6b-chat-hf', + path='01-ai/Yi-1.5-6B-Chat', + max_out_len=1024, + batch_size=8, + run_cfg=dict(num_gpus=1), + ) +] diff --git a/configs/models/yi/hf_yi_1_5_9b_chat.py b/configs/models/yi/hf_yi_1_5_9b_chat.py new file mode 100644 index 00000000..5d7e566d --- /dev/null +++ b/configs/models/yi/hf_yi_1_5_9b_chat.py @@ -0,0 +1,12 @@ +from opencompass.models import HuggingFacewithChatTemplate + +models = [ + dict( + type=HuggingFacewithChatTemplate, + abbr='yi-1.5-9b-chat-hf', + path='01-ai/Yi-1.5-9B-Chat', + max_out_len=1024, + batch_size=8, + run_cfg=dict(num_gpus=1), + ) +] diff --git a/opencompass/datasets/mbpp.py b/opencompass/datasets/mbpp.py index 90ac1410..6992a14f 100644 --- a/opencompass/datasets/mbpp.py +++ b/opencompass/datasets/mbpp.py @@ -288,6 +288,7 @@ class MBPPEvaluator(BaseEvaluator): r'(.*)\s*```.*', r"\[BEGIN\]\s*'(.*)", r'\[BEGIN\](.*)', + r"'(.*)'\s*\[DONE\]", ] for p in patterns: match = re.search(p, text, re.DOTALL) diff --git a/opencompass/datasets/taco.py b/opencompass/datasets/taco.py index 79b41297..7d548477 100644 --- a/opencompass/datasets/taco.py +++ b/opencompass/datasets/taco.py @@ -37,13 +37,16 @@ TIMEOUT = 10 class TACODataset(BaseDataset): @staticmethod - def load(path: str, num_repeats: int = 1): + def load(path: str, num_repeats: int = 1, difficulty='ALL'): dataset = load_from_disk(path) new_dataset = DatasetDict() # add new column "starter" in the prompt for split in dataset.keys(): new_samples = [] for idx, sample in enumerate(dataset[split]): + if 'ALL' not in difficulty: + if not sample['difficulty'] == difficulty: + continue starter_code = None if len( sample['starter_code']) == 0 else sample['starter_code'] try: @@ -71,7 +74,6 @@ class TACODataset(BaseDataset): for key in new_samples[0].keys() } new_dataset[split] = Dataset.from_dict(new_data) - # num_repeats duplicate # train_repeated = [] test_repeated = [] @@ -84,7 +86,6 @@ class TACODataset(BaseDataset): # train_repeated # ) dataset_test_repeated = new_dataset['test'].from_list(test_repeated) - return DatasetDict({ # 'train': dataset_train_repeated, 'test': dataset_test_repeated diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 8a9375a7..4b35b160 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -42,5 +42,6 @@ from .vllm import VLLM # noqa: F401 from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401 from .xunfei_api import XunFei, XunFeiSpark # noqa: F401 from .yayi_api import Yayi # noqa: F401 +from .yi_api import YiAPI # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401 from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401 diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index c88d2c4c..fc758b31 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -64,7 +64,7 @@ def _convert_chat_messages(inputs): for _input in inputs: messages = [] if isinstance(_input, str): - messages.append({'role': 'HUMAN', 'content': _input}) + messages.append({'role': 'user', 'content': _input}) else: for item in _input: role = { @@ -165,7 +165,7 @@ class HuggingFacewithChatTemplate(BaseModel): def _load_tokenizer(self, path: Optional[str], kwargs: dict, pad_token_id: Optional[int] = None): from transformers import AutoTokenizer, GenerationConfig - DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', use_fast=False, trust_remote_code=True) + DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', trust_remote_code=True) tokenizer_kwargs = DEFAULT_TOKENIZER_KWARGS tokenizer_kwargs.update(kwargs) self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_kwargs) @@ -199,6 +199,7 @@ class HuggingFacewithChatTemplate(BaseModel): model_kwargs = DEFAULT_MODEL_KWARGS model_kwargs.update(kwargs) model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) + self.logger.debug(f'using model_kwargs: {model_kwargs}') try: self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) diff --git a/opencompass/models/lightllm_api.py b/opencompass/models/lightllm_api.py index 58fbae72..b0062525 100644 --- a/opencompass/models/lightllm_api.py +++ b/opencompass/models/lightllm_api.py @@ -68,6 +68,7 @@ class LightllmAPI(BaseModel): self.wait() header = {'content-type': 'application/json'} try: + self.logger.debug(f'input: {input}') data = dict(inputs=input, parameters=self.generation_kwargs) raw_response = requests.post(self.url, headers=header, @@ -80,6 +81,7 @@ class LightllmAPI(BaseModel): generated_text = response['generated_text'] if isinstance(generated_text, list): generated_text = generated_text[0] + self.logger.debug(f'generated_text: {generated_text}') return generated_text except requests.JSONDecodeError: self.logger.error('JsonDecode error, got', diff --git a/opencompass/models/yi_api.py b/opencompass/models/yi_api.py new file mode 100644 index 00000000..d9bc70a7 --- /dev/null +++ b/opencompass/models/yi_api.py @@ -0,0 +1,178 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class YiAPI(BaseAPIModel): + """Model wrapper around YiAPI. + + Documentation: + + Args: + path (str): The name of YiAPI model. + e.g. `moonshot-v1-32k` + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + key: str, + url: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + system_prompt: str = '', + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + key, + } + self.url = url + self.model = path + self.system_prompt = system_prompt + + def generate( + self, + inputs: List[PromptType], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[PromptType]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: PromptType, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (PromptType): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + msg_buffer, last_role = [], None + for item in input: + item['role'] = 'assistant' if item['role'] == 'BOT' else 'user' + if item['role'] != last_role and last_role is not None: + messages.append({ + 'content': '\n'.join(msg_buffer), + 'role': last_role + }) + msg_buffer = [] + msg_buffer.append(item['prompt']) + last_role = item['role'] + messages.append({ + 'content': '\n'.join(msg_buffer), + 'role': last_role + }) + + if self.system_prompt: + system = {'role': 'system', 'content': self.system_prompt} + messages.insert(0, system) + + data = {'model': self.model, 'messages': messages} + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + try: + raw_response = requests.request('POST', + url=self.url, + headers=self.headers, + json=data) + except Exception as err: + print('Request Error:{}'.format(err)) + time.sleep(2) + continue + + try: + response = raw_response.json() + except Exception as err: + print('Response Error:{}'.format(err)) + response = None + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + + if raw_response.status_code == 200: + # msg = json.load(response.text) + # response + msg = response['choices'][0]['message']['content'] + self.logger.debug(f'Generated: {msg}') + return msg + + if raw_response.status_code == 401: + print('请求被拒绝 api_key错误') + continue + elif raw_response.status_code == 400: + print(messages, response) + print('请求失败,状态码:', raw_response) + msg = 'The request was rejected because high risk' + return msg + elif raw_response.status_code == 429: + print(messages, response) + print('请求失败,状态码:', raw_response) + time.sleep(5) + continue + else: + print(messages, response) + print('请求失败,状态码:', raw_response) + time.sleep(1) + + max_num_retries += 1 + + raise RuntimeError(raw_response) diff --git a/opencompass/runners/dlc.py b/opencompass/runners/dlc.py index bc4ca0dd..0590d223 100644 --- a/opencompass/runners/dlc.py +++ b/opencompass/runners/dlc.py @@ -161,17 +161,19 @@ class DLCRunner(BaseRunner): shell_cmd += 'umask 0000; ' shell_cmd += '{task_cmd}' - tmpl = ('dlc create job' - f" --command '{shell_cmd}'" - f' --name {task_name[:512]}' - ' --kind BatchJob' - f" -c {self.aliyun_cfg['dlc_config_path']}" - f" --workspace_id {self.aliyun_cfg['workspace_id']}" - ' --worker_count 1' - f' --worker_cpu {max(num_gpus * 8, 12)}' - f' --worker_gpu {num_gpus}' - f' --worker_memory {max(num_gpus * 128, 192)}' - f" --worker_image {self.aliyun_cfg['worker_image']}") + tmpl = ( + 'dlc submit pytorchjob' + f" --command '{shell_cmd}'" + f' --name {task_name[:512]}' + f" --config {self.aliyun_cfg['dlc_config_path']}" + f" --workspace_id {self.aliyun_cfg['workspace_id']}" + f" --resource_id {self.aliyun_cfg['resource_id']}" + ' --workers 1' + f' --worker_cpu {max(num_gpus * 8, 12)}' + f' --worker_gpu {num_gpus}' + f' --worker_memory {max(num_gpus * 128, 192)}Gi' + f" --worker_image {self.aliyun_cfg['worker_image']}" + f" --data_sources {','.join(self.aliyun_cfg['data_sources'])}") get_cmd = partial(task.get_command, cfg_path=param_file, template=tmpl) @@ -219,14 +221,9 @@ class DLCRunner(BaseRunner): pri_time = None initial_time = datetime.datetime.now() - url = 'http://pai-console.cb210e3f99cd7403f8de2a630dcc99fc3.cn-wulanchabu.alicontainer.com' # noqa: E501 + url = f"https://pai.console.aliyun.com/?regionId=cn-wulanchabu&workspaceId={self.aliyun_cfg['workspace_id']}#/dlc/jobs/{job_id}" # noqa: E501 logger = get_logger() - logger.debug('') - logger.debug('*' * 168) - logger.debug( - f'{url}/index?workspaceId={self.aliyun_cfg["workspace_id"]}#/dlc2/job/{job_id}/detail' # noqa: E501 - ) - logger.debug('*' * 168) + logger.debug('\n' + '*' * 168 + '\n' + url + '\n' + '*' * 168) while True: # 1. Avoid to request dlc too frequently. @@ -264,7 +261,7 @@ class DLCRunner(BaseRunner): cur_time = (pod_create_time + elasped_time).strftime('%Y-%m-%dT%H:%M:%SZ') logs_cmd = ('dlc logs' - f' {job_id} {job_id}-worker-0' + f' {job_id} {job_id}-master-0' f" -c {self.aliyun_cfg['dlc_config_path']}" f' --start_time {pri_time}' f' --end_time {cur_time}') diff --git a/opencompass/utils/run.py b/opencompass/utils/run.py index 816c1655..fbcf60c3 100644 --- a/opencompass/utils/run.py +++ b/opencompass/utils/run.py @@ -84,7 +84,7 @@ def get_config_from_arg(args) -> Config: # set infer accelerator if needed if args.accelerator in ['vllm', 'lmdeploy']: config['models'] = change_accelerator(config['models'], args.accelerator) - if 'eval' in config and 'partitioner' in config['eval']: + if config.get('eval', {}).get('partitioner', {}).get('models') is not None: config['eval']['partitioner']['models'] = change_accelerator(config['eval']['partitioner']['models'], args.accelerator) if config.get('eval', {}).get('partitioner', {}).get('judge_models') is not None: config['eval']['partitioner']['judge_models'] = change_accelerator(config['eval']['partitioner']['judge_models'], args.accelerator)