[Feature] Support sample count in prompt_viewer (#273)

* support sample count in prompt_viewer

* update

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
cdpath 2023-08-29 20:51:10 +08:00 committed by GitHub
parent c26ecdb1b0
commit fa7978fe08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,6 +22,11 @@ def parse_args():
'--pattern', '--pattern',
type=str, type=str,
help='To match the dataset abbr.') help='To match the dataset abbr.')
parser.add_argument('-c',
'--count',
type=int,
default=1,
help='Number of prompts to print')
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -40,7 +45,7 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
return dataset2cfg return dataset2cfg
def print_prompts(model_cfg, dataset_cfg): def print_prompts(model_cfg, dataset_cfg, count=1):
# TODO: A really dirty method that copies code from PPLInferencer and # TODO: A really dirty method that copies code from PPLInferencer and
# GenInferencer. In the future, the prompt extraction code should be # GenInferencer. In the future, the prompt extraction code should be
# extracted and generalized as a static method in these Inferencers # extracted and generalized as a static method in these Inferencers
@ -79,13 +84,14 @@ def print_prompts(model_cfg, dataset_cfg):
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \ assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
'Only PPLInferencer and GenInferencer are supported' 'Only PPLInferencer and GenInferencer are supported'
for idx in range(min(count, len(ice_idx_list))):
if infer_cfg.inferencer.type == PPLInferencer: if infer_cfg.inferencer.type == PPLInferencer:
labels = retriever.get_labels(ice_template=ice_template, labels = retriever.get_labels(ice_template=ice_template,
prompt_template=prompt_template) prompt_template=prompt_template)
ice = [ ice = [
retriever.generate_ice(ice_idx_list[idx], retriever.generate_ice(ice_idx_list[_idx],
ice_template=ice_template) ice_template=ice_template)
for idx in range(len(ice_idx_list)) for _idx in range(len(ice_idx_list))
] ]
print('-' * 100) print('-' * 100)
print('ICE Template:') print('ICE Template:')
@ -93,7 +99,6 @@ def print_prompts(model_cfg, dataset_cfg):
print(ice[0]) print(ice[0])
print('-' * 100) print('-' * 100)
for label in labels: for label in labels:
idx = 0
prompt = retriever.generate_label_prompt( prompt = retriever.generate_label_prompt(
idx, idx,
ice[idx], ice[idx],
@ -102,7 +107,8 @@ def print_prompts(model_cfg, dataset_cfg):
prompt_template=prompt_template, prompt_template=prompt_template,
remain_sep=None) remain_sep=None)
if max_seq_len is not None: if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(prompt) prompt_token_num = model.get_token_len_from_template(
prompt)
while len(ice_idx_list[idx] while len(ice_idx_list[idx]
) > 0 and prompt_token_num > max_seq_len: ) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx_list[idx]) num_ice = len(ice_idx_list[idx])
@ -129,7 +135,7 @@ def print_prompts(model_cfg, dataset_cfg):
print(prompt) print(prompt)
print('-' * 100) print('-' * 100)
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]: elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
idx, ice_idx = 0, ice_idx_list[0] ice_idx = ice_idx_list[idx]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template) ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task( prompt = retriever.generate_prompt_for_generate_task(
idx, idx,
@ -154,7 +160,8 @@ def print_prompts(model_cfg, dataset_cfg):
'gen_field_replace_token', ''), 'gen_field_replace_token', ''),
ice_template=ice_template, ice_template=ice_template,
prompt_template=prompt_template) prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(prompt) prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}') print(f'Number of tokens: {prompt_token_num}')
if model is not None: if model is not None:
prompt = model.parse_template(prompt, mode='gen') prompt = model.parse_template(prompt, mode='gen')
@ -201,7 +208,7 @@ def main():
dataset = list(dataset2cfg.keys())[0] dataset = list(dataset2cfg.keys())[0]
model_cfg = model2cfg[model] model_cfg = model2cfg[model]
dataset_cfg = dataset2cfg[dataset] dataset_cfg = dataset2cfg[dataset]
print_prompts(model_cfg, dataset_cfg) print_prompts(model_cfg, dataset_cfg, args.count)
else: else:
for model_abbr, model_cfg in model2cfg.items(): for model_abbr, model_cfg in model2cfg.items():
for dataset_abbr, dataset_cfg in dataset2cfg.items(): for dataset_abbr, dataset_cfg in dataset2cfg.items():
@ -209,7 +216,7 @@ def main():
print(f'[MODEL]: {model_abbr}') print(f'[MODEL]: {model_abbr}')
print(f'[DATASET]: {dataset_abbr}') print(f'[DATASET]: {dataset_abbr}')
print('---') print('---')
print_prompts(model_cfg, dataset_cfg) print_prompts(model_cfg, dataset_cfg, args.count)
print('=' * 65, '[END]', '=' * 65) print('=' * 65, '[END]', '=' * 65)
print() print()