vllm_hairuo/ihp/model/loss.py
2024-10-25 17:16:26 +08:00

23 lines
606 B
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 AI. Inspur Inc.
#
# @author: sunxian <sunxian@inspur.com>
# @date: 2024/07/21
#
import torch
from hakkero.dataset import IGNORE_INDEX
def lm_cross_entropy(logits, labels, reduction="mean"):
# we do not do the stupid shifting as in huggingface since we shift it in dataset
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
return torch.nn.functional.cross_entropy(logits, labels, ignore_index=IGNORE_INDEX, reduction=reduction)
def lm_z_loss(logits):
return logits.max(-1).values().square().mean()