mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] Update LongEval (#443)
This commit is contained in:
parent
2bb7beeca3
commit
3bb3d330eb
@ -27,7 +27,7 @@ LEval_financialqa_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_financialqa_eval_cfg = dict(
|
LEval_financialqa_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_govreport_summ_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_govreport_summ_eval_cfg = dict(
|
LEval_govreport_summ_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_legalqa_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_legalqa_eval_cfg = dict(
|
LEval_legalqa_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_meetingsumm_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_meetingsumm_eval_cfg = dict(
|
LEval_meetingsumm_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_narrativeqa_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_narrativeqa_eval_cfg = dict(
|
LEval_narrativeqa_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator,),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_nq_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_nq_eval_cfg = dict(
|
LEval_nq_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_newssumm_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_newssumm_eval_cfg = dict(
|
LEval_newssumm_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_ps_summ_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_ps_summ_eval_cfg = dict(
|
LEval_ps_summ_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_patent_summ_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_patent_summ_eval_cfg = dict(
|
LEval_patent_summ_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_review_summ_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_review_summ_eval_cfg = dict(
|
LEval_review_summ_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ LEval_scientificqa_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_scientificqa_eval_cfg = dict(
|
LEval_scientificqa_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.openicl.icl_evaluator import EMEvaluator, RougeEvaluator, SquadEvaluator, AccEvaluator
|
from opencompass.openicl.icl_evaluator import EMEvaluator, RougeEvaluator, SquadEvaluator, AccEvaluator
|
||||||
from opencompass.datasets.leval import LEvalTopicRetrievalDataset
|
from opencompass.datasets.leval import LEvalTopicRetrievalDataset, LEvalEMEvaluator
|
||||||
from opencompass.utils.text_postprocessors import first_capital_postprocess, first_capital_postprocess_multi, general_postprocess
|
from opencompass.utils.text_postprocessors import first_capital_postprocess, first_capital_postprocess_multi, general_postprocess
|
||||||
|
|
||||||
LEval_tr_reader_cfg = dict(
|
LEval_tr_reader_cfg = dict(
|
||||||
@ -28,7 +28,7 @@ LEval_tr_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_tr_eval_cfg = dict(
|
LEval_tr_eval_cfg = dict(
|
||||||
evaluator=dict(type=EMEvaluator),
|
evaluator=dict(type=LEvalEMEvaluator),
|
||||||
pred_postprocessor=dict(type=general_postprocess),
|
pred_postprocessor=dict(type=general_postprocess),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
@ -27,7 +27,7 @@ LEval_tvshow_summ_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LEval_tvshow_summ_eval_cfg = dict(
|
LEval_tvshow_summ_eval_cfg = dict(
|
||||||
evaluator=dict(type=LEvalGPTEvaluator),
|
evaluator=dict(type=RougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ with read_base():
|
|||||||
from .longbenchmultifieldqa_en.longbench_multifieldqa_en_gen import LongBench_multifieldqa_en_datasets
|
from .longbenchmultifieldqa_en.longbench_multifieldqa_en_gen import LongBench_multifieldqa_en_datasets
|
||||||
from .longbenchmultifieldqa_zh.longbench_multifieldqa_zh_gen import LongBench_multifieldqa_zh_datasets
|
from .longbenchmultifieldqa_zh.longbench_multifieldqa_zh_gen import LongBench_multifieldqa_zh_datasets
|
||||||
from .longbenchnarrativeqa.longbench_narrativeqa_gen import LongBench_narrativeqa_datasets
|
from .longbenchnarrativeqa.longbench_narrativeqa_gen import LongBench_narrativeqa_datasets
|
||||||
from .longbenchnq.longbench_nq_gen import LongBench_nq_datasets
|
|
||||||
from .longbenchqasper.longbench_qasper_gen import LongBench_qasper_datasets
|
from .longbenchqasper.longbench_qasper_gen import LongBench_qasper_datasets
|
||||||
from .longbenchtriviaqa.longbench_triviaqa_gen import LongBench_triviaqa_datasets
|
from .longbenchtriviaqa.longbench_triviaqa_gen import LongBench_triviaqa_datasets
|
||||||
from .longbenchgov_report.longbench_gov_report_gen import LongBench_gov_report_datasets
|
from .longbenchgov_report.longbench_gov_report_gen import LongBench_gov_report_datasets
|
||||||
@ -21,5 +20,7 @@ with read_base():
|
|||||||
from .longbenchpassage_count.longbench_passage_count_gen import LongBench_passage_count_datasets
|
from .longbenchpassage_count.longbench_passage_count_gen import LongBench_passage_count_datasets
|
||||||
from .longbenchtrec.longbench_trec_gen import LongBench_trec_datasets
|
from .longbenchtrec.longbench_trec_gen import LongBench_trec_datasets
|
||||||
from .longbenchlsht.longbench_lsht_gen import LongBench_lsht_datasets
|
from .longbenchlsht.longbench_lsht_gen import LongBench_lsht_datasets
|
||||||
|
from .longbenchmulti_news.longbench_multi_news_gen import LongBench_multi_news_datasets
|
||||||
|
from .longbenchsamsum.longbench_samsum_gen import LongBench_samsum_datasets
|
||||||
|
|
||||||
longbench_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
longbench_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
@ -0,0 +1,4 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .longbench_multi_news_gen_f6e3fb import LongBench_multi_news_datasets # noqa: F401, F403
|
@ -0,0 +1,38 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.datasets import LongBenchRougeEvaluator, LongBenchmulti_newsDataset
|
||||||
|
|
||||||
|
LongBench_multi_news_reader_cfg = dict(
|
||||||
|
input_columns=['context'],
|
||||||
|
output_column='answers',
|
||||||
|
train_split='test',
|
||||||
|
test_split='test'
|
||||||
|
)
|
||||||
|
|
||||||
|
LongBench_multi_news_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(role='HUMAN', prompt='You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:'),
|
||||||
|
], )),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=512)
|
||||||
|
)
|
||||||
|
|
||||||
|
LongBench_multi_news_eval_cfg = dict(
|
||||||
|
evaluator=dict(type=LongBenchRougeEvaluator),
|
||||||
|
pred_role='BOT'
|
||||||
|
)
|
||||||
|
|
||||||
|
LongBench_multi_news_datasets = [
|
||||||
|
dict(
|
||||||
|
type=LongBenchmulti_newsDataset,
|
||||||
|
abbr='LongBench_multi_news',
|
||||||
|
path='THUDM/LongBench',
|
||||||
|
name='multi_news',
|
||||||
|
reader_cfg=LongBench_multi_news_reader_cfg,
|
||||||
|
infer_cfg=LongBench_multi_news_infer_cfg,
|
||||||
|
eval_cfg=LongBench_multi_news_eval_cfg)
|
||||||
|
]
|
@ -1,4 +0,0 @@
|
|||||||
from mmengine.config import read_base
|
|
||||||
|
|
||||||
with read_base():
|
|
||||||
from .longbench_nq_gen_d30cb9 import LongBench_nq_datasets # noqa: F401, F403
|
|
@ -1,38 +0,0 @@
|
|||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
|
||||||
from opencompass.datasets import LongBenchF1Evaluator, LongBenchnqDataset
|
|
||||||
|
|
||||||
LongBench_nq_reader_cfg = dict(
|
|
||||||
input_columns=['context', 'input'],
|
|
||||||
output_column='answers',
|
|
||||||
train_split='test',
|
|
||||||
test_split='test'
|
|
||||||
)
|
|
||||||
|
|
||||||
LongBench_nq_infer_cfg = dict(
|
|
||||||
prompt_template=dict(
|
|
||||||
type=PromptTemplate,
|
|
||||||
template=dict(
|
|
||||||
round=[
|
|
||||||
dict(role='HUMAN', prompt='Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}'),
|
|
||||||
], )),
|
|
||||||
retriever=dict(type=ZeroRetriever),
|
|
||||||
inferencer=dict(type=GenInferencer, max_out_len=32)
|
|
||||||
)
|
|
||||||
|
|
||||||
LongBench_nq_eval_cfg = dict(
|
|
||||||
evaluator=dict(type=LongBenchF1Evaluator),
|
|
||||||
pred_role='BOT'
|
|
||||||
)
|
|
||||||
|
|
||||||
LongBench_nq_datasets = [
|
|
||||||
dict(
|
|
||||||
type=LongBenchnqDataset,
|
|
||||||
abbr='LongBench_nq',
|
|
||||||
path='THUDM/LongBench',
|
|
||||||
name='nq',
|
|
||||||
reader_cfg=LongBench_nq_reader_cfg,
|
|
||||||
infer_cfg=LongBench_nq_infer_cfg,
|
|
||||||
eval_cfg=LongBench_nq_eval_cfg)
|
|
||||||
]
|
|
@ -0,0 +1,4 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .longbench_samsum_gen_f4416d import LongBench_samsum_datasets # noqa: F401, F403
|
@ -0,0 +1,38 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.datasets import LongBenchRougeEvaluator, LongBenchsamsumDataset
|
||||||
|
|
||||||
|
LongBench_samsum_reader_cfg = dict(
|
||||||
|
input_columns=['context', 'input'],
|
||||||
|
output_column='answers',
|
||||||
|
train_split='test',
|
||||||
|
test_split='test'
|
||||||
|
)
|
||||||
|
|
||||||
|
LongBench_samsum_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(role='HUMAN', prompt='Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}'),
|
||||||
|
], )),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=128)
|
||||||
|
)
|
||||||
|
|
||||||
|
LongBench_samsum_eval_cfg = dict(
|
||||||
|
evaluator=dict(type=LongBenchRougeEvaluator),
|
||||||
|
pred_role='BOT'
|
||||||
|
)
|
||||||
|
|
||||||
|
LongBench_samsum_datasets = [
|
||||||
|
dict(
|
||||||
|
type=LongBenchsamsumDataset,
|
||||||
|
abbr='LongBench_samsum',
|
||||||
|
path='THUDM/LongBench',
|
||||||
|
name='samsum',
|
||||||
|
reader_cfg=LongBench_samsum_reader_cfg,
|
||||||
|
infer_cfg=LongBench_samsum_infer_cfg,
|
||||||
|
eval_cfg=LongBench_samsum_eval_cfg)
|
||||||
|
]
|
@ -13,19 +13,20 @@ summarizer = dict(
|
|||||||
'--------- LongBench Summarization ---------', # category
|
'--------- LongBench Summarization ---------', # category
|
||||||
'LongBench_gov_report',
|
'LongBench_gov_report',
|
||||||
'LongBench_qmsum',
|
'LongBench_qmsum',
|
||||||
|
'LongBench_multi_news',
|
||||||
'LongBench_vcsum',
|
'LongBench_vcsum',
|
||||||
'--------- LongBench Few-shot Learning ---------', # category
|
'--------- LongBench Few-shot Learning ---------', # category
|
||||||
'LongBench_trec',
|
'LongBench_trec',
|
||||||
'LongBench_nq',
|
|
||||||
'LongBench_triviaqa',
|
'LongBench_triviaqa',
|
||||||
|
'LongBench_samsum',
|
||||||
'LongBench_lsht',
|
'LongBench_lsht',
|
||||||
|
'--------- LongBench Synthetic Tasks ---------', # category
|
||||||
|
'LongBench_passage_count',
|
||||||
|
'LongBench_passage_retrieval_en',
|
||||||
|
'LongBench_passage_retrieval_zh',
|
||||||
'--------- LongBench Code Completion ---------', # category
|
'--------- LongBench Code Completion ---------', # category
|
||||||
'LongBench_lcc',
|
'LongBench_lcc',
|
||||||
'LongBench_repobench-p',
|
'LongBench_repobench-p',
|
||||||
'--------- LongBench Synthetic Tasks ---------', # category
|
|
||||||
'LongBench_passage_retrieval_en',
|
|
||||||
'LongBench_passage_count',
|
|
||||||
'LongBench_passage_retrieval_zh',
|
|
||||||
],
|
],
|
||||||
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
||||||
prompt_db=dict(
|
prompt_db=dict(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from .evaluators import LEvalEMEvaluator # noqa: F401, F403
|
||||||
from .evaluators import LEvalGPTEvaluator # noqa: F401, F403
|
from .evaluators import LEvalGPTEvaluator # noqa: F401, F403
|
||||||
from .leval_coursera import * # noqa: F401, F403
|
from .leval_coursera import * # noqa: F401, F403
|
||||||
from .leval_financial_qa import * # noqa: F401, F403
|
from .leval_financial_qa import * # noqa: F401, F403
|
||||||
|
@ -4,6 +4,7 @@ from typing import List
|
|||||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
from opencompass.registry import ICL_EVALUATORS
|
from opencompass.registry import ICL_EVALUATORS
|
||||||
from opencompass.utils.prompt import PromptList
|
from opencompass.utils.prompt import PromptList
|
||||||
|
from opencompass.utils.text_postprocessors import general_postprocess
|
||||||
|
|
||||||
|
|
||||||
@ICL_EVALUATORS.register_module()
|
@ICL_EVALUATORS.register_module()
|
||||||
@ -107,3 +108,32 @@ class LEvalGPTEvaluator(BaseEvaluator):
|
|||||||
|
|
||||||
score = score / (num_samples - bad_case) * 100
|
score = score / (num_samples - bad_case) * 100
|
||||||
return {'score': score}
|
return {'score': score}
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_EVALUATORS.register_module()
|
||||||
|
class LEvalEMEvaluator(BaseEvaluator):
|
||||||
|
"""Exact match evaluator."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def score(self, predictions, references):
|
||||||
|
if len(predictions) != len(references):
|
||||||
|
return {
|
||||||
|
'error': 'predictions and references have different '
|
||||||
|
'length'
|
||||||
|
}
|
||||||
|
predictions = [
|
||||||
|
general_postprocess(prediction) for prediction in predictions
|
||||||
|
]
|
||||||
|
processed_answers = [general_postprocess(i) for i in references]
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
for pred, ans, origin_ans in zip(predictions, processed_answers,
|
||||||
|
references):
|
||||||
|
if ans in pred or origin_ans in pred:
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
score = cnt / len(predictions) * 100
|
||||||
|
|
||||||
|
return {'score': score}
|
||||||
|
@ -10,17 +10,18 @@ from .longbench_gov_report import * # noqa: F401, F403
|
|||||||
from .longbench_hotpot_qa import * # noqa: F401, F403
|
from .longbench_hotpot_qa import * # noqa: F401, F403
|
||||||
from .longbench_lcc import * # noqa: F401, F403
|
from .longbench_lcc import * # noqa: F401, F403
|
||||||
from .longbench_lsht import * # noqa: F401, F403
|
from .longbench_lsht import * # noqa: F401, F403
|
||||||
|
from .longbench_multi_news import * # noqa: F401, F403
|
||||||
from .longbench_multifieldqa_en import * # noqa: F401, F403
|
from .longbench_multifieldqa_en import * # noqa: F401, F403
|
||||||
from .longbench_multifieldqa_zh import * # noqa: F401, F403
|
from .longbench_multifieldqa_zh import * # noqa: F401, F403
|
||||||
from .longbench_musique import * # noqa: F401, F403
|
from .longbench_musique import * # noqa: F401, F403
|
||||||
from .longbench_narrative_qa import * # noqa: F401, F403
|
from .longbench_narrative_qa import * # noqa: F401, F403
|
||||||
from .longbench_nq import * # noqa: F401, F403
|
|
||||||
from .longbench_passage_count import * # noqa: F401, F403
|
from .longbench_passage_count import * # noqa: F401, F403
|
||||||
from .longbench_passage_retrieval_en import * # noqa: F401, F403
|
from .longbench_passage_retrieval_en import * # noqa: F401, F403
|
||||||
from .longbench_passage_retrieval_zh import * # noqa: F401, F403
|
from .longbench_passage_retrieval_zh import * # noqa: F401, F403
|
||||||
from .longbench_qasper import * # noqa: F401, F403
|
from .longbench_qasper import * # noqa: F401, F403
|
||||||
from .longbench_qmsum import * # noqa: F401, F403
|
from .longbench_qmsum import * # noqa: F401, F403
|
||||||
from .longbench_repobench import * # noqa: F401, F403
|
from .longbench_repobench import * # noqa: F401, F403
|
||||||
|
from .longbench_samsum import * # noqa: F401, F403
|
||||||
from .longbench_trec import * # noqa: F401, F403
|
from .longbench_trec import * # noqa: F401, F403
|
||||||
from .longbench_trivia_qa import * # noqa: F401, F403
|
from .longbench_trivia_qa import * # noqa: F401, F403
|
||||||
from .longbench_vcsum import * # noqa: F401, F403
|
from .longbench_vcsum import * # noqa: F401, F403
|
||||||
|
@ -189,10 +189,10 @@ class LongBenchRougeEvaluator(BaseEvaluator):
|
|||||||
list(jieba.cut(reference, cut_all=False)))
|
list(jieba.cut(reference, cut_all=False)))
|
||||||
|
|
||||||
rouge = Rouge()
|
rouge = Rouge()
|
||||||
if prediction != '':
|
try:
|
||||||
cur_score = rouge.get_scores([prediction], [reference],
|
cur_score = rouge.get_scores([prediction], [reference],
|
||||||
avg=True)['rouge-l']['f']
|
avg=True)['rouge-l']['f']
|
||||||
else:
|
except Exception:
|
||||||
cur_score = 0.
|
cur_score = 0.
|
||||||
task_score = max(task_score, cur_score)
|
task_score = max(task_score, cur_score)
|
||||||
|
|
||||||
|
21
opencompass/datasets/longbench/longbench_multi_news.py
Normal file
21
opencompass/datasets/longbench/longbench_multi_news.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
|
from opencompass.registry import LOAD_DATASET
|
||||||
|
|
||||||
|
from ..base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
|
class LongBenchmulti_newsDataset(BaseDataset):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(**kwargs):
|
||||||
|
dataset = load_dataset(**kwargs)
|
||||||
|
split = 'test'
|
||||||
|
raw_data = []
|
||||||
|
for i in range(len(dataset[split])):
|
||||||
|
context = dataset[split]['context'][i]
|
||||||
|
answers = dataset[split]['answers'][i]
|
||||||
|
raw_data.append({'context': context, 'answers': answers})
|
||||||
|
dataset[split] = Dataset.from_list(raw_data)
|
||||||
|
return dataset
|
@ -6,7 +6,7 @@ from ..base import BaseDataset
|
|||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
class LongBenchnqDataset(BaseDataset):
|
class LongBenchsamsumDataset(BaseDataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(**kwargs):
|
def load(**kwargs):
|
@ -42,6 +42,9 @@ class HuggingFace(BaseModel):
|
|||||||
without batch padding.
|
without batch padding.
|
||||||
pad_token_id (int): The id of the padding token. Defaults to None. Use
|
pad_token_id (int): The id of the padding token. Defaults to None. Use
|
||||||
(#vocab + pad_token_id) if get negative value.
|
(#vocab + pad_token_id) if get negative value.
|
||||||
|
mode (str, optional): The method of input truncation when input length
|
||||||
|
exceeds max_seq_len. 'mid' represents the part of input to
|
||||||
|
truncate. Defaults to 'none'.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
About ``extract_pred_after_decode``: Commonly, we should extract the
|
About ``extract_pred_after_decode``: Commonly, we should extract the
|
||||||
@ -62,7 +65,8 @@ class HuggingFace(BaseModel):
|
|||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
extract_pred_after_decode: bool = False,
|
extract_pred_after_decode: bool = False,
|
||||||
batch_padding: bool = False,
|
batch_padding: bool = False,
|
||||||
pad_token_id: Optional[int] = None):
|
pad_token_id: Optional[int] = None,
|
||||||
|
mode: str = 'none'):
|
||||||
super().__init__(path=path,
|
super().__init__(path=path,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
tokenizer_only=tokenizer_only,
|
tokenizer_only=tokenizer_only,
|
||||||
@ -73,6 +77,8 @@ class HuggingFace(BaseModel):
|
|||||||
patch_hf_auto_model(hf_cache_dir)
|
patch_hf_auto_model(hf_cache_dir)
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
|
assert mode in ['none', 'mid']
|
||||||
|
self.mode = mode
|
||||||
self._load_tokenizer(path=path,
|
self._load_tokenizer(path=path,
|
||||||
tokenizer_path=tokenizer_path,
|
tokenizer_path=tokenizer_path,
|
||||||
tokenizer_kwargs=tokenizer_kwargs)
|
tokenizer_kwargs=tokenizer_kwargs)
|
||||||
@ -228,6 +234,18 @@ class HuggingFace(BaseModel):
|
|||||||
if self.extract_pred_after_decode:
|
if self.extract_pred_after_decode:
|
||||||
prompt_lens = [len(input_) for input_ in inputs]
|
prompt_lens = [len(input_) for input_ in inputs]
|
||||||
|
|
||||||
|
if self.mode == 'mid':
|
||||||
|
input_ids = self.tokenizer(inputs, truncation=False)['input_ids']
|
||||||
|
input_ids = torch.tensor(input_ids, device=self.model.device)
|
||||||
|
if len(input_ids[0]) > self.max_seq_len - max_out_len:
|
||||||
|
half = int((self.max_seq_len - max_out_len) / 2)
|
||||||
|
inputs = [
|
||||||
|
self.tokenizer.decode(input_ids[0][:half],
|
||||||
|
skip_special_tokens=True) +
|
||||||
|
self.tokenizer.decode(input_ids[0][-half:],
|
||||||
|
skip_special_tokens=True)
|
||||||
|
]
|
||||||
|
|
||||||
input_ids = self.tokenizer(inputs,
|
input_ids = self.tokenizer(inputs,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_seq_len -
|
max_length=self.max_seq_len -
|
||||||
|
Loading…
Reference in New Issue
Block a user