56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
|
prompt = "你好,我的名字是"
|
|||
|
model_path = './model_ckpt/hairuo'
|
|||
|
# run_type = 'vllm'
|
|||
|
run_type = 'transformers'
|
|||
|
|
|||
|
if run_type == 'transformers':
|
|||
|
from ihp.zoo.hairuo import HairuoTokenizer
|
|||
|
from ihp.zoo.hairuo import HairuoForCausalLM
|
|||
|
|
|||
|
model = HairuoForCausalLM.from_pretrained(model_path)
|
|||
|
tokenizer = HairuoTokenizer.from_pretrained(model_path)
|
|||
|
|
|||
|
model.requires_grad_(False)
|
|||
|
model.eval()
|
|||
|
|
|||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
|||
|
generate_ids = model.generate(inputs.input_ids, attention_mask = inputs.attention_mask, max_length=200, temperature=0.8, do_sample=True, eos_token_id=151644, pad_token_id=151644)
|
|||
|
generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|||
|
print(generated_text)
|
|||
|
|
|||
|
|
|||
|
if run_type == 'vllm':
|
|||
|
|
|||
|
# 载入 LLM 和 SamplingParams
|
|||
|
from vllm import LLM, SamplingParams
|
|||
|
from vllm import ModelRegistry
|
|||
|
from ihp.zoo.hairuo.vllm_hairuo import HairuoForCausalLM
|
|||
|
ModelRegistry.register_model("HairuoForCausalLM", HairuoForCausalLM)
|
|||
|
# 推理数据以List[str]格式组织
|
|||
|
prompts = [
|
|||
|
"你好,我的名字是",
|
|||
|
"The president of the United States is",
|
|||
|
"The capital of France is",
|
|||
|
"AI的未来是什么?",
|
|||
|
]
|
|||
|
# 设置采样参数
|
|||
|
sampling_params = SamplingParams(temperature=0.8, top_p=1)
|
|||
|
# 加载模型
|
|||
|
llm = LLM(
|
|||
|
model=model_path,
|
|||
|
trust_remote_code=True,
|
|||
|
tensor_parallel_size=1,
|
|||
|
# dtype='float32',
|
|||
|
gpu_memory_utilization=0.95,
|
|||
|
max_model_len=100,
|
|||
|
enforce_eager=True,
|
|||
|
)
|
|||
|
# 执行推理
|
|||
|
outputs = llm.generate(prompts, sampling_params)
|
|||
|
|
|||
|
# 输出推理结果
|
|||
|
for output in outputs:
|
|||
|
prompt = output.prompt
|
|||
|
generated_text = output.outputs[0].text
|
|||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|