mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
adapt to lmdeploy v0.4.0 (#1073)
* adapt to lmdeploy v0.4.0 * compatible
This commit is contained in:
parent
58a57a4c45
commit
1013dce60c
@ -50,6 +50,7 @@ class LmdeployPytorchModel(BaseModel):
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
meta_template=meta_template)
|
meta_template=meta_template)
|
||||||
from lmdeploy.pytorch import engine as tm
|
from lmdeploy.pytorch import engine as tm
|
||||||
|
from lmdeploy.version import version_info
|
||||||
|
|
||||||
if engine_config is not None:
|
if engine_config is not None:
|
||||||
from lmdeploy.messages import PytorchEngineConfig
|
from lmdeploy.messages import PytorchEngineConfig
|
||||||
@ -71,6 +72,7 @@ class LmdeployPytorchModel(BaseModel):
|
|||||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||||
self.gen_config = gen_config
|
self.gen_config = gen_config
|
||||||
self.end_str = end_str
|
self.end_str = end_str
|
||||||
|
self.major_version, self.minor_version, _ = version_info
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -145,9 +147,16 @@ class LmdeployPytorchModel(BaseModel):
|
|||||||
assert type(
|
assert type(
|
||||||
prompt) is str, 'We only support string for TurboMind Python API'
|
prompt) is str, 'We only support string for TurboMind Python API'
|
||||||
input_ids = self.tokenizer.encode(prompt)
|
input_ids = self.tokenizer.encode(prompt)
|
||||||
_, output_ids, _ = generator.infer(session_id,
|
if self.major_version >= 0 and self.minor_version >= 4:
|
||||||
input_ids,
|
outputs = generator.infer(session_id,
|
||||||
gen_config=gen_config)
|
input_ids,
|
||||||
|
gen_config=gen_config)
|
||||||
|
output_ids = outputs.token_ids
|
||||||
|
else:
|
||||||
|
_, output_ids, _ = generator.infer(session_id,
|
||||||
|
input_ids,
|
||||||
|
gen_config=gen_config)
|
||||||
|
|
||||||
# stop engine
|
# stop engine
|
||||||
if hasattr(generator, 'end'):
|
if hasattr(generator, 'end'):
|
||||||
generator.end(session_id)
|
generator.end(session_id)
|
||||||
|
@ -54,6 +54,7 @@ class TurboMindModel(BaseModel):
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
meta_template=meta_template)
|
meta_template=meta_template)
|
||||||
from lmdeploy.turbomind import TurboMind
|
from lmdeploy.turbomind import TurboMind
|
||||||
|
from lmdeploy.version import version_info
|
||||||
|
|
||||||
if engine_config is not None:
|
if engine_config is not None:
|
||||||
from lmdeploy.messages import TurbomindEngineConfig
|
from lmdeploy.messages import TurbomindEngineConfig
|
||||||
@ -70,6 +71,7 @@ class TurboMindModel(BaseModel):
|
|||||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||||
self.gen_config = gen_config
|
self.gen_config = gen_config
|
||||||
self.end_str = end_str
|
self.end_str = end_str
|
||||||
|
self.major_version, self.minor_version, _ = version_info
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
inputs: List[str],
|
inputs: List[str],
|
||||||
@ -165,7 +167,10 @@ class TurboMindModel(BaseModel):
|
|||||||
sequence_end=True,
|
sequence_end=True,
|
||||||
step=0,
|
step=0,
|
||||||
stream_output=False):
|
stream_output=False):
|
||||||
_, output_ids, _ = outputs
|
if self.major_version >= 0 and self.minor_version >= 4:
|
||||||
|
output_ids = outputs.token_ids
|
||||||
|
else:
|
||||||
|
_, output_ids, _ = outputs
|
||||||
response = self.tokenizer.decode(output_ids)
|
response = self.tokenizer.decode(output_ids)
|
||||||
response = valid_str(response)
|
response = valid_str(response)
|
||||||
# used to trim
|
# used to trim
|
||||||
|
Loading…
Reference in New Issue
Block a user