From 04dd01a2359419c7a3042c6ea5268f460da580b7 Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Wed, 5 Jul 2023 11:45:08 +0800 Subject: [PATCH] Update configs and code --- configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py | 4 + .../FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py | 4 + .../SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py | 4 + configs/datasets/XCOPA/XCOPA_ppl.py | 4 + configs/datasets/agieval/agieval_gen.py | 4 + configs/datasets/flores/flores_gen.py | 4 + configs/datasets/mmlu/mmlu_ppl.py | 4 + configs/datasets/summscreen/summscreen_gen.py | 4 + configs/datasets/winogrande/winogrande_gen.py | 4 + docs/en/advanced_guides/new_model.md | 1 + docs/zh_cn/user_guides/evaluation.md | 1 + opencompass/datasets/strategyqa.py | 14 ++ opencompass/models/__init__.py | 6 + opencompass/models/xunfei_api.py | 212 ------------------ opencompass/runners/__init__.py | 3 + 15 files changed, 61 insertions(+), 212 deletions(-) create mode 100644 configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py create mode 100644 configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py create mode 100644 configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py create mode 100644 configs/datasets/XCOPA/XCOPA_ppl.py create mode 100644 configs/datasets/agieval/agieval_gen.py create mode 100644 configs/datasets/flores/flores_gen.py create mode 100644 configs/datasets/mmlu/mmlu_ppl.py create mode 100644 configs/datasets/summscreen/summscreen_gen.py create mode 100644 configs/datasets/winogrande/winogrande_gen.py create mode 100644 docs/en/advanced_guides/new_model.md create mode 100644 docs/zh_cn/user_guides/evaluation.md create mode 100644 opencompass/datasets/strategyqa.py create mode 100644 opencompass/models/__init__.py delete mode 100644 opencompass/models/xunfei_api.py create mode 100644 opencompass/runners/__init__.py diff --git a/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py b/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py new file mode 100644 index 00000000..f2e0b9f0 --- /dev/null +++ b/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .CLUE_CMRC_gen_72a8d5 import CMRC_datasets # noqa: F401, F403 diff --git a/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py b/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py new file mode 100644 index 00000000..2037bc71 --- /dev/null +++ b/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .FewCLUE_eprstmt_ppl_d3c387 import eprstmt_datasets # noqa: F401, F403 diff --git a/configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py b/configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py new file mode 100644 index 00000000..899a7729 --- /dev/null +++ b/configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .SuperGLUE_AX_b_ppl_4bd960 import AX_b_datasets # noqa: F401, F403 diff --git a/configs/datasets/XCOPA/XCOPA_ppl.py b/configs/datasets/XCOPA/XCOPA_ppl.py new file mode 100644 index 00000000..a8f777df --- /dev/null +++ b/configs/datasets/XCOPA/XCOPA_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .XCOPA_ppl_6215c4 import XCOPA_datasets # noqa: F401, F403 diff --git a/configs/datasets/agieval/agieval_gen.py b/configs/datasets/agieval/agieval_gen.py new file mode 100644 index 00000000..27940dcd --- /dev/null +++ b/configs/datasets/agieval/agieval_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .agieval_gen_dc7dae import agieval_datasets # noqa: F401, F403 diff --git a/configs/datasets/flores/flores_gen.py b/configs/datasets/flores/flores_gen.py new file mode 100644 index 00000000..b36d3c5c --- /dev/null +++ b/configs/datasets/flores/flores_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .flores_gen_8eb9ca import flores_datasets # noqa: F401, F403 diff --git a/configs/datasets/mmlu/mmlu_ppl.py b/configs/datasets/mmlu/mmlu_ppl.py new file mode 100644 index 00000000..73c3161d --- /dev/null +++ b/configs/datasets/mmlu/mmlu_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .mmlu_ppl_c6bbe6 import mmlu_datasets # noqa: F401, F403 diff --git a/configs/datasets/summscreen/summscreen_gen.py b/configs/datasets/summscreen/summscreen_gen.py new file mode 100644 index 00000000..296abf2c --- /dev/null +++ b/configs/datasets/summscreen/summscreen_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .summscreen_gen_997ee2 import summscreen_datasets # noqa: F401, F403 diff --git a/configs/datasets/winogrande/winogrande_gen.py b/configs/datasets/winogrande/winogrande_gen.py new file mode 100644 index 00000000..ef8475b1 --- /dev/null +++ b/configs/datasets/winogrande/winogrande_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .winogrande_gen_c19d87 import winogrande_datasets # noqa: F401, F403 diff --git a/docs/en/advanced_guides/new_model.md b/docs/en/advanced_guides/new_model.md new file mode 100644 index 00000000..94354994 --- /dev/null +++ b/docs/en/advanced_guides/new_model.md @@ -0,0 +1 @@ +# New A Model diff --git a/docs/zh_cn/user_guides/evaluation.md b/docs/zh_cn/user_guides/evaluation.md new file mode 100644 index 00000000..4e79927f --- /dev/null +++ b/docs/zh_cn/user_guides/evaluation.md @@ -0,0 +1 @@ +# 评估策略 diff --git a/opencompass/datasets/strategyqa.py b/opencompass/datasets/strategyqa.py new file mode 100644 index 00000000..f0f56ec2 --- /dev/null +++ b/opencompass/datasets/strategyqa.py @@ -0,0 +1,14 @@ +from opencompass.registry import TEXT_POSTPROCESSORS + + +@TEXT_POSTPROCESSORS.register_module('strategyqa') +def strategyqa_pred_postprocess(text: str) -> str: + text = text.split('\n\n')[0] + strategyqa_pre = text.split('So the answer is ')[-1].strip().replace( + '.', '') + return strategyqa_pre + + +@TEXT_POSTPROCESSORS.register_module('strategyqa_dataset') +def strategyqa_dataset_postprocess(text: str) -> str: + return 'yes' if str(text) == 'True' else 'no' diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py new file mode 100644 index 00000000..fa46042e --- /dev/null +++ b/opencompass/models/__init__.py @@ -0,0 +1,6 @@ +from .base import BaseModel, LMTemplateParser # noqa +from .base_api import APITemplateParser, BaseAPIModel # noqa +from .glm import GLM130B # noqa: F401, F403 +from .huggingface import HuggingFace # noqa: F401, F403 +from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 +from .openai_api import OpenAI # noqa: F401 diff --git a/opencompass/models/xunfei_api.py b/opencompass/models/xunfei_api.py deleted file mode 100644 index 3de606e8..00000000 --- a/opencompass/models/xunfei_api.py +++ /dev/null @@ -1,212 +0,0 @@ -import json -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Union - -from opencompass.registry import MODELS -from opencompass.utils.prompt import PromptList - -from .base_api import BaseAPIModel - -PromptType = Union[PromptList, str] - - -@MODELS.register_module(name=['XunFei']) -class XunFei(BaseAPIModel): - """Model wrapper around OpenAI-AllesAPIN. - - Args: - path (str): The name of OpenAI's model. - max_seq_len (int): Unused here. - call_interval (float): The minimum time interval in seconds between two - calls to the API. Defaults to 1. - retry (int): Number of retires if the API call fails. Defaults to 2. - """ - - def __init__(self, - path: str, - appid: str, - api_secret: str, - api_key: str, - query_per_second: int = 2, - max_seq_len: int = 2048, - meta_template: Optional[Dict] = None, - retry: int = 2): - super().__init__(path=path, - max_seq_len=max_seq_len, - query_per_second=query_per_second, - meta_template=meta_template, - retry=retry) - import ssl - import threading - from urllib.parse import urlencode, urlparse - - import websocket - self.urlencode = urlencode - self.websocket = websocket - self.websocket.enableTrace(False) - self.threading = threading - self.ssl = ssl - - # weird auth keys - self.APISecret = api_secret - self.APIKey = api_key - self.appid = appid - self.hostname = urlparse(path).netloc - self.hostpath = urlparse(path).path - - self.headers = { - 'content-type': 'application/json', - } - - def get_url(self): - from datetime import datetime - from time import mktime - from wsgiref.handlers import format_date_time - - cur_time = datetime.now() - date = format_date_time(mktime(cur_time.timetuple())) - tmp = f'host: {self.hostname}\n' - tmp += 'date: ' + date + '\n' - tmp += 'GET ' + self.hostpath + ' HTTP/1.1' - import hashlib - import hmac - tmp_sha = hmac.new(self.APISecret.encode('utf-8'), - tmp.encode('utf-8'), - digestmod=hashlib.sha256).digest() - import base64 - signature = base64.b64encode(tmp_sha).decode(encoding='utf-8') - authorization_origin = (f'api_key="{self.APIKey}", ' - 'algorithm="hmac-sha256", ' - 'headers="host date request-line", ' - f'signature="{signature}"') - authorization = base64.b64encode( - authorization_origin.encode('utf-8')).decode(encoding='utf-8') - v = { - 'authorization': authorization, - 'date': date, - 'host': self.hostname - } - url = self.path + '?' + self.urlencode(v) - return url - - def generate( - self, - inputs: List[str or PromptList], - max_out_len: int = 512, - ) -> List[str]: - """Generate results given a list of inputs. - - Args: - inputs (List[str or PromptList]): A list of strings or PromptDicts. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - - Returns: - List[str]: A list of generated strings. - """ - with ThreadPoolExecutor() as executor: - results = list( - executor.map(self._generate, inputs, - [max_out_len] * len(inputs))) - return results - - def _generate( - self, - input: str or PromptList, - max_out_len: int = 512, - ) -> List[str]: - """Generate results given an input. - - Args: - inputs (str or PromptList): A string or PromptDict. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - - Returns: - str: The generated string. - """ - assert isinstance(input, (str, PromptList)) - - # FIXME: messages only contains the last input - if isinstance(input, str): - messages = [{'role': 'user', 'content': input}] - else: - messages = [] - # word_ctr = 0 - # TODO: Implement truncation in PromptList - for item in input: - msg = {'content': item['prompt']} - # if word_ctr >= self.max_seq_len: - # break - # if len(msg['content']) + word_ctr > self.max_seq_len: - # msg['content'] = msg['content'][word_ctr - - # self.max_seq_len:] - # word_ctr += len(msg['content']) - if item['role'] == 'HUMAN': - msg['role'] = 'user' - elif item['role'] == 'BOT': - msg['role'] = 'assistant' - messages.append(msg) - # in case the word break results in even number of messages - # if len(messages) > 0 and len(messages) % 2 == 0: - # messages = messages[:-1] - - data = { - 'header': { - 'app_id': self.appid, - }, - 'parameter': { - 'chat': { - 'domain': 'general', - 'max_tokens': max_out_len, - } - }, - 'payload': { - 'message': { - 'text': messages - } - } - } - - msg = '' - err_code = None - err_data = None - content_received = self.threading.Event() - - def on_open(ws): - nonlocal data - ws.send(json.dumps(data)) - - def on_message(ws, message): - nonlocal msg, err_code, err_data, content_received - err_data = json.loads(message) - err_code = err_data['header']['code'] - if err_code != 0: - content_received.set() - ws.close() - else: - choices = err_data['payload']['choices'] - status = choices['status'] - msg += choices['text'][0]['content'] - if status == 2: - content_received.set() - ws.close() - - ws = self.websocket.WebSocketApp(self.get_url(), - on_message=on_message, - on_open=on_open) - ws.appid = self.appid - ws.question = messages[-1]['content'] - - for _ in range(self.retry): - self.wait() - ws.run_forever(sslopt={'cert_reqs': self.ssl.CERT_NONE}) - content_received.wait() - if err_code == 0: - return msg.strip() - - if err_code == 10013: - return err_data['header']['message'] - raise RuntimeError(f'Code: {err_code}, data: {err_data}') diff --git a/opencompass/runners/__init__.py b/opencompass/runners/__init__.py new file mode 100644 index 00000000..5e2bb1ef --- /dev/null +++ b/opencompass/runners/__init__.py @@ -0,0 +1,3 @@ +from .dlc import * # noqa: F401, F403 +from .local import * # noqa: F401, F403 +from .slurm import * # noqa: F401, F403