diff --git a/opencompass/models/base.py b/opencompass/models/base.py index 6a560ce8..ec29a8f3 100644 --- a/opencompass/models/base.py +++ b/opencompass/models/base.py @@ -245,11 +245,14 @@ class LMTemplateParser: section_stack.append((item['section'], i + 1)) else: raise ValueError(f'Invalid pos {item["pos"]}') + # if in "begin" or "end" section elif section_stack[-1][0] in ['begin', 'end']: role_dict = self._update_role_dict(item) - new_str, generate = self._prompt2str(item, - role_dict, - for_gen=mode == 'gen') + new_str, generate = self._prompt2str( + item, + role_dict, + # never stop generation + for_gen=False) prompt += new_str prompt = self.meta_template.get('begin', '') + prompt diff --git a/tests/openicl/test_prompt_template.py b/tests/openicl/test_prompt_template.py index a831a20f..0faf00bf 100644 --- a/tests/openicl/test_prompt_template.py +++ b/tests/openicl/test_prompt_template.py @@ -12,37 +12,29 @@ class TestPromptTemplate(unittest.TestCase): '', ], round=[ - dict(role='HUMAN', prompt=''), - dict(role='BOT', - prompt='Answer: ') + dict(role='HUMAN', prompt='{input}'), + dict(role='BOT', prompt='Answer: {answer}') ]) self.multiround_qa_template = dict(round=[ - dict(role='HUMAN', prompt=''), + dict(role='HUMAN', prompt='{input}'), dict(role='BOT', prompt='A1', end='\n'), dict(role='HUMAN', prompt='Q1'), dict(role='BOT', prompt='A2', end='\n\n'), dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), - dict(role='BOT', prompt='Answer: ') + dict(role='BOT', prompt='Answer: {answer}') ]) - self.column_token_map = { - 'input': '', - 'answer': '', - } self.entry = {'input': 'Hello, how are you?', 'answer': 'Good.'} def test_init(self): - template = 'Translate the following English text to French: {t}.' - column_token_map = {'input': '{t}'} - pt = PromptTemplate(template, column_token_map) + template = 'Translate the following English text to French: {input}.' + pt = PromptTemplate(template) self.assertEqual(pt.template, template) - self.assertEqual(pt.column_token_map, column_token_map) def test_generate_ice_item(self): # Test simple prompt - template = 'Translate the following English text to French: {t}.' - column_token_map = {'input': '{t}'} - pt = PromptTemplate(template, column_token_map) + template = 'Translate the following English text to French: {input}.' + pt = PromptTemplate(template) label = None ice = pt.generate_ice_item(self.entry, label) @@ -51,9 +43,7 @@ class TestPromptTemplate(unittest.TestCase): 'Hello, how are you?.')) # test meta prompt style - pt = PromptTemplate(self.qa_template, - self.column_token_map, - ice_token='') + pt = PromptTemplate(self.qa_template, ice_token='') label = None ice = pt.generate_ice_item(self.entry, label) @@ -72,9 +62,7 @@ class TestPromptTemplate(unittest.TestCase): self.assertEqual(ice, ice_target) # test_multiround - pt = PromptTemplate(self.multiround_qa_template, - self.column_token_map, - ice_token='') + pt = PromptTemplate(self.multiround_qa_template, ice_token='') label = None ice = pt.generate_ice_item(self.entry, label) @@ -98,9 +86,9 @@ class TestPromptTemplate(unittest.TestCase): def test_generate_label_prompt_item(self): # Test simple prompt - template = ' Translate the following English text to French: {t}.' - column_token_map = {'input': '{t}'} - pt = PromptTemplate(template, column_token_map, ice_token='') + template = (' Translate the following English text to French: ' + '{input}.') + pt = PromptTemplate(template, ice_token='') ice = 'ICE' label = None prompt = pt.generate_label_prompt_item(self.entry, ice, label) @@ -123,9 +111,7 @@ class TestPromptTemplate(unittest.TestCase): ]) # test meta prompt style - pt = PromptTemplate(self.qa_template, - self.column_token_map, - ice_token='') + pt = PromptTemplate(self.qa_template, ice_token='') label = None prompt = pt.generate_label_prompt_item(self.entry, ice, label) target = PromptList([ @@ -162,9 +148,7 @@ class TestPromptTemplate(unittest.TestCase): self.assertEqual(prompt, target) # test_multiround - pt = PromptTemplate(self.multiround_qa_template, - self.column_token_map, - ice_token='') + pt = PromptTemplate(self.multiround_qa_template, ice_token='') label = None prompt = pt.generate_label_prompt_item(self.entry, ice, label) target = PromptList([ @@ -187,9 +171,8 @@ class TestPromptTemplate(unittest.TestCase): def test_generate_item(self): # Test simple prompt - template = 'Translate the following English text to French: {t}.' - column_token_map = {'input': '{t}'} - pt = PromptTemplate(template, column_token_map) + template = 'Translate the following English text to French: {input}.' + pt = PromptTemplate(template) item = pt.generate_item(self.entry) self.assertEqual(item, @@ -210,9 +193,7 @@ class TestPromptTemplate(unittest.TestCase): ]) # test meta prompt (without system role) - pt = PromptTemplate(self.qa_template, - self.column_token_map, - ice_token='') + pt = PromptTemplate(self.qa_template, ice_token='') prompt = pt.generate_item(self.entry, ice_field_replace_token=ice) target = PromptList([ { @@ -247,9 +228,7 @@ class TestPromptTemplate(unittest.TestCase): ]) self.assertEqual(prompt, target) - pt = PromptTemplate(self.multiround_qa_template, - self.column_token_map, - ice_token='') + pt = PromptTemplate(self.multiround_qa_template, ice_token='') prompt = pt.generate_item(self.entry, ice_field_replace_token=ice) target = PromptList([ {