fix potential oom issue (#387)

This commit is contained in:
cdpath 2023-09-12 10:41:03 +08:00 committed by GitHub
parent b9b145c335
commit 722eb39526
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -141,6 +141,7 @@ class MDLRetriever(TopkRetriever):
"""Retrieve the in-context example index for each test example."""
return self.topk_search()
@torch.no_grad()
def cal_ce(self, input_texts: List[str], mask_length=None):
if self.metric_model is None:
logger.info(