2023-07-04 21:34:55 +08:00
|
|
|
|
# 支持新模型
|
|
|
|
|
|
2023-07-06 12:58:44 +08:00
|
|
|
|
目前我们已经支持的模型有 HF 模型、部分模型 API 、部分第三方模型。
|
2023-07-04 21:34:55 +08:00
|
|
|
|
|
|
|
|
|
## 新增API模型
|
|
|
|
|
|
|
|
|
|
新增基于API的模型,需要在 `opencompass/models` 下新建 `mymodel_api.py` 文件,继承 `BaseAPIModel`,并实现 `generate` 方法来进行推理,以及 `get_token_len` 方法来计算 token 的长度。在定义好之后修改对应配置文件名称即可。
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
from ..base_api import BaseAPIModel
|
|
|
|
|
|
|
|
|
|
class MyModelAPI(BaseAPIModel):
|
|
|
|
|
|
|
|
|
|
is_api: bool = True
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
path: str,
|
|
|
|
|
max_seq_len: int = 2048,
|
|
|
|
|
query_per_second: int = 1,
|
|
|
|
|
retry: int = 2,
|
|
|
|
|
**kwargs):
|
|
|
|
|
super().__init__(path=path,
|
|
|
|
|
max_seq_len=max_seq_len,
|
|
|
|
|
meta_template=meta_template,
|
|
|
|
|
query_per_second=query_per_second,
|
|
|
|
|
retry=retry)
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
|
self,
|
|
|
|
|
inputs,
|
|
|
|
|
max_out_len: int = 512,
|
|
|
|
|
temperature: float = 0.7,
|
|
|
|
|
) -> List[str]:
|
|
|
|
|
"""Generate results given a list of inputs."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def get_token_len(self, prompt: str) -> int:
|
|
|
|
|
"""Get lengths of the tokenized string."""
|
|
|
|
|
pass
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## 新增第三方模型
|
|
|
|
|
|
|
|
|
|
新增基于API的模型,需要在 `opencompass/models` 下新建 `mymodel.py` 文件,继承 `BaseModel`,并实现 `generate` 方法来进行生成式推理, `get_ppl` 方法来进行判别式推理,以及 `get_token_len` 方法来计算 token 的长度。在定义好之后修改对应配置文件名称即可。
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
from ..base import BaseModel
|
|
|
|
|
|
|
|
|
|
class MyModel(BaseModel):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
pkg_root: str,
|
|
|
|
|
ckpt_path: str,
|
|
|
|
|
tokenizer_only: bool = False,
|
|
|
|
|
meta_template: Optional[Dict] = None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
def get_token_len(self, prompt: str) -> int:
|
|
|
|
|
"""Get lengths of the tokenized strings."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
|
|
|
|
"""Generate results given a list of inputs. """
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def get_ppl(self,
|
|
|
|
|
inputs: List[str],
|
|
|
|
|
mask_length: Optional[List[int]] = None) -> List[float]:
|
|
|
|
|
"""Get perplexity scores given a list of inputs."""
|
|
|
|
|
pass
|
|
|
|
|
```
|