mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Allow explicitly setting the temperature for API model (#121)
* allow explicitly setting the temperature * update
This commit is contained in:
parent
80ce18f860
commit
46c9645753
@ -12,6 +12,7 @@ from opencompass.utils.prompt import PromptList
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
@ -40,21 +41,24 @@ class OpenAI(BaseAPIModel):
|
||||
wrapping of any meta instructions.
|
||||
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
||||
'https://api.openai.com/v1/chat/completions'.
|
||||
temperature (float, optional): What sampling temperature to use.
|
||||
If not None, will override the temperature in the `generate()`
|
||||
call. Defaults to None.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
key: Union[str, List[str]] = 'ENV',
|
||||
org: Optional[Union[str, List[str]]] = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
openai_api_base: str = 'https://api.openai.com/v1/chat/completions'
|
||||
): # noqa
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
key: Union[str, List[str]] = 'ENV',
|
||||
org: Optional[Union[str, List[str]]] = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
temperature: Optional[float] = None):
|
||||
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
@ -62,6 +66,7 @@ class OpenAI(BaseAPIModel):
|
||||
retry=retry)
|
||||
import tiktoken
|
||||
self.tiktoken = tiktoken
|
||||
self.temperature = temperature
|
||||
|
||||
if isinstance(key, str):
|
||||
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
||||
@ -96,6 +101,9 @@ class OpenAI(BaseAPIModel):
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
if self.temperature is not None:
|
||||
temperature = self.temperature
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
|
Loading…
Reference in New Issue
Block a user