mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Refactor] Move fix_id_list to Retriever (#442)
* [Refactor] Move fix_id_list to Retriever * update * move to base * fix
This commit is contained in:
parent
767c12a660
commit
119bfd1569
@ -23,8 +23,8 @@ CoLA_infer_cfg = dict(
|
||||
},
|
||||
ice_token='</E>',
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[17, 18, 19, 20, 21]))
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[17, 18, 19, 20, 21]),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
CoLA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
||||
|
@ -22,8 +22,8 @@ QQP_infer_cfg = dict(
|
||||
},
|
||||
ice_token='</E>',
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]))
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
QQP_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
||||
|
@ -161,8 +161,8 @@ for _split in ["val", "test"]:
|
||||
]),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
ceval_eval_cfg = dict(
|
||||
|
@ -161,8 +161,8 @@ for _split in ["val"]:
|
||||
]),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
ceval_eval_cfg = dict(
|
||||
|
@ -163,8 +163,8 @@ for _split in ["val"]:
|
||||
},
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
@ -163,8 +163,8 @@ for _split in ["val", "test"]:
|
||||
},
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
@ -28,8 +28,8 @@ cmb_infer_cfg = dict(
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
cmb_datasets.append(
|
||||
|
@ -96,8 +96,8 @@ for _name in cmmlu_all_sets:
|
||||
]),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
cmmlu_eval_cfg = dict(
|
||||
|
@ -98,8 +98,8 @@ for _name in cmmlu_all_sets:
|
||||
},
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
@ -29,8 +29,8 @@ mmlu_infer_cfg = dict(
|
||||
dict(role='BOT', prompt='{target}\n')
|
||||
])),
|
||||
prompt_template=mmlu_prompt_template,
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]))
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
|
@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
|
@ -87,8 +87,8 @@ for _name in mmlu_all_sets:
|
||||
f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:",
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
|
@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
|
@ -93,8 +93,8 @@ for _name in mmlu_all_sets:
|
||||
},
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
@ -44,8 +44,8 @@ for k in [0, 1, 5]:
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=50),
|
||||
)
|
||||
|
||||
nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT")
|
||||
|
@ -45,8 +45,8 @@ for k in [0, 1, 5]:
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=50),
|
||||
)
|
||||
|
||||
triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT")
|
||||
|
@ -34,8 +34,8 @@ infer_cfg = dict(
|
||||
template='Solve the following questions.\n</E>{question}\n{answer}',
|
||||
ice_token="</E>"
|
||||
),
|
||||
retriever=dict(type=FixKRetriever), # Definition of how to retrieve in-context examples.
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # Method used to generate predictions.
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # Definition of how to retrieve in-context examples.
|
||||
inferencer=dict(type=GenInferencer), # Method used to generate predictions.
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -34,8 +34,8 @@ infer_cfg=dict(
|
||||
template='Solve the following questions.\n</E>{question}\n{answer}',
|
||||
ice_token="</E>"
|
||||
),
|
||||
retriever=dict(type=FixKRetriever), # 定义 in context example 的获取方式
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # 使用何种方式推理得到 prediction
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # 定义 in context example 的获取方式
|
||||
inferencer=dict(type=GenInferencer), # 使用何种方式推理得到 prediction
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -55,10 +55,7 @@ class AgentInferencer(BaseInferencer):
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in retriever.__class__.__name__:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
|
@ -59,7 +59,6 @@ class AttackInferencer(BaseInferencer):
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = None,
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
dataset_cfg: Optional[List[int]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
@ -78,7 +77,6 @@ class AttackInferencer(BaseInferencer):
|
||||
self.output_column = dataset_cfg['reader_cfg']['output_column']
|
||||
self.gen_field_replace_token = gen_field_replace_token
|
||||
self.max_out_len = max_out_len
|
||||
self.fix_id_list = fix_id_list
|
||||
|
||||
if self.model.is_api and save_every is None:
|
||||
save_every = 1
|
||||
@ -94,10 +92,7 @@ class AttackInferencer(BaseInferencer):
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in self.retriever.__class__.__name__:
|
||||
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = self.retriever.retrieve()
|
||||
ice_idx_list = self.retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa
|
||||
|
@ -25,9 +25,6 @@ class BaseInferencer:
|
||||
`JSON` file.
|
||||
output_json_filename (:obj:`str`, optional): File name for output
|
||||
`JSON` file.
|
||||
api_name (:obj:`str`, optional): Name of API service.
|
||||
call_api (:obj:`bool`): If ``True``, an API for LM models will be used,
|
||||
determined by :obj:`api_name`.
|
||||
"""
|
||||
model = None
|
||||
|
||||
@ -38,8 +35,15 @@ class BaseInferencer:
|
||||
batch_size: Optional[int] = 1,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
if fix_id_list:
|
||||
raise ValueError('Passing fix_id_list to Inferencer is no longer '
|
||||
'allowed. Please pass it to FixKRetriever '
|
||||
'instead.')
|
||||
|
||||
self.model = model
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
|
@ -54,7 +54,6 @@ class CLPInferencer(BaseInferencer):
|
||||
batch_size: Optional[int] = 1,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
single_token: bool = True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
@ -66,7 +65,6 @@ class CLPInferencer(BaseInferencer):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.fix_id_list = fix_id_list
|
||||
# TODO: support multiple token
|
||||
assert single_token, 'Only support single token choice currently.'
|
||||
self.single_token = single_token
|
||||
@ -103,10 +101,7 @@ class CLPInferencer(BaseInferencer):
|
||||
raise ValueError(err_msg)
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if self.fix_id_list:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate in-context examples for testing inputs
|
||||
for idx in range(len(ice_idx_list)):
|
||||
|
@ -51,7 +51,6 @@ class GenInferencer(BaseInferencer):
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = None,
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
@ -64,7 +63,6 @@ class GenInferencer(BaseInferencer):
|
||||
|
||||
self.gen_field_replace_token = gen_field_replace_token
|
||||
self.max_out_len = max_out_len
|
||||
self.fix_id_list = fix_id_list
|
||||
|
||||
if self.model.is_api and save_every is None:
|
||||
save_every = 1
|
||||
@ -85,10 +83,7 @@ class GenInferencer(BaseInferencer):
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in retriever.__class__.__name__:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
|
||||
@ -220,10 +215,7 @@ class GLMChoiceInferencer(GenInferencer):
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in retriever.__class__.__name__:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
|
||||
|
@ -41,7 +41,6 @@ class PPLInferencer(BaseInferencer):
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
labels: Optional[List] = None,
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
@ -53,7 +52,6 @@ class PPLInferencer(BaseInferencer):
|
||||
)
|
||||
|
||||
self.labels = labels
|
||||
self.fix_id_list = fix_id_list
|
||||
|
||||
def inference(self,
|
||||
retriever: BaseRetriever,
|
||||
@ -75,10 +73,7 @@ class PPLInferencer(BaseInferencer):
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if self.fix_id_list:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Get labels of all the classes
|
||||
if self.labels is None:
|
||||
|
@ -52,7 +52,6 @@ class SCInferencer(BaseInferencer):
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = None,
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
sc_size: Optional[int] = 1,
|
||||
infer_type: Optional[str] = '',
|
||||
generation_kwargs: dict = {},
|
||||
@ -69,7 +68,6 @@ class SCInferencer(BaseInferencer):
|
||||
self.gen_field_replace_token = gen_field_replace_token
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.max_out_len = max_out_len
|
||||
self.fix_id_list = fix_id_list
|
||||
self.sc_size = sc_size
|
||||
|
||||
if self.model.is_api and save_every is None:
|
||||
@ -91,10 +89,7 @@ class SCInferencer(BaseInferencer):
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in retriever.__class__.__name__:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
|
||||
|
@ -46,7 +46,6 @@ class ToTInferencer(GenInferencer):
|
||||
`save_every` epochs.
|
||||
generation_kwargs (:obj:`Dict`, optional): Parameters for the
|
||||
:obj:`model.generate()` method.
|
||||
fix_id_list (:obj:`List[int]`, optional): List of indices to fix
|
||||
naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of
|
||||
ToT + BFS.
|
||||
prompt_wrapper (:obj:`dict`): wrapper for prompts
|
||||
@ -76,7 +75,6 @@ class ToTInferencer(GenInferencer):
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = None,
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
naive_run: bool = False,
|
||||
prompt_wrapper: dict = {},
|
||||
prompt_sample: str = 'standard',
|
||||
@ -97,7 +95,6 @@ class ToTInferencer(GenInferencer):
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
save_every=save_every,
|
||||
fix_id_list=fix_id_list,
|
||||
sc_size=n_evaluate_sample,
|
||||
**kwargs,
|
||||
)
|
||||
@ -319,10 +316,7 @@ class ToTInferencer(GenInferencer):
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in retriever.__class__.__name__:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
|
||||
|
@ -19,6 +19,8 @@ class FixKRetriever(BaseRetriever):
|
||||
Args:
|
||||
dataset (`BaseDataset`): Any BaseDataset instances.
|
||||
Attributes of ``reader``, ``train`` and ``test`` will be used.
|
||||
fix_id_list (List[int]): List of in-context example indices for every
|
||||
test prompts.
|
||||
ice_separator (`Optional[str]`): The separator between each in-context
|
||||
example template when origin `PromptTemplate` is provided. Defaults
|
||||
to '\n'.
|
||||
@ -31,22 +33,19 @@ class FixKRetriever(BaseRetriever):
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
fix_id_list: List[int],
|
||||
ice_separator: Optional[str] = '\n',
|
||||
ice_eos_token: Optional[str] = '\n',
|
||||
ice_num: Optional[int] = 1) -> None:
|
||||
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
|
||||
self.fix_id_list = fix_id_list
|
||||
|
||||
def retrieve(self, id_list: List[int]):
|
||||
"""Retrieve the in-context example index for each test example.
|
||||
|
||||
Args:
|
||||
id_list (List[int]): List of in-context example indices for every
|
||||
test prompts.
|
||||
"""
|
||||
def retrieve(self):
|
||||
"""Retrieve the in-context example index for each test example."""
|
||||
num_idx = len(self.index_ds)
|
||||
for idx in id_list:
|
||||
for idx in self.fix_id_list:
|
||||
assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'
|
||||
rtr_idx_list = []
|
||||
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
|
||||
rtr_idx_list.append(id_list)
|
||||
rtr_idx_list.append(self.fix_id_list)
|
||||
return rtr_idx_list
|
||||
|
@ -56,6 +56,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
|
||||
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
|
||||
for k, v in dataset_cfg.infer_cfg.items():
|
||||
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
|
||||
# A compromise for the hash consistency
|
||||
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
|
||||
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
|
||||
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
|
||||
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
|
||||
hash_object = hashlib.sha256(d_json.encode())
|
||||
return hash_object.hexdigest()
|
||||
|
@ -61,7 +61,6 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
|
||||
|
||||
infer_cfg = dataset_cfg.get('infer_cfg')
|
||||
|
||||
fix_id_list = infer_cfg.inferencer.get('fix_id_list', [])
|
||||
dataset = build_dataset_from_cfg(dataset_cfg)
|
||||
|
||||
ice_template = None
|
||||
@ -76,10 +75,7 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
|
||||
infer_cfg['retriever']['dataset'] = dataset
|
||||
retriever = ICL_RETRIEVERS.build(infer_cfg['retriever'])
|
||||
|
||||
if fix_id_list:
|
||||
ice_idx_list = retriever.retrieve(fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
|
||||
'Only PPLInferencer and GenInferencer are supported'
|
||||
|
@ -45,6 +45,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
|
||||
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
|
||||
for k, v in dataset_cfg.infer_cfg.items():
|
||||
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
|
||||
# A compromise for the hash consistency
|
||||
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
|
||||
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
|
||||
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
|
||||
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
|
||||
hash_object = hashlib.sha256(d_json.encode())
|
||||
return hash_object.hexdigest()
|
||||
|
Loading…
Reference in New Issue
Block a user