mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support SEED-Bench (#203)
* support seedbench * update docstrings * update * update * update * update according to review * rebase * fix lint * update
This commit is contained in:
parent
ae3c1869da
commit
0fa2482661
63
configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py
Normal file
63
configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py
Normal file
@ -0,0 +1,63 @@
|
||||
from opencompass.multimodal.models.minigpt_4 import MiniGPT4SEEDBenchPromptConstructor # noqa
|
||||
|
||||
# dataloader settings
|
||||
image_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'answer', 'choices', 'data_type', 'question_type_id',
|
||||
'index', 'data_path', 'question_id'
|
||||
])
|
||||
]
|
||||
video_pipeline = [
|
||||
dict(type='mmaction.Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(type='mmaction.CenterCrop', crop_size=224),
|
||||
dict(type='Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'answer', 'choices', 'data_type', 'question_type_id',
|
||||
'index', 'data_path', 'question_id'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(
|
||||
type='opencompass.SEEDBenchDataset',
|
||||
ann_file='data/seedbench/SEED-Bench.json',
|
||||
cc3m_path='data/seedbench/SEED-Bench-image',
|
||||
sthv2_path='data/seedbench/sthv2/videos',
|
||||
epic_kitchens_path='data/seedbench/3h91syskeag572hl6tvuovwv4d/videos/test',
|
||||
breakfast_path='data/seedbench/BreakfastII_15fps_qvga_sync',
|
||||
image_pipeline=image_pipeline,
|
||||
video_pipeline=video_pipeline,
|
||||
only_image=True)
|
||||
|
||||
minigpt_4_seedbench_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
minigpt_4_seedbench_model = dict(
|
||||
type='minigpt-4',
|
||||
low_resource=False,
|
||||
llama_model='/path/to/vicuna/',
|
||||
prompt_constructor=dict(type=MiniGPT4SEEDBenchPromptConstructor,
|
||||
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||
reply_prompt='###Assistant:'),
|
||||
post_processor=None,
|
||||
mode='loss')
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_seedbench_evaluator = [dict(type='opencompass.SEEDBenchAcc')]
|
||||
|
||||
minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth'
|
14
configs/multimodal/seedbench.py
Normal file
14
configs/multimodal/seedbench.py
Normal file
@ -0,0 +1,14 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .minigpt_4.minigpt_4_7b_seedbench import (
|
||||
minigpt_4_seedbench_dataloader, minigpt_4_seedbench_evaluator,
|
||||
minigpt_4_load_from, minigpt_4_seedbench_model)
|
||||
|
||||
models = [minigpt_4_seedbench_model]
|
||||
datasets = [minigpt_4_seedbench_dataloader]
|
||||
evaluators = [minigpt_4_seedbench_evaluator]
|
||||
load_froms = [minigpt_4_load_from]
|
||||
num_gpus = 1
|
||||
num_procs = 1
|
||||
launcher = 'slurm'
|
@ -1,3 +1,4 @@
|
||||
from .dump_results import DumpResults
|
||||
from .seedbench import SEEDBenchAcc
|
||||
|
||||
__all__ = ['DumpResults']
|
||||
__all__ = ['DumpResults', 'SEEDBenchAcc']
|
||||
|
67
opencompass/metrics/seedbench.py
Normal file
67
opencompass/metrics/seedbench.py
Normal file
@ -0,0 +1,67 @@
|
||||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from opencompass.registry import METRICS
|
||||
|
||||
EVAL_DIM_MAPPING = {
|
||||
1: 'Scene Understanding',
|
||||
2: 'Instance Identity',
|
||||
3: 'Instance Attributes',
|
||||
4: 'Instance Location',
|
||||
5: 'Instance Counting',
|
||||
6: 'Spatial Relations',
|
||||
7: 'Instance Interaction',
|
||||
8: 'Visual Reasoning',
|
||||
9: 'Text Recognition',
|
||||
10: 'Action Recognition',
|
||||
11: 'Action Prediction',
|
||||
12: 'Procedure Understanding',
|
||||
}
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class SEEDBenchAcc(BaseMetric):
|
||||
"""Compute results for SEED-Bench."""
|
||||
|
||||
def process(self, data_batch, data_samples) -> None:
|
||||
for data_sample in data_samples:
|
||||
losses = data_sample['losses']
|
||||
class_ranks = torch.argsort(losses, dim=-1).cpu()
|
||||
pred_id = ['A', 'B', 'C', 'D'][class_ranks[0]]
|
||||
answer_record = {
|
||||
'q_id': data_sample['question_id'],
|
||||
'prediction': pred_id,
|
||||
'gt': data_sample['answer'],
|
||||
'q_type_id': data_sample['question_type_id'],
|
||||
'losses': [str(num) for num in list(losses.cpu().numpy())],
|
||||
}
|
||||
self.results.append(answer_record)
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
type_counts = {}
|
||||
correct_counts = {}
|
||||
out = {}
|
||||
out['answer_records'] = results
|
||||
for item in results:
|
||||
pred, gt = item['prediction'], item['gt']
|
||||
data_type = item['q_type_id']
|
||||
|
||||
type_counts[data_type] = type_counts.get(data_type, 0) + 1
|
||||
if pred == gt:
|
||||
correct_counts[data_type] = correct_counts.get(data_type,
|
||||
0) + 1
|
||||
|
||||
total_count = 0
|
||||
total_correct = 0
|
||||
for data_type in type_counts.keys():
|
||||
accuracy = correct_counts.get(data_type,
|
||||
0) / type_counts[data_type] * 100
|
||||
category = EVAL_DIM_MAPPING[data_type]
|
||||
out[f'Data type {data_type} - {category}'] = accuracy
|
||||
|
||||
total_count += type_counts[data_type]
|
||||
total_correct += correct_counts.get(data_type, 0)
|
||||
|
||||
total_accuracy = total_correct / total_count * 100
|
||||
out['Total accuracy'] = total_accuracy
|
||||
return out
|
@ -1,3 +1,4 @@
|
||||
from .mmbench import MMBenchDataset
|
||||
from .seedbench import SEEDBenchDataset
|
||||
|
||||
__all__ = ['MMBenchDataset']
|
||||
__all__ = ['MMBenchDataset', 'SEEDBenchDataset']
|
||||
|
173
opencompass/multimodal/datasets/seedbench.py
Normal file
173
opencompass/multimodal/datasets/seedbench.py
Normal file
@ -0,0 +1,173 @@
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from decord import VideoReader, cpu
|
||||
from mmengine.dataset import Compose
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from opencompass.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class SEEDBenchDataset(Dataset):
|
||||
"""Dataset to load SEED-Bench dataset.
|
||||
|
||||
Args:
|
||||
ann_file (str): The path of the annotation file.
|
||||
cc3m_path (str): The data path of the image dimension(1-9).
|
||||
sthv2_path (str): The data path of the dimension 10.
|
||||
epic_kitchens_path (str): The data path of the dimension 11.
|
||||
breakfast_path (str): The data path of the dimension 12.
|
||||
image_pipeline (List[dict]): The data transforms for image.
|
||||
video_pipeline (List[dict]): The data transforms for video.
|
||||
only_image (bool): Whether run SEED-Bench only with image data.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ann_file: str,
|
||||
cc3m_path: str,
|
||||
sthv2_path: str,
|
||||
epic_kitchens_path: str,
|
||||
breakfast_path: str,
|
||||
image_pipeline: List[dict],
|
||||
video_pipeline: List[dict],
|
||||
only_image: bool = True,
|
||||
) -> None:
|
||||
ann_file = json.load(open(ann_file, 'rb'))
|
||||
if 'questions' in ann_file.keys():
|
||||
self.ann_file = ann_file['questions']
|
||||
self.cc3m_path = cc3m_path
|
||||
self.sthv2_path = sthv2_path
|
||||
self.epic_kitchens_path = epic_kitchens_path
|
||||
self.breakfast_path = breakfast_path
|
||||
self.image_pipeline = Compose(image_pipeline)
|
||||
if only_image:
|
||||
image_ann_file = [
|
||||
ann for ann in self.ann_file if ann['data_type'] == 'image'
|
||||
]
|
||||
self.ann_file = image_ann_file
|
||||
if not only_image:
|
||||
raise NotImplementedError
|
||||
self.video_pipeline = Compose(video_pipeline)
|
||||
|
||||
def __len__(self) -> None:
|
||||
return len(self.ann_file)
|
||||
|
||||
def __getitem__(self, idx: str) -> dict:
|
||||
item = self.ann_file[idx]
|
||||
data = {
|
||||
'question':
|
||||
item['question'],
|
||||
'answer':
|
||||
item['answer'],
|
||||
'choices': [
|
||||
item['choice_a'], item['choice_b'], item['choice_c'],
|
||||
item['choice_d']
|
||||
],
|
||||
'data_type':
|
||||
item['data_type'],
|
||||
'question_id':
|
||||
item['question_id'],
|
||||
'question_type_id':
|
||||
item['question_type_id'],
|
||||
'index':
|
||||
idx,
|
||||
}
|
||||
|
||||
if item['data_type'] == 'image':
|
||||
data_path = osp.join(self.cc3m_path, item['data_id'])
|
||||
raw_image = Image.open(open(data_path, 'rb')).convert('RGB')
|
||||
data['data_path'] = data_path
|
||||
data['img'] = raw_image
|
||||
data = self.image_pipeline(data)
|
||||
elif item['data_type'] == 'video':
|
||||
if item['question_type_id'] == 10:
|
||||
data_path = osp.join(self.sthv2_path, item['data_id'])
|
||||
data['data_path'] = data_path
|
||||
elif item['question_type_id'] == 11:
|
||||
data_path = osp.join(self.epic_kitchens_path, item['data_id'])
|
||||
data['data_path'] = data_path
|
||||
data['segment'] = item['segment']
|
||||
elif item['question_type_id'] == 12:
|
||||
data_path = osp.join(self.breakfast_path, item['data_id'])
|
||||
data['data_path'] = data_path
|
||||
data['segment'] = item['segment']
|
||||
else:
|
||||
raise ValueError('The question type id is not valid.')
|
||||
|
||||
# preprocessing videos in evaluation dimension 10-12
|
||||
use_pyav = False
|
||||
if 'segment' in data.keys():
|
||||
segment = data['segment']
|
||||
if isinstance(segment[0], int):
|
||||
# using pyav for decoding videos in evaluation dimension 12
|
||||
use_pyav = True
|
||||
start, end = segment[0], segment[1]
|
||||
else:
|
||||
start = 0.0
|
||||
end = 0.0
|
||||
|
||||
if use_pyav:
|
||||
# using pyav for videos in evaluation dimension 12
|
||||
reader = av.open(data_path)
|
||||
frames = [
|
||||
torch.from_numpy(f.to_rgb().to_ndarray())
|
||||
for f in reader.decode(video=0)
|
||||
]
|
||||
video_len = len(frames)
|
||||
start_frame, end_frame = start, end
|
||||
end_frame = min(end_frame, video_len)
|
||||
offset = self.get_index(end_frame - start_frame, 8)
|
||||
frame_indices = offset + start_frame
|
||||
buffer = torch.stack([frames[idx] for idx in frame_indices])
|
||||
buffer = buffer.numpy()
|
||||
else:
|
||||
# using decord for videos in evaluating dimension 10-11
|
||||
import io
|
||||
|
||||
import mmengine.fileio as fileio
|
||||
file_obj = io.BytesIO(fileio.get(data_path))
|
||||
vr = VideoReader(file_obj, num_threads=1, ctx=cpu(0))
|
||||
video_len = len(vr)
|
||||
fps = vr.get_avg_fps()
|
||||
if 'segment' in data.keys():
|
||||
# obtain start and end frame for the video segment
|
||||
# in evaluation dimension 11
|
||||
start_frame = int(min(max(start * fps, 0), video_len - 1))
|
||||
end_frame = int(min(max(end * fps, 0), video_len - 1))
|
||||
tot_frames = int(end_frame - start_frame)
|
||||
offset = self.get_index(tot_frames, 8)
|
||||
frame_indices = offset + start_frame
|
||||
else:
|
||||
# sample frames of the video in evaluation dimension 10
|
||||
frame_indices = self.get_index(video_len - 1, 8)
|
||||
vr.seek(0)
|
||||
buffer = vr.get_batch(frame_indices)
|
||||
buffer = buffer.asnumpy()
|
||||
data['imgs'] = buffer
|
||||
data = self.video_pipeline(data)
|
||||
|
||||
else:
|
||||
raise ValueError('The data type is not valid.')
|
||||
|
||||
return data
|
||||
|
||||
def get_index(self, num_frames, num_segments):
|
||||
if num_segments > num_frames:
|
||||
offsets = np.array([idx for idx in range(num_frames)])
|
||||
else:
|
||||
# uniform sampling
|
||||
seg_size = float(num_frames - 1) / num_segments
|
||||
start = int(seg_size / 2)
|
||||
offsets = np.array([
|
||||
start + int(np.round(seg_size * idx))
|
||||
for idx in range(num_segments)
|
||||
])
|
||||
return offsets
|
@ -7,6 +7,7 @@ from .post_processor import (MiniGPT4COCOCaptionPostProcessor,
|
||||
from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor,
|
||||
MiniGPT4MMBenchPromptConstructor,
|
||||
MiniGPT4ScienceQAPromptConstructor,
|
||||
MiniGPT4SEEDBenchPromptConstructor,
|
||||
MiniGPT4VQAPromptConstructor,
|
||||
MiniGPT4VSRPromptConstructor)
|
||||
|
||||
@ -16,5 +17,5 @@ __all__ = [
|
||||
'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor',
|
||||
'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor',
|
||||
'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor',
|
||||
'MiniGPT4VSRPromptConstructor'
|
||||
'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor'
|
||||
]
|
||||
|
@ -59,10 +59,14 @@ class MiniGPT4Inferencer(MiniGPT4):
|
||||
do_sample: bool = False,
|
||||
max_length: int = 30,
|
||||
img_size: int = 224,
|
||||
low_resource: bool = False) -> None:
|
||||
low_resource: bool = False,
|
||||
mode: str = 'generation',
|
||||
n_segments: int = 1) -> None:
|
||||
super().__init__(llama_model=llama_model,
|
||||
low_resource=low_resource,
|
||||
img_size=img_size)
|
||||
self.mode = mode
|
||||
self.n_segments = n_segments
|
||||
|
||||
cur_device = get_device()
|
||||
stop_words_ids = [
|
||||
@ -71,34 +75,73 @@ class MiniGPT4Inferencer(MiniGPT4):
|
||||
]
|
||||
self.stopping_criteria = StoppingCriteriaList(
|
||||
[StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||
prompt_constructor, MM_MODELS)
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
if post_processor is not None:
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
self.do_sample = do_sample
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, batch):
|
||||
if self.mode == 'generation':
|
||||
return self.generate(batch)
|
||||
elif self.mode == 'loss':
|
||||
return self.loss(batch)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{self.mode}".')
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(
|
||||
self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],
|
||||
dtype=torch.long).to(device)
|
||||
if image.dim() == 5:
|
||||
inputs_llama, atts_llama = [], []
|
||||
for j in range(image.size(2)):
|
||||
this_frame = image[:, :, j, :, :]
|
||||
frame_embeds = self.ln_vision(
|
||||
self.visual_encoder(this_frame))
|
||||
frame_atts = torch.ones(frame_embeds.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
|
||||
-1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
query_tokens = self.query_tokens.expand(
|
||||
frame_embeds.shape[0], -1, -1)
|
||||
frame_query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=frame_embeds,
|
||||
encoder_attention_mask=frame_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
frame_inputs_llama = self.llama_proj(
|
||||
frame_query_output.last_hidden_state[:, :query_tokens.
|
||||
size(1), :])
|
||||
frame_atts_llama = torch.ones(
|
||||
frame_inputs_llama.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
inputs_llama.append(frame_inputs_llama)
|
||||
atts_llama.append(frame_atts_llama)
|
||||
inputs_llama = torch.cat(inputs_llama, dim=1)
|
||||
atts_llama = torch.cat(atts_llama, dim=1)
|
||||
else:
|
||||
image_embeds = self.ln_vision(
|
||||
self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],
|
||||
dtype=torch.long).to(device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(
|
||||
image_embeds.shape[0], -1, -1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
def pack_inputs(self, batch):
|
||||
@ -153,3 +196,87 @@ class MiniGPT4Inferencer(MiniGPT4):
|
||||
data_sample.pred_answer = output_text
|
||||
data_samples[i] = data_sample
|
||||
return data_samples
|
||||
|
||||
def loss(self, batch):
|
||||
inputs = self.pack_inputs(batch)
|
||||
inputs = self.prompt_constructor(inputs)
|
||||
image = inputs['image']
|
||||
batch_size = image.size(0)
|
||||
prompt = inputs['prompt']
|
||||
data_samples = inputs['data_samples']
|
||||
choices = data_samples[0].choices
|
||||
|
||||
with torch.no_grad():
|
||||
img_embeds, atts_img = self.encode_img(image)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
|
||||
prompt)
|
||||
|
||||
self.llama_tokenizer.padding_side = 'right'
|
||||
|
||||
n_cands = len(choices)
|
||||
losses = []
|
||||
for n in range(self.n_segments):
|
||||
seg_len = n_cands // self.n_segments
|
||||
if n == (self.n_segments - 1):
|
||||
seg_len = n_cands - seg_len * (self.n_segments - 1)
|
||||
|
||||
to_regress_tokens = self.llama_tokenizer(
|
||||
choices,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False).to(image.device)
|
||||
|
||||
targets = to_regress_tokens.input_ids.masked_fill(
|
||||
to_regress_tokens.input_ids ==
|
||||
self.llama_tokenizer.pad_token_id, -100)
|
||||
|
||||
empty_targets = (
|
||||
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
|
||||
dtype=torch.long).to(image.device).fill_(
|
||||
-100) # plus one for bos
|
||||
)
|
||||
empty_targets = empty_targets.repeat_interleave(seg_len, dim=0)
|
||||
targets = torch.cat([empty_targets, targets], dim=1)
|
||||
|
||||
bos = torch.ones([batch_size, 1],
|
||||
dtype=to_regress_tokens.input_ids.dtype,
|
||||
device=to_regress_tokens.input_ids.device
|
||||
) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
||||
bos_embeds = bos_embeds.repeat_interleave(seg_len, dim=0)
|
||||
img_embeds = img_embeds.repeat_interleave(seg_len, dim=0)
|
||||
|
||||
atts_bos = atts_img[:, :1]
|
||||
atts_bos = atts_bos.repeat_interleave(seg_len, dim=0)
|
||||
atts_img = atts_img.repeat_interleave(seg_len, dim=0)
|
||||
|
||||
to_regress_embeds = self.llama_model.model.embed_tokens(
|
||||
to_regress_tokens.input_ids)
|
||||
|
||||
inputs_embeds = torch.cat(
|
||||
[bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[atts_bos, atts_img, to_regress_tokens.attention_mask],
|
||||
dim=1)
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
labels=targets,
|
||||
reduction='none',
|
||||
)
|
||||
loss = outputs.loss
|
||||
loss = loss.view(targets.size(0), -1).sum(1)
|
||||
loss = loss.reshape(batch_size, seg_len)
|
||||
losses.append(loss)
|
||||
# losses of 4 choices
|
||||
losses = torch.cat(losses, dim=-1)[0]
|
||||
|
||||
for i, data_sample in enumerate(data_samples):
|
||||
data_sample.losses = losses
|
||||
data_samples[i] = data_sample
|
||||
return data_samples
|
||||
|
@ -118,3 +118,23 @@ class MiniGPT4VSRPromptConstructor(MiniGPT4MMBenchPromptConstructor):
|
||||
question = questions[0]
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
||||
|
||||
|
||||
class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor):
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
"""Process data sample to prompt.
|
||||
|
||||
Args:
|
||||
data_samples (List[DataSample]): A list of data_samples.
|
||||
|
||||
Returns:
|
||||
str: Prompt.
|
||||
"""
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
question = questions[0]
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt
|
||||
return prompt
|
||||
|
@ -127,9 +127,9 @@ class MultimodalInferTask:
|
||||
|
||||
for batch in track_iter_progress(dataloader):
|
||||
if dist.is_initialized():
|
||||
data_samples = model.module.generate(batch)
|
||||
data_samples = model.module.forward(batch)
|
||||
else:
|
||||
data_samples = model.generate(batch)
|
||||
data_samples = model.forward(batch)
|
||||
if not isinstance(data_samples, Sequence):
|
||||
data_samples = [data_samples]
|
||||
evaluator.process(data_samples)
|
||||
|
Loading…
Reference in New Issue
Block a user