diff --git a/opencompass/openicl/icl_retriever/icl_mdl_retriever.py b/opencompass/openicl/icl_retriever/icl_mdl_retriever.py index 43fe12d1..f92e1acf 100644 --- a/opencompass/openicl/icl_retriever/icl_mdl_retriever.py +++ b/opencompass/openicl/icl_retriever/icl_mdl_retriever.py @@ -141,6 +141,7 @@ class MDLRetriever(TopkRetriever): """Retrieve the in-context example index for each test example.""" return self.topk_search() + @torch.no_grad() def cal_ce(self, input_texts: List[str], mask_length=None): if self.metric_model is None: logger.info(