mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
![]() |
import requests
|
||
|
from typing import List
|
||
|
from .base_api import BaseAPIModel
|
||
|
import fireworks.client
|
||
|
import os
|
||
|
import tiktoken
|
||
|
from opencompass.registry import MODELS
|
||
|
from typing import Dict, List, Optional, Union
|
||
|
NUM_ALLOWED_TOKENS_GPT_4 = 8192
|
||
|
# Load environment variables
|
||
|
fireworks.client.api_key = os.getenv('FIREWORKS_API_KEY')
|
||
|
class LLMError(Exception):
|
||
|
"""A custom exception used to report errors in use of Large Language Model class"""
|
||
|
@MODELS.register_module()
|
||
|
class Fireworks(BaseAPIModel):
|
||
|
is_api: bool = True
|
||
|
def __init__(self,
|
||
|
path: str = "accounts/fireworks/models/mistral-7b",
|
||
|
max_seq_len: int = 2048,
|
||
|
query_per_second: int = 1,
|
||
|
retry: int = 2,
|
||
|
key: Union[str, List[str]] = 'ENV',
|
||
|
**kwargs):
|
||
|
super().__init__(path=path,
|
||
|
max_seq_len=max_seq_len,
|
||
|
query_per_second=query_per_second,
|
||
|
retry=retry,
|
||
|
**kwargs)
|
||
|
self.model_name = path
|
||
|
self.tokenizer = tiktoken.encoding_for_model("gpt-4")
|
||
|
self.retry = retry
|
||
|
def query_fireworks(
|
||
|
self,
|
||
|
prompt: str,
|
||
|
max_tokens: int = 512,
|
||
|
temperature: float = 0.01,
|
||
|
) -> str:
|
||
|
if max_tokens == 0:
|
||
|
max_tokens = NUM_ALLOWED_TOKENS_GPT_4 # Adjust based on your model's capabilities
|
||
|
completion = fireworks.client.Completion.create(
|
||
|
model=self.model_name,
|
||
|
prompt=prompt,
|
||
|
n=1,
|
||
|
max_tokens=max_tokens,
|
||
|
temperature=temperature,
|
||
|
)
|
||
|
if len(completion.choices) != 1:
|
||
|
raise LLMError("Unexpected number of choices returned by Fireworks.")
|
||
|
return completion.choices[0].text
|
||
|
def generate(
|
||
|
self,
|
||
|
inputs,
|
||
|
max_out_len: int = 512,
|
||
|
temperature: float = 0,
|
||
|
) -> List[str]:
|
||
|
"""Generate results given a list of inputs."""
|
||
|
outputs = []
|
||
|
for input_text in inputs:
|
||
|
try:
|
||
|
output = self.query_fireworks(prompt=input_text, max_tokens=max_out_len, temperature=temperature)
|
||
|
outputs.append(output)
|
||
|
except Exception as e:
|
||
|
print(f"Failed to generate output for input: {input_text} due to {e}")
|
||
|
return outputs
|
||
|
def get_token_len(self, prompt: str) -> int:
|
||
|
"""Get lengths of the tokenized string."""
|
||
|
return len(self.tokenizer.encode(prompt))
|