mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support sample count in prompt_viewer (#273)
* support sample count in prompt_viewer * update --------- Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
c26ecdb1b0
commit
fa7978fe08
@ -22,6 +22,11 @@ def parse_args():
|
|||||||
'--pattern',
|
'--pattern',
|
||||||
type=str,
|
type=str,
|
||||||
help='To match the dataset abbr.')
|
help='To match the dataset abbr.')
|
||||||
|
parser.add_argument('-c',
|
||||||
|
'--count',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Number of prompts to print')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -40,7 +45,7 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
|
|||||||
return dataset2cfg
|
return dataset2cfg
|
||||||
|
|
||||||
|
|
||||||
def print_prompts(model_cfg, dataset_cfg):
|
def print_prompts(model_cfg, dataset_cfg, count=1):
|
||||||
# TODO: A really dirty method that copies code from PPLInferencer and
|
# TODO: A really dirty method that copies code from PPLInferencer and
|
||||||
# GenInferencer. In the future, the prompt extraction code should be
|
# GenInferencer. In the future, the prompt extraction code should be
|
||||||
# extracted and generalized as a static method in these Inferencers
|
# extracted and generalized as a static method in these Inferencers
|
||||||
@ -79,90 +84,92 @@ def print_prompts(model_cfg, dataset_cfg):
|
|||||||
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
|
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
|
||||||
'Only PPLInferencer and GenInferencer are supported'
|
'Only PPLInferencer and GenInferencer are supported'
|
||||||
|
|
||||||
if infer_cfg.inferencer.type == PPLInferencer:
|
for idx in range(min(count, len(ice_idx_list))):
|
||||||
labels = retriever.get_labels(ice_template=ice_template,
|
if infer_cfg.inferencer.type == PPLInferencer:
|
||||||
prompt_template=prompt_template)
|
labels = retriever.get_labels(ice_template=ice_template,
|
||||||
ice = [
|
prompt_template=prompt_template)
|
||||||
retriever.generate_ice(ice_idx_list[idx],
|
ice = [
|
||||||
ice_template=ice_template)
|
retriever.generate_ice(ice_idx_list[_idx],
|
||||||
for idx in range(len(ice_idx_list))
|
ice_template=ice_template)
|
||||||
]
|
for _idx in range(len(ice_idx_list))
|
||||||
print('-' * 100)
|
]
|
||||||
print('ICE Template:')
|
print('-' * 100)
|
||||||
print('-' * 100)
|
print('ICE Template:')
|
||||||
print(ice[0])
|
print('-' * 100)
|
||||||
print('-' * 100)
|
print(ice[0])
|
||||||
for label in labels:
|
print('-' * 100)
|
||||||
idx = 0
|
for label in labels:
|
||||||
prompt = retriever.generate_label_prompt(
|
prompt = retriever.generate_label_prompt(
|
||||||
|
idx,
|
||||||
|
ice[idx],
|
||||||
|
label,
|
||||||
|
ice_template=ice_template,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
remain_sep=None)
|
||||||
|
if max_seq_len is not None:
|
||||||
|
prompt_token_num = model.get_token_len_from_template(
|
||||||
|
prompt)
|
||||||
|
while len(ice_idx_list[idx]
|
||||||
|
) > 0 and prompt_token_num > max_seq_len:
|
||||||
|
num_ice = len(ice_idx_list[idx])
|
||||||
|
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
||||||
|
f'Number of tokens: {prompt_token_num} -> ...')
|
||||||
|
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
||||||
|
ice[idx] = retriever.generate_ice(
|
||||||
|
ice_idx_list[idx], ice_template=ice_template)
|
||||||
|
prompt = retriever.generate_label_prompt(
|
||||||
|
idx,
|
||||||
|
ice[idx],
|
||||||
|
label,
|
||||||
|
ice_template=ice_template,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
prompt_token_num = model.get_token_len_from_template(
|
||||||
|
prompt)
|
||||||
|
print(f'Number of tokens: {prompt_token_num}')
|
||||||
|
if model is not None:
|
||||||
|
prompt = model.parse_template(prompt, mode='ppl')
|
||||||
|
print('-' * 100)
|
||||||
|
print(f'Label: {label}')
|
||||||
|
print('Sample prompt:')
|
||||||
|
print('-' * 100)
|
||||||
|
print(prompt)
|
||||||
|
print('-' * 100)
|
||||||
|
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
|
||||||
|
ice_idx = ice_idx_list[idx]
|
||||||
|
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||||
|
prompt = retriever.generate_prompt_for_generate_task(
|
||||||
idx,
|
idx,
|
||||||
ice[idx],
|
ice,
|
||||||
label,
|
gen_field_replace_token=infer_cfg.inferencer.get(
|
||||||
|
'gen_field_replace_token', ''),
|
||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template)
|
||||||
remain_sep=None)
|
|
||||||
if max_seq_len is not None:
|
if max_seq_len is not None:
|
||||||
prompt_token_num = model.get_token_len_from_template(prompt)
|
prompt_token_num = model.get_token_len_from_template(prompt)
|
||||||
while len(ice_idx_list[idx]
|
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||||
) > 0 and prompt_token_num > max_seq_len:
|
num_ice = len(ice_idx)
|
||||||
num_ice = len(ice_idx_list[idx])
|
|
||||||
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
||||||
f'Number of tokens: {prompt_token_num} -> ...')
|
f'Number of tokens: {prompt_token_num} -> ...')
|
||||||
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
ice_idx = ice_idx[:-1]
|
||||||
ice[idx] = retriever.generate_ice(
|
ice = retriever.generate_ice(ice_idx,
|
||||||
ice_idx_list[idx], ice_template=ice_template)
|
ice_template=ice_template)
|
||||||
prompt = retriever.generate_label_prompt(
|
prompt = retriever.generate_prompt_for_generate_task(
|
||||||
idx,
|
idx,
|
||||||
ice[idx],
|
ice,
|
||||||
label,
|
gen_field_replace_token=infer_cfg.inferencer.get(
|
||||||
|
'gen_field_replace_token', ''),
|
||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template)
|
prompt_template=prompt_template)
|
||||||
prompt_token_num = model.get_token_len_from_template(
|
prompt_token_num = model.get_token_len_from_template(
|
||||||
prompt)
|
prompt)
|
||||||
print(f'Number of tokens: {prompt_token_num}')
|
print(f'Number of tokens: {prompt_token_num}')
|
||||||
if model is not None:
|
if model is not None:
|
||||||
prompt = model.parse_template(prompt, mode='ppl')
|
prompt = model.parse_template(prompt, mode='gen')
|
||||||
print('-' * 100)
|
print('-' * 100)
|
||||||
print(f'Label: {label}')
|
|
||||||
print('Sample prompt:')
|
print('Sample prompt:')
|
||||||
print('-' * 100)
|
print('-' * 100)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print('-' * 100)
|
print('-' * 100)
|
||||||
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
|
|
||||||
idx, ice_idx = 0, ice_idx_list[0]
|
|
||||||
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
|
||||||
prompt = retriever.generate_prompt_for_generate_task(
|
|
||||||
idx,
|
|
||||||
ice,
|
|
||||||
gen_field_replace_token=infer_cfg.inferencer.get(
|
|
||||||
'gen_field_replace_token', ''),
|
|
||||||
ice_template=ice_template,
|
|
||||||
prompt_template=prompt_template)
|
|
||||||
if max_seq_len is not None:
|
|
||||||
prompt_token_num = model.get_token_len_from_template(prompt)
|
|
||||||
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
|
||||||
num_ice = len(ice_idx)
|
|
||||||
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
|
|
||||||
f'Number of tokens: {prompt_token_num} -> ...')
|
|
||||||
ice_idx = ice_idx[:-1]
|
|
||||||
ice = retriever.generate_ice(ice_idx,
|
|
||||||
ice_template=ice_template)
|
|
||||||
prompt = retriever.generate_prompt_for_generate_task(
|
|
||||||
idx,
|
|
||||||
ice,
|
|
||||||
gen_field_replace_token=infer_cfg.inferencer.get(
|
|
||||||
'gen_field_replace_token', ''),
|
|
||||||
ice_template=ice_template,
|
|
||||||
prompt_template=prompt_template)
|
|
||||||
prompt_token_num = model.get_token_len_from_template(prompt)
|
|
||||||
print(f'Number of tokens: {prompt_token_num}')
|
|
||||||
if model is not None:
|
|
||||||
prompt = model.parse_template(prompt, mode='gen')
|
|
||||||
print('-' * 100)
|
|
||||||
print('Sample prompt:')
|
|
||||||
print('-' * 100)
|
|
||||||
print(prompt)
|
|
||||||
print('-' * 100)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -201,7 +208,7 @@ def main():
|
|||||||
dataset = list(dataset2cfg.keys())[0]
|
dataset = list(dataset2cfg.keys())[0]
|
||||||
model_cfg = model2cfg[model]
|
model_cfg = model2cfg[model]
|
||||||
dataset_cfg = dataset2cfg[dataset]
|
dataset_cfg = dataset2cfg[dataset]
|
||||||
print_prompts(model_cfg, dataset_cfg)
|
print_prompts(model_cfg, dataset_cfg, args.count)
|
||||||
else:
|
else:
|
||||||
for model_abbr, model_cfg in model2cfg.items():
|
for model_abbr, model_cfg in model2cfg.items():
|
||||||
for dataset_abbr, dataset_cfg in dataset2cfg.items():
|
for dataset_abbr, dataset_cfg in dataset2cfg.items():
|
||||||
@ -209,7 +216,7 @@ def main():
|
|||||||
print(f'[MODEL]: {model_abbr}')
|
print(f'[MODEL]: {model_abbr}')
|
||||||
print(f'[DATASET]: {dataset_abbr}')
|
print(f'[DATASET]: {dataset_abbr}')
|
||||||
print('---')
|
print('---')
|
||||||
print_prompts(model_cfg, dataset_cfg)
|
print_prompts(model_cfg, dataset_cfg, args.count)
|
||||||
print('=' * 65, '[END]', '=' * 65)
|
print('=' * 65, '[END]', '=' * 65)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user