OpenCompass/opencompass/multimodal/models/llava/prompt_constructor.py
Yike Yuan 8d368d1cd6
[Feat] Support visualglm and llava for MMBench evaluation. (#211)
* [Feat] Support visualglm inference on MMBench.

* [Feat] Support llava inference on MMBench.

* [Fix] Fix pre-commit format.

* [Fix] Add docstring for llava

* [Fix] Fix multi-process inference error of LlaVA and add comments.
1. Set `low_cpu_mem_usage` to False to address device issue.
2. Add docstring and type hints.
3. Rename class and remove registry.

* [Fix] Pre-commit fix.

* [Fix] add forward entry, add dynamic import to seedbench

* [Fix] Fix pre-commit.

* [Fix] Fix missing context.

* [Fix] Fix docstring.
2023-08-21 15:57:30 +08:00

60 lines
2.1 KiB
Python

import importlib
from typing import Any
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
class LLaVAMMBenchPromptConstructor:
"""Prompt constructor for LLaVA on MMBench.
Args:
conv_templates (Any): Conversation class to build prompt.
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
"""
def __init__(self, conv_templates: Any, conv_mode: str,
mm_use_im_start_end: bool) -> None:
self.conv_templates = conv_templates
self.conv_mode = conv_mode
self.mm_use_im_start_end = mm_use_im_start_end
conversation = importlib.import_module('llava.conversation')
self.SeparatorStyle = conversation.SeparatorStyle
def __call__(self, inputs: dict) -> tuple:
"""Construct prompt.
Args:
inputs (dict): Input data containing images and data_samples.
Returns:
tuple: A tuple containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
assert len(data_samples) == 1
question = data_samples[0].get('question')
options = data_samples[0].get('options')
context = data_samples[0].get('context')
if context is not None:
prompt = context + ' ' + question + ' ' + options
else:
prompt = question + ' ' + options
if self.mm_use_im_start_end:
prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN + '\n' + prompt)
else:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt # noqa
conv = self.conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
output_prompt = conv.get_prompt()
stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa
return output_prompt, stop_str