[Feature]: Refactor input and output (#176)

* [Feature]: Refactor input and output

* [Feature]: Update tasks
This commit is contained in:
Yuan Liu 2023-08-10 14:01:28 +08:00 committed by GitHub
parent 876ade71a5
commit a205629ff3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 69 deletions

View File

@ -1,3 +1,6 @@
from opencompass.multimodal.models.minigpt_4 import (
MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor)
# dataloader settings # dataloader settings
val_pipeline = [ val_pipeline = [
dict(type='mmpretrain.torchvision/Resize', dict(type='mmpretrain.torchvision/Resize',
@ -9,8 +12,8 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)), std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', dict(type='mmpretrain.PackInputs',
algorithm_keys=[ algorithm_keys=[
'question', 'category', 'l2-category', 'context', 'question', 'category', 'l2-category', 'context', 'index',
'index', 'options_dict', 'options', 'split' 'options_dict', 'options', 'split'
]) ])
] ]
@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
# model settings # model settings
minigpt_4_model = dict( minigpt_4_model = dict(
type='minigpt-4-mmbench', type='minigpt-4-mmbench',
low_resource=True, low_resource=False,
llama_model='/path/to/vicuna', llama_model='/path/to/vicuna-7b/',
sys_prompt= # noqa: E251 prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor,
'###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n' image_prompt='###Human: <Img><ImageHere></Img>',
) reply_prompt='###Assistant:'),
post_processor=dict(type=MiniGPT4PostProcessor))
# evaluation settings # evaluation settings
minigpt_4_evaluator = [ minigpt_4_evaluator = [
@ -39,4 +43,4 @@ minigpt_4_evaluator = [
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx') save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
] ]
minigpt_4_load_from = '/path/to/minigpt-4' # noqa minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa

View File

@ -10,6 +10,6 @@ models = [minigpt_4_model]
datasets = [minigpt_4_dataloader] datasets = [minigpt_4_dataloader]
evaluators = [minigpt_4_evaluator] evaluators = [minigpt_4_evaluator]
load_froms = [minigpt_4_load_from] load_froms = [minigpt_4_load_from]
num_gpus = 1 num_gpus = 8
num_procs = 1 num_procs = 8
launcher = 'slurm' launcher = 'pytorch'

View File

@ -1,3 +1,8 @@
from .minigpt_4 import MiniGPT4MMBench from .minigpt_4 import MiniGPT4MMBench
from .post_processor import MiniGPT4PostProcessor
from .prompt_constructor import MiniGPT4MMBenchPromptConstructor
__all__ = ['MiniGPT4MMBench'] __all__ = [
'MiniGPT4MMBench', 'MiniGPT4PostProcessor',
'MiniGPT4MMBenchPromptConstructor'
]

View File

@ -1,7 +1,7 @@
import os import os
import re
import sys import sys
import mmengine
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.device import get_device from mmengine.device import get_device
@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
Args: Args:
llama_model (str): The path of vicuna path. llama_model (str): The path of vicuna path.
sys_prompt (str): The prompt added to the beginning prompt_constructor (dict): The config of prompt constructor.
of each query. Defaults to ''. post_processor (dict): The config of post processor.
low_resource (bool): Whether loaded in low precision. low_resource (bool): Whether loaded in low precision.
Defaults to False. Defaults to False.
""" """
def __init__(self, def __init__(self,
llama_model: str, llama_model: str,
sys_prompt: str = '', prompt_constructor: dict,
post_processor: dict,
low_resource: bool = False) -> None: low_resource: bool = False) -> None:
super().__init__(llama_model=llama_model, low_resource=low_resource) super().__init__(llama_model=llama_model, low_resource=low_resource)
@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
] ]
self.stopping_criteria = StoppingCriteriaList( self.stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)]) [StoppingCriteriaSub(stops=stop_words_ids)])
self.sys_prompt = sys_prompt self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
def encode_img(self, image): def encode_img(self, image):
device = image.device device = image.device
@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
def generate(self, batch): def generate(self, batch):
inputs = self.pack_inputs(batch) inputs = self.pack_inputs(batch)
image = inputs.pop('image') inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples'] data_samples = inputs['data_samples']
samples = {'image': image}
question = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
samples.update({'question': question[0]})
samples.update({'options': options[0]})
if data_samples[0].get('context') is not None:
context = [
data_sample.get('context') for data_sample in data_samples
]
samples.update({'context': context})
data_sample = data_samples[0]
img_prompt = '###Human: <Img><ImageHere></Img> '
if 'context' in samples:
context_prompt = samples['context'][0]
question = samples['question'] # The main process of generation
options = samples['options']
if 'context' in samples:
prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
else:
prompt = img_prompt + ' ' + question + ' ' + options
# prompt = self.sys_prompt + prompt
prompt = prompt + '###Assistant:'
image = samples['image']
img_embeds, _ = self.encode_img(image) img_embeds, _ = self.encode_img(image)
prompt_segs = prompt.split('<ImageHere>') prompt_segs = prompt.split('<ImageHere>')
prompt_seg_tokens = [ prompt_seg_tokens = [
self.llama_tokenizer(seg, self.llama_tokenizer(seg,
@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
stopping_criteria=self.stopping_criteria, stopping_criteria=self.stopping_criteria,
num_return_sequences=1) num_return_sequences=1)
output_token = outputs[0] for i, data_sample in enumerate(data_samples):
if output_token[0] == 0: output_token = outputs[i]
output_token = output_token[1:] output_text = self.post_processor(output_token,
if output_token[0] == 1: self.llama_tokenizer)
output_token = output_token[1:]
output_text = self.llama_tokenizer.decode(output_token,
add_special_tokens=False)
output_text = self.post_process(output_text)
data_sample.pred_answer = output_text data_sample.pred_answer = output_text
return data_sample data_samples[i] = data_sample
return data_samples
def post_process(self, output_text):
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text

View File

@ -0,0 +1,34 @@
import re
import torch
class MiniGPT4PostProcessor:
""""Post processor for MiniGPT-4 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text)
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text

View File

@ -0,0 +1,55 @@
from typing import List
from mmpretrain.structures import DataSample
class MiniGPT4MMBenchPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
contexts = [data_sample.get('context') for data_sample in data_samples]
question = questions[0]
option = options[0]
context = contexts[0]
if context is not None:
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt

View File

@ -4,7 +4,7 @@ import os
import os.path as osp import os.path as osp
import random import random
import time import time
from typing import Sequence from typing import List, Sequence
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -78,6 +78,22 @@ class MultimodalInferTask:
return osp.join(model_name, return osp.join(model_name,
f'{dataset_name}-{evaluator_name}.{file_extension}') f'{dataset_name}-{evaluator_name}.{file_extension}')
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the path to the output file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return [
osp.join(model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
]
def get_command(self, cfg_path, template): def get_command(self, cfg_path, template):
"""Get the command template for the task. """Get the command template for the task.