[Feature]: Refactor input and output (#176)

* [Feature]: Refactor input and output

* [Feature]: Update tasks
This commit is contained in:
Yuan Liu 2023-08-10 14:01:28 +08:00 committed by GitHub
parent 876ade71a5
commit a205629ff3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 69 deletions

View File

@ -1,3 +1,6 @@
from opencompass.multimodal.models.minigpt_4 import (
MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
@ -9,8 +12,8 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'category', 'l2-category', 'context',
'index', 'options_dict', 'options', 'split'
'question', 'category', 'l2-category', 'context', 'index',
'options_dict', 'options', 'split'
])
]
@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
# model settings
minigpt_4_model = dict(
type='minigpt-4-mmbench',
low_resource=True,
llama_model='/path/to/vicuna',
sys_prompt= # noqa: E251
'###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
)
low_resource=False,
llama_model='/path/to/vicuna-7b/',
prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'),
post_processor=dict(type=MiniGPT4PostProcessor))
# evaluation settings
minigpt_4_evaluator = [
@ -39,4 +43,4 @@ minigpt_4_evaluator = [
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
]
minigpt_4_load_from = '/path/to/minigpt-4' # noqa
minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa

View File

@ -10,6 +10,6 @@ models = [minigpt_4_model]
datasets = [minigpt_4_dataloader]
evaluators = [minigpt_4_evaluator]
load_froms = [minigpt_4_load_from]
num_gpus = 1
num_procs = 1
launcher = 'slurm'
num_gpus = 8
num_procs = 8
launcher = 'pytorch'

View File

@ -1,3 +1,8 @@
from .minigpt_4 import MiniGPT4MMBench
from .post_processor import MiniGPT4PostProcessor
from .prompt_constructor import MiniGPT4MMBenchPromptConstructor
__all__ = ['MiniGPT4MMBench']
__all__ = [
'MiniGPT4MMBench', 'MiniGPT4PostProcessor',
'MiniGPT4MMBenchPromptConstructor'
]

View File

@ -1,7 +1,7 @@
import os
import re
import sys
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
Args:
llama_model (str): The path of vicuna path.
sys_prompt (str): The prompt added to the beginning
of each query. Defaults to ''.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
low_resource (bool): Whether loaded in low precision.
Defaults to False.
"""
def __init__(self,
llama_model: str,
sys_prompt: str = '',
prompt_constructor: dict,
post_processor: dict,
low_resource: bool = False) -> None:
super().__init__(llama_model=llama_model, low_resource=low_resource)
@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
]
self.stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)])
self.sys_prompt = sys_prompt
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
def encode_img(self, image):
device = image.device
@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
def generate(self, batch):
inputs = self.pack_inputs(batch)
image = inputs.pop('image')
inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples']
samples = {'image': image}
question = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
samples.update({'question': question[0]})
samples.update({'options': options[0]})
if data_samples[0].get('context') is not None:
context = [
data_sample.get('context') for data_sample in data_samples
]
samples.update({'context': context})
data_sample = data_samples[0]
img_prompt = '###Human: <Img><ImageHere></Img> '
if 'context' in samples:
context_prompt = samples['context'][0]
question = samples['question']
options = samples['options']
if 'context' in samples:
prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
else:
prompt = img_prompt + ' ' + question + ' ' + options
# prompt = self.sys_prompt + prompt
prompt = prompt + '###Assistant:'
image = samples['image']
# The main process of generation
img_embeds, _ = self.encode_img(image)
prompt_segs = prompt.split('<ImageHere>')
prompt_seg_tokens = [
self.llama_tokenizer(seg,
@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
stopping_criteria=self.stopping_criteria,
num_return_sequences=1)
output_token = outputs[0]
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = self.llama_tokenizer.decode(output_token,
add_special_tokens=False)
output_text = self.post_process(output_text)
data_sample.pred_answer = output_text
return data_sample
def post_process(self, output_text):
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
for i, data_sample in enumerate(data_samples):
output_token = outputs[i]
output_text = self.post_processor(output_token,
self.llama_tokenizer)
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples

View File

@ -0,0 +1,34 @@
import re
import torch
class MiniGPT4PostProcessor:
""""Post processor for MiniGPT-4 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text)
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text

View File

@ -0,0 +1,55 @@
from typing import List
from mmpretrain.structures import DataSample
class MiniGPT4MMBenchPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
contexts = [data_sample.get('context') for data_sample in data_samples]
question = questions[0]
option = options[0]
context = contexts[0]
if context is not None:
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt

View File

@ -4,7 +4,7 @@ import os
import os.path as osp
import random
import time
from typing import Sequence
from typing import List, Sequence
import torch
import torch.distributed as dist
@ -78,6 +78,22 @@ class MultimodalInferTask:
return osp.join(model_name,
f'{dataset_name}-{evaluator_name}.{file_extension}')
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the path to the output file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return [
osp.join(model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
]
def get_command(self, cfg_path, template):
"""Get the command template for the task.