Adding support for fireworks and vllm models and sample for wiki model eval

This commit is contained in:
Arshdeep Singh 2024-04-22 21:15:33 -04:00
parent 004ed79593
commit fc95a851a7
8 changed files with 254 additions and 0 deletions

26
configs/eval_fireworks.py Normal file
View File

@ -0,0 +1,26 @@
from mmengine.config import read_base
from opencompass.models import Fireworks
from opencompass.partitioners import NaivePartitioner
from opencompass.runners.local_api import LocalAPIRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
from .datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets
from .datasets.CLUE_ocnli.CLUE_ocnli_gen import ocnli_datasets
from .datasets.FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_gen import ocnli_fc_datasets
from .datasets.SuperGLUE_AX_b.SuperGLUE_AX_b_gen import AX_b_datasets
from .datasets.SuperGLUE_AX_g.SuperGLUE_AX_g_gen import AX_g_datasets
from .datasets.SuperGLUE_CB.SuperGLUE_CB_gen import CB_datasets
from .datasets.SuperGLUE_RTE.SuperGLUE_RTE_gen import RTE_datasets
from .datasets.anli.anli_gen import anli_datasets
datasets = [*CB_datasets]
models = [
dict(abbr='mistral-7b',
type=Fireworks, path='accounts/fireworks/models/mistral-7b',
key='ENV',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=8),
]
work_dir = "outputs/api_mistral_7b/"

8
configs/eval_mistral.py Normal file
View File

@ -0,0 +1,8 @@
from mmengine.config import read_base
with read_base():
from .datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets
from .models.mistral.hf_mistral_7b_v0_1 import models
datasets = [*cmnli_datasets]

View File

@ -0,0 +1,27 @@
from mmengine.config import read_base
from opencompass.models import VLLM_OPENAI
from opencompass.partitioners import NaivePartitioner
from opencompass.runners.local_api import LocalAPIRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
from .datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets
# from .models.fireworks.mistral_7b import mistral_7b
from .datasets.CLUE_ocnli.CLUE_ocnli_gen import ocnli_datasets
from .datasets.FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_gen import ocnli_fc_datasets
from .datasets.SuperGLUE_AX_b.SuperGLUE_AX_b_gen import AX_b_datasets
from .datasets.SuperGLUE_AX_g.SuperGLUE_AX_g_gen import AX_g_datasets
from .datasets.SuperGLUE_CB.SuperGLUE_CB_gen import CB_datasets
from .datasets.SuperGLUE_RTE.SuperGLUE_RTE_gen import RTE_datasets
from .datasets.anli.anli_gen import anli_datasets
datasets = [*ocnli_datasets,*ocnli_fc_datasets]
models = [
dict(abbr='full-finetuned',
type=VLLM_OPENAI, path='model',
key='ENV',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=8),
]
work_dir = "outputs/api_full_finetuned_ift_20k/"

27
configs/eval_wikisft.py Normal file
View File

@ -0,0 +1,27 @@
from mmengine.config import read_base
from opencompass.models import Fireworks
from opencompass.partitioners import NaivePartitioner
from opencompass.runners.local_api import LocalAPIRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
from .datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets
# from .models.fireworks.mistral_7b import mistral_7b
from .datasets.CLUE_ocnli.CLUE_ocnli_gen import ocnli_datasets
from .datasets.FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_gen import ocnli_fc_datasets
from .datasets.SuperGLUE_AX_b.SuperGLUE_AX_b_gen import AX_b_datasets
from .datasets.SuperGLUE_AX_g.SuperGLUE_AX_g_gen import AX_g_datasets
from .datasets.SuperGLUE_CB.SuperGLUE_CB_gen import CB_datasets
from .datasets.SuperGLUE_RTE.SuperGLUE_RTE_gen import RTE_datasets
from .datasets.anli.anli_gen import anli_datasets
datasets = [*CB_datasets]
models = [
dict(abbr='mistral_wiki_sft',
type=Fireworks, path='path/to/fireworks/api',
key='ENV',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=8),
]
work_dir = "outputs/api_wikisft/"

View File

