[Fix] meta template & unit tests (#170)

This commit is contained in:
Tong Gao 2023-08-10 16:49:13 +08:00 committed by GitHub
parent ed248af136
commit 312095de9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 43 deletions

View File

@ -245,11 +245,14 @@ class LMTemplateParser:
section_stack.append((item['section'], i + 1))
else:
raise ValueError(f'Invalid pos {item["pos"]}')
# if in "begin" or "end" section
elif section_stack[-1][0] in ['begin', 'end']:
role_dict = self._update_role_dict(item)
new_str, generate = self._prompt2str(item,
role_dict,
for_gen=mode == 'gen')
new_str, generate = self._prompt2str(
item,
role_dict,
# never stop generation
for_gen=False)
prompt += new_str
prompt = self.meta_template.get('begin', '') + prompt

View File

@ -12,37 +12,29 @@ class TestPromptTemplate(unittest.TestCase):
'</E>',
],
round=[
dict(role='HUMAN', prompt='</input>'),
dict(role='BOT',
prompt='Answer: </answer>')
dict(role='HUMAN', prompt='{input}'),
dict(role='BOT', prompt='Answer: {answer}')
])
self.multiround_qa_template = dict(round=[
dict(role='HUMAN', prompt='</input>'),
dict(role='HUMAN', prompt='{input}'),
dict(role='BOT', prompt='A1', end='\n'),
dict(role='HUMAN', prompt='Q1'),
dict(role='BOT', prompt='A2', end='\n\n'),
dict(role='HUMAN', prompt='Q2', begin='HUMAN:'),
dict(role='BOT', prompt='Answer: </answer>')
dict(role='BOT', prompt='Answer: {answer}')
])
self.column_token_map = {
'input': '</input>',
'answer': '</answer>',
}
self.entry = {'input': 'Hello, how are you?', 'answer': 'Good.'}
def test_init(self):
template = 'Translate the following English text to French: {t}.'
column_token_map = {'input': '{t}'}
pt = PromptTemplate(template, column_token_map)
template = 'Translate the following English text to French: {input}.'
pt = PromptTemplate(template)
self.assertEqual(pt.template, template)
self.assertEqual(pt.column_token_map, column_token_map)
def test_generate_ice_item(self):
# Test simple prompt
template = 'Translate the following English text to French: {t}.'
column_token_map = {'input': '{t}'}
pt = PromptTemplate(template, column_token_map)
template = 'Translate the following English text to French: {input}.'
pt = PromptTemplate(template)
label = None
ice = pt.generate_ice_item(self.entry, label)
@ -51,9 +43,7 @@ class TestPromptTemplate(unittest.TestCase):
'Hello, how are you?.'))
# test meta prompt style
pt = PromptTemplate(self.qa_template,
self.column_token_map,
ice_token='</E>')
pt = PromptTemplate(self.qa_template, ice_token='</E>')
label = None
ice = pt.generate_ice_item(self.entry, label)
@ -72,9 +62,7 @@ class TestPromptTemplate(unittest.TestCase):
self.assertEqual(ice, ice_target)
# test_multiround
pt = PromptTemplate(self.multiround_qa_template,
self.column_token_map,
ice_token='</E>')
pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>')
label = None
ice = pt.generate_ice_item(self.entry, label)
@ -98,9 +86,9 @@ class TestPromptTemplate(unittest.TestCase):
def test_generate_label_prompt_item(self):
# Test simple prompt
template = '</E> Translate the following English text to French: {t}.'
column_token_map = {'input': '{t}'}
pt = PromptTemplate(template, column_token_map, ice_token='</E>')
template = ('</E> Translate the following English text to French: '
'{input}.')
pt = PromptTemplate(template, ice_token='</E>')
ice = 'ICE'
label = None
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
@ -123,9 +111,7 @@ class TestPromptTemplate(unittest.TestCase):
])
# test meta prompt style
pt = PromptTemplate(self.qa_template,
self.column_token_map,
ice_token='</E>')
pt = PromptTemplate(self.qa_template, ice_token='</E>')
label = None
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
target = PromptList([
@ -162,9 +148,7 @@ class TestPromptTemplate(unittest.TestCase):
self.assertEqual(prompt, target)
# test_multiround
pt = PromptTemplate(self.multiround_qa_template,
self.column_token_map,
ice_token='</E>')
pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>')
label = None
prompt = pt.generate_label_prompt_item(self.entry, ice, label)
target = PromptList([
@ -187,9 +171,8 @@ class TestPromptTemplate(unittest.TestCase):
def test_generate_item(self):
# Test simple prompt
template = 'Translate the following English text to French: {t}.'
column_token_map = {'input': '{t}'}
pt = PromptTemplate(template, column_token_map)
template = 'Translate the following English text to French: {input}.'
pt = PromptTemplate(template)
item = pt.generate_item(self.entry)
self.assertEqual(item,
@ -210,9 +193,7 @@ class TestPromptTemplate(unittest.TestCase):
])
# test meta prompt (without system role)
pt = PromptTemplate(self.qa_template,
self.column_token_map,
ice_token='</E>')
pt = PromptTemplate(self.qa_template, ice_token='</E>')
prompt = pt.generate_item(self.entry, ice_field_replace_token=ice)
target = PromptList([
{
@ -247,9 +228,7 @@ class TestPromptTemplate(unittest.TestCase):
])
self.assertEqual(prompt, target)
pt = PromptTemplate(self.multiround_qa_template,
self.column_token_map,
ice_token='</E>')
pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>')
prompt = pt.generate_item(self.entry, ice_field_replace_token=ice)
target = PromptList([
{