[Feature] Allow explicitly setting the temperature for API model (#121)

* allow explicitly setting the temperature

* update
This commit is contained in:
Haodong Duan 2023-07-28 11:28:15 +08:00 committed by GitHub
parent 80ce18f860
commit 46c9645753
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,6 +12,7 @@ from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
@MODELS.register_module()
@ -40,21 +41,24 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
"""
is_api: bool = True
def __init__(
self,
path: str,
max_seq_len: int = 2048,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None,
openai_api_base: str = 'https://api.openai.com/v1/chat/completions'
): # noqa
def __init__(self,
path: str,
max_seq_len: int = 2048,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None,
openai_api_base: str = OPENAI_API_BASE,
temperature: Optional[float] = None):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
@ -62,6 +66,7 @@ class OpenAI(BaseAPIModel):
retry=retry)
import tiktoken
self.tiktoken = tiktoken
self.temperature = temperature
if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
@ -96,6 +101,9 @@ class OpenAI(BaseAPIModel):
Returns:
List[str]: A list of generated strings.
"""
if self.temperature is not None:
temperature = self.temperature
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,