[Feature] Allow explicitly setting the temperature for API model (#121)

* allow explicitly setting the temperature

* update
This commit is contained in:
Haodong Duan 2023-07-28 11:28:15 +08:00 committed by GitHub
parent 80ce18f860
commit 46c9645753
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,6 +12,7 @@ from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel from .base_api import BaseAPIModel
PromptType = Union[PromptList, str] PromptType = Union[PromptList, str]
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
@MODELS.register_module() @MODELS.register_module()
@ -40,21 +41,24 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions. wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'. 'https://api.openai.com/v1/chat/completions'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
""" """
is_api: bool = True is_api: bool = True
def __init__( def __init__(self,
self, path: str,
path: str, max_seq_len: int = 2048,
max_seq_len: int = 2048, query_per_second: int = 1,
query_per_second: int = 1, retry: int = 2,
retry: int = 2, key: Union[str, List[str]] = 'ENV',
key: Union[str, List[str]] = 'ENV', org: Optional[Union[str, List[str]]] = None,
org: Optional[Union[str, List[str]]] = None, meta_template: Optional[Dict] = None,
meta_template: Optional[Dict] = None, openai_api_base: str = OPENAI_API_BASE,
openai_api_base: str = 'https://api.openai.com/v1/chat/completions' temperature: Optional[float] = None):
): # noqa
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template, meta_template=meta_template,
@ -62,6 +66,7 @@ class OpenAI(BaseAPIModel):
retry=retry) retry=retry)
import tiktoken import tiktoken
self.tiktoken = tiktoken self.tiktoken = tiktoken
self.temperature = temperature
if isinstance(key, str): if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
@ -96,6 +101,9 @@ class OpenAI(BaseAPIModel):
Returns: Returns:
List[str]: A list of generated strings. List[str]: A list of generated strings.
""" """
if self.temperature is not None:
temperature = self.temperature
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
results = list( results = list(
executor.map(self._generate, inputs, executor.map(self._generate, inputs,