diff --git a/configs/multimodal/visualglm/visualglm_6b_coco_caption.py b/configs/multimodal/visualglm/visualglm_6b_coco_caption.py index 66e0801c..c2ffa6a9 100644 --- a/configs/multimodal/visualglm/visualglm_6b_coco_caption.py +++ b/configs/multimodal/visualglm/visualglm_6b_coco_caption.py @@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict( type='visualglm', pretrained_path='/path/to/visualglm', # or Huggingface repo id is_caption_task=True, - prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'), + prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'), post_processor=dict(type=VisualGLMBasePostProcessor) ) diff --git a/configs/multimodal/visualglm/visualglm_6b_flickr30k.py b/configs/multimodal/visualglm/visualglm_6b_flickr30k.py index 58ab4649..9860ba78 100644 --- a/configs/multimodal/visualglm/visualglm_6b_flickr30k.py +++ b/configs/multimodal/visualglm/visualglm_6b_flickr30k.py @@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict( type='visualglm', pretrained_path='/path/to/visualglm', # or Huggingface repo id is_caption_task=True, - prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'), + prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='Describe the image.'), post_processor=dict(type=VisualGLMBasePostProcessor) ) diff --git a/opencompass/multimodal/models/visualglm/post_processor.py b/opencompass/multimodal/models/visualglm/post_processor.py index 4ff3b4a8..8289fc8a 100644 --- a/opencompass/multimodal/models/visualglm/post_processor.py +++ b/opencompass/multimodal/models/visualglm/post_processor.py @@ -9,9 +9,8 @@ class VisualGLMBasePostProcessor: def __init__(self) -> None: pass - def __call__(self, output_token: torch.tensor, tokenizer: Any, - input_len: int) -> str: - return tokenizer.decode(output_token[input_len:]) + def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str: + return tokenizer.decode(output_token) class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor): @@ -20,9 +19,8 @@ class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor): def __init__(self) -> None: super().__init__() - def __call__(self, output_token: torch.tensor, tokenizer: Any, - input_len: int) -> str: - output_text = tokenizer.decode(output_token[input_len:]) + def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str: + output_text = tokenizer.decode(output_token) if 'yes' in output_text.lower(): return 'yes' elif 'no' in output_text.lower(): diff --git a/opencompass/multimodal/models/visualglm/prompt_constructor.py b/opencompass/multimodal/models/visualglm/prompt_constructor.py index a10b7d77..68bea8e4 100644 --- a/opencompass/multimodal/models/visualglm/prompt_constructor.py +++ b/opencompass/multimodal/models/visualglm/prompt_constructor.py @@ -1,24 +1,16 @@ -import torch - - class VisualGLMMMBenchPromptConstructor: """MMBench prompt constructor for VisualGLM. - The overall prompt will be formulated as - "system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt". Args: system_prompt (str): System prompt. (Default: '') human_prompt (str): Human prompt. (Default: 'Q:') - image_prompt (str): Image prompt. (Default: '') assistant_prompt (str): Assistant prompt. (Default: 'A:') """ def __init__(self, system_prompt: str = '', human_prompt: str = 'Q:', - image_prompt: str = '', assistant_prompt: str = 'A:') -> None: - self.image_prompt = image_prompt self.system_prompt = system_prompt self.human_prompt = human_prompt self.assistant_prompt = assistant_prompt @@ -33,26 +25,18 @@ class VisualGLMMMBenchPromptConstructor: A tuple containing images, prompt, data_samples and image_position. """ - images = batch.pop('inputs') - images = torch.stack(images, dim=0) + assert len(batch['inputs']) == 1 + image = batch.pop('inputs')[0].unsqueeze(0) + data_sample = batch.pop('data_samples')[0] + img_prompt = '' + if data_sample.get('context') is not None: + prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.context + ' ' + data_sample.question + ' ' + data_sample.options # noqa + else: + prompt = img_prompt + self.system_prompt + self.human_prompt + data_sample.question + ' ' + data_sample.options # noqa + prompt += self.assistant_prompt + image_position = prompt.rfind('') + 5 - data_samples = batch.pop('data_samples') - questions = [sample.get('question') for sample in data_samples] - options = [sample.get('options') for sample in data_samples] - contexts = [sample.get('context') for sample in data_samples] - contexts = [c if c else '' for c in contexts] - - # generate text prompt - prompt = [ - '{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt, - self.human_prompt, context, question, - option, self.assistant_prompt) - for context, question, option in zip(contexts, questions, options) - ] - - image_position = 5 - - return images, prompt, data_samples, image_position + return image, prompt, data_sample, image_position class VisualGLMBasePromptConstructor: @@ -61,10 +45,17 @@ class VisualGLMBasePromptConstructor: The prompt will concat and the given system prompt. Args: system_prompt (str): System prompt. (Default: '') + human_prompt (str): Human prompt. (Default: 'Q:') + assistant_prompt (str): Assistant prompt. (Default: 'A:') """ - def __init__(self, system_prompt='') -> None: + def __init__(self, + system_prompt: str = '', + human_prompt: str = 'Q:', + assistant_prompt: str = 'A:') -> None: self.prompt = system_prompt + self.human_prompt = human_prompt + self.assistant_prompt = assistant_prompt def __call__(self, batch: dict) -> tuple: """Construct prompt. @@ -76,16 +67,16 @@ class VisualGLMBasePromptConstructor: A tuple containing images, prompt, data_samples and image_position. """ - images = batch.pop('inputs') - images = torch.stack(images, dim=0) - data_samples = batch.pop('data_samples') + assert len(batch['inputs']) == 1 + image = batch.pop('inputs')[0].unsqueeze(0) + data_sample = batch.pop('data_samples')[0] # generate text prompt - prompt = ['' + self.prompt for i in range(images.shape[0])] + prompt = '' + self.human_prompt + self.prompt + self.assistant_prompt # noqa - image_position = 5 + image_position = prompt.rfind('') + 5 - return images, prompt, data_samples, image_position + return image, prompt, data_sample, image_position class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor): @@ -94,10 +85,15 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor): The prompt will concat , the question and the system prompt. Args: system_prompt (str): System prompt. (Default: '') + human_prompt (str): Human prompt. (Default: 'Q:') + assistant_prompt (str): Assistant prompt. (Default: 'A:') """ - def __init__(self, system_prompt='') -> None: - super().__init__(system_prompt) + def __init__(self, + system_prompt='', + human_prompt: str = 'Q:', + assistant_prompt: str = 'A:') -> None: + super().__init__(system_prompt, human_prompt, assistant_prompt) def __call__(self, batch: dict) -> tuple: """Construct prompt. @@ -109,19 +105,18 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor): A tuple containing images, prompt, data_samples and image_position. """ - images = batch.pop('inputs') - images = torch.stack(images, dim=0) - data_samples = batch.pop('data_samples') - questions = [sample.get('question') for sample in data_samples] + assert len(batch['inputs']) == 1 + image = batch.pop('inputs')[0].unsqueeze(0) + data_sample = batch.pop('data_samples')[0] # generate text prompt - prompt = [ - 'Q:{} {}\nA:'.format(question, self.prompt) - for question in questions - ] - image_position = 5 + question = data_sample.get('question') + prompt = '' + self.human_prompt + question + self.prompt + prompt += '\n' + self.assistant_prompt - return images, prompt, data_samples, image_position + image_position = prompt.rfind('') + 5 + + return image, prompt, data_sample, image_position class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor): @@ -130,12 +125,17 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor): The prompt will concat image and all terms in a question. Args: system_prompt (str): System prompt. (Default: '') + human_prompt (str): Human prompt. (Default: 'Q:') + assistant_prompt (str): Assistant prompt. (Default: 'A:') """ choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'} - def __init__(self, system_prompt='') -> None: - super().__init__(system_prompt) + def __init__(self, + system_prompt='', + human_prompt: str = 'Q:', + assistant_prompt: str = 'A:') -> None: + super().__init__(system_prompt, human_prompt, assistant_prompt) def __call__(self, batch: dict) -> tuple: """Construct prompt. @@ -147,33 +147,24 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor): A tuple containing images, prompt, data_samples and image_position. """ - images = batch.pop('inputs') - images = torch.stack(images, dim=0) - data_samples = batch.pop('data_samples') - questions = [ - 'Q: ' + sample.get('question') + '\n' for sample in data_samples - ] - choices = [sample.get('choices') for sample in data_samples] - choices = [[ - f'({self.choice_mapping[i]}) ' + item - for i, item in enumerate(choice) - ] for choice in choices] + assert len(batch['inputs']) == 1 + image = batch.pop('inputs')[0].unsqueeze(0) + data_sample = batch.pop('data_samples')[0] + + questions = 'Question: ' + data_sample.get('question') + choices = data_sample.get('choices') choices = [ - 'Choices: ' + ' '.join(choice) + '\n' for choice in choices - ] # noqa - contexts = [ - 'Context: ' + data_sample.get('hint') + '\n' - for data_sample in data_samples - ] # noqa + f'({self.choice_mapping[i]}) ' + item + for i, item in enumerate(choices) + ] + choices = 'Choices: ' + ' '.join(choices) + '\n' + contexts = 'Context: ' + data_sample.get('hint') + '\n' # generate text prompt - prompt = [ - '' + context + question + choice + self.prompt - for context, question, choice in zip(contexts, questions, choices) - ] - image_position = 5 + prompt = '' + self.human_prompt + contexts + questions + choices + self.prompt + self.assistant_prompt # noqa + image_position = prompt.rfind('') + 5 - return images, prompt, data_samples, image_position + return image, prompt, data_sample, image_position class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor): @@ -182,10 +173,15 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor): The prompt will concat , the question and the system prompt. Args: system_prompt (str): System prompt. (Default: '') + human_prompt (str): Human prompt. (Default: 'Q:') + assistant_prompt (str): Assistant prompt. (Default: 'A:') """ - def __init__(self, system_prompt='') -> None: - super().__init__(system_prompt) + def __init__(self, + system_prompt='', + human_prompt: str = 'Q:', + assistant_prompt: str = 'A:') -> None: + super().__init__(system_prompt, human_prompt, assistant_prompt) def __call__(self, batch: dict) -> tuple: """Construct prompt. @@ -197,22 +193,16 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor): A tuple containing images, prompt, data_samples and image_position. """ - images = batch.pop('inputs') - images = torch.stack(images, dim=0) - data_samples = batch.pop('data_samples') - questions = [ - 'Q: ' + sample.get('question') + '\n' for sample in data_samples - ] - choices = [sample.get('choices') for sample in data_samples] - choices = [ - 'Options: ' + ', '.join(choice) + '.\n' for choice in choices - ] # noqa + assert len(batch['inputs']) == 1 + image = batch.pop('inputs')[0].unsqueeze(0) + data_sample = batch.pop('data_samples')[0] + + questions = data_sample.get('question') + '\n' + choices = data_sample.get('choices') + choices = 'Options: ' + ', '.join(choices) + '.\n' # generate text prompt - prompt = [ - '' + question + choice + self.prompt - for question, choice in zip(questions, choices) - ] - image_position = 5 + prompt = '' + self.human_prompt + questions + choices + self.prompt + self.assistant_prompt # noqa + image_position = prompt.rfind('') + 5 - return images, prompt, data_samples, image_position + return image, prompt, data_sample, image_position diff --git a/opencompass/multimodal/models/visualglm/visualglm.py b/opencompass/multimodal/models/visualglm/visualglm.py index 7187b97e..a9534d94 100644 --- a/opencompass/multimodal/models/visualglm/visualglm.py +++ b/opencompass/multimodal/models/visualglm/visualglm.py @@ -43,39 +43,31 @@ class VisualGLM(nn.Module): if gen_kwargs: self.gen_kwargs = gen_kwargs else: - self.gen_kwargs = dict( - max_new_tokens=30, - num_beams=1, - do_sample=False, - repetition_penalty=1.0, - length_penalty=-1.0, - ) + self.gen_kwargs = dict(max_length=1024, + min_length=100, + do_sample=True, + temperature=0.8, + top_p=0.4, + top_k=100, + repetition_penalty=1.2) self.is_caption_task = is_caption_task - def encode_by_tokenizer(self, multi_prompts, image_position): - input_ids = [] - max_seq_length = 0 - for prompt in multi_prompts: - input0 = self.tokenizer.encode(prompt[:image_position], - add_special_tokens=False) - input1 = [self.tokenizer.pad_token_id] * self.model.image_length - input2 = self.tokenizer.encode(prompt[image_position:], - add_special_tokens=False) - input_all = sum([input0, input1, input2], []) - input_all = self.tokenizer.build_inputs_with_special_tokens( - input_all) - max_seq_length = max(max_seq_length, len(input_all)) - input_ids.append(input_all) + def encode_by_tokenizer(self, prompt, image_position): + + input0 = self.tokenizer.encode(prompt[:image_position], + add_special_tokens=False) + input1 = [self.tokenizer.unk_token_id] * self.model.image_length + input2 = self.tokenizer.encode(prompt[image_position:], + add_special_tokens=False) + input_all = sum([input0, input1, input2], []) + input_all = self.tokenizer.build_inputs_with_special_tokens(input_all) + input_all = torch.tensor(input_all, dtype=torch.long).to(get_device()) + input_all = input_all.unsqueeze(0) + pre_image_len = len(input0) - # padding - for i, _ in enumerate(input_ids): - pad_len = max_seq_length - len(input_ids[i]) - input_ids[i] = [self.tokenizer.pad_token_id - ] * pad_len + input_ids[i] - - return input_ids, pre_image_len + return input_all, pre_image_len def generate(self, batch): # process input @@ -87,26 +79,24 @@ class VisualGLM(nn.Module): input_all, pre_image_len = self.encode_by_tokenizer( prompt, image_position) - input_all = torch.tensor(input_all, dtype=torch.long).to(get_device()) - # build input param inputs = { 'input_ids': input_all, 'pre_image_length': pre_image_len, 'images': image } + # generate answer outputs = self.model.generate(**inputs, **self.gen_kwargs) # format output - outputs = outputs.tolist() - for i, sample in enumerate(data_sample): - answer = self.post_processor(outputs[i], self.tokenizer, - input_all.shape[1]) - if self.is_caption_task: - data_sample[i].pred_caption = answer - else: - data_sample[i].pred_answer = answer + outputs = outputs.tolist()[0][input_all.shape[1]:] + answer = self.post_processor(outputs, self.tokenizer) + + if self.is_caption_task: + data_sample.pred_caption = answer + else: + data_sample.pred_answer = answer return data_sample