[Refactor] Move fix_id_list to Retriever (#442)

* [Refactor] Move fix_id_list to Retriever

* update

* move to base

* fix
This commit is contained in:
Tong Gao 2023-10-06 23:53:41 -05:00 committed by GitHub
parent 767c12a660
commit 119bfd1569
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 68 additions and 98 deletions

View File

@ -23,8 +23,8 @@ CoLA_infer_cfg = dict(
},
ice_token='</E>',
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[17, 18, 19, 20, 21]))
retriever=dict(type=FixKRetriever, fix_id_list=[17, 18, 19, 20, 21]),
inferencer=dict(type=PPLInferencer))
CoLA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )

View File

@ -22,8 +22,8 @@ QQP_infer_cfg = dict(
},
ice_token='</E>',
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]))
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer))
QQP_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )

View File

@ -161,8 +161,8 @@ for _split in ["val", "test"]:
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
ceval_eval_cfg = dict(

View File

@ -161,8 +161,8 @@ for _split in ["val"]:
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
ceval_eval_cfg = dict(

View File

@ -163,8 +163,8 @@ for _split in ["val"]:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))

View File

@ -163,8 +163,8 @@ for _split in ["val", "test"]:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))

View File

@ -28,8 +28,8 @@ cmb_infer_cfg = dict(
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
cmb_datasets.append(

View File

@ -96,8 +96,8 @@ for _name in cmmlu_all_sets:
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
cmmlu_eval_cfg = dict(

View File

@ -98,8 +98,8 @@ for _name in cmmlu_all_sets:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))

View File

@ -29,8 +29,8 @@ mmlu_infer_cfg = dict(
dict(role='BOT', prompt='{target}\n')
])),
prompt_template=mmlu_prompt_template,
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]))
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer))
mmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),

View File

@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
mmlu_eval_cfg = dict(

View File

@ -87,8 +87,8 @@ for _name in mmlu_all_sets:
f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:",
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
mmlu_eval_cfg = dict(

View File

@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
mmlu_eval_cfg = dict(

View File

@ -93,8 +93,8 @@ for _name in mmlu_all_sets:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )

View File

@ -44,8 +44,8 @@ for k in [0, 1, 5]:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50),
)
nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT")

View File

@ -45,8 +45,8 @@ for k in [0, 1, 5]:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50),
)
triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT")

View File

@ -34,8 +34,8 @@ infer_cfg = dict(
template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>"
),
retriever=dict(type=FixKRetriever), # Definition of how to retrieve in-context examples.
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # Method used to generate predictions.
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # Definition of how to retrieve in-context examples.
inferencer=dict(type=GenInferencer), # Method used to generate predictions.
)
```

View File

@ -34,8 +34,8 @@ infer_cfg=dict(
template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>"
),
retriever=dict(type=FixKRetriever), # 定义 in context example 的获取方式
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # 使用何种方式推理得到 prediction
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # 定义 in context example 的获取方式
inferencer=dict(type=GenInferencer), # 使用何种方式推理得到 prediction
)
```

View File

@ -55,10 +55,7 @@ class AgentInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# Create tmp json file for saving intermediate results and future
# resuming

View File

@ -59,7 +59,6 @@ class AttackInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
dataset_cfg: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
@ -78,7 +77,6 @@ class AttackInferencer(BaseInferencer):
self.output_column = dataset_cfg['reader_cfg']['output_column']
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None:
save_every = 1
@ -94,10 +92,7 @@ class AttackInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in self.retriever.__class__.__name__:
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = self.retriever.retrieve()
ice_idx_list = self.retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa

View File

