mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature]: Use multimodal (#73)
* [Feature]: Add minigpt-4 * [Feature]: Add mm local runner * [Feature]: Add instructblip * [Feature]: Delete redundant file * [Feature]: Delete redundant file * [Feature]: Add README to InstructBLIP * [Feature]: Update MiniGPT-4 * [Fix]: Fix lint * [Feature]add omnibenchmark readme (#49) * add omnibenchmark readme * fix * Update OmniMMBench.md * Update OmniMMBench.md * Update OmniMMBench.md * [Fix]: Refine name (#54) * [Feature]: Unify out and err * [Fix]: Fix lint * [Feature]: Rename to mmbench and change weight path * [Feature]: Delete Omni in instructblip * [Feature]: Check the avaliablity of lavis * [Fix]: Fix lint * [Feature]: Refactor MM * [Refactor]: Refactor path * [Feature]: Delete redundant files * [Refactor]: Delete redundant files --------- Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com>
This commit is contained in:
parent
289e0567bd
commit
191a3f6f9d
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,6 +10,7 @@ configs/datasets/log.json
|
||||
configs/eval_debug*.py
|
||||
configs/viz_*.py
|
||||
data
|
||||
work_dirs
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
9
configs/multimodal/instructblip/README.md
Normal file
9
configs/multimodal/instructblip/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# InstructBLIP
|
||||
|
||||
### Prepare the environment
|
||||
|
||||
```sh
|
||||
git clone https://github.com/salesforce/LAVIS.git
|
||||
cd ./LAVIS
|
||||
pip install -e .
|
||||
```
|
45
configs/multimodal/instructblip/instructblip-mmbench.py
Normal file
45
configs/multimodal/instructblip/instructblip-mmbench.py
Normal file
@ -0,0 +1,45 @@
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'answer', 'category', 'l2-category', 'context',
|
||||
'index', 'options_dict', 'options', 'split'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBench',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='blip2-vicuna-instruct-mmbench',
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
sys_prompt= # noqa: E251
|
||||
'###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
evaluator = [
|
||||
dict(
|
||||
type='opencompass.DumpResults',
|
||||
save_path= # noqa: E251
|
||||
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
|
||||
]
|
||||
|
||||
load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa
|
10
configs/multimodal/minigpt_4/README.md
Normal file
10
configs/multimodal/minigpt_4/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# MiniGPT-4
|
||||
|
||||
### Prepare the environment
|
||||
|
||||
```sh
|
||||
cd opencompass/multimodal/models/minigpt_4
|
||||
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
|
||||
```
|
||||
|
||||
Then prepare the environement according to this [doc](https://github.com/Vision-CAIR/MiniGPT-4)
|
42
configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
Normal file
42
configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
Normal file
@ -0,0 +1,42 @@
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'answer', 'category', 'l2-category', 'context',
|
||||
'index', 'options_dict', 'options', 'split'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBenchDataset',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
minigpt_4_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
minigpt_4_model = dict(
|
||||
type='minigpt-4-mmbench',
|
||||
low_resource=True,
|
||||
llama_model='/path/to/vicuna',
|
||||
sys_prompt= # noqa: E251
|
||||
'###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_evaluator = [
|
||||
dict(type='opencompass.DumpResults',
|
||||
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
|
||||
]
|
||||
|
||||
minigpt_4_load_from = '/path/to/minigpt-4' # noqa
|
15
configs/multimodal/tasks.py
Normal file
15
configs/multimodal/tasks.py
Normal file
@ -0,0 +1,15 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .minigpt_4.minigpt_4_7b_mmbench import (minigpt_4_dataloader,
|
||||
minigpt_4_evaluator,
|
||||
minigpt_4_load_from,
|
||||
minigpt_4_model)
|
||||
|
||||
models = [minigpt_4_model]
|
||||
datasets = [minigpt_4_dataloader]
|
||||
evaluators = [minigpt_4_evaluator]
|
||||
load_froms = [minigpt_4_load_from]
|
||||
num_gpus = 1
|
||||
num_procs = 1
|
||||
launcher = 'slurm'
|
@ -76,7 +76,6 @@ class MMBenchDataset(Dataset):
|
||||
'context': hint,
|
||||
}
|
||||
return data
|
||||
|
||||
def load_from_df(self, idx, key):
|
||||
if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
|
||||
return self.df.iloc[idx][key]
|
||||
|
3
opencompass/metrics/__init__.py
Normal file
3
opencompass/metrics/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .dump_results import DumpResults
|
||||
|
||||
__all__ = ['DumpResults']
|
53
opencompass/metrics/dump_results.py
Normal file
53
opencompass/metrics/dump_results.py
Normal file
@ -0,0 +1,53 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from opencompass.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class DumpResults(BaseMetric):
|
||||
"""Dump model's prediction to a file.
|
||||
|
||||
Args:
|
||||
save_path (str): the path to save model's prediction.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
save_path: str,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device, prefix)
|
||||
self.save_path = save_path
|
||||
if not os.path.exists(os.path.dirname(self.save_path)):
|
||||
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
|
||||
|
||||
def process(self, data_batch, data_samples) -> None:
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
|
||||
result['question'] = data_sample.get('question')
|
||||
result.update(data_sample.get('options_dict'))
|
||||
result['prediction'] = data_sample.get('pred_answer')
|
||||
if data_sample.get('category') is not None:
|
||||
result['category'] = data_sample.get('category')
|
||||
if data_sample.get('l2-category') is not None:
|
||||
result['l2-category'] = data_sample.get('l2-category')
|
||||
result['index'] = data_sample.get('index')
|
||||
result['split'] = data_sample.get('split')
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
df = pd.DataFrame(results)
|
||||
with pd.ExcelWriter(self.save_path, engine='openpyxl') as writer:
|
||||
df.to_excel(writer, index=False)
|
||||
return {}
|
3
opencompass/multimodal/datasets/__init__.py
Normal file
3
opencompass/multimodal/datasets/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .mmbench import MMBenchDataset
|
||||
|
||||
__all__ = ['MMBenchDataset']
|
79
opencompass/multimodal/datasets/mmbench.py
Normal file
79
opencompass/multimodal/datasets/mmbench.py
Normal file
@ -0,0 +1,79 @@
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from mmengine.dataset import Compose
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from opencompass.registry import DATASETS
|
||||
|
||||
|
||||
def decode_base64_to_image(base64_string) -> Image:
|
||||
"""Convert raw data into Pillow image."""
|
||||
image_data = base64.b64decode(base64_string)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
return image
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MMBenchDataset(Dataset):
|
||||
"""Dataset to load MMBench dataset.
|
||||
|
||||
Args:
|
||||
data_file (str): The path of the dataset.
|
||||
pipeline (dict): The data augmentation.
|
||||
sys_prompt (str): The system prompt added to the head
|
||||
of these options. Defaults to
|
||||
There are several options:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file: str,
|
||||
pipeline: List[dict],
|
||||
sys_prompt: str = 'There are several options:') -> None:
|
||||
self.df = pd.read_csv(data_file, sep='\t')
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.sys_prompt = sys_prompt
|
||||
|
||||
def __len__(self) -> None:
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx: str) -> dict:
|
||||
index = self.df.iloc[idx]['index']
|
||||
image = self.df.iloc[idx]['image']
|
||||
image = decode_base64_to_image(image)
|
||||
question = self.df.iloc[idx]['question']
|
||||
catetory = self.df.iloc[idx]['category']
|
||||
l2_catetory = self.df.iloc[idx]['l2-category']
|
||||
|
||||
option_candidate = ['A', 'B', 'C', 'D', 'E']
|
||||
options = {
|
||||
cand: self.load_from_df(idx, cand)
|
||||
for cand in option_candidate
|
||||
if self.load_from_df(idx, cand) is not None
|
||||
}
|
||||
options_prompt = f'{self.sys_prompt}\n'
|
||||
for key, item in options.items():
|
||||
options_prompt += f'{key}. {item}\n'
|
||||
|
||||
hint = self.load_from_df(idx, 'hint')
|
||||
data = {
|
||||
'img': image,
|
||||
'question': question,
|
||||
'options': options_prompt,
|
||||
'category': catetory,
|
||||
'l2-category': l2_catetory,
|
||||
'options_dict': options,
|
||||
'index': index,
|
||||
'context': hint,
|
||||
}
|
||||
data = self.pipeline(data)
|
||||
return data
|
||||
|
||||
def load_from_df(self, idx: int, key: str) -> Optional[str]:
|
||||
if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
|
||||
return self.df.iloc[idx][key]
|
||||
else:
|
||||
return None
|
5
opencompass/multimodal/models/__init__.py
Normal file
5
opencompass/multimodal/models/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from opencompass.utils import satisfy_requirement
|
||||
|
||||
if satisfy_requirement('salesforce-lavis'):
|
||||
from .instructblip import * # noqa: F401, F403
|
||||
from .minigpt_4 import * # noqa: F401, F403
|
3
opencompass/multimodal/models/instructblip/__init__.py
Normal file
3
opencompass/multimodal/models/instructblip/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench
|
||||
|
||||
__all__ = ['Blip2VicunaInstructMMBench']
|
@ -0,0 +1,260 @@
|
||||
"""Requires Transformer 4.28 and above, implementation may change according the
|
||||
Llama implementation."""
|
||||
import logging
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
|
||||
from mmengine.device import get_device
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from opencompass.registry import MM_MODELS
|
||||
|
||||
|
||||
@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench')
|
||||
class Blip2VicunaInstructMMBench(Blip2Base):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model='eva_clip_g',
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision='fp16',
|
||||
freeze_vit=True,
|
||||
num_query_token=32,
|
||||
llm_model='',
|
||||
sys_prompt='',
|
||||
prompt='',
|
||||
max_txt_len=128,
|
||||
max_output_txt_len=256,
|
||||
qformer_text_input=True,
|
||||
low_resource=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = self.init_tokenizer(truncation_side='left')
|
||||
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint,
|
||||
vit_precision)
|
||||
if freeze_vit:
|
||||
for name, param in self.visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.visual_encoder = self.visual_encoder.eval()
|
||||
self.visual_encoder.train = disabled_train
|
||||
logging.info('freeze vision encoder')
|
||||
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features)
|
||||
|
||||
if not qformer_text_input:
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
else:
|
||||
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
||||
self.Qformer.cls = None
|
||||
|
||||
self.llm_tokenizer = LlamaTokenizer.from_pretrained(
|
||||
llm_model, use_fast=False, truncation_side='left')
|
||||
|
||||
if low_resource:
|
||||
self.llm_model = LlamaForCausalLM.from_pretrained(
|
||||
llm_model,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_8bit=True,
|
||||
device_map={'': 0})
|
||||
else:
|
||||
self.llm_model = LlamaForCausalLM.from_pretrained(
|
||||
llm_model, torch_dtype=torch.float16)
|
||||
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
|
||||
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
||||
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
|
||||
|
||||
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
|
||||
|
||||
for name, param in self.llm_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.llm_proj = nn.Linear(self.Qformer.config.hidden_size,
|
||||
self.llm_model.config.hidden_size)
|
||||
|
||||
self.max_txt_len = max_txt_len
|
||||
self.max_output_txt_len = max_output_txt_len
|
||||
self.sys_prompt = sys_prompt
|
||||
self.prompt = prompt
|
||||
|
||||
self._lemmatizer = None
|
||||
|
||||
self.qformer_text_input = qformer_text_input
|
||||
|
||||
def concat_text_input_output(self, input_ids, input_atts, output_ids,
|
||||
output_atts):
|
||||
input_part_targets_len = []
|
||||
llm_tokens = {'input_ids': [], 'attention_mask': []}
|
||||
for i in range(input_ids.size(0)):
|
||||
this_input_ones = input_atts[i].sum()
|
||||
input_part_targets_len.append(this_input_ones)
|
||||
llm_tokens['input_ids'].append(
|
||||
torch.cat([
|
||||
input_ids[i][:this_input_ones], output_ids[i][1:],
|
||||
input_ids[i][this_input_ones:]
|
||||
]))
|
||||
llm_tokens['attention_mask'].append(
|
||||
torch.cat([
|
||||
input_atts[i][:this_input_ones], output_atts[i][1:],
|
||||
input_atts[i][this_input_ones:]
|
||||
]))
|
||||
llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
|
||||
llm_tokens['attention_mask'] = torch.stack(
|
||||
llm_tokens['attention_mask'])
|
||||
return llm_tokens, input_part_targets_len
|
||||
|
||||
def pack_inputs(self, batch):
|
||||
images = [image.unsqueeze(0) for image in batch['inputs']]
|
||||
data_samples = [data_sample for data_sample in batch['data_samples']]
|
||||
images = torch.cat(images, dim=0).to(get_device())
|
||||
inputs = {'image': images, 'data_samples': data_samples}
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
batch,
|
||||
use_nucleus_sampling=False,
|
||||
num_beams=5,
|
||||
max_length=256,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.5,
|
||||
length_penalty=1,
|
||||
num_captions=1,
|
||||
temperature=1,
|
||||
):
|
||||
inputs = self.pack_inputs(batch)
|
||||
image = inputs.pop('image')
|
||||
data_samples = inputs['data_samples']
|
||||
samples = {'image': image}
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
options = [data_sample.get('options') for data_sample in data_samples]
|
||||
if data_samples[0].get('context') is not None:
|
||||
contexts = [
|
||||
data_sample.get('context') for data_sample in data_samples
|
||||
]
|
||||
prompt = [
|
||||
context + ' ' + question + ' ' + option for context, question,
|
||||
option in zip(contexts, questions, options)
|
||||
]
|
||||
else:
|
||||
prompt = [
|
||||
question + ' ' + option
|
||||
for question, option in zip(questions, options)
|
||||
]
|
||||
|
||||
self.llm_tokenizer.padding_side = 'left'
|
||||
|
||||
image = samples['image']
|
||||
|
||||
bs = image.size(0)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt] * bs
|
||||
else:
|
||||
assert len(
|
||||
prompt
|
||||
) == bs, 'The number of prompts must be equal to the batch size.'
|
||||
|
||||
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
||||
if self.qformer_text_input:
|
||||
text_Qformer = self.tokenizer(
|
||||
prompt,
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
return_tensors='pt',
|
||||
).to(image.device)
|
||||
query_atts = torch.ones(query_tokens.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],
|
||||
dim=1)
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
|
||||
if self.qformer_text_input:
|
||||
query_output = self.Qformer.bert(
|
||||
text_Qformer.input_ids,
|
||||
attention_mask=Qformer_atts,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
inputs_llm = self.llm_proj(
|
||||
query_output.last_hidden_state[:, :query_tokens.size(1), :])
|
||||
atts_llm = torch.ones(inputs_llm.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
|
||||
prompt = ['###Human: ' + p + '###Assistant:' for p in prompt]
|
||||
prompt = [self.sys_prompt + p for p in prompt]
|
||||
llm_tokens = self.llm_tokenizer(prompt,
|
||||
padding='longest',
|
||||
return_tensors='pt').to(image.device)
|
||||
|
||||
with self.maybe_autocast():
|
||||
inputs_embeds = self.llm_model.get_input_embeddings()(
|
||||
llm_tokens.input_ids)
|
||||
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
||||
attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask],
|
||||
dim=1)
|
||||
|
||||
outputs = self.llm_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=use_nucleus_sampling,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_beams=num_beams,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
num_return_sequences=num_captions,
|
||||
)
|
||||
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
|
||||
output_text = self.llm_tokenizer.batch_decode(outputs,
|
||||
skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
output_text = self.post_process(output_text[0])
|
||||
data_sample = data_samples[0]
|
||||
data_sample.pred_answer = output_text
|
||||
|
||||
return data_sample
|
||||
|
||||
def post_process(self, output_text):
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
pattern = re.compile(r'([A-Z]\.)')
|
||||
res = pattern.findall(output_text)
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
3
opencompass/multimodal/models/minigpt_4/__init__.py
Normal file
3
opencompass/multimodal/models/minigpt_4/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .minigpt_4 import MiniGPT4MMBench
|
||||
|
||||
__all__ = ['MiniGPT4MMBench']
|
181
opencompass/multimodal/models/minigpt_4/minigpt_4.py
Normal file
181
opencompass/multimodal/models/minigpt_4/minigpt_4.py
Normal file
@ -0,0 +1,181 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from transformers import StoppingCriteriaList
|
||||
|
||||
from opencompass.registry import MM_MODELS
|
||||
|
||||
from .utils import StoppingCriteriaSub
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
def load_package():
|
||||
"""Load required packages from MiniGPT-4."""
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_folder_path = os.path.dirname(current_file_path)
|
||||
|
||||
sys.path.append(os.path.join(current_folder_path, 'MiniGPT-4')) # noqa
|
||||
from minigpt4.models.mini_gpt4 import MiniGPT4
|
||||
|
||||
sys.path.pop(-1)
|
||||
|
||||
return MiniGPT4
|
||||
|
||||
|
||||
MiniGPT4 = load_package()
|
||||
|
||||
|
||||
@MM_MODELS.register_module('minigpt-4-mmbench')
|
||||
class MiniGPT4MMBench(MiniGPT4):
|
||||
"""Inference code of MiniGPT-4 on MMBench.
|
||||
|
||||
Args:
|
||||
llama_model (str): The path of vicuna path.
|
||||
sys_prompt (str): The prompt added to the beginning
|
||||
of each query. Defaults to ''.
|
||||
low_resource (bool): Whether loaded in low precision.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llama_model: str,
|
||||
sys_prompt: str = '',
|
||||
low_resource: bool = False) -> None:
|
||||
super().__init__(llama_model=llama_model, low_resource=low_resource)
|
||||
|
||||
cur_device = get_device()
|
||||
stop_words_ids = [
|
||||
torch.tensor([835]).to(cur_device),
|
||||
torch.tensor([2277, 29937]).to(cur_device),
|
||||
]
|
||||
self.stopping_criteria = StoppingCriteriaList(
|
||||
[StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
self.sys_prompt = sys_prompt
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(
|
||||
self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],
|
||||
dtype=torch.long).to(device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
|
||||
-1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1],
|
||||
dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
def pack_inputs(self, batch):
|
||||
images = [image.unsqueeze(0) for image in batch['inputs']]
|
||||
data_samples = [data_sample for data_sample in batch['data_samples']]
|
||||
images = torch.cat(images, dim=0).to(get_device())
|
||||
inputs = {'image': images, 'data_samples': data_samples}
|
||||
return inputs
|
||||
|
||||
def generate(self, batch):
|
||||
inputs = self.pack_inputs(batch)
|
||||
image = inputs.pop('image')
|
||||
data_samples = inputs['data_samples']
|
||||
samples = {'image': image}
|
||||
question = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
options = [data_sample.get('options') for data_sample in data_samples]
|
||||
samples.update({'question': question[0]})
|
||||
samples.update({'options': options[0]})
|
||||
if data_samples[0].get('context') is not None:
|
||||
context = [
|
||||
data_sample.get('context') for data_sample in data_samples
|
||||
]
|
||||
samples.update({'context': context})
|
||||
data_sample = data_samples[0]
|
||||
img_prompt = '###Human: <Img><ImageHere></Img> '
|
||||
if 'context' in samples:
|
||||
context_prompt = samples['context'][0]
|
||||
|
||||
question = samples['question']
|
||||
options = samples['options']
|
||||
if 'context' in samples:
|
||||
prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
|
||||
else:
|
||||
prompt = img_prompt + ' ' + question + ' ' + options
|
||||
|
||||
# prompt = self.sys_prompt + prompt
|
||||
prompt = prompt + '###Assistant:'
|
||||
|
||||
image = samples['image']
|
||||
img_embeds, _ = self.encode_img(image)
|
||||
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
prompt_seg_tokens = [
|
||||
self.llama_tokenizer(seg,
|
||||
return_tensors='pt',
|
||||
add_special_tokens=i == 0).
|
||||
to(self.llama_model.model.embed_tokens.weight.device).input_ids
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
prompt_seg_embs = [
|
||||
self.llama_model.model.embed_tokens(seg)
|
||||
for seg in prompt_seg_tokens
|
||||
]
|
||||
prompt_seg_embs = [prompt_seg_embs[0], img_embeds, prompt_seg_embs[1]]
|
||||
prompt_embs = torch.cat(prompt_seg_embs, dim=1)
|
||||
|
||||
# generate output
|
||||
outputs = self.llama_model.generate(
|
||||
inputs_embeds=prompt_embs,
|
||||
max_new_tokens=20,
|
||||
num_beams=5,
|
||||
do_sample=False,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=-1.0,
|
||||
temperature=1.0,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_return_sequences=1)
|
||||
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = self.llama_tokenizer.decode(output_token,
|
||||
add_special_tokens=False)
|
||||
output_text = self.post_process(output_text)
|
||||
data_sample.pred_answer = output_text
|
||||
return data_sample
|
||||
|
||||
def post_process(self, output_text):
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
pattern = re.compile(r'([A-Z]\.)')
|
||||
res = pattern.findall(output_text)
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
56
opencompass/multimodal/models/minigpt_4/utils.py
Normal file
56
opencompass/multimodal/models/minigpt_4/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import timm.models.hub as timm_hub
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmengine.dist import is_distributed, is_main_process
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
"""Download a file from a URL and cache it locally.
|
||||
|
||||
If the file already exists, it is not downloaded again. If distributed,
|
||||
only the main process downloads the file, and the other processes wait for
|
||||
the file to be downloaded.
|
||||
"""
|
||||
|
||||
def get_cached_file_path():
|
||||
# a hack to sync the file path across processes
|
||||
parts = torch.hub.urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
||||
|
||||
return cached_file
|
||||
|
||||
if is_main_process():
|
||||
timm_hub.download_cached_file(url, check_hash, progress)
|
||||
|
||||
if is_distributed():
|
||||
dist.barrier()
|
||||
|
||||
return get_cached_file_path()
|
||||
|
||||
|
||||
def is_url(input_url):
|
||||
"""Check if an input string is a url.
|
||||
|
||||
look for http(s):// and ignoring the case
|
||||
"""
|
||||
is_url = re.match(r'^(?:http)s?://', input_url, re.IGNORECASE) is not None
|
||||
return is_url
|
@ -120,7 +120,7 @@ class AccEvaluator(HuggingfaceEvaluator):
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class RougeEvaluator(HuggingfaceEvaluator):
|
||||
"""Rouge evaluator."""
|
||||
"""Rouge evaluator.""" # noqa
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(metric='rouge')
|
||||
|
@ -1,2 +1,3 @@
|
||||
from .mm_naive import * # noqa: F401, F403
|
||||
from .naive import * # noqa: F401, F403
|
||||
from .size import * # noqa: F401, F403
|
||||
|
119
opencompass/partitioners/mm_naive.py
Normal file
119
opencompass/partitioners/mm_naive.py
Normal file
@ -0,0 +1,119 @@
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
|
||||
from opencompass.registry import PARTITIONERS
|
||||
|
||||
from .base import BasePartitioner
|
||||
|
||||
|
||||
@PARTITIONERS.register_module()
|
||||
class MultimodalNaivePartitioner(BasePartitioner):
|
||||
"""Multimodal naive task partitioner.
|
||||
|
||||
This partitioner will generate a task for each
|
||||
model-dataset-evaluator pair.
|
||||
|
||||
Args:
|
||||
config (ConfigDict): The full config dict.
|
||||
"""
|
||||
|
||||
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
|
||||
evaluators: List[ConfigDict], load_froms: List[ConfigDict],
|
||||
work_dir: str, num_gpus: int, num_procs: int,
|
||||
launcher: str) -> List[Dict]:
|
||||
"""Partition model-dataset pairs into tasks. Each task is defined as a
|
||||
dict and will run independently as a unit. Its structure is as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'models': [], # a list of model configs
|
||||
'datasets': [], # a list of dataset configs
|
||||
'evaluators': [], # a list of evaluator configs
|
||||
'load_froms': [], # a list of load_from paths
|
||||
'work_dir': '', # the work dir
|
||||
'num_gpus': int, # integer, number of gpus for each task
|
||||
'num_procs': int, # integer, number of gpus on single machine
|
||||
'launcher': str, # string, how to launch distributed training
|
||||
}
|
||||
|
||||
Args:
|
||||
models (List[ConfigDict]): A list of model configs.
|
||||
datasets (List[ConfigDict]): A list of dataset configs.
|
||||
evaluators (List[ConfigDict]): A list of evaluator configs.
|
||||
load_froms (List[ConfigDict]): A list of load_from paths.
|
||||
work_dir (str): The work dir for the task.
|
||||
num_gpus (int): Number of gpus for each task.
|
||||
num_procs (int): Number of gpus on single machine.
|
||||
launcher (str): How to launch distributed training.
|
||||
Only `slurm`, `pytorch` and `mpi` are available.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of tasks.
|
||||
"""
|
||||
|
||||
tasks = []
|
||||
for model, dataset, evaluator, load_from in zip(
|
||||
models, datasets, evaluators, load_froms):
|
||||
task = Config({
|
||||
'model': model,
|
||||
'dataset': dataset,
|
||||
'evaluator': evaluator,
|
||||
'load_from': load_from,
|
||||
'work_dir': work_dir,
|
||||
'num_gpus': num_gpus,
|
||||
'num_procs': num_procs,
|
||||
'launcher': launcher
|
||||
})
|
||||
tasks.append(task)
|
||||
|
||||
return tasks
|
||||
|
||||
def __call__(self, cfg: ConfigDict) -> List[Dict]:
|
||||
"""Generate tasks from config. Each task is defined as a
|
||||
dict and will run independently as a unit. Its structure is as
|
||||
follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'models': [], # a list of model configs
|
||||
'datasets': [], # a list of dataset configs
|
||||
'evaluators': [], # a list of evaluator configs
|
||||
'load_froms': [], # a list of load_from paths
|
||||
'work_dir': '', # the work dir
|
||||
'num_gpus': int, # integer, number of gpus for each task
|
||||
'num_procs': int, # integer, number of gpus on single machine
|
||||
}
|
||||
|
||||
Args:
|
||||
cfg (ConfigDict): The config dict, containing "models", "dataset"
|
||||
and "work_dir" keys.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of tasks.
|
||||
"""
|
||||
cfg = deepcopy(cfg)
|
||||
models = cfg['models']
|
||||
datasets = cfg['datasets']
|
||||
evaluators = cfg['evaluators']
|
||||
load_froms = cfg['load_froms']
|
||||
work_dir = cfg['work_dir']
|
||||
num_gpus = cfg['num_gpus']
|
||||
num_procs = cfg['num_procs']
|
||||
launcher = cfg['launcher']
|
||||
|
||||
tasks = self.partition(models, datasets, evaluators, load_froms,
|
||||
work_dir, num_gpus, num_procs, launcher)
|
||||
|
||||
self.logger.info(f'Partitioned into {len(tasks)} tasks.')
|
||||
for i, task in enumerate(tasks):
|
||||
model_name = task['model']['type']
|
||||
dataset_name = task['dataset']['dataset']['type']
|
||||
evaluator_name = task['evaluator'][0]['type']
|
||||
self.logger.debug(
|
||||
f'Task {i}: {model_name}-{dataset_name}-{evaluator_name}')
|
||||
|
||||
return tasks
|
@ -1,3 +1,6 @@
|
||||
from mmengine.registry import DATASETS as MMENGINE_DATASETS
|
||||
from mmengine.registry import METRICS as MMENGINE_METRICS
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
from mmengine.registry import Registry
|
||||
|
||||
PARTITIONERS = Registry('partitioner', locations=['opencompass.partitioners'])
|
||||
@ -22,3 +25,12 @@ ICL_PROMPT_TEMPLATES = Registry(
|
||||
locations=['opencompass.openicl.icl_prompt_template'])
|
||||
ICL_EVALUATORS = Registry('icl_evaluators',
|
||||
locations=['opencompass.openicl.icl_evaluator'])
|
||||
DATASETS = Registry('mm_datasets',
|
||||
parent=MMENGINE_DATASETS,
|
||||
locations=['opencompass.multimodal.datasets'])
|
||||
METRICS = Registry('metric',
|
||||
parent=MMENGINE_METRICS,
|
||||
locations=['opencompass.metrics'])
|
||||
MM_MODELS = Registry('mm_model',
|
||||
parent=MMENGINE_MODELS,
|
||||
locations=['opencompass.multimodal.models'])
|
||||
|
@ -81,7 +81,6 @@ class SlurmRunner(BaseRunner):
|
||||
Returns:
|
||||
tuple[str, int]: Task name and exit code.
|
||||
"""
|
||||
|
||||
task_type = self.task_cfg.type
|
||||
if isinstance(self.task_cfg.type, str):
|
||||
task_type = TASKS.get(task_type)
|
||||
|
@ -1,2 +1,3 @@
|
||||
from .mm_infer import * # noqa: F401, F403
|
||||
from .openicl_eval import * # noqa: F401, F403
|
||||
from .openicl_infer import * # noqa: F401, F403
|
||||
|
126
opencompass/tasks/mm_infer.py
Normal file
126
opencompass/tasks/mm_infer.py
Normal file
@ -0,0 +1,126 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.device import get_device
|
||||
from mmengine.dist import init_dist
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model.wrappers import MMDistributedDataParallel
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.utils import track_iter_progress
|
||||
|
||||
from opencompass.registry import MM_MODELS, TASKS
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
model = MM_MODELS.build(cfg['model'])
|
||||
load_from = cfg.get('load_from', None)
|
||||
if load_from is not None:
|
||||
state_dict = torch.load(cfg['load_from'], map_location='cpu')
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
elif 'state_dict' in state_dict:
|
||||
state_dict = state_dict['state_dict']
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
print_log(msg)
|
||||
model.to(get_device())
|
||||
if dist.is_initialized():
|
||||
model = MMDistributedDataParallel(
|
||||
model,
|
||||
device_ids=[int(os.environ['LOCAL_RANK'])],
|
||||
broadcast_buffers=False)
|
||||
return model
|
||||
|
||||
|
||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||
class MultimodalInferTask:
|
||||
"""Multimodal Inference Task.
|
||||
|
||||
This task is used to run the inference process.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: ConfigDict):
|
||||
self.num_gpus = cfg.get('num_gpus', 0)
|
||||
self.num_procs = cfg.get('num_procs', 1)
|
||||
self.dataloader = cfg.get('dataset')
|
||||
self.model = cfg.get('model')
|
||||
self.evaluator = cfg.get('evaluator')
|
||||
self.cfg = cfg
|
||||
self.logger = get_logger()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
model_name = self.model['type']
|
||||
dataset_name = self.dataloader['dataset']['type']
|
||||
evaluator_name = self.evaluator[0]['type']
|
||||
return f'{model_name}-{dataset_name}-{evaluator_name}'
|
||||
|
||||
def get_command(self, cfg_path, template):
|
||||
"""Get the command template for the task.
|
||||
|
||||
Args:
|
||||
cfg_path (str): The path to the config file of the task.
|
||||
template (str): The template which have '{task_cmd}' to format
|
||||
the command.
|
||||
"""
|
||||
script_path = __file__
|
||||
if self.num_gpus > 0:
|
||||
port = random.randint(12000, 32000)
|
||||
command = (f'torchrun --master_port={port} '
|
||||
f'--nproc_per_node {self.num_procs} '
|
||||
f'{script_path} {cfg_path}')
|
||||
else:
|
||||
command = f'python {script_path} {cfg_path}'
|
||||
|
||||
return template.format(task_cmd=command)
|
||||
|
||||
def run(self):
|
||||
# only support slurm, pytorch, mpi
|
||||
init_dist(self.cfg.launcher)
|
||||
self.logger.info(f'Task {self.name}')
|
||||
# build dataloader
|
||||
dataloader = Runner.build_dataloader(self.dataloader)
|
||||
# build model
|
||||
model = build_model(self.cfg)
|
||||
# build evaluator
|
||||
evaluator = Evaluator(self.evaluator)
|
||||
|
||||
for batch in track_iter_progress(dataloader):
|
||||
if dist.is_initialized():
|
||||
data_samples = model.module.generate(batch)
|
||||
else:
|
||||
data_samples = model.generate(batch)
|
||||
if not isinstance(data_samples, Sequence):
|
||||
data_samples = [data_samples]
|
||||
evaluator.process(data_samples)
|
||||
|
||||
metrics = evaluator.evaluate(len(dataloader.dataset))
|
||||
metrics_file = osp.join(cfg.work_dir, 'res.log')
|
||||
with open(metrics_file, 'w') as f:
|
||||
json.dump(metrics, f)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Model Inferencer')
|
||||
parser.add_argument('config', help='Config file path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
start_time = time.time()
|
||||
inferencer = MultimodalInferTask(cfg)
|
||||
inferencer.run()
|
||||
end_time = time.time()
|
||||
get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')
|
@ -1,6 +1,7 @@
|
||||
from .abbr import * # noqa
|
||||
from .build import * # noqa
|
||||
from .collect_env import * # noqa
|
||||
from .dependency import * # noqa
|
||||
from .fileio import * # noqa
|
||||
from .git import * # noqa
|
||||
from .lark import * # noqa
|
||||
|
32
opencompass/utils/dependency.py
Normal file
32
opencompass/utils/dependency.py
Normal file
@ -0,0 +1,32 @@
|
||||
import re
|
||||
|
||||
from importlib_metadata import PackageNotFoundError, distribution
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
|
||||
def satisfy_requirement(dep):
|
||||
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
|
||||
parts = re.split(pat, dep, maxsplit=1)
|
||||
parts = [p.strip() for p in parts]
|
||||
package = parts[0]
|
||||
if len(parts) > 1:
|
||||
op, version = parts[1:]
|
||||
op = {
|
||||
'>=': '__ge__',
|
||||
'==': '__eq__',
|
||||
'>': '__gt__',
|
||||
'<': '__lt__',
|
||||
'<=': '__le__'
|
||||
}[op]
|
||||
else:
|
||||
op, version = None, None
|
||||
|
||||
try:
|
||||
dist = distribution(package)
|
||||
if op is None or getattr(digit_version(dist.version), op)(
|
||||
digit_version(version)):
|
||||
return True
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return False
|
37
run.py
37
run.py
@ -6,7 +6,8 @@ from datetime import datetime
|
||||
|
||||
from mmengine.config import Config
|
||||
|
||||
from opencompass.partitioners import NaivePartitioner, SizePartitioner
|
||||
from opencompass.partitioners import (MultimodalNaivePartitioner,
|
||||
NaivePartitioner, SizePartitioner)
|
||||
from opencompass.registry import PARTITIONERS, RUNNERS
|
||||
from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner
|
||||
from opencompass.utils import LarkReporter, Summarizer, get_logger
|
||||
@ -37,6 +38,10 @@ def parse_args():
|
||||
'redirected to files',
|
||||
action='store_true',
|
||||
default=False)
|
||||
parser.add_argument('--mm-eval',
|
||||
help='Whether or not enable multimodal evaluation',
|
||||
action='store_true',
|
||||
default=False)
|
||||
parser.add_argument('--dry-run',
|
||||
help='Dry run mode, in which the scheduler will not '
|
||||
'actually run the tasks, but only print the commands '
|
||||
@ -201,7 +206,14 @@ def main():
|
||||
'also specified --slurm or --dlc. '
|
||||
'The "infer" configuration will be overridden by '
|
||||
'your runtime arguments.')
|
||||
if args.dlc or args.slurm or cfg.get('infer', None) is None:
|
||||
# Check whether run multimodal evaluation
|
||||
if args.mm_eval:
|
||||
partitioner = MultimodalNaivePartitioner(
|
||||
osp.join(cfg['work_dir'], 'predictions/'))
|
||||
tasks = partitioner(cfg)
|
||||
exec_mm_infer_runner(tasks, args, cfg)
|
||||
return
|
||||
elif args.dlc or args.slurm or cfg.get('infer', None) is None:
|
||||
# Use SizePartitioner to split into subtasks
|
||||
partitioner = SizePartitioner(
|
||||
osp.join(cfg['work_dir'], 'predictions/'),
|
||||
@ -283,6 +295,27 @@ def main():
|
||||
summarizer.summarize(time_str=cfg_time_str)
|
||||
|
||||
|
||||
def exec_mm_infer_runner(tasks, args, cfg):
|
||||
"""execute multimodal infer runner according to args."""
|
||||
if args.slurm:
|
||||
runner = SlurmRunner(dict(type='MultimodalInferTask'),
|
||||
max_num_workers=args.max_num_workers,
|
||||
partition=args.partition,
|
||||
quotatype=args.quotatype,
|
||||
retry=args.retry,
|
||||
debug=args.debug,
|
||||
lark_bot_url=cfg['lark_bot_url'])
|
||||
elif args.dlc:
|
||||
raise NotImplementedError('Currently, we do not support evaluating \
|
||||
multimodal models on dlc.')
|
||||
else:
|
||||
runner = LocalRunner(task=dict(type='MultimodalInferTask'),
|
||||
max_num_workers=args.max_num_workers,
|
||||
debug=args.debug,
|
||||
lark_bot_url=cfg['lark_bot_url'])
|
||||
runner(tasks)
|
||||
|
||||
|
||||
def exec_infer_runner(tasks, args, cfg):
|
||||
"""execute infer runner according to args."""
|
||||
if args.slurm:
|
||||
|
Loading…
Reference in New Issue
Block a user