[Feature] Support sample count in prompt_viewer (#273)

* support sample count in prompt_viewer

* update

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
cdpath 2023-08-29 20:51:10 +08:00 committed by GitHub
parent c26ecdb1b0
commit fa7978fe08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,6 +22,11 @@ def parse_args():
'--pattern',
type=str,
help='To match the dataset abbr.')
parser.add_argument('-c',
'--count',
type=int,
default=1,
help='Number of prompts to print')
args = parser.parse_args()
return args
@ -40,7 +45,7 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
return dataset2cfg
def print_prompts(model_cfg, dataset_cfg):
def print_prompts(model_cfg, dataset_cfg, count=1):
# TODO: A really dirty method that copies code from PPLInferencer and
# GenInferencer. In the future, the prompt extraction code should be
# extracted and generalized as a static method in these Inferencers
@ -79,90 +84,92 @@ def print_prompts(model_cfg, dataset_cfg):
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
'Only PPLInferencer and GenInferencer are supported'
if infer_cfg.inferencer.type == PPLInferencer:
labels = retriever.get_labels(ice_template=ice_template,
prompt_template=prompt_template)
ice = [
retriever.generate_ice(ice_idx_list[idx],
ice_template=ice_template)
for idx in range(len(ice_idx_list))
]
print('-' * 100)
print('ICE Template:')
print('-' * 100)
print(ice[0])
print('-' * 100)
for label in labels:
idx = 0
prompt = retriever.generate_label_prompt(
for idx in range(min(count, len(ice_idx_list))):
if infer_cfg.inferencer.type == PPLInferencer:
labels = retriever.get_labels(ice_template=ice_template,
prompt_template=prompt_template)
ice = [
retriever.generate_ice(ice_idx_list[_idx],
ice_template=ice_template)
for _idx in range(len(ice_idx_list))
]
print('-' * 100)
print('ICE Template:')
print('-' * 100)
print(ice[0])
print('-' * 100)
for label in labels:
prompt = retriever.generate_label_prompt(
idx,
ice[idx],
label,
ice_template=ice_template,
prompt_template=prompt_template,
remain_sep=None)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(
prompt)
while len(ice_idx_list[idx]
) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx_list[idx])
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx_list[idx] = ice_idx_list[idx][:-1]
ice[idx] = retriever.generate_ice(
ice_idx_list[idx], ice_template=ice_template)
prompt = retriever.generate_label_prompt(
idx,
ice[idx],
label,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='ppl')
print('-' * 100)
print(f'Label: {label}')
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
ice_idx = ice_idx_list[idx]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice[idx],
label,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template,
remain_sep=None)
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(prompt)
while len(ice_idx_list[idx]
) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx_list[idx])
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx)
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx_list[idx] = ice_idx_list[idx][:-1]
ice[idx] = retriever.generate_ice(
ice_idx_list[idx], ice_template=ice_template)
prompt = retriever.generate_label_prompt(
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice[idx],
label,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}')
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='ppl')
prompt = model.parse_template(prompt, mode='gen')
print('-' * 100)
print(f'Label: {label}')
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
idx, ice_idx = 0, ice_idx_list[0]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(prompt)
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx)
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(prompt)
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='gen')
print('-' * 100)
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
def main():
@ -201,7 +208,7 @@ def main():
dataset = list(dataset2cfg.keys())[0]
model_cfg = model2cfg[model]
dataset_cfg = dataset2cfg[dataset]
print_prompts(model_cfg, dataset_cfg)
print_prompts(model_cfg, dataset_cfg, args.count)
else:
for model_abbr, model_cfg in model2cfg.items():
for dataset_abbr, dataset_cfg in dataset2cfg.items():
@ -209,7 +216,7 @@ def main():
print(f'[MODEL]: {model_abbr}')
print(f'[DATASET]: {dataset_abbr}')
print('---')
print_prompts(model_cfg, dataset_cfg)
print_prompts(model_cfg, dataset_cfg, args.count)
print('=' * 65, '[END]', '=' * 65)
print()