mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
74 lines
2.3 KiB
Markdown
74 lines
2.3 KiB
Markdown
# 支持新模型
|
||
|
||
目前我们已经支持的模型有 HF 模型、部分模型 API 、部分第三方模型。
|
||
|
||
## 新增API模型
|
||
|
||
新增基于API的模型,需要在 `opencompass/models` 下新建 `mymodel_api.py` 文件,继承 `BaseAPIModel`,并实现 `generate` 方法来进行推理,以及 `get_token_len` 方法来计算 token 的长度。在定义好之后修改对应配置文件名称即可。
|
||
|
||
```python
|
||
from ..base_api import BaseAPIModel
|
||
|
||
class MyModelAPI(BaseAPIModel):
|
||
|
||
is_api: bool = True
|
||
|
||
def __init__(self,
|
||
path: str,
|
||
max_seq_len: int = 2048,
|
||
query_per_second: int = 1,
|
||
retry: int = 2,
|
||
**kwargs):
|
||
super().__init__(path=path,
|
||
max_seq_len=max_seq_len,
|
||
meta_template=meta_template,
|
||
query_per_second=query_per_second,
|
||
retry=retry)
|
||
...
|
||
|
||
def generate(
|
||
self,
|
||
inputs,
|
||
max_out_len: int = 512,
|
||
temperature: float = 0.7,
|
||
) -> List[str]:
|
||
"""Generate results given a list of inputs."""
|
||
pass
|
||
|
||
def get_token_len(self, prompt: str) -> int:
|
||
"""Get lengths of the tokenized string."""
|
||
pass
|
||
```
|
||
|
||
## 新增第三方模型
|
||
|
||
新增基于API的模型,需要在 `opencompass/models` 下新建 `mymodel.py` 文件,继承 `BaseModel`,并实现 `generate` 方法来进行生成式推理, `get_ppl` 方法来进行判别式推理,以及 `get_token_len` 方法来计算 token 的长度。在定义好之后修改对应配置文件名称即可。
|
||
|
||
```python
|
||
from ..base import BaseModel
|
||
|
||
class MyModel(BaseModel):
|
||
|
||
def __init__(self,
|
||
pkg_root: str,
|
||
ckpt_path: str,
|
||
tokenizer_only: bool = False,
|
||
meta_template: Optional[Dict] = None,
|
||
**kwargs):
|
||
...
|
||
|
||
def get_token_len(self, prompt: str) -> int:
|
||
"""Get lengths of the tokenized strings."""
|
||
pass
|
||
|
||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||
"""Generate results given a list of inputs. """
|
||
pass
|
||
|
||
def get_ppl(self,
|
||
inputs: List[str],
|
||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||
"""Get perplexity scores given a list of inputs."""
|
||
pass
|
||
```
|