mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Fix performance issue of visualglm. (#424)
* [Fix] Visualglm performance fixed. * [Fix] Hide ckpt path.
This commit is contained in:
parent
8803f7f7a6
commit
97fdc51102
@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict(
|
|||||||
type='visualglm',
|
type='visualglm',
|
||||||
pretrained_path='/path/to/visualglm', # or Huggingface repo id
|
pretrained_path='/path/to/visualglm', # or Huggingface repo id
|
||||||
is_caption_task=True,
|
is_caption_task=True,
|
||||||
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
|
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
|
||||||
post_processor=dict(type=VisualGLMBasePostProcessor)
|
post_processor=dict(type=VisualGLMBasePostProcessor)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict(
|
|||||||
type='visualglm',
|
type='visualglm',
|
||||||
pretrained_path='/path/to/visualglm', # or Huggingface repo id
|
pretrained_path='/path/to/visualglm', # or Huggingface repo id
|
||||||
is_caption_task=True,
|
is_caption_task=True,
|
||||||
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
|
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
|
||||||
post_processor=dict(type=VisualGLMBasePostProcessor)
|
post_processor=dict(type=VisualGLMBasePostProcessor)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,9 +9,8 @@ class VisualGLMBasePostProcessor:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __call__(self, output_token: torch.tensor, tokenizer: Any,
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
||||||
input_len: int) -> str:
|
return tokenizer.decode(output_token)
|
||||||
return tokenizer.decode(output_token[input_len:])
|
|
||||||
|
|
||||||
|
|
||||||
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
|
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
|
||||||
@ -20,9 +19,8 @@ class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def __call__(self, output_token: torch.tensor, tokenizer: Any,
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
||||||
input_len: int) -> str:
|
output_text = tokenizer.decode(output_token)
|
||||||
output_text = tokenizer.decode(output_token[input_len:])
|
|
||||||
if 'yes' in output_text.lower():
|
if 'yes' in output_text.lower():
|
||||||
return 'yes'
|
return 'yes'
|
||||||
elif 'no' in output_text.lower():
|
elif 'no' in output_text.lower():
|
||||||
|
@ -1,24 +1,16 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class VisualGLMMMBenchPromptConstructor:
|
class VisualGLMMMBenchPromptConstructor:
|
||||||
"""MMBench prompt constructor for VisualGLM.
|
"""MMBench prompt constructor for VisualGLM.
|
||||||
|
|
||||||
The overall prompt will be formulated as
|
|
||||||
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
|
|
||||||
Args:
|
Args:
|
||||||
system_prompt (str): System prompt. (Default: '')
|
system_prompt (str): System prompt. (Default: '')
|
||||||
human_prompt (str): Human prompt. (Default: 'Q:')
|
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||||
image_prompt (str): Image prompt. (Default: '<img></img>')
|
|
||||||
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
system_prompt: str = '',
|
system_prompt: str = '',
|
||||||
human_prompt: str = 'Q:',
|
human_prompt: str = 'Q:',
|
||||||
image_prompt: str = '<img></img>',
|
|
||||||
assistant_prompt: str = 'A:') -> None:
|
assistant_prompt: str = 'A:') -> None:
|
||||||
self.image_prompt = image_prompt
|
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
self.human_prompt = human_prompt
|
self.human_prompt = human_prompt
|
||||||
self.assistant_prompt = assistant_prompt
|
self.assistant_prompt = assistant_prompt
|
||||||
@ -33,26 +25,18 @@ class VisualGLMMMBenchPromptConstructor:
|
|||||||
A tuple containing images, prompt, data_samples and image_position.
|
A tuple containing images, prompt, data_samples and image_position.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images = batch.pop('inputs')
|
assert len(batch['inputs']) == 1
|
||||||
images = torch.stack(images, dim=0)
|
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||||
|
data_sample = batch.pop('data_samples')[0]
|
||||||
|
img_prompt = '<img></img>'
|
||||||
|
if data_sample.get('context') is not None:
|
||||||
|
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.context + ' ' + data_sample.question + ' ' + data_sample.options # noqa
|
||||||
|
else:
|
||||||
|
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.question + ' ' + data_sample.options # noqa
|
||||||
|
prompt += self.assistant_prompt
|
||||||
|
image_position = prompt.rfind('<img>') + 5
|
||||||
|
|
||||||
data_samples = batch.pop('data_samples')
|
return image, prompt, data_sample, image_position
|
||||||
questions = [sample.get('question') for sample in data_samples]
|
|
||||||
options = [sample.get('options') for sample in data_samples]
|
|
||||||
contexts = [sample.get('context') for sample in data_samples]
|
|
||||||
contexts = [c if c else '' for c in contexts]
|
|
||||||
|
|
||||||
# generate text prompt
|
|
||||||
prompt = [
|
|
||||||
'{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
|
|
||||||
self.human_prompt, context, question,
|
|
||||||
option, self.assistant_prompt)
|
|
||||||
for context, question, option in zip(contexts, questions, options)
|
|
||||||
]
|
|
||||||
|
|
||||||
image_position = 5
|
|
||||||
|
|
||||||
return images, prompt, data_samples, image_position
|
|
||||||
|
|
||||||
|
|
||||||
class VisualGLMBasePromptConstructor:
|
class VisualGLMBasePromptConstructor:
|
||||||
@ -61,10 +45,17 @@ class VisualGLMBasePromptConstructor:
|
|||||||
The prompt will concat <img> and the given system prompt.
|
The prompt will concat <img> and the given system prompt.
|
||||||
Args:
|
Args:
|
||||||
system_prompt (str): System prompt. (Default: '')
|
system_prompt (str): System prompt. (Default: '')
|
||||||
|
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||||
|
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, system_prompt='') -> None:
|
def __init__(self,
|
||||||
|
system_prompt: str = '',
|
||||||
|
human_prompt: str = 'Q:',
|
||||||
|
assistant_prompt: str = 'A:') -> None:
|
||||||
self.prompt = system_prompt
|
self.prompt = system_prompt
|
||||||
|
self.human_prompt = human_prompt
|
||||||
|
self.assistant_prompt = assistant_prompt
|
||||||
|
|
||||||
def __call__(self, batch: dict) -> tuple:
|
def __call__(self, batch: dict) -> tuple:
|
||||||
"""Construct prompt.
|
"""Construct prompt.
|
||||||
@ -76,16 +67,16 @@ class VisualGLMBasePromptConstructor:
|
|||||||
A tuple containing images, prompt, data_samples and image_position.
|
A tuple containing images, prompt, data_samples and image_position.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images = batch.pop('inputs')
|
assert len(batch['inputs']) == 1
|
||||||
images = torch.stack(images, dim=0)
|
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||||
data_samples = batch.pop('data_samples')
|
data_sample = batch.pop('data_samples')[0]
|
||||||
|
|
||||||
# generate text prompt
|
# generate text prompt
|
||||||
prompt = ['<img></img>' + self.prompt for i in range(images.shape[0])]
|
prompt = '<img></img>' + self.human_prompt + self.prompt + self.assistant_prompt # noqa
|
||||||
|
|
||||||
image_position = 5
|
image_position = prompt.rfind('<img>') + 5
|
||||||
|
|
||||||
return images, prompt, data_samples, image_position
|
return image, prompt, data_sample, image_position
|
||||||
|
|
||||||
|
|
||||||
class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
|
class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||||
@ -94,10 +85,15 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
|
|||||||
The prompt will concat <img>, the question and the system prompt.
|
The prompt will concat <img>, the question and the system prompt.
|
||||||
Args:
|
Args:
|
||||||
system_prompt (str): System prompt. (Default: '')
|
system_prompt (str): System prompt. (Default: '')
|
||||||
|
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||||
|
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, system_prompt='') -> None:
|
def __init__(self,
|
||||||
super().__init__(system_prompt)
|
system_prompt='',
|
||||||
|
human_prompt: str = 'Q:',
|
||||||
|
assistant_prompt: str = 'A:') -> None:
|
||||||
|
super().__init__(system_prompt, human_prompt, assistant_prompt)
|
||||||
|
|
||||||
def __call__(self, batch: dict) -> tuple:
|
def __call__(self, batch: dict) -> tuple:
|
||||||
"""Construct prompt.
|
"""Construct prompt.
|
||||||
@ -109,19 +105,18 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
|
|||||||
A tuple containing images, prompt, data_samples and image_position.
|
A tuple containing images, prompt, data_samples and image_position.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images = batch.pop('inputs')
|
assert len(batch['inputs']) == 1
|
||||||
images = torch.stack(images, dim=0)
|
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||||
data_samples = batch.pop('data_samples')
|
data_sample = batch.pop('data_samples')[0]
|
||||||
questions = [sample.get('question') for sample in data_samples]
|
|
||||||
|
|
||||||
# generate text prompt
|
# generate text prompt
|
||||||
prompt = [
|
question = data_sample.get('question')
|
||||||
'<img></img>Q:{} {}\nA:'.format(question, self.prompt)
|
prompt = '<img></img>' + self.human_prompt + question + self.prompt
|
||||||
for question in questions
|
prompt += '\n' + self.assistant_prompt
|
||||||
]
|
|
||||||
image_position = 5
|
|
||||||
|
|
||||||
return images, prompt, data_samples, image_position
|
image_position = prompt.rfind('<img>') + 5
|
||||||
|
|
||||||
|
return image, prompt, data_sample, image_position
|
||||||
|
|
||||||
|
|
||||||
class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
|
class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||||
@ -130,12 +125,17 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
|
|||||||
The prompt will concat image and all terms in a question.
|
The prompt will concat image and all terms in a question.
|
||||||
Args:
|
Args:
|
||||||
system_prompt (str): System prompt. (Default: '')
|
system_prompt (str): System prompt. (Default: '')
|
||||||
|
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||||
|
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
|
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
|
||||||
|
|
||||||
def __init__(self, system_prompt='') -> None:
|
def __init__(self,
|
||||||
super().__init__(system_prompt)
|
system_prompt='',
|
||||||
|
human_prompt: str = 'Q:',
|
||||||
|
assistant_prompt: str = 'A:') -> None:
|
||||||
|
super().__init__(system_prompt, human_prompt, assistant_prompt)
|
||||||
|
|
||||||
def __call__(self, batch: dict) -> tuple:
|
def __call__(self, batch: dict) -> tuple:
|
||||||
"""Construct prompt.
|
"""Construct prompt.
|
||||||
@ -147,33 +147,24 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
|
|||||||
A tuple containing images, prompt, data_samples and image_position.
|
A tuple containing images, prompt, data_samples and image_position.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images = batch.pop('inputs')
|
assert len(batch['inputs']) == 1
|
||||||
images = torch.stack(images, dim=0)
|
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||||
data_samples = batch.pop('data_samples')
|
data_sample = batch.pop('data_samples')[0]
|
||||||
questions = [
|
|
||||||
'Q: ' + sample.get('question') + '\n' for sample in data_samples
|
questions = 'Question: ' + data_sample.get('question')
|
||||||
]
|
choices = data_sample.get('choices')
|
||||||
choices = [sample.get('choices') for sample in data_samples]
|
|
||||||
choices = [[
|
|
||||||
f'({self.choice_mapping[i]}) ' + item
|
|
||||||
for i, item in enumerate(choice)
|
|
||||||
] for choice in choices]
|
|
||||||
choices = [
|
choices = [
|
||||||
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
|
f'({self.choice_mapping[i]}) ' + item
|
||||||
] # noqa
|
for i, item in enumerate(choices)
|
||||||
contexts = [
|
]
|
||||||
'Context: ' + data_sample.get('hint') + '\n'
|
choices = 'Choices: ' + ' '.join(choices) + '\n'
|
||||||
for data_sample in data_samples
|
contexts = 'Context: ' + data_sample.get('hint') + '\n'
|
||||||
] # noqa
|
|
||||||
|
|
||||||
# generate text prompt
|
# generate text prompt
|
||||||
prompt = [
|
prompt = '<img></img>' + self.human_prompt + contexts + questions + choices + self.prompt + self.assistant_prompt # noqa
|
||||||
'<img></img>' + context + question + choice + self.prompt
|
image_position = prompt.rfind('<img>') + 5
|
||||||
for context, question, choice in zip(contexts, questions, choices)
|
|
||||||
]
|
|
||||||
image_position = 5
|
|
||||||
|
|
||||||
return images, prompt, data_samples, image_position
|
return image, prompt, data_sample, image_position
|
||||||
|
|
||||||
|
|
||||||
class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
|
class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||||
@ -182,10 +173,15 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
|
|||||||
The prompt will concat <img>, the question and the system prompt.
|
The prompt will concat <img>, the question and the system prompt.
|
||||||
Args:
|
Args:
|
||||||
system_prompt (str): System prompt. (Default: '')
|
system_prompt (str): System prompt. (Default: '')
|
||||||
|
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||||
|
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, system_prompt='') -> None:
|
def __init__(self,
|
||||||
super().__init__(system_prompt)
|
system_prompt='',
|
||||||
|
human_prompt: str = 'Q:',
|
||||||
|
assistant_prompt: str = 'A:') -> None:
|
||||||
|
super().__init__(system_prompt, human_prompt, assistant_prompt)
|
||||||
|
|
||||||
def __call__(self, batch: dict) -> tuple:
|
def __call__(self, batch: dict) -> tuple:
|
||||||
"""Construct prompt.
|
"""Construct prompt.
|
||||||
@ -197,22 +193,16 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
|
|||||||
A tuple containing images, prompt, data_samples and image_position.
|
A tuple containing images, prompt, data_samples and image_position.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images = batch.pop('inputs')
|
assert len(batch['inputs']) == 1
|
||||||
images = torch.stack(images, dim=0)
|
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||||
data_samples = batch.pop('data_samples')
|
data_sample = batch.pop('data_samples')[0]
|
||||||
questions = [
|
|
||||||
'Q: ' + sample.get('question') + '\n' for sample in data_samples
|
questions = data_sample.get('question') + '\n'
|
||||||
]
|
choices = data_sample.get('choices')
|
||||||
choices = [sample.get('choices') for sample in data_samples]
|
choices = 'Options: ' + ', '.join(choices) + '.\n'
|
||||||
choices = [
|
|
||||||
'Options: ' + ', '.join(choice) + '.\n' for choice in choices
|
|
||||||
] # noqa
|
|
||||||
|
|
||||||
# generate text prompt
|
# generate text prompt
|
||||||
prompt = [
|
prompt = '<img></img>' + self.human_prompt + questions + choices + self.prompt + self.assistant_prompt # noqa
|
||||||
'<img></img>' + question + choice + self.prompt
|
image_position = prompt.rfind('<img>') + 5
|
||||||
for question, choice in zip(questions, choices)
|
|
||||||
]
|
|
||||||
image_position = 5
|
|
||||||
|
|
||||||
return images, prompt, data_samples, image_position
|
return image, prompt, data_sample, image_position
|
||||||
|
@ -43,39 +43,31 @@ class VisualGLM(nn.Module):
|
|||||||
if gen_kwargs:
|
if gen_kwargs:
|
||||||
self.gen_kwargs = gen_kwargs
|
self.gen_kwargs = gen_kwargs
|
||||||
else:
|
else:
|
||||||
self.gen_kwargs = dict(
|
self.gen_kwargs = dict(max_length=1024,
|
||||||
max_new_tokens=30,
|
min_length=100,
|
||||||
num_beams=1,
|
do_sample=True,
|
||||||
do_sample=False,
|
temperature=0.8,
|
||||||
repetition_penalty=1.0,
|
top_p=0.4,
|
||||||
length_penalty=-1.0,
|
top_k=100,
|
||||||
)
|
repetition_penalty=1.2)
|
||||||
|
|
||||||
self.is_caption_task = is_caption_task
|
self.is_caption_task = is_caption_task
|
||||||
|
|
||||||
def encode_by_tokenizer(self, multi_prompts, image_position):
|
def encode_by_tokenizer(self, prompt, image_position):
|
||||||
input_ids = []
|
|
||||||
max_seq_length = 0
|
input0 = self.tokenizer.encode(prompt[:image_position],
|
||||||
for prompt in multi_prompts:
|
add_special_tokens=False)
|
||||||
input0 = self.tokenizer.encode(prompt[:image_position],
|
input1 = [self.tokenizer.unk_token_id] * self.model.image_length
|
||||||
add_special_tokens=False)
|
input2 = self.tokenizer.encode(prompt[image_position:],
|
||||||
input1 = [self.tokenizer.pad_token_id] * self.model.image_length
|
add_special_tokens=False)
|
||||||
input2 = self.tokenizer.encode(prompt[image_position:],
|
input_all = sum([input0, input1, input2], [])
|
||||||
add_special_tokens=False)
|
input_all = self.tokenizer.build_inputs_with_special_tokens(input_all)
|
||||||
input_all = sum([input0, input1, input2], [])
|
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
|
||||||
input_all = self.tokenizer.build_inputs_with_special_tokens(
|
input_all = input_all.unsqueeze(0)
|
||||||
input_all)
|
|
||||||
max_seq_length = max(max_seq_length, len(input_all))
|
|
||||||
input_ids.append(input_all)
|
|
||||||
pre_image_len = len(input0)
|
pre_image_len = len(input0)
|
||||||
|
|
||||||
# padding
|
return input_all, pre_image_len
|
||||||
for i, _ in enumerate(input_ids):
|
|
||||||
pad_len = max_seq_length - len(input_ids[i])
|
|
||||||
input_ids[i] = [self.tokenizer.pad_token_id
|
|
||||||
] * pad_len + input_ids[i]
|
|
||||||
|
|
||||||
return input_ids, pre_image_len
|
|
||||||
|
|
||||||
def generate(self, batch):
|
def generate(self, batch):
|
||||||
# process input
|
# process input
|
||||||
@ -87,26 +79,24 @@ class VisualGLM(nn.Module):
|
|||||||
input_all, pre_image_len = self.encode_by_tokenizer(
|
input_all, pre_image_len = self.encode_by_tokenizer(
|
||||||
prompt, image_position)
|
prompt, image_position)
|
||||||
|
|
||||||
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
|
|
||||||
|
|
||||||
# build input param
|
# build input param
|
||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': input_all,
|
'input_ids': input_all,
|
||||||
'pre_image_length': pre_image_len,
|
'pre_image_length': pre_image_len,
|
||||||
'images': image
|
'images': image
|
||||||
}
|
}
|
||||||
|
|
||||||
# generate answer
|
# generate answer
|
||||||
outputs = self.model.generate(**inputs, **self.gen_kwargs)
|
outputs = self.model.generate(**inputs, **self.gen_kwargs)
|
||||||
|
|
||||||
# format output
|
# format output
|
||||||
outputs = outputs.tolist()
|
outputs = outputs.tolist()[0][input_all.shape[1]:]
|
||||||
for i, sample in enumerate(data_sample):
|
answer = self.post_processor(outputs, self.tokenizer)
|
||||||
answer = self.post_processor(outputs[i], self.tokenizer,
|
|
||||||
input_all.shape[1])
|
if self.is_caption_task:
|
||||||
if self.is_caption_task:
|
data_sample.pred_caption = answer
|
||||||
data_sample[i].pred_caption = answer
|
else:
|
||||||
else:
|
data_sample.pred_answer = answer
|
||||||
data_sample[i].pred_answer = answer
|
|
||||||
|
|
||||||
return data_sample
|
return data_sample
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user