[Fix] bin_trim (#237)

Co-authored-by: wangchonghua <wangchonghua@pjlab.org.cn>
This commit is contained in:
philipwangOvO 2023-08-21 15:44:49 +08:00 committed by GitHub
parent 655a807f4b
commit 3b29aaee2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -281,18 +281,20 @@ class OpenAI(BaseAPIModel):
pattern = re.compile(r'[\u4e00-\u9fa5]')
if pattern.search(prompt):
words = list(jieba.cut(prompt, cut_all=False))
sep = ''
else:
words = prompt.split(' ')
sep = ' '
l, r = 1, len(words)
while l + 2 < r:
mid = (l + r) // 2
if self.mode == 'front':
cur_prompt = ' '.join(words[-mid:])
cur_prompt = sep.join(words[-mid:])
elif self.mode == 'mid':
cur_prompt = ' '.join(words[:mid]) + ' '.join(words[-mid:])
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
elif self.mode == 'rear':
cur_prompt = ' '.join(words[:mid])
cur_prompt = sep.join(words[:mid])
if self.get_token_len(cur_prompt) <= num_token:
l = mid # noqa: E741
@ -300,9 +302,9 @@ class OpenAI(BaseAPIModel):
r = mid
if self.mode == 'front':
prompt = ' '.join(words[-l:])
prompt = sep.join(words[-l:])
elif self.mode == 'mid':
prompt = ' '.join(words[:l]) + ' '.join(words[-l:])
prompt = sep.join(words[:l]) + sep.join(words[-l:])
elif self.mode == 'rear':
prompt = ' '.join(words[:l])
prompt = sep.join(words[:l])
return prompt