[Enhancement] Skip invalid keys to avoid requesting API (#184)

* Skip invalid keys to avoid requesting API

* get expected key

* print warning info
This commit is contained in:
Zaida Zhou 2023-08-10 18:41:43 +08:00 committed by GitHub
parent 0406e4e7ed
commit f256abffd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -73,6 +73,11 @@ class OpenAI(BaseAPIModel):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
else: else:
self.keys = key self.keys = key
# record invalid keys and skip them when requesting API
# - keys have insufficient_quota
self.invalid_keys = set()
self.key_ctr = 0 self.key_ctr = 0
if isinstance(org, str): if isinstance(org, str):
self.orgs = [org] self.orgs = [org]
@ -164,15 +169,27 @@ class OpenAI(BaseAPIModel):
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.wait() self.wait()
if hasattr(self, 'keys'):
with Lock(): with Lock():
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')
# find the next valid key
while True:
self.key_ctr += 1 self.key_ctr += 1
if self.key_ctr == len(self.keys): if self.key_ctr == len(self.keys):
self.key_ctr = 0 self.key_ctr = 0
header = {
'Authorization': f'Bearer {self.keys[self.key_ctr]}', if self.keys[self.key_ctr] not in self.invalid_keys:
'content-type': 'application/json', break
}
key = self.keys[self.key_ctr]
header = {
'Authorization': f'Bearer {key}',
'content-type': 'application/json',
}
if self.orgs: if self.orgs:
with Lock(): with Lock():
self.org_ctr += 1 self.org_ctr += 1
@ -208,6 +225,11 @@ class OpenAI(BaseAPIModel):
if response['error']['code'] == 'rate_limit_exceeded': if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1) time.sleep(1)
continue continue
elif response['error']['code'] == 'insufficient_quota':
self.invalid_keys.add(key)
self.logger.warn(f'insufficient_quota key: {key}')
continue
self.logger.error('Find error message in response: ', self.logger.error('Find error message in response: ',
str(response['error'])) str(response['error']))
max_num_retries += 1 max_num_retries += 1