mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
231 lines
8.6 KiB
Python
231 lines
8.6 KiB
Python
import copy
|
|
import torch
|
|
import transformers
|
|
import tokenizers
|
|
|
|
from typing import Dict, Sequence
|
|
|
|
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX
|
|
from opencompass.models.ola import conversation as conversation_lib
|
|
from opencompass.models.ola.model import *
|
|
from opencompass.models.ola.arguments import DataArguments
|
|
from opencompass.models.ola.constants import SPEECH_TOKEN_INDEX
|
|
|
|
from packaging import version
|
|
|
|
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
|
|
|
|
|
|
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
|
|
|
|
def insert_separator(X, sep):
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
|
|
|
input_ids = []
|
|
offset = 0
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
offset = 1
|
|
input_ids.append(prompt_chunks[0][0])
|
|
|
|
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
|
|
input_ids.extend(x[offset:])
|
|
|
|
if return_tensors is not None:
|
|
if return_tensors == 'pt':
|
|
return torch.tensor(input_ids, dtype=torch.long)
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
return input_ids
|
|
|
|
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
|
|
|
def insert_separator(X, sep):
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
|
|
|
input_ids = []
|
|
offset = 0
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
offset = 1
|
|
input_ids.append(prompt_chunks[0][0])
|
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
|
input_ids.extend(x[offset:])
|
|
|
|
if return_tensors is not None:
|
|
if return_tensors == 'pt':
|
|
return torch.tensor(input_ids, dtype=torch.long)
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
return input_ids
|
|
|
|
def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech><image>')]
|
|
|
|
def insert_separator(X, sep):
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
|
|
|
input_ids = []
|
|
offset = 0
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
offset = 1
|
|
input_ids.append(prompt_chunks[0][0])
|
|
|
|
for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)):
|
|
input_ids.extend(x[offset:])
|
|
|
|
if return_tensors is not None:
|
|
if return_tensors == 'pt':
|
|
return torch.tensor(input_ids, dtype=torch.long)
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
return input_ids
|
|
|
|
def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>\nUser's question in speech: <speech>\n")]
|
|
|
|
def insert_separator(X, sep):
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
|
|
|
input_ids = []
|
|
offset = 0
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
offset = 1
|
|
input_ids.append(prompt_chunks[0][0])
|
|
|
|
nl_tokens = tokenizer("\n").input_ids
|
|
special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens]
|
|
|
|
for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)):
|
|
input_ids.extend(x[offset:])
|
|
|
|
if return_tensors is not None:
|
|
if return_tensors == 'pt':
|
|
return torch.tensor(input_ids, dtype=torch.long)
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
return input_ids
|
|
|
|
def preprocess_v1(
|
|
sources,
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
has_speech: bool = False
|
|
) -> Dict:
|
|
conv = conversation_lib.default_conversation.copy()
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
|
|
|
# Apply prompt templates
|
|
conversations = []
|
|
for i, source in enumerate(sources):
|
|
if roles[source[0]["from"]] != conv.roles[0]:
|
|
# Skip the first one if it is not from human
|
|
source = source[1:]
|
|
|
|
conv.messages = []
|
|
for j, sentence in enumerate(source):
|
|
role = roles[sentence["from"]]
|
|
assert role == conv.roles[j % 2], f"{i}"
|
|
conv.append_message(role, sentence["value"])
|
|
conversations.append(conv.get_prompt())
|
|
|
|
# Tokenize conversations
|
|
|
|
if has_speech:
|
|
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
|
else:
|
|
input_ids = tokenizer(
|
|
conversations,
|
|
return_tensors="pt",
|
|
padding="longest",
|
|
max_length=tokenizer.model_max_length,
|
|
truncation=True,
|
|
).input_ids
|
|
|
|
targets = input_ids.clone()
|
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
|
|
|
# Mask targets
|
|
sep = conv.sep + conv.roles[1] + ": "
|
|
for conversation, target in zip(conversations, targets):
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
|
|
|
rounds = conversation.split(conv.sep2)
|
|
cur_len = 1
|
|
target[:cur_len] = IGNORE_INDEX
|
|
for i, rou in enumerate(rounds):
|
|
if rou == "":
|
|
break
|
|
|
|
parts = rou.split(sep)
|
|
if len(parts) != 2:
|
|
break
|
|
parts[0] += sep
|
|
|
|
if has_speech:
|
|
round_len = len(tokenizer_speech_token(rou, tokenizer))
|
|
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
|
|
else:
|
|
round_len = len(tokenizer(rou).input_ids)
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
|
|
|
# FIXME: tokenizer bug
|
|
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
|
round_len -= 1
|
|
instruction_len -= 1
|
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
|
|
|
cur_len += round_len
|
|
target[cur_len:] = IGNORE_INDEX
|
|
|
|
if cur_len < tokenizer.model_max_length:
|
|
if cur_len != total_len:
|
|
target[:] = IGNORE_INDEX
|
|
print(
|
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
|
f" (ignored)"
|
|
)
|
|
|
|
return dict(
|
|
input_ids=input_ids,
|
|
labels=targets,
|
|
)
|
|
|
|
|
|
def preprocess_plain(
|
|
sources: Sequence[str],
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
) -> Dict:
|
|
# add end signal and concatenate together
|
|
conversations = []
|
|
for source in sources:
|
|
assert len(source) == 2
|
|
assert DEFAULT_SPEECH_TOKEN in source[0]['value']
|
|
source[0]['value'] = DEFAULT_SPEECH_TOKEN
|
|
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
|
conversations.append(conversation)
|
|
# tokenize conversations
|
|
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
|
targets = copy.deepcopy(input_ids)
|
|
for target, source in zip(targets, sources):
|
|
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer))
|
|
target[:tokenized_len] = IGNORE_INDEX
|
|
|
|
return dict(input_ids=input_ids, labels=targets)
|
|
|
|
|
|
def preprocess(
|
|
sources: Sequence[str],
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
has_speech: bool = False
|
|
) -> Dict:
|
|
"""
|
|
Given a list of sources, each is a conversation list. This transform:
|
|
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
|
2. Concatenate conversations together;
|
|
3. Tokenize the concatenated conversation;
|
|
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
|
"""
|
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
|
return preprocess_plain(sources, tokenizer)
|
|
if conversation_lib.default_conversation.version.startswith("v1"):
|
|
return preprocess_v1(sources, tokenizer, has_speech=has_speech)
|
|
raise NotImplementedError |