@ -25,9 +25,6 @@ class BaseInferencer:
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
api_name (:obj:`str`, optional): Name of API service.
call_api (:obj:`bool`): If ``True``, an API for LM models will be used,
determined by :obj:`api_name`.
"""
model = None
@ -38,8 +35,15 @@ class BaseInferencer:
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
**kwargs,
) -> None:
if fix_id_list:
raise ValueError('Passing fix_id_list to Inferencer is no longer '
'allowed. Please pass it to FixKRetriever '
'instead.')
self.model = model
self.max_seq_len = max_seq_len

View File

@ -54,7 +54,6 @@ class CLPInferencer(BaseInferencer):
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
single_token: bool = True,
**kwargs) -> None:
super().__init__(
@ -66,7 +65,6 @@ class CLPInferencer(BaseInferencer):
**kwargs,
)
self.fix_id_list = fix_id_list
# TODO: support multiple token
assert single_token, 'Only support single token choice currently.'
self.single_token = single_token
@ -103,10 +101,7 @@ class CLPInferencer(BaseInferencer):
raise ValueError(err_msg)
# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# 3. Generate in-context examples for testing inputs
for idx in range(len(ice_idx_list)):

View File

@ -51,7 +51,6 @@ class GenInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
@ -64,7 +63,6 @@ class GenInferencer(BaseInferencer):
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None:
save_every = 1
@ -85,10 +83,7 @@ class GenInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
@ -220,10 +215,7 @@ class GLMChoiceInferencer(GenInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(

View File

@ -41,7 +41,6 @@ class PPLInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
labels: Optional[List] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
@ -53,7 +52,6 @@ class PPLInferencer(BaseInferencer):
)
self.labels = labels
self.fix_id_list = fix_id_list
def inference(self,
retriever: BaseRetriever,
@ -75,10 +73,7 @@ class PPLInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# 3. Get labels of all the classes
if self.labels is None:

View File

@ -52,7 +52,6 @@ class SCInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
sc_size: Optional[int] = 1,
infer_type: Optional[str] = '',
generation_kwargs: dict = {},
@ -69,7 +68,6 @@ class SCInferencer(BaseInferencer):
self.gen_field_replace_token = gen_field_replace_token
self.generation_kwargs = generation_kwargs
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
self.sc_size = sc_size
if self.model.is_api and save_every is None:
@ -91,10 +89,7 @@ class SCInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(

View File

@ -46,7 +46,6 @@ class ToTInferencer(GenInferencer):
`save_every` epochs.
generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method.
fix_id_list (:obj:`List[int]`, optional): List of indices to fix
naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of
ToT + BFS.
prompt_wrapper (:obj:`dict`): wrapper for prompts
@ -76,7 +75,6 @@ class ToTInferencer(GenInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
naive_run: bool = False,
prompt_wrapper: dict = {},
prompt_sample: str = 'standard',
@ -97,7 +95,6 @@ class ToTInferencer(GenInferencer):
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
save_every=save_every,
fix_id_list=fix_id_list,
sc_size=n_evaluate_sample,
**kwargs,
)
@ -319,10 +316,7 @@ class ToTInferencer(GenInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(

View File

@ -19,6 +19,8 @@ class FixKRetriever(BaseRetriever):
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
fix_id_list (List[int]): List of in-context example indices for every
test prompts.
ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults
to '\n'.
@ -31,22 +33,19 @@ class FixKRetriever(BaseRetriever):
def __init__(self,
dataset,
fix_id_list: List[int],
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.fix_id_list = fix_id_list
def retrieve(self, id_list: List[int]):
"""Retrieve the in-context example index for each test example.
Args:
id_list (List[int]): List of in-context example indices for every
test prompts.
"""
def retrieve(self):
"""Retrieve the in-context example index for each test example."""
num_idx = len(self.index_ds)
for idx in id_list:
for idx in self.fix_id_list:
assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'
rtr_idx_list = []
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
rtr_idx_list.append(id_list)
rtr_idx_list.append(self.fix_id_list)
return rtr_idx_list

View File

@ -56,6 +56,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()

View File

@ -61,7 +61,6 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
infer_cfg = dataset_cfg.get('infer_cfg')
fix_id_list = infer_cfg.inferencer.get('fix_id_list', [])
dataset = build_dataset_from_cfg(dataset_cfg)
ice_template = None
@ -76,10 +75,7 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
infer_cfg['retriever']['dataset'] = dataset
retriever = ICL_RETRIEVERS.build(infer_cfg['retriever'])
if fix_id_list:
ice_idx_list = retriever.retrieve(fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
'Only PPLInferencer and GenInferencer are supported'

View File

@ -45,6 +45,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()