mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] meta template & unit tests (#170)
This commit is contained in:
parent
ed248af136
commit
312095de9d
@ -245,11 +245,14 @@ class LMTemplateParser:
|
|||||||
section_stack.append((item['section'], i + 1))
|
section_stack.append((item['section'], i + 1))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Invalid pos {item["pos"]}')
|
raise ValueError(f'Invalid pos {item["pos"]}')
|
||||||
|
# if in "begin" or "end" section
|
||||||
elif section_stack[-1][0] in ['begin', 'end']:
|
elif section_stack[-1][0] in ['begin', 'end']:
|
||||||
role_dict = self._update_role_dict(item)
|
role_dict = self._update_role_dict(item)
|
||||||
new_str, generate = self._prompt2str(item,
|
new_str, generate = self._prompt2str(
|
||||||
|
item,
|
||||||
role_dict,
|
role_dict,
|
||||||
for_gen=mode == 'gen')
|
# never stop generation
|
||||||
|
for_gen=False)
|
||||||
prompt += new_str
|
prompt += new_str
|
||||||
|
|
||||||
prompt = self.meta_template.get('begin', '') + prompt
|
prompt = self.meta_template.get('begin', '') + prompt
|
||||||
|
@ -12,37 +12,29 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
'</E>',
|
'</E>',
|
||||||
],
|
],
|
||||||
round=[
|
round=[
|
||||||
dict(role='HUMAN', prompt='</input>'),
|
dict(role='HUMAN', prompt='{input}'),
|
||||||
dict(role='BOT',
|
dict(role='BOT', prompt='Answer: {answer}')
|
||||||
prompt='Answer: </answer>')
|
|
||||||
])
|
])
|
||||||
self.multiround_qa_template = dict(round=[
|
self.multiround_qa_template = dict(round=[
|
||||||
dict(role='HUMAN', prompt='</input>'),
|
dict(role='HUMAN', prompt='{input}'),
|
||||||
dict(role='BOT', prompt='A1', end='\n'),
|
dict(role='BOT', prompt='A1', end='\n'),
|
||||||
dict(role='HUMAN', prompt='Q1'),
|
dict(role='HUMAN', prompt='Q1'),
|
||||||
dict(role='BOT', prompt='A2', end='\n\n'),
|
dict(role='BOT', prompt='A2', end='\n\n'),
|
||||||
dict(role='HUMAN', prompt='Q2', begin='HUMAN:'),
|
dict(role='HUMAN', prompt='Q2', begin='HUMAN:'),
|
||||||
dict(role='BOT', prompt='Answer: </answer>')
|
dict(role='BOT', prompt='Answer: {answer}')
|
||||||
])
|
])
|
||||||
self.column_token_map = {
|
|
||||||
'input': '</input>',
|
|
||||||
'answer': '</answer>',
|
|
||||||
}
|
|
||||||
self.entry = {'input': 'Hello, how are you?', 'answer': 'Good.'}
|
self.entry = {'input': 'Hello, how are you?', 'answer': 'Good.'}
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
template = 'Translate the following English text to French: {t}.'
|
template = 'Translate the following English text to French: {input}.'
|
||||||
column_token_map = {'input': '{t}'}
|
pt = PromptTemplate(template)
|
||||||
pt = PromptTemplate(template, column_token_map)
|
|
||||||
|
|
||||||
self.assertEqual(pt.template, template)
|
self.assertEqual(pt.template, template)
|
||||||
self.assertEqual(pt.column_token_map, column_token_map)
|
|
||||||
|
|
||||||
def test_generate_ice_item(self):
|
def test_generate_ice_item(self):
|
||||||
# Test simple prompt
|
# Test simple prompt
|
||||||
template = 'Translate the following English text to French: {t}.'
|
template = 'Translate the following English text to French: {input}.'
|
||||||
column_token_map = {'input': '{t}'}
|
pt = PromptTemplate(template)
|
||||||
pt = PromptTemplate(template, column_token_map)
|
|
||||||
label = None
|
label = None
|
||||||
ice = pt.generate_ice_item(self.entry, label)
|
ice = pt.generate_ice_item(self.entry, label)
|
||||||
|
|
||||||
@ -51,9 +43,7 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
'Hello, how are you?.'))
|
'Hello, how are you?.'))
|
||||||
|
|
||||||
# test meta prompt style
|
# test meta prompt style
|
||||||
pt = PromptTemplate(self.qa_template,
|
pt = PromptTemplate(self.qa_template, ice_token='</E>')
|
||||||
self.column_token_map,
|
|
||||||
ice_token='</E>')
|
|
||||||
label = None
|
label = None
|
||||||
ice = pt.generate_ice_item(self.entry, label)
|
ice = pt.generate_ice_item(self.entry, label)
|
||||||
|
|
||||||
@ -72,9 +62,7 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
self.assertEqual(ice, ice_target)
|
self.assertEqual(ice, ice_target)
|
||||||
|
|
||||||
# test_multiround
|
# test_multiround
|
||||||
pt = PromptTemplate(self.multiround_qa_template,
|
pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>')
|
||||||
self.column_token_map,
|
|
||||||
ice_token='</E>')
|
|
||||||
label = None
|
label = None
|
||||||
ice = pt.generate_ice_item(self.entry, label)
|
ice = pt.generate_ice_item(self.entry, label)
|
||||||
|
|
||||||
@ -98,9 +86,9 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
|
|
||||||
def test_generate_label_prompt_item(self):
|
def test_generate_label_prompt_item(self):
|
||||||
# Test simple prompt
|
# Test simple prompt
|
||||||
template = '</E> Translate the following English text to French: {t}.'
|
template = ('</E> Translate the following English text to French: '
|
||||||
column_token_map = {'input': '{t}'}
|
'{input}.')
|
||||||
pt = PromptTemplate(template, column_token_map, ice_token='</E>')
|
pt = PromptTemplate(template, ice_token='</E>')
|
||||||
ice = 'ICE'
|
ice = 'ICE'
|
||||||
label = None
|
label = None
|
||||||
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
|
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
|
||||||
@ -123,9 +111,7 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# test meta prompt style
|
# test meta prompt style
|
||||||
pt = PromptTemplate(self.qa_template,
|
pt = PromptTemplate(self.qa_template, ice_token='</E>')
|
||||||
self.column_token_map,
|
|
||||||
ice_token='</E>')
|
|
||||||
label = None
|
label = None
|
||||||
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
|
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
|
||||||
target = PromptList([
|
target = PromptList([
|
||||||
@ -162,9 +148,7 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
self.assertEqual(prompt, target)
|
self.assertEqual(prompt, target)
|
||||||
|
|
||||||
# test_multiround
|
# test_multiround
|
||||||
pt = PromptTemplate(self.multiround_qa_template,
|
pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>')
|
||||||
self.column_token_map,
|
|
||||||
ice_token='</E>')
|
|
||||||
label = None
|
label = None
|
||||||
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
|
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
|
||||||
target = PromptList([
|
target = PromptList([
|
||||||
@ -187,9 +171,8 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
|
|
||||||
def test_generate_item(self):
|
def test_generate_item(self):
|
||||||
# Test simple prompt
|
# Test simple prompt
|
||||||
template = 'Translate the following English text to French: {t}.'
|
template = 'Translate the following English text to French: {input}.'
|
||||||
column_token_map = {'input': '{t}'}
|
pt = PromptTemplate(template)
|
||||||
pt = PromptTemplate(template, column_token_map)
|
|
||||||
item = pt.generate_item(self.entry)
|
item = pt.generate_item(self.entry)
|
||||||
|
|
||||||
self.assertEqual(item,
|
self.assertEqual(item,
|
||||||
@ -210,9 +193,7 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# test meta prompt (without system role)
|
# test meta prompt (without system role)
|
||||||
pt = PromptTemplate(self.qa_template,
|
pt = PromptTemplate(self.qa_template, ice_token='</E>')
|
||||||
self.column_token_map,
|
|
||||||
ice_token='</E>')
|
|
||||||
prompt = pt.generate_item(self.entry, ice_field_replace_token=ice)
|
prompt = pt.generate_item(self.entry, ice_field_replace_token=ice)
|
||||||
target = PromptList([
|
target = PromptList([
|
||||||
{
|
{
|
||||||
@ -247,9 +228,7 @@ class TestPromptTemplate(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
self.assertEqual(prompt, target)
|
self.assertEqual(prompt, target)
|
||||||
|
|
||||||
pt = PromptTemplate(self.multiround_qa_template,
|
pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>')
|
||||||
self.column_token_map,
|
|
||||||
ice_token='</E>')
|
|
||||||
prompt = pt.generate_item(self.entry, ice_field_replace_token=ice)
|
prompt = pt.generate_item(self.entry, ice_field_replace_token=ice)
|
||||||
target = PromptList([
|
target = PromptList([
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user