mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support load PEFT adapter for HuggingFace model (#74)
* support peft for HuggingFace model * add docstring
This commit is contained in:
parent
f36c0496f3
commit
26e2f171f4
@ -25,6 +25,9 @@ class HuggingFace(BaseModel):
|
|||||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||||
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||||
Defaults to {}.
|
Defaults to {}.
|
||||||
|
peft_path (str, optional): The name or path to the HuggingFace's PEFT
|
||||||
|
model. If None, the original model will not be converted to PEFT.
|
||||||
|
Defaults to None.
|
||||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
||||||
@ -51,6 +54,7 @@ class HuggingFace(BaseModel):
|
|||||||
max_seq_len: int = 2048,
|
max_seq_len: int = 2048,
|
||||||
tokenizer_path: Optional[str] = None,
|
tokenizer_path: Optional[str] = None,
|
||||||
tokenizer_kwargs: dict = dict(),
|
tokenizer_kwargs: dict = dict(),
|
||||||
|
peft_path: Optional[str] = None,
|
||||||
tokenizer_only: bool = False,
|
tokenizer_only: bool = False,
|
||||||
model_kwargs: dict = dict(device_map='auto'),
|
model_kwargs: dict = dict(device_map='auto'),
|
||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
@ -71,7 +75,9 @@ class HuggingFace(BaseModel):
|
|||||||
self.batch_padding = batch_padding
|
self.batch_padding = batch_padding
|
||||||
self.extract_pred_after_decode = extract_pred_after_decode
|
self.extract_pred_after_decode = extract_pred_after_decode
|
||||||
if not tokenizer_only:
|
if not tokenizer_only:
|
||||||
self._load_model(path=path, model_kwargs=model_kwargs)
|
self._load_model(path=path,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
peft_path=peft_path)
|
||||||
|
|
||||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
||||||
tokenizer_kwargs: dict):
|
tokenizer_kwargs: dict):
|
||||||
@ -94,11 +100,19 @@ class HuggingFace(BaseModel):
|
|||||||
self.tokenizer.eos_token = '</s>'
|
self.tokenizer.eos_token = '</s>'
|
||||||
self.tokenizer.pad_token_id = 0
|
self.tokenizer.pad_token_id = 0
|
||||||
|
|
||||||
def _load_model(self, path: str, model_kwargs: dict):
|
def _load_model(self,
|
||||||
|
path: str,
|
||||||
|
model_kwargs: dict,
|
||||||
|
peft_path: Optional[str] = None):
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||||
|
if peft_path is not None:
|
||||||
|
from peft import PeftModel
|
||||||
|
self.model = PeftModel.from_pretrained(self.model,
|
||||||
|
peft_path,
|
||||||
|
is_trainable=False)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# A patch for llama when batch_padding = True
|
# A patch for llama when batch_padding = True
|
||||||
@ -184,7 +198,8 @@ class HuggingFace(BaseModel):
|
|||||||
max_length=self.max_seq_len -
|
max_length=self.max_seq_len -
|
||||||
max_out_len)['input_ids']
|
max_out_len)['input_ids']
|
||||||
input_ids = torch.tensor(input_ids, device=self.model.device)
|
input_ids = torch.tensor(input_ids, device=self.model.device)
|
||||||
outputs = self.model.generate(input_ids, max_new_tokens=max_out_len)
|
outputs = self.model.generate(input_ids=input_ids,
|
||||||
|
max_new_tokens=max_out_len)
|
||||||
|
|
||||||
if not self.extract_pred_after_decode:
|
if not self.extract_pred_after_decode:
|
||||||
outputs = outputs[:, input_ids.shape[1]:]
|
outputs = outputs[:, input_ids.shape[1]:]
|
||||||
@ -318,6 +333,9 @@ class HuggingFaceCausalLM(HuggingFace):
|
|||||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||||
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||||
Defaults to {}.
|
Defaults to {}.
|
||||||
|
peft_path (str, optional): The name or path to the HuggingFace's PEFT
|
||||||
|
model. If None, the original model will not be converted to PEFT.
|
||||||
|
Defaults to None.
|
||||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
||||||
@ -329,10 +347,17 @@ class HuggingFaceCausalLM(HuggingFace):
|
|||||||
without batch padding.
|
without batch padding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _load_model(self, path: str, model_kwargs: dict):
|
def _load_model(self,
|
||||||
|
path: str,
|
||||||
|
model_kwargs: dict,
|
||||||
|
peft_path: Optional[str] = None):
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||||
|
if peft_path is not None:
|
||||||
|
from peft import PeftModel
|
||||||
|
self.model = PeftModel.from_pretrained(self.model,
|
||||||
|
peft_path,
|
||||||
|
is_trainable=False)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user