mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Update pip install (#1324)
* [Feature] Update pip install * Update Configuration * Update * Update * Update * Update Internal Config * Update collect env
This commit is contained in:
parent
edab1c07ba
commit
704853e5e7
2
MANIFEST.in
Normal file
2
MANIFEST.in
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
recursive-include opencompass/configs *.py *.yml *.json *.txt *.md
|
||||||
|
recursive-include opencompass/openicl/icl_evaluator/hf_metrics *.py
|
@ -1,7 +1,7 @@
|
|||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.datasets import GPQASimpleEvalDataset, GPQA_Simple_Eval_postprocess, GPQAEvaluator
|
from opencompass.datasets import GPQADataset, GPQA_Simple_Eval_postprocess, GPQAEvaluator
|
||||||
|
|
||||||
# openai_simple_eval prompt
|
# openai_simple_eval prompt
|
||||||
align_prompt = """
|
align_prompt = """
|
||||||
@ -43,7 +43,7 @@ for split in list(gpqa_subsets.keys()):
|
|||||||
gpqa_datasets.append(
|
gpqa_datasets.append(
|
||||||
dict(
|
dict(
|
||||||
abbr='GPQA_' + split,
|
abbr='GPQA_' + split,
|
||||||
type=GPQASimpleEvalDataset,
|
type=GPQADataset,
|
||||||
path='./data/gpqa/',
|
path='./data/gpqa/',
|
||||||
name=gpqa_subsets[split],
|
name=gpqa_subsets[split],
|
||||||
reader_cfg=gpqa_reader_cfg,
|
reader_cfg=gpqa_reader_cfg,
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '0.2.6'
|
__version__ = '0.2.7rc1'
|
||||||
|
329
opencompass/datasets/LCBench.py
Normal file
329
opencompass/datasets/LCBench.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
import contextlib
|
||||||
|
import io
|
||||||
|
import itertools
|
||||||
|
import multiprocessing
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from typing import List, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||||
|
|
||||||
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
|
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||||
|
|
||||||
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
|
class LCDataset(BaseDataset):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: str, num_repeats: int = 1, difficulty='ALL'):
|
||||||
|
"""Load LC dataset for pass k mode.
|
||||||
|
|
||||||
|
Note that you can use num_repeats > 1 when your model does not support
|
||||||
|
`num_return_sequence` in generation, otherwise use the raw
|
||||||
|
LC dataset and set `num_return_sequence` in model config to
|
||||||
|
generate multiple responses for testing pass@k>1.
|
||||||
|
|
||||||
|
It better to change your dataset abbr correspondingly if you want to
|
||||||
|
change num_repeats>1, otherwise the number in
|
||||||
|
`.cache/dataset_size.json` might be inconsistent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_repeats(int): Number of repetition for this dataset to get
|
||||||
|
multiple responses in special cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def processing_test(example):
|
||||||
|
example['test_case'] = example['test_list']
|
||||||
|
example['test_list'] = '\n'.join(example['test_list'])
|
||||||
|
example['test_column'] = dict(test_list_2=example['test_list'],
|
||||||
|
task_id=example['Contest id'])
|
||||||
|
return example
|
||||||
|
|
||||||
|
train = load_dataset('json', data_files=path,
|
||||||
|
split='train[:5]').map(processing_test)
|
||||||
|
test = load_dataset('json', data_files=path,
|
||||||
|
split='train[5:]').map(processing_test)
|
||||||
|
if not difficulty == 'ALL':
|
||||||
|
train = train.filter(
|
||||||
|
lambda example: example['Difficulty'] == difficulty)
|
||||||
|
test = test.filter(
|
||||||
|
lambda example: example['Difficulty'] == difficulty)
|
||||||
|
test = concatenate_datasets([test] * num_repeats)
|
||||||
|
return DatasetDict({'train': train, 'test': test})
|
||||||
|
|
||||||
|
|
||||||
|
class TimeOutException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def swallow_io():
|
||||||
|
stream = WriteOnlyStringIO()
|
||||||
|
with contextlib.redirect_stdout(stream):
|
||||||
|
with contextlib.redirect_stderr(stream):
|
||||||
|
with redirect_stdin(stream):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def time_limit(seconds: float):
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
raise TimeOutException('Time out!')
|
||||||
|
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||||
|
signal.signal(signal.SIGALRM, signal_handler)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteOnlyStringIO(io.StringIO):
|
||||||
|
"""StringIO that throws an exception when it's read from."""
|
||||||
|
|
||||||
|
def read(self, *args, **kwargs):
|
||||||
|
raise IOError
|
||||||
|
|
||||||
|
def readline(self, *args, **kwargs):
|
||||||
|
raise IOError
|
||||||
|
|
||||||
|
def readlines(self, *args, **kwargs):
|
||||||
|
raise IOError
|
||||||
|
|
||||||
|
def readable(self, *args, **kwargs):
|
||||||
|
"""Returns True if the IO object can be read."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class redirect_stdin(contextlib._RedirectStream): # type: ignore
|
||||||
|
_stream = 'stdin'
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_EVALUATORS.register_module()
|
||||||
|
class LCEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
|
def score(self, predictions, references):
|
||||||
|
if len(predictions) != len(references):
|
||||||
|
return {'error': 'preds and refrs have different length'}
|
||||||
|
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
|
||||||
|
details = {}
|
||||||
|
with ProcessPoolExecutor() as executor:
|
||||||
|
futures = []
|
||||||
|
for i, (refer, pred) in enumerate(zip(references, predictions)):
|
||||||
|
pred = self._process_answer(pred)
|
||||||
|
programs = self._process_test(refer, pred)
|
||||||
|
future = executor.submit(execution, programs, i, 3)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||||
|
index, ret = future.result()
|
||||||
|
result[ret] += 1
|
||||||
|
details[str(index)] = {
|
||||||
|
'programs': predictions[index],
|
||||||
|
'result': ret,
|
||||||
|
'is_correct': ret == 'pass',
|
||||||
|
}
|
||||||
|
|
||||||
|
result['score'] = result['pass'] / len(predictions) * 100
|
||||||
|
result['details'] = details
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _process_answer(self, text):
|
||||||
|
try:
|
||||||
|
# for chatGLM related text
|
||||||
|
eval_text = eval(text)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if isinstance(eval_text, str):
|
||||||
|
text = eval_text
|
||||||
|
# deal with code block
|
||||||
|
if '```' in text:
|
||||||
|
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||||
|
if len(blocks) == 0:
|
||||||
|
text = text.split('```')[1] # fall back to default strategy
|
||||||
|
else:
|
||||||
|
text = blocks[0] # fetch the first code block
|
||||||
|
if not text.startswith('\n'): # in case starting with ```xxx
|
||||||
|
text = text[max(text.find('\n') + 1, 0):]
|
||||||
|
text = text.strip()
|
||||||
|
match = re.search(r"('\s*|)(\[DONE\]|DONE)", text)
|
||||||
|
if match:
|
||||||
|
text = text[:match.start()]
|
||||||
|
match = re.search(r"(\[BEGIN\]|BEGIN)('\s*|)", text)
|
||||||
|
if match:
|
||||||
|
text = text[match.end():]
|
||||||
|
text = text.strip()
|
||||||
|
if text.startswith("'"):
|
||||||
|
text = text[1:]
|
||||||
|
if text.endswith("'"):
|
||||||
|
text = text[:-1]
|
||||||
|
text = text.replace('\\', '')
|
||||||
|
match = re.search(r'```python(.*)```', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
text = match.group(1).strip().split('```')[0].strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _process_test(self, test_case, pred):
|
||||||
|
formatted = pred + '\n'
|
||||||
|
formatted += test_case
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
|
||||||
|
def execution(programs, task_id, timeout):
|
||||||
|
"""Execution function for running generation code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
programs(str): Python code to be executed.
|
||||||
|
task_id(int): Task id of the current example.
|
||||||
|
timeout(int): Time limit for execution, avoid unnecessary
|
||||||
|
blocking.
|
||||||
|
|
||||||
|
In pass@k scenario, a lot of programs should be executed.
|
||||||
|
Some internal error cannot be handled properly, such as
|
||||||
|
`RecursionError` might cause system break. It is better to
|
||||||
|
separate the execution in thread or multiprocess to better
|
||||||
|
control the process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _execution(programs, timeout):
|
||||||
|
try:
|
||||||
|
# Add exec globals to prevent the exec to raise
|
||||||
|
# unnecessary NameError for correct answer
|
||||||
|
exec_globals = {}
|
||||||
|
with swallow_io():
|
||||||
|
with time_limit(timeout):
|
||||||
|
exec(programs, exec_globals)
|
||||||
|
key.append('pass')
|
||||||
|
except TimeOutException:
|
||||||
|
key.append('timeout')
|
||||||
|
except AssertionError:
|
||||||
|
key.append('wrong_answer')
|
||||||
|
except BaseException as e:
|
||||||
|
print(e)
|
||||||
|
key.append('failed')
|
||||||
|
|
||||||
|
manager = multiprocessing.Manager()
|
||||||
|
key = manager.list()
|
||||||
|
# `signal` cannot be used in child thread, therefore, we
|
||||||
|
# need to create a process in the thread.
|
||||||
|
p = multiprocessing.Process(target=_execution,
|
||||||
|
args=(programs, timeout - 1))
|
||||||
|
p.start()
|
||||||
|
p.join(timeout=timeout)
|
||||||
|
if p.is_alive():
|
||||||
|
p.kill()
|
||||||
|
# key might not have value if killed
|
||||||
|
return task_id, 'timeout'
|
||||||
|
return task_id, key[0]
|
||||||
|
|
||||||
|
|
||||||
|
class LCPassKEvaluator(LCEvaluator):
|
||||||
|
"""Better use for pass k evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k(Tuple[int]): Choices of Pass@k. Defaults to (1, 10, 100)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, k=(1, 10, 100)) -> None:
|
||||||
|
if not isinstance(k, Sequence):
|
||||||
|
k = (k, )
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def estimate_pass_at_k(
|
||||||
|
num_samples: Union[int, List[int], np.ndarray],
|
||||||
|
num_correct: Union[List[int], np.ndarray],
|
||||||
|
k: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Estimates pass@k of each problem and returns them in an array."""
|
||||||
|
|
||||||
|
def estimator(n: int, c: int, k: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculates 1 - comb(n - c, k) / comb(n, k).
|
||||||
|
"""
|
||||||
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
|
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
||||||
|
|
||||||
|
if isinstance(num_samples, int):
|
||||||
|
num_samples_it = itertools.repeat(num_samples, len(num_correct))
|
||||||
|
else:
|
||||||
|
assert len(num_samples) == len(num_correct)
|
||||||
|
num_samples_it = iter(num_samples)
|
||||||
|
|
||||||
|
return np.array([
|
||||||
|
estimator(int(n), int(c), k)
|
||||||
|
for n, c in zip(num_samples_it, num_correct)
|
||||||
|
])
|
||||||
|
|
||||||
|
def score(self, predictions, references):
|
||||||
|
if len(predictions) != len(references):
|
||||||
|
return {'error': 'preds and refrs have different length'}
|
||||||
|
|
||||||
|
task_pass = defaultdict(int)
|
||||||
|
task_total = defaultdict(int)
|
||||||
|
|
||||||
|
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
|
||||||
|
details = {}
|
||||||
|
with ProcessPoolExecutor() as executor:
|
||||||
|
futures = []
|
||||||
|
index, programs = 0, []
|
||||||
|
for refer, preds in zip(references, predictions):
|
||||||
|
# suits for two case
|
||||||
|
# 1. use repeated dataset
|
||||||
|
# 2. use `num_return_sequences` to generate multiple responses
|
||||||
|
if not isinstance(preds, list):
|
||||||
|
preds = [preds]
|
||||||
|
test_case = refer['test_list_2']
|
||||||
|
task_id = refer['task_id']
|
||||||
|
# create empty task_pass in case all example failed
|
||||||
|
if task_id not in task_pass:
|
||||||
|
task_pass[task_id] = 0
|
||||||
|
for pred in preds:
|
||||||
|
pred = self._process_answer(pred)
|
||||||
|
program = self._process_test(test_case, pred)
|
||||||
|
future = executor.submit(execution, program,
|
||||||
|
(index, task_id), 3)
|
||||||
|
futures.append(future)
|
||||||
|
programs.append(program)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||||
|
(index, task_id), ret = future.result()
|
||||||
|
result[ret] += 1
|
||||||
|
task_total[task_id] += 1
|
||||||
|
is_correct = ret == 'pass'
|
||||||
|
task_pass[task_id] += is_correct
|
||||||
|
details[str(index)] = {
|
||||||
|
'program': programs[index],
|
||||||
|
'task_id': task_id,
|
||||||
|
'result': ret,
|
||||||
|
'is_correct': is_correct,
|
||||||
|
}
|
||||||
|
|
||||||
|
result['details'] = details
|
||||||
|
|
||||||
|
def get_number(tasks):
|
||||||
|
return np.array([
|
||||||
|
task[1] for task in sorted(tasks.items(), key=lambda x: x[0])
|
||||||
|
])
|
||||||
|
|
||||||
|
task_pass = get_number(task_pass)
|
||||||
|
task_total = get_number(task_total)
|
||||||
|
pass_at_k = {
|
||||||
|
f'pass@{k}':
|
||||||
|
self.estimate_pass_at_k(task_total, task_pass, k).mean() * 100
|
||||||
|
for k in self.k if (task_total >= k).all()
|
||||||
|
}
|
||||||
|
result.update(pass_at_k)
|
||||||
|
return result
|
@ -62,6 +62,7 @@ from .jsonl import JsonlDataset # noqa: F401, F403
|
|||||||
from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403
|
from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403
|
||||||
from .lambada import * # noqa: F401, F403
|
from .lambada import * # noqa: F401, F403
|
||||||
from .lawbench import * # noqa: F401, F403
|
from .lawbench import * # noqa: F401, F403
|
||||||
|
from .LCBench import * # noqa: F401, F403
|
||||||
from .lcsts import * # noqa: F401, F403
|
from .lcsts import * # noqa: F401, F403
|
||||||
from .leval import * # noqa: F401, F403
|
from .leval import * # noqa: F401, F403
|
||||||
from .llm_compression import LLMCompressionDataset # noqa: F401, F403
|
from .llm_compression import LLMCompressionDataset # noqa: F401, F403
|
||||||
|
@ -50,9 +50,15 @@ def bbh_freeform_postprocess(text: str) -> str:
|
|||||||
ans_line = ans.split('answer is ')
|
ans_line = ans.split('answer is ')
|
||||||
if len(ans_line) != 1:
|
if len(ans_line) != 1:
|
||||||
ans = ans_line[1].strip()
|
ans = ans_line[1].strip()
|
||||||
ans = ans.split('\n')[0]
|
ans = ans.split('\n')[0].strip()
|
||||||
|
|
||||||
if ans.endswith('.'):
|
if ans.endswith('.'):
|
||||||
ans = ans[:-1]
|
ans = ans[:-1].strip()
|
||||||
|
|
||||||
|
match = re.search(r'\*\*(.*?)\*\*', ans)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,9 @@ from .baidu_api import ERNIEBot # noqa: F401
|
|||||||
from .base import BaseModel, LMTemplateParser # noqa: F401
|
from .base import BaseModel, LMTemplateParser # noqa: F401
|
||||||
from .base_api import APITemplateParser, BaseAPIModel # noqa: F401
|
from .base_api import APITemplateParser, BaseAPIModel # noqa: F401
|
||||||
from .bytedance_api import ByteDance # noqa: F401
|
from .bytedance_api import ByteDance # noqa: F401
|
||||||
|
from .claude_allesapin import ClaudeAllesAPIN # noqa: F401
|
||||||
from .claude_api import Claude # noqa: F401
|
from .claude_api import Claude # noqa: F401
|
||||||
|
from .claude_sdk_api import ClaudeSDK # noqa: F401
|
||||||
from .deepseek_api import DeepseekAPI # noqa: F401
|
from .deepseek_api import DeepseekAPI # noqa: F401
|
||||||
from .doubao_api import Doubao # noqa: F401
|
from .doubao_api import Doubao # noqa: F401
|
||||||
from .gemini_api import Gemini # noqa: F401
|
from .gemini_api import Gemini # noqa: F401
|
||||||
|
150
opencompass/models/claude_allesapin.py
Normal file
150
opencompass/models/claude_allesapin.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from opencompass.registry import MODELS
|
||||||
|
from opencompass.utils import PromptList
|
||||||
|
|
||||||
|
from .base_api import BaseAPIModel
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class ClaudeAllesAPIN(BaseAPIModel):
|
||||||
|
"""Model wrapper around Claude-AllesAPIN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The name of Claude's model.
|
||||||
|
url (str): URL to AllesAPIN.
|
||||||
|
key (str): AllesAPIN key.
|
||||||
|
query_per_second (int): The maximum queries allowed per second
|
||||||
|
between two consecutive calls of the API. Defaults to 1.
|
||||||
|
max_seq_len (int): Unused here.
|
||||||
|
meta_template (Dict, optional): The model's meta prompt
|
||||||
|
template if needed, in case the requirement of injecting or
|
||||||
|
wrapping of any meta instructions.
|
||||||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_api: bool = True
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
path: str,
|
||||||
|
url: str,
|
||||||
|
key: str,
|
||||||
|
query_per_second: int = 1,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
meta_template: Optional[Dict] = None,
|
||||||
|
retry: int = 2):
|
||||||
|
super().__init__(path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
meta_template=meta_template,
|
||||||
|
retry=retry)
|
||||||
|
self.url = url
|
||||||
|
self.headers = {
|
||||||
|
'alles-apin-token': key,
|
||||||
|
'content-type': 'application/json',
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
inputs: List[PromptType],
|
||||||
|
max_out_len: int = 512,
|
||||||
|
**kwargs) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[PromptType]): A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenAGIEval's
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = list(
|
||||||
|
executor.map(self._generate, inputs,
|
||||||
|
[max_out_len] * len(inputs)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate(self, input: PromptType, max_out_len: int) -> str:
|
||||||
|
"""Generate results given an input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (PromptType): A string or PromptDict.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
temperature (float): What sampling temperature to use,
|
||||||
|
between 0 and 2. Higher values like 0.8 will make the output
|
||||||
|
more random, while lower values like 0.2 will make it more
|
||||||
|
focused and deterministic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated string.
|
||||||
|
"""
|
||||||
|
assert isinstance(input, (str, PromptList))
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [{'role': 'user', 'content': input}]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
msg_buffer, last_role = [], None
|
||||||
|
for item in input:
|
||||||
|
item['role'] = 'assistant' if item['role'] == 'BOT' else 'user'
|
||||||
|
if item['role'] != last_role and last_role is not None:
|
||||||
|
messages.append({
|
||||||
|
'content': '\n'.join(msg_buffer),
|
||||||
|
'role': last_role
|
||||||
|
})
|
||||||
|
msg_buffer = []
|
||||||
|
msg_buffer.append(item['prompt'])
|
||||||
|
last_role = item['role']
|
||||||
|
messages.append({
|
||||||
|
'content': '\n'.join(msg_buffer),
|
||||||
|
'role': last_role
|
||||||
|
})
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'model': self.path,
|
||||||
|
'messages': messages,
|
||||||
|
'max_tokens': max_out_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
err_data = []
|
||||||
|
for _ in range(self.retry + 1):
|
||||||
|
self.wait()
|
||||||
|
try:
|
||||||
|
raw_response = requests.post(self.url,
|
||||||
|
headers=self.headers,
|
||||||
|
data=json.dumps(data))
|
||||||
|
except requests.ConnectionError:
|
||||||
|
time.sleep(5)
|
||||||
|
continue
|
||||||
|
except requests.ReadTimeout:
|
||||||
|
time.sleep(5)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
response = raw_response.json()
|
||||||
|
except requests.JSONDecodeError:
|
||||||
|
if 'https://errors.aliyun.com/images' in \
|
||||||
|
raw_response.content.decode():
|
||||||
|
return 'request blocked by allesapin'
|
||||||
|
self.logger.error('JsonDecode error, got',
|
||||||
|
raw_response.content)
|
||||||
|
continue
|
||||||
|
if raw_response.status_code == 200 and response[
|
||||||
|
'msgCode'] == '10000':
|
||||||
|
data = response['data']
|
||||||
|
generated = data['content'][0]['text'].strip()
|
||||||
|
self.logger.debug(f'Generated: {generated}')
|
||||||
|
return generated
|
||||||
|
self.logger.error(response['data'])
|
||||||
|
err_data.append(response['data'])
|
||||||
|
|
||||||
|
raise RuntimeError(err_data)
|
121
opencompass/models/claude_sdk_api.py
Normal file
121
opencompass/models/claude_sdk_api.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from opencompass.registry import MODELS
|
||||||
|
from opencompass.utils import PromptList
|
||||||
|
|
||||||
|
from .base_api import BaseAPIModel
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class ClaudeSDK(BaseAPIModel):
|
||||||
|
"""Model wrapper around Claude SDK API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str): Authorization key.
|
||||||
|
path (str): The model to be used. Defaults to claude-2.
|
||||||
|
query_per_second (int): The maximum queries allowed per second
|
||||||
|
between two consecutive calls of the API. Defaults to 1.
|
||||||
|
max_seq_len (int): Unused here.
|
||||||
|
meta_template (Dict, optional): The model's meta prompt
|
||||||
|
template if needed, in case the requirement of injecting or
|
||||||
|
wrapping of any meta instructions.
|
||||||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
path: str = 'claude-2',
|
||||||
|
query_per_second: int = 2,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
meta_template: Optional[Dict] = None,
|
||||||
|
temperature: Optional[float] = 0.0,
|
||||||
|
retry: int = 2,
|
||||||
|
):
|
||||||
|
super().__init__(path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
meta_template=meta_template,
|
||||||
|
retry=retry)
|
||||||
|
try:
|
||||||
|
from anthropic import Anthropic
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Import anthropic failed. Please install it '
|
||||||
|
'with "pip install anthropic" and try again.')
|
||||||
|
|
||||||
|
self.anthropic = Anthropic(api_key=key)
|
||||||
|
self.model = path
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: List[PromptType],
|
||||||
|
max_out_len: int = 512,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[PromptType]): A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = list(
|
||||||
|
executor.map(self._generate, inputs,
|
||||||
|
[max_out_len] * len(inputs)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
input: PromptType,
|
||||||
|
max_out_len: int = 512,
|
||||||
|
) -> str:
|
||||||
|
"""Generate results given an input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (PromptType): A string or PromptDict.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated string.
|
||||||
|
"""
|
||||||
|
assert isinstance(input, (str, PromptList))
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [{'role': 'user', 'content': input}]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
for item in input:
|
||||||
|
msg = {'content': item['prompt']}
|
||||||
|
if item['role'] == 'HUMAN':
|
||||||
|
msg['role'] = 'user'
|
||||||
|
elif item['role'] == 'BOT':
|
||||||
|
msg['role'] = 'assistant'
|
||||||
|
elif item['role'] == 'SYSTEM':
|
||||||
|
msg['role'] = 'system'
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
num_retries = 0
|
||||||
|
while num_retries < self.retry:
|
||||||
|
self.wait()
|
||||||
|
try:
|
||||||
|
responses = self.anthropic.messages.create(
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=max_out_len,
|
||||||
|
temperature=self.temperature,
|
||||||
|
messages=messages)
|
||||||
|
return responses.content[0].text
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(e)
|
||||||
|
num_retries += 1
|
||||||
|
raise RuntimeError('Calling Claude API failed after retrying for '
|
||||||
|
f'{self.retry} times. Check the logs for details.')
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
@ -60,6 +61,7 @@ class TurboMindModel(BaseModel):
|
|||||||
from lmdeploy.messages import TurbomindEngineConfig
|
from lmdeploy.messages import TurbomindEngineConfig
|
||||||
engine_config = TurbomindEngineConfig(**engine_config)
|
engine_config = TurbomindEngineConfig(**engine_config)
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
assert os.path.exists(path), '{} is not existist'.format(path)
|
||||||
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
|
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
|
||||||
self.tokenizer = tm_model.tokenizer
|
self.tokenizer = tm_model.tokenizer
|
||||||
self.generators = [
|
self.generators = [
|
||||||
|
@ -2,3 +2,4 @@ from .dlc import * # noqa: F401, F403
|
|||||||
from .local import * # noqa: F401, F403
|
from .local import * # noqa: F401, F403
|
||||||
from .slurm import * # noqa: F401, F403
|
from .slurm import * # noqa: F401, F403
|
||||||
from .slurm_sequential import * # noqa: F401, F403
|
from .slurm_sequential import * # noqa: F401, F403
|
||||||
|
from .volc import * # noqa: F401, F403
|
||||||
|
260
opencompass/runners/volc.py
Normal file
260
opencompass/runners/volc.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import yaml
|
||||||
|
from mmengine.config import ConfigDict
|
||||||
|
from mmengine.utils import track_parallel_progress
|
||||||
|
|
||||||
|
from opencompass.registry import RUNNERS, TASKS
|
||||||
|
from opencompass.utils import get_logger
|
||||||
|
|
||||||
|
from .base import BaseRunner
|
||||||
|
|
||||||
|
|
||||||
|
@RUNNERS.register_module()
|
||||||
|
class VOLCRunner(BaseRunner):
|
||||||
|
"""Distributed runner based on Volcano Cloud Cluster (VCC). It will launch
|
||||||
|
multiple tasks in parallel with the 'vcc' command. Please install and
|
||||||
|
configure VCC first before using this runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (ConfigDict): Task type config.
|
||||||
|
volcano_cfg (ConfigDict): Volcano Cloud config.
|
||||||
|
queue_name (str): Name of resource queue.
|
||||||
|
preemptible (bool): Whether to launch task in preemptible way.
|
||||||
|
Default: False
|
||||||
|
priority (bool): Priority of tasks, ranging from 1 to 9.
|
||||||
|
9 means the highest priority. Default: None
|
||||||
|
max_num_workers (int): Max number of workers. Default: 32.
|
||||||
|
retry (int): Number of retries when job failed. Default: 2.
|
||||||
|
debug (bool): Whether to run in debug mode. Default: False.
|
||||||
|
lark_bot_url (str): Lark bot url. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
task: ConfigDict,
|
||||||
|
volcano_cfg: ConfigDict,
|
||||||
|
queue_name: str,
|
||||||
|
preemptible: bool = False,
|
||||||
|
priority: Optional[int] = None,
|
||||||
|
max_num_workers: int = 32,
|
||||||
|
retry: int = 2,
|
||||||
|
debug: bool = False,
|
||||||
|
lark_bot_url: str = None):
|
||||||
|
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||||
|
self.volcano_cfg = volcano_cfg
|
||||||
|
self.max_num_workers = max_num_workers
|
||||||
|
self.retry = retry
|
||||||
|
self.queue_name = queue_name
|
||||||
|
self.preemptible = preemptible
|
||||||
|
self.priority = priority
|
||||||
|
|
||||||
|
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||||
|
"""Launch multiple tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list[dict]): A list of task configs, usually generated by
|
||||||
|
Partitioner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple[str, int]]: A list of (task name, exit code).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.debug:
|
||||||
|
status = track_parallel_progress(self._launch,
|
||||||
|
tasks,
|
||||||
|
nproc=self.max_num_workers,
|
||||||
|
keep_order=False)
|
||||||
|
else:
|
||||||
|
status = [self._launch(task, random_sleep=False) for task in tasks]
|
||||||
|
return status
|
||||||
|
|
||||||
|
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
|
||||||
|
"""Launch a single task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_cfg (ConfigDict): Task config.
|
||||||
|
random_sleep (bool): Whether to sleep for a random time before
|
||||||
|
running the command. This avoids cluster error when launching
|
||||||
|
multiple tasks at the same time. Default: True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, int]: Task name and exit code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_type = self.task_cfg.type
|
||||||
|
if isinstance(self.task_cfg.type, str):
|
||||||
|
task_type = TASKS.get(task_type)
|
||||||
|
task = task_type(task_cfg)
|
||||||
|
num_gpus = task.num_gpus
|
||||||
|
task_name = task.name
|
||||||
|
|
||||||
|
# Build up VCC command
|
||||||
|
pwd = os.getcwd()
|
||||||
|
# Dump task config to file
|
||||||
|
mmengine.mkdir_or_exist('tmp/')
|
||||||
|
param_file = f'{pwd}/tmp/{os.getpid()}_params.py'
|
||||||
|
|
||||||
|
volc_cfg_file = f'{pwd}/tmp/{os.getpid()}_cfg.yaml'
|
||||||
|
volc_cfg = self._choose_flavor(num_gpus)
|
||||||
|
with open(volc_cfg_file, 'w') as fp:
|
||||||
|
yaml.dump(volc_cfg, fp, sort_keys=False)
|
||||||
|
try:
|
||||||
|
task_cfg.dump(param_file)
|
||||||
|
if self.volcano_cfg.get('bashrc_path') is not None:
|
||||||
|
# using user's conda env
|
||||||
|
bashrc_path = self.volcano_cfg['bashrc_path']
|
||||||
|
assert osp.exists(bashrc_path)
|
||||||
|
assert self.volcano_cfg.get('conda_env_name') is not None
|
||||||
|
|
||||||
|
conda_env_name = self.volcano_cfg['conda_env_name']
|
||||||
|
|
||||||
|
shell_cmd = (f'source {self.volcano_cfg["bashrc_path"]}; '
|
||||||
|
f'source activate {conda_env_name}; ')
|
||||||
|
shell_cmd += f'export PYTHONPATH={pwd}:$PYTHONPATH; '
|
||||||
|
else:
|
||||||
|
assert self.volcano_cfg.get('python_env_path') is not None
|
||||||
|
shell_cmd = (
|
||||||
|
f'export PATH={self.volcano_cfg["python_env_path"]}/bin:$PATH; ' # noqa: E501
|
||||||
|
f'export PYTHONPATH={pwd}:$PYTHONPATH; ')
|
||||||
|
|
||||||
|
huggingface_cache = self.volcano_cfg.get('huggingface_cache')
|
||||||
|
if huggingface_cache is not None:
|
||||||
|
# HUGGINGFACE_HUB_CACHE is a Legacy env variable, here we set
|
||||||
|
# `HF_HUB_CACHE` and `HUGGINGFACE_HUB_CACHE` for bc
|
||||||
|
shell_cmd += f'export HF_HUB_CACHE={huggingface_cache}; '
|
||||||
|
shell_cmd += f'export HUGGINGFACE_HUB_CACHE={huggingface_cache}; ' # noqa: E501
|
||||||
|
|
||||||
|
torch_cache = self.volcano_cfg.get('torch_cache')
|
||||||
|
if torch_cache is not None:
|
||||||
|
shell_cmd += f'export TORCH_HOME={torch_cache}; '
|
||||||
|
|
||||||
|
hf_offline = self.volcano_cfg.get('hf_offline', True)
|
||||||
|
|
||||||
|
if hf_offline:
|
||||||
|
shell_cmd += 'export HF_DATASETS_OFFLINE=1; export TRANSFORMERS_OFFLINE=1; export HF_EVALUATE_OFFLINE=1; export HF_HUB_OFFLINE=1; ' # noqa: E501
|
||||||
|
|
||||||
|
hf_endpoint = self.volcano_cfg.get('hf_endpoint')
|
||||||
|
if hf_endpoint is not None:
|
||||||
|
shell_cmd += f'export HF_ENDPOINT={hf_endpoint}; '
|
||||||
|
|
||||||
|
extra_envs = self.volcano_cfg.get('extra_envs')
|
||||||
|
if extra_envs is not None:
|
||||||
|
for extra_env in extra_envs:
|
||||||
|
shell_cmd += f'export {extra_env}; '
|
||||||
|
|
||||||
|
shell_cmd += f'cd {pwd}; '
|
||||||
|
shell_cmd += '{task_cmd}'
|
||||||
|
|
||||||
|
task_name = task_name[:128].replace('[', '-').replace(
|
||||||
|
']', '').replace('/', '-').replace(',',
|
||||||
|
'--').replace('.', '_')
|
||||||
|
tmpl = ('volc ml_task submit'
|
||||||
|
f" --conf '{volc_cfg_file}'"
|
||||||
|
f" --entrypoint '{shell_cmd}'"
|
||||||
|
f' --task_name {task_name}'
|
||||||
|
f' --resource_queue_name {self.queue_name}')
|
||||||
|
if self.preemptible:
|
||||||
|
tmpl += ' --preemptible'
|
||||||
|
if self.priority is not None:
|
||||||
|
tmpl += f' --priority {self.priority}'
|
||||||
|
get_cmd = partial(task.get_command,
|
||||||
|
cfg_path=param_file,
|
||||||
|
template=tmpl)
|
||||||
|
cmd = get_cmd()
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
logger.debug(f'Running command: {cmd}')
|
||||||
|
|
||||||
|
out_path = task.get_log_path(file_extension='txt')
|
||||||
|
mmengine.mkdir_or_exist(osp.split(out_path)[0])
|
||||||
|
|
||||||
|
retry = self.retry
|
||||||
|
while True:
|
||||||
|
if random_sleep:
|
||||||
|
time.sleep(random.randint(0, 10))
|
||||||
|
task_status, returncode = self._run_task(cmd,
|
||||||
|
out_path,
|
||||||
|
poll_interval=20)
|
||||||
|
output_paths = task.get_output_paths()
|
||||||
|
if not (self._job_failed(task_status, output_paths)) \
|
||||||
|
or retry <= 0:
|
||||||
|
break
|
||||||
|
retry -= 1
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
os.remove(param_file)
|
||||||
|
os.remove(volc_cfg_file)
|
||||||
|
return task_name, returncode
|
||||||
|
|
||||||
|
def _run_task(self, cmd, log_path, poll_interval):
|
||||||
|
result = subprocess.run(cmd,
|
||||||
|
shell=True,
|
||||||
|
text=True,
|
||||||
|
capture_output=True)
|
||||||
|
pattern = r'(?<=task_id=).*(?=\n\n)'
|
||||||
|
match = re.search(pattern, result.stdout)
|
||||||
|
if match:
|
||||||
|
task_id = match.group()
|
||||||
|
ask_cmd = f'volc ml_task get --id {task_id} --output json ' + \
|
||||||
|
'--format Status'
|
||||||
|
log_cmd = f'volc ml_task logs --task {task_id} --instance worker_0'
|
||||||
|
while True:
|
||||||
|
task_status = os.popen(ask_cmd).read()
|
||||||
|
pattern = r'(?<=\[{"Status":").*(?="}\])'
|
||||||
|
match = re.search(pattern, task_status)
|
||||||
|
if match:
|
||||||
|
task_status = match.group()
|
||||||
|
else:
|
||||||
|
task_status = 'Exception'
|
||||||
|
if self.debug:
|
||||||
|
print(task_status)
|
||||||
|
logs = os.popen(log_cmd).read()
|
||||||
|
with open(log_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(logs)
|
||||||
|
if task_status in [
|
||||||
|
'Success', 'Failed', 'Cancelled', 'Exception',
|
||||||
|
'Killing'
|
||||||
|
]:
|
||||||
|
break
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
else:
|
||||||
|
task_status = 'Exception'
|
||||||
|
|
||||||
|
return task_status, result.returncode
|
||||||
|
|
||||||
|
def _job_failed(self, task_status: str, output_paths: List[str]) -> bool:
|
||||||
|
return task_status != 'Success' or not all(
|
||||||
|
osp.exists(output_path) for output_path in output_paths)
|
||||||
|
|
||||||
|
def _choose_flavor(self, num_gpus):
|
||||||
|
config_path = self.volcano_cfg.volcano_config_path
|
||||||
|
with open(config_path) as fp:
|
||||||
|
volc_cfg = yaml.safe_load(fp)
|
||||||
|
if num_gpus <= 0:
|
||||||
|
flavor = 'ml.c1ie.2xlarge'
|
||||||
|
elif num_gpus == 1:
|
||||||
|
flavor = 'ml.pni2l.3xlarge'
|
||||||
|
elif num_gpus == 2:
|
||||||
|
flavor = 'ml.pni2l.7xlarge'
|
||||||
|
elif num_gpus <= 4:
|
||||||
|
flavor = 'ml.pni2l.14xlarge'
|
||||||
|
elif num_gpus <= 8:
|
||||||
|
flavor = 'ml.pni2l.28xlarge'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
role_specs = volc_cfg['TaskRoleSpecs']
|
||||||
|
for i in range(len(role_specs)):
|
||||||
|
if role_specs[i]['RoleName'] == 'worker':
|
||||||
|
role_specs[i]['Flavor'] = flavor
|
||||||
|
|
||||||
|
return volc_cfg
|
@ -16,7 +16,7 @@ from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg,
|
|||||||
model_abbr_from_cfg)
|
model_abbr_from_cfg)
|
||||||
from opencompass.utils.prompt import get_prompt_hash
|
from opencompass.utils.prompt import get_prompt_hash
|
||||||
|
|
||||||
METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth', 'f1', 'exact_match']
|
METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth', 'f1', 'exact_match', 'extract_rate']
|
||||||
METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len', 'type']
|
METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len', 'type']
|
||||||
|
|
||||||
def model_abbr_from_cfg_used_in_summarizer(model):
|
def model_abbr_from_cfg_used_in_summarizer(model):
|
||||||
|
@ -75,6 +75,8 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
for c in sum(self.dataset_cfgs, []))
|
for c in sum(self.dataset_cfgs, []))
|
||||||
self.dump_details = cfg.get('eval', {}).get('runner', {}).get(
|
self.dump_details = cfg.get('eval', {}).get('runner', {}).get(
|
||||||
'task', {}).get('dump_details', False)
|
'task', {}).get('dump_details', False)
|
||||||
|
self.cal_extrat_rate = cfg.get('eval', {}).get('runner', {}).get(
|
||||||
|
'task', {}).get('cal_extrat_rate', False)
|
||||||
|
|
||||||
def get_command(self, cfg_path, template):
|
def get_command(self, cfg_path, template):
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
@ -234,6 +236,9 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
pred_strs, test_set[self.output_column], details,
|
pred_strs, test_set[self.output_column], details,
|
||||||
pred_dicts)
|
pred_dicts)
|
||||||
result['type'] = result['details'].pop('type', None)
|
result['type'] = result['details'].pop('type', None)
|
||||||
|
if self.cal_extrat_rate:
|
||||||
|
# Calculate the extraction success rate for prediction
|
||||||
|
result['extract_rate'] = self.extract_rate(result)
|
||||||
|
|
||||||
if 'PPL' in str(
|
if 'PPL' in str(
|
||||||
self.dataset_cfg.infer_cfg.inferencer.type):
|
self.dataset_cfg.infer_cfg.inferencer.type):
|
||||||
@ -262,6 +267,25 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
mkdir_or_exist(osp.split(out_path)[0])
|
mkdir_or_exist(osp.split(out_path)[0])
|
||||||
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
|
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
def extract_rate(self, results):
|
||||||
|
"""This function is designed for calculating the extraction rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): The result dict, include the information
|
||||||
|
"""
|
||||||
|
details = results['details']
|
||||||
|
details_list = list(details.values())
|
||||||
|
invalid_extractions = []
|
||||||
|
for item in details_list:
|
||||||
|
try:
|
||||||
|
invalid_extractions.extend(
|
||||||
|
[item] if not item['predictions'] else [])
|
||||||
|
except KeyError as e:
|
||||||
|
self.logger.warning(f'Skip {e} due to: {item}')
|
||||||
|
raise KeyError
|
||||||
|
success_rate = 100 - len(invalid_extractions) / len(details) * 100
|
||||||
|
return success_rate
|
||||||
|
|
||||||
def format_details(self, predictions, references, details, pred_dicts):
|
def format_details(self, predictions, references, details, pred_dicts):
|
||||||
"""This function is responsible for formatting prediction details.
|
"""This function is responsible for formatting prediction details.
|
||||||
|
|
||||||
|
@ -9,4 +9,18 @@ def collect_env():
|
|||||||
env_info = collect_base_env()
|
env_info = collect_base_env()
|
||||||
env_info['opencompass'] = opencompass.__version__ + '+' + get_git_hash(
|
env_info['opencompass'] = opencompass.__version__ + '+' + get_git_hash(
|
||||||
)[:7]
|
)[:7]
|
||||||
|
|
||||||
|
# LMDeploy
|
||||||
|
try:
|
||||||
|
import lmdeploy
|
||||||
|
env_info['lmdeploy'] = lmdeploy.__version__
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
env_info['lmdeploy'] = f'not installed:{e}'
|
||||||
|
# Transformers
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
env_info['transformers'] = transformers.__version__
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
env_info['transformers'] = f'not installed:{e}'
|
||||||
|
|
||||||
return env_info
|
return env_info
|
||||||
|
10
setup.py
10
setup.py
@ -117,14 +117,8 @@ def do_setup():
|
|||||||
python_requires='>=3.8.0',
|
python_requires='>=3.8.0',
|
||||||
install_requires=parse_requirements('requirements/runtime.txt'),
|
install_requires=parse_requirements('requirements/runtime.txt'),
|
||||||
license='Apache License 2.0',
|
license='Apache License 2.0',
|
||||||
packages=find_packages(exclude=[
|
include_package_data=True,
|
||||||
'test*',
|
packages=find_packages(),
|
||||||
'configs',
|
|
||||||
'data',
|
|
||||||
'docs',
|
|
||||||
'tools',
|
|
||||||
'tmp',
|
|
||||||
]),
|
|
||||||
keywords=[
|
keywords=[
|
||||||
'AI', 'NLP', 'in-context learning', 'large language model',
|
'AI', 'NLP', 'in-context learning', 'large language model',
|
||||||
'evaluation', 'benchmark', 'llm'
|
'evaluation', 'benchmark', 'llm'
|
||||||
|
@ -109,9 +109,11 @@ def update_imports(data):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('python_files', nargs='*')
|
parser.add_argument('python_files', nargs='*')
|
||||||
|
# Could be opencompass/configs/datasets and configs/datasets
|
||||||
|
parser.add_argument('--root_folder', default='configs/datasets')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
root_folder = 'configs/datasets'
|
root_folder = args.root_folder
|
||||||
if args.python_files:
|
if args.python_files:
|
||||||
python_files = [
|
python_files = [
|
||||||
i for i in args.python_files if i.startswith(root_folder)
|
i for i in args.python_files if i.startswith(root_folder)
|
||||||
|
Loading…
Reference in New Issue
Block a user