adapt to lmdeploy v0.4.0 (#1073)

* adapt to lmdeploy v0.4.0

* compatible
This commit is contained in:
Lyu Han 2024-04-28 19:57:40 +08:00 committed by GitHub
parent 58a57a4c45
commit 1013dce60c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 4 deletions

View File

@ -50,6 +50,7 @@ class LmdeployPytorchModel(BaseModel):
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.pytorch import engine as tm
from lmdeploy.version import version_info
if engine_config is not None:
from lmdeploy.messages import PytorchEngineConfig
@ -71,6 +72,7 @@ class LmdeployPytorchModel(BaseModel):
self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
self.end_str = end_str
self.major_version, self.minor_version, _ = version_info
def generate(
self,
@ -145,9 +147,16 @@ class LmdeployPytorchModel(BaseModel):
assert type(
prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt)
_, output_ids, _ = generator.infer(session_id,
input_ids,
gen_config=gen_config)
if self.major_version >= 0 and self.minor_version >= 4:
outputs = generator.infer(session_id,
input_ids,
gen_config=gen_config)
output_ids = outputs.token_ids
else:
_, output_ids, _ = generator.infer(session_id,
input_ids,
gen_config=gen_config)
# stop engine
if hasattr(generator, 'end'):
generator.end(session_id)

View File

@ -54,6 +54,7 @@ class TurboMindModel(BaseModel):
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.turbomind import TurboMind
from lmdeploy.version import version_info
if engine_config is not None:
from lmdeploy.messages import TurbomindEngineConfig
@ -70,6 +71,7 @@ class TurboMindModel(BaseModel):
self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
self.end_str = end_str
self.major_version, self.minor_version, _ = version_info
def generate(self,
inputs: List[str],
@ -165,7 +167,10 @@ class TurboMindModel(BaseModel):
sequence_end=True,
step=0,
stream_output=False):
_, output_ids, _ = outputs
if self.major_version >= 0 and self.minor_version >= 4:
output_ids = outputs.token_ids
else:
_, output_ids, _ = outputs
response = self.tokenizer.decode(output_ids)
response = valid_str(response)
# used to trim