This commit is contained in:
zhangsongyang 2025-05-19 13:39:11 +00:00
parent 580a2b7980
commit 8274603540
3 changed files with 12 additions and 3 deletions

View File

@ -204,7 +204,11 @@ def math_postprocess_v2(text: str) -> str:
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MATHEvaluator(BaseEvaluator): class MATHEvaluator(BaseEvaluator):
def __init__(self, version='v1'): def __init__(self,
version='v1',
pred_postprocessor=None): # 可能需要接收父类__init__的参数
super().__init__(
pred_postprocessor=pred_postprocessor) # 调用父类的__init__
assert version in ['v1', 'v2'] assert version in ['v1', 'v2']
self.version = version self.version = version

View File

@ -280,7 +280,11 @@ class MusrDataset(BaseDataset):
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MusrEvaluator(BaseEvaluator): class MusrEvaluator(BaseEvaluator):
def __init__(self, answer_index_modifier=1, self_consistency_n=1): def __init__(self,
answer_index_modifier=1,
self_consistency_n=1,
pred_postprocessor=None):
super().__init__(pred_postprocessor=pred_postprocessor)
self.answer_index_modifier = answer_index_modifier self.answer_index_modifier = answer_index_modifier
self.self_consistency_n = self_consistency_n self.self_consistency_n = self_consistency_n

View File

@ -93,7 +93,8 @@ class BaseEvaluator:
return g_passk_details return g_passk_details
def pred_postprocess(self, predictions: List) -> Dict: def pred_postprocess(self, predictions: List) -> Dict:
if self.pred_postprocessor is None: if not hasattr(
self, 'pred_postprocessor') or self.pred_postprocessor is None:
return predictions return predictions
else: else:
kwargs = deepcopy(self.pred_postprocessor) kwargs = deepcopy(self.pred_postprocessor)