From 191a3f6f9d5fe4f0e4b96e3a6be95d50a14efd63 Mon Sep 17 00:00:00 2001
From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Date: Thu, 3 Aug 2023 11:07:50 +0800
Subject: [PATCH] [Feature]: Use multimodal (#73)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* [Feature]: Add minigpt-4
* [Feature]: Add mm local runner
* [Feature]: Add instructblip
* [Feature]: Delete redundant file
* [Feature]: Delete redundant file
* [Feature]: Add README to InstructBLIP
* [Feature]: Update MiniGPT-4
* [Fix]: Fix lint
* [Feature]add omnibenchmark readme (#49)
* add omnibenchmark readme
* fix
* Update OmniMMBench.md
* Update OmniMMBench.md
* Update OmniMMBench.md
* [Fix]: Refine name (#54)
* [Feature]: Unify out and err
* [Fix]: Fix lint
* [Feature]: Rename to mmbench and change weight path
* [Feature]: Delete Omni in instructblip
* [Feature]: Check the avaliablity of lavis
* [Fix]: Fix lint
* [Feature]: Refactor MM
* [Refactor]: Refactor path
* [Feature]: Delete redundant files
* [Refactor]: Delete redundant files
---------
Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com>
---
.gitignore | 1 +
configs/multimodal/instructblip/README.md | 9 +
.../instructblip/instructblip-mmbench.py | 45 +++
configs/multimodal/minigpt_4/README.md | 10 +
.../minigpt_4/minigpt_4_7b_mmbench.py | 42 +++
configs/multimodal/tasks.py | 15 +
docs/en/MMBench.md | 1 -
opencompass/metrics/__init__.py | 3 +
opencompass/metrics/dump_results.py | 53 ++++
opencompass/multimodal/datasets/__init__.py | 3 +
opencompass/multimodal/datasets/mmbench.py | 79 ++++++
opencompass/multimodal/models/__init__.py | 5 +
.../models/instructblip/__init__.py | 3 +
.../instructblip/blip2_vicuna_instruct.py | 260 ++++++++++++++++++
.../multimodal/models/minigpt_4/__init__.py | 3 +
.../multimodal/models/minigpt_4/minigpt_4.py | 181 ++++++++++++
.../multimodal/models/minigpt_4/utils.py | 56 ++++
.../openicl/icl_evaluator/icl_hf_evaluator.py | 2 +-
opencompass/partitioners/__init__.py | 1 +
opencompass/partitioners/mm_naive.py | 119 ++++++++
opencompass/registry.py | 12 +
opencompass/runners/slurm.py | 1 -
opencompass/tasks/__init__.py | 1 +
opencompass/tasks/mm_infer.py | 126 +++++++++
opencompass/utils/__init__.py | 1 +
opencompass/utils/dependency.py | 32 +++
run.py | 37 ++-
27 files changed, 1096 insertions(+), 5 deletions(-)
create mode 100644 configs/multimodal/instructblip/README.md
create mode 100644 configs/multimodal/instructblip/instructblip-mmbench.py
create mode 100644 configs/multimodal/minigpt_4/README.md
create mode 100644 configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
create mode 100644 configs/multimodal/tasks.py
create mode 100644 opencompass/metrics/__init__.py
create mode 100644 opencompass/metrics/dump_results.py
create mode 100644 opencompass/multimodal/datasets/__init__.py
create mode 100644 opencompass/multimodal/datasets/mmbench.py
create mode 100644 opencompass/multimodal/models/__init__.py
create mode 100644 opencompass/multimodal/models/instructblip/__init__.py
create mode 100644 opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
create mode 100644 opencompass/multimodal/models/minigpt_4/__init__.py
create mode 100644 opencompass/multimodal/models/minigpt_4/minigpt_4.py
create mode 100644 opencompass/multimodal/models/minigpt_4/utils.py
create mode 100644 opencompass/partitioners/mm_naive.py
create mode 100644 opencompass/tasks/mm_infer.py
create mode 100644 opencompass/utils/dependency.py
diff --git a/.gitignore b/.gitignore
index d3a72f8b..23bf2a52 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,7 @@ configs/datasets/log.json
configs/eval_debug*.py
configs/viz_*.py
data
+work_dirs
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/configs/multimodal/instructblip/README.md b/configs/multimodal/instructblip/README.md
new file mode 100644
index 00000000..1002328d
--- /dev/null
+++ b/configs/multimodal/instructblip/README.md
@@ -0,0 +1,9 @@
+# InstructBLIP
+
+### Prepare the environment
+
+```sh
+git clone https://github.com/salesforce/LAVIS.git
+cd ./LAVIS
+pip install -e .
+```
\ No newline at end of file
diff --git a/configs/multimodal/instructblip/instructblip-mmbench.py b/configs/multimodal/instructblip/instructblip-mmbench.py
new file mode 100644
index 00000000..f4923e89
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip-mmbench.py
@@ -0,0 +1,45 @@
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs',
+ algorithm_keys=[
+ 'question', 'answer', 'category', 'l2-category', 'context',
+ 'index', 'options_dict', 'options', 'split'
+ ])
+]
+
+dataset = dict(type='opencompass.MMBench',
+ data_file='data/mmbench/mmbench_test_20230712.tsv',
+ pipeline=val_pipeline)
+
+dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler', shuffle=False))
+
+# model settings
+model = dict(
+ type='blip2-vicuna-instruct-mmbench',
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ sys_prompt= # noqa: E251
+ '###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
+)
+
+# evaluation settings
+evaluator = [
+ dict(
+ type='opencompass.DumpResults',
+ save_path= # noqa: E251
+ 'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
+]
+
+load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa
diff --git a/configs/multimodal/minigpt_4/README.md b/configs/multimodal/minigpt_4/README.md
new file mode 100644
index 00000000..7d2bf4e5
--- /dev/null
+++ b/configs/multimodal/minigpt_4/README.md
@@ -0,0 +1,10 @@
+# MiniGPT-4
+
+### Prepare the environment
+
+```sh
+cd opencompass/multimodal/models/minigpt_4
+git clone https://github.com/Vision-CAIR/MiniGPT-4.git
+```
+
+Then prepare the environement according to this [doc](https://github.com/Vision-CAIR/MiniGPT-4)
\ No newline at end of file
diff --git a/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
new file mode 100644
index 00000000..913007cc
--- /dev/null
+++ b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
@@ -0,0 +1,42 @@
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs',
+ algorithm_keys=[
+ 'question', 'answer', 'category', 'l2-category', 'context',
+ 'index', 'options_dict', 'options', 'split'
+ ])
+]
+
+dataset = dict(type='opencompass.MMBenchDataset',
+ data_file='data/mmbench/mmbench_test_20230712.tsv',
+ pipeline=val_pipeline)
+
+minigpt_4_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler', shuffle=False))
+
+# model settings
+minigpt_4_model = dict(
+ type='minigpt-4-mmbench',
+ low_resource=True,
+ llama_model='/path/to/vicuna',
+ sys_prompt= # noqa: E251
+ '###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
+)
+
+# evaluation settings
+minigpt_4_evaluator = [
+ dict(type='opencompass.DumpResults',
+ save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
+]
+
+minigpt_4_load_from = '/path/to/minigpt-4' # noqa
diff --git a/configs/multimodal/tasks.py b/configs/multimodal/tasks.py
new file mode 100644
index 00000000..b8fd75fd
--- /dev/null
+++ b/configs/multimodal/tasks.py
@@ -0,0 +1,15 @@
+from mmengine.config import read_base
+
+with read_base():
+ from .minigpt_4.minigpt_4_7b_mmbench import (minigpt_4_dataloader,
+ minigpt_4_evaluator,
+ minigpt_4_load_from,
+ minigpt_4_model)
+
+models = [minigpt_4_model]
+datasets = [minigpt_4_dataloader]
+evaluators = [minigpt_4_evaluator]
+load_froms = [minigpt_4_load_from]
+num_gpus = 1
+num_procs = 1
+launcher = 'slurm'
diff --git a/docs/en/MMBench.md b/docs/en/MMBench.md
index 02db8b95..103854c6 100644
--- a/docs/en/MMBench.md
+++ b/docs/en/MMBench.md
@@ -76,7 +76,6 @@ class MMBenchDataset(Dataset):
'context': hint,
}
return data
-
def load_from_df(self, idx, key):
if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
return self.df.iloc[idx][key]
diff --git a/opencompass/metrics/__init__.py b/opencompass/metrics/__init__.py
new file mode 100644
index 00000000..09333edf
--- /dev/null
+++ b/opencompass/metrics/__init__.py
@@ -0,0 +1,3 @@
+from .dump_results import DumpResults
+
+__all__ = ['DumpResults']
diff --git a/opencompass/metrics/dump_results.py b/opencompass/metrics/dump_results.py
new file mode 100644
index 00000000..0b330729
--- /dev/null
+++ b/opencompass/metrics/dump_results.py
@@ -0,0 +1,53 @@
+import os
+from typing import Optional
+
+import pandas as pd
+from mmengine.evaluator import BaseMetric
+
+from opencompass.registry import METRICS
+
+
+@METRICS.register_module()
+class DumpResults(BaseMetric):
+ """Dump model's prediction to a file.
+
+ Args:
+ save_path (str): the path to save model's prediction.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Default: None.
+ """
+
+ def __init__(self,
+ save_path: str,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device, prefix)
+ self.save_path = save_path
+ if not os.path.exists(os.path.dirname(self.save_path)):
+ os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
+
+ def process(self, data_batch, data_samples) -> None:
+ for data_sample in data_samples:
+ result = dict()
+
+ result['question'] = data_sample.get('question')
+ result.update(data_sample.get('options_dict'))
+ result['prediction'] = data_sample.get('pred_answer')
+ if data_sample.get('category') is not None:
+ result['category'] = data_sample.get('category')
+ if data_sample.get('l2-category') is not None:
+ result['l2-category'] = data_sample.get('l2-category')
+ result['index'] = data_sample.get('index')
+ result['split'] = data_sample.get('split')
+ self.results.append(result)
+
+ def compute_metrics(self, results: list) -> dict:
+ df = pd.DataFrame(results)
+ with pd.ExcelWriter(self.save_path, engine='openpyxl') as writer:
+ df.to_excel(writer, index=False)
+ return {}
diff --git a/opencompass/multimodal/datasets/__init__.py b/opencompass/multimodal/datasets/__init__.py
new file mode 100644
index 00000000..47fa24ad
--- /dev/null
+++ b/opencompass/multimodal/datasets/__init__.py
@@ -0,0 +1,3 @@
+from .mmbench import MMBenchDataset
+
+__all__ = ['MMBenchDataset']
diff --git a/opencompass/multimodal/datasets/mmbench.py b/opencompass/multimodal/datasets/mmbench.py
new file mode 100644
index 00000000..a1ba4c58
--- /dev/null
+++ b/opencompass/multimodal/datasets/mmbench.py
@@ -0,0 +1,79 @@
+import base64
+import io
+from typing import List, Optional
+
+import pandas as pd
+from mmengine.dataset import Compose
+from PIL import Image
+from torch.utils.data import Dataset
+
+from opencompass.registry import DATASETS
+
+
+def decode_base64_to_image(base64_string) -> Image:
+ """Convert raw data into Pillow image."""
+ image_data = base64.b64decode(base64_string)
+ image = Image.open(io.BytesIO(image_data))
+ return image
+
+
+@DATASETS.register_module()
+class MMBenchDataset(Dataset):
+ """Dataset to load MMBench dataset.
+
+ Args:
+ data_file (str): The path of the dataset.
+ pipeline (dict): The data augmentation.
+ sys_prompt (str): The system prompt added to the head
+ of these options. Defaults to
+ There are several options:
+ """
+
+ def __init__(self,
+ data_file: str,
+ pipeline: List[dict],
+ sys_prompt: str = 'There are several options:') -> None:
+ self.df = pd.read_csv(data_file, sep='\t')
+ self.pipeline = Compose(pipeline)
+ self.sys_prompt = sys_prompt
+
+ def __len__(self) -> None:
+ return len(self.df)
+
+ def __getitem__(self, idx: str) -> dict:
+ index = self.df.iloc[idx]['index']
+ image = self.df.iloc[idx]['image']
+ image = decode_base64_to_image(image)
+ question = self.df.iloc[idx]['question']
+ catetory = self.df.iloc[idx]['category']
+ l2_catetory = self.df.iloc[idx]['l2-category']
+
+ option_candidate = ['A', 'B', 'C', 'D', 'E']
+ options = {
+ cand: self.load_from_df(idx, cand)
+ for cand in option_candidate
+ if self.load_from_df(idx, cand) is not None
+ }
+ options_prompt = f'{self.sys_prompt}\n'
+ for key, item in options.items():
+ options_prompt += f'{key}. {item}\n'
+
+ hint = self.load_from_df(idx, 'hint')
+ data = {
+ 'img': image,
+ 'question': question,
+ 'options': options_prompt,
+ 'category': catetory,
+ 'l2-category': l2_catetory,
+ 'options_dict': options,
+ 'index': index,
+ 'context': hint,
+ }
+ data = self.pipeline(data)
+ return data
+
+ def load_from_df(self, idx: int, key: str) -> Optional[str]:
+ if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
+ return self.df.iloc[idx][key]
+ else:
+ return None
diff --git a/opencompass/multimodal/models/__init__.py b/opencompass/multimodal/models/__init__.py
new file mode 100644
index 00000000..3747a125
--- /dev/null
+++ b/opencompass/multimodal/models/__init__.py
@@ -0,0 +1,5 @@
+from opencompass.utils import satisfy_requirement
+
+if satisfy_requirement('salesforce-lavis'):
+ from .instructblip import * # noqa: F401, F403
+from .minigpt_4 import * # noqa: F401, F403
diff --git a/opencompass/multimodal/models/instructblip/__init__.py b/opencompass/multimodal/models/instructblip/__init__.py
new file mode 100644
index 00000000..1aa1c98b
--- /dev/null
+++ b/opencompass/multimodal/models/instructblip/__init__.py
@@ -0,0 +1,3 @@
+from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench
+
+__all__ = ['Blip2VicunaInstructMMBench']
diff --git a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
new file mode 100644
index 00000000..6595df10
--- /dev/null
+++ b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
@@ -0,0 +1,260 @@
+"""Requires Transformer 4.28 and above, implementation may change according the
+Llama implementation."""
+import logging
+import re
+
+import torch
+import torch.nn as nn
+from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
+from mmengine.device import get_device
+from transformers import LlamaForCausalLM, LlamaTokenizer
+
+from opencompass.registry import MM_MODELS
+
+
+@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench')
+class Blip2VicunaInstructMMBench(Blip2Base):
+
+ def __init__(
+ self,
+ vit_model='eva_clip_g',
+ img_size=224,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision='fp16',
+ freeze_vit=True,
+ num_query_token=32,
+ llm_model='',
+ sys_prompt='',
+ prompt='',
+ max_txt_len=128,
+ max_output_txt_len=256,
+ qformer_text_input=True,
+ low_resource=False,
+ ):
+ super().__init__()
+ self.tokenizer = self.init_tokenizer(truncation_side='left')
+
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint,
+ vit_precision)
+ if freeze_vit:
+ for name, param in self.visual_encoder.named_parameters():
+ param.requires_grad = False
+ self.visual_encoder = self.visual_encoder.eval()
+ self.visual_encoder.train = disabled_train
+ logging.info('freeze vision encoder')
+
+ self.Qformer, self.query_tokens = self.init_Qformer(
+ num_query_token, self.visual_encoder.num_features)
+
+ if not qformer_text_input:
+ self.Qformer.bert.embeddings.word_embeddings = None
+ self.Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ else:
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
+ self.Qformer.cls = None
+
+ self.llm_tokenizer = LlamaTokenizer.from_pretrained(
+ llm_model, use_fast=False, truncation_side='left')
+
+ if low_resource:
+ self.llm_model = LlamaForCausalLM.from_pretrained(
+ llm_model,
+ torch_dtype=torch.float16,
+ load_in_8bit=True,
+ device_map={'': 0})
+ else:
+ self.llm_model = LlamaForCausalLM.from_pretrained(
+ llm_model, torch_dtype=torch.float16)
+ self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+ self.llm_tokenizer.add_special_tokens({'bos_token': ''})
+ self.llm_tokenizer.add_special_tokens({'eos_token': ''})
+ self.llm_tokenizer.add_special_tokens({'unk_token': ''})
+
+ self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
+
+ for name, param in self.llm_model.named_parameters():
+ param.requires_grad = False
+
+ self.llm_proj = nn.Linear(self.Qformer.config.hidden_size,
+ self.llm_model.config.hidden_size)
+
+ self.max_txt_len = max_txt_len
+ self.max_output_txt_len = max_output_txt_len
+ self.sys_prompt = sys_prompt
+ self.prompt = prompt
+
+ self._lemmatizer = None
+
+ self.qformer_text_input = qformer_text_input
+
+ def concat_text_input_output(self, input_ids, input_atts, output_ids,
+ output_atts):
+ input_part_targets_len = []
+ llm_tokens = {'input_ids': [], 'attention_mask': []}
+ for i in range(input_ids.size(0)):
+ this_input_ones = input_atts[i].sum()
+ input_part_targets_len.append(this_input_ones)
+ llm_tokens['input_ids'].append(
+ torch.cat([
+ input_ids[i][:this_input_ones], output_ids[i][1:],
+ input_ids[i][this_input_ones:]
+ ]))
+ llm_tokens['attention_mask'].append(
+ torch.cat([
+ input_atts[i][:this_input_ones], output_atts[i][1:],
+ input_atts[i][this_input_ones:]
+ ]))
+ llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+ llm_tokens['attention_mask'] = torch.stack(
+ llm_tokens['attention_mask'])
+ return llm_tokens, input_part_targets_len
+
+ def pack_inputs(self, batch):
+ images = [image.unsqueeze(0) for image in batch['inputs']]
+ data_samples = [data_sample for data_sample in batch['data_samples']]
+ images = torch.cat(images, dim=0).to(get_device())
+ inputs = {'image': images, 'data_samples': data_samples}
+ return inputs
+
+ @torch.no_grad()
+ def generate(
+ self,
+ batch,
+ use_nucleus_sampling=False,
+ num_beams=5,
+ max_length=256,
+ min_length=1,
+ top_p=0.9,
+ repetition_penalty=1.5,
+ length_penalty=1,
+ num_captions=1,
+ temperature=1,
+ ):
+ inputs = self.pack_inputs(batch)
+ image = inputs.pop('image')
+ data_samples = inputs['data_samples']
+ samples = {'image': image}
+ questions = [
+ data_sample.get('question') for data_sample in data_samples
+ ]
+ options = [data_sample.get('options') for data_sample in data_samples]
+ if data_samples[0].get('context') is not None:
+ contexts = [
+ data_sample.get('context') for data_sample in data_samples
+ ]
+ prompt = [
+ context + ' ' + question + ' ' + option for context, question,
+ option in zip(contexts, questions, options)
+ ]
+ else:
+ prompt = [
+ question + ' ' + option
+ for question, option in zip(questions, options)
+ ]
+
+ self.llm_tokenizer.padding_side = 'left'
+
+ image = samples['image']
+
+ bs = image.size(0)
+
+ if isinstance(prompt, str):
+ prompt = [prompt] * bs
+ else:
+ assert len(
+ prompt
+ ) == bs, 'The number of prompts must be equal to the batch size.'
+
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
+ if self.qformer_text_input:
+ text_Qformer = self.tokenizer(
+ prompt,
+ padding='longest',
+ truncation=True,
+ max_length=self.max_txt_len,
+ return_tensors='pt',
+ ).to(image.device)
+ query_atts = torch.ones(query_tokens.size()[:-1],
+ dtype=torch.long).to(image.device)
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],
+ dim=1)
+
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(self.visual_encoder(image))
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ if self.qformer_text_input:
+ query_output = self.Qformer.bert(
+ text_Qformer.input_ids,
+ attention_mask=Qformer_atts,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ else:
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ inputs_llm = self.llm_proj(
+ query_output.last_hidden_state[:, :query_tokens.size(1), :])
+ atts_llm = torch.ones(inputs_llm.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ prompt = ['###Human: ' + p + '###Assistant:' for p in prompt]
+ prompt = [self.sys_prompt + p for p in prompt]
+ llm_tokens = self.llm_tokenizer(prompt,
+ padding='longest',
+ return_tensors='pt').to(image.device)
+
+ with self.maybe_autocast():
+ inputs_embeds = self.llm_model.get_input_embeddings()(
+ llm_tokens.input_ids)
+ inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+ attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask],
+ dim=1)
+
+ outputs = self.llm_model.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ do_sample=use_nucleus_sampling,
+ top_p=top_p,
+ temperature=temperature,
+ num_beams=num_beams,
+ max_length=max_length,
+ min_length=min_length,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty,
+ num_return_sequences=num_captions,
+ )
+ outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
+ output_text = self.llm_tokenizer.batch_decode(outputs,
+ skip_special_tokens=True)
+ output_text = [text.strip() for text in output_text]
+ output_text = self.post_process(output_text[0])
+ data_sample = data_samples[0]
+ data_sample.pred_answer = output_text
+
+ return data_sample
+
+ def post_process(self, output_text):
+ output_text = output_text.split('###')[0]
+ output_text = output_text.split('Assistant:')[-1].strip()
+ output_text = output_text.strip('')
+ output_text = output_text.strip('')
+ output_text = output_text.strip()
+ pattern = re.compile(r'([A-Z]\.)')
+ res = pattern.findall(output_text)
+ if len(res) > 0:
+ output_text = res[0][:-1]
+ return output_text
diff --git a/opencompass/multimodal/models/minigpt_4/__init__.py b/opencompass/multimodal/models/minigpt_4/__init__.py
new file mode 100644
index 00000000..3104855c
--- /dev/null
+++ b/opencompass/multimodal/models/minigpt_4/__init__.py
@@ -0,0 +1,3 @@
+from .minigpt_4 import MiniGPT4MMBench
+
+__all__ = ['MiniGPT4MMBench']
diff --git a/opencompass/multimodal/models/minigpt_4/minigpt_4.py b/opencompass/multimodal/models/minigpt_4/minigpt_4.py
new file mode 100644
index 00000000..306fec58
--- /dev/null
+++ b/opencompass/multimodal/models/minigpt_4/minigpt_4.py
@@ -0,0 +1,181 @@
+import os
+import re
+import sys
+
+import torch
+import torch.nn as nn
+from mmengine.device import get_device
+from transformers import StoppingCriteriaList
+
+from opencompass.registry import MM_MODELS
+
+from .utils import StoppingCriteriaSub
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+def load_package():
+ """Load required packages from MiniGPT-4."""
+ current_file_path = os.path.abspath(__file__)
+ current_folder_path = os.path.dirname(current_file_path)
+
+ sys.path.append(os.path.join(current_folder_path, 'MiniGPT-4')) # noqa
+ from minigpt4.models.mini_gpt4 import MiniGPT4
+
+ sys.path.pop(-1)
+
+ return MiniGPT4
+
+
+MiniGPT4 = load_package()
+
+
+@MM_MODELS.register_module('minigpt-4-mmbench')
+class MiniGPT4MMBench(MiniGPT4):
+ """Inference code of MiniGPT-4 on MMBench.
+
+ Args:
+ llama_model (str): The path of vicuna path.
+ sys_prompt (str): The prompt added to the beginning
+ of each query. Defaults to ''.
+ low_resource (bool): Whether loaded in low precision.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ llama_model: str,
+ sys_prompt: str = '',
+ low_resource: bool = False) -> None:
+ super().__init__(llama_model=llama_model, low_resource=low_resource)
+
+ cur_device = get_device()
+ stop_words_ids = [
+ torch.tensor([835]).to(cur_device),
+ torch.tensor([2277, 29937]).to(cur_device),
+ ]
+ self.stopping_criteria = StoppingCriteriaList(
+ [StoppingCriteriaSub(stops=stop_words_ids)])
+ self.sys_prompt = sys_prompt
+
+ def encode_img(self, image):
+ device = image.device
+
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(
+ self.visual_encoder(image)).to(device)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(device)
+
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
+ -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
+ atts_llama = torch.ones(inputs_llama.size()[:-1],
+ dtype=torch.long).to(image.device)
+ return inputs_llama, atts_llama
+
+ def pack_inputs(self, batch):
+ images = [image.unsqueeze(0) for image in batch['inputs']]
+ data_samples = [data_sample for data_sample in batch['data_samples']]
+ images = torch.cat(images, dim=0).to(get_device())
+ inputs = {'image': images, 'data_samples': data_samples}
+ return inputs
+
+ def generate(self, batch):
+ inputs = self.pack_inputs(batch)
+ image = inputs.pop('image')
+ data_samples = inputs['data_samples']
+ samples = {'image': image}
+ question = [
+ data_sample.get('question') for data_sample in data_samples
+ ]
+ options = [data_sample.get('options') for data_sample in data_samples]
+ samples.update({'question': question[0]})
+ samples.update({'options': options[0]})
+ if data_samples[0].get('context') is not None:
+ context = [
+ data_sample.get('context') for data_sample in data_samples
+ ]
+ samples.update({'context': context})
+ data_sample = data_samples[0]
+ img_prompt = '###Human:
'
+ if 'context' in samples:
+ context_prompt = samples['context'][0]
+
+ question = samples['question']
+ options = samples['options']
+ if 'context' in samples:
+ prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
+ else:
+ prompt = img_prompt + ' ' + question + ' ' + options
+
+ # prompt = self.sys_prompt + prompt
+ prompt = prompt + '###Assistant:'
+
+ image = samples['image']
+ img_embeds, _ = self.encode_img(image)
+
+ prompt_segs = prompt.split('')
+ prompt_seg_tokens = [
+ self.llama_tokenizer(seg,
+ return_tensors='pt',
+ add_special_tokens=i == 0).
+ to(self.llama_model.model.embed_tokens.weight.device).input_ids
+ for i, seg in enumerate(prompt_segs)
+ ]
+ prompt_seg_embs = [
+ self.llama_model.model.embed_tokens(seg)
+ for seg in prompt_seg_tokens
+ ]
+ prompt_seg_embs = [prompt_seg_embs[0], img_embeds, prompt_seg_embs[1]]
+ prompt_embs = torch.cat(prompt_seg_embs, dim=1)
+
+ # generate output
+ outputs = self.llama_model.generate(
+ inputs_embeds=prompt_embs,
+ max_new_tokens=20,
+ num_beams=5,
+ do_sample=False,
+ min_length=1,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ length_penalty=-1.0,
+ temperature=1.0,
+ stopping_criteria=self.stopping_criteria,
+ num_return_sequences=1)
+
+ output_token = outputs[0]
+ if output_token[0] == 0:
+ output_token = output_token[1:]
+ if output_token[0] == 1:
+ output_token = output_token[1:]
+ output_text = self.llama_tokenizer.decode(output_token,
+ add_special_tokens=False)
+ output_text = self.post_process(output_text)
+ data_sample.pred_answer = output_text
+ return data_sample
+
+ def post_process(self, output_text):
+ output_text = output_text.split('###')[0]
+ output_text = output_text.split('Assistant:')[-1].strip()
+ output_text = output_text.strip('')
+ output_text = output_text.strip('')
+ output_text = output_text.strip()
+ pattern = re.compile(r'([A-Z]\.)')
+ res = pattern.findall(output_text)
+ if len(res) > 0:
+ output_text = res[0][:-1]
+ return output_text
diff --git a/opencompass/multimodal/models/minigpt_4/utils.py b/opencompass/multimodal/models/minigpt_4/utils.py
new file mode 100644
index 00000000..777c1939
--- /dev/null
+++ b/opencompass/multimodal/models/minigpt_4/utils.py
@@ -0,0 +1,56 @@
+import os
+import re
+
+import timm.models.hub as timm_hub
+import torch
+import torch.distributed as dist
+from mmengine.dist import is_distributed, is_main_process
+from transformers import StoppingCriteria
+
+
+class StoppingCriteriaSub(StoppingCriteria):
+
+ def __init__(self, stops=[], encounters=1):
+ super().__init__()
+ self.stops = stops
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ for stop in self.stops:
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
+ return True
+
+ return False
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+ """Download a file from a URL and cache it locally.
+
+ If the file already exists, it is not downloaded again. If distributed,
+ only the main process downloads the file, and the other processes wait for
+ the file to be downloaded.
+ """
+
+ def get_cached_file_path():
+ # a hack to sync the file path across processes
+ parts = torch.hub.urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
+
+ return cached_file
+
+ if is_main_process():
+ timm_hub.download_cached_file(url, check_hash, progress)
+
+ if is_distributed():
+ dist.barrier()
+
+ return get_cached_file_path()
+
+
+def is_url(input_url):
+ """Check if an input string is a url.
+
+ look for http(s):// and ignoring the case
+ """
+ is_url = re.match(r'^(?:http)s?://', input_url, re.IGNORECASE) is not None
+ return is_url
diff --git a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py
index 89fb1d17..3e2e4bcc 100644
--- a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py
+++ b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py
@@ -120,7 +120,7 @@ class AccEvaluator(HuggingfaceEvaluator):
@ICL_EVALUATORS.register_module()
class RougeEvaluator(HuggingfaceEvaluator):
- """Rouge evaluator."""
+ """Rouge evaluator.""" # noqa
def __init__(self) -> None:
super().__init__(metric='rouge')
diff --git a/opencompass/partitioners/__init__.py b/opencompass/partitioners/__init__.py
index 836081fb..ee9fe108 100644
--- a/opencompass/partitioners/__init__.py
+++ b/opencompass/partitioners/__init__.py
@@ -1,2 +1,3 @@
+from .mm_naive import * # noqa: F401, F403
from .naive import * # noqa: F401, F403
from .size import * # noqa: F401, F403
diff --git a/opencompass/partitioners/mm_naive.py b/opencompass/partitioners/mm_naive.py
new file mode 100644
index 00000000..817b1276
--- /dev/null
+++ b/opencompass/partitioners/mm_naive.py
@@ -0,0 +1,119 @@
+from copy import deepcopy
+from typing import Dict, List
+
+from mmengine.config import Config, ConfigDict
+
+from opencompass.registry import PARTITIONERS
+
+from .base import BasePartitioner
+
+
+@PARTITIONERS.register_module()
+class MultimodalNaivePartitioner(BasePartitioner):
+ """Multimodal naive task partitioner.
+
+ This partitioner will generate a task for each
+ model-dataset-evaluator pair.
+
+ Args:
+ config (ConfigDict): The full config dict.
+ """
+
+ def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
+ evaluators: List[ConfigDict], load_froms: List[ConfigDict],
+ work_dir: str, num_gpus: int, num_procs: int,
+ launcher: str) -> List[Dict]:
+ """Partition model-dataset pairs into tasks. Each task is defined as a
+ dict and will run independently as a unit. Its structure is as follows:
+
+ .. code-block:: python
+
+ {
+ 'models': [], # a list of model configs
+ 'datasets': [], # a list of dataset configs
+ 'evaluators': [], # a list of evaluator configs
+ 'load_froms': [], # a list of load_from paths
+ 'work_dir': '', # the work dir
+ 'num_gpus': int, # integer, number of gpus for each task
+ 'num_procs': int, # integer, number of gpus on single machine
+ 'launcher': str, # string, how to launch distributed training
+ }
+
+ Args:
+ models (List[ConfigDict]): A list of model configs.
+ datasets (List[ConfigDict]): A list of dataset configs.
+ evaluators (List[ConfigDict]): A list of evaluator configs.
+ load_froms (List[ConfigDict]): A list of load_from paths.
+ work_dir (str): The work dir for the task.
+ num_gpus (int): Number of gpus for each task.
+ num_procs (int): Number of gpus on single machine.
+ launcher (str): How to launch distributed training.
+ Only `slurm`, `pytorch` and `mpi` are available.
+
+ Returns:
+ List[Dict]: A list of tasks.
+ """
+
+ tasks = []
+ for model, dataset, evaluator, load_from in zip(
+ models, datasets, evaluators, load_froms):
+ task = Config({
+ 'model': model,
+ 'dataset': dataset,
+ 'evaluator': evaluator,
+ 'load_from': load_from,
+ 'work_dir': work_dir,
+ 'num_gpus': num_gpus,
+ 'num_procs': num_procs,
+ 'launcher': launcher
+ })
+ tasks.append(task)
+
+ return tasks
+
+ def __call__(self, cfg: ConfigDict) -> List[Dict]:
+ """Generate tasks from config. Each task is defined as a
+ dict and will run independently as a unit. Its structure is as
+ follows:
+
+ .. code-block:: python
+
+ {
+ 'models': [], # a list of model configs
+ 'datasets': [], # a list of dataset configs
+ 'evaluators': [], # a list of evaluator configs
+ 'load_froms': [], # a list of load_from paths
+ 'work_dir': '', # the work dir
+ 'num_gpus': int, # integer, number of gpus for each task
+ 'num_procs': int, # integer, number of gpus on single machine
+ }
+
+ Args:
+ cfg (ConfigDict): The config dict, containing "models", "dataset"
+ and "work_dir" keys.
+
+ Returns:
+ List[Dict]: A list of tasks.
+ """
+ cfg = deepcopy(cfg)
+ models = cfg['models']
+ datasets = cfg['datasets']
+ evaluators = cfg['evaluators']
+ load_froms = cfg['load_froms']
+ work_dir = cfg['work_dir']
+ num_gpus = cfg['num_gpus']
+ num_procs = cfg['num_procs']
+ launcher = cfg['launcher']
+
+ tasks = self.partition(models, datasets, evaluators, load_froms,
+ work_dir, num_gpus, num_procs, launcher)
+
+ self.logger.info(f'Partitioned into {len(tasks)} tasks.')
+ for i, task in enumerate(tasks):
+ model_name = task['model']['type']
+ dataset_name = task['dataset']['dataset']['type']
+ evaluator_name = task['evaluator'][0]['type']
+ self.logger.debug(
+ f'Task {i}: {model_name}-{dataset_name}-{evaluator_name}')
+
+ return tasks
diff --git a/opencompass/registry.py b/opencompass/registry.py
index 8f26e607..a48ee51c 100644
--- a/opencompass/registry.py
+++ b/opencompass/registry.py
@@ -1,3 +1,6 @@
+from mmengine.registry import DATASETS as MMENGINE_DATASETS
+from mmengine.registry import METRICS as MMENGINE_METRICS
+from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import Registry
PARTITIONERS = Registry('partitioner', locations=['opencompass.partitioners'])
@@ -22,3 +25,12 @@ ICL_PROMPT_TEMPLATES = Registry(
locations=['opencompass.openicl.icl_prompt_template'])
ICL_EVALUATORS = Registry('icl_evaluators',
locations=['opencompass.openicl.icl_evaluator'])
+DATASETS = Registry('mm_datasets',
+ parent=MMENGINE_DATASETS,
+ locations=['opencompass.multimodal.datasets'])
+METRICS = Registry('metric',
+ parent=MMENGINE_METRICS,
+ locations=['opencompass.metrics'])
+MM_MODELS = Registry('mm_model',
+ parent=MMENGINE_MODELS,
+ locations=['opencompass.multimodal.models'])
diff --git a/opencompass/runners/slurm.py b/opencompass/runners/slurm.py
index ddb7808d..c6efb60c 100644
--- a/opencompass/runners/slurm.py
+++ b/opencompass/runners/slurm.py
@@ -81,7 +81,6 @@ class SlurmRunner(BaseRunner):
Returns:
tuple[str, int]: Task name and exit code.
"""
-
task_type = self.task_cfg.type
if isinstance(self.task_cfg.type, str):
task_type = TASKS.get(task_type)
diff --git a/opencompass/tasks/__init__.py b/opencompass/tasks/__init__.py
index 308ea5d6..ac63c77d 100644
--- a/opencompass/tasks/__init__.py
+++ b/opencompass/tasks/__init__.py
@@ -1,2 +1,3 @@
+from .mm_infer import * # noqa: F401, F403
from .openicl_eval import * # noqa: F401, F403
from .openicl_infer import * # noqa: F401, F403
diff --git a/opencompass/tasks/mm_infer.py b/opencompass/tasks/mm_infer.py
new file mode 100644
index 00000000..2d52c230
--- /dev/null
+++ b/opencompass/tasks/mm_infer.py
@@ -0,0 +1,126 @@
+import argparse
+import json
+import os
+import os.path as osp
+import random
+import time
+from typing import Sequence
+
+import torch
+import torch.distributed as dist
+from mmengine.config import Config, ConfigDict
+from mmengine.device import get_device
+from mmengine.dist import init_dist
+from mmengine.evaluator import Evaluator
+from mmengine.logging import print_log
+from mmengine.model.wrappers import MMDistributedDataParallel
+from mmengine.runner import Runner
+from mmengine.utils import track_iter_progress
+
+from opencompass.registry import MM_MODELS, TASKS
+from opencompass.utils import get_logger
+
+
+def build_model(cfg):
+ model = MM_MODELS.build(cfg['model'])
+ load_from = cfg.get('load_from', None)
+ if load_from is not None:
+ state_dict = torch.load(cfg['load_from'], map_location='cpu')
+ if 'model' in state_dict:
+ state_dict = state_dict['model']
+ elif 'state_dict' in state_dict:
+ state_dict = state_dict['state_dict']
+ msg = model.load_state_dict(state_dict, strict=False)
+ print_log(msg)
+ model.to(get_device())
+ if dist.is_initialized():
+ model = MMDistributedDataParallel(
+ model,
+ device_ids=[int(os.environ['LOCAL_RANK'])],
+ broadcast_buffers=False)
+ return model
+
+
+@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
+class MultimodalInferTask:
+ """Multimodal Inference Task.
+
+ This task is used to run the inference process.
+ """
+
+ def __init__(self, cfg: ConfigDict):
+ self.num_gpus = cfg.get('num_gpus', 0)
+ self.num_procs = cfg.get('num_procs', 1)
+ self.dataloader = cfg.get('dataset')
+ self.model = cfg.get('model')
+ self.evaluator = cfg.get('evaluator')
+ self.cfg = cfg
+ self.logger = get_logger()
+
+ @property
+ def name(self) -> str:
+ model_name = self.model['type']
+ dataset_name = self.dataloader['dataset']['type']
+ evaluator_name = self.evaluator[0]['type']
+ return f'{model_name}-{dataset_name}-{evaluator_name}'
+
+ def get_command(self, cfg_path, template):
+ """Get the command template for the task.
+
+ Args:
+ cfg_path (str): The path to the config file of the task.
+ template (str): The template which have '{task_cmd}' to format
+ the command.
+ """
+ script_path = __file__
+ if self.num_gpus > 0:
+ port = random.randint(12000, 32000)
+ command = (f'torchrun --master_port={port} '
+ f'--nproc_per_node {self.num_procs} '
+ f'{script_path} {cfg_path}')
+ else:
+ command = f'python {script_path} {cfg_path}'
+
+ return template.format(task_cmd=command)
+
+ def run(self):
+ # only support slurm, pytorch, mpi
+ init_dist(self.cfg.launcher)
+ self.logger.info(f'Task {self.name}')
+ # build dataloader
+ dataloader = Runner.build_dataloader(self.dataloader)
+ # build model
+ model = build_model(self.cfg)
+ # build evaluator
+ evaluator = Evaluator(self.evaluator)
+
+ for batch in track_iter_progress(dataloader):
+ if dist.is_initialized():
+ data_samples = model.module.generate(batch)
+ else:
+ data_samples = model.generate(batch)
+ if not isinstance(data_samples, Sequence):
+ data_samples = [data_samples]
+ evaluator.process(data_samples)
+
+ metrics = evaluator.evaluate(len(dataloader.dataset))
+ metrics_file = osp.join(cfg.work_dir, 'res.log')
+ with open(metrics_file, 'w') as f:
+ json.dump(metrics, f)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Model Inferencer')
+ parser.add_argument('config', help='Config file path')
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ cfg = Config.fromfile(args.config)
+ start_time = time.time()
+ inferencer = MultimodalInferTask(cfg)
+ inferencer.run()
+ end_time = time.time()
+ get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')
diff --git a/opencompass/utils/__init__.py b/opencompass/utils/__init__.py
index d9fdeb4c..d899d161 100644
--- a/opencompass/utils/__init__.py
+++ b/opencompass/utils/__init__.py
@@ -1,6 +1,7 @@
from .abbr import * # noqa
from .build import * # noqa
from .collect_env import * # noqa
+from .dependency import * # noqa
from .fileio import * # noqa
from .git import * # noqa
from .lark import * # noqa
diff --git a/opencompass/utils/dependency.py b/opencompass/utils/dependency.py
new file mode 100644
index 00000000..821735f7
--- /dev/null
+++ b/opencompass/utils/dependency.py
@@ -0,0 +1,32 @@
+import re
+
+from importlib_metadata import PackageNotFoundError, distribution
+from mmengine.utils import digit_version
+
+
+def satisfy_requirement(dep):
+ pat = '(' + '|'.join(['>=', '==', '>']) + ')'
+ parts = re.split(pat, dep, maxsplit=1)
+ parts = [p.strip() for p in parts]
+ package = parts[0]
+ if len(parts) > 1:
+ op, version = parts[1:]
+ op = {
+ '>=': '__ge__',
+ '==': '__eq__',
+ '>': '__gt__',
+ '<': '__lt__',
+ '<=': '__le__'
+ }[op]
+ else:
+ op, version = None, None
+
+ try:
+ dist = distribution(package)
+ if op is None or getattr(digit_version(dist.version), op)(
+ digit_version(version)):
+ return True
+ except PackageNotFoundError:
+ pass
+
+ return False
diff --git a/run.py b/run.py
index 661beed2..3c975b95 100644
--- a/run.py
+++ b/run.py
@@ -6,7 +6,8 @@ from datetime import datetime
from mmengine.config import Config
-from opencompass.partitioners import NaivePartitioner, SizePartitioner
+from opencompass.partitioners import (MultimodalNaivePartitioner,
+ NaivePartitioner, SizePartitioner)
from opencompass.registry import PARTITIONERS, RUNNERS
from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner
from opencompass.utils import LarkReporter, Summarizer, get_logger
@@ -37,6 +38,10 @@ def parse_args():
'redirected to files',
action='store_true',
default=False)
+ parser.add_argument('--mm-eval',
+ help='Whether or not enable multimodal evaluation',
+ action='store_true',
+ default=False)
parser.add_argument('--dry-run',
help='Dry run mode, in which the scheduler will not '
'actually run the tasks, but only print the commands '
@@ -201,7 +206,14 @@ def main():
'also specified --slurm or --dlc. '
'The "infer" configuration will be overridden by '
'your runtime arguments.')
- if args.dlc or args.slurm or cfg.get('infer', None) is None:
+ # Check whether run multimodal evaluation
+ if args.mm_eval:
+ partitioner = MultimodalNaivePartitioner(
+ osp.join(cfg['work_dir'], 'predictions/'))
+ tasks = partitioner(cfg)
+ exec_mm_infer_runner(tasks, args, cfg)
+ return
+ elif args.dlc or args.slurm or cfg.get('infer', None) is None:
# Use SizePartitioner to split into subtasks
partitioner = SizePartitioner(
osp.join(cfg['work_dir'], 'predictions/'),
@@ -283,6 +295,27 @@ def main():
summarizer.summarize(time_str=cfg_time_str)
+def exec_mm_infer_runner(tasks, args, cfg):
+ """execute multimodal infer runner according to args."""
+ if args.slurm:
+ runner = SlurmRunner(dict(type='MultimodalInferTask'),
+ max_num_workers=args.max_num_workers,
+ partition=args.partition,
+ quotatype=args.quotatype,
+ retry=args.retry,
+ debug=args.debug,
+ lark_bot_url=cfg['lark_bot_url'])
+ elif args.dlc:
+ raise NotImplementedError('Currently, we do not support evaluating \
+ multimodal models on dlc.')
+ else:
+ runner = LocalRunner(task=dict(type='MultimodalInferTask'),
+ max_num_workers=args.max_num_workers,
+ debug=args.debug,
+ lark_bot_url=cfg['lark_bot_url'])
+ runner(tasks)
+
+
def exec_infer_runner(tasks, args, cfg):
"""execute infer runner according to args."""
if args.slurm: