diff --git a/configs/multimodal/visualglm/visualglm_6b_coco_caption.py b/configs/multimodal/visualglm/visualglm_6b_coco_caption.py
index 66e0801c..c2ffa6a9 100644
--- a/configs/multimodal/visualglm/visualglm_6b_coco_caption.py
+++ b/configs/multimodal/visualglm/visualglm_6b_coco_caption.py
@@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True,
- prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
+ prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
diff --git a/configs/multimodal/visualglm/visualglm_6b_flickr30k.py b/configs/multimodal/visualglm/visualglm_6b_flickr30k.py
index 58ab4649..9860ba78 100644
--- a/configs/multimodal/visualglm/visualglm_6b_flickr30k.py
+++ b/configs/multimodal/visualglm/visualglm_6b_flickr30k.py
@@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True,
- prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
+ prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
diff --git a/opencompass/multimodal/models/visualglm/post_processor.py b/opencompass/multimodal/models/visualglm/post_processor.py
index 4ff3b4a8..8289fc8a 100644
--- a/opencompass/multimodal/models/visualglm/post_processor.py
+++ b/opencompass/multimodal/models/visualglm/post_processor.py
@@ -9,9 +9,8 @@ class VisualGLMBasePostProcessor:
def __init__(self) -> None:
pass
- def __call__(self, output_token: torch.tensor, tokenizer: Any,
- input_len: int) -> str:
- return tokenizer.decode(output_token[input_len:])
+ def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
+ return tokenizer.decode(output_token)
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
@@ -20,9 +19,8 @@ class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
def __init__(self) -> None:
super().__init__()
- def __call__(self, output_token: torch.tensor, tokenizer: Any,
- input_len: int) -> str:
- output_text = tokenizer.decode(output_token[input_len:])
+ def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
+ output_text = tokenizer.decode(output_token)
if 'yes' in output_text.lower():
return 'yes'
elif 'no' in output_text.lower():
diff --git a/opencompass/multimodal/models/visualglm/prompt_constructor.py b/opencompass/multimodal/models/visualglm/prompt_constructor.py
index a10b7d77..68bea8e4 100644
--- a/opencompass/multimodal/models/visualglm/prompt_constructor.py
+++ b/opencompass/multimodal/models/visualglm/prompt_constructor.py
@@ -1,24 +1,16 @@
-import torch
-
-
class VisualGLMMMBenchPromptConstructor:
"""MMBench prompt constructor for VisualGLM.
- The overall prompt will be formulated as
- "system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
- image_prompt (str): Image prompt. (Default: '
')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def __init__(self,
system_prompt: str = '',
human_prompt: str = 'Q:',
- image_prompt: str = '
',
assistant_prompt: str = 'A:') -> None:
- self.image_prompt = image_prompt
self.system_prompt = system_prompt
self.human_prompt = human_prompt
self.assistant_prompt = assistant_prompt
@@ -33,26 +25,18 @@ class VisualGLMMMBenchPromptConstructor:
A tuple containing images, prompt, data_samples and image_position.
"""
- images = batch.pop('inputs')
- images = torch.stack(images, dim=0)
+ assert len(batch['inputs']) == 1
+ image = batch.pop('inputs')[0].unsqueeze(0)
+ data_sample = batch.pop('data_samples')[0]
+ img_prompt = '
'
+ if data_sample.get('context') is not None:
+ prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.context + ' ' + data_sample.question + ' ' + data_sample.options # noqa
+ else:
+ prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.question + ' ' + data_sample.options # noqa
+ prompt += self.assistant_prompt
+ image_position = prompt.rfind('
') + 5
- data_samples = batch.pop('data_samples')
- questions = [sample.get('question') for sample in data_samples]
- options = [sample.get('options') for sample in data_samples]
- contexts = [sample.get('context') for sample in data_samples]
- contexts = [c if c else '' for c in contexts]
-
- # generate text prompt
- prompt = [
- '{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
- self.human_prompt, context, question,
- option, self.assistant_prompt)
- for context, question, option in zip(contexts, questions, options)
- ]
-
- image_position = 5
-
- return images, prompt, data_samples, image_position
+ return image, prompt, data_sample, image_position
class VisualGLMBasePromptConstructor:
@@ -61,10 +45,17 @@ class VisualGLMBasePromptConstructor:
The prompt will concat
and the given system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
+ human_prompt (str): Human prompt. (Default: 'Q:')
+ assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
- def __init__(self, system_prompt='') -> None:
+ def __init__(self,
+ system_prompt: str = '',
+ human_prompt: str = 'Q:',
+ assistant_prompt: str = 'A:') -> None:
self.prompt = system_prompt
+ self.human_prompt = human_prompt
+ self.assistant_prompt = assistant_prompt
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@@ -76,16 +67,16 @@ class VisualGLMBasePromptConstructor:
A tuple containing images, prompt, data_samples and image_position.
"""
- images = batch.pop('inputs')
- images = torch.stack(images, dim=0)
- data_samples = batch.pop('data_samples')
+ assert len(batch['inputs']) == 1
+ image = batch.pop('inputs')[0].unsqueeze(0)
+ data_sample = batch.pop('data_samples')[0]
# generate text prompt
- prompt = ['
' + self.prompt for i in range(images.shape[0])]
+ prompt = '
' + self.human_prompt + self.prompt + self.assistant_prompt # noqa
- image_position = 5
+ image_position = prompt.rfind('
') + 5
- return images, prompt, data_samples, image_position
+ return image, prompt, data_sample, image_position
class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
@@ -94,10 +85,15 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat
, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
+ human_prompt (str): Human prompt. (Default: 'Q:')
+ assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
- def __init__(self, system_prompt='') -> None:
- super().__init__(system_prompt)
+ def __init__(self,
+ system_prompt='',
+ human_prompt: str = 'Q:',
+ assistant_prompt: str = 'A:') -> None:
+ super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@@ -109,19 +105,18 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
- images = batch.pop('inputs')
- images = torch.stack(images, dim=0)
- data_samples = batch.pop('data_samples')
- questions = [sample.get('question') for sample in data_samples]
+ assert len(batch['inputs']) == 1
+ image = batch.pop('inputs')[0].unsqueeze(0)
+ data_sample = batch.pop('data_samples')[0]
# generate text prompt
- prompt = [
- '
Q:{} {}\nA:'.format(question, self.prompt)
- for question in questions
- ]
- image_position = 5
+ question = data_sample.get('question')
+ prompt = '
' + self.human_prompt + question + self.prompt
+ prompt += '\n' + self.assistant_prompt
- return images, prompt, data_samples, image_position
+ image_position = prompt.rfind('
') + 5
+
+ return image, prompt, data_sample, image_position
class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
@@ -130,12 +125,17 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat image and all terms in a question.
Args:
system_prompt (str): System prompt. (Default: '')
+ human_prompt (str): Human prompt. (Default: 'Q:')
+ assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
- def __init__(self, system_prompt='') -> None:
- super().__init__(system_prompt)
+ def __init__(self,
+ system_prompt='',
+ human_prompt: str = 'Q:',
+ assistant_prompt: str = 'A:') -> None:
+ super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@@ -147,33 +147,24 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
- images = batch.pop('inputs')
- images = torch.stack(images, dim=0)
- data_samples = batch.pop('data_samples')
- questions = [
- 'Q: ' + sample.get('question') + '\n' for sample in data_samples
- ]
- choices = [sample.get('choices') for sample in data_samples]
- choices = [[
- f'({self.choice_mapping[i]}) ' + item
- for i, item in enumerate(choice)
- ] for choice in choices]
+ assert len(batch['inputs']) == 1
+ image = batch.pop('inputs')[0].unsqueeze(0)
+ data_sample = batch.pop('data_samples')[0]
+
+ questions = 'Question: ' + data_sample.get('question')
+ choices = data_sample.get('choices')
choices = [
- 'Choices: ' + ' '.join(choice) + '\n' for choice in choices
- ] # noqa
- contexts = [
- 'Context: ' + data_sample.get('hint') + '\n'
- for data_sample in data_samples
- ] # noqa
+ f'({self.choice_mapping[i]}) ' + item
+ for i, item in enumerate(choices)
+ ]
+ choices = 'Choices: ' + ' '.join(choices) + '\n'
+ contexts = 'Context: ' + data_sample.get('hint') + '\n'
# generate text prompt
- prompt = [
- '
' + context + question + choice + self.prompt
- for context, question, choice in zip(contexts, questions, choices)
- ]
- image_position = 5
+ prompt = '
' + self.human_prompt + contexts + questions + choices + self.prompt + self.assistant_prompt # noqa
+ image_position = prompt.rfind('
') + 5
- return images, prompt, data_samples, image_position
+ return image, prompt, data_sample, image_position
class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
@@ -182,10 +173,15 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat
, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
+ human_prompt (str): Human prompt. (Default: 'Q:')
+ assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
- def __init__(self, system_prompt='') -> None:
- super().__init__(system_prompt)
+ def __init__(self,
+ system_prompt='',
+ human_prompt: str = 'Q:',
+ assistant_prompt: str = 'A:') -> None:
+ super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@@ -197,22 +193,16 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
- images = batch.pop('inputs')
- images = torch.stack(images, dim=0)
- data_samples = batch.pop('data_samples')
- questions = [
- 'Q: ' + sample.get('question') + '\n' for sample in data_samples
- ]
- choices = [sample.get('choices') for sample in data_samples]
- choices = [
- 'Options: ' + ', '.join(choice) + '.\n' for choice in choices
- ] # noqa
+ assert len(batch['inputs']) == 1
+ image = batch.pop('inputs')[0].unsqueeze(0)
+ data_sample = batch.pop('data_samples')[0]
+
+ questions = data_sample.get('question') + '\n'
+ choices = data_sample.get('choices')
+ choices = 'Options: ' + ', '.join(choices) + '.\n'
# generate text prompt
- prompt = [
- '
' + question + choice + self.prompt
- for question, choice in zip(questions, choices)
- ]
- image_position = 5
+ prompt = '
' + self.human_prompt + questions + choices + self.prompt + self.assistant_prompt # noqa
+ image_position = prompt.rfind('
') + 5
- return images, prompt, data_samples, image_position
+ return image, prompt, data_sample, image_position
diff --git a/opencompass/multimodal/models/visualglm/visualglm.py b/opencompass/multimodal/models/visualglm/visualglm.py
index 7187b97e..a9534d94 100644
--- a/opencompass/multimodal/models/visualglm/visualglm.py
+++ b/opencompass/multimodal/models/visualglm/visualglm.py
@@ -43,39 +43,31 @@ class VisualGLM(nn.Module):
if gen_kwargs:
self.gen_kwargs = gen_kwargs
else:
- self.gen_kwargs = dict(
- max_new_tokens=30,
- num_beams=1,
- do_sample=False,
- repetition_penalty=1.0,
- length_penalty=-1.0,
- )
+ self.gen_kwargs = dict(max_length=1024,
+ min_length=100,
+ do_sample=True,
+ temperature=0.8,
+ top_p=0.4,
+ top_k=100,
+ repetition_penalty=1.2)
self.is_caption_task = is_caption_task
- def encode_by_tokenizer(self, multi_prompts, image_position):
- input_ids = []
- max_seq_length = 0
- for prompt in multi_prompts:
- input0 = self.tokenizer.encode(prompt[:image_position],
- add_special_tokens=False)
- input1 = [self.tokenizer.pad_token_id] * self.model.image_length
- input2 = self.tokenizer.encode(prompt[image_position:],
- add_special_tokens=False)
- input_all = sum([input0, input1, input2], [])
- input_all = self.tokenizer.build_inputs_with_special_tokens(
- input_all)
- max_seq_length = max(max_seq_length, len(input_all))
- input_ids.append(input_all)
+ def encode_by_tokenizer(self, prompt, image_position):
+
+ input0 = self.tokenizer.encode(prompt[:image_position],
+ add_special_tokens=False)
+ input1 = [self.tokenizer.unk_token_id] * self.model.image_length
+ input2 = self.tokenizer.encode(prompt[image_position:],
+ add_special_tokens=False)
+ input_all = sum([input0, input1, input2], [])
+ input_all = self.tokenizer.build_inputs_with_special_tokens(input_all)
+ input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
+ input_all = input_all.unsqueeze(0)
+
pre_image_len = len(input0)
- # padding
- for i, _ in enumerate(input_ids):
- pad_len = max_seq_length - len(input_ids[i])
- input_ids[i] = [self.tokenizer.pad_token_id
- ] * pad_len + input_ids[i]
-
- return input_ids, pre_image_len
+ return input_all, pre_image_len
def generate(self, batch):
# process input
@@ -87,26 +79,24 @@ class VisualGLM(nn.Module):
input_all, pre_image_len = self.encode_by_tokenizer(
prompt, image_position)
- input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
-
# build input param
inputs = {
'input_ids': input_all,
'pre_image_length': pre_image_len,
'images': image
}
+
# generate answer
outputs = self.model.generate(**inputs, **self.gen_kwargs)
# format output
- outputs = outputs.tolist()
- for i, sample in enumerate(data_sample):
- answer = self.post_processor(outputs[i], self.tokenizer,
- input_all.shape[1])
- if self.is_caption_task:
- data_sample[i].pred_caption = answer
- else:
- data_sample[i].pred_answer = answer
+ outputs = outputs.tolist()[0][input_all.shape[1]:]
+ answer = self.post_processor(outputs, self.tokenizer)
+
+ if self.is_caption_task:
+ data_sample.pred_caption = answer
+ else:
+ data_sample.pred_answer = answer
return data_sample