mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
![]() |
import unittest
|
||
|
|
||
|
from opencompass.utils.prompt import PromptList
|
||
|
|
||
|
|
||
|
class TestPromptList(unittest.TestCase):
|
||
|
|
||
|
def test_initialization(self):
|
||
|
pl = PromptList()
|
||
|
self.assertEqual(pl, [])
|
||
|
|
||
|
pl = PromptList(['test', '123'])
|
||
|
self.assertEqual(pl, ['test', '123'])
|
||
|
|
||
|
def test_format(self):
|
||
|
pl = PromptList(['hi {a}{b}', {'prompt': 'hey {a}!'}, '123'])
|
||
|
new_pl = pl.format(a=1, b=2, c=3)
|
||
|
self.assertEqual(new_pl, ['hi 12', {'prompt': 'hey 1!'}, '123'])
|
||
|
|
||
|
new_pl = pl.format(b=2)
|
||
|
self.assertEqual(new_pl, ['hi {a}2', {'prompt': 'hey {a}!'}, '123'])
|
||
|
|
||
|
new_pl = pl.format(d=1)
|
||
|
self.assertEqual(new_pl, ['hi {a}{b}', {'prompt': 'hey {a}!'}, '123'])
|
||
|
|
||
|
def test_replace(self):
|
||
|
pl = PromptList(['hello world', {'prompt': 'hello world'}, '123'])
|
||
|
new_pl = pl.replace('world', 'there')
|
||
|
self.assertEqual(new_pl,
|
||
|
['hello there', {
|
||
|
'prompt': 'hello there'
|
||
|
}, '123'])
|
||
|
|
||
|
new_pl = pl.replace('123', PromptList(['p', {'role': 'BOT'}]))
|
||
|
self.assertEqual(
|
||
|
new_pl,
|
||
|
['hello world', {
|
||
|
'prompt': 'hello world'
|
||
|
}, 'p', {
|
||
|
'role': 'BOT'
|
||
|
}])
|
||
|
|
||
|
new_pl = pl.replace('2', PromptList(['p', {'role': 'BOT'}]))
|
||
|
self.assertEqual(new_pl, [
|
||
|
'hello world', {
|
||
|
'prompt': 'hello world'
|
||
|
}, '1', 'p', {
|
||
|
'role': 'BOT'
|
||
|
}, '3'
|
||
|
])
|
||
|
|
||
|
with self.assertRaises(TypeError):
|
||
|
new_pl = pl.replace('world', PromptList(['new', 'world']))
|
||
|
|
||
|
def test_add(self):
|
||
|
pl = PromptList(['hello'])
|
||
|
new_pl = pl + ' world'
|
||
|
self.assertEqual(new_pl, ['hello', ' world'])
|
||
|
|
||
|
pl2 = PromptList([' world'])
|
||
|
new_pl = pl + pl2
|
||
|
self.assertEqual(new_pl, ['hello', ' world'])
|
||
|
|
||
|
new_pl = 'hi, ' + pl
|
||
|
self.assertEqual(new_pl, ['hi, ', 'hello'])
|
||
|
|
||
|
pl += '!'
|
||
|
self.assertEqual(pl, ['hello', '!'])
|
||
|
|
||
|
def test_str(self):
|
||
|
pl = PromptList(['hello', ' world', {'prompt': '!'}])
|
||
|
self.assertEqual(str(pl), 'hello world!')
|
||
|
|
||
|
with self.assertRaises(TypeError):
|
||
|
pl = PromptList(['hello', ' world', 123])
|
||
|
str(pl)
|