From 001e77fea236276aa8018b34cd23076145ab1672 Mon Sep 17 00:00:00 2001 From: bittersweet1999 <148421775+bittersweet1999@users.noreply.github.com> Date: Wed, 28 Feb 2024 19:38:34 +0800 Subject: [PATCH] [Feature] add support for gemini (#931) * add gemini * add gemini * add gemini --- .../alignbench_judgeby_critiquellm.py | 4 +- configs/models/gemini/gemini_pro.py | 23 ++ opencompass/models/__init__.py | 1 + opencompass/models/gemini_api.py | 251 ++++++++++++++++++ 4 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 configs/models/gemini/gemini_pro.py create mode 100644 opencompass/models/gemini_api.py diff --git a/configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py b/configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py index df710138..522a36eb 100644 --- a/configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py +++ b/configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py @@ -14,8 +14,8 @@ subjective_all_sets = [ ] data_path ="data/subjective/alignment_bench" -alignment_bench_config_path = "data/subjective/alignment_bench/" -alignment_bench_config_name = 'config/multi-dimension' +alignment_bench_config_path = "data/subjective/alignment_bench/config" +alignment_bench_config_name = 'multi-dimension' subjective_datasets = [] diff --git a/configs/models/gemini/gemini_pro.py b/configs/models/gemini/gemini_pro.py new file mode 100644 index 00000000..d4861540 --- /dev/null +++ b/configs/models/gemini/gemini_pro.py @@ -0,0 +1,23 @@ +from opencompass.models import Gemini + + +api_meta_template = dict( + round=[ + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ], +) + +models = [ + dict(abbr='gemini', + type=Gemini, + path='gemini-pro', + key='your keys', # The key will be obtained from Environment, but you can write down your key here as well + url = "your url", + meta_template=api_meta_template, + query_per_second=16, + max_out_len=100, + max_seq_len=2048, + batch_size=1, + temperature=1,) +] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 8f3f26e5..ab0f28b3 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -7,6 +7,7 @@ from .base import BaseModel, LMTemplateParser # noqa from .base_api import APITemplateParser, BaseAPIModel # noqa from .bytedance_api import ByteDance # noqa: F401 from .claude_api import Claude # noqa: F401 +from .gemini_api import Gemini, GeminiAllesAPIN # noqa: F401, F403 from .glm import GLM130B # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403 from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 diff --git a/opencompass/models/gemini_api.py b/opencompass/models/gemini_api.py new file mode 100644 index 00000000..bbb2c427 --- /dev/null +++ b/opencompass/models/gemini_api.py @@ -0,0 +1,251 @@ +# flake8: noqa: E501 +import json +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str, float] + + +class Gemini(BaseAPIModel): + """Model wrapper around Gemini models. + + Documentation: + + Args: + path (str): The name of Gemini model. + e.g. `gemini-pro` + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + key: str, + path: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + temperature: float = 1.0, + top_p: float = 0.8, + top_k: float = 10.0, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + self.url = f'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={key}' + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.headers = { + 'content-type': 'application/json', + } + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'parts': [{'text': input}]}] + else: + messages = [] + system_prompt = None + for item in input: + if item['role'] == 'SYSTEM': + system_prompt = item['prompt'] + for item in input: + if system_prompt is not None: + msg = { + 'parts': [{ + 'text': system_prompt + '\n' + item['prompt'] + }] + } + else: + msg = {'parts': [{'text': item['prompt']}]} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + messages.append(msg) + elif item['role'] == 'BOT': + msg['role'] = 'model' + messages.append(msg) + elif item['role'] == 'SYSTEM': + pass + + # model can be response with user and system + # when it comes with agent involved. + assert msg['role'] in ['user', 'system'] + + data = { + 'model': + self.path, + 'contents': + messages, + 'safetySettings': [ + { + 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', + 'threshold': 'BLOCK_NONE' + }, + { + 'category': 'HARM_CATEGORY_HATE_SPEECH', + 'threshold': 'BLOCK_NONE' + }, + { + 'category': 'HARM_CATEGORY_HARASSMENT', + 'threshold': 'BLOCK_NONE' + }, + { + 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', + 'threshold': 'BLOCK_NONE' + }, + ], + 'generationConfig': { + 'candidate_count': 1, + 'temperature': self.temperature, + 'maxOutputTokens': 2048, + 'topP': self.top_p, + 'topK': self.top_k + } + } + + for _ in range(self.retry): + self.wait() + raw_response = requests.post(self.url, + headers=self.headers, + data=json.dumps(data)) + try: + response = raw_response.json() + except requests.JSONDecodeError: + self.logger.error('JsonDecode error, got', + str(raw_response.content)) + time.sleep(1) + continue + if raw_response.status_code == 200 and response['msg'] == 'ok': + body = response['body'] + if 'candidates' not in body: + self.logger.error(response) + else: + if 'content' not in body['candidates'][0]: + return "Due to Google's restrictive policies, I am unable to respond to this question." + else: + return body['candidates'][0]['content']['parts'][0][ + 'text'].strip() + self.logger.error(response['msg']) + self.logger.error(response) + time.sleep(1) + + raise RuntimeError('API call failed.') + + +class GeminiAllesAPIN(Gemini): + """Model wrapper around Gemini models. + + Documentation: + + Args: + path (str): The name of Gemini model. + e.g. `gemini-pro` + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + key: str, + url: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + temperature: float = 1.0, + top_p: float = 0.8, + top_k: float = 10.0, + ): + super().__init__(key=key, + path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + # Replace the url and headers into AllesApin + self.url = url + self.headers = { + 'alles-apin-token': key, + 'content-type': 'application/json', + } + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + return super().generate(inputs, max_out_len)