mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Update configs and code
This commit is contained in:
parent
c94cc94348
commit
04dd01a235
4
configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py
Normal file
4
configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .CLUE_CMRC_gen_72a8d5 import CMRC_datasets # noqa: F401, F403
|
4
configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py
Normal file
4
configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .FewCLUE_eprstmt_ppl_d3c387 import eprstmt_datasets # noqa: F401, F403
|
4
configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py
Normal file
4
configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .SuperGLUE_AX_b_ppl_4bd960 import AX_b_datasets # noqa: F401, F403
|
4
configs/datasets/XCOPA/XCOPA_ppl.py
Normal file
4
configs/datasets/XCOPA/XCOPA_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .XCOPA_ppl_6215c4 import XCOPA_datasets # noqa: F401, F403
|
4
configs/datasets/agieval/agieval_gen.py
Normal file
4
configs/datasets/agieval/agieval_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .agieval_gen_dc7dae import agieval_datasets # noqa: F401, F403
|
4
configs/datasets/flores/flores_gen.py
Normal file
4
configs/datasets/flores/flores_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .flores_gen_8eb9ca import flores_datasets # noqa: F401, F403
|
4
configs/datasets/mmlu/mmlu_ppl.py
Normal file
4
configs/datasets/mmlu/mmlu_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .mmlu_ppl_c6bbe6 import mmlu_datasets # noqa: F401, F403
|
4
configs/datasets/summscreen/summscreen_gen.py
Normal file
4
configs/datasets/summscreen/summscreen_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .summscreen_gen_997ee2 import summscreen_datasets # noqa: F401, F403
|
4
configs/datasets/winogrande/winogrande_gen.py
Normal file
4
configs/datasets/winogrande/winogrande_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .winogrande_gen_c19d87 import winogrande_datasets # noqa: F401, F403
|
1
docs/en/advanced_guides/new_model.md
Normal file
1
docs/en/advanced_guides/new_model.md
Normal file
@ -0,0 +1 @@
|
||||
# New A Model
|
1
docs/zh_cn/user_guides/evaluation.md
Normal file
1
docs/zh_cn/user_guides/evaluation.md
Normal file
@ -0,0 +1 @@
|
||||
# 评估策略
|
14
opencompass/datasets/strategyqa.py
Normal file
14
opencompass/datasets/strategyqa.py
Normal file
@ -0,0 +1,14 @@
|
||||
from opencompass.registry import TEXT_POSTPROCESSORS
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('strategyqa')
|
||||
def strategyqa_pred_postprocess(text: str) -> str:
|
||||
text = text.split('\n\n')[0]
|
||||
strategyqa_pre = text.split('So the answer is ')[-1].strip().replace(
|
||||
'.', '')
|
||||
return strategyqa_pre
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('strategyqa_dataset')
|
||||
def strategyqa_dataset_postprocess(text: str) -> str:
|
||||
return 'yes' if str(text) == 'True' else 'no'
|
6
opencompass/models/__init__.py
Normal file
6
opencompass/models/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .base import BaseModel, LMTemplateParser # noqa
|
||||
from .base_api import APITemplateParser, BaseAPIModel # noqa
|
||||
from .glm import GLM130B # noqa: F401, F403
|
||||
from .huggingface import HuggingFace # noqa: F401, F403
|
||||
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
||||
from .openai_api import OpenAI # noqa: F401
|
@ -1,212 +0,0 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
@MODELS.register_module(name=['XunFei'])
|
||||
class XunFei(BaseAPIModel):
|
||||
"""Model wrapper around OpenAI-AllesAPIN.
|
||||
|
||||
Args:
|
||||
path (str): The name of OpenAI's model.
|
||||
max_seq_len (int): Unused here.
|
||||
call_interval (float): The minimum time interval in seconds between two
|
||||
calls to the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
appid: str,
|
||||
api_secret: str,
|
||||
api_key: str,
|
||||
query_per_second: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
import ssl
|
||||
import threading
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import websocket
|
||||
self.urlencode = urlencode
|
||||
self.websocket = websocket
|
||||
self.websocket.enableTrace(False)
|
||||
self.threading = threading
|
||||
self.ssl = ssl
|
||||
|
||||
# weird auth keys
|
||||
self.APISecret = api_secret
|
||||
self.APIKey = api_key
|
||||
self.appid = appid
|
||||
self.hostname = urlparse(path).netloc
|
||||
self.hostpath = urlparse(path).path
|
||||
|
||||
self.headers = {
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
|
||||
def get_url(self):
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
cur_time = datetime.now()
|
||||
date = format_date_time(mktime(cur_time.timetuple()))
|
||||
tmp = f'host: {self.hostname}\n'
|
||||
tmp += 'date: ' + date + '\n'
|
||||
tmp += 'GET ' + self.hostpath + ' HTTP/1.1'
|
||||
import hashlib
|
||||
import hmac
|
||||
tmp_sha = hmac.new(self.APISecret.encode('utf-8'),
|
||||
tmp.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
import base64
|
||||
signature = base64.b64encode(tmp_sha).decode(encoding='utf-8')
|
||||
authorization_origin = (f'api_key="{self.APIKey}", '
|
||||
'algorithm="hmac-sha256", '
|
||||
'headers="host date request-line", '
|
||||
f'signature="{signature}"')
|
||||
authorization = base64.b64encode(
|
||||
authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
v = {
|
||||
'authorization': authorization,
|
||||
'date': date,
|
||||
'host': self.hostname
|
||||
}
|
||||
url = self.path + '?' + self.urlencode(v)
|
||||
return url
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
return results
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
input: str or PromptList,
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given an input.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
# FIXME: messages only contains the last input
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
# word_ctr = 0
|
||||
# TODO: Implement truncation in PromptList
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
# if word_ctr >= self.max_seq_len:
|
||||
# break
|
||||
# if len(msg['content']) + word_ctr > self.max_seq_len:
|
||||
# msg['content'] = msg['content'][word_ctr -
|
||||
# self.max_seq_len:]
|
||||
# word_ctr += len(msg['content'])
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
messages.append(msg)
|
||||
# in case the word break results in even number of messages
|
||||
# if len(messages) > 0 and len(messages) % 2 == 0:
|
||||
# messages = messages[:-1]
|
||||
|
||||
data = {
|
||||
'header': {
|
||||
'app_id': self.appid,
|
||||
},
|
||||
'parameter': {
|
||||
'chat': {
|
||||
'domain': 'general',
|
||||
'max_tokens': max_out_len,
|
||||
}
|
||||
},
|
||||
'payload': {
|
||||
'message': {
|
||||
'text': messages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg = ''
|
||||
err_code = None
|
||||
err_data = None
|
||||
content_received = self.threading.Event()
|
||||
|
||||
def on_open(ws):
|
||||
nonlocal data
|
||||
ws.send(json.dumps(data))
|
||||
|
||||
def on_message(ws, message):
|
||||
nonlocal msg, err_code, err_data, content_received
|
||||
err_data = json.loads(message)
|
||||
err_code = err_data['header']['code']
|
||||
if err_code != 0:
|
||||
content_received.set()
|
||||
ws.close()
|
||||
else:
|
||||
choices = err_data['payload']['choices']
|
||||
status = choices['status']
|
||||
msg += choices['text'][0]['content']
|
||||
if status == 2:
|
||||
content_received.set()
|
||||
ws.close()
|
||||
|
||||
ws = self.websocket.WebSocketApp(self.get_url(),
|
||||
on_message=on_message,
|
||||
on_open=on_open)
|
||||
ws.appid = self.appid
|
||||
ws.question = messages[-1]['content']
|
||||
|
||||
for _ in range(self.retry):
|
||||
self.wait()
|
||||
ws.run_forever(sslopt={'cert_reqs': self.ssl.CERT_NONE})
|
||||
content_received.wait()
|
||||
if err_code == 0:
|
||||
return msg.strip()
|
||||
|
||||
if err_code == 10013:
|
||||
return err_data['header']['message']
|
||||
raise RuntimeError(f'Code: {err_code}, data: {err_data}')
|
3
opencompass/runners/__init__.py
Normal file
3
opencompass/runners/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .dlc import * # noqa: F401, F403
|
||||
from .local import * # noqa: F401, F403
|
||||
from .slurm import * # noqa: F401, F403
|
Loading…
Reference in New Issue
Block a user