mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support Math evaluation via judgemodel (#1094)
* support openai math evaluation * support openai math evaluation * support openai math evaluation * support math llm judge * support math llm judge
This commit is contained in:
parent
41196c48ae
commit
6ba1c4937d
35
configs/datasets/math/math_llm_judge.py
Normal file
35
configs/datasets/math/math_llm_judge.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess
|
||||||
|
|
||||||
|
QUERY_TEMPLATE = """
|
||||||
|
Solve the following math problem step by step. The last line of your response should be of the form ANSWER: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
|
||||||
|
{problem}
|
||||||
|
Remember to put your answer on its own line after "ANSWER:", and you do not need to use a \\boxed command.
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
|
||||||
|
|
||||||
|
math_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
|
||||||
|
template=dict(round=[
|
||||||
|
dict(role="HUMAN", prompt=QUERY_TEMPLATE),
|
||||||
|
])),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||||
|
|
||||||
|
math_eval_cfg = dict(
|
||||||
|
evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess))
|
||||||
|
|
||||||
|
math_datasets = [
|
||||||
|
dict(
|
||||||
|
type=MATHDataset,
|
||||||
|
abbr='math',
|
||||||
|
path='./data/math/math.json',
|
||||||
|
reader_cfg=math_reader_cfg,
|
||||||
|
infer_cfg=math_infer_cfg,
|
||||||
|
eval_cfg=math_eval_cfg)
|
||||||
|
]
|
111
configs/eval_math_llm_judge.py
Normal file
111
configs/eval_math_llm_judge.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# Most of the code in this file is copied from https://github.com/openai/simple-evals/blob/main/math_eval.py
|
||||||
|
from mmengine.config import read_base
|
||||||
|
with read_base():
|
||||||
|
from .models.hf_llama.hf_llama3_8b_instruct import models as hf_llama3_8b_instruct_model # noqa: F401, F403
|
||||||
|
from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model # noqa: F401, F403
|
||||||
|
from .models.hf_llama.hf_llama3_70b_instruct import models as hf_llama3_70b_instruct_model # noqa: F401, F403
|
||||||
|
from .datasets.math.math_llm_judge import math_datasets # noqa: F401, F403
|
||||||
|
from opencompass.models.openai_api import OpenAIAllesAPIN
|
||||||
|
from opencompass.datasets import math_judement_preprocess
|
||||||
|
from opencompass.partitioners import NaivePartitioner, SizePartitioner
|
||||||
|
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
|
||||||
|
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
|
||||||
|
from opencompass.runners import LocalRunner
|
||||||
|
from opencompass.runners import SlurmSequentialRunner
|
||||||
|
from opencompass.tasks import OpenICLInferTask
|
||||||
|
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
|
||||||
|
from opencompass.summarizers import AllObjSummarizer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
# -------------Prompt Settings ----------------------------------------
|
||||||
|
eng_obj_prompt = """
|
||||||
|
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
||||||
|
Examples:
|
||||||
|
Expression 1: $2x+3$
|
||||||
|
Expression 2: $3+2x$
|
||||||
|
Result: [[Correct]]
|
||||||
|
Expression 1: 3/2
|
||||||
|
Expression 2: 1.5
|
||||||
|
Result: [[Correct]]
|
||||||
|
Expression 1: $x^2+2x+1$
|
||||||
|
Expression 2: $y^2+2y+1$
|
||||||
|
Result: [[Incorrect]]
|
||||||
|
Expression 1: $x^2+2x+1$
|
||||||
|
Expression 2: $(x+1)^2$
|
||||||
|
Result: [[Correct]]
|
||||||
|
Expression 1: 3245/5
|
||||||
|
Expression 2: 649
|
||||||
|
Result: [[Incorrect]]
|
||||||
|
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
||||||
|
Expression 1: 2/(-3)
|
||||||
|
Expression 2: -2/3
|
||||||
|
Result: [[Correct]]
|
||||||
|
(trivial simplifications are allowed)
|
||||||
|
Expression 1: 72 degrees
|
||||||
|
Expression 2: 72
|
||||||
|
Result: [[Correct]]
|
||||||
|
(give benefit of the doubt to units)
|
||||||
|
Expression 1: 64
|
||||||
|
Expression 2: 64 square feet
|
||||||
|
Result: [[Correct]]
|
||||||
|
(give benefit of the doubt to units)
|
||||||
|
---
|
||||||
|
YOUR TASK
|
||||||
|
Respond with only "Result: [[Correct]]" or "Result: [[Incorrect]]" (without quotes). Do not include a rationale.
|
||||||
|
Expression 1: {obj_gold}
|
||||||
|
Expression 2: {prediction}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
# -------------Inferen Stage ----------------------------------------
|
||||||
|
# eval models
|
||||||
|
models = [*hf_llama3_8b_instruct_model]
|
||||||
|
# judge models
|
||||||
|
judge_models = hf_llama3_70b_instruct_model
|
||||||
|
|
||||||
|
eng_datasets = [*math_datasets]
|
||||||
|
chn_datasets = []
|
||||||
|
datasets = eng_datasets + chn_datasets
|
||||||
|
work_dir = 'outputs/obj_all/'
|
||||||
|
|
||||||
|
for d in eng_datasets:
|
||||||
|
d['eval_cfg']= dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
# If you need to preprocess the prediction before judging,
|
||||||
|
# you can specify the pred_postprocessor function here
|
||||||
|
pred_postprocessor=dict(type=math_judement_preprocess),
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = eng_obj_prompt
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pred_role="BOT",
|
||||||
|
)
|
||||||
|
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=SizePartitioner, max_task_size=40000),
|
||||||
|
runner=dict(
|
||||||
|
type=LocalRunner,
|
||||||
|
max_num_workers=256,
|
||||||
|
task=dict(type=OpenICLInferTask)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------- Evaluation Configuration --------------------------------
|
||||||
|
eval = dict(
|
||||||
|
partitioner=dict(
|
||||||
|
type=SubjectiveSizePartitioner, max_task_size=80000, mode='singlescore', models=models, judge_models=judge_models,
|
||||||
|
),
|
||||||
|
runner=dict(type=LocalRunner,
|
||||||
|
max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
|
||||||
|
)
|
||||||
|
|
||||||
|
summarizer = dict(
|
||||||
|
type=AllObjSummarizer
|
||||||
|
)
|
@ -125,6 +125,15 @@ def normalize_final_answer(final_answer: str) -> str:
|
|||||||
return final_answer
|
return final_answer
|
||||||
|
|
||||||
|
|
||||||
|
ANSWER_PATTERN = r'(?i)ANSWER\s*:\s*([^\n]+)'
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer(response_text: str):
|
||||||
|
# We suggest to return an empty string but not None when extract failed
|
||||||
|
match = re.search(ANSWER_PATTERN, response_text)
|
||||||
|
return match.group(1) if match else ''
|
||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
class MATHDataset(BaseDataset):
|
class MATHDataset(BaseDataset):
|
||||||
|
|
||||||
@ -156,6 +165,12 @@ def math_postprocess(text: str) -> str:
|
|||||||
# text.split('Final Answer: ', 1)[-1].split('\n\n')[0])
|
# text.split('Final Answer: ', 1)[-1].split('\n\n')[0])
|
||||||
|
|
||||||
|
|
||||||
|
@TEXT_POSTPROCESSORS.register_module('math_judement_preprocess')
|
||||||
|
def math_judement_preprocess(text: str) -> str:
|
||||||
|
"""Preprocess prediction before judgement."""
|
||||||
|
return extract_answer(text)
|
||||||
|
|
||||||
|
|
||||||
@TEXT_POSTPROCESSORS.register_module('math_postprocess_v2')
|
@TEXT_POSTPROCESSORS.register_module('math_postprocess_v2')
|
||||||
def math_postprocess_v2(text: str) -> str:
|
def math_postprocess_v2(text: str) -> str:
|
||||||
|
|
||||||
|
@ -12,8 +12,6 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
|
|||||||
from opencompass.registry import ICL_PROMPT_TEMPLATES
|
from opencompass.registry import ICL_PROMPT_TEMPLATES
|
||||||
from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg
|
from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg
|
||||||
from opencompass.utils.logging import get_logger
|
from opencompass.utils.logging import get_logger
|
||||||
from opencompass.utils.text_postprocessors import first_number_postprocess
|
|
||||||
from opencompass.utils.types import get_type_from_cfg
|
|
||||||
|
|
||||||
|
|
||||||
def extract_dicts(data):
|
def extract_dicts(data):
|
||||||
@ -80,7 +78,7 @@ class LMEvaluator:
|
|||||||
dataset_cfg (ConfigDict, optional): The config of the dataset to be
|
dataset_cfg (ConfigDict, optional): The config of the dataset to be
|
||||||
evaluated.
|
evaluated.
|
||||||
pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round.
|
pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round.
|
||||||
postprocessor (ConfigDict): The model prediction's postprocessor
|
pred_postprocessor (ConfigDict): The model prediction's postprocessor
|
||||||
config.
|
config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -92,7 +90,7 @@ class LMEvaluator:
|
|||||||
meta_review_prompt_template: Optional[ConfigDict] = None,
|
meta_review_prompt_template: Optional[ConfigDict] = None,
|
||||||
pack_all_predictions: Optional[bool] = False,
|
pack_all_predictions: Optional[bool] = False,
|
||||||
dataset_cfg: Optional[ConfigDict] = None,
|
dataset_cfg: Optional[ConfigDict] = None,
|
||||||
postprocessor: ConfigDict = dict(type=first_number_postprocess)
|
pred_postprocessor: Optional[ConfigDict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
out_dir, out_name = osp.split(output_path)
|
out_dir, out_name = osp.split(output_path)
|
||||||
@ -112,7 +110,6 @@ class LMEvaluator:
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
output_json_filepath=out_dir,
|
output_json_filepath=out_dir,
|
||||||
output_json_filename=out_name)
|
output_json_filename=out_name)
|
||||||
self.postprocessor = get_type_from_cfg(postprocessor)
|
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
self.dataset_cfg = dataset_cfg
|
self.dataset_cfg = dataset_cfg
|
||||||
self.pack_all_predictions = pack_all_predictions
|
self.pack_all_predictions = pack_all_predictions
|
||||||
@ -163,7 +160,9 @@ class LMEvaluator:
|
|||||||
): #single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']]
|
): #single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']]
|
||||||
for i in range(len(predictions)):
|
for i in range(len(predictions)):
|
||||||
key = 'prediction' if i == 0 else f'prediction{i + 1}'
|
key = 'prediction' if i == 0 else f'prediction{i + 1}'
|
||||||
|
gold_key = 'obj_gold'
|
||||||
pred_dict[key] = predictions[i]
|
pred_dict[key] = predictions[i]
|
||||||
|
pred_dict[gold_key] = references
|
||||||
if judgements:
|
if judgements:
|
||||||
for i in range(len(judgements)):
|
for i in range(len(judgements)):
|
||||||
key = 'judgement' if i == 0 else f'judgement{i + 1}'
|
key = 'judgement' if i == 0 else f'judgement{i + 1}'
|
||||||
@ -189,6 +188,10 @@ class LMEvaluator:
|
|||||||
if judgements:
|
if judgements:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Not applied meta-reivew judge on multi-round dataset')
|
'Not applied meta-reivew judge on multi-round dataset')
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f'{predictions[0][0]} with type {type(predictions[0][0])}, please check the postprocess you add to the prediction string is right or not, we suggest to return an empty string but not None'
|
||||||
|
)
|
||||||
if self.dataset_cfg:
|
if self.dataset_cfg:
|
||||||
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# flake8: noqa: F401, E501
|
# flake8: noqa: F401, E501
|
||||||
from .alignmentbench import AlignmentBenchSummarizer
|
from .alignmentbench import AlignmentBenchSummarizer
|
||||||
|
from .all_obj import AllObjSummarizer
|
||||||
from .alpacaeval import AlpacaSummarizer
|
from .alpacaeval import AlpacaSummarizer
|
||||||
from .compass_arena import CompassArenaSummarizer
|
from .compass_arena import CompassArenaSummarizer
|
||||||
from .corev2 import Corev2Summarizer
|
from .corev2 import Corev2Summarizer
|
||||||
|
122
opencompass/summarizers/subjective/all_obj.py
Normal file
122
opencompass/summarizers/subjective/all_obj.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# flake8: noqa: E501
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmengine import ConfigDict
|
||||||
|
from prettytable import from_csv
|
||||||
|
|
||||||
|
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
|
||||||
|
|
||||||
|
from .utils import get_judgeanswer_and_reference, get_outdir
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_allobj(judgement: str):
|
||||||
|
"""Input a string like below:
|
||||||
|
|
||||||
|
xxx[[correct]]xxx, and extract the judge
|
||||||
|
"""
|
||||||
|
pattern = r'(?i)\[(incorrect|correct|正确|错误)\]'
|
||||||
|
matched_result = re.findall(pattern, judgement)
|
||||||
|
if matched_result:
|
||||||
|
content = matched_result[0].lower()
|
||||||
|
if content in ['correct', '正确']:
|
||||||
|
return {'score': 1}
|
||||||
|
elif content in ['incorrect', '错误']:
|
||||||
|
return {'score': 0}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_capability_results(
|
||||||
|
judged_answers,
|
||||||
|
references,
|
||||||
|
fout,
|
||||||
|
fout_flag,
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
capability_ratings = defaultdict(int)
|
||||||
|
capability_counts = defaultdict(int)
|
||||||
|
for ans, ref in zip(judged_answers, references):
|
||||||
|
capability_ratings['total'] += ans['score']
|
||||||
|
capability_counts['total'] += 1
|
||||||
|
|
||||||
|
capability_avg_ratings = defaultdict(float)
|
||||||
|
|
||||||
|
for capability, total_score in capability_ratings.items():
|
||||||
|
capability_avg_ratings[
|
||||||
|
capability] = total_score / capability_counts[capability]
|
||||||
|
columns = list(capability_avg_ratings.keys())
|
||||||
|
columns.insert(0, columns.pop(columns.index('total')))
|
||||||
|
with open(fout, 'a+', newline='') as csvfile:
|
||||||
|
writer = csv.writer(csvfile)
|
||||||
|
if fout_flag == 0:
|
||||||
|
writer.writerow(['model'] + columns)
|
||||||
|
writer.writerow([model] +
|
||||||
|
[capability_avg_ratings[column] for column in columns])
|
||||||
|
|
||||||
|
|
||||||
|
class AllObjSummarizer:
|
||||||
|
"""Do the subjectivity analyze based on evaluation results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (ConfigDict): The configuration object of the evaluation task.
|
||||||
|
It's expected to be filled out at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ConfigDict, judge_type='single') -> None:
|
||||||
|
self.judge_type = judge_type
|
||||||
|
self.tasks = []
|
||||||
|
self.cfg = config
|
||||||
|
if self.judge_type == 'single':
|
||||||
|
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
|
||||||
|
self.eval_model_abbrs = [
|
||||||
|
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
|
||||||
|
]
|
||||||
|
elif self.judge_type == 'pair':
|
||||||
|
self.base_models = self.cfg['eval']['partitioner']['base_models']
|
||||||
|
self.compare_models = self.cfg['eval']['partitioner'][
|
||||||
|
'compare_models']
|
||||||
|
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
|
||||||
|
self.judge_map = {'single': post_process_allobj}
|
||||||
|
self.judge_function = self.judge_map[self.judge_type]
|
||||||
|
|
||||||
|
def summarize(self,
|
||||||
|
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
|
||||||
|
"""Summarize the subjectivity analysis based on evaluation results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_str (str): Timestamp for file naming.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: The summary results.
|
||||||
|
"""
|
||||||
|
if self.judge_type == 'single':
|
||||||
|
dataset_cfgs = self.cfg['datasets']
|
||||||
|
judge_model = self.judge_abbr
|
||||||
|
output_dir, results_folder = get_outdir(self.cfg, time_str)
|
||||||
|
for dataset in dataset_cfgs:
|
||||||
|
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||||
|
fout = osp.join(
|
||||||
|
output_dir,
|
||||||
|
'judged-by--' + judge_model + '-' + dataset_abbr + '.csv')
|
||||||
|
fout_flag = 0
|
||||||
|
for eval_model_abbr in self.eval_model_abbrs:
|
||||||
|
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
|
||||||
|
subdir_path = os.path.join(results_folder, subdir)
|
||||||
|
if os.path.isdir(subdir_path):
|
||||||
|
model = eval_model_abbr
|
||||||
|
judged_answers, references = get_judgeanswer_and_reference(
|
||||||
|
dataset, subdir_path, self.judge_function)
|
||||||
|
get_capability_results(judged_answers, references,
|
||||||
|
fout, fout_flag, model)
|
||||||
|
fout_flag += 1
|
||||||
|
else:
|
||||||
|
print(subdir_path + ' is not exist! please check!')
|
||||||
|
with open(fout, 'r') as f:
|
||||||
|
x = from_csv(f)
|
||||||
|
print(x)
|
@ -139,7 +139,8 @@ class SubjectiveEvalTask(BaseTask):
|
|||||||
# If no predictions get in predictions dir
|
# If no predictions get in predictions dir
|
||||||
assert osp.exists(filename) or osp.exists(
|
assert osp.exists(filename) or osp.exists(
|
||||||
osp.realpath(partial_filename)
|
osp.realpath(partial_filename)
|
||||||
), 'No predictions found for {filename}.'.format(filename=filename)
|
), 'No predictions found for {filename} and {partial_filename}'.format(
|
||||||
|
filename=filename, partial_filename=partial_filename)
|
||||||
|
|
||||||
# If use Naive partition in infer stage
|
# If use Naive partition in infer stage
|
||||||
if osp.exists(osp.realpath(filename)):
|
if osp.exists(osp.realpath(filename)):
|
||||||
@ -188,10 +189,14 @@ class SubjectiveEvalTask(BaseTask):
|
|||||||
if fnmatch.fnmatch(ds_abbr, pattern):
|
if fnmatch.fnmatch(ds_abbr, pattern):
|
||||||
pred_postprocessor = model_postprocessors[pattern]
|
pred_postprocessor = model_postprocessors[pattern]
|
||||||
break
|
break
|
||||||
if 'pred_postprocessor' in eval_cfg or pred_postprocessor:
|
if 'pred_postprocessor' in eval_cfg['evaluator'] or pred_postprocessor:
|
||||||
kwargs = pred_postprocessor or eval_cfg['pred_postprocessor']
|
kwargs = pred_postprocessor or eval_cfg['evaluator'][
|
||||||
|
'pred_postprocessor']
|
||||||
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
||||||
|
self.logger.info('Get postprocessor {postprocessor}.')
|
||||||
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
||||||
|
else:
|
||||||
|
self.logger.info('No postprocessor found.')
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'model_name': model_abbr_from_cfg(model_cfg),
|
'model_name': model_abbr_from_cfg(model_cfg),
|
||||||
|
@ -77,6 +77,17 @@ def get_config_from_arg(args) -> Config:
|
|||||||
if args.accelerator in ['vllm', 'lmdeploy']:
|
if args.accelerator in ['vllm', 'lmdeploy']:
|
||||||
config['models'] = change_accelerator(config['models'],
|
config['models'] = change_accelerator(config['models'],
|
||||||
args.accelerator)
|
args.accelerator)
|
||||||
|
if 'eval' in config and 'partitioner' in config['eval']:
|
||||||
|
if 'models' in config['eval']['partitioner']:
|
||||||
|
config['eval']['partitioner'][
|
||||||
|
'models'] = change_accelerator(
|
||||||
|
config['eval']['partitioner']['models'],
|
||||||
|
args.accelerator)
|
||||||
|
if 'judge_models' in config['eval']['partitioner']:
|
||||||
|
config['eval']['partitioner'][
|
||||||
|
'judge_models'] = change_accelerator(
|
||||||
|
config['eval']['partitioner']['judge_models'],
|
||||||
|
args.accelerator)
|
||||||
return config
|
return config
|
||||||
# parse dataset args
|
# parse dataset args
|
||||||
if not args.datasets and not args.custom_dataset_path:
|
if not args.datasets and not args.custom_dataset_path:
|
||||||
|
Loading…
Reference in New Issue
Block a user