mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
74 lines
2.5 KiB
Markdown
74 lines
2.5 KiB
Markdown
# Add a Model
|
|
|
|
Currently, we support HF models, some model APIs, and some third-party models.
|
|
|
|
## Adding API Models
|
|
|
|
To add a new API-based model, you need to create a new file named `mymodel_api.py` under `opencompass/models` directory. In this file, you should inherit from `BaseAPIModel` and implement the `generate` method for inference and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file.
|
|
|
|
```python
|
|
from ..base_api import BaseAPIModel
|
|
|
|
class MyModelAPI(BaseAPIModel):
|
|
|
|
is_api: bool = True
|
|
|
|
def __init__(self,
|
|
path: str,
|
|
max_seq_len: int = 2048,
|
|
query_per_second: int = 1,
|
|
retry: int = 2,
|
|
**kwargs):
|
|
super().__init__(path=path,
|
|
max_seq_len=max_seq_len,
|
|
meta_template=meta_template,
|
|
query_per_second=query_per_second,
|
|
retry=retry)
|
|
...
|
|
|
|
def generate(
|
|
self,
|
|
inputs,
|
|
max_out_len: int = 512,
|
|
temperature: float = 0.7,
|
|
) -> List[str]:
|
|
"""Generate results given a list of inputs."""
|
|
pass
|
|
|
|
def get_token_len(self, prompt: str) -> int:
|
|
"""Get lengths of the tokenized string."""
|
|
pass
|
|
```
|
|
|
|
## Adding Third-Party Models
|
|
|
|
To add a new third-party model, you need to create a new file named `mymodel.py` under `opencompass/models` directory. In this file, you should inherit from `BaseModel` and implement the `generate` method for generative inference, the `get_ppl` method for discriminative inference, and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file.
|
|
|
|
```python
|
|
from ..base import BaseModel
|
|
|
|
class MyModel(BaseModel):
|
|
|
|
def __init__(self,
|
|
pkg_root: str,
|
|
ckpt_path: str,
|
|
tokenizer_only: bool = False,
|
|
meta_template: Optional[Dict] = None,
|
|
**kwargs):
|
|
...
|
|
|
|
def get_token_len(self, prompt: str) -> int:
|
|
"""Get lengths of the tokenized strings."""
|
|
pass
|
|
|
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
|
"""Generate results given a list of inputs. """
|
|
pass
|
|
|
|
def get_ppl(self,
|
|
inputs: List[str],
|
|
mask_length: Optional[List[int]] = None) -> List[float]:
|
|
"""Get perplexity scores given a list of inputs."""
|
|
pass
|
|
```
|