[Fix] Fix performance issue of visualglm. (#424)

* [Fix] Visualglm performance fixed.

* [Fix] Hide ckpt path.
This commit is contained in:
Yike Yuan 2023-09-21 19:54:23 +08:00 committed by GitHub
parent 8803f7f7a6
commit 97fdc51102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 135 deletions

View File

@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
post_processor=dict(type=VisualGLMBasePostProcessor)
)

View File

@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
post_processor=dict(type=VisualGLMBasePostProcessor)
)

View File

@ -9,9 +9,8 @@ class VisualGLMBasePostProcessor:
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer: Any,
input_len: int) -> str:
return tokenizer.decode(output_token[input_len:])
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
return tokenizer.decode(output_token)
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
@ -20,9 +19,8 @@ class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
def __init__(self) -> None:
super().__init__()
def __call__(self, output_token: torch.tensor, tokenizer: Any,
input_len: int) -> str:
output_text = tokenizer.decode(output_token[input_len:])
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
output_text = tokenizer.decode(output_token)
if 'yes' in output_text.lower():
return 'yes'
elif 'no' in output_text.lower():

View File

@ -1,24 +1,16 @@
import torch
class VisualGLMMMBenchPromptConstructor:
"""MMBench prompt constructor for VisualGLM.
The overall prompt will be formulated as
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
image_prompt (str): Image prompt. (Default: '<img></img>')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def __init__(self,
system_prompt: str = '',
human_prompt: str = 'Q:',
image_prompt: str = '<img></img>',
assistant_prompt: str = 'A:') -> None:
self.image_prompt = image_prompt
self.system_prompt = system_prompt
self.human_prompt = human_prompt
self.assistant_prompt = assistant_prompt
@ -33,26 +25,18 @@ class VisualGLMMMBenchPromptConstructor:
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
assert len(batch['inputs']) == 1
image = batch.pop('inputs')[0].unsqueeze(0)
data_sample = batch.pop('data_samples')[0]
img_prompt = '<img></img>'
if data_sample.get('context') is not None:
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.context + ' ' + data_sample.question + ' ' + data_sample.options # noqa
else:
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.question + ' ' + data_sample.options # noqa
prompt += self.assistant_prompt
image_position = prompt.rfind('<img>') + 5
data_samples = batch.pop('data_samples')
questions = [sample.get('question') for sample in data_samples]
options = [sample.get('options') for sample in data_samples]
contexts = [sample.get('context') for sample in data_samples]
contexts = [c if c else '' for c in contexts]
# generate text prompt
prompt = [
'{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
self.human_prompt, context, question,
option, self.assistant_prompt)
for context, question, option in zip(contexts, questions, options)
]
image_position = 5
return images, prompt, data_samples, image_position
return image, prompt, data_sample, image_position
class VisualGLMBasePromptConstructor:
@ -61,10 +45,17 @@ class VisualGLMBasePromptConstructor:
The prompt will concat <img> and the given system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def __init__(self, system_prompt='') -> None:
def __init__(self,
system_prompt: str = '',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
self.prompt = system_prompt
self.human_prompt = human_prompt
self.assistant_prompt = assistant_prompt
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@ -76,16 +67,16 @@ class VisualGLMBasePromptConstructor:
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
assert len(batch['inputs']) == 1
image = batch.pop('inputs')[0].unsqueeze(0)
data_sample = batch.pop('data_samples')[0]
# generate text prompt
prompt = ['<img></img>' + self.prompt for i in range(images.shape[0])]
prompt = '<img></img>' + self.human_prompt + self.prompt + self.assistant_prompt # noqa
image_position = 5
image_position = prompt.rfind('<img>') + 5
return images, prompt, data_samples, image_position
return image, prompt, data_sample, image_position
class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
@ -94,10 +85,15 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat <img>, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def __init__(self, system_prompt='') -> None:
super().__init__(system_prompt)
def __init__(self,
system_prompt='',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@ -109,19 +105,18 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [sample.get('question') for sample in data_samples]
assert len(batch['inputs']) == 1
image = batch.pop('inputs')[0].unsqueeze(0)
data_sample = batch.pop('data_samples')[0]
# generate text prompt
prompt = [
'<img></img>Q:{} {}\nA:'.format(question, self.prompt)
for question in questions
]
image_position = 5
question = data_sample.get('question')
prompt = '<img></img>' + self.human_prompt + question + self.prompt
prompt += '\n' + self.assistant_prompt
return images, prompt, data_samples, image_position
image_position = prompt.rfind('<img>') + 5
return image, prompt, data_sample, image_position
class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
@ -130,12 +125,17 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat image and all terms in a question.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, system_prompt='') -> None:
super().__init__(system_prompt)
def __init__(self,
system_prompt='',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@ -147,33 +147,24 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [
'Q: ' + sample.get('question') + '\n' for sample in data_samples
]
choices = [sample.get('choices') for sample in data_samples]
choices = [[
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choice)
] for choice in choices]
assert len(batch['inputs']) == 1
image = batch.pop('inputs')[0].unsqueeze(0)
data_sample = batch.pop('data_samples')[0]
questions = 'Question: ' + data_sample.get('question')
choices = data_sample.get('choices')
choices = [
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
] # noqa
contexts = [
'Context: ' + data_sample.get('hint') + '\n'
for data_sample in data_samples
] # noqa
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
choices = 'Choices: ' + ' '.join(choices) + '\n'
contexts = 'Context: ' + data_sample.get('hint') + '\n'
# generate text prompt
prompt = [
'<img></img>' + context + question + choice + self.prompt
for context, question, choice in zip(contexts, questions, choices)
]
image_position = 5
prompt = '<img></img>' + self.human_prompt + contexts + questions + choices + self.prompt + self.assistant_prompt # noqa
image_position = prompt.rfind('<img>') + 5
return images, prompt, data_samples, image_position
return image, prompt, data_sample, image_position
class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
@ -182,10 +173,15 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat <img>, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def __init__(self, system_prompt='') -> None:
super().__init__(system_prompt)
def __init__(self,
system_prompt='',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
@ -197,22 +193,16 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [
'Q: ' + sample.get('question') + '\n' for sample in data_samples
]
choices = [sample.get('choices') for sample in data_samples]
choices = [
'Options: ' + ', '.join(choice) + '.\n' for choice in choices
] # noqa
assert len(batch['inputs']) == 1
image = batch.pop('inputs')[0].unsqueeze(0)
data_sample = batch.pop('data_samples')[0]
questions = data_sample.get('question') + '\n'
choices = data_sample.get('choices')
choices = 'Options: ' + ', '.join(choices) + '.\n'
# generate text prompt
prompt = [
'<img></img>' + question + choice + self.prompt
for question, choice in zip(questions, choices)
]
image_position = 5
prompt = '<img></img>' + self.human_prompt + questions + choices + self.prompt + self.assistant_prompt # noqa
image_position = prompt.rfind('<img>') + 5
return images, prompt, data_samples, image_position
return image, prompt, data_sample, image_position

View File

@ -43,39 +43,31 @@ class VisualGLM(nn.Module):
if gen_kwargs:
self.gen_kwargs = gen_kwargs
else:
self.gen_kwargs = dict(
max_new_tokens=30,
num_beams=1,
do_sample=False,
repetition_penalty=1.0,
length_penalty=-1.0,
)
self.gen_kwargs = dict(max_length=1024,
min_length=100,
do_sample=True,
temperature=0.8,
top_p=0.4,
top_k=100,
repetition_penalty=1.2)
self.is_caption_task = is_caption_task
def encode_by_tokenizer(self, multi_prompts, image_position):
input_ids = []
max_seq_length = 0
for prompt in multi_prompts:
def encode_by_tokenizer(self, prompt, image_position):
input0 = self.tokenizer.encode(prompt[:image_position],
add_special_tokens=False)
input1 = [self.tokenizer.pad_token_id] * self.model.image_length
input1 = [self.tokenizer.unk_token_id] * self.model.image_length
input2 = self.tokenizer.encode(prompt[image_position:],
add_special_tokens=False)
input_all = sum([input0, input1, input2], [])
input_all = self.tokenizer.build_inputs_with_special_tokens(
input_all)
max_seq_length = max(max_seq_length, len(input_all))
input_ids.append(input_all)
input_all = self.tokenizer.build_inputs_with_special_tokens(input_all)
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
input_all = input_all.unsqueeze(0)
pre_image_len = len(input0)
# padding
for i, _ in enumerate(input_ids):
pad_len = max_seq_length - len(input_ids[i])
input_ids[i] = [self.tokenizer.pad_token_id
] * pad_len + input_ids[i]
return input_ids, pre_image_len
return input_all, pre_image_len
def generate(self, batch):
# process input
@ -87,26 +79,24 @@ class VisualGLM(nn.Module):
input_all, pre_image_len = self.encode_by_tokenizer(
prompt, image_position)
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
# build input param
inputs = {
'input_ids': input_all,
'pre_image_length': pre_image_len,
'images': image
}
# generate answer
outputs = self.model.generate(**inputs, **self.gen_kwargs)
# format output
outputs = outputs.tolist()
for i, sample in enumerate(data_sample):
answer = self.post_processor(outputs[i], self.tokenizer,
input_all.shape[1])
outputs = outputs.tolist()[0][input_all.shape[1]:]
answer = self.post_processor(outputs, self.tokenizer)
if self.is_caption_task:
data_sample[i].pred_caption = answer
data_sample.pred_caption = answer
else:
data_sample[i].pred_answer = answer
data_sample.pred_answer = answer
return data_sample