mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Update Config
This commit is contained in:
parent
eac7a6230d
commit
45af358798
@ -12,8 +12,8 @@ from mmengine.config import Config, DictAction
|
||||
from opencompass.registry import PARTITIONERS, RUNNERS, build_from_cfg
|
||||
from opencompass.runners import SlurmRunner
|
||||
from opencompass.summarizers import DefaultSummarizer
|
||||
from opencompass.utils import (LarkReporter, get_logger, read_from_station,
|
||||
save_to_station)
|
||||
from opencompass.utils import (LarkReporter, get_logger, pretty_print_config,
|
||||
read_from_station, save_to_station)
|
||||
from opencompass.utils.run import (fill_eval_cfg, fill_infer_cfg,
|
||||
get_config_from_arg)
|
||||
|
||||
@ -94,6 +94,11 @@ def parse_args():
|
||||
help='Use the custom config directory instead of config/ to '
|
||||
'search the configs for datasets, models and summarizers',
|
||||
type=str)
|
||||
parser.add_argument(
|
||||
'--config-verbose',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Whether to print the config in verbose mode.')
|
||||
parser.add_argument('-l',
|
||||
'--lark',
|
||||
help='Report the running status to lark bot',
|
||||
@ -131,7 +136,7 @@ def parse_args():
|
||||
'correctness of each sample, bpb, etc.',
|
||||
action='store_true',
|
||||
)
|
||||
|
||||
# for the results persistence
|
||||
parser.add_argument('-sp',
|
||||
'--station-path',
|
||||
help='Path to your results station.',
|
||||
@ -150,7 +155,12 @@ def parse_args():
|
||||
'data station.',
|
||||
action='store_true',
|
||||
)
|
||||
|
||||
# for evaluation with multiple runs
|
||||
parser.add_argument('--dataset-num-runs',
|
||||
help='How many runs for one dataset',
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
|
||||
# set srun args
|
||||
slurm_parser = parser.add_argument_group('slurm_args')
|
||||
@ -299,7 +309,10 @@ def main():
|
||||
content = f'{getpass.getuser()}\'s task has been launched!'
|
||||
LarkReporter(cfg['lark_bot_url']).post(content)
|
||||
|
||||
logger.info(f'The full config is \n{cfg.pretty_text}')
|
||||
|
||||
# print config if specified --config-verbose
|
||||
if args.config_verbose:
|
||||
pretty_print_config(cfg)
|
||||
|
||||
# infer
|
||||
if args.mode in ['all', 'infer']:
|
||||
|
@ -98,12 +98,12 @@ for sub_set in sub_sets:
|
||||
olymmath_datasets.append(
|
||||
dict(
|
||||
type=OlymMATHDataset,
|
||||
abbr=f'olymmath_llmjudge_{sub_set}',
|
||||
abbr=f'olymmath_{sub_set}',
|
||||
path='RUC-AIBOX/OlymMATH',
|
||||
reader_cfg=math_reader_cfg,
|
||||
infer_cfg=math_infer_cfg,
|
||||
eval_cfg=math_eval_cfg,
|
||||
subset=sub_set,
|
||||
n=4
|
||||
n=1
|
||||
)
|
||||
)
|
||||
|
@ -109,8 +109,6 @@ for _name in categories:
|
||||
reader_cfg=olympiadbench_reader_cfg,
|
||||
infer_cfg=olympiadbench_infer_cfg,
|
||||
eval_cfg=olympiadbench_eval_cfg,
|
||||
n=4,
|
||||
n=1,
|
||||
)
|
||||
)
|
||||
|
||||
del _name
|
||||
|
@ -6,7 +6,7 @@ Setting:
|
||||
- CascadeEvaluator
|
||||
- MATHVerifyEvaluator
|
||||
- GenericLLMEvaluator
|
||||
Repeat: 32
|
||||
Repeat: 1
|
||||
Avaliable Models:
|
||||
- Instruct/Chat Models
|
||||
"""
|
||||
@ -113,6 +113,6 @@ aime2024_datasets = [
|
||||
reader_cfg=aime2024_reader_cfg,
|
||||
infer_cfg=aime2024_infer_cfg,
|
||||
eval_cfg=aime2024_eval_cfg,
|
||||
n=32,# Evaluate the dataset with 2 times
|
||||
n=1,# Evaluate the dataset with n times
|
||||
)
|
||||
]
|
||||
|
@ -6,7 +6,7 @@ Setting:
|
||||
- CascadeEvaluator
|
||||
- MATHVerifyEvaluator
|
||||
- GenericLLMEvaluator
|
||||
Repeat: 32
|
||||
Repeat: 1
|
||||
Avaliable Models:
|
||||
- Instruct/Chat Models
|
||||
"""
|
||||
@ -66,7 +66,7 @@ GRADER_TEMPLATE = """
|
||||
Judging the correctness of candidates' answers:
|
||||
""".strip()
|
||||
|
||||
aime2025_eval_cfg = dict(
|
||||
cascade_evaluator = dict(
|
||||
type=CascadeEvaluator,
|
||||
rule_evaluator=dict(
|
||||
type=MATHVerifyEvaluator,
|
||||
@ -98,6 +98,9 @@ aime2025_eval_cfg = dict(
|
||||
),
|
||||
parallel=False,
|
||||
)
|
||||
aime2025_eval_cfg = dict(
|
||||
evaluator=cascade_evaluator,
|
||||
)
|
||||
|
||||
aime2025_datasets = [
|
||||
dict(
|
||||
@ -107,5 +110,6 @@ aime2025_datasets = [
|
||||
reader_cfg=aime2025_reader_cfg,
|
||||
infer_cfg=aime2025_infer_cfg,
|
||||
eval_cfg=aime2025_eval_cfg,
|
||||
n=1,
|
||||
)
|
||||
]
|
||||
|
@ -0,0 +1,118 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import GPQADataset, GPQA_Simple_Eval_postprocess
|
||||
from opencompass.evaluator import GenericLLMEvaluator, CascadeEvaluator
|
||||
from opencompass.datasets import generic_llmjudge_postprocess
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.utils.text_postprocessors import match_answer_pattern
|
||||
|
||||
# openai_simple_eval prompt
|
||||
align_prompt = """
|
||||
Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||
|
||||
{question}
|
||||
|
||||
A) {A}
|
||||
B) {B}
|
||||
C) {C}
|
||||
D) {D}
|
||||
""".strip()
|
||||
|
||||
|
||||
GRADER_TEMPLATE = """
|
||||
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
|
||||
|
||||
Here are some evaluation criteria:
|
||||
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
|
||||
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
|
||||
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
|
||||
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
|
||||
|
||||
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
|
||||
A: CORRECT
|
||||
B: INCORRECT
|
||||
Just return the letters "A" or "B", with no text around it.
|
||||
|
||||
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
|
||||
|
||||
<Original Question Begin>: {question}\n A) {A}\n B) {B}\n C) {C}\n D) {D}\n<Original Question End>\n\n
|
||||
<Gold Target Begin>: \n{answer}\n<Gold Target End>\n\n
|
||||
<Predicted Answer Begin>: \n{prediction}\n<Predicted End>\n\n
|
||||
Judging the correctness of candidates' answers:
|
||||
""".strip()
|
||||
|
||||
|
||||
gpqa_reader_cfg = dict(
|
||||
input_columns=['question', 'A', 'B', 'C', 'D'],
|
||||
output_column='answer')
|
||||
|
||||
gpqa_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt=align_prompt),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
|
||||
|
||||
gpqa_datasets = []
|
||||
gpqa_subsets = {
|
||||
# 'extended': 'gpqa_extended.csv',
|
||||
# 'main': 'gpqa_main.csv',
|
||||
'diamond': 'gpqa_diamond.csv'
|
||||
}
|
||||
|
||||
for split in list(gpqa_subsets.keys()):
|
||||
gpqa_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=CascadeEvaluator,
|
||||
rule_evaluator=dict(
|
||||
type=AccEvaluator,
|
||||
pred_postprocessor=dict(type=match_answer_pattern, answer_pattern=r'(?i)ANSWER\s*:\s*([A-D])'),
|
||||
),
|
||||
llm_evaluator=dict(
|
||||
type=GenericLLMEvaluator,
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='SYSTEM',
|
||||
fallback_role='HUMAN',
|
||||
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.")
|
||||
],
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt = GRADER_TEMPLATE
|
||||
),
|
||||
]),
|
||||
),
|
||||
dataset_cfg=dict(
|
||||
type=GPQADataset,
|
||||
path='./data/gpqa/',
|
||||
name=gpqa_subsets[split],
|
||||
reader_cfg=gpqa_reader_cfg,
|
||||
),
|
||||
judge_cfg=dict(),
|
||||
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
|
||||
),
|
||||
parallel=False,
|
||||
),
|
||||
)
|
||||
gpqa_datasets.append(
|
||||
dict(
|
||||
abbr='GPQA_' + split,
|
||||
type=GPQADataset,
|
||||
path='./data/gpqa/',
|
||||
name=gpqa_subsets[split],
|
||||
reader_cfg=gpqa_reader_cfg,
|
||||
infer_cfg=gpqa_infer_cfg,
|
||||
eval_cfg=gpqa_eval_cfg,
|
||||
mode='singlescore',
|
||||
)
|
||||
)
|
@ -114,6 +114,6 @@ livemathbench_datasets = [
|
||||
),
|
||||
),
|
||||
),
|
||||
n=32, # repeat 32 times
|
||||
n=1, # repeat n times
|
||||
) for split in splits
|
||||
]
|
||||
|
@ -112,6 +112,6 @@ math_datasets = [
|
||||
eval_cfg=dict(
|
||||
evaluator=cascade_evaluator,
|
||||
),
|
||||
n=4,
|
||||
n=1,
|
||||
)
|
||||
]
|
||||
|
@ -0,0 +1,126 @@
|
||||
"""
|
||||
Setting: 0-shot No-CoT
|
||||
Evaluator: GenericLLMEvaluator
|
||||
"""
|
||||
from mmengine.config import read_base
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
from opencompass.utils.text_postprocessors import match_answer_pattern
|
||||
from opencompass.datasets import generic_llmjudge_postprocess
|
||||
from opencompass.evaluator import (
|
||||
CascadeEvaluator,
|
||||
GenericLLMEvaluator,
|
||||
)
|
||||
|
||||
with read_base():
|
||||
# from .....configs.datasets.mmlu.mmlu_all_sets import mmlu_all_sets
|
||||
from .mmlu_stem_sets import mmlu_all_sets
|
||||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||
|
||||
QUERY_TEMPLATE = """
|
||||
Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||
|
||||
{input}
|
||||
|
||||
A) {A}
|
||||
B) {B}
|
||||
C) {C}
|
||||
D) {D}
|
||||
""".strip()
|
||||
|
||||
|
||||
GRADER_TEMPLATE = """
|
||||
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
|
||||
|
||||
Here are some evaluation criteria:
|
||||
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
|
||||
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
|
||||
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
|
||||
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
|
||||
|
||||
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
|
||||
A: CORRECT
|
||||
B: INCORRECT
|
||||
Just return the letters "A" or "B", with no text around it.
|
||||
|
||||
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
|
||||
|
||||
<Original Question Begin>: {input}\n A) {A}\n B) {B}\n C) {C}\n D) {D}\n<Original Question End>\n\n
|
||||
<Gold Target Begin>: \n{target}\n<Gold Target End>\n\n
|
||||
<Predicted Answer Begin>: \n{prediction}\n<Predicted End>\n\n
|
||||
Judging the correctness of candidates' answers:
|
||||
""".strip()
|
||||
|
||||
mmlu_reader_cfg = dict(
|
||||
input_columns=['input', 'A', 'B', 'C', 'D'],
|
||||
output_column='target',
|
||||
train_split='dev')
|
||||
|
||||
mmlu_datasets = []
|
||||
for name in mmlu_all_sets:
|
||||
mmlu_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt=QUERY_TEMPLATE),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=CascadeEvaluator,
|
||||
rule_evaluator=dict(
|
||||
type=AccEvaluator,
|
||||
pred_postprocessor=dict(type=match_answer_pattern, answer_pattern=r'(?i)ANSWER\s*:\s*([A-D])'),
|
||||
),
|
||||
llm_evaluator = dict(
|
||||
type=GenericLLMEvaluator,
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='SYSTEM',
|
||||
fallback_role='HUMAN',
|
||||
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.")
|
||||
],
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt = GRADER_TEMPLATE
|
||||
),
|
||||
]),
|
||||
),
|
||||
dataset_cfg=dict(
|
||||
abbr=f'lukaemon_mmlu_{name}',
|
||||
type=MMLUDataset,
|
||||
path='opencompass/mmlu',
|
||||
name=name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
),
|
||||
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
|
||||
judge_cfg=dict(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
abbr=f'lukaemon_mmlu_{name}',
|
||||
type=MMLUDataset,
|
||||
path='opencompass/mmlu',
|
||||
name=name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
infer_cfg=mmlu_infer_cfg,
|
||||
eval_cfg=mmlu_eval_cfg,
|
||||
mode='singlescore',
|
||||
))
|
@ -3,6 +3,9 @@ from typing import Dict, List, Optional, Union
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets
|
||||
|
||||
from opencompass.openicl import DatasetReader
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
|
@ -76,7 +76,6 @@ class ReviewEvaluator:
|
||||
|
||||
pred_data = data_sample.pred
|
||||
if pred_data is not None:
|
||||
# import pdb; pdb.set_trace()
|
||||
metrics_result['review_quality'] = 1.0 if pred_data == \
|
||||
data_sample.gt else 0.0
|
||||
metrics_result['parse_rate'] = 1.0
|
||||
|
@ -239,6 +239,9 @@ class CascadeEvaluator(BaseEvaluator):
|
||||
|
||||
# Update the details for samples that were evaluated by LLM
|
||||
for i, llm_detail in enumerate(llm_details.values()):
|
||||
# Add dataset replica index to LLM evaluation result
|
||||
llm_detail['dataset_replica_idx'] = self.dataset_replica_idx
|
||||
|
||||
original_index = failed_indices[i]
|
||||
# Store original rule-based evaluation result
|
||||
rule_result = details[original_index].copy()
|
||||
|
@ -99,7 +99,6 @@ class GenericLLMEvaluator(BaseEvaluator):
|
||||
assert len(predictions) == len(
|
||||
references), 'predictions and references must have the same length'
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
# -------------- Build Inferencer ----------------
|
||||
self.build_inferencer()
|
||||
# ---------------- Process Predictions ------------------
|
||||
|
@ -8,6 +8,8 @@ import numpy as np
|
||||
from datasets import Dataset
|
||||
from scipy.stats import hypergeom
|
||||
|
||||
from opencompass.registry import TEXT_POSTPROCESSORS
|
||||
|
||||
|
||||
def compute_pass_at_k(n, c, k):
|
||||
if n - c < k:
|
||||
@ -39,8 +41,8 @@ def compute_mg_pass_at_k(n, c, k):
|
||||
|
||||
class BaseEvaluator:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, pred_postprocessor=None) -> None:
|
||||
self.pred_postprocessor = pred_postprocessor
|
||||
|
||||
@property
|
||||
def output_dir(self):
|
||||
@ -86,6 +88,14 @@ class BaseEvaluator:
|
||||
[detail[metric] for detail in details])
|
||||
return g_passk_details
|
||||
|
||||
def pred_postprocess(self, predictions: List) -> Dict:
|
||||
if self.pred_postprocessor is None:
|
||||
return predictions
|
||||
else:
|
||||
kwargs = self.pred_postprocessor
|
||||
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
||||
return [proc(pred, **kwargs) for pred in predictions]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
k: Union[int, List[int]],
|
||||
|
@ -1,10 +1,11 @@
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
from opencompass.registry import ICL_EVALUATORS
|
||||
|
||||
@ -19,12 +20,16 @@ class HuggingfaceEvaluator(BaseEvaluator):
|
||||
seed (int): There exists some randomness during the calculation of some
|
||||
metrics, thus we set a fixed random seed for reproducing. Defaults
|
||||
to 0.
|
||||
pred_postprocessor (optional): 用于预测后处理的函数或配置。
|
||||
"""
|
||||
|
||||
def __init__(self, metric: str, seed: int = 0) -> None:
|
||||
def __init__(self,
|
||||
metric: str,
|
||||
seed: int = 0,
|
||||
pred_postprocessor=None) -> None:
|
||||
self.metric = metric
|
||||
self.seed = seed
|
||||
super().__init__()
|
||||
super().__init__(pred_postprocessor=pred_postprocessor)
|
||||
|
||||
def _preprocess(self, predictions: List, references: List) -> dict:
|
||||
"""Preprocess the final predictions and references to needed format.
|
||||
@ -37,7 +42,7 @@ class HuggingfaceEvaluator(BaseEvaluator):
|
||||
dict: preprocessed results.
|
||||
"""
|
||||
return {
|
||||
'predictions': predictions,
|
||||
'predictions': self.pred_postprocess(predictions),
|
||||
'references': references,
|
||||
}
|
||||
|
||||
@ -92,8 +97,10 @@ class HuggingfaceEvaluator(BaseEvaluator):
|
||||
class AccEvaluator(HuggingfaceEvaluator):
|
||||
"""Accuracy evaluator."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(metric='accuracy')
|
||||
def __init__(self,
|
||||
pred_postprocessor: Optional[ConfigDict] = None) -> None:
|
||||
super().__init__(metric='accuracy',
|
||||
pred_postprocessor=pred_postprocessor)
|
||||
|
||||
def _preprocess(self, predictions: List, references: List) -> dict:
|
||||
"""Preprocess the final predictions and references to needed format.
|
||||
@ -187,8 +194,9 @@ class RougeEvaluator(HuggingfaceEvaluator):
|
||||
Note: this evaluator is not suitable for chinese datasets.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(metric='rouge')
|
||||
def __init__(self,
|
||||
pred_postprocessor: Optional[ConfigDict] = None) -> None:
|
||||
super().__init__(metric='rouge', pred_postprocessor=pred_postprocessor)
|
||||
|
||||
def _postprocess(self, scores: dict) -> dict:
|
||||
"""Postprocess for final scores.
|
||||
@ -206,8 +214,10 @@ class RougeEvaluator(HuggingfaceEvaluator):
|
||||
class BleuEvaluator(HuggingfaceEvaluator):
|
||||
"""Bleu evaluator."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(metric='sacrebleu')
|
||||
def __init__(self,
|
||||
pred_postprocessor: Optional[ConfigDict] = None) -> None:
|
||||
super().__init__(metric='sacrebleu',
|
||||
pred_postprocessor=pred_postprocessor)
|
||||
|
||||
|
||||
class BleuFloresEvaluator(HuggingfaceEvaluator):
|
||||
|
@ -26,6 +26,7 @@ class NumWorkerPartitioner(BasePartitioner):
|
||||
dataset_size_path (str): The path to the dataset size cache file.
|
||||
keep_keys (list[str]): The keys to be kept from the experiment config
|
||||
to the task config.
|
||||
force_rebuild (bool): Whether to force rebuild dataset to get size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -35,7 +36,8 @@ class NumWorkerPartitioner(BasePartitioner):
|
||||
min_task_size: int = 16,
|
||||
strategy: str = 'heuristic',
|
||||
dataset_size_path: str = '.cache/dataset_size.json',
|
||||
keep_keys: Optional[List[str]] = None):
|
||||
keep_keys: Optional[List[str]] = None,
|
||||
force_rebuild: bool = False):
|
||||
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
|
||||
if strategy == 'split' and num_worker is not None:
|
||||
self.logger.warning('num_worker is ignored with split.')
|
||||
@ -44,6 +46,7 @@ class NumWorkerPartitioner(BasePartitioner):
|
||||
self.num_split = num_split or num_worker
|
||||
self.min_task_size = min_task_size
|
||||
self.dataset_size_path = dataset_size_path
|
||||
self.force_rebuild = force_rebuild
|
||||
assert strategy in ('heuristic', 'split'), \
|
||||
f'Unsupported partition strategy: {strategy}. '\
|
||||
'Supported strategies are: `heuristic`, `split` .'
|
||||
@ -106,7 +109,7 @@ class NumWorkerPartitioner(BasePartitioner):
|
||||
@property
|
||||
def dataset_size(self):
|
||||
if not hasattr(self, '_dataset_size'):
|
||||
if osp.exists(self.dataset_size_path):
|
||||
if not self.force_rebuild and osp.exists(self.dataset_size_path):
|
||||
self._dataset_size = mmengine.load(self.dataset_size_path)
|
||||
else:
|
||||
self._dataset_size = {}
|
||||
@ -130,22 +133,25 @@ class NumWorkerPartitioner(BasePartitioner):
|
||||
|
||||
def get_size(self, dataset: ConfigDict) -> int:
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||
|
||||
test_range = dataset.reader_cfg.get('test_range', '')
|
||||
|
||||
if dataset_abbr in self.dataset_size:
|
||||
# 如果不强制重建且缓存中有数据,则使用缓存
|
||||
if not self.force_rebuild and dataset_abbr in self.dataset_size:
|
||||
actual_size = eval('len(range(self.dataset_size[dataset_abbr])'
|
||||
f'{test_range})')
|
||||
return actual_size
|
||||
|
||||
# 否则重新构建数据集获取大小
|
||||
dataset = build_dataset_from_cfg(dataset)
|
||||
self.dataset_size[dataset_abbr] = len(dataset.test)
|
||||
|
||||
mmengine.mkdir_or_exist('.cache/')
|
||||
mmengine.dump(self.dataset_size,
|
||||
self.dataset_size_path,
|
||||
indent=4,
|
||||
ensure_ascii=False)
|
||||
# 保存到缓存文件
|
||||
if self.dataset_size_path:
|
||||
mmengine.mkdir_or_exist('.cache/')
|
||||
mmengine.dump(self.dataset_size,
|
||||
self.dataset_size_path,
|
||||
indent=4,
|
||||
ensure_ascii=False)
|
||||
|
||||
actual_size = eval('len(range(self.dataset_size[dataset_abbr])'
|
||||
f'{test_range})')
|
||||
|
@ -2,6 +2,8 @@ import logging
|
||||
import os
|
||||
|
||||
from mmengine.logging import MMLogger
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
|
||||
_nameToLevel = {
|
||||
'CRITICAL': logging.CRITICAL,
|
||||
@ -79,3 +81,14 @@ class FilterDuplicateMessage(logging.Filter):
|
||||
self.seen.add(record.msg)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def pretty_print_config(cfg):
|
||||
"""使用rich库美化配置输出."""
|
||||
console = Console()
|
||||
config_str = cfg.pretty_text
|
||||
syntax = Syntax(config_str,
|
||||
'python',
|
||||
theme='solarized-dark',
|
||||
line_numbers=True)
|
||||
console.print(syntax)
|
||||
|
@ -150,6 +150,12 @@ def get_config_from_arg(args) -> Config:
|
||||
dataset['meta_path'] = args.custom_dataset_meta_path
|
||||
dataset = make_custom_dataset_config(dataset)
|
||||
datasets.append(dataset)
|
||||
## apply the dataset repeat runs
|
||||
if len(datasets) > 0 and args.dataset_num_runs > 1:
|
||||
logger.warning(f'User has set the --dataset-num-runs, the datasets will be evaluated with {args.dataset_num_runs} runs.')
|
||||
for _dataset in datasets:
|
||||
logger.warning(f"The default num runs of {_dataset['abbr']} is: {_dataset['n']}, changed into: {args.dataset_num_runs}")
|
||||
_dataset['n'] = args.dataset_num_runs
|
||||
|
||||
# parse model args
|
||||
if not args.models and not args.hf_path:
|
||||
|
Loading…
Reference in New Issue
Block a user