mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* add TheoremQA with 5-shot * add huggingface_above_v4_33 classes * use num_worker partitioner in cli * update theoremqa * update TheoremQA * add TheoremQA * rename theoremqa -> TheoremQA * update TheoremQA output path * rewrite many model configs * update huggingface * further update * refine configs * update configs * update configs * add configs/eval_llama3_instruct.py * add summarizer multi faceted * update bbh datasets * update configs/models/hf_llama/lmdeploy_llama3_8b_instruct.py * rename class * update readme * update hf above v4.33
228 lines
9.1 KiB
Python
228 lines
9.1 KiB
Python
import argparse
|
|
import fnmatch
|
|
from typing import Dict
|
|
|
|
from mmengine.config import Config, ConfigDict
|
|
|
|
from opencompass.openicl.icl_inferencer import (AgentInferencer,
|
|
ChatInferencer, CLPInferencer,
|
|
GenInferencer, LLInferencer,
|
|
PPLInferencer,
|
|
PPLOnlyInferencer)
|
|
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS
|
|
from opencompass.utils import (Menu, build_dataset_from_cfg,
|
|
build_model_from_cfg, dataset_abbr_from_cfg,
|
|
model_abbr_from_cfg)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='View generated prompts based on datasets (and models)')
|
|
parser.add_argument('config', help='Train config file path')
|
|
parser.add_argument('-n', '--non-interactive', action='store_true')
|
|
parser.add_argument('-a', '--all', action='store_true')
|
|
parser.add_argument('-p',
|
|
'--pattern',
|
|
type=str,
|
|
help='To match the dataset abbr.')
|
|
parser.add_argument('-c',
|
|
'--count',
|
|
type=int,
|
|
default=1,
|
|
help='Number of prompts to print')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def parse_model_cfg(model_cfg: ConfigDict) -> Dict[str, ConfigDict]:
|
|
model2cfg = {}
|
|
for model in model_cfg:
|
|
model2cfg[model_abbr_from_cfg(model)] = model
|
|
return model2cfg
|
|
|
|
|
|
def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
|
|
dataset2cfg = {}
|
|
for dataset in dataset_cfg:
|
|
dataset2cfg[dataset_abbr_from_cfg(dataset)] = dataset
|
|
return dataset2cfg
|
|
|
|
|
|
def print_prompts(model_cfg, dataset_cfg, count=1):
|
|
# TODO: A really dirty method that copies code from PPLInferencer and
|
|
# GenInferencer. In the future, the prompt extraction code should be
|
|
# extracted and generalized as a static method in these Inferencers
|
|
# and reused here.
|
|
if model_cfg:
|
|
max_seq_len = model_cfg.get('max_seq_len', 32768)
|
|
if not model_cfg['type'].is_api:
|
|
model_cfg['tokenizer_only'] = True
|
|
model = build_model_from_cfg(model_cfg)
|
|
else:
|
|
max_seq_len = None
|
|
model = None
|
|
|
|
infer_cfg = dataset_cfg.get('infer_cfg')
|
|
|
|
dataset = build_dataset_from_cfg(dataset_cfg)
|
|
|
|
ice_template = None
|
|
if hasattr(infer_cfg, 'ice_template'):
|
|
ice_template = ICL_PROMPT_TEMPLATES.build(infer_cfg['ice_template'])
|
|
|
|
prompt_template = None
|
|
if hasattr(infer_cfg, 'prompt_template'):
|
|
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
|
infer_cfg['prompt_template'])
|
|
|
|
infer_cfg['retriever']['dataset'] = dataset
|
|
retriever = ICL_RETRIEVERS.build(infer_cfg['retriever'])
|
|
|
|
ice_idx_list = retriever.retrieve()
|
|
|
|
supported_inferencer = [
|
|
AgentInferencer, PPLInferencer, GenInferencer, CLPInferencer,
|
|
PPLOnlyInferencer, ChatInferencer, LLInferencer
|
|
]
|
|
if infer_cfg.inferencer.type not in supported_inferencer:
|
|
print(f'Only {supported_inferencer} are supported')
|
|
return
|
|
|
|
for idx in range(min(count, len(ice_idx_list))):
|
|
if issubclass(infer_cfg.inferencer.type,
|
|
(PPLInferencer, LLInferencer)):
|
|
labels = retriever.get_labels(ice_template=ice_template,
|
|
prompt_template=prompt_template)
|
|
ice = retriever.generate_ice(ice_idx_list[idx],
|
|
ice_template=ice_template)
|
|
print('-' * 100)
|
|
print('ICE Template:')
|
|
print('-' * 100)
|
|
print(ice)
|
|
print('-' * 100)
|
|
for label in labels:
|
|
prompt = retriever.generate_label_prompt(
|
|
idx,
|
|
ice,
|
|
label,
|
|
ice_template=ice_template,
|
|
prompt_template=prompt_template,
|
|
remain_sep=None)
|
|
if max_seq_len is not None:
|
|
prompt_token_num = model.get_token_len_from_template(
|
|
prompt)
|
|
while len(ice_idx_list[idx]
|
|
) > 0 and prompt_token_num > max_seq_len:
|
|
num_ice = len(ice_idx_list[idx])
|
|
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
|
f'Number of tokens: {prompt_token_num} -> ...')
|
|
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
|
ice = retriever.generate_ice(ice_idx_list[idx],
|
|
ice_template=ice_template)
|
|
prompt = retriever.generate_label_prompt(
|
|
idx,
|
|
ice,
|
|
label,
|
|
ice_template=ice_template,
|
|
prompt_template=prompt_template)
|
|
prompt_token_num = model.get_token_len_from_template(
|
|
prompt)
|
|
print(f'Number of tokens: {prompt_token_num}')
|
|
if model is not None:
|
|
prompt = model.parse_template(prompt, mode='ppl')
|
|
print('-' * 100)
|
|
print(f'Label: {label}')
|
|
print('Sample prompt:')
|
|
print('-' * 100)
|
|
print(prompt)
|
|
print('-' * 100)
|
|
else:
|
|
ice_idx = ice_idx_list[idx]
|
|
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
|
prompt = retriever.generate_prompt_for_generate_task(
|
|
idx,
|
|
ice,
|
|
gen_field_replace_token=infer_cfg.inferencer.get(
|
|
'gen_field_replace_token', ''),
|
|
ice_template=ice_template,
|
|
prompt_template=prompt_template)
|
|
if max_seq_len is not None:
|
|
prompt_token_num = model.get_token_len_from_template(prompt)
|
|
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
|
num_ice = len(ice_idx)
|
|
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
|
f'Number of tokens: {prompt_token_num} -> ...')
|
|
ice_idx = ice_idx[:-1]
|
|
ice = retriever.generate_ice(ice_idx,
|
|
ice_template=ice_template)
|
|
prompt = retriever.generate_prompt_for_generate_task(
|
|
idx,
|
|
ice,
|
|
gen_field_replace_token=infer_cfg.inferencer.get(
|
|
'gen_field_replace_token', ''),
|
|
ice_template=ice_template,
|
|
prompt_template=prompt_template)
|
|
prompt_token_num = model.get_token_len_from_template(
|
|
prompt)
|
|
print(f'Number of tokens: {prompt_token_num}')
|
|
if model is not None:
|
|
prompt = model.parse_template(prompt, mode='gen')
|
|
print('-' * 100)
|
|
print('Sample prompt:')
|
|
print('-' * 100)
|
|
print(prompt)
|
|
print('-' * 100)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
cfg = Config.fromfile(args.config)
|
|
# cfg.models =
|
|
model2cfg = parse_model_cfg(cfg.models) if 'models' in cfg else {
|
|
'None': None
|
|
}
|
|
if 'datasets' in cfg:
|
|
dataset2cfg = parse_dataset_cfg(cfg.datasets)
|
|
else:
|
|
dataset2cfg = {}
|
|
for key in cfg.keys():
|
|
if key.endswith('_datasets'):
|
|
dataset2cfg.update(parse_dataset_cfg(cfg[key]))
|
|
|
|
if args.pattern is not None:
|
|
matches = fnmatch.filter(dataset2cfg, args.pattern)
|
|
if len(matches) == 0:
|
|
raise ValueError(
|
|
'No dataset match the pattern. Please select from: \n' +
|
|
'\n'.join(dataset2cfg.keys()))
|
|
dataset2cfg = {k: dataset2cfg[k] for k in matches}
|
|
|
|
if not args.all:
|
|
if not args.non_interactive:
|
|
model, dataset = Menu(
|
|
[list(model2cfg.keys()),
|
|
list(dataset2cfg.keys())], [
|
|
f'Please make a selection of {s}:'
|
|
for s in ['model', 'dataset']
|
|
]).run()
|
|
else:
|
|
model = list(model2cfg.keys())[0]
|
|
dataset = list(dataset2cfg.keys())[0]
|
|
model_cfg = model2cfg[model]
|
|
dataset_cfg = dataset2cfg[dataset]
|
|
print_prompts(model_cfg, dataset_cfg, args.count)
|
|
else:
|
|
for model_abbr, model_cfg in model2cfg.items():
|
|
for dataset_abbr, dataset_cfg in dataset2cfg.items():
|
|
print('=' * 64, '[BEGIN]', '=' * 64)
|
|
print(f'[MODEL]: {model_abbr}')
|
|
print(f'[DATASET]: {dataset_abbr}')
|
|
print('---')
|
|
print_prompts(model_cfg, dataset_cfg, args.count)
|
|
print('=' * 65, '[END]', '=' * 65)
|
|
print()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|