[Refine] Refine PR #122 (#123)

* update

* update
This commit is contained in:
Haodong Duan 2023-08-03 14:54:38 +08:00 committed by GitHub
parent 191a3f6f9d
commit d17a5b94fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,8 +21,7 @@ class HuggingfaceEvaluator(BaseEvaluator):
def __init__(self, metric: str, seed: int = 0) -> None: def __init__(self, metric: str, seed: int = 0) -> None:
self.metric = metric self.metric = metric
random.seed(seed) self.seed = seed
np.random.seed(seed)
super().__init__() super().__init__()
def _preprocess(self, predictions: List, references: List) -> dict: def _preprocess(self, predictions: List, references: List) -> dict:
@ -61,6 +60,11 @@ class HuggingfaceEvaluator(BaseEvaluator):
Returns: Returns:
dict: calculated scores. dict: calculated scores.
""" """
random_state = random.getstate()
np_random_state = np.random.get_state()
random.seed(self.seed)
np.random.seed(self.seed)
if len(predictions) != len(references): if len(predictions) != len(references):
return { return {
'error': 'error':
@ -70,7 +74,10 @@ class HuggingfaceEvaluator(BaseEvaluator):
} }
metric = evaluate.load(self.metric) metric = evaluate.load(self.metric)
scores = metric.compute(**self._preprocess(predictions, references)) scores = metric.compute(**self._preprocess(predictions, references))
return self._postprocess(scores) result = self._postprocess(scores)
random.setstate(random_state)
np.random.set_state(np_random_state)
return result
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()