mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature]: Refactor input and output (#176)
* [Feature]: Refactor input and output * [Feature]: Update tasks
This commit is contained in:
parent
876ade71a5
commit
a205629ff3
@ -1,3 +1,6 @@
|
|||||||
|
from opencompass.multimodal.models.minigpt_4 import (
|
||||||
|
MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor)
|
||||||
|
|
||||||
# dataloader settings
|
# dataloader settings
|
||||||
val_pipeline = [
|
val_pipeline = [
|
||||||
dict(type='mmpretrain.torchvision/Resize',
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
@ -9,8 +12,8 @@ val_pipeline = [
|
|||||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
dict(type='mmpretrain.PackInputs',
|
dict(type='mmpretrain.PackInputs',
|
||||||
algorithm_keys=[
|
algorithm_keys=[
|
||||||
'question', 'category', 'l2-category', 'context',
|
'question', 'category', 'l2-category', 'context', 'index',
|
||||||
'index', 'options_dict', 'options', 'split'
|
'options_dict', 'options', 'split'
|
||||||
])
|
])
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
|
|||||||
# model settings
|
# model settings
|
||||||
minigpt_4_model = dict(
|
minigpt_4_model = dict(
|
||||||
type='minigpt-4-mmbench',
|
type='minigpt-4-mmbench',
|
||||||
low_resource=True,
|
low_resource=False,
|
||||||
llama_model='/path/to/vicuna',
|
llama_model='/path/to/vicuna-7b/',
|
||||||
sys_prompt= # noqa: E251
|
prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor,
|
||||||
'###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
|
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||||
)
|
reply_prompt='###Assistant:'),
|
||||||
|
post_processor=dict(type=MiniGPT4PostProcessor))
|
||||||
|
|
||||||
# evaluation settings
|
# evaluation settings
|
||||||
minigpt_4_evaluator = [
|
minigpt_4_evaluator = [
|
||||||
@ -39,4 +43,4 @@ minigpt_4_evaluator = [
|
|||||||
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
|
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
|
||||||
]
|
]
|
||||||
|
|
||||||
minigpt_4_load_from = '/path/to/minigpt-4' # noqa
|
minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
|
||||||
|
@ -10,6 +10,6 @@ models = [minigpt_4_model]
|
|||||||
datasets = [minigpt_4_dataloader]
|
datasets = [minigpt_4_dataloader]
|
||||||
evaluators = [minigpt_4_evaluator]
|
evaluators = [minigpt_4_evaluator]
|
||||||
load_froms = [minigpt_4_load_from]
|
load_froms = [minigpt_4_load_from]
|
||||||
num_gpus = 1
|
num_gpus = 8
|
||||||
num_procs = 1
|
num_procs = 8
|
||||||
launcher = 'slurm'
|
launcher = 'pytorch'
|
||||||
|
@ -1,3 +1,8 @@
|
|||||||
from .minigpt_4 import MiniGPT4MMBench
|
from .minigpt_4 import MiniGPT4MMBench
|
||||||
|
from .post_processor import MiniGPT4PostProcessor
|
||||||
|
from .prompt_constructor import MiniGPT4MMBenchPromptConstructor
|
||||||
|
|
||||||
__all__ = ['MiniGPT4MMBench']
|
__all__ = [
|
||||||
|
'MiniGPT4MMBench', 'MiniGPT4PostProcessor',
|
||||||
|
'MiniGPT4MMBenchPromptConstructor'
|
||||||
|
]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import mmengine
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmengine.device import get_device
|
from mmengine.device import get_device
|
||||||
@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
llama_model (str): The path of vicuna path.
|
llama_model (str): The path of vicuna path.
|
||||||
sys_prompt (str): The prompt added to the beginning
|
prompt_constructor (dict): The config of prompt constructor.
|
||||||
of each query. Defaults to ''.
|
post_processor (dict): The config of post processor.
|
||||||
low_resource (bool): Whether loaded in low precision.
|
low_resource (bool): Whether loaded in low precision.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
llama_model: str,
|
llama_model: str,
|
||||||
sys_prompt: str = '',
|
prompt_constructor: dict,
|
||||||
|
post_processor: dict,
|
||||||
low_resource: bool = False) -> None:
|
low_resource: bool = False) -> None:
|
||||||
super().__init__(llama_model=llama_model, low_resource=low_resource)
|
super().__init__(llama_model=llama_model, low_resource=low_resource)
|
||||||
|
|
||||||
@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
|
|||||||
]
|
]
|
||||||
self.stopping_criteria = StoppingCriteriaList(
|
self.stopping_criteria = StoppingCriteriaList(
|
||||||
[StoppingCriteriaSub(stops=stop_words_ids)])
|
[StoppingCriteriaSub(stops=stop_words_ids)])
|
||||||
self.sys_prompt = sys_prompt
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||||
|
prompt_constructor, MM_MODELS)
|
||||||
|
self.post_processor = mmengine.registry.build_from_cfg(
|
||||||
|
post_processor, MM_MODELS)
|
||||||
|
|
||||||
def encode_img(self, image):
|
def encode_img(self, image):
|
||||||
device = image.device
|
device = image.device
|
||||||
@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
|
|||||||
|
|
||||||
def generate(self, batch):
|
def generate(self, batch):
|
||||||
inputs = self.pack_inputs(batch)
|
inputs = self.pack_inputs(batch)
|
||||||
image = inputs.pop('image')
|
inputs = self.prompt_constructor(inputs)
|
||||||
|
image = inputs['image']
|
||||||
|
prompt = inputs['prompt']
|
||||||
data_samples = inputs['data_samples']
|
data_samples = inputs['data_samples']
|
||||||
samples = {'image': image}
|
|
||||||
question = [
|
|
||||||
data_sample.get('question') for data_sample in data_samples
|
|
||||||
]
|
|
||||||
options = [data_sample.get('options') for data_sample in data_samples]
|
|
||||||
samples.update({'question': question[0]})
|
|
||||||
samples.update({'options': options[0]})
|
|
||||||
if data_samples[0].get('context') is not None:
|
|
||||||
context = [
|
|
||||||
data_sample.get('context') for data_sample in data_samples
|
|
||||||
]
|
|
||||||
samples.update({'context': context})
|
|
||||||
data_sample = data_samples[0]
|
|
||||||
img_prompt = '###Human: <Img><ImageHere></Img> '
|
|
||||||
if 'context' in samples:
|
|
||||||
context_prompt = samples['context'][0]
|
|
||||||
|
|
||||||
question = samples['question']
|
# The main process of generation
|
||||||
options = samples['options']
|
|
||||||
if 'context' in samples:
|
|
||||||
prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
|
|
||||||
else:
|
|
||||||
prompt = img_prompt + ' ' + question + ' ' + options
|
|
||||||
|
|
||||||
# prompt = self.sys_prompt + prompt
|
|
||||||
prompt = prompt + '###Assistant:'
|
|
||||||
|
|
||||||
image = samples['image']
|
|
||||||
img_embeds, _ = self.encode_img(image)
|
img_embeds, _ = self.encode_img(image)
|
||||||
|
|
||||||
prompt_segs = prompt.split('<ImageHere>')
|
prompt_segs = prompt.split('<ImageHere>')
|
||||||
prompt_seg_tokens = [
|
prompt_seg_tokens = [
|
||||||
self.llama_tokenizer(seg,
|
self.llama_tokenizer(seg,
|
||||||
@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
|
|||||||
stopping_criteria=self.stopping_criteria,
|
stopping_criteria=self.stopping_criteria,
|
||||||
num_return_sequences=1)
|
num_return_sequences=1)
|
||||||
|
|
||||||
output_token = outputs[0]
|
for i, data_sample in enumerate(data_samples):
|
||||||
if output_token[0] == 0:
|
output_token = outputs[i]
|
||||||
output_token = output_token[1:]
|
output_text = self.post_processor(output_token,
|
||||||
if output_token[0] == 1:
|
self.llama_tokenizer)
|
||||||
output_token = output_token[1:]
|
|
||||||
output_text = self.llama_tokenizer.decode(output_token,
|
|
||||||
add_special_tokens=False)
|
|
||||||
output_text = self.post_process(output_text)
|
|
||||||
data_sample.pred_answer = output_text
|
data_sample.pred_answer = output_text
|
||||||
return data_sample
|
data_samples[i] = data_sample
|
||||||
|
return data_samples
|
||||||
def post_process(self, output_text):
|
|
||||||
output_text = output_text.split('###')[0]
|
|
||||||
output_text = output_text.split('Assistant:')[-1].strip()
|
|
||||||
output_text = output_text.strip('</s><s>')
|
|
||||||
output_text = output_text.strip('</Img>')
|
|
||||||
output_text = output_text.strip()
|
|
||||||
pattern = re.compile(r'([A-Z]\.)')
|
|
||||||
res = pattern.findall(output_text)
|
|
||||||
if len(res) > 0:
|
|
||||||
output_text = res[0][:-1]
|
|
||||||
return output_text
|
|
||||||
|
34
opencompass/multimodal/models/minigpt_4/post_processor.py
Normal file
34
opencompass/multimodal/models/minigpt_4/post_processor.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MiniGPT4PostProcessor:
|
||||||
|
""""Post processor for MiniGPT-4 on MMBench."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
|
||||||
|
if output_token[0] == 0:
|
||||||
|
output_token = output_token[1:]
|
||||||
|
if output_token[0] == 1:
|
||||||
|
output_token = output_token[1:]
|
||||||
|
output_text = tokenizer.decode(output_token,
|
||||||
|
add_special_tokens=False) # noqa
|
||||||
|
output_text = self._extract_key_words(output_text)
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
def _extract_key_words(self, output_text: str) -> str:
|
||||||
|
|
||||||
|
output_text = output_text.split('###')[0]
|
||||||
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||||||
|
output_text = output_text.strip('</s><s>')
|
||||||
|
output_text = output_text.strip('</Img>')
|
||||||
|
output_text = output_text.strip()
|
||||||
|
pattern = re.compile(r'([A-Z]\.)')
|
||||||
|
res = pattern.findall(output_text)
|
||||||
|
if len(res) > 0:
|
||||||
|
output_text = res[0][:-1]
|
||||||
|
return output_text
|
@ -0,0 +1,55 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mmpretrain.structures import DataSample
|
||||||
|
|
||||||
|
|
||||||
|
class MiniGPT4MMBenchPromptConstructor:
|
||||||
|
"""Prompt constructor for MiniGPT-4 on MMBench.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_prompt (str): Image prompt.
|
||||||
|
reply_prompt (str): Reply prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
|
||||||
|
self.image_prompt = image_prompt
|
||||||
|
self.reply_prompt = reply_prompt
|
||||||
|
|
||||||
|
def __call__(self, inputs: dict) -> dict:
|
||||||
|
"""Construct prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): Input data containing image and data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict containing prompt, images and data_samples.
|
||||||
|
"""
|
||||||
|
data_samples = inputs['data_samples']
|
||||||
|
prompt = self._process(data_samples)
|
||||||
|
inputs.update({'prompt': prompt})
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
"""Process data sample to prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_samples (List[DataSample]): A list of data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Prompt.
|
||||||
|
"""
|
||||||
|
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||||
|
questions = [
|
||||||
|
data_sample.get('question') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
options = [data_sample.get('options') for data_sample in data_samples]
|
||||||
|
contexts = [data_sample.get('context') for data_sample in data_samples]
|
||||||
|
question = questions[0]
|
||||||
|
option = options[0]
|
||||||
|
context = contexts[0]
|
||||||
|
if context is not None:
|
||||||
|
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||||
|
else:
|
||||||
|
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||||
|
return prompt
|
@ -4,7 +4,7 @@ import os
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Sequence
|
from typing import List, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -78,6 +78,22 @@ class MultimodalInferTask:
|
|||||||
return osp.join(model_name,
|
return osp.join(model_name,
|
||||||
f'{dataset_name}-{evaluator_name}.{file_extension}')
|
f'{dataset_name}-{evaluator_name}.{file_extension}')
|
||||||
|
|
||||||
|
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
|
||||||
|
"""Get the path to the output file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_extension (str): The file extension of the log file.
|
||||||
|
Default: 'json'.
|
||||||
|
"""
|
||||||
|
model_name = self.model['type']
|
||||||
|
dataset_name = self.dataloader['dataset']['type']
|
||||||
|
evaluator_name = self.evaluator[0]['type']
|
||||||
|
|
||||||
|
return [
|
||||||
|
osp.join(model_name, dataset_name,
|
||||||
|
f'{evaluator_name}.{file_extension}')
|
||||||
|
]
|
||||||
|
|
||||||
def get_command(self, cfg_path, template):
|
def get_command(self, cfg_path, template):
|
||||||
"""Get the command template for the task.
|
"""Get the command template for the task.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user