From a205629ff3d2e17551d74a8ead023d284fde2f2a Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Thu, 10 Aug 2023 14:01:28 +0800 Subject: [PATCH] [Feature]: Refactor input and output (#176) * [Feature]: Refactor input and output * [Feature]: Update tasks --- .../minigpt_4/minigpt_4_7b_mmbench.py | 20 +++-- configs/multimodal/tasks.py | 6 +- .../multimodal/models/minigpt_4/__init__.py | 7 +- .../multimodal/models/minigpt_4/minigpt_4.py | 76 +++++-------------- .../models/minigpt_4/post_processor.py | 34 +++++++++ .../models/minigpt_4/prompt_constructor.py | 55 ++++++++++++++ opencompass/tasks/mm_infer.py | 18 ++++- 7 files changed, 147 insertions(+), 69 deletions(-) create mode 100644 opencompass/multimodal/models/minigpt_4/post_processor.py create mode 100644 opencompass/multimodal/models/minigpt_4/prompt_constructor.py diff --git a/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py index 90a6df61..43ecb801 100644 --- a/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py +++ b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py @@ -1,3 +1,6 @@ +from opencompass.multimodal.models.minigpt_4 import ( + MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor) + # dataloader settings val_pipeline = [ dict(type='mmpretrain.torchvision/Resize', @@ -9,8 +12,8 @@ val_pipeline = [ std=(0.26862954, 0.26130258, 0.27577711)), dict(type='mmpretrain.PackInputs', algorithm_keys=[ - 'question', 'category', 'l2-category', 'context', - 'index', 'options_dict', 'options', 'split' + 'question', 'category', 'l2-category', 'context', 'index', + 'options_dict', 'options', 'split' ]) ] @@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1, # model settings minigpt_4_model = dict( type='minigpt-4-mmbench', - low_resource=True, - llama_model='/path/to/vicuna', - sys_prompt= # noqa: E251 - '###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n' -) + low_resource=False, + llama_model='/path/to/vicuna-7b/', + prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor, + image_prompt='###Human: ', + reply_prompt='###Assistant:'), + post_processor=dict(type=MiniGPT4PostProcessor)) # evaluation settings minigpt_4_evaluator = [ @@ -39,4 +43,4 @@ minigpt_4_evaluator = [ save_path='work_dirs/minigpt-4-7b-mmbench.xlsx') ] -minigpt_4_load_from = '/path/to/minigpt-4' # noqa +minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa diff --git a/configs/multimodal/tasks.py b/configs/multimodal/tasks.py index b8fd75fd..94273b96 100644 --- a/configs/multimodal/tasks.py +++ b/configs/multimodal/tasks.py @@ -10,6 +10,6 @@ models = [minigpt_4_model] datasets = [minigpt_4_dataloader] evaluators = [minigpt_4_evaluator] load_froms = [minigpt_4_load_from] -num_gpus = 1 -num_procs = 1 -launcher = 'slurm' +num_gpus = 8 +num_procs = 8 +launcher = 'pytorch' diff --git a/opencompass/multimodal/models/minigpt_4/__init__.py b/opencompass/multimodal/models/minigpt_4/__init__.py index 3104855c..6604c669 100644 --- a/opencompass/multimodal/models/minigpt_4/__init__.py +++ b/opencompass/multimodal/models/minigpt_4/__init__.py @@ -1,3 +1,8 @@ from .minigpt_4 import MiniGPT4MMBench +from .post_processor import MiniGPT4PostProcessor +from .prompt_constructor import MiniGPT4MMBenchPromptConstructor -__all__ = ['MiniGPT4MMBench'] +__all__ = [ + 'MiniGPT4MMBench', 'MiniGPT4PostProcessor', + 'MiniGPT4MMBenchPromptConstructor' +] diff --git a/opencompass/multimodal/models/minigpt_4/minigpt_4.py b/opencompass/multimodal/models/minigpt_4/minigpt_4.py index 306fec58..ee4d4c8c 100644 --- a/opencompass/multimodal/models/minigpt_4/minigpt_4.py +++ b/opencompass/multimodal/models/minigpt_4/minigpt_4.py @@ -1,7 +1,7 @@ import os -import re import sys +import mmengine import torch import torch.nn as nn from mmengine.device import get_device @@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4): Args: llama_model (str): The path of vicuna path. - sys_prompt (str): The prompt added to the beginning - of each query. Defaults to ''. + prompt_constructor (dict): The config of prompt constructor. + post_processor (dict): The config of post processor. low_resource (bool): Whether loaded in low precision. Defaults to False. """ def __init__(self, llama_model: str, - sys_prompt: str = '', + prompt_constructor: dict, + post_processor: dict, low_resource: bool = False) -> None: super().__init__(llama_model=llama_model, low_resource=low_resource) @@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4): ] self.stopping_criteria = StoppingCriteriaList( [StoppingCriteriaSub(stops=stop_words_ids)]) - self.sys_prompt = sys_prompt + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) def encode_img(self, image): device = image.device @@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4): def generate(self, batch): inputs = self.pack_inputs(batch) - image = inputs.pop('image') + inputs = self.prompt_constructor(inputs) + image = inputs['image'] + prompt = inputs['prompt'] data_samples = inputs['data_samples'] - samples = {'image': image} - question = [ - data_sample.get('question') for data_sample in data_samples - ] - options = [data_sample.get('options') for data_sample in data_samples] - samples.update({'question': question[0]}) - samples.update({'options': options[0]}) - if data_samples[0].get('context') is not None: - context = [ - data_sample.get('context') for data_sample in data_samples - ] - samples.update({'context': context}) - data_sample = data_samples[0] - img_prompt = '###Human: ' - if 'context' in samples: - context_prompt = samples['context'][0] - question = samples['question'] - options = samples['options'] - if 'context' in samples: - prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa - else: - prompt = img_prompt + ' ' + question + ' ' + options - - # prompt = self.sys_prompt + prompt - prompt = prompt + '###Assistant:' - - image = samples['image'] + # The main process of generation img_embeds, _ = self.encode_img(image) - prompt_segs = prompt.split('') prompt_seg_tokens = [ self.llama_tokenizer(seg, @@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4): stopping_criteria=self.stopping_criteria, num_return_sequences=1) - output_token = outputs[0] - if output_token[0] == 0: - output_token = output_token[1:] - if output_token[0] == 1: - output_token = output_token[1:] - output_text = self.llama_tokenizer.decode(output_token, - add_special_tokens=False) - output_text = self.post_process(output_text) - data_sample.pred_answer = output_text - return data_sample - - def post_process(self, output_text): - output_text = output_text.split('###')[0] - output_text = output_text.split('Assistant:')[-1].strip() - output_text = output_text.strip('') - output_text = output_text.strip('') - output_text = output_text.strip() - pattern = re.compile(r'([A-Z]\.)') - res = pattern.findall(output_text) - if len(res) > 0: - output_text = res[0][:-1] - return output_text + for i, data_sample in enumerate(data_samples): + output_token = outputs[i] + output_text = self.post_processor(output_token, + self.llama_tokenizer) + data_sample.pred_answer = output_text + data_samples[i] = data_sample + return data_samples diff --git a/opencompass/multimodal/models/minigpt_4/post_processor.py b/opencompass/multimodal/models/minigpt_4/post_processor.py new file mode 100644 index 00000000..301a3422 --- /dev/null +++ b/opencompass/multimodal/models/minigpt_4/post_processor.py @@ -0,0 +1,34 @@ +import re + +import torch + + +class MiniGPT4PostProcessor: + """"Post processor for MiniGPT-4 on MMBench.""" + + def __init__(self) -> None: + pass + + def __call__(self, output_token: torch.tensor, tokenizer) -> str: + + if output_token[0] == 0: + output_token = output_token[1:] + if output_token[0] == 1: + output_token = output_token[1:] + output_text = tokenizer.decode(output_token, + add_special_tokens=False) # noqa + output_text = self._extract_key_words(output_text) + return output_text + + def _extract_key_words(self, output_text: str) -> str: + + output_text = output_text.split('###')[0] + output_text = output_text.split('Assistant:')[-1].strip() + output_text = output_text.strip('') + output_text = output_text.strip('') + output_text = output_text.strip() + pattern = re.compile(r'([A-Z]\.)') + res = pattern.findall(output_text) + if len(res) > 0: + output_text = res[0][:-1] + return output_text diff --git a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py new file mode 100644 index 00000000..de07c1bf --- /dev/null +++ b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py @@ -0,0 +1,55 @@ +from typing import List + +from mmpretrain.structures import DataSample + + +class MiniGPT4MMBenchPromptConstructor: + """Prompt constructor for MiniGPT-4 on MMBench. + + Args: + image_prompt (str): Image prompt. + reply_prompt (str): Reply prompt. + """ + + def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None: + self.image_prompt = image_prompt + self.reply_prompt = reply_prompt + + def __call__(self, inputs: dict) -> dict: + """Construct prompt. + + Args: + inputs (dict): Input data containing image and data_samples. + + Returns: + dict: A dict containing prompt, images and data_samples. + """ + data_samples = inputs['data_samples'] + prompt = self._process(data_samples) + inputs.update({'prompt': prompt}) + + return inputs + + def _process(self, data_samples: List[DataSample]) -> str: + """Process data sample to prompt. + + Args: + data_samples (List[DataSample]): A list of data_samples. + + Returns: + str: Prompt. + """ + assert len(data_samples) == 1, 'Only support batch size 1.' + questions = [ + data_sample.get('question') for data_sample in data_samples + ] + options = [data_sample.get('options') for data_sample in data_samples] + contexts = [data_sample.get('context') for data_sample in data_samples] + question = questions[0] + option = options[0] + context = contexts[0] + if context is not None: + prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa + else: + prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa + return prompt diff --git a/opencompass/tasks/mm_infer.py b/opencompass/tasks/mm_infer.py index 5ecc1ab3..46ef4a79 100644 --- a/opencompass/tasks/mm_infer.py +++ b/opencompass/tasks/mm_infer.py @@ -4,7 +4,7 @@ import os import os.path as osp import random import time -from typing import Sequence +from typing import List, Sequence import torch import torch.distributed as dist @@ -78,6 +78,22 @@ class MultimodalInferTask: return osp.join(model_name, f'{dataset_name}-{evaluator_name}.{file_extension}') + def get_output_paths(self, file_extension: str = 'json') -> List[str]: + """Get the path to the output file. + + Args: + file_extension (str): The file extension of the log file. + Default: 'json'. + """ + model_name = self.model['type'] + dataset_name = self.dataloader['dataset']['type'] + evaluator_name = self.evaluator[0]['type'] + + return [ + osp.join(model_name, dataset_name, + f'{evaluator_name}.{file_extension}') + ] + def get_command(self, cfg_path, template): """Get the command template for the task.