OpenCompass/opencompass/multimodal/models/openflamingo/prompt_constructor.py
Yike Yuan bd50bad8b5
[Feat] Support mm models on public dataset and fix several issues. (#412)
* [Feat] Add public dataset support for visualglm, qwenvl, and flamingo

* [Fix] MMBench related changes.

* [Fix] Openflamingo inference.

* [Fix] Hide ckpt path.

* [Fix] Pre-commit.

---------

Co-authored-by: Haodong Duan <dhd.efz@gmail.com>
2023-09-19 19:08:44 +08:00

131 lines
4.5 KiB
Python

from typing import Optional
from mmpretrain.structures import DataSample
class OpenFlamingoMMBenchPromptConstructor:
"""MMBench prompt constructor for OpenFlamingo."""
def __init__(self) -> None:
pass
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
prompts = []
question = sample.get('question')
option = sample.get('options')
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
if sample.get('context') is not None:
prompt = sample.get('context') + ' ' + prompt
prompts.append(prompt)
return prompts
class OpenFlamingoCaptionPromptConstructor:
"""Caption prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Output:A child holding a flowered umbrella and petting a yak.<|endofchunk|>' # noqa
'Output:The child is holding a brush close to his mouth.<|endofchunk|>' # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
prompts = []
prompt = '<image>Output:'
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoVQAPromptConstructor:
"""VQA prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Question:Is the sky dark? Short Answer:yes<|endofchunk|>' # noqa: E501
'Question:What is on the white wall? Short Answer:pipe<|endofchunk|>' # noqa: E501
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
prompts = []
for sample in data_samples:
question = sample.get('question')
prompt = '<image>Question:{} Short Answer:'.format(question)
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoScienceQAPromptConstructor:
"""ScienceQA prompt constructor for OpenFlamingo."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
"Context:Question:Which of these states is farthest north? Choices:['(A) West Virginia' '(B) Louisiana' '(C) Arizona' '(D) Oklahoma'] Answer with a single character: A<|endofchunk|>" # noqa
'Context:The diagrams below show two pure samples of gas in identical closed, rigid containers. Each colored ball represents one gas particle. Both samples have the same number of particles.' # noqa
"Question:Compare the average kinetic energies of the particles in each sample. Which sample has the higher temperature? Choices:'[(A) neither' '(B) sample A' '(C) sample B'] Answer with a single character: C<|endofchunk|>" # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
question = sample.get('question')
choices = sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
hint = sample.get('hint')
prompts = []
prompt = '<image>Context:{} Question:{} Choices:{}'.format(
hint, question, choices)
prompt += ' Answer with a single character:'
prompts.append(self.shot_prompt + prompt)
return prompts