#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright @2024 AI. Inspur Inc. # # @author: sunxian # @date: 2024/07/21 # import torch from hakkero.dataset import IGNORE_INDEX def lm_cross_entropy(logits, labels, reduction="mean"): # we do not do the stupid shifting as in huggingface since we shift it in dataset logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) return torch.nn.functional.cross_entropy(logits, labels, ignore_index=IGNORE_INDEX, reduction=reduction) def lm_z_loss(logits): return logits.max(-1).values().square().mean()