mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] Fix cmnli, fix vicuna meta template, fix longbench postprocess and other minor fixes (#625)
This commit is contained in:
parent
5329724b65
commit
d4d1330a5a
@ -2,7 +2,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||||
from opencompass.datasets import HFDataset
|
from opencompass.datasets import cmnliDataset
|
||||||
|
|
||||||
cmnli_reader_cfg = dict(
|
cmnli_reader_cfg = dict(
|
||||||
input_columns=['sentence1', 'sentence2'],
|
input_columns=['sentence1', 'sentence2'],
|
||||||
@ -25,11 +25,9 @@ cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
|||||||
|
|
||||||
cmnli_datasets = [
|
cmnli_datasets = [
|
||||||
dict(
|
dict(
|
||||||
type=HFDataset,
|
abbr="cmnli",
|
||||||
abbr='cmnli',
|
type=cmnliDataset,
|
||||||
path='json',
|
path='./data/CLUE/cmnli/cmnli_public/dev.json',
|
||||||
split='train',
|
|
||||||
data_files='./data/CLUE/cmnli/cmnli_public/dev.json',
|
|
||||||
reader_cfg=cmnli_reader_cfg,
|
reader_cfg=cmnli_reader_cfg,
|
||||||
infer_cfg=cmnli_infer_cfg,
|
infer_cfg=cmnli_infer_cfg,
|
||||||
eval_cfg=cmnli_eval_cfg)
|
eval_cfg=cmnli_eval_cfg)
|
||||||
|
@ -2,7 +2,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||||
from opencompass.datasets import HFDataset
|
from opencompass.datasets import cmnliDataset
|
||||||
|
|
||||||
cmnli_reader_cfg = dict(
|
cmnli_reader_cfg = dict(
|
||||||
input_columns=['sentence1', 'sentence2'],
|
input_columns=['sentence1', 'sentence2'],
|
||||||
@ -41,11 +41,9 @@ cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
|||||||
|
|
||||||
cmnli_datasets = [
|
cmnli_datasets = [
|
||||||
dict(
|
dict(
|
||||||
type=HFDataset,
|
abbr="cmnli",
|
||||||
abbr='cmnli',
|
type=cmnliDataset,
|
||||||
path='json',
|
path='./data/CLUE/cmnli/cmnli_public/dev.json',
|
||||||
split='train',
|
|
||||||
data_files='./data/CLUE/cmnli/cmnli_public/dev.json',
|
|
||||||
reader_cfg=cmnli_reader_cfg,
|
reader_cfg=cmnli_reader_cfg,
|
||||||
infer_cfg=cmnli_infer_cfg,
|
infer_cfg=cmnli_infer_cfg,
|
||||||
eval_cfg=cmnli_eval_cfg)
|
eval_cfg=cmnli_eval_cfg)
|
||||||
|
@ -2,7 +2,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
|
|||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||||
from opencompass.datasets import HFDataset
|
from opencompass.datasets import cmnliDataset
|
||||||
|
|
||||||
cmnli_reader_cfg = dict(
|
cmnli_reader_cfg = dict(
|
||||||
input_columns=['sentence1', 'sentence2'],
|
input_columns=['sentence1', 'sentence2'],
|
||||||
@ -45,11 +45,9 @@ cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
|||||||
|
|
||||||
cmnli_datasets = [
|
cmnli_datasets = [
|
||||||
dict(
|
dict(
|
||||||
type=HFDataset,
|
abbr="cmnli",
|
||||||
abbr='cmnli',
|
type=cmnliDataset,
|
||||||
path='json',
|
path='./data/CLUE/cmnli/cmnli_public/dev.json',
|
||||||
split='train',
|
|
||||||
data_files='./data/CLUE/cmnli/cmnli_public/dev.json',
|
|
||||||
reader_cfg=cmnli_reader_cfg,
|
reader_cfg=cmnli_reader_cfg,
|
||||||
infer_cfg=cmnli_infer_cfg,
|
infer_cfg=cmnli_infer_cfg,
|
||||||
eval_cfg=cmnli_eval_cfg)
|
eval_cfg=cmnli_eval_cfg)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchlshtDataset
|
from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchlshtDataset, lsht_postprocess
|
||||||
|
|
||||||
LongBench_lsht_reader_cfg = dict(
|
LongBench_lsht_reader_cfg = dict(
|
||||||
input_columns=['context', 'input'],
|
input_columns=['context', 'input'],
|
||||||
@ -23,7 +23,8 @@ LongBench_lsht_infer_cfg = dict(
|
|||||||
|
|
||||||
LongBench_lsht_eval_cfg = dict(
|
LongBench_lsht_eval_cfg = dict(
|
||||||
evaluator=dict(type=LongBenchClassificationEvaluator),
|
evaluator=dict(type=LongBenchClassificationEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT',
|
||||||
|
pred_postprocessor=dict(type=lsht_postprocess),
|
||||||
)
|
)
|
||||||
|
|
||||||
LongBench_lsht_datasets = [
|
LongBench_lsht_datasets = [
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.datasets import LongBenchRougeEvaluator, LongBenchsamsumDataset
|
from opencompass.datasets import LongBenchRougeEvaluator, LongBenchsamsumDataset, samsum_postprocess
|
||||||
|
|
||||||
LongBench_samsum_reader_cfg = dict(
|
LongBench_samsum_reader_cfg = dict(
|
||||||
input_columns=['context', 'input'],
|
input_columns=['context', 'input'],
|
||||||
@ -23,7 +23,8 @@ LongBench_samsum_infer_cfg = dict(
|
|||||||
|
|
||||||
LongBench_samsum_eval_cfg = dict(
|
LongBench_samsum_eval_cfg = dict(
|
||||||
evaluator=dict(type=LongBenchRougeEvaluator),
|
evaluator=dict(type=LongBenchRougeEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT',
|
||||||
|
pred_postprocessor=dict(type=samsum_postprocess),
|
||||||
)
|
)
|
||||||
|
|
||||||
LongBench_samsum_datasets = [
|
LongBench_samsum_datasets = [
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchtrecDataset
|
from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchtrecDataset, trec_postprocess
|
||||||
|
|
||||||
LongBench_trec_reader_cfg = dict(
|
LongBench_trec_reader_cfg = dict(
|
||||||
input_columns=['context', 'input'],
|
input_columns=['context', 'input'],
|
||||||
@ -23,7 +23,8 @@ LongBench_trec_infer_cfg = dict(
|
|||||||
|
|
||||||
LongBench_trec_eval_cfg = dict(
|
LongBench_trec_eval_cfg = dict(
|
||||||
evaluator=dict(type=LongBenchClassificationEvaluator),
|
evaluator=dict(type=LongBenchClassificationEvaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT',
|
||||||
|
pred_postprocessor=dict(type=trec_postprocess),
|
||||||
)
|
)
|
||||||
|
|
||||||
LongBench_trec_datasets = [
|
LongBench_trec_datasets = [
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.datasets import LongBenchF1Evaluator, LongBenchtriviaqaDataset
|
from opencompass.datasets import LongBenchF1Evaluator, LongBenchtriviaqaDataset, triviaqa_postprocess
|
||||||
|
|
||||||
LongBench_triviaqa_reader_cfg = dict(
|
LongBench_triviaqa_reader_cfg = dict(
|
||||||
input_columns=['context', 'input'],
|
input_columns=['context', 'input'],
|
||||||
@ -23,7 +23,8 @@ LongBench_triviaqa_infer_cfg = dict(
|
|||||||
|
|
||||||
LongBench_triviaqa_eval_cfg = dict(
|
LongBench_triviaqa_eval_cfg = dict(
|
||||||
evaluator=dict(type=LongBenchF1Evaluator),
|
evaluator=dict(type=LongBenchF1Evaluator),
|
||||||
pred_role='BOT'
|
pred_role='BOT',
|
||||||
|
pred_postprocessor=dict(type=triviaqa_postprocess),
|
||||||
)
|
)
|
||||||
|
|
||||||
LongBench_triviaqa_datasets = [
|
LongBench_triviaqa_datasets = [
|
||||||
|
@ -17,6 +17,7 @@ models = [
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
model_kwargs=dict(device_map='auto'),
|
model_kwargs=dict(device_map='auto'),
|
||||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||||
|
use_fastchat_template=True,
|
||||||
run_cfg=dict(num_gpus=2, num_procs=1)
|
run_cfg=dict(num_gpus=2, num_procs=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -17,6 +17,7 @@ models = [
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
model_kwargs=dict(device_map='auto'),
|
model_kwargs=dict(device_map='auto'),
|
||||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||||
|
use_fastchat_template=True,
|
||||||
run_cfg=dict(num_gpus=1, num_procs=1)
|
run_cfg=dict(num_gpus=1, num_procs=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -17,6 +17,7 @@ models = [
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
model_kwargs=dict(device_map='auto'),
|
model_kwargs=dict(device_map='auto'),
|
||||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||||
|
use_fastchat_template=True,
|
||||||
run_cfg=dict(num_gpus=2, num_procs=1)
|
run_cfg=dict(num_gpus=2, num_procs=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -17,6 +17,7 @@ models = [
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
model_kwargs=dict(device_map='auto'),
|
model_kwargs=dict(device_map='auto'),
|
||||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||||
|
use_fastchat_template=True,
|
||||||
run_cfg=dict(num_gpus=4, num_procs=1)
|
run_cfg=dict(num_gpus=4, num_procs=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -17,6 +17,7 @@ models = [
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
model_kwargs=dict(device_map='auto'),
|
model_kwargs=dict(device_map='auto'),
|
||||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||||
|
use_fastchat_template=True,
|
||||||
run_cfg=dict(num_gpus=1, num_procs=1)
|
run_cfg=dict(num_gpus=1, num_procs=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -17,6 +17,7 @@ models = [
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
model_kwargs=dict(device_map='auto'),
|
model_kwargs=dict(device_map='auto'),
|
||||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||||
|
use_fastchat_template=True,
|
||||||
run_cfg=dict(num_gpus=1, num_procs=1)
|
run_cfg=dict(num_gpus=1, num_procs=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -17,6 +17,7 @@ models = [
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
model_kwargs=dict(device_map='auto'),
|
model_kwargs=dict(device_map='auto'),
|
||||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||||
|
use_fastchat_template=True,
|
||||||
run_cfg=dict(num_gpus=1, num_procs=1)
|
run_cfg=dict(num_gpus=1, num_procs=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
tydiqa_summary_groups = []
|
tydiqa_summary_groups = []
|
||||||
|
|
||||||
_tydiqa = ['arabic', 'bengali', 'english', 'finnish', 'indonesian', 'japanese', 'korean', 'russian', 'swahili', 'telugu', 'thai']
|
_tydiqa = ['arabic', 'bengali', 'english', 'finnish', 'indonesian', 'japanese', 'korean', 'russian', 'swahili', 'telugu', 'thai']
|
||||||
_tydiqa = ['tyidqa-goldp_' + s for s in _tydiqa]
|
_tydiqa = ['tydiqa-goldp_' + s for s in _tydiqa]
|
||||||
tydiqa_summary_groups.append({'name': 'tydiqa-goldp', 'subsets': _tydiqa})
|
tydiqa_summary_groups.append({'name': 'tydiqa-goldp', 'subsets': _tydiqa})
|
||||||
|
@ -18,6 +18,7 @@ class CMBDataset(BaseDataset):
|
|||||||
for d in val_data:
|
for d in val_data:
|
||||||
d['option_str'] = '\n'.join(
|
d['option_str'] = '\n'.join(
|
||||||
[f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1])
|
[f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1])
|
||||||
|
d['answer'] = 'NULL'
|
||||||
val_dataset = Dataset.from_list(val_data)
|
val_dataset = Dataset.from_list(val_data)
|
||||||
|
|
||||||
with open(osp.join(path, 'test.json'), 'r', encoding='utf-8') as f:
|
with open(osp.join(path, 'test.json'), 'r', encoding='utf-8') as f:
|
||||||
@ -25,7 +26,6 @@ class CMBDataset(BaseDataset):
|
|||||||
for d in test_data:
|
for d in test_data:
|
||||||
d['option_str'] = '\n'.join(
|
d['option_str'] = '\n'.join(
|
||||||
[f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1])
|
[f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1])
|
||||||
d['answer'] = 'NULL'
|
|
||||||
test_dataset = Dataset.from_list(test_data)
|
test_dataset = Dataset.from_list(test_data)
|
||||||
|
|
||||||
return DatasetDict({'val': val_dataset, 'test': test_dataset})
|
return DatasetDict({'val': val_dataset, 'test': test_dataset})
|
||||||
|
@ -7,6 +7,19 @@ from opencompass.registry import LOAD_DATASET
|
|||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
|
class cmnliDataset(BaseDataset):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path):
|
||||||
|
data = []
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = json.loads(line)
|
||||||
|
data.append(line)
|
||||||
|
return Dataset.from_list(data)
|
||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
class cmnliDataset_V2(BaseDataset):
|
class cmnliDataset_V2(BaseDataset):
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from opencompass.registry import LOAD_DATASET
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
|
|
||||||
@ -28,3 +28,9 @@ class LongBenchlshtDataset(BaseDataset):
|
|||||||
})
|
})
|
||||||
dataset[split] = Dataset.from_list(raw_data)
|
dataset[split] = Dataset.from_list(raw_data)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@TEXT_POSTPROCESSORS.register_module()
|
||||||
|
def lsht_postprocess(text: str) -> str:
|
||||||
|
text = text.lstrip('\n').split('\n')[0]
|
||||||
|
return text
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from opencompass.registry import LOAD_DATASET
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
|
|
||||||
@ -24,3 +24,9 @@ class LongBenchsamsumDataset(BaseDataset):
|
|||||||
})
|
})
|
||||||
dataset[split] = Dataset.from_list(raw_data)
|
dataset[split] = Dataset.from_list(raw_data)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@TEXT_POSTPROCESSORS.register_module()
|
||||||
|
def samsum_postprocess(text: str) -> str:
|
||||||
|
text = text.lstrip('\n').split('\n')[0]
|
||||||
|
return text
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from opencompass.registry import LOAD_DATASET
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
|
|
||||||
@ -28,3 +28,9 @@ class LongBenchtrecDataset(BaseDataset):
|
|||||||
})
|
})
|
||||||
dataset[split] = Dataset.from_list(raw_data)
|
dataset[split] = Dataset.from_list(raw_data)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@TEXT_POSTPROCESSORS.register_module()
|
||||||
|
def trec_postprocess(text: str) -> str:
|
||||||
|
text = text.lstrip('\n').split('\n')[0]
|
||||||
|
return text
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from opencompass.registry import LOAD_DATASET
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
|
|
||||||
@ -24,3 +24,9 @@ class LongBenchtriviaqaDataset(BaseDataset):
|
|||||||
})
|
})
|
||||||
dataset[split] = Dataset.from_list(raw_data)
|
dataset[split] = Dataset.from_list(raw_data)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@TEXT_POSTPROCESSORS.register_module()
|
||||||
|
def triviaqa_postprocess(text: str) -> str:
|
||||||
|
text = text.lstrip('\n').split('\n')[0]
|
||||||
|
return text
|
||||||
|
@ -46,6 +46,9 @@ class HuggingFace(BaseModel):
|
|||||||
mode (str, optional): The method of input truncation when input length
|
mode (str, optional): The method of input truncation when input length
|
||||||
exceeds max_seq_len. 'mid' represents the part of input to
|
exceeds max_seq_len. 'mid' represents the part of input to
|
||||||
truncate. Defaults to 'none'.
|
truncate. Defaults to 'none'.
|
||||||
|
use_fastchat_template (str, optional): Whether to use fastchat to get
|
||||||
|
the conversation template. If True, fastchat needs to be
|
||||||
|
implemented first. Defaults to False.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
About ``extract_pred_after_decode``: Commonly, we should extract the
|
About ``extract_pred_after_decode``: Commonly, we should extract the
|
||||||
@ -68,7 +71,8 @@ class HuggingFace(BaseModel):
|
|||||||
extract_pred_after_decode: bool = False,
|
extract_pred_after_decode: bool = False,
|
||||||
batch_padding: bool = False,
|
batch_padding: bool = False,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
mode: str = 'none'):
|
mode: str = 'none',
|
||||||
|
use_fastchat_template: bool = False):
|
||||||
super().__init__(path=path,
|
super().__init__(path=path,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
tokenizer_only=tokenizer_only,
|
tokenizer_only=tokenizer_only,
|
||||||
@ -91,6 +95,7 @@ class HuggingFace(BaseModel):
|
|||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
peft_path=peft_path)
|
peft_path=peft_path)
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
|
self.use_fastchat_template = use_fastchat_template
|
||||||
|
|
||||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
||||||
tokenizer_kwargs: dict):
|
tokenizer_kwargs: dict):
|
||||||
@ -220,6 +225,20 @@ class HuggingFace(BaseModel):
|
|||||||
if self.extract_pred_after_decode:
|
if self.extract_pred_after_decode:
|
||||||
prompt_lens = [len(input_) for input_ in inputs]
|
prompt_lens = [len(input_) for input_ in inputs]
|
||||||
|
|
||||||
|
if self.use_fastchat_template:
|
||||||
|
try:
|
||||||
|
from fastchat.model import get_conversation_template
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
'Fastchat is not implemented. You can use '
|
||||||
|
'\'pip install "fschat[model_worker,webui]"\' '
|
||||||
|
'to implement fastchat.')
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
conv = get_conversation_template('vicuna')
|
||||||
|
conv.append_message(conv.roles[0], inputs[i])
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
inputs[i] = conv.get_prompt()
|
||||||
|
|
||||||
# step-1: tokenize the input with batch_encode_plus
|
# step-1: tokenize the input with batch_encode_plus
|
||||||
tokens = self.tokenizer.batch_encode_plus(inputs,
|
tokens = self.tokenizer.batch_encode_plus(inputs,
|
||||||
padding=True,
|
padding=True,
|
||||||
@ -263,6 +282,19 @@ class HuggingFace(BaseModel):
|
|||||||
if self.extract_pred_after_decode:
|
if self.extract_pred_after_decode:
|
||||||
prompt_lens = [len(input_) for input_ in inputs]
|
prompt_lens = [len(input_) for input_ in inputs]
|
||||||
|
|
||||||
|
if self.use_fastchat_template:
|
||||||
|
try:
|
||||||
|
from fastchat.model import get_conversation_template
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
'Fastchat is not implemented. You can use '
|
||||||
|
'\'pip install "fschat[model_worker,webui]"\' '
|
||||||
|
'to implement fastchat.')
|
||||||
|
conv = get_conversation_template('vicuna')
|
||||||
|
conv.append_message(conv.roles[0], inputs[0])
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
inputs = [conv.get_prompt()]
|
||||||
|
|
||||||
if self.mode == 'mid':
|
if self.mode == 'mid':
|
||||||
input_ids = self.tokenizer(inputs, truncation=False)['input_ids']
|
input_ids = self.tokenizer(inputs, truncation=False)['input_ids']
|
||||||
input_ids = torch.tensor(input_ids, device=self.model.device)
|
input_ids = torch.tensor(input_ids, device=self.model.device)
|
||||||
@ -491,7 +523,8 @@ class HuggingFaceChatGLM3(HuggingFace):
|
|||||||
def generate(self,
|
def generate(self,
|
||||||
inputs: List[str or PromptList],
|
inputs: List[str or PromptList],
|
||||||
max_out_len: int = 512,
|
max_out_len: int = 512,
|
||||||
temperature: float = 0.6) -> str:
|
temperature: float = 0.6,
|
||||||
|
skip_overlength=False) -> str:
|
||||||
"""Generate response from input prompt.
|
"""Generate response from input prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -518,6 +551,20 @@ class HuggingFaceChatGLM3(HuggingFace):
|
|||||||
history.append(msg)
|
history.append(msg)
|
||||||
user_content = history[-1]['content']
|
user_content = history[-1]['content']
|
||||||
history = history[:-1]
|
history = history[:-1]
|
||||||
|
|
||||||
|
if skip_overlength:
|
||||||
|
# The model will report the following error
|
||||||
|
# if the sequence length is greater than the maximum length:
|
||||||
|
# "Input length of input_ids is {INPUT_IDS},
|
||||||
|
# but `max_length` is set to 8192.
|
||||||
|
# This can lead to unexpected behavior.
|
||||||
|
# You should consider increasing `max_new_tokens`."
|
||||||
|
# The following hardcode can fix this exception.
|
||||||
|
len_user_content = len(self.tokenizer.encode(user_content))
|
||||||
|
if len_user_content > 8192:
|
||||||
|
responses.append('')
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, history = self.model.chat(self.tokenizer,
|
response, history = self.model.chat(self.tokenizer,
|
||||||
user_content,
|
user_content,
|
||||||
|
@ -141,12 +141,19 @@ class Llama2Chat(BaseModel):
|
|||||||
path: str,
|
path: str,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
tokenizer_path: Optional[str] = None):
|
tokenizer_path: Optional[str] = None,
|
||||||
|
force_bf16=False):
|
||||||
from llama import Llama
|
from llama import Llama
|
||||||
self.generator = Llama.build(path, tokenizer_path, max_seq_len,
|
self.generator = Llama.build(path, tokenizer_path, max_seq_len,
|
||||||
max_batch_size)
|
max_batch_size)
|
||||||
self.tokenizer = self.generator.tokenizer
|
self.tokenizer = self.generator.tokenizer
|
||||||
self.model = self.generator.model
|
self.model = self.generator.model
|
||||||
|
if force_bf16:
|
||||||
|
# force set model to `bfloat16` to fix
|
||||||
|
# the exception of 'RuntimeError: probability tensor
|
||||||
|
# contains either `inf`, `nan` or element < 0',
|
||||||
|
# encountered during the inference of llama2-7b
|
||||||
|
self.model = self.model.bfloat16()
|
||||||
|
|
||||||
def _load_tokenizer(self, tokenizer_path: str):
|
def _load_tokenizer(self, tokenizer_path: str):
|
||||||
from llama import Tokenizer
|
from llama import Tokenizer
|
||||||
|
@ -108,9 +108,13 @@ class GenInferencer(BaseInferencer):
|
|||||||
'tmp_' + output_json_filename)
|
'tmp_' + output_json_filename)
|
||||||
if osp.exists(tmp_json_filepath):
|
if osp.exists(tmp_json_filepath):
|
||||||
# TODO: move resume to output handler
|
# TODO: move resume to output handler
|
||||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
try:
|
||||||
output_handler.results_dict = tmp_result_dict
|
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||||
index = len(tmp_result_dict)
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
output_handler.results_dict = tmp_result_dict
|
||||||
|
index = len(tmp_result_dict)
|
||||||
|
|
||||||
# 4. Wrap prompts with Dataloader
|
# 4. Wrap prompts with Dataloader
|
||||||
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
||||||
|
@ -96,7 +96,7 @@ class SlurmSequentialRunner(BaseRunner):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
parent_conns = []
|
parent_conns = []
|
||||||
num_workers = min(self.max_num_workers, len(tasks))
|
num_workers = max(min(self.max_num_workers, len(tasks)), 1)
|
||||||
with Pool(processes=num_workers) as pool:
|
with Pool(processes=num_workers) as pool:
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
parent_conn, child_conn = Pipe()
|
parent_conn, child_conn = Pipe()
|
||||||
|
Loading…
Reference in New Issue
Block a user