[Fix] Fix performance issue of visualglm. (#424)

* [Fix] Visualglm performance fixed.

* [Fix] Hide ckpt path.
This commit is contained in:
Yike Yuan 2023-09-21 19:54:23 +08:00 committed by GitHub
parent 8803f7f7a6
commit 97fdc51102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 135 deletions

View File

@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True, is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'), prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
post_processor=dict(type=VisualGLMBasePostProcessor) post_processor=dict(type=VisualGLMBasePostProcessor)
) )

View File

@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True, is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'), prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'),
post_processor=dict(type=VisualGLMBasePostProcessor) post_processor=dict(type=VisualGLMBasePostProcessor)
) )

View File

@ -9,9 +9,8 @@ class VisualGLMBasePostProcessor:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def __call__(self, output_token: torch.tensor, tokenizer: Any, def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
input_len: int) -> str: return tokenizer.decode(output_token)
return tokenizer.decode(output_token[input_len:])
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor): class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
@ -20,9 +19,8 @@ class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def __call__(self, output_token: torch.tensor, tokenizer: Any, def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
input_len: int) -> str: output_text = tokenizer.decode(output_token)
output_text = tokenizer.decode(output_token[input_len:])
if 'yes' in output_text.lower(): if 'yes' in output_text.lower():
return 'yes' return 'yes'
elif 'no' in output_text.lower(): elif 'no' in output_text.lower():

View File

