[Feature] Update pip install (#1324)

* [Feature] Update pip install

* Update Configuration

* Update

* Update

* Update

* Update Internal Config

* Update collect env
This commit is contained in:
Songyang Zhang 2024-07-29 18:32:50 +08:00 committed by GitHub
parent edab1c07ba
commit 704853e5e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 923 additions and 15 deletions

2
MANIFEST.in Normal file
View File

@ -0,0 +1,2 @@
recursive-include opencompass/configs *.py *.yml *.json *.txt *.md
recursive-include opencompass/openicl/icl_evaluator/hf_metrics *.py

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import GPQASimpleEvalDataset, GPQA_Simple_Eval_postprocess, GPQAEvaluator
from opencompass.datasets import GPQADataset, GPQA_Simple_Eval_postprocess, GPQAEvaluator
# openai_simple_eval prompt
align_prompt = """
@ -43,7 +43,7 @@ for split in list(gpqa_subsets.keys()):
gpqa_datasets.append(
dict(
abbr='GPQA_' + split,
type=GPQASimpleEvalDataset,
type=GPQADataset,
path='./data/gpqa/',
name=gpqa_subsets[split],
reader_cfg=gpqa_reader_cfg,

View File

@ -1 +1 @@
__version__ = '0.2.6'
__version__ = '0.2.7rc1'

View File

@ -0,0 +1,329 @@
import contextlib
import io
import itertools
import multiprocessing
import re
import signal
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Sequence, Union
import numpy as np
from datasets import DatasetDict, concatenate_datasets, load_dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class LCDataset(BaseDataset):
@staticmethod
def load(path: str, num_repeats: int = 1, difficulty='ALL'):
"""Load LC dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
LC dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
def processing_test(example):
example['test_case'] = example['test_list']
example['test_list'] = '\n'.join(example['test_list'])
example['test_column'] = dict(test_list_2=example['test_list'],
task_id=example['Contest id'])
return example
train = load_dataset('json', data_files=path,
split='train[:5]').map(processing_test)
test = load_dataset('json', data_files=path,
split='train[5:]').map(processing_test)
if not difficulty == 'ALL':
train = train.filter(
lambda example: example['Difficulty'] == difficulty)
test = test.filter(
lambda example: example['Difficulty'] == difficulty)
test = concatenate_datasets([test] * num_repeats)
return DatasetDict({'train': train, 'test': test})
class TimeOutException(Exception):
pass
@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with redirect_stdin(stream):
yield
@contextlib.contextmanager
def time_limit(seconds: float):
def signal_handler(signum, frame):
raise TimeOutException('Time out!')
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from."""
def read(self, *args, **kwargs):
raise IOError
def readline(self, *args, **kwargs):
raise IOError
def readlines(self, *args, **kwargs):
raise IOError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'
@ICL_EVALUATORS.register_module()
class LCEvaluator(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
details = {}
with ProcessPoolExecutor() as executor:
futures = []
for i, (refer, pred) in enumerate(zip(references, predictions)):
pred = self._process_answer(pred)
programs = self._process_test(refer, pred)
future = executor.submit(execution, programs, i, 3)
futures.append(future)
from tqdm import tqdm
for future in tqdm(as_completed(futures), total=len(futures)):
index, ret = future.result()
result[ret] += 1
details[str(index)] = {
'programs': predictions[index],
'result': ret,
'is_correct': ret == 'pass',
}
result['score'] = result['pass'] / len(predictions) * 100
result['details'] = details
return result
def _process_answer(self, text):
try:
# for chatGLM related text
eval_text = eval(text)
except Exception:
pass
else:
if isinstance(eval_text, str):
text = eval_text
# deal with code block
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```xxx
text = text[max(text.find('\n') + 1, 0):]
text = text.strip()
match = re.search(r"('\s*|)(\[DONE\]|DONE)", text)
if match:
text = text[:match.start()]
match = re.search(r"(\[BEGIN\]|BEGIN)('\s*|)", text)
if match:
text = text[match.end():]
text = text.strip()
if text.startswith("'"):
text = text[1:]
if text.endswith("'"):
text = text[:-1]
text = text.replace('\\', '')
match = re.search(r'```python(.*)```', text, re.DOTALL)
if match:
text = match.group(1).strip().split('```')[0].strip()
return text
def _process_test(self, test_case, pred):
formatted = pred + '\n'
formatted += test_case
return formatted
def execution(programs, task_id, timeout):
"""Execution function for running generation code.
Args:
programs(str): Python code to be executed.
task_id(int): Task id of the current example.
timeout(int): Time limit for execution, avoid unnecessary
blocking.
In pass@k scenario, a lot of programs should be executed.
Some internal error cannot be handled properly, such as
`RecursionError` might cause system break. It is better to
separate the execution in thread or multiprocess to better
control the process.
"""
def _execution(programs, timeout):
try:
# Add exec globals to prevent the exec to raise
# unnecessary NameError for correct answer
exec_globals = {}
with swallow_io():
with time_limit(timeout):
exec(programs, exec_globals)
key.append('pass')
except TimeOutException:
key.append('timeout')
except AssertionError:
key.append('wrong_answer')
except BaseException as e:
print(e)
key.append('failed')
manager = multiprocessing.Manager()
key = manager.list()
# `signal` cannot be used in child thread, therefore, we
# need to create a process in the thread.
p = multiprocessing.Process(target=_execution,
args=(programs, timeout - 1))
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.kill()
# key might not have value if killed
return task_id, 'timeout'
return task_id, key[0]
class LCPassKEvaluator(LCEvaluator):
"""Better use for pass k evaluation.
Args:
k(Tuple[int]): Choices of Pass@k. Defaults to (1, 10, 100)
"""
def __init__(self, k=(1, 10, 100)) -> None:
if not isinstance(k, Sequence):
k = (k, )
self.k = k
@staticmethod
def estimate_pass_at_k(
num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray],
k: int,
) -> np.ndarray:
"""Estimates pass@k of each problem and returns them in an array."""
def estimator(n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
if isinstance(num_samples, int):
num_samples_it = itertools.repeat(num_samples, len(num_correct))
else:
assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)
return np.array([
estimator(int(n), int(c), k)
for n, c in zip(num_samples_it, num_correct)
])
def score(self, predictions, references):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
task_pass = defaultdict(int)
task_total = defaultdict(int)
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
details = {}
with ProcessPoolExecutor() as executor:
futures = []
index, programs = 0, []
for refer, preds in zip(references, predictions):
# suits for two case
# 1. use repeated dataset
# 2. use `num_return_sequences` to generate multiple responses
if not isinstance(preds, list):
preds = [preds]
test_case = refer['test_list_2']
task_id = refer['task_id']
# create empty task_pass in case all example failed
if task_id not in task_pass:
task_pass[task_id] = 0
for pred in preds:
pred = self._process_answer(pred)
program = self._process_test(test_case, pred)
future = executor.submit(execution, program,
(index, task_id), 3)
futures.append(future)
programs.append(program)
index += 1
from tqdm import tqdm
for future in tqdm(as_completed(futures), total=len(futures)):
(index, task_id), ret = future.result()
result[ret] += 1
task_total[task_id] += 1
is_correct = ret == 'pass'
task_pass[task_id] += is_correct
details[str(index)] = {
'program': programs[index],
'task_id': task_id,
'result': ret,
'is_correct': is_correct,
}
result['details'] = details
def get_number(tasks):
return np.array([
task[1] for task in sorted(tasks.items(), key=lambda x: x[0])
])
task_pass = get_number(task_pass)
task_total = get_number(task_total)
pass_at_k = {
f'pass@{k}':
self.estimate_pass_at_k(task_total, task_pass, k).mean() * 100
for k in self.k if (task_total >= k).all()
}
result.update(pass_at_k)
return result

View File

@ -62,6 +62,7 @@ from .jsonl import JsonlDataset # noqa: F401, F403
from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403
from .lambada import * # noqa: F401, F403
from .lawbench import * # noqa: F401, F403
from .LCBench import * # noqa: F401, F403
from .lcsts import * # noqa: F401, F403
from .leval import * # noqa: F401, F403
from .llm_compression import LLMCompressionDataset # noqa: F401, F403

View File

@ -50,9 +50,15 @@ def bbh_freeform_postprocess(text: str) -> str:
ans_line = ans.split('answer is ')
if len(ans_line) != 1:
ans = ans_line[1].strip()
ans = ans.split('\n')[0]
ans = ans.split('\n')[0].strip()
if ans.endswith('.'):
ans = ans[:-1]
ans = ans[:-1].strip()
match = re.search(r'\*\*(.*?)\*\*', ans)
if match:
return match.group(1)
return ans

View File

@ -6,7 +6,9 @@ from .baidu_api import ERNIEBot # noqa: F401
from .base import BaseModel, LMTemplateParser # noqa: F401
from .base_api import APITemplateParser, BaseAPIModel # noqa: F401
from .bytedance_api import ByteDance # noqa: F401
from .claude_allesapin import ClaudeAllesAPIN # noqa: F401
from .claude_api import Claude # noqa: F401
from .claude_sdk_api import ClaudeSDK # noqa: F401
from .deepseek_api import DeepseekAPI # noqa: F401
from .doubao_api import Doubao # noqa: F401
from .gemini_api import Gemini # noqa: F401

View File

@ -0,0 +1,150 @@
import json
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
import requests
from opencompass.registry import MODELS
from opencompass.utils import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
@MODELS.register_module()
class ClaudeAllesAPIN(BaseAPIModel):
"""Model wrapper around Claude-AllesAPIN.
Args:
path (str): The name of Claude's model.
url (str): URL to AllesAPIN.
key (str): AllesAPIN key.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
max_seq_len (int): Unused here.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
is_api: bool = True
def __init__(self,
path: str,
url: str,
key: str,
query_per_second: int = 1,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2):
super().__init__(path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry)
self.url = url
self.headers = {
'alles-apin-token': key,
'content-type': 'application/json',
}
def generate(self,
inputs: List[PromptType],
max_out_len: int = 512,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenAGIEval's
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
return results
def _generate(self, input: PromptType, max_out_len: int) -> str:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert isinstance(input, (str, PromptList))
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
msg_buffer, last_role = [], None
for item in input:
item['role'] = 'assistant' if item['role'] == 'BOT' else 'user'
if item['role'] != last_role and last_role is not None:
messages.append({
'content': '\n'.join(msg_buffer),
'role': last_role
})
msg_buffer = []
msg_buffer.append(item['prompt'])
last_role = item['role']
messages.append({
'content': '\n'.join(msg_buffer),
'role': last_role
})
data = {
'model': self.path,
'messages': messages,
'max_tokens': max_out_len,
}
err_data = []
for _ in range(self.retry + 1):
self.wait()
try:
raw_response = requests.post(self.url,
headers=self.headers,
data=json.dumps(data))
except requests.ConnectionError:
time.sleep(5)
continue
except requests.ReadTimeout:
time.sleep(5)
continue
try:
response = raw_response.json()
except requests.JSONDecodeError:
if 'https://errors.aliyun.com/images' in \
raw_response.content.decode():
return 'request blocked by allesapin'
self.logger.error('JsonDecode error, got',
raw_response.content)
continue
if raw_response.status_code == 200 and response[
'msgCode'] == '10000':
data = response['data']
generated = data['content'][0]['text'].strip()
self.logger.debug(f'Generated: {generated}')
return generated
self.logger.error(response['data'])
err_data.append(response['data'])
raise RuntimeError(err_data)

View File

@ -0,0 +1,121 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
from opencompass.registry import MODELS
from opencompass.utils import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
@MODELS.register_module()
class ClaudeSDK(BaseAPIModel):
"""Model wrapper around Claude SDK API.
Args:
key (str): Authorization key.
path (str): The model to be used. Defaults to claude-2.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
max_seq_len (int): Unused here.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def __init__(
self,
key: str,
path: str = 'claude-2',
query_per_second: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
temperature: Optional[float] = 0.0,
retry: int = 2,
):
super().__init__(path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry)
try:
from anthropic import Anthropic
except ImportError:
raise ImportError('Import anthropic failed. Please install it '
'with "pip install anthropic" and try again.')
self.anthropic = Anthropic(api_key=key)
self.model = path
self.temperature = temperature
def generate(
self,
inputs: List[PromptType],
max_out_len: int = 512,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
return results
def _generate(
self,
input: PromptType,
max_out_len: int = 512,
) -> str:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert isinstance(input, (str, PromptList))
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
for item in input:
msg = {'content': item['prompt']}
if item['role'] == 'HUMAN':
msg['role'] = 'user'
elif item['role'] == 'BOT':
msg['role'] = 'assistant'
elif item['role'] == 'SYSTEM':
msg['role'] = 'system'
messages.append(msg)
num_retries = 0
while num_retries < self.retry:
self.wait()
try:
responses = self.anthropic.messages.create(
model=self.model,
max_tokens=max_out_len,
temperature=self.temperature,
messages=messages)
return responses.content[0].text
except Exception as e:
self.logger.error(e)
num_retries += 1
raise RuntimeError('Calling Claude API failed after retrying for '
f'{self.retry} times. Check the logs for details.')

View File

@ -1,4 +1,5 @@
import copy
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
@ -60,6 +61,7 @@ class TurboMindModel(BaseModel):
from lmdeploy.messages import TurbomindEngineConfig
engine_config = TurbomindEngineConfig(**engine_config)
self.logger = get_logger()
assert os.path.exists(path), '{} is not existist'.format(path)
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
self.tokenizer = tm_model.tokenizer
self.generators = [

View File

@ -2,3 +2,4 @@ from .dlc import * # noqa: F401, F403
from .local import * # noqa: F401, F403
from .slurm import * # noqa: F401, F403
from .slurm_sequential import * # noqa: F401, F403
from .volc import * # noqa: F401, F403

260
opencompass/runners/volc.py Normal file
View File

@ -0,0 +1,260 @@
import os
import os.path as osp
import random
import re
import subprocess
import time
from functools import partial
from typing import Any, Dict, List, Optional, Tuple
import mmengine
import yaml
from mmengine.config import ConfigDict
from mmengine.utils import track_parallel_progress
from opencompass.registry import RUNNERS, TASKS
from opencompass.utils import get_logger
from .base import BaseRunner
@RUNNERS.register_module()
class VOLCRunner(BaseRunner):
"""Distributed runner based on Volcano Cloud Cluster (VCC). It will launch
multiple tasks in parallel with the 'vcc' command. Please install and
configure VCC first before using this runner.
Args:
task (ConfigDict): Task type config.
volcano_cfg (ConfigDict): Volcano Cloud config.
queue_name (str): Name of resource queue.
preemptible (bool): Whether to launch task in preemptible way.
Default: False
priority (bool): Priority of tasks, ranging from 1 to 9.
9 means the highest priority. Default: None
max_num_workers (int): Max number of workers. Default: 32.
retry (int): Number of retries when job failed. Default: 2.
debug (bool): Whether to run in debug mode. Default: False.
lark_bot_url (str): Lark bot url. Default: None.
"""
def __init__(self,
task: ConfigDict,
volcano_cfg: ConfigDict,
queue_name: str,
preemptible: bool = False,
priority: Optional[int] = None,
max_num_workers: int = 32,
retry: int = 2,
debug: bool = False,
lark_bot_url: str = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.volcano_cfg = volcano_cfg
self.max_num_workers = max_num_workers
self.retry = retry
self.queue_name = queue_name
self.preemptible = preemptible
self.priority = priority
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks.
Args:
tasks (list[dict]): A list of task configs, usually generated by
Partitioner.
Returns:
list[tuple[str, int]]: A list of (task name, exit code).
"""
if not self.debug:
status = track_parallel_progress(self._launch,
tasks,
nproc=self.max_num_workers,
keep_order=False)
else:
status = [self._launch(task, random_sleep=False) for task in tasks]
return status
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
"""Launch a single task.
Args:
task_cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching
multiple tasks at the same time. Default: True.
Returns:
tuple[str, int]: Task name and exit code.
"""
task_type = self.task_cfg.type
if isinstance(self.task_cfg.type, str):
task_type = TASKS.get(task_type)
task = task_type(task_cfg)
num_gpus = task.num_gpus
task_name = task.name
# Build up VCC command
pwd = os.getcwd()
# Dump task config to file
mmengine.mkdir_or_exist('tmp/')
param_file = f'{pwd}/tmp/{os.getpid()}_params.py'
volc_cfg_file = f'{pwd}/tmp/{os.getpid()}_cfg.yaml'
volc_cfg = self._choose_flavor(num_gpus)
with open(volc_cfg_file, 'w') as fp:
yaml.dump(volc_cfg, fp, sort_keys=False)
try:
task_cfg.dump(param_file)
if self.volcano_cfg.get('bashrc_path') is not None:
# using user's conda env
bashrc_path = self.volcano_cfg['bashrc_path']
assert osp.exists(bashrc_path)
assert self.volcano_cfg.get('conda_env_name') is not None
conda_env_name = self.volcano_cfg['conda_env_name']
shell_cmd = (f'source {self.volcano_cfg["bashrc_path"]}; '
f'source activate {conda_env_name}; ')
shell_cmd += f'export PYTHONPATH={pwd}:$PYTHONPATH; '
else:
assert self.volcano_cfg.get('python_env_path') is not None
shell_cmd = (
f'export PATH={self.volcano_cfg["python_env_path"]}/bin:$PATH; ' # noqa: E501
f'export PYTHONPATH={pwd}:$PYTHONPATH; ')
huggingface_cache = self.volcano_cfg.get('huggingface_cache')
if huggingface_cache is not None:
# HUGGINGFACE_HUB_CACHE is a Legacy env variable, here we set
# `HF_HUB_CACHE` and `HUGGINGFACE_HUB_CACHE` for bc
shell_cmd += f'export HF_HUB_CACHE={huggingface_cache}; '
shell_cmd += f'export HUGGINGFACE_HUB_CACHE={huggingface_cache}; ' # noqa: E501
torch_cache = self.volcano_cfg.get('torch_cache')
if torch_cache is not None:
shell_cmd += f'export TORCH_HOME={torch_cache}; '
hf_offline = self.volcano_cfg.get('hf_offline', True)
if hf_offline:
shell_cmd += 'export HF_DATASETS_OFFLINE=1; export TRANSFORMERS_OFFLINE=1; export HF_EVALUATE_OFFLINE=1; export HF_HUB_OFFLINE=1; ' # noqa: E501
hf_endpoint = self.volcano_cfg.get('hf_endpoint')
if hf_endpoint is not None:
shell_cmd += f'export HF_ENDPOINT={hf_endpoint}; '
extra_envs = self.volcano_cfg.get('extra_envs')
if extra_envs is not None:
for extra_env in extra_envs:
shell_cmd += f'export {extra_env}; '
shell_cmd += f'cd {pwd}; '
shell_cmd += '{task_cmd}'
task_name = task_name[:128].replace('[', '-').replace(
']', '').replace('/', '-').replace(',',
'--').replace('.', '_')
tmpl = ('volc ml_task submit'
f" --conf '{volc_cfg_file}'"
f" --entrypoint '{shell_cmd}'"
f' --task_name {task_name}'
f' --resource_queue_name {self.queue_name}')
if self.preemptible:
tmpl += ' --preemptible'
if self.priority is not None:
tmpl += f' --priority {self.priority}'
get_cmd = partial(task.get_command,
cfg_path=param_file,
template=tmpl)
cmd = get_cmd()
logger = get_logger()
logger.debug(f'Running command: {cmd}')
out_path = task.get_log_path(file_extension='txt')
mmengine.mkdir_or_exist(osp.split(out_path)[0])
retry = self.retry
while True:
if random_sleep:
time.sleep(random.randint(0, 10))
task_status, returncode = self._run_task(cmd,
out_path,
poll_interval=20)
output_paths = task.get_output_paths()
if not (self._job_failed(task_status, output_paths)) \
or retry <= 0:
break
retry -= 1
finally:
# Clean up
os.remove(param_file)
os.remove(volc_cfg_file)
return task_name, returncode
def _run_task(self, cmd, log_path, poll_interval):
result = subprocess.run(cmd,
shell=True,
text=True,
capture_output=True)
pattern = r'(?<=task_id=).*(?=\n\n)'
match = re.search(pattern, result.stdout)
if match:
task_id = match.group()
ask_cmd = f'volc ml_task get --id {task_id} --output json ' + \
'--format Status'
log_cmd = f'volc ml_task logs --task {task_id} --instance worker_0'
while True:
task_status = os.popen(ask_cmd).read()
pattern = r'(?<=\[{"Status":").*(?="}\])'
match = re.search(pattern, task_status)
if match:
task_status = match.group()
else:
task_status = 'Exception'
if self.debug:
print(task_status)
logs = os.popen(log_cmd).read()
with open(log_path, 'w', encoding='utf-8') as f:
f.write(logs)
if task_status in [
'Success', 'Failed', 'Cancelled', 'Exception',
'Killing'
]:
break
time.sleep(poll_interval)
else:
task_status = 'Exception'
return task_status, result.returncode
def _job_failed(self, task_status: str, output_paths: List[str]) -> bool:
return task_status != 'Success' or not all(
osp.exists(output_path) for output_path in output_paths)
def _choose_flavor(self, num_gpus):
config_path = self.volcano_cfg.volcano_config_path
with open(config_path) as fp:
volc_cfg = yaml.safe_load(fp)
if num_gpus <= 0:
flavor = 'ml.c1ie.2xlarge'
elif num_gpus == 1:
flavor = 'ml.pni2l.3xlarge'
elif num_gpus == 2:
flavor = 'ml.pni2l.7xlarge'
elif num_gpus <= 4:
flavor = 'ml.pni2l.14xlarge'
elif num_gpus <= 8:
flavor = 'ml.pni2l.28xlarge'
else:
raise NotImplementedError
role_specs = volc_cfg['TaskRoleSpecs']
for i in range(len(role_specs)):
if role_specs[i]['RoleName'] == 'worker':
role_specs[i]['Flavor'] = flavor
return volc_cfg

View File

@ -16,7 +16,7 @@ from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg,
model_abbr_from_cfg)
from opencompass.utils.prompt import get_prompt_hash
METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth', 'f1', 'exact_match']
METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth', 'f1', 'exact_match', 'extract_rate']
METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len', 'type']
def model_abbr_from_cfg_used_in_summarizer(model):

View File

@ -75,6 +75,8 @@ class OpenICLEvalTask(BaseTask):
for c in sum(self.dataset_cfgs, []))
self.dump_details = cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('dump_details', False)
self.cal_extrat_rate = cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('cal_extrat_rate', False)
def get_command(self, cfg_path, template):
sys.path.append(os.getcwd())
@ -234,6 +236,9 @@ class OpenICLEvalTask(BaseTask):
pred_strs, test_set[self.output_column], details,
pred_dicts)
result['type'] = result['details'].pop('type', None)
if self.cal_extrat_rate:
# Calculate the extraction success rate for prediction
result['extract_rate'] = self.extract_rate(result)
if 'PPL' in str(
self.dataset_cfg.infer_cfg.inferencer.type):
@ -262,6 +267,25 @@ class OpenICLEvalTask(BaseTask):
mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
def extract_rate(self, results):
"""This function is designed for calculating the extraction rate.
Args:
results (dict): The result dict, include the information
"""
details = results['details']
details_list = list(details.values())
invalid_extractions = []
for item in details_list:
try:
invalid_extractions.extend(
[item] if not item['predictions'] else [])
except KeyError as e:
self.logger.warning(f'Skip {e} due to: {item}')
raise KeyError
success_rate = 100 - len(invalid_extractions) / len(details) * 100
return success_rate
def format_details(self, predictions, references, details, pred_dicts):
"""This function is responsible for formatting prediction details.

View File

@ -9,4 +9,18 @@ def collect_env():
env_info = collect_base_env()
env_info['opencompass'] = opencompass.__version__ + '+' + get_git_hash(
)[:7]
# LMDeploy
try:
import lmdeploy
env_info['lmdeploy'] = lmdeploy.__version__
except ModuleNotFoundError as e:
env_info['lmdeploy'] = f'not installed:{e}'
# Transformers
try:
import transformers
env_info['transformers'] = transformers.__version__
except ModuleNotFoundError as e:
env_info['transformers'] = f'not installed:{e}'
return env_info

View File

@ -117,14 +117,8 @@ def do_setup():
python_requires='>=3.8.0',
install_requires=parse_requirements('requirements/runtime.txt'),
license='Apache License 2.0',
packages=find_packages(exclude=[
'test*',
'configs',
'data',
'docs',
'tools',
'tmp',
]),
include_package_data=True,
packages=find_packages(),
keywords=[
'AI', 'NLP', 'in-context learning', 'large language model',
'evaluation', 'benchmark', 'llm'

View File

@ -109,9 +109,11 @@ def update_imports(data):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('python_files', nargs='*')
# Could be opencompass/configs/datasets and configs/datasets
parser.add_argument('--root_folder', default='configs/datasets')
args = parser.parse_args()
root_folder = 'configs/datasets'
root_folder = args.root_folder
if args.python_files:
python_files = [
i for i in args.python_files if i.startswith(root_folder)