mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
adapt to lmdeploy v0.4.0 (#1073)
* adapt to lmdeploy v0.4.0 * compatible
This commit is contained in:
parent
58a57a4c45
commit
1013dce60c
@ -50,6 +50,7 @@ class LmdeployPytorchModel(BaseModel):
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template)
|
||||
from lmdeploy.pytorch import engine as tm
|
||||
from lmdeploy.version import version_info
|
||||
|
||||
if engine_config is not None:
|
||||
from lmdeploy.messages import PytorchEngineConfig
|
||||
@ -71,6 +72,7 @@ class LmdeployPytorchModel(BaseModel):
|
||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||
self.gen_config = gen_config
|
||||
self.end_str = end_str
|
||||
self.major_version, self.minor_version, _ = version_info
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -145,9 +147,16 @@ class LmdeployPytorchModel(BaseModel):
|
||||
assert type(
|
||||
prompt) is str, 'We only support string for TurboMind Python API'
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
_, output_ids, _ = generator.infer(session_id,
|
||||
input_ids,
|
||||
gen_config=gen_config)
|
||||
if self.major_version >= 0 and self.minor_version >= 4:
|
||||
outputs = generator.infer(session_id,
|
||||
input_ids,
|
||||
gen_config=gen_config)
|
||||
output_ids = outputs.token_ids
|
||||
else:
|
||||
_, output_ids, _ = generator.infer(session_id,
|
||||
input_ids,
|
||||
gen_config=gen_config)
|
||||
|
||||
# stop engine
|
||||
if hasattr(generator, 'end'):
|
||||
generator.end(session_id)
|
||||
|
@ -54,6 +54,7 @@ class TurboMindModel(BaseModel):
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template)
|
||||
from lmdeploy.turbomind import TurboMind
|
||||
from lmdeploy.version import version_info
|
||||
|
||||
if engine_config is not None:
|
||||
from lmdeploy.messages import TurbomindEngineConfig
|
||||
@ -70,6 +71,7 @@ class TurboMindModel(BaseModel):
|
||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||
self.gen_config = gen_config
|
||||
self.end_str = end_str
|
||||
self.major_version, self.minor_version, _ = version_info
|
||||
|
||||
def generate(self,
|
||||
inputs: List[str],
|
||||
@ -165,7 +167,10 @@ class TurboMindModel(BaseModel):
|
||||
sequence_end=True,
|
||||
step=0,
|
||||
stream_output=False):
|
||||
_, output_ids, _ = outputs
|
||||
if self.major_version >= 0 and self.minor_version >= 4:
|
||||
output_ids = outputs.token_ids
|
||||
else:
|
||||
_, output_ids, _ = outputs
|
||||
response = self.tokenizer.decode(output_ids)
|
||||
response = valid_str(response)
|
||||
# used to trim
|
||||
|
Loading…
Reference in New Issue
Block a user