@ -0,0 +1,23 @@
from opencompass.models import Fireworks
from opencompass.partitioners import NaivePartitioner
from opencompass.runners.local_api import LocalAPIRunner
from opencompass.tasks import OpenICLInferTask
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
],
)
mistral_7b = [
dict(abbr='mistral-7b',
type=Fireworks, path='accounts/fireworks/models/mistral-7b',
key='ENV', # The key will be obtained from $FIREWORKS_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048,
max_seq_len=4096,
batch_size=8),
]

View File

@ -36,3 +36,5 @@ from .xunfei_api import XunFei # noqa: F401
from .yayi_api import Yayi # noqa: F401
from .zhipuai_api import ZhiPuAI # noqa: F401
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
from .fireworks_api import Fireworks
from .vllm_openai_api import VLLM_OPENAI

View File

@ -0,0 +1,67 @@
import requests
from typing import List
from .base_api import BaseAPIModel
import fireworks.client
import os
import tiktoken
from opencompass.registry import MODELS
from typing import Dict, List, Optional, Union
NUM_ALLOWED_TOKENS_GPT_4 = 8192
# Load environment variables
fireworks.client.api_key = os.getenv('FIREWORKS_API_KEY')
class LLMError(Exception):
"""A custom exception used to report errors in use of Large Language Model class"""
@MODELS.register_module()
class Fireworks(BaseAPIModel):
is_api: bool = True
def __init__(self,
path: str = "accounts/fireworks/models/mistral-7b",
max_seq_len: int = 2048,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
**kwargs):
super().__init__(path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
retry=retry,
**kwargs)
self.model_name = path
self.tokenizer = tiktoken.encoding_for_model("gpt-4")
self.retry = retry
def query_fireworks(
self,
prompt: str,
max_tokens: int = 512,
temperature: float = 0.01,
) -> str:
if max_tokens == 0:
max_tokens = NUM_ALLOWED_TOKENS_GPT_4 # Adjust based on your model's capabilities
completion = fireworks.client.Completion.create(
model=self.model_name,
prompt=prompt,
n=1,
max_tokens=max_tokens,
temperature=temperature,
)
if len(completion.choices) != 1:
raise LLMError("Unexpected number of choices returned by Fireworks.")
return completion.choices[0].text
def generate(
self,
inputs,
max_out_len: int = 512,
temperature: float = 0,
) -> List[str]:
"""Generate results given a list of inputs."""
outputs = []
for input_text in inputs:
try:
output = self.query_fireworks(prompt=input_text, max_tokens=max_out_len, temperature=temperature)
outputs.append(output)
except Exception as e:
print(f"Failed to generate output for input: {input_text} due to {e}")
return outputs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string."""
return len(self.tokenizer.encode(prompt))

View File

@ -0,0 +1,74 @@
from typing import List
from .base_api import BaseAPIModel
from openai import OpenAI
import os
import tiktoken
from opencompass.registry import MODELS
from typing import Dict, List, Optional, Union
NUM_ALLOWED_TOKENS_GPT_4 = 8192
class LLMError(Exception):
"""A custom exception used to report errors in use of Large Language Model class"""
@MODELS.register_module()
class VLLM_OPENAI(BaseAPIModel):
is_api: bool = True
def __init__(self,
path: str = "model",
max_seq_len: int = 2048,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
**kwargs):
super().__init__(path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
retry=retry,
**kwargs)
self.model_name = path
self.tokenizer = tiktoken.encoding_for_model("gpt-4")
self.retry = retry
if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
else:
self.keys = key
def query_vllm(
self,
prompt: str,
temperature: float,
max_tokens: int,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
client = OpenAI(
api_key=os.getenv('OPENAI_API_KEY'),
base_url=os.getenv('OPENAI_API_BASE'),
)
response = client.completions.create( # type: ignore
prompt=prompt,
model=self.model_name,
temperature=temperature,
max_tokens=max_tokens,
)
answer: str = response.choices[0].text
return answer
def generate(
self,
inputs,
max_out_len: int = 512,
temperature: float = 0,
) -> List[str]:
"""Generate results given a list of inputs."""
outputs = []
for input_text in inputs:
try:
output = self.query_vllm(prompt=input_text, max_tokens=max_out_len, temperature=temperature)
outputs.append(output)
except Exception as e:
print(f"Failed to generate output for input: {input_text} due to {e}")
return outputs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string."""
return len(self.tokenizer.encode(prompt))