@ -1,24 +1,16 @@
import torch
class VisualGLMMMBenchPromptConstructor: class VisualGLMMMBenchPromptConstructor:
"""MMBench prompt constructor for VisualGLM. """MMBench prompt constructor for VisualGLM.
The overall prompt will be formulated as
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
Args: Args:
system_prompt (str): System prompt. (Default: '') system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:') human_prompt (str): Human prompt. (Default: 'Q:')
image_prompt (str): Image prompt. (Default: '<img></img>')
assistant_prompt (str): Assistant prompt. (Default: 'A:') assistant_prompt (str): Assistant prompt. (Default: 'A:')
""" """
def __init__(self, def __init__(self,
system_prompt: str = '', system_prompt: str = '',
human_prompt: str = 'Q:', human_prompt: str = 'Q:',
image_prompt: str = '<img></img>',
assistant_prompt: str = 'A:') -> None: assistant_prompt: str = 'A:') -> None:
self.image_prompt = image_prompt
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.human_prompt = human_prompt self.human_prompt = human_prompt
self.assistant_prompt = assistant_prompt self.assistant_prompt = assistant_prompt
@ -33,26 +25,18 @@ class VisualGLMMMBenchPromptConstructor:
A tuple containing images, prompt, data_samples and image_position. A tuple containing images, prompt, data_samples and image_position.
""" """
images = batch.pop('inputs') assert len(batch['inputs']) == 1
images = torch.stack(images, dim=0) image = batch.pop('inputs')[0].unsqueeze(0)
data_sample = batch.pop('data_samples')[0]
img_prompt = '<img></img>'
if data_sample.get('context') is not None:
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.context + ' ' + data_sample.question + ' ' + data_sample.options # noqa
else:
prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.question + ' ' + data_sample.options # noqa
prompt += self.assistant_prompt
image_position = prompt.rfind('<img>') + 5
data_samples = batch.pop('data_samples') return image, prompt, data_sample, image_position
questions = [sample.get('question') for sample in data_samples]
options = [sample.get('options') for sample in data_samples]
contexts = [sample.get('context') for sample in data_samples]
contexts = [c if c else '' for c in contexts]
# generate text prompt
prompt = [
'{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
self.human_prompt, context, question,
option, self.assistant_prompt)
for context, question, option in zip(contexts, questions, options)
]
image_position = 5
return images, prompt, data_samples, image_position
class VisualGLMBasePromptConstructor: class VisualGLMBasePromptConstructor:
@ -61,10 +45,17 @@ class VisualGLMBasePromptConstructor:
The prompt will concat <img> and the given system prompt. The prompt will concat <img> and the given system prompt.
Args: Args:
system_prompt (str): System prompt. (Default: '') system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
""" """
def __init__(self, system_prompt='') -> None: def __init__(self,
system_prompt: str = '',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
self.prompt = system_prompt self.prompt = system_prompt
self.human_prompt = human_prompt
self.assistant_prompt = assistant_prompt
def __call__(self, batch: dict) -> tuple: def __call__(self, batch: dict) -> tuple:
"""Construct prompt. """Construct prompt.
@ -76,16 +67,16 @@ class VisualGLMBasePromptConstructor:
A tuple containing images, prompt, data_samples and image_position. A tuple containing images, prompt, data_samples and image_position.
""" """
images = batch.pop('inputs') assert len(batch['inputs']) == 1
images = torch.stack(images, dim=0) image = batch.pop('inputs')[0].unsqueeze(0)
data_samples = batch.pop('data_samples') data_sample = batch.pop('data_samples')[0]
# generate text prompt # generate text prompt
prompt = ['<img></img>' + self.prompt for i in range(images.shape[0])] prompt = '<img></img>' + self.human_prompt + self.prompt + self.assistant_prompt # noqa
image_position = 5 image_position = prompt.rfind('<img>') + 5
return images, prompt, data_samples, image_position return image, prompt, data_sample, image_position
class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor): class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
@ -94,10 +85,15 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat <img>, the question and the system prompt. The prompt will concat <img>, the question and the system prompt.
Args: Args:
system_prompt (str): System prompt. (Default: '') system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
""" """
def __init__(self, system_prompt='') -> None: def __init__(self,
super().__init__(system_prompt) system_prompt='',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple: def __call__(self, batch: dict) -> tuple:
"""Construct prompt. """Construct prompt.
@ -109,19 +105,18 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position. A tuple containing images, prompt, data_samples and image_position.
""" """
images = batch.pop('inputs') assert len(batch['inputs']) == 1
images = torch.stack(images, dim=0) image = batch.pop('inputs')[0].unsqueeze(0)
data_samples = batch.pop('data_samples') data_sample = batch.pop('data_samples')[0]
questions = [sample.get('question') for sample in data_samples]
# generate text prompt # generate text prompt
prompt = [ question = data_sample.get('question')
'<img></img>Q:{} {}\nA:'.format(question, self.prompt) prompt = '<img></img>' + self.human_prompt + question + self.prompt
for question in questions prompt += '\n' + self.assistant_prompt
]
image_position = 5
return images, prompt, data_samples, image_position image_position = prompt.rfind('<img>') + 5
return image, prompt, data_sample, image_position
class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor): class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
@ -130,12 +125,17 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat image and all terms in a question. The prompt will concat image and all terms in a question.
Args: Args:
system_prompt (str): System prompt. (Default: '') system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
""" """
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'} choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, system_prompt='') -> None: def __init__(self,
super().__init__(system_prompt) system_prompt='',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple: def __call__(self, batch: dict) -> tuple:
"""Construct prompt. """Construct prompt.
@ -147,33 +147,24 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position. A tuple containing images, prompt, data_samples and image_position.
""" """
images = batch.pop('inputs') assert len(batch['inputs']) == 1
images = torch.stack(images, dim=0) image = batch.pop('inputs')[0].unsqueeze(0)
data_samples = batch.pop('data_samples') data_sample = batch.pop('data_samples')[0]
questions = [
'Q: ' + sample.get('question') + '\n' for sample in data_samples questions = 'Question: ' + data_sample.get('question')
] choices = data_sample.get('choices')
choices = [sample.get('choices') for sample in data_samples]
choices = [[
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choice)
] for choice in choices]
choices = [ choices = [
'Choices: ' + ' '.join(choice) + '\n' for choice in choices f'({self.choice_mapping[i]}) ' + item
] # noqa for i, item in enumerate(choices)
contexts = [ ]
'Context: ' + data_sample.get('hint') + '\n' choices = 'Choices: ' + ' '.join(choices) + '\n'
for data_sample in data_samples contexts = 'Context: ' + data_sample.get('hint') + '\n'
] # noqa
# generate text prompt # generate text prompt
prompt = [ prompt = '<img></img>' + self.human_prompt + contexts + questions + choices + self.prompt + self.assistant_prompt # noqa
'<img></img>' + context + question + choice + self.prompt image_position = prompt.rfind('<img>') + 5
for context, question, choice in zip(contexts, questions, choices)
]
image_position = 5
return images, prompt, data_samples, image_position return image, prompt, data_sample, image_position
class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor): class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
@ -182,10 +173,15 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat <img>, the question and the system prompt. The prompt will concat <img>, the question and the system prompt.
Args: Args:
system_prompt (str): System prompt. (Default: '') system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
""" """
def __init__(self, system_prompt='') -> None: def __init__(self,
super().__init__(system_prompt) system_prompt='',
human_prompt: str = 'Q:',
assistant_prompt: str = 'A:') -> None:
super().__init__(system_prompt, human_prompt, assistant_prompt)
def __call__(self, batch: dict) -> tuple: def __call__(self, batch: dict) -> tuple:
"""Construct prompt. """Construct prompt.
@ -197,22 +193,16 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position. A tuple containing images, prompt, data_samples and image_position.
""" """
images = batch.pop('inputs') assert len(batch['inputs']) == 1
images = torch.stack(images, dim=0) image = batch.pop('inputs')[0].unsqueeze(0)
data_samples = batch.pop('data_samples') data_sample = batch.pop('data_samples')[0]
questions = [
'Q: ' + sample.get('question') + '\n' for sample in data_samples questions = data_sample.get('question') + '\n'
] choices = data_sample.get('choices')
choices = [sample.get('choices') for sample in data_samples] choices = 'Options: ' + ', '.join(choices) + '.\n'
choices = [
'Options: ' + ', '.join(choice) + '.\n' for choice in choices
] # noqa
# generate text prompt # generate text prompt
prompt = [ prompt = '<img></img>' + self.human_prompt + questions + choices + self.prompt + self.assistant_prompt # noqa
'<img></img>' + question + choice + self.prompt image_position = prompt.rfind('<img>') + 5
for question, choice in zip(questions, choices)
]
image_position = 5
return images, prompt, data_samples, image_position return image, prompt, data_sample, image_position

