mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] fix alpacaeval while add caching path (#1139)
* fix alpacaeval * fix alpacaeval
This commit is contained in:
parent
19d7e630d6
commit
833a35140b
@ -74,4 +74,3 @@ eval = dict(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
work_dir = 'outputs/alpaca/'
|
work_dir = 'outputs/alpaca/'
|
||||||
|
|
@ -28,12 +28,14 @@ class PredictionMerger:
|
|||||||
self.model_cfg, self.dataset_cfg,
|
self.model_cfg, self.dataset_cfg,
|
||||||
osp.join(self.work_dir, 'predictions'))
|
osp.join(self.work_dir, 'predictions'))
|
||||||
root, ext = osp.splitext(filename)
|
root, ext = osp.splitext(filename)
|
||||||
|
alpaca_format_filename = root + '_alpaca' + ext
|
||||||
partial_filename = root + '_0' + ext
|
partial_filename = root + '_0' + ext
|
||||||
|
|
||||||
if osp.exists(osp.realpath(filename)):
|
if osp.exists(osp.realpath(alpaca_format_filename)):
|
||||||
return
|
return
|
||||||
|
|
||||||
if not osp.exists(osp.realpath(partial_filename)):
|
if not osp.exists(osp.realpath(partial_filename)) and not osp.exists(
|
||||||
|
osp.realpath(filename)):
|
||||||
print(f'{filename} not found')
|
print(f'{filename} not found')
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -67,8 +69,8 @@ class PredictionMerger:
|
|||||||
data_format[idx]['output'] = preds[str(idx)]['prediction']
|
data_format[idx]['output'] = preds[str(idx)]['prediction']
|
||||||
data_format[idx]['generator'] = self.model_cfg['abbr']
|
data_format[idx]['generator'] = self.model_cfg['abbr']
|
||||||
|
|
||||||
print(f'Merge {partial_filenames} to {filename}')
|
print(f'Convert to {alpaca_format_filename}')
|
||||||
with open(filename, 'w', encoding='utf-8') as f:
|
with open(alpaca_format_filename, 'w', encoding='utf-8') as f:
|
||||||
json.dump(data_format, f, indent=4, ensure_ascii=False)
|
json.dump(data_format, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
@ -107,6 +109,7 @@ class AlpacaEvalTask(BaseTask):
|
|||||||
# script_path = __file__
|
# script_path = __file__
|
||||||
alpaca_cfg = self.judge_cfg.get('config', None)
|
alpaca_cfg = self.judge_cfg.get('config', None)
|
||||||
api_key = self.judge_cfg.get('key', None)
|
api_key = self.judge_cfg.get('key', None)
|
||||||
|
base_url = self.judge_cfg.get('base_url', None)
|
||||||
assert alpaca_cfg is not None
|
assert alpaca_cfg is not None
|
||||||
all_cfg = Config.fromfile(cfg_path)
|
all_cfg = Config.fromfile(cfg_path)
|
||||||
model_cfg = all_cfg['models']
|
model_cfg = all_cfg['models']
|
||||||
@ -120,7 +123,12 @@ class AlpacaEvalTask(BaseTask):
|
|||||||
}).run()
|
}).run()
|
||||||
filename = get_infer_output_path(m_cfg, dataset_cfg,
|
filename = get_infer_output_path(m_cfg, dataset_cfg,
|
||||||
osp.join(work_dir, 'predictions'))
|
osp.join(work_dir, 'predictions'))
|
||||||
|
root, ext = osp.splitext(filename)
|
||||||
|
alpaca_format_filename = root + '_alpaca' + ext
|
||||||
output_path = osp.join(work_dir, 'results', m_cfg['abbr'])
|
output_path = osp.join(work_dir, 'results', m_cfg['abbr'])
|
||||||
|
if not osp.exists(output_path):
|
||||||
|
os.makedirs(output_path)
|
||||||
|
caching_path = osp.join(output_path, 'tmp_annotations.json')
|
||||||
command = ''
|
command = ''
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
command += f'export OPENAI_API_KEY={api_key}; '
|
command += f'export OPENAI_API_KEY={api_key}; '
|
||||||
@ -128,7 +136,9 @@ class AlpacaEvalTask(BaseTask):
|
|||||||
api_key = os.environ.get('OPENAI_API_KEY', '').split(',')[0]
|
api_key = os.environ.get('OPENAI_API_KEY', '').split(',')[0]
|
||||||
if api_key:
|
if api_key:
|
||||||
command += f'export OPENAI_API_KEY={api_key}; '
|
command += f'export OPENAI_API_KEY={api_key}; '
|
||||||
command += f'alpaca_eval --model_outputs {filename} --annotators_config {alpaca_cfg} --output_path {output_path}'
|
if base_url is not None:
|
||||||
|
command += f'export OPENAI_BASE_URL={base_url}; '
|
||||||
|
command += f'alpaca_eval --model_outputs {alpaca_format_filename} --annotators_config {alpaca_cfg} --output_path {output_path} --caching_path {caching_path};'
|
||||||
return template.format(task_cmd=command)
|
return template.format(task_cmd=command)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user