diff --git a/opencompass/models/openai_api.py b/opencompass/models/openai_api.py index 364cc7fb..5e3c1207 100644 --- a/opencompass/models/openai_api.py +++ b/opencompass/models/openai_api.py @@ -12,6 +12,7 @@ from opencompass.utils.prompt import PromptList from .base_api import BaseAPIModel PromptType = Union[PromptList, str] +OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' @MODELS.register_module() @@ -40,21 +41,24 @@ class OpenAI(BaseAPIModel): wrapping of any meta instructions. openai_api_base (str): The base url of OpenAI's API. Defaults to 'https://api.openai.com/v1/chat/completions'. + temperature (float, optional): What sampling temperature to use. + If not None, will override the temperature in the `generate()` + call. Defaults to None. """ is_api: bool = True - def __init__( - self, - path: str, - max_seq_len: int = 2048, - query_per_second: int = 1, - retry: int = 2, - key: Union[str, List[str]] = 'ENV', - org: Optional[Union[str, List[str]]] = None, - meta_template: Optional[Dict] = None, - openai_api_base: str = 'https://api.openai.com/v1/chat/completions' - ): # noqa + def __init__(self, + path: str, + max_seq_len: int = 2048, + query_per_second: int = 1, + retry: int = 2, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = None, + openai_api_base: str = OPENAI_API_BASE, + temperature: Optional[float] = None): + super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, @@ -62,6 +66,7 @@ class OpenAI(BaseAPIModel): retry=retry) import tiktoken self.tiktoken = tiktoken + self.temperature = temperature if isinstance(key, str): self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] @@ -96,6 +101,9 @@ class OpenAI(BaseAPIModel): Returns: List[str]: A list of generated strings. """ + if self.temperature is not None: + temperature = self.temperature + with ThreadPoolExecutor() as executor: results = list( executor.map(self._generate, inputs,