From d17a5b94fa8afbe79ef6e556f05417851fbec7c8 Mon Sep 17 00:00:00 2001 From: Haodong Duan Date: Thu, 3 Aug 2023 14:54:38 +0800 Subject: [PATCH] [Refine] Refine PR #122 (#123) * update * update --- .../openicl/icl_evaluator/icl_hf_evaluator.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py index 3e2e4bcc..9d40aaf3 100644 --- a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py @@ -21,8 +21,7 @@ class HuggingfaceEvaluator(BaseEvaluator): def __init__(self, metric: str, seed: int = 0) -> None: self.metric = metric - random.seed(seed) - np.random.seed(seed) + self.seed = seed super().__init__() def _preprocess(self, predictions: List, references: List) -> dict: @@ -61,6 +60,11 @@ class HuggingfaceEvaluator(BaseEvaluator): Returns: dict: calculated scores. """ + random_state = random.getstate() + np_random_state = np.random.get_state() + + random.seed(self.seed) + np.random.seed(self.seed) if len(predictions) != len(references): return { 'error': @@ -70,7 +74,10 @@ class HuggingfaceEvaluator(BaseEvaluator): } metric = evaluate.load(self.metric) scores = metric.compute(**self._preprocess(predictions, references)) - return self._postprocess(scores) + result = self._postprocess(scores) + random.setstate(random_state) + np.random.set_state(np_random_state) + return result @ICL_EVALUATORS.register_module()