mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Fix performance issue of visualglm. (#424)
* [Fix] Visualglm performance fixed. * [Fix] Hide ckpt path.
This commit is contained in:
parent
8803f7f7a6
commit
97fdc51102
@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict(
|
||||
type='visualglm',
|
||||
pretrained_path='/path/to/visualglm', # or Huggingface repo id
|
||||
is_caption_task=True,
|
||||
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
|
||||
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
|
||||
post_processor=dict(type=VisualGLMBasePostProcessor)
|
||||
)
|
||||
|
||||
|
@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict(
|
||||
type='visualglm',
|
||||
pretrained_path='/path/to/visualglm', # or Huggingface repo id
|
||||
is_caption_task=True,
|
||||
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
|
||||
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
|
||||
post_processor=dict(type=VisualGLMBasePostProcessor)
|
||||
)
|
||||
|
||||
|
@ -9,9 +9,8 @@ class VisualGLMBasePostProcessor:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer: Any,
|
||||
input_len: int) -> str:
|
||||
return tokenizer.decode(output_token[input_len:])
|
||||
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
||||
return tokenizer.decode(output_token)
|
||||
|
||||
|
||||
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
|
||||
@ -20,9 +19,8 @@ class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer: Any,
|
||||
input_len: int) -> str:
|
||||
output_text = tokenizer.decode(output_token[input_len:])
|
||||
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
|
||||
output_text = tokenizer.decode(output_token)
|
||||
if 'yes' in output_text.lower():
|
||||
return 'yes'
|
||||
elif 'no' in output_text.lower():
|
||||
|
@ -1,24 +1,16 @@
|
||||
import torch
|
||||
|
||||
|
||||
class VisualGLMMMBenchPromptConstructor:
|
||||
"""MMBench prompt constructor for VisualGLM.
|
||||
|
||||
The overall prompt will be formulated as
|
||||
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
|
||||
Args:
|
||||
system_prompt (str): System prompt. (Default: '')
|
||||
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||
image_prompt (str): Image prompt. (Default: '<img></img>')
|
||||
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
system_prompt: str = '',
|
||||
human_prompt: str = 'Q:',
|
||||
image_prompt: str = '<img></img>',
|
||||
assistant_prompt: str = 'A:') -> None:
|
||||
self.image_prompt = image_prompt
|
||||
self.system_prompt = system_prompt
|
||||
self.human_prompt = human_prompt
|
||||
self.assistant_prompt = assistant_prompt
|
||||
@ -33,26 +25,18 @@ class VisualGLMMMBenchPromptConstructor:
|
||||
A tuple containing images, prompt, data_samples and image_position.
|
||||
"""
|
||||
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
assert len(batch['inputs']) == 1
|
||||
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||
data_sample = batch.pop('data_samples')[0]
|
||||
img_prompt = '<img></img>'
|
||||
if data_sample.get('context') is not None:
|
||||
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.context + ' ' + data_sample.question + ' ' + data_sample.options # noqa
|
||||
else:
|
||||
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.question + ' ' + data_sample.options # noqa
|
||||
prompt += self.assistant_prompt
|
||||
image_position = prompt.rfind('<img>') + 5
|
||||
|
||||
data_samples = batch.pop('data_samples')
|
||||
questions = [sample.get('question') for sample in data_samples]
|
||||
options = [sample.get('options') for sample in data_samples]
|
||||
contexts = [sample.get('context') for sample in data_samples]
|
||||
contexts = [c if c else '' for c in contexts]
|
||||
|
||||
# generate text prompt
|
||||
prompt = [
|
||||
'{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
|
||||
self.human_prompt, context, question,
|
||||
option, self.assistant_prompt)
|
||||
for context, question, option in zip(contexts, questions, options)
|
||||
]
|
||||
|
||||
image_position = 5
|
||||
|
||||
return images, prompt, data_samples, image_position
|
||||
return image, prompt, data_sample, image_position
|
||||
|
||||
|
||||
class VisualGLMBasePromptConstructor:
|
||||
@ -61,10 +45,17 @@ class VisualGLMBasePromptConstructor:
|
||||
The prompt will concat <img> and the given system prompt.
|
||||
Args:
|
||||
system_prompt (str): System prompt. (Default: '')
|
||||
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||
"""
|
||||
|
||||
def __init__(self, system_prompt='') -> None:
|
||||
def __init__(self,
|
||||
system_prompt: str = '',
|
||||
human_prompt: str = 'Q:',
|
||||
assistant_prompt: str = 'A:') -> None:
|
||||
self.prompt = system_prompt
|
||||
self.human_prompt = human_prompt
|
||||
self.assistant_prompt = assistant_prompt
|
||||
|
||||
def __call__(self, batch: dict) -> tuple:
|
||||
"""Construct prompt.
|
||||
@ -76,16 +67,16 @@ class VisualGLMBasePromptConstructor:
|
||||
A tuple containing images, prompt, data_samples and image_position.
|
||||
"""
|
||||
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
data_samples = batch.pop('data_samples')
|
||||
assert len(batch['inputs']) == 1
|
||||
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||
data_sample = batch.pop('data_samples')[0]
|
||||
|
||||
# generate text prompt
|
||||
prompt = ['<img></img>' + self.prompt for i in range(images.shape[0])]
|
||||
prompt = '<img></img>' + self.human_prompt + self.prompt + self.assistant_prompt # noqa
|
||||
|
||||
image_position = 5
|
||||
image_position = prompt.rfind('<img>') + 5
|
||||
|
||||
return images, prompt, data_samples, image_position
|
||||
return image, prompt, data_sample, image_position
|
||||
|
||||
|
||||
class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
@ -94,10 +85,15 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
The prompt will concat <img>, the question and the system prompt.
|
||||
Args:
|
||||
system_prompt (str): System prompt. (Default: '')
|
||||
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||
"""
|
||||
|
||||
def __init__(self, system_prompt='') -> None:
|
||||
super().__init__(system_prompt)
|
||||
def __init__(self,
|
||||
system_prompt='',
|
||||
human_prompt: str = 'Q:',
|
||||
assistant_prompt: str = 'A:') -> None:
|
||||
super().__init__(system_prompt, human_prompt, assistant_prompt)
|
||||
|
||||
def __call__(self, batch: dict) -> tuple:
|
||||
"""Construct prompt.
|
||||
@ -109,19 +105,18 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
A tuple containing images, prompt, data_samples and image_position.
|
||||
"""
|
||||
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
data_samples = batch.pop('data_samples')
|
||||
questions = [sample.get('question') for sample in data_samples]
|
||||
assert len(batch['inputs']) == 1
|
||||
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||
data_sample = batch.pop('data_samples')[0]
|
||||
|
||||
# generate text prompt
|
||||
prompt = [
|
||||
'<img></img>Q:{} {}\nA:'.format(question, self.prompt)
|
||||
for question in questions
|
||||
]
|
||||
image_position = 5
|
||||
question = data_sample.get('question')
|
||||
prompt = '<img></img>' + self.human_prompt + question + self.prompt
|
||||
prompt += '\n' + self.assistant_prompt
|
||||
|
||||
return images, prompt, data_samples, image_position
|
||||
image_position = prompt.rfind('<img>') + 5
|
||||
|
||||
return image, prompt, data_sample, image_position
|
||||
|
||||
|
||||
class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
@ -130,12 +125,17 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
The prompt will concat image and all terms in a question.
|
||||
Args:
|
||||
system_prompt (str): System prompt. (Default: '')
|
||||
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||
"""
|
||||
|
||||
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
|
||||
|
||||
def __init__(self, system_prompt='') -> None:
|
||||
super().__init__(system_prompt)
|
||||
def __init__(self,
|
||||
system_prompt='',
|
||||
human_prompt: str = 'Q:',
|
||||
assistant_prompt: str = 'A:') -> None:
|
||||
super().__init__(system_prompt, human_prompt, assistant_prompt)
|
||||
|
||||
def __call__(self, batch: dict) -> tuple:
|
||||
"""Construct prompt.
|
||||
@ -147,33 +147,24 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
A tuple containing images, prompt, data_samples and image_position.
|
||||
"""
|
||||
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
data_samples = batch.pop('data_samples')
|
||||
questions = [
|
||||
'Q: ' + sample.get('question') + '\n' for sample in data_samples
|
||||
]
|
||||
choices = [sample.get('choices') for sample in data_samples]
|
||||
choices = [[
|
||||
f'({self.choice_mapping[i]}) ' + item
|
||||
for i, item in enumerate(choice)
|
||||
] for choice in choices]
|
||||
assert len(batch['inputs']) == 1
|
||||
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||
data_sample = batch.pop('data_samples')[0]
|
||||
|
||||
questions = 'Question: ' + data_sample.get('question')
|
||||
choices = data_sample.get('choices')
|
||||
choices = [
|
||||
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
|
||||
] # noqa
|
||||
contexts = [
|
||||
'Context: ' + data_sample.get('hint') + '\n'
|
||||
for data_sample in data_samples
|
||||
] # noqa
|
||||
f'({self.choice_mapping[i]}) ' + item
|
||||
for i, item in enumerate(choices)
|
||||
]
|
||||
choices = 'Choices: ' + ' '.join(choices) + '\n'
|
||||
contexts = 'Context: ' + data_sample.get('hint') + '\n'
|
||||
|
||||
# generate text prompt
|
||||
prompt = [
|
||||
'<img></img>' + context + question + choice + self.prompt
|
||||
for context, question, choice in zip(contexts, questions, choices)
|
||||
]
|
||||
image_position = 5
|
||||
prompt = '<img></img>' + self.human_prompt + contexts + questions + choices + self.prompt + self.assistant_prompt # noqa
|
||||
image_position = prompt.rfind('<img>') + 5
|
||||
|
||||
return images, prompt, data_samples, image_position
|
||||
return image, prompt, data_sample, image_position
|
||||
|
||||
|
||||
class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
@ -182,10 +173,15 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
The prompt will concat <img>, the question and the system prompt.
|
||||
Args:
|
||||
system_prompt (str): System prompt. (Default: '')
|
||||
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||
"""
|
||||
|
||||
def __init__(self, system_prompt='') -> None:
|
||||
super().__init__(system_prompt)
|
||||
def __init__(self,
|
||||
system_prompt='',
|
||||
human_prompt: str = 'Q:',
|
||||
assistant_prompt: str = 'A:') -> None:
|
||||
super().__init__(system_prompt, human_prompt, assistant_prompt)
|
||||
|
||||
def __call__(self, batch: dict) -> tuple:
|
||||
"""Construct prompt.
|
||||
@ -197,22 +193,16 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
|
||||
A tuple containing images, prompt, data_samples and image_position.
|
||||
"""
|
||||
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
data_samples = batch.pop('data_samples')
|
||||
questions = [
|
||||
'Q: ' + sample.get('question') + '\n' for sample in data_samples
|
||||
]
|
||||
choices = [sample.get('choices') for sample in data_samples]
|
||||
choices = [
|
||||
'Options: ' + ', '.join(choice) + '.\n' for choice in choices
|
||||
] # noqa
|
||||
assert len(batch['inputs']) == 1
|
||||
image = batch.pop('inputs')[0].unsqueeze(0)
|
||||
data_sample = batch.pop('data_samples')[0]
|
||||
|
||||
questions = data_sample.get('question') + '\n'
|
||||
choices = data_sample.get('choices')
|
||||
choices = 'Options: ' + ', '.join(choices) + '.\n'
|
||||
|
||||
# generate text prompt
|
||||
prompt = [
|
||||
'<img></img>' + question + choice + self.prompt
|
||||
for question, choice in zip(questions, choices)
|
||||
]
|
||||
image_position = 5
|
||||
prompt = '<img></img>' + self.human_prompt + questions + choices + self.prompt + self.assistant_prompt # noqa
|
||||
image_position = prompt.rfind('<img>') + 5
|
||||
|
||||
return images, prompt, data_samples, image_position
|
||||
return image, prompt, data_sample, image_position
|
||||
|
@ -43,39 +43,31 @@ class VisualGLM(nn.Module):
|
||||
if gen_kwargs:
|
||||
self.gen_kwargs = gen_kwargs
|
||||
else:
|
||||
self.gen_kwargs = dict(
|
||||
max_new_tokens=30,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=-1.0,
|
||||
)
|
||||
self.gen_kwargs = dict(max_length=1024,
|
||||
min_length=100,
|
||||
do_sample=True,
|
||||
temperature=0.8,
|
||||
top_p=0.4,
|
||||
top_k=100,
|
||||
repetition_penalty=1.2)
|
||||
|
||||
self.is_caption_task = is_caption_task
|
||||
|
||||
def encode_by_tokenizer(self, multi_prompts, image_position):
|
||||
input_ids = []
|
||||
max_seq_length = 0
|
||||
for prompt in multi_prompts:
|
||||
def encode_by_tokenizer(self, prompt, image_position):
|
||||
|
||||
input0 = self.tokenizer.encode(prompt[:image_position],
|
||||
add_special_tokens=False)
|
||||
input1 = [self.tokenizer.pad_token_id] * self.model.image_length
|
||||
input1 = [self.tokenizer.unk_token_id] * self.model.image_length
|
||||
input2 = self.tokenizer.encode(prompt[image_position:],
|
||||
add_special_tokens=False)
|
||||
input_all = sum([input0, input1, input2], [])
|
||||
input_all = self.tokenizer.build_inputs_with_special_tokens(
|
||||
input_all)
|
||||
max_seq_length = max(max_seq_length, len(input_all))
|
||||
input_ids.append(input_all)
|
||||
input_all = self.tokenizer.build_inputs_with_special_tokens(input_all)
|
||||
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
|
||||
input_all = input_all.unsqueeze(0)
|
||||
|
||||
pre_image_len = len(input0)
|
||||
|
||||
# padding
|
||||
for i, _ in enumerate(input_ids):
|
||||
pad_len = max_seq_length - len(input_ids[i])
|
||||
input_ids[i] = [self.tokenizer.pad_token_id
|
||||
] * pad_len + input_ids[i]
|
||||
|
||||
return input_ids, pre_image_len
|
||||
return input_all, pre_image_len
|
||||
|
||||
def generate(self, batch):
|
||||
# process input
|
||||
@ -87,26 +79,24 @@ class VisualGLM(nn.Module):
|
||||
input_all, pre_image_len = self.encode_by_tokenizer(
|
||||
prompt, image_position)
|
||||
|
||||
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
|
||||
|
||||
# build input param
|
||||
inputs = {
|
||||
'input_ids': input_all,
|
||||
'pre_image_length': pre_image_len,
|
||||
'images': image
|
||||
}
|
||||
|
||||
# generate answer
|
||||
outputs = self.model.generate(**inputs, **self.gen_kwargs)
|
||||
|
||||
# format output
|
||||
outputs = outputs.tolist()
|
||||
for i, sample in enumerate(data_sample):
|
||||
answer = self.post_processor(outputs[i], self.tokenizer,
|
||||
input_all.shape[1])
|
||||
outputs = outputs.tolist()[0][input_all.shape[1]:]
|
||||
answer = self.post_processor(outputs, self.tokenizer)
|
||||
|
||||
if self.is_caption_task:
|
||||
data_sample[i].pred_caption = answer
|
||||
data_sample.pred_caption = answer
|
||||
else:
|
||||
data_sample[i].pred_answer = answer
|
||||
data_sample.pred_answer = answer
|
||||
|
||||
return data_sample
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user