23 lines
606 B
Python
23 lines
606 B
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
#
|
||
|
# Copyright @2024 AI. Inspur Inc.
|
||
|
#
|
||
|
# @author: sunxian <sunxian@inspur.com>
|
||
|
# @date: 2024/07/21
|
||
|
#
|
||
|
|
||
|
import torch
|
||
|
from hakkero.dataset import IGNORE_INDEX
|
||
|
|
||
|
|
||
|
def lm_cross_entropy(logits, labels, reduction="mean"):
|
||
|
# we do not do the stupid shifting as in huggingface since we shift it in dataset
|
||
|
logits = logits.view(-1, logits.shape[-1])
|
||
|
labels = labels.view(-1)
|
||
|
return torch.nn.functional.cross_entropy(logits, labels, ignore_index=IGNORE_INDEX, reduction=reduction)
|
||
|
|
||
|
|
||
|
def lm_z_loss(logits):
|
||
|
return logits.max(-1).values().square().mean()
|