From fa7978fe0857d47640da5b9bd4a4317189d6668a Mon Sep 17 00:00:00 2001 From: cdpath <11472839+cdpath@users.noreply.github.com> Date: Tue, 29 Aug 2023 20:51:10 +0800 Subject: [PATCH] [Feature] Support sample count in prompt_viewer (#273) * support sample count in prompt_viewer * update --------- Co-authored-by: Leymore --- tools/prompt_viewer.py | 147 +++++++++++++++++++++-------------------- 1 file changed, 77 insertions(+), 70 deletions(-) diff --git a/tools/prompt_viewer.py b/tools/prompt_viewer.py index ff2be2c2..35280b1f 100644 --- a/tools/prompt_viewer.py +++ b/tools/prompt_viewer.py @@ -22,6 +22,11 @@ def parse_args(): '--pattern', type=str, help='To match the dataset abbr.') + parser.add_argument('-c', + '--count', + type=int, + default=1, + help='Number of prompts to print') args = parser.parse_args() return args @@ -40,7 +45,7 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]: return dataset2cfg -def print_prompts(model_cfg, dataset_cfg): +def print_prompts(model_cfg, dataset_cfg, count=1): # TODO: A really dirty method that copies code from PPLInferencer and # GenInferencer. In the future, the prompt extraction code should be # extracted and generalized as a static method in these Inferencers @@ -79,90 +84,92 @@ def print_prompts(model_cfg, dataset_cfg): assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \ 'Only PPLInferencer and GenInferencer are supported' - if infer_cfg.inferencer.type == PPLInferencer: - labels = retriever.get_labels(ice_template=ice_template, - prompt_template=prompt_template) - ice = [ - retriever.generate_ice(ice_idx_list[idx], - ice_template=ice_template) - for idx in range(len(ice_idx_list)) - ] - print('-' * 100) - print('ICE Template:') - print('-' * 100) - print(ice[0]) - print('-' * 100) - for label in labels: - idx = 0 - prompt = retriever.generate_label_prompt( + for idx in range(min(count, len(ice_idx_list))): + if infer_cfg.inferencer.type == PPLInferencer: + labels = retriever.get_labels(ice_template=ice_template, + prompt_template=prompt_template) + ice = [ + retriever.generate_ice(ice_idx_list[_idx], + ice_template=ice_template) + for _idx in range(len(ice_idx_list)) + ] + print('-' * 100) + print('ICE Template:') + print('-' * 100) + print(ice[0]) + print('-' * 100) + for label in labels: + prompt = retriever.generate_label_prompt( + idx, + ice[idx], + label, + ice_template=ice_template, + prompt_template=prompt_template, + remain_sep=None) + if max_seq_len is not None: + prompt_token_num = model.get_token_len_from_template( + prompt) + while len(ice_idx_list[idx] + ) > 0 and prompt_token_num > max_seq_len: + num_ice = len(ice_idx_list[idx]) + print(f'Truncating ice {num_ice} -> {num_ice - 1}', + f'Number of tokens: {prompt_token_num} -> ...') + ice_idx_list[idx] = ice_idx_list[idx][:-1] + ice[idx] = retriever.generate_ice( + ice_idx_list[idx], ice_template=ice_template) + prompt = retriever.generate_label_prompt( + idx, + ice[idx], + label, + ice_template=ice_template, + prompt_template=prompt_template) + prompt_token_num = model.get_token_len_from_template( + prompt) + print(f'Number of tokens: {prompt_token_num}') + if model is not None: + prompt = model.parse_template(prompt, mode='ppl') + print('-' * 100) + print(f'Label: {label}') + print('Sample prompt:') + print('-' * 100) + print(prompt) + print('-' * 100) + elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]: + ice_idx = ice_idx_list[idx] + ice = retriever.generate_ice(ice_idx, ice_template=ice_template) + prompt = retriever.generate_prompt_for_generate_task( idx, - ice[idx], - label, + ice, + gen_field_replace_token=infer_cfg.inferencer.get( + 'gen_field_replace_token', ''), ice_template=ice_template, - prompt_template=prompt_template, - remain_sep=None) + prompt_template=prompt_template) if max_seq_len is not None: prompt_token_num = model.get_token_len_from_template(prompt) - while len(ice_idx_list[idx] - ) > 0 and prompt_token_num > max_seq_len: - num_ice = len(ice_idx_list[idx]) + while len(ice_idx) > 0 and prompt_token_num > max_seq_len: + num_ice = len(ice_idx) print(f'Truncating ice {num_ice} -> {num_ice - 1}', f'Number of tokens: {prompt_token_num} -> ...') - ice_idx_list[idx] = ice_idx_list[idx][:-1] - ice[idx] = retriever.generate_ice( - ice_idx_list[idx], ice_template=ice_template) - prompt = retriever.generate_label_prompt( + ice_idx = ice_idx[:-1] + ice = retriever.generate_ice(ice_idx, + ice_template=ice_template) + prompt = retriever.generate_prompt_for_generate_task( idx, - ice[idx], - label, + ice, + gen_field_replace_token=infer_cfg.inferencer.get( + 'gen_field_replace_token', ''), ice_template=ice_template, prompt_template=prompt_template) prompt_token_num = model.get_token_len_from_template( prompt) - print(f'Number of tokens: {prompt_token_num}') + print(f'Number of tokens: {prompt_token_num}') if model is not None: - prompt = model.parse_template(prompt, mode='ppl') + prompt = model.parse_template(prompt, mode='gen') print('-' * 100) - print(f'Label: {label}') print('Sample prompt:') print('-' * 100) print(prompt) print('-' * 100) - elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]: - idx, ice_idx = 0, ice_idx_list[0] - ice = retriever.generate_ice(ice_idx, ice_template=ice_template) - prompt = retriever.generate_prompt_for_generate_task( - idx, - ice, - gen_field_replace_token=infer_cfg.inferencer.get( - 'gen_field_replace_token', ''), - ice_template=ice_template, - prompt_template=prompt_template) - if max_seq_len is not None: - prompt_token_num = model.get_token_len_from_template(prompt) - while len(ice_idx) > 0 and prompt_token_num > max_seq_len: - num_ice = len(ice_idx) - print(f'Truncating ice {num_ice} -> {num_ice - 1}', - f'Number of tokens: {prompt_token_num} -> ...') - ice_idx = ice_idx[:-1] - ice = retriever.generate_ice(ice_idx, - ice_template=ice_template) - prompt = retriever.generate_prompt_for_generate_task( - idx, - ice, - gen_field_replace_token=infer_cfg.inferencer.get( - 'gen_field_replace_token', ''), - ice_template=ice_template, - prompt_template=prompt_template) - prompt_token_num = model.get_token_len_from_template(prompt) - print(f'Number of tokens: {prompt_token_num}') - if model is not None: - prompt = model.parse_template(prompt, mode='gen') - print('-' * 100) - print('Sample prompt:') - print('-' * 100) - print(prompt) - print('-' * 100) def main(): @@ -201,7 +208,7 @@ def main(): dataset = list(dataset2cfg.keys())[0] model_cfg = model2cfg[model] dataset_cfg = dataset2cfg[dataset] - print_prompts(model_cfg, dataset_cfg) + print_prompts(model_cfg, dataset_cfg, args.count) else: for model_abbr, model_cfg in model2cfg.items(): for dataset_abbr, dataset_cfg in dataset2cfg.items(): @@ -209,7 +216,7 @@ def main(): print(f'[MODEL]: {model_abbr}') print(f'[DATASET]: {dataset_abbr}') print('---') - print_prompts(model_cfg, dataset_cfg) + print_prompts(model_cfg, dataset_cfg, args.count) print('=' * 65, '[END]', '=' * 65) print()