2023-07-04 21:34:55 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import hashlib
|
|
|
|
import json
|
|
|
|
from copy import deepcopy
|
2023-07-05 18:29:07 +08:00
|
|
|
from typing import Dict, List, Union
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
from mmengine.config import ConfigDict
|
|
|
|
|
|
|
|
|
|
|
|
def safe_format(input_str: str, **kwargs) -> str:
|
|
|
|
"""Safely formats a string with the given keyword arguments. If a keyword
|
|
|
|
is not found in the string, it will be ignored.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_str (str): The string to be formatted.
|
|
|
|
**kwargs: The keyword arguments to be used for formatting.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: The formatted string.
|
|
|
|
"""
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
input_str = input_str.replace(f'{{{k}}}', str(v))
|
|
|
|
return input_str
|
|
|
|
|
|
|
|
|
2023-07-05 18:29:07 +08:00
|
|
|
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
|
2023-07-04 21:34:55 +08:00
|
|
|
"""Get the hash of the prompt configuration.
|
|
|
|
|
|
|
|
Args:
|
2023-07-05 18:29:07 +08:00
|
|
|
dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
|
|
|
|
configuration.
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: The hash of the prompt configuration.
|
|
|
|
"""
|
2023-07-05 18:29:07 +08:00
|
|
|
if isinstance(dataset_cfg, list):
|
|
|
|
if len(dataset_cfg) == 1:
|
|
|
|
dataset_cfg = dataset_cfg[0]
|
|
|
|
else:
|
|
|
|
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
|
|
|
|
hash_object = hashlib.sha256(hashes.encode())
|
|
|
|
return hash_object.hexdigest()
|
2023-07-04 21:34:55 +08:00
|
|
|
if 'reader_cfg' in dataset_cfg.infer_cfg:
|
|
|
|
# new config
|
|
|
|
reader_cfg = dict(type='DatasetReader',
|
|
|
|
input_columns=dataset_cfg.reader_cfg.input_columns,
|
|
|
|
output_column=dataset_cfg.reader_cfg.output_column)
|
|
|
|
dataset_cfg.infer_cfg.reader = reader_cfg
|
|
|
|
if 'train_split' in dataset_cfg.infer_cfg.reader_cfg:
|
|
|
|
dataset_cfg.infer_cfg.retriever[
|
|
|
|
'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][
|
|
|
|
'train_split']
|
|
|
|
if 'test_split' in dataset_cfg.infer_cfg.reader_cfg:
|
|
|
|
dataset_cfg.infer_cfg.retriever[
|
|
|
|
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
|
|
|
|
for k, v in dataset_cfg.infer_cfg.items():
|
|
|
|
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
|
2023-10-07 12:53:41 +08:00
|
|
|
# A compromise for the hash consistency
|
|
|
|
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
|
|
|
|
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
|
|
|
|
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
|
2023-07-05 18:29:07 +08:00
|
|
|
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
|
2023-07-04 21:34:55 +08:00
|
|
|
hash_object = hashlib.sha256(d_json.encode())
|
|
|
|
return hash_object.hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
class PromptList(list):
|
|
|
|
"""An enhanced list, used for intermidate representation of a prompt."""
|
|
|
|
|
|
|
|
def format(self, **kwargs) -> PromptList:
|
|
|
|
"""Replaces all instances of 'src' in the PromptList with 'dst'.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
src (str): The string to be replaced.
|
2024-04-09 17:50:23 +08:00
|
|
|
dst (PromptType): The string or PromptList to replace with.
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
PromptList: A new PromptList with 'src' replaced by 'dst'.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
TypeError: If 'dst' is a PromptList and 'src' is in a dictionary's
|
|
|
|
'prompt' key.
|
|
|
|
"""
|
|
|
|
new_list = PromptList()
|
|
|
|
for item in self:
|
|
|
|
if isinstance(item, Dict):
|
|
|
|
new_item = deepcopy(item)
|
|
|
|
if 'prompt' in item:
|
|
|
|
new_item['prompt'] = safe_format(item['prompt'], **kwargs)
|
|
|
|
new_list.append(new_item)
|
|
|
|
else:
|
|
|
|
new_list.append(safe_format(item, **kwargs))
|
|
|
|
return new_list
|
|
|
|
|
|
|
|
def replace(self, src: str, dst: Union[str, PromptList]) -> PromptList:
|
|
|
|
"""Replaces all instances of 'src' in the PromptList with 'dst'.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
src (str): The string to be replaced.
|
2024-04-09 17:50:23 +08:00
|
|
|
dst (PromptType): The string or PromptList to replace with.
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
PromptList: A new PromptList with 'src' replaced by 'dst'.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
TypeError: If 'dst' is a PromptList and 'src' is in a dictionary's
|
|
|
|
'prompt' key.
|
|
|
|
"""
|
|
|
|
new_list = PromptList()
|
|
|
|
for item in self:
|
|
|
|
if isinstance(item, str):
|
|
|
|
if isinstance(dst, str):
|
|
|
|
new_list.append(item.replace(src, dst))
|
|
|
|
elif isinstance(dst, PromptList):
|
|
|
|
split_str = item.split(src)
|
|
|
|
for i, split_item in enumerate(split_str):
|
|
|
|
if split_item:
|
|
|
|
new_list.append(split_item)
|
|
|
|
if i < len(split_str) - 1:
|
|
|
|
new_list += dst
|
|
|
|
elif isinstance(item, Dict):
|
|
|
|
new_item = deepcopy(item)
|
|
|
|
if 'prompt' in item:
|
|
|
|
if src in item['prompt']:
|
|
|
|
if isinstance(dst, PromptList):
|
|
|
|
raise TypeError(
|
|
|
|
f'Found keyword {src} in a dictionary\'s '
|
|
|
|
'prompt key. Cannot replace with a '
|
|
|
|
'PromptList.')
|
|
|
|
new_item['prompt'] = new_item['prompt'].replace(
|
|
|
|
src, dst)
|
|
|
|
new_list.append(new_item)
|
|
|
|
else:
|
|
|
|
new_list.append(item.replace(src, dst))
|
|
|
|
return new_list
|
|
|
|
|
|
|
|
def __add__(self, other: Union[str, PromptList]) -> PromptList:
|
|
|
|
"""Adds a string or another PromptList to this PromptList.
|
|
|
|
|
|
|
|
Args:
|
2024-04-09 17:50:23 +08:00
|
|
|
other (PromptType): The string or PromptList to be added.
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
PromptList: A new PromptList that is the result of the addition.
|
|
|
|
"""
|
|
|
|
if not other:
|
|
|
|
return PromptList([*self])
|
|
|
|
if isinstance(other, str):
|
|
|
|
return PromptList(self + [other])
|
|
|
|
else:
|
|
|
|
return PromptList(super().__add__(other))
|
|
|
|
|
|
|
|
def __radd__(self, other: Union[str, PromptList]) -> PromptList:
|
|
|
|
"""Implements addition when the PromptList is on the right side of the
|
|
|
|
'+' operator.
|
|
|
|
|
|
|
|
Args:
|
2024-04-09 17:50:23 +08:00
|
|
|
other (PromptType): The string or PromptList to be added.
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
PromptList: A new PromptList that is the result of the addition.
|
|
|
|
"""
|
|
|
|
if not other:
|
|
|
|
return PromptList([*self])
|
|
|
|
if isinstance(other, str):
|
|
|
|
return PromptList([other, *self])
|
|
|
|
else:
|
|
|
|
return PromptList(other + self)
|
|
|
|
|
|
|
|
def __iadd__(self, other: Union[str, PromptList]) -> PromptList:
|
|
|
|
"""Implements in-place addition for the PromptList.
|
|
|
|
|
|
|
|
Args:
|
2024-04-09 17:50:23 +08:00
|
|
|
other (PromptType): The string or PromptList to be added.
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
PromptList: The updated PromptList.
|
|
|
|
"""
|
|
|
|
if not other:
|
|
|
|
return self
|
|
|
|
if isinstance(other, str):
|
|
|
|
self.append(other)
|
|
|
|
else:
|
|
|
|
super().__iadd__(other)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
"""Converts the PromptList into a string.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: The string representation of the PromptList.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
TypeError: If there's an item in the PromptList that is not a
|
|
|
|
string or dictionary.
|
|
|
|
"""
|
|
|
|
res = []
|
|
|
|
for item in self:
|
|
|
|
if isinstance(item, str):
|
|
|
|
res.append(item)
|
|
|
|
elif isinstance(item, dict):
|
|
|
|
if 'prompt' in item:
|
|
|
|
res.append(item['prompt'])
|
|
|
|
else:
|
|
|
|
raise TypeError('Invalid type in prompt list when '
|
|
|
|
'converting to string')
|
|
|
|
return ''.join(res)
|