From 191a3f6f9d5fe4f0e4b96e3a6be95d50a14efd63 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Thu, 3 Aug 2023 11:07:50 +0800 Subject: [PATCH] [Feature]: Use multimodal (#73) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Feature]: Add minigpt-4 * [Feature]: Add mm local runner * [Feature]: Add instructblip * [Feature]: Delete redundant file * [Feature]: Delete redundant file * [Feature]: Add README to InstructBLIP * [Feature]: Update MiniGPT-4 * [Fix]: Fix lint * [Feature]add omnibenchmark readme (#49) * add omnibenchmark readme * fix * Update OmniMMBench.md * Update OmniMMBench.md * Update OmniMMBench.md * [Fix]: Refine name (#54) * [Feature]: Unify out and err * [Fix]: Fix lint * [Feature]: Rename to mmbench and change weight path * [Feature]: Delete Omni in instructblip * [Feature]: Check the avaliablity of lavis * [Fix]: Fix lint * [Feature]: Refactor MM * [Refactor]: Refactor path * [Feature]: Delete redundant files * [Refactor]: Delete redundant files --------- Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> --- .gitignore | 1 + configs/multimodal/instructblip/README.md | 9 + .../instructblip/instructblip-mmbench.py | 45 +++ configs/multimodal/minigpt_4/README.md | 10 + .../minigpt_4/minigpt_4_7b_mmbench.py | 42 +++ configs/multimodal/tasks.py | 15 + docs/en/MMBench.md | 1 - opencompass/metrics/__init__.py | 3 + opencompass/metrics/dump_results.py | 53 ++++ opencompass/multimodal/datasets/__init__.py | 3 + opencompass/multimodal/datasets/mmbench.py | 79 ++++++ opencompass/multimodal/models/__init__.py | 5 + .../models/instructblip/__init__.py | 3 + .../instructblip/blip2_vicuna_instruct.py | 260 ++++++++++++++++++ .../multimodal/models/minigpt_4/__init__.py | 3 + .../multimodal/models/minigpt_4/minigpt_4.py | 181 ++++++++++++ .../multimodal/models/minigpt_4/utils.py | 56 ++++ .../openicl/icl_evaluator/icl_hf_evaluator.py | 2 +- opencompass/partitioners/__init__.py | 1 + opencompass/partitioners/mm_naive.py | 119 ++++++++ opencompass/registry.py | 12 + opencompass/runners/slurm.py | 1 - opencompass/tasks/__init__.py | 1 + opencompass/tasks/mm_infer.py | 126 +++++++++ opencompass/utils/__init__.py | 1 + opencompass/utils/dependency.py | 32 +++ run.py | 37 ++- 27 files changed, 1096 insertions(+), 5 deletions(-) create mode 100644 configs/multimodal/instructblip/README.md create mode 100644 configs/multimodal/instructblip/instructblip-mmbench.py create mode 100644 configs/multimodal/minigpt_4/README.md create mode 100644 configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py create mode 100644 configs/multimodal/tasks.py create mode 100644 opencompass/metrics/__init__.py create mode 100644 opencompass/metrics/dump_results.py create mode 100644 opencompass/multimodal/datasets/__init__.py create mode 100644 opencompass/multimodal/datasets/mmbench.py create mode 100644 opencompass/multimodal/models/__init__.py create mode 100644 opencompass/multimodal/models/instructblip/__init__.py create mode 100644 opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py create mode 100644 opencompass/multimodal/models/minigpt_4/__init__.py create mode 100644 opencompass/multimodal/models/minigpt_4/minigpt_4.py create mode 100644 opencompass/multimodal/models/minigpt_4/utils.py create mode 100644 opencompass/partitioners/mm_naive.py create mode 100644 opencompass/tasks/mm_infer.py create mode 100644 opencompass/utils/dependency.py diff --git a/.gitignore b/.gitignore index d3a72f8b..23bf2a52 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ configs/datasets/log.json configs/eval_debug*.py configs/viz_*.py data +work_dirs # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/configs/multimodal/instructblip/README.md b/configs/multimodal/instructblip/README.md new file mode 100644 index 00000000..1002328d --- /dev/null +++ b/configs/multimodal/instructblip/README.md @@ -0,0 +1,9 @@ +# InstructBLIP + +### Prepare the environment + +```sh +git clone https://github.com/salesforce/LAVIS.git +cd ./LAVIS +pip install -e . +``` \ No newline at end of file diff --git a/configs/multimodal/instructblip/instructblip-mmbench.py b/configs/multimodal/instructblip/instructblip-mmbench.py new file mode 100644 index 00000000..f4923e89 --- /dev/null +++ b/configs/multimodal/instructblip/instructblip-mmbench.py @@ -0,0 +1,45 @@ +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'answer', 'category', 'l2-category', 'context', + 'index', 'options_dict', 'options', 'split' + ]) +] + +dataset = dict(type='opencompass.MMBench', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +model = dict( + type='blip2-vicuna-instruct-mmbench', + freeze_vit=True, + low_resource=False, + llm_model='/path/to/vicuna-7b/', + sys_prompt= # noqa: E251 + '###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n' +) + +# evaluation settings +evaluator = [ + dict( + type='opencompass.DumpResults', + save_path= # noqa: E251 + 'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx') +] + +load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa diff --git a/configs/multimodal/minigpt_4/README.md b/configs/multimodal/minigpt_4/README.md new file mode 100644 index 00000000..7d2bf4e5 --- /dev/null +++ b/configs/multimodal/minigpt_4/README.md @@ -0,0 +1,10 @@ +# MiniGPT-4 + +### Prepare the environment + +```sh +cd opencompass/multimodal/models/minigpt_4 +git clone https://github.com/Vision-CAIR/MiniGPT-4.git +``` + +Then prepare the environement according to this [doc](https://github.com/Vision-CAIR/MiniGPT-4) \ No newline at end of file diff --git a/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py new file mode 100644 index 00000000..913007cc --- /dev/null +++ b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py @@ -0,0 +1,42 @@ +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'answer', 'category', 'l2-category', 'context', + 'index', 'options_dict', 'options', 'split' + ]) +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +minigpt_4_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +minigpt_4_model = dict( + type='minigpt-4-mmbench', + low_resource=True, + llama_model='/path/to/vicuna', + sys_prompt= # noqa: E251 + '###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n' +) + +# evaluation settings +minigpt_4_evaluator = [ + dict(type='opencompass.DumpResults', + save_path='work_dirs/minigpt-4-7b-mmbench.xlsx') +] + +minigpt_4_load_from = '/path/to/minigpt-4' # noqa diff --git a/configs/multimodal/tasks.py b/configs/multimodal/tasks.py new file mode 100644 index 00000000..b8fd75fd --- /dev/null +++ b/configs/multimodal/tasks.py @@ -0,0 +1,15 @@ +from mmengine.config import read_base + +with read_base(): + from .minigpt_4.minigpt_4_7b_mmbench import (minigpt_4_dataloader, + minigpt_4_evaluator, + minigpt_4_load_from, + minigpt_4_model) + +models = [minigpt_4_model] +datasets = [minigpt_4_dataloader] +evaluators = [minigpt_4_evaluator] +load_froms = [minigpt_4_load_from] +num_gpus = 1 +num_procs = 1 +launcher = 'slurm' diff --git a/docs/en/MMBench.md b/docs/en/MMBench.md index 02db8b95..103854c6 100644 --- a/docs/en/MMBench.md +++ b/docs/en/MMBench.md @@ -76,7 +76,6 @@ class MMBenchDataset(Dataset): 'context': hint, } return data - def load_from_df(self, idx, key): if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]): return self.df.iloc[idx][key] diff --git a/opencompass/metrics/__init__.py b/opencompass/metrics/__init__.py new file mode 100644 index 00000000..09333edf --- /dev/null +++ b/opencompass/metrics/__init__.py @@ -0,0 +1,3 @@ +from .dump_results import DumpResults + +__all__ = ['DumpResults'] diff --git a/opencompass/metrics/dump_results.py b/opencompass/metrics/dump_results.py new file mode 100644 index 00000000..0b330729 --- /dev/null +++ b/opencompass/metrics/dump_results.py @@ -0,0 +1,53 @@ +import os +from typing import Optional + +import pandas as pd +from mmengine.evaluator import BaseMetric + +from opencompass.registry import METRICS + + +@METRICS.register_module() +class DumpResults(BaseMetric): + """Dump model's prediction to a file. + + Args: + save_path (str): the path to save model's prediction. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None. + """ + + def __init__(self, + save_path: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + self.save_path = save_path + if not os.path.exists(os.path.dirname(self.save_path)): + os.makedirs(os.path.dirname(self.save_path), exist_ok=True) + + def process(self, data_batch, data_samples) -> None: + for data_sample in data_samples: + result = dict() + + result['question'] = data_sample.get('question') + result.update(data_sample.get('options_dict')) + result['prediction'] = data_sample.get('pred_answer') + if data_sample.get('category') is not None: + result['category'] = data_sample.get('category') + if data_sample.get('l2-category') is not None: + result['l2-category'] = data_sample.get('l2-category') + result['index'] = data_sample.get('index') + result['split'] = data_sample.get('split') + self.results.append(result) + + def compute_metrics(self, results: list) -> dict: + df = pd.DataFrame(results) + with pd.ExcelWriter(self.save_path, engine='openpyxl') as writer: + df.to_excel(writer, index=False) + return {} diff --git a/opencompass/multimodal/datasets/__init__.py b/opencompass/multimodal/datasets/__init__.py new file mode 100644 index 00000000..47fa24ad --- /dev/null +++ b/opencompass/multimodal/datasets/__init__.py @@ -0,0 +1,3 @@ +from .mmbench import MMBenchDataset + +__all__ = ['MMBenchDataset'] diff --git a/opencompass/multimodal/datasets/mmbench.py b/opencompass/multimodal/datasets/mmbench.py new file mode 100644 index 00000000..a1ba4c58 --- /dev/null +++ b/opencompass/multimodal/datasets/mmbench.py @@ -0,0 +1,79 @@ +import base64 +import io +from typing import List, Optional + +import pandas as pd +from mmengine.dataset import Compose +from PIL import Image +from torch.utils.data import Dataset + +from opencompass.registry import DATASETS + + +def decode_base64_to_image(base64_string) -> Image: + """Convert raw data into Pillow image.""" + image_data = base64.b64decode(base64_string) + image = Image.open(io.BytesIO(image_data)) + return image + + +@DATASETS.register_module() +class MMBenchDataset(Dataset): + """Dataset to load MMBench dataset. + + Args: + data_file (str): The path of the dataset. + pipeline (dict): The data augmentation. + sys_prompt (str): The system prompt added to the head + of these options. Defaults to + There are several options: + """ + + def __init__(self, + data_file: str, + pipeline: List[dict], + sys_prompt: str = 'There are several options:') -> None: + self.df = pd.read_csv(data_file, sep='\t') + self.pipeline = Compose(pipeline) + self.sys_prompt = sys_prompt + + def __len__(self) -> None: + return len(self.df) + + def __getitem__(self, idx: str) -> dict: + index = self.df.iloc[idx]['index'] + image = self.df.iloc[idx]['image'] + image = decode_base64_to_image(image) + question = self.df.iloc[idx]['question'] + catetory = self.df.iloc[idx]['category'] + l2_catetory = self.df.iloc[idx]['l2-category'] + + option_candidate = ['A', 'B', 'C', 'D', 'E'] + options = { + cand: self.load_from_df(idx, cand) + for cand in option_candidate + if self.load_from_df(idx, cand) is not None + } + options_prompt = f'{self.sys_prompt}\n' + for key, item in options.items(): + options_prompt += f'{key}. {item}\n' + + hint = self.load_from_df(idx, 'hint') + data = { + 'img': image, + 'question': question, + 'options': options_prompt, + 'category': catetory, + 'l2-category': l2_catetory, + 'options_dict': options, + 'index': index, + 'context': hint, + } + data = self.pipeline(data) + return data + + def load_from_df(self, idx: int, key: str) -> Optional[str]: + if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]): + return self.df.iloc[idx][key] + else: + return None diff --git a/opencompass/multimodal/models/__init__.py b/opencompass/multimodal/models/__init__.py new file mode 100644 index 00000000..3747a125 --- /dev/null +++ b/opencompass/multimodal/models/__init__.py @@ -0,0 +1,5 @@ +from opencompass.utils import satisfy_requirement + +if satisfy_requirement('salesforce-lavis'): + from .instructblip import * # noqa: F401, F403 +from .minigpt_4 import * # noqa: F401, F403 diff --git a/opencompass/multimodal/models/instructblip/__init__.py b/opencompass/multimodal/models/instructblip/__init__.py new file mode 100644 index 00000000..1aa1c98b --- /dev/null +++ b/opencompass/multimodal/models/instructblip/__init__.py @@ -0,0 +1,3 @@ +from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench + +__all__ = ['Blip2VicunaInstructMMBench'] diff --git a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py new file mode 100644 index 00000000..6595df10 --- /dev/null +++ b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py @@ -0,0 +1,260 @@ +"""Requires Transformer 4.28 and above, implementation may change according the +Llama implementation.""" +import logging +import re + +import torch +import torch.nn as nn +from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train +from mmengine.device import get_device +from transformers import LlamaForCausalLM, LlamaTokenizer + +from opencompass.registry import MM_MODELS + + +@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench') +class Blip2VicunaInstructMMBench(Blip2Base): + + def __init__( + self, + vit_model='eva_clip_g', + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision='fp16', + freeze_vit=True, + num_query_token=32, + llm_model='', + sys_prompt='', + prompt='', + max_txt_len=128, + max_output_txt_len=256, + qformer_text_input=True, + low_resource=False, + ): + super().__init__() + self.tokenizer = self.init_tokenizer(truncation_side='left') + + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, + vit_precision) + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + logging.info('freeze vision encoder') + + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features) + + if not qformer_text_input: + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + else: + self.Qformer.resize_token_embeddings(len(self.tokenizer)) + self.Qformer.cls = None + + self.llm_tokenizer = LlamaTokenizer.from_pretrained( + llm_model, use_fast=False, truncation_side='left') + + if low_resource: + self.llm_model = LlamaForCausalLM.from_pretrained( + llm_model, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map={'': 0}) + else: + self.llm_model = LlamaForCausalLM.from_pretrained( + llm_model, torch_dtype=torch.float16) + self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + self.llm_tokenizer.add_special_tokens({'bos_token': ''}) + self.llm_tokenizer.add_special_tokens({'eos_token': ''}) + self.llm_tokenizer.add_special_tokens({'unk_token': ''}) + + self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) + + for name, param in self.llm_model.named_parameters(): + param.requires_grad = False + + self.llm_proj = nn.Linear(self.Qformer.config.hidden_size, + self.llm_model.config.hidden_size) + + self.max_txt_len = max_txt_len + self.max_output_txt_len = max_output_txt_len + self.sys_prompt = sys_prompt + self.prompt = prompt + + self._lemmatizer = None + + self.qformer_text_input = qformer_text_input + + def concat_text_input_output(self, input_ids, input_atts, output_ids, + output_atts): + input_part_targets_len = [] + llm_tokens = {'input_ids': [], 'attention_mask': []} + for i in range(input_ids.size(0)): + this_input_ones = input_atts[i].sum() + input_part_targets_len.append(this_input_ones) + llm_tokens['input_ids'].append( + torch.cat([ + input_ids[i][:this_input_ones], output_ids[i][1:], + input_ids[i][this_input_ones:] + ])) + llm_tokens['attention_mask'].append( + torch.cat([ + input_atts[i][:this_input_ones], output_atts[i][1:], + input_atts[i][this_input_ones:] + ])) + llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids']) + llm_tokens['attention_mask'] = torch.stack( + llm_tokens['attention_mask']) + return llm_tokens, input_part_targets_len + + def pack_inputs(self, batch): + images = [image.unsqueeze(0) for image in batch['inputs']] + data_samples = [data_sample for data_sample in batch['data_samples']] + images = torch.cat(images, dim=0).to(get_device()) + inputs = {'image': images, 'data_samples': data_samples} + return inputs + + @torch.no_grad() + def generate( + self, + batch, + use_nucleus_sampling=False, + num_beams=5, + max_length=256, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1, + num_captions=1, + temperature=1, + ): + inputs = self.pack_inputs(batch) + image = inputs.pop('image') + data_samples = inputs['data_samples'] + samples = {'image': image} + questions = [ + data_sample.get('question') for data_sample in data_samples + ] + options = [data_sample.get('options') for data_sample in data_samples] + if data_samples[0].get('context') is not None: + contexts = [ + data_sample.get('context') for data_sample in data_samples + ] + prompt = [ + context + ' ' + question + ' ' + option for context, question, + option in zip(contexts, questions, options) + ] + else: + prompt = [ + question + ' ' + option + for question, option in zip(questions, options) + ] + + self.llm_tokenizer.padding_side = 'left' + + image = samples['image'] + + bs = image.size(0) + + if isinstance(prompt, str): + prompt = [prompt] * bs + else: + assert len( + prompt + ) == bs, 'The number of prompts must be equal to the batch size.' + + query_tokens = self.query_tokens.expand(bs, -1, -1) + if self.qformer_text_input: + text_Qformer = self.tokenizer( + prompt, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], + dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], + dim=1) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], + dtype=torch.long).to(image.device) + + if self.qformer_text_input: + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llm = self.llm_proj( + query_output.last_hidden_state[:, :query_tokens.size(1), :]) + atts_llm = torch.ones(inputs_llm.size()[:-1], + dtype=torch.long).to(image.device) + + prompt = ['###Human: ' + p + '###Assistant:' for p in prompt] + prompt = [self.sys_prompt + p for p in prompt] + llm_tokens = self.llm_tokenizer(prompt, + padding='longest', + return_tensors='pt').to(image.device) + + with self.maybe_autocast(): + inputs_embeds = self.llm_model.get_input_embeddings()( + llm_tokens.input_ids) + inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1) + attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], + dim=1) + + outputs = self.llm_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_length=max_length, + min_length=min_length, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) + output_text = self.llm_tokenizer.batch_decode(outputs, + skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + output_text = self.post_process(output_text[0]) + data_sample = data_samples[0] + data_sample.pred_answer = output_text + + return data_sample + + def post_process(self, output_text): + output_text = output_text.split('###')[0] + output_text = output_text.split('Assistant:')[-1].strip() + output_text = output_text.strip('') + output_text = output_text.strip('') + output_text = output_text.strip() + pattern = re.compile(r'([A-Z]\.)') + res = pattern.findall(output_text) + if len(res) > 0: + output_text = res[0][:-1] + return output_text diff --git a/opencompass/multimodal/models/minigpt_4/__init__.py b/opencompass/multimodal/models/minigpt_4/__init__.py new file mode 100644 index 00000000..3104855c --- /dev/null +++ b/opencompass/multimodal/models/minigpt_4/__init__.py @@ -0,0 +1,3 @@ +from .minigpt_4 import MiniGPT4MMBench + +__all__ = ['MiniGPT4MMBench'] diff --git a/opencompass/multimodal/models/minigpt_4/minigpt_4.py b/opencompass/multimodal/models/minigpt_4/minigpt_4.py new file mode 100644 index 00000000..306fec58 --- /dev/null +++ b/opencompass/multimodal/models/minigpt_4/minigpt_4.py @@ -0,0 +1,181 @@ +import os +import re +import sys + +import torch +import torch.nn as nn +from mmengine.device import get_device +from transformers import StoppingCriteriaList + +from opencompass.registry import MM_MODELS + +from .utils import StoppingCriteriaSub + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +def load_package(): + """Load required packages from MiniGPT-4.""" + current_file_path = os.path.abspath(__file__) + current_folder_path = os.path.dirname(current_file_path) + + sys.path.append(os.path.join(current_folder_path, 'MiniGPT-4')) # noqa + from minigpt4.models.mini_gpt4 import MiniGPT4 + + sys.path.pop(-1) + + return MiniGPT4 + + +MiniGPT4 = load_package() + + +@MM_MODELS.register_module('minigpt-4-mmbench') +class MiniGPT4MMBench(MiniGPT4): + """Inference code of MiniGPT-4 on MMBench. + + Args: + llama_model (str): The path of vicuna path. + sys_prompt (str): The prompt added to the beginning + of each query. Defaults to ''. + low_resource (bool): Whether loaded in low precision. + Defaults to False. + """ + + def __init__(self, + llama_model: str, + sys_prompt: str = '', + low_resource: bool = False) -> None: + super().__init__(llama_model=llama_model, low_resource=low_resource) + + cur_device = get_device() + stop_words_ids = [ + torch.tensor([835]).to(cur_device), + torch.tensor([2277, 29937]).to(cur_device), + ] + self.stopping_criteria = StoppingCriteriaList( + [StoppingCriteriaSub(stops=stop_words_ids)]) + self.sys_prompt = sys_prompt + + def encode_img(self, image): + device = image.device + + with self.maybe_autocast(): + image_embeds = self.ln_vision( + self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], + dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, + -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + atts_llama = torch.ones(inputs_llama.size()[:-1], + dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + def pack_inputs(self, batch): + images = [image.unsqueeze(0) for image in batch['inputs']] + data_samples = [data_sample for data_sample in batch['data_samples']] + images = torch.cat(images, dim=0).to(get_device()) + inputs = {'image': images, 'data_samples': data_samples} + return inputs + + def generate(self, batch): + inputs = self.pack_inputs(batch) + image = inputs.pop('image') + data_samples = inputs['data_samples'] + samples = {'image': image} + question = [ + data_sample.get('question') for data_sample in data_samples + ] + options = [data_sample.get('options') for data_sample in data_samples] + samples.update({'question': question[0]}) + samples.update({'options': options[0]}) + if data_samples[0].get('context') is not None: + context = [ + data_sample.get('context') for data_sample in data_samples + ] + samples.update({'context': context}) + data_sample = data_samples[0] + img_prompt = '###Human: ' + if 'context' in samples: + context_prompt = samples['context'][0] + + question = samples['question'] + options = samples['options'] + if 'context' in samples: + prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa + else: + prompt = img_prompt + ' ' + question + ' ' + options + + # prompt = self.sys_prompt + prompt + prompt = prompt + '###Assistant:' + + image = samples['image'] + img_embeds, _ = self.encode_img(image) + + prompt_segs = prompt.split('') + prompt_seg_tokens = [ + self.llama_tokenizer(seg, + return_tensors='pt', + add_special_tokens=i == 0). + to(self.llama_model.model.embed_tokens.weight.device).input_ids + for i, seg in enumerate(prompt_segs) + ] + prompt_seg_embs = [ + self.llama_model.model.embed_tokens(seg) + for seg in prompt_seg_tokens + ] + prompt_seg_embs = [prompt_seg_embs[0], img_embeds, prompt_seg_embs[1]] + prompt_embs = torch.cat(prompt_seg_embs, dim=1) + + # generate output + outputs = self.llama_model.generate( + inputs_embeds=prompt_embs, + max_new_tokens=20, + num_beams=5, + do_sample=False, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=-1.0, + temperature=1.0, + stopping_criteria=self.stopping_criteria, + num_return_sequences=1) + + output_token = outputs[0] + if output_token[0] == 0: + output_token = output_token[1:] + if output_token[0] == 1: + output_token = output_token[1:] + output_text = self.llama_tokenizer.decode(output_token, + add_special_tokens=False) + output_text = self.post_process(output_text) + data_sample.pred_answer = output_text + return data_sample + + def post_process(self, output_text): + output_text = output_text.split('###')[0] + output_text = output_text.split('Assistant:')[-1].strip() + output_text = output_text.strip('') + output_text = output_text.strip('') + output_text = output_text.strip() + pattern = re.compile(r'([A-Z]\.)') + res = pattern.findall(output_text) + if len(res) > 0: + output_text = res[0][:-1] + return output_text diff --git a/opencompass/multimodal/models/minigpt_4/utils.py b/opencompass/multimodal/models/minigpt_4/utils.py new file mode 100644 index 00000000..777c1939 --- /dev/null +++ b/opencompass/multimodal/models/minigpt_4/utils.py @@ -0,0 +1,56 @@ +import os +import re + +import timm.models.hub as timm_hub +import torch +import torch.distributed as dist +from mmengine.dist import is_distributed, is_main_process +from transformers import StoppingCriteria + + +class StoppingCriteriaSub(StoppingCriteria): + + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop):])).item(): + return True + + return False + + +def download_cached_file(url, check_hash=True, progress=False): + """Download a file from a URL and cache it locally. + + If the file already exists, it is not downloaded again. If distributed, + only the main process downloads the file, and the other processes wait for + the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_distributed(): + dist.barrier() + + return get_cached_file_path() + + +def is_url(input_url): + """Check if an input string is a url. + + look for http(s):// and ignoring the case + """ + is_url = re.match(r'^(?:http)s?://', input_url, re.IGNORECASE) is not None + return is_url diff --git a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py index 89fb1d17..3e2e4bcc 100644 --- a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py @@ -120,7 +120,7 @@ class AccEvaluator(HuggingfaceEvaluator): @ICL_EVALUATORS.register_module() class RougeEvaluator(HuggingfaceEvaluator): - """Rouge evaluator.""" + """Rouge evaluator.""" # noqa def __init__(self) -> None: super().__init__(metric='rouge') diff --git a/opencompass/partitioners/__init__.py b/opencompass/partitioners/__init__.py index 836081fb..ee9fe108 100644 --- a/opencompass/partitioners/__init__.py +++ b/opencompass/partitioners/__init__.py @@ -1,2 +1,3 @@ +from .mm_naive import * # noqa: F401, F403 from .naive import * # noqa: F401, F403 from .size import * # noqa: F401, F403 diff --git a/opencompass/partitioners/mm_naive.py b/opencompass/partitioners/mm_naive.py new file mode 100644 index 00000000..817b1276 --- /dev/null +++ b/opencompass/partitioners/mm_naive.py @@ -0,0 +1,119 @@ +from copy import deepcopy +from typing import Dict, List + +from mmengine.config import Config, ConfigDict + +from opencompass.registry import PARTITIONERS + +from .base import BasePartitioner + + +@PARTITIONERS.register_module() +class MultimodalNaivePartitioner(BasePartitioner): + """Multimodal naive task partitioner. + + This partitioner will generate a task for each + model-dataset-evaluator pair. + + Args: + config (ConfigDict): The full config dict. + """ + + def partition(self, models: List[ConfigDict], datasets: List[ConfigDict], + evaluators: List[ConfigDict], load_froms: List[ConfigDict], + work_dir: str, num_gpus: int, num_procs: int, + launcher: str) -> List[Dict]: + """Partition model-dataset pairs into tasks. Each task is defined as a + dict and will run independently as a unit. Its structure is as follows: + + .. code-block:: python + + { + 'models': [], # a list of model configs + 'datasets': [], # a list of dataset configs + 'evaluators': [], # a list of evaluator configs + 'load_froms': [], # a list of load_from paths + 'work_dir': '', # the work dir + 'num_gpus': int, # integer, number of gpus for each task + 'num_procs': int, # integer, number of gpus on single machine + 'launcher': str, # string, how to launch distributed training + } + + Args: + models (List[ConfigDict]): A list of model configs. + datasets (List[ConfigDict]): A list of dataset configs. + evaluators (List[ConfigDict]): A list of evaluator configs. + load_froms (List[ConfigDict]): A list of load_from paths. + work_dir (str): The work dir for the task. + num_gpus (int): Number of gpus for each task. + num_procs (int): Number of gpus on single machine. + launcher (str): How to launch distributed training. + Only `slurm`, `pytorch` and `mpi` are available. + + Returns: + List[Dict]: A list of tasks. + """ + + tasks = [] + for model, dataset, evaluator, load_from in zip( + models, datasets, evaluators, load_froms): + task = Config({ + 'model': model, + 'dataset': dataset, + 'evaluator': evaluator, + 'load_from': load_from, + 'work_dir': work_dir, + 'num_gpus': num_gpus, + 'num_procs': num_procs, + 'launcher': launcher + }) + tasks.append(task) + + return tasks + + def __call__(self, cfg: ConfigDict) -> List[Dict]: + """Generate tasks from config. Each task is defined as a + dict and will run independently as a unit. Its structure is as + follows: + + .. code-block:: python + + { + 'models': [], # a list of model configs + 'datasets': [], # a list of dataset configs + 'evaluators': [], # a list of evaluator configs + 'load_froms': [], # a list of load_from paths + 'work_dir': '', # the work dir + 'num_gpus': int, # integer, number of gpus for each task + 'num_procs': int, # integer, number of gpus on single machine + } + + Args: + cfg (ConfigDict): The config dict, containing "models", "dataset" + and "work_dir" keys. + + Returns: + List[Dict]: A list of tasks. + """ + cfg = deepcopy(cfg) + models = cfg['models'] + datasets = cfg['datasets'] + evaluators = cfg['evaluators'] + load_froms = cfg['load_froms'] + work_dir = cfg['work_dir'] + num_gpus = cfg['num_gpus'] + num_procs = cfg['num_procs'] + launcher = cfg['launcher'] + + tasks = self.partition(models, datasets, evaluators, load_froms, + work_dir, num_gpus, num_procs, launcher) + + self.logger.info(f'Partitioned into {len(tasks)} tasks.') + for i, task in enumerate(tasks): + model_name = task['model']['type'] + dataset_name = task['dataset']['dataset']['type'] + evaluator_name = task['evaluator'][0]['type'] + self.logger.debug( + f'Task {i}: {model_name}-{dataset_name}-{evaluator_name}') + + return tasks diff --git a/opencompass/registry.py b/opencompass/registry.py index 8f26e607..a48ee51c 100644 --- a/opencompass/registry.py +++ b/opencompass/registry.py @@ -1,3 +1,6 @@ +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODELS as MMENGINE_MODELS from mmengine.registry import Registry PARTITIONERS = Registry('partitioner', locations=['opencompass.partitioners']) @@ -22,3 +25,12 @@ ICL_PROMPT_TEMPLATES = Registry( locations=['opencompass.openicl.icl_prompt_template']) ICL_EVALUATORS = Registry('icl_evaluators', locations=['opencompass.openicl.icl_evaluator']) +DATASETS = Registry('mm_datasets', + parent=MMENGINE_DATASETS, + locations=['opencompass.multimodal.datasets']) +METRICS = Registry('metric', + parent=MMENGINE_METRICS, + locations=['opencompass.metrics']) +MM_MODELS = Registry('mm_model', + parent=MMENGINE_MODELS, + locations=['opencompass.multimodal.models']) diff --git a/opencompass/runners/slurm.py b/opencompass/runners/slurm.py index ddb7808d..c6efb60c 100644 --- a/opencompass/runners/slurm.py +++ b/opencompass/runners/slurm.py @@ -81,7 +81,6 @@ class SlurmRunner(BaseRunner): Returns: tuple[str, int]: Task name and exit code. """ - task_type = self.task_cfg.type if isinstance(self.task_cfg.type, str): task_type = TASKS.get(task_type) diff --git a/opencompass/tasks/__init__.py b/opencompass/tasks/__init__.py index 308ea5d6..ac63c77d 100644 --- a/opencompass/tasks/__init__.py +++ b/opencompass/tasks/__init__.py @@ -1,2 +1,3 @@ +from .mm_infer import * # noqa: F401, F403 from .openicl_eval import * # noqa: F401, F403 from .openicl_infer import * # noqa: F401, F403 diff --git a/opencompass/tasks/mm_infer.py b/opencompass/tasks/mm_infer.py new file mode 100644 index 00000000..2d52c230 --- /dev/null +++ b/opencompass/tasks/mm_infer.py @@ -0,0 +1,126 @@ +import argparse +import json +import os +import os.path as osp +import random +import time +from typing import Sequence + +import torch +import torch.distributed as dist +from mmengine.config import Config, ConfigDict +from mmengine.device import get_device +from mmengine.dist import init_dist +from mmengine.evaluator import Evaluator +from mmengine.logging import print_log +from mmengine.model.wrappers import MMDistributedDataParallel +from mmengine.runner import Runner +from mmengine.utils import track_iter_progress + +from opencompass.registry import MM_MODELS, TASKS +from opencompass.utils import get_logger + + +def build_model(cfg): + model = MM_MODELS.build(cfg['model']) + load_from = cfg.get('load_from', None) + if load_from is not None: + state_dict = torch.load(cfg['load_from'], map_location='cpu') + if 'model' in state_dict: + state_dict = state_dict['model'] + elif 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + msg = model.load_state_dict(state_dict, strict=False) + print_log(msg) + model.to(get_device()) + if dist.is_initialized(): + model = MMDistributedDataParallel( + model, + device_ids=[int(os.environ['LOCAL_RANK'])], + broadcast_buffers=False) + return model + + +@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run +class MultimodalInferTask: + """Multimodal Inference Task. + + This task is used to run the inference process. + """ + + def __init__(self, cfg: ConfigDict): + self.num_gpus = cfg.get('num_gpus', 0) + self.num_procs = cfg.get('num_procs', 1) + self.dataloader = cfg.get('dataset') + self.model = cfg.get('model') + self.evaluator = cfg.get('evaluator') + self.cfg = cfg + self.logger = get_logger() + + @property + def name(self) -> str: + model_name = self.model['type'] + dataset_name = self.dataloader['dataset']['type'] + evaluator_name = self.evaluator[0]['type'] + return f'{model_name}-{dataset_name}-{evaluator_name}' + + def get_command(self, cfg_path, template): + """Get the command template for the task. + + Args: + cfg_path (str): The path to the config file of the task. + template (str): The template which have '{task_cmd}' to format + the command. + """ + script_path = __file__ + if self.num_gpus > 0: + port = random.randint(12000, 32000) + command = (f'torchrun --master_port={port} ' + f'--nproc_per_node {self.num_procs} ' + f'{script_path} {cfg_path}') + else: + command = f'python {script_path} {cfg_path}' + + return template.format(task_cmd=command) + + def run(self): + # only support slurm, pytorch, mpi + init_dist(self.cfg.launcher) + self.logger.info(f'Task {self.name}') + # build dataloader + dataloader = Runner.build_dataloader(self.dataloader) + # build model + model = build_model(self.cfg) + # build evaluator + evaluator = Evaluator(self.evaluator) + + for batch in track_iter_progress(dataloader): + if dist.is_initialized(): + data_samples = model.module.generate(batch) + else: + data_samples = model.generate(batch) + if not isinstance(data_samples, Sequence): + data_samples = [data_samples] + evaluator.process(data_samples) + + metrics = evaluator.evaluate(len(dataloader.dataset)) + metrics_file = osp.join(cfg.work_dir, 'res.log') + with open(metrics_file, 'w') as f: + json.dump(metrics, f) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model Inferencer') + parser.add_argument('config', help='Config file path') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + cfg = Config.fromfile(args.config) + start_time = time.time() + inferencer = MultimodalInferTask(cfg) + inferencer.run() + end_time = time.time() + get_logger().info(f'time elapsed: {end_time - start_time:.2f}s') diff --git a/opencompass/utils/__init__.py b/opencompass/utils/__init__.py index d9fdeb4c..d899d161 100644 --- a/opencompass/utils/__init__.py +++ b/opencompass/utils/__init__.py @@ -1,6 +1,7 @@ from .abbr import * # noqa from .build import * # noqa from .collect_env import * # noqa +from .dependency import * # noqa from .fileio import * # noqa from .git import * # noqa from .lark import * # noqa diff --git a/opencompass/utils/dependency.py b/opencompass/utils/dependency.py new file mode 100644 index 00000000..821735f7 --- /dev/null +++ b/opencompass/utils/dependency.py @@ -0,0 +1,32 @@ +import re + +from importlib_metadata import PackageNotFoundError, distribution +from mmengine.utils import digit_version + + +def satisfy_requirement(dep): + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, dep, maxsplit=1) + parts = [p.strip() for p in parts] + package = parts[0] + if len(parts) > 1: + op, version = parts[1:] + op = { + '>=': '__ge__', + '==': '__eq__', + '>': '__gt__', + '<': '__lt__', + '<=': '__le__' + }[op] + else: + op, version = None, None + + try: + dist = distribution(package) + if op is None or getattr(digit_version(dist.version), op)( + digit_version(version)): + return True + except PackageNotFoundError: + pass + + return False diff --git a/run.py b/run.py index 661beed2..3c975b95 100644 --- a/run.py +++ b/run.py @@ -6,7 +6,8 @@ from datetime import datetime from mmengine.config import Config -from opencompass.partitioners import NaivePartitioner, SizePartitioner +from opencompass.partitioners import (MultimodalNaivePartitioner, + NaivePartitioner, SizePartitioner) from opencompass.registry import PARTITIONERS, RUNNERS from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner from opencompass.utils import LarkReporter, Summarizer, get_logger @@ -37,6 +38,10 @@ def parse_args(): 'redirected to files', action='store_true', default=False) + parser.add_argument('--mm-eval', + help='Whether or not enable multimodal evaluation', + action='store_true', + default=False) parser.add_argument('--dry-run', help='Dry run mode, in which the scheduler will not ' 'actually run the tasks, but only print the commands ' @@ -201,7 +206,14 @@ def main(): 'also specified --slurm or --dlc. ' 'The "infer" configuration will be overridden by ' 'your runtime arguments.') - if args.dlc or args.slurm or cfg.get('infer', None) is None: + # Check whether run multimodal evaluation + if args.mm_eval: + partitioner = MultimodalNaivePartitioner( + osp.join(cfg['work_dir'], 'predictions/')) + tasks = partitioner(cfg) + exec_mm_infer_runner(tasks, args, cfg) + return + elif args.dlc or args.slurm or cfg.get('infer', None) is None: # Use SizePartitioner to split into subtasks partitioner = SizePartitioner( osp.join(cfg['work_dir'], 'predictions/'), @@ -283,6 +295,27 @@ def main(): summarizer.summarize(time_str=cfg_time_str) +def exec_mm_infer_runner(tasks, args, cfg): + """execute multimodal infer runner according to args.""" + if args.slurm: + runner = SlurmRunner(dict(type='MultimodalInferTask'), + max_num_workers=args.max_num_workers, + partition=args.partition, + quotatype=args.quotatype, + retry=args.retry, + debug=args.debug, + lark_bot_url=cfg['lark_bot_url']) + elif args.dlc: + raise NotImplementedError('Currently, we do not support evaluating \ + multimodal models on dlc.') + else: + runner = LocalRunner(task=dict(type='MultimodalInferTask'), + max_num_workers=args.max_num_workers, + debug=args.debug, + lark_bot_url=cfg['lark_bot_url']) + runner(tasks) + + def exec_infer_runner(tasks, args, cfg): """execute infer runner according to args.""" if args.slurm: