[Feature] Support sample count in prompt_viewer (#273)

* support sample count in prompt_viewer

* update

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
cdpath 2023-08-29 20:51:10 +08:00 committed by GitHub
parent c26ecdb1b0
commit fa7978fe08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,6 +22,11 @@ def parse_args():
'--pattern', '--pattern',
type=str, type=str,
help='To match the dataset abbr.') help='To match the dataset abbr.')
parser.add_argument('-c',
'--count',
type=int,
default=1,
help='Number of prompts to print')
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -40,7 +45,7 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
return dataset2cfg return dataset2cfg
def print_prompts(model_cfg, dataset_cfg): def print_prompts(model_cfg, dataset_cfg, count=1):
# TODO: A really dirty method that copies code from PPLInferencer and # TODO: A really dirty method that copies code from PPLInferencer and
# GenInferencer. In the future, the prompt extraction code should be # GenInferencer. In the future, the prompt extraction code should be
# extracted and generalized as a static method in these Inferencers # extracted and generalized as a static method in these Inferencers
@ -79,90 +84,92 @@ def print_prompts(model_cfg, dataset_cfg):
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \ assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
'Only PPLInferencer and GenInferencer are supported' 'Only PPLInferencer and GenInferencer are supported'
if infer_cfg.inferencer.type == PPLInferencer: for idx in range(min(count, len(ice_idx_list))):
labels = retriever.get_labels(ice_template=ice_template, if infer_cfg.inferencer.type == PPLInferencer:
prompt_template=prompt_template) labels = retriever.get_labels(ice_template=ice_template,
ice = [ prompt_template=prompt_template)
retriever.generate_ice(ice_idx_list[idx], ice = [
ice_template=ice_template) retriever.generate_ice(ice_idx_list[_idx],
for idx in range(len(ice_idx_list)) ice_template=ice_template)
] for _idx in range(len(ice_idx_list))
print('-' * 100) ]
print('ICE Template:') print('-' * 100)
print('-' * 100) print('ICE Template:')
print(ice[0]) print('-' * 100)
print('-' * 100) print(ice[0])
for label in labels: print('-' * 100)
idx = 0 for label in labels:
prompt = retriever.generate_label_prompt( prompt = retriever.generate_label_prompt(
idx,
ice[idx],
label,
ice_template=ice_template,
prompt_template=prompt_template,
remain_sep=None)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(
prompt)
while len(ice_idx_list[idx]
) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx_list[idx])
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx_list[idx] = ice_idx_list[idx][:-1]
ice[idx] = retriever.generate_ice(
ice_idx_list[idx], ice_template=ice_template)
prompt = retriever.generate_label_prompt(
idx,
ice[idx],
label,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='ppl')
print('-' * 100)
print(f'Label: {label}')
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
ice_idx = ice_idx_list[idx]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx, idx,
ice[idx], ice,
label, gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template, ice_template=ice_template,
prompt_template=prompt_template, prompt_template=prompt_template)
remain_sep=None)
if max_seq_len is not None: if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(prompt) prompt_token_num = model.get_token_len_from_template(prompt)
while len(ice_idx_list[idx] while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
) > 0 and prompt_token_num > max_seq_len: num_ice = len(ice_idx)
num_ice = len(ice_idx_list[idx])
print(f'Truncating ice {num_ice} -> {num_ice - 1}', print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...') f'Number of tokens: {prompt_token_num} -> ...')
ice_idx_list[idx] = ice_idx_list[idx][:-1] ice_idx = ice_idx[:-1]
ice[idx] = retriever.generate_ice( ice = retriever.generate_ice(ice_idx,
ice_idx_list[idx], ice_template=ice_template) ice_template=ice_template)
prompt = retriever.generate_label_prompt( prompt = retriever.generate_prompt_for_generate_task(
idx, idx,
ice[idx], ice,
label, gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template, ice_template=ice_template,
prompt_template=prompt_template) prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template( prompt_token_num = model.get_token_len_from_template(
prompt) prompt)
print(f'Number of tokens: {prompt_token_num}') print(f'Number of tokens: {prompt_token_num}')
if model is not None: if model is not None:
prompt = model.parse_template(prompt, mode='ppl') prompt = model.parse_template(prompt, mode='gen')
print('-' * 100) print('-' * 100)
print(f'Label: {label}')
print('Sample prompt:') print('Sample prompt:')
print('-' * 100) print('-' * 100)
print(prompt) print(prompt)
print('-' * 100) print('-' * 100)
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
idx, ice_idx = 0, ice_idx_list[0]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(prompt)
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx)
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(prompt)
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='gen')
print('-' * 100)
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
def main(): def main():
@ -201,7 +208,7 @@ def main():
dataset = list(dataset2cfg.keys())[0] dataset = list(dataset2cfg.keys())[0]
model_cfg = model2cfg[model] model_cfg = model2cfg[model]
dataset_cfg = dataset2cfg[dataset] dataset_cfg = dataset2cfg[dataset]
print_prompts(model_cfg, dataset_cfg) print_prompts(model_cfg, dataset_cfg, args.count)
else: else:
for model_abbr, model_cfg in model2cfg.items(): for model_abbr, model_cfg in model2cfg.items():
for dataset_abbr, dataset_cfg in dataset2cfg.items(): for dataset_abbr, dataset_cfg in dataset2cfg.items():
@ -209,7 +216,7 @@ def main():
print(f'[MODEL]: {model_abbr}') print(f'[MODEL]: {model_abbr}')
print(f'[DATASET]: {dataset_abbr}') print(f'[DATASET]: {dataset_abbr}')
print('---') print('---')
print_prompts(model_cfg, dataset_cfg) print_prompts(model_cfg, dataset_cfg, args.count)
print('=' * 65, '[END]', '=' * 65) print('=' * 65, '[END]', '=' * 65)
print() print()