OpenCompass/docs/zh_cn/advanced_guides/new_model.md
Hubert 7f8eee4725
[Docs] add en docs (#15)
* add en docs

* update

---------

Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>
2023-07-06 12:58:44 +08:00

74 lines
2.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 支持新模型
目前我们已经支持的模型有 HF 模型、部分模型 API 、部分第三方模型。
## 新增API模型
新增基于API的模型需要在 `opencompass/models` 下新建 `mymodel_api.py` 文件,继承 `BaseAPIModel`,并实现 `generate` 方法来进行推理,以及 `get_token_len` 方法来计算 token 的长度。在定义好之后修改对应配置文件名称即可。
```python
from ..base_api import BaseAPIModel
class MyModelAPI(BaseAPIModel):
is_api: bool = True
def __init__(self,
path: str,
max_seq_len: int = 2048,
query_per_second: int = 1,
retry: int = 2,
**kwargs):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry)
...
def generate(
self,
inputs,
max_out_len: int = 512,
temperature: float = 0.7,
) -> List[str]:
"""Generate results given a list of inputs."""
pass
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string."""
pass
```
## 新增第三方模型
新增基于API的模型需要在 `opencompass/models` 下新建 `mymodel.py` 文件,继承 `BaseModel`,并实现 `generate` 方法来进行生成式推理, `get_ppl` 方法来进行判别式推理,以及 `get_token_len` 方法来计算 token 的长度。在定义好之后修改对应配置文件名称即可。
```python
from ..base import BaseModel
class MyModel(BaseModel):
def __init__(self,
pkg_root: str,
ckpt_path: str,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None,
**kwargs):
...
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings."""
pass
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
"""Generate results given a list of inputs. """
pass
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Get perplexity scores given a list of inputs."""
pass
```