mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
parent
191a3f6f9d
commit
d17a5b94fa
@ -21,8 +21,7 @@ class HuggingfaceEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self, metric: str, seed: int = 0) -> None:
|
||||
self.metric = metric
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
self.seed = seed
|
||||
super().__init__()
|
||||
|
||||
def _preprocess(self, predictions: List, references: List) -> dict:
|
||||
@ -61,6 +60,11 @@ class HuggingfaceEvaluator(BaseEvaluator):
|
||||
Returns:
|
||||
dict: calculated scores.
|
||||
"""
|
||||
random_state = random.getstate()
|
||||
np_random_state = np.random.get_state()
|
||||
|
||||
random.seed(self.seed)
|
||||
np.random.seed(self.seed)
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error':
|
||||
@ -70,7 +74,10 @@ class HuggingfaceEvaluator(BaseEvaluator):
|
||||
}
|
||||
metric = evaluate.load(self.metric)
|
||||
scores = metric.compute(**self._preprocess(predictions, references))
|
||||
return self._postprocess(scores)
|
||||
result = self._postprocess(scores)
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(np_random_state)
|
||||
return result
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
|
Loading…
Reference in New Issue
Block a user