mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support sample count in prompt_viewer (#273)
* support sample count in prompt_viewer * update --------- Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
c26ecdb1b0
commit
fa7978fe08
@ -22,6 +22,11 @@ def parse_args():
|
||||
'--pattern',
|
||||
type=str,
|
||||
help='To match the dataset abbr.')
|
||||
parser.add_argument('-c',
|
||||
'--count',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to print')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -40,7 +45,7 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
|
||||
return dataset2cfg
|
||||
|
||||
|
||||
def print_prompts(model_cfg, dataset_cfg):
|
||||
def print_prompts(model_cfg, dataset_cfg, count=1):
|
||||
# TODO: A really dirty method that copies code from PPLInferencer and
|
||||
# GenInferencer. In the future, the prompt extraction code should be
|
||||
# extracted and generalized as a static method in these Inferencers
|
||||
@ -79,90 +84,92 @@ def print_prompts(model_cfg, dataset_cfg):
|
||||
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
|
||||
'Only PPLInferencer and GenInferencer are supported'
|
||||
|
||||
if infer_cfg.inferencer.type == PPLInferencer:
|
||||
labels = retriever.get_labels(ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
ice = [
|
||||
retriever.generate_ice(ice_idx_list[idx],
|
||||
ice_template=ice_template)
|
||||
for idx in range(len(ice_idx_list))
|
||||
]
|
||||
print('-' * 100)
|
||||
print('ICE Template:')
|
||||
print('-' * 100)
|
||||
print(ice[0])
|
||||
print('-' * 100)
|
||||
for label in labels:
|
||||
idx = 0
|
||||
prompt = retriever.generate_label_prompt(
|
||||
for idx in range(min(count, len(ice_idx_list))):
|
||||
if infer_cfg.inferencer.type == PPLInferencer:
|
||||
labels = retriever.get_labels(ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
ice = [
|
||||
retriever.generate_ice(ice_idx_list[_idx],
|
||||
ice_template=ice_template)
|
||||
for _idx in range(len(ice_idx_list))
|
||||
]
|
||||
print('-' * 100)
|
||||
print('ICE Template:')
|
||||
print('-' * 100)
|
||||
print(ice[0])
|
||||
print('-' * 100)
|
||||
for label in labels:
|
||||
prompt = retriever.generate_label_prompt(
|
||||
idx,
|
||||
ice[idx],
|
||||
label,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template,
|
||||
remain_sep=None)
|
||||
if max_seq_len is not None:
|
||||
prompt_token_num = model.get_token_len_from_template(
|
||||
prompt)
|
||||
while len(ice_idx_list[idx]
|
||||
) > 0 and prompt_token_num > max_seq_len:
|
||||
num_ice = len(ice_idx_list[idx])
|
||||
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
||||
f'Number of tokens: {prompt_token_num} -> ...')
|
||||
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
||||
ice[idx] = retriever.generate_ice(
|
||||
ice_idx_list[idx], ice_template=ice_template)
|
||||
prompt = retriever.generate_label_prompt(
|
||||
idx,
|
||||
ice[idx],
|
||||
label,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = model.get_token_len_from_template(
|
||||
prompt)
|
||||
print(f'Number of tokens: {prompt_token_num}')
|
||||
if model is not None:
|
||||
prompt = model.parse_template(prompt, mode='ppl')
|
||||
print('-' * 100)
|
||||
print(f'Label: {label}')
|
||||
print('Sample prompt:')
|
||||
print('-' * 100)
|
||||
print(prompt)
|
||||
print('-' * 100)
|
||||
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
|
||||
ice_idx = ice_idx_list[idx]
|
||||
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice[idx],
|
||||
label,
|
||||
ice,
|
||||
gen_field_replace_token=infer_cfg.inferencer.get(
|
||||
'gen_field_replace_token', ''),
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template,
|
||||
remain_sep=None)
|
||||
prompt_template=prompt_template)
|
||||
if max_seq_len is not None:
|
||||
prompt_token_num = model.get_token_len_from_template(prompt)
|
||||
while len(ice_idx_list[idx]
|
||||
) > 0 and prompt_token_num > max_seq_len:
|
||||
num_ice = len(ice_idx_list[idx])
|
||||
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||
num_ice = len(ice_idx)
|
||||
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
||||
f'Number of tokens: {prompt_token_num} -> ...')
|
||||
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
||||
ice[idx] = retriever.generate_ice(
|
||||
ice_idx_list[idx], ice_template=ice_template)
|
||||
prompt = retriever.generate_label_prompt(
|
||||
ice_idx = ice_idx[:-1]
|
||||
ice = retriever.generate_ice(ice_idx,
|
||||
ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice[idx],
|
||||
label,
|
||||
ice,
|
||||
gen_field_replace_token=infer_cfg.inferencer.get(
|
||||
'gen_field_replace_token', ''),
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = model.get_token_len_from_template(
|
||||
prompt)
|
||||
print(f'Number of tokens: {prompt_token_num}')
|
||||
print(f'Number of tokens: {prompt_token_num}')
|
||||
if model is not None:
|
||||
prompt = model.parse_template(prompt, mode='ppl')
|
||||
prompt = model.parse_template(prompt, mode='gen')
|
||||
print('-' * 100)
|
||||
print(f'Label: {label}')
|
||||
print('Sample prompt:')
|
||||
print('-' * 100)
|
||||
print(prompt)
|
||||
print('-' * 100)
|
||||
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
|
||||
idx, ice_idx = 0, ice_idx_list[0]
|
||||
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
gen_field_replace_token=infer_cfg.inferencer.get(
|
||||
'gen_field_replace_token', ''),
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
if max_seq_len is not None:
|
||||
prompt_token_num = model.get_token_len_from_template(prompt)
|
||||
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||
num_ice = len(ice_idx)
|
||||
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
||||
f'Number of tokens: {prompt_token_num} -> ...')
|
||||
ice_idx = ice_idx[:-1]
|
||||
ice = retriever.generate_ice(ice_idx,
|
||||
ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
gen_field_replace_token=infer_cfg.inferencer.get(
|
||||
'gen_field_replace_token', ''),
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = model.get_token_len_from_template(prompt)
|
||||
print(f'Number of tokens: {prompt_token_num}')
|
||||
if model is not None:
|
||||
prompt = model.parse_template(prompt, mode='gen')
|
||||
print('-' * 100)
|
||||
print('Sample prompt:')
|
||||
print('-' * 100)
|
||||
print(prompt)
|
||||
print('-' * 100)
|
||||
|
||||
|
||||
def main():
|
||||
@ -201,7 +208,7 @@ def main():
|
||||
dataset = list(dataset2cfg.keys())[0]
|
||||
model_cfg = model2cfg[model]
|
||||
dataset_cfg = dataset2cfg[dataset]
|
||||
print_prompts(model_cfg, dataset_cfg)
|
||||
print_prompts(model_cfg, dataset_cfg, args.count)
|
||||
else:
|
||||
for model_abbr, model_cfg in model2cfg.items():
|
||||
for dataset_abbr, dataset_cfg in dataset2cfg.items():
|
||||
@ -209,7 +216,7 @@ def main():
|
||||
print(f'[MODEL]: {model_abbr}')
|
||||
print(f'[DATASET]: {dataset_abbr}')
|
||||
print('---')
|
||||
print_prompts(model_cfg, dataset_cfg)
|
||||
print_prompts(model_cfg, dataset_cfg, args.count)
|
||||
print('=' * 65, '[END]', '=' * 65)
|
||||
print()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user