View File

@ -43,39 +43,31 @@ class VisualGLM(nn.Module):
if gen_kwargs: if gen_kwargs:
self.gen_kwargs = gen_kwargs self.gen_kwargs = gen_kwargs
else: else:
self.gen_kwargs = dict( self.gen_kwargs = dict(max_length=1024,
max_new_tokens=30, min_length=100,
num_beams=1, do_sample=True,
do_sample=False, temperature=0.8,
repetition_penalty=1.0, top_p=0.4,
length_penalty=-1.0, top_k=100,
) repetition_penalty=1.2)
self.is_caption_task = is_caption_task self.is_caption_task = is_caption_task
def encode_by_tokenizer(self, multi_prompts, image_position): def encode_by_tokenizer(self, prompt, image_position):
input_ids = []
max_seq_length = 0
for prompt in multi_prompts:
input0 = self.tokenizer.encode(prompt[:image_position], input0 = self.tokenizer.encode(prompt[:image_position],
add_special_tokens=False) add_special_tokens=False)
input1 = [self.tokenizer.pad_token_id] * self.model.image_length input1 = [self.tokenizer.unk_token_id] * self.model.image_length
input2 = self.tokenizer.encode(prompt[image_position:], input2 = self.tokenizer.encode(prompt[image_position:],
add_special_tokens=False) add_special_tokens=False)
input_all = sum([input0, input1, input2], []) input_all = sum([input0, input1, input2], [])
input_all = self.tokenizer.build_inputs_with_special_tokens( input_all = self.tokenizer.build_inputs_with_special_tokens(input_all)
input_all) input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
max_seq_length = max(max_seq_length, len(input_all)) input_all = input_all.unsqueeze(0)
input_ids.append(input_all)
pre_image_len = len(input0) pre_image_len = len(input0)
# padding return input_all, pre_image_len
for i, _ in enumerate(input_ids):
pad_len = max_seq_length - len(input_ids[i])
input_ids[i] = [self.tokenizer.pad_token_id
] * pad_len + input_ids[i]
return input_ids, pre_image_len
def generate(self, batch): def generate(self, batch):
# process input # process input
@ -87,26 +79,24 @@ class VisualGLM(nn.Module):
input_all, pre_image_len = self.encode_by_tokenizer( input_all, pre_image_len = self.encode_by_tokenizer(
prompt, image_position) prompt, image_position)
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
# build input param # build input param
inputs = { inputs = {
'input_ids': input_all, 'input_ids': input_all,
'pre_image_length': pre_image_len, 'pre_image_length': pre_image_len,
'images': image 'images': image
} }
# generate answer # generate answer
outputs = self.model.generate(**inputs, **self.gen_kwargs) outputs = self.model.generate(**inputs, **self.gen_kwargs)
# format output # format output
outputs = outputs.tolist() outputs = outputs.tolist()[0][input_all.shape[1]:]
for i, sample in enumerate(data_sample): answer = self.post_processor(outputs, self.tokenizer)
answer = self.post_processor(outputs[i], self.tokenizer,
input_all.shape[1])
if self.is_caption_task: if self.is_caption_task:
data_sample[i].pred_caption = answer data_sample.pred_caption = answer
else: else:
data_sample[i].pred_answer = answer data_sample.pred_answer = answer
return data_sample return data_sample