mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support load PEFT adapter for HuggingFace model (#74)
* support peft for HuggingFace model * add docstring
This commit is contained in:
parent
f36c0496f3
commit
26e2f171f4
@ -25,6 +25,9 @@ class HuggingFace(BaseModel):
|
||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||
Defaults to {}.
|
||||
peft_path (str, optional): The name or path to the HuggingFace's PEFT
|
||||
model. If None, the original model will not be converted to PEFT.
|
||||
Defaults to None.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
||||
@ -51,6 +54,7 @@ class HuggingFace(BaseModel):
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
peft_path: Optional[str] = None,
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = None,
|
||||
@ -71,7 +75,9 @@ class HuggingFace(BaseModel):
|
||||
self.batch_padding = batch_padding
|
||||
self.extract_pred_after_decode = extract_pred_after_decode
|
||||
if not tokenizer_only:
|
||||
self._load_model(path=path, model_kwargs=model_kwargs)
|
||||
self._load_model(path=path,
|
||||
model_kwargs=model_kwargs,
|
||||
peft_path=peft_path)
|
||||
|
||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
||||
tokenizer_kwargs: dict):
|
||||
@ -94,11 +100,19 @@ class HuggingFace(BaseModel):
|
||||
self.tokenizer.eos_token = '</s>'
|
||||
self.tokenizer.pad_token_id = 0
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
def _load_model(self,
|
||||
path: str,
|
||||
model_kwargs: dict,
|
||||
peft_path: Optional[str] = None):
|
||||
from transformers import AutoModel
|
||||
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||
if peft_path is not None:
|
||||
from peft import PeftModel
|
||||
self.model = PeftModel.from_pretrained(self.model,
|
||||
peft_path,
|
||||
is_trainable=False)
|
||||
self.model.eval()
|
||||
|
||||
# A patch for llama when batch_padding = True
|
||||
@ -184,7 +198,8 @@ class HuggingFace(BaseModel):
|
||||
max_length=self.max_seq_len -
|
||||
max_out_len)['input_ids']
|
||||
input_ids = torch.tensor(input_ids, device=self.model.device)
|
||||
outputs = self.model.generate(input_ids, max_new_tokens=max_out_len)
|
||||
outputs = self.model.generate(input_ids=input_ids,
|
||||
max_new_tokens=max_out_len)
|
||||
|
||||
if not self.extract_pred_after_decode:
|
||||
outputs = outputs[:, input_ids.shape[1]:]
|
||||
@ -318,6 +333,9 @@ class HuggingFaceCausalLM(HuggingFace):
|
||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||
Defaults to {}.
|
||||
peft_path (str, optional): The name or path to the HuggingFace's PEFT
|
||||
model. If None, the original model will not be converted to PEFT.
|
||||
Defaults to None.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
||||
@ -329,10 +347,17 @@ class HuggingFaceCausalLM(HuggingFace):
|
||||
without batch padding.
|
||||
"""
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
def _load_model(self,
|
||||
path: str,
|
||||
model_kwargs: dict,
|
||||
peft_path: Optional[str] = None):
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
|
||||
if peft_path is not None:
|
||||
from peft import PeftModel
|
||||
self.model = PeftModel.from_pretrained(self.model,
|
||||
peft_path,
|
||||
is_trainable=False)
|
||||
self.model.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user