mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature]: Refactor input and output (#176)
* [Feature]: Refactor input and output * [Feature]: Update tasks
This commit is contained in:
parent
876ade71a5
commit
a205629ff3
@ -1,3 +1,6 @@
|
||||
from opencompass.multimodal.models.minigpt_4 import (
|
||||
MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
@ -9,8 +12,8 @@ val_pipeline = [
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'category', 'l2-category', 'context',
|
||||
'index', 'options_dict', 'options', 'split'
|
||||
'question', 'category', 'l2-category', 'context', 'index',
|
||||
'options_dict', 'options', 'split'
|
||||
])
|
||||
]
|
||||
|
||||
@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
|
||||
# model settings
|
||||
minigpt_4_model = dict(
|
||||
type='minigpt-4-mmbench',
|
||||
low_resource=True,
|
||||
llama_model='/path/to/vicuna',
|
||||
sys_prompt= # noqa: E251
|
||||
'###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
|
||||
)
|
||||
low_resource=False,
|
||||
llama_model='/path/to/vicuna-7b/',
|
||||
prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor,
|
||||
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||
reply_prompt='###Assistant:'),
|
||||
post_processor=dict(type=MiniGPT4PostProcessor))
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_evaluator = [
|
||||
@ -39,4 +43,4 @@ minigpt_4_evaluator = [
|
||||
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
|
||||
]
|
||||
|
||||
minigpt_4_load_from = '/path/to/minigpt-4' # noqa
|
||||
minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
|
||||
|
@ -10,6 +10,6 @@ models = [minigpt_4_model]
|
||||
datasets = [minigpt_4_dataloader]
|
||||
evaluators = [minigpt_4_evaluator]
|
||||
load_froms = [minigpt_4_load_from]
|
||||
num_gpus = 1
|
||||
num_procs = 1
|
||||
launcher = 'slurm'
|
||||
num_gpus = 8
|
||||
num_procs = 8
|
||||
launcher = 'pytorch'
|
||||
|
@ -1,3 +1,8 @@
|
||||
from .minigpt_4 import MiniGPT4MMBench
|
||||
from .post_processor import MiniGPT4PostProcessor
|
||||
from .prompt_constructor import MiniGPT4MMBenchPromptConstructor
|
||||
|
||||
__all__ = ['MiniGPT4MMBench']
|
||||
__all__ = [
|
||||
'MiniGPT4MMBench', 'MiniGPT4PostProcessor',
|
||||
'MiniGPT4MMBenchPromptConstructor'
|
||||
]
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
|
||||
|
||||
Args:
|
||||
llama_model (str): The path of vicuna path.
|
||||
sys_prompt (str): The prompt added to the beginning
|
||||
of each query. Defaults to ''.
|
||||
prompt_constructor (dict): The config of prompt constructor.
|
||||
post_processor (dict): The config of post processor.
|
||||
low_resource (bool): Whether loaded in low precision.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llama_model: str,
|
||||
sys_prompt: str = '',
|
||||
prompt_constructor: dict,
|
||||
post_processor: dict,
|
||||
low_resource: bool = False) -> None:
|
||||
super().__init__(llama_model=llama_model, low_resource=low_resource)
|
||||
|
||||
@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
|
||||
]
|
||||
self.stopping_criteria = StoppingCriteriaList(
|
||||
[StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
self.sys_prompt = sys_prompt
|
||||
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||
prompt_constructor, MM_MODELS)
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
|
||||
|
||||
def generate(self, batch):
|
||||
inputs = self.pack_inputs(batch)
|
||||
image = inputs.pop('image')
|
||||
inputs = self.prompt_constructor(inputs)
|
||||
image = inputs['image']
|
||||
prompt = inputs['prompt']
|
||||
data_samples = inputs['data_samples']
|
||||
samples = {'image': image}
|
||||
question = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
options = [data_sample.get('options') for data_sample in data_samples]
|
||||
samples.update({'question': question[0]})
|
||||
samples.update({'options': options[0]})
|
||||
if data_samples[0].get('context') is not None:
|
||||
context = [
|
||||
data_sample.get('context') for data_sample in data_samples
|
||||
]
|
||||
samples.update({'context': context})
|
||||
data_sample = data_samples[0]
|
||||
img_prompt = '###Human: <Img><ImageHere></Img> '
|
||||
if 'context' in samples:
|
||||
context_prompt = samples['context'][0]
|
||||
|
||||
question = samples['question']
|
||||
options = samples['options']
|
||||
if 'context' in samples:
|
||||
prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
|
||||
else:
|
||||
prompt = img_prompt + ' ' + question + ' ' + options
|
||||
|
||||
# prompt = self.sys_prompt + prompt
|
||||
prompt = prompt + '###Assistant:'
|
||||
|
||||
image = samples['image']
|
||||
# The main process of generation
|
||||
img_embeds, _ = self.encode_img(image)
|
||||
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
prompt_seg_tokens = [
|
||||
self.llama_tokenizer(seg,
|
||||
@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_return_sequences=1)
|
||||
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = self.llama_tokenizer.decode(output_token,
|
||||
add_special_tokens=False)
|
||||
output_text = self.post_process(output_text)
|
||||
for i, data_sample in enumerate(data_samples):
|
||||
output_token = outputs[i]
|
||||
output_text = self.post_processor(output_token,
|
||||
self.llama_tokenizer)
|
||||
data_sample.pred_answer = output_text
|
||||
return data_sample
|
||||
|
||||
def post_process(self, output_text):
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
pattern = re.compile(r'([A-Z]\.)')
|
||||
res = pattern.findall(output_text)
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
||||
data_samples[i] = data_sample
|
||||
return data_samples
|
||||
|
34
opencompass/multimodal/models/minigpt_4/post_processor.py
Normal file
34
opencompass/multimodal/models/minigpt_4/post_processor.py
Normal file
@ -0,0 +1,34 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MiniGPT4PostProcessor:
|
||||
""""Post processor for MiniGPT-4 on MMBench."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = self._extract_key_words(output_text)
|
||||
return output_text
|
||||
|
||||
def _extract_key_words(self, output_text: str) -> str:
|
||||
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
pattern = re.compile(r'([A-Z]\.)')
|
||||
res = pattern.findall(output_text)
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
@ -0,0 +1,55 @@
|
||||
from typing import List
|
||||
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class MiniGPT4MMBenchPromptConstructor:
|
||||
"""Prompt constructor for MiniGPT-4 on MMBench.
|
||||
|
||||
Args:
|
||||
image_prompt (str): Image prompt.
|
||||
reply_prompt (str): Reply prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
|
||||
self.image_prompt = image_prompt
|
||||
self.reply_prompt = reply_prompt
|
||||
|
||||
def __call__(self, inputs: dict) -> dict:
|
||||
"""Construct prompt.
|
||||
|
||||
Args:
|
||||
inputs (dict): Input data containing image and data_samples.
|
||||
|
||||
Returns:
|
||||
dict: A dict containing prompt, images and data_samples.
|
||||
"""
|
||||
data_samples = inputs['data_samples']
|
||||
prompt = self._process(data_samples)
|
||||
inputs.update({'prompt': prompt})
|
||||
|
||||
return inputs
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
"""Process data sample to prompt.
|
||||
|
||||
Args:
|
||||
data_samples (List[DataSample]): A list of data_samples.
|
||||
|
||||
Returns:
|
||||
str: Prompt.
|
||||
"""
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
options = [data_sample.get('options') for data_sample in data_samples]
|
||||
contexts = [data_sample.get('context') for data_sample in data_samples]
|
||||
question = questions[0]
|
||||
option = options[0]
|
||||
context = contexts[0]
|
||||
if context is not None:
|
||||
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||
else:
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
@ -4,7 +4,7 @@ import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -78,6 +78,22 @@ class MultimodalInferTask:
|
||||
return osp.join(model_name,
|
||||
f'{dataset_name}-{evaluator_name}.{file_extension}')
|
||||
|
||||
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
|
||||
"""Get the path to the output file.
|
||||
|
||||
Args:
|
||||
file_extension (str): The file extension of the log file.
|
||||
Default: 'json'.
|
||||
"""
|
||||
model_name = self.model['type']
|
||||
dataset_name = self.dataloader['dataset']['type']
|
||||
evaluator_name = self.evaluator[0]['type']
|
||||
|
||||
return [
|
||||
osp.join(model_name, dataset_name,
|
||||
f'{evaluator_name}.{file_extension}')
|
||||
]
|
||||
|
||||
def get_command(self, cfg_path, template):
|
||||
"""Get the command template for the task.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user