[Fix] fix summarizer (#1217)

* fix summarizer

* fix summarizer
This commit is contained in:
bittersweet1999 2024-05-31 11:40:47 +08:00 committed by GitHub
parent a77b8a5cec
commit 7c381e5be8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 9 deletions

View File

@ -72,8 +72,8 @@ judge_models = [dict(
key='',
meta_template=api_meta_template,
query_per_second=1,
max_out_len=5120,
max_seq_len=9216,
max_out_len=4096,
max_seq_len=8192,
batch_size=10,
retry=10,
temperature = 0,

View File

@ -133,6 +133,27 @@ def get_win_rate_column(df, column, baseline='gpt4-0314'):
return win_rate_table[baseline].fillna(0.5).apply(lambda x: round(x * 100, 2))
def load_model_preds(filename):
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
if osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
pred_strs = [
preds[str(i)]['prediction'] for i in range(len(preds))
]
else:
filename = partial_filename
pred_strs = []
i = 1
while osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
filename = root + f'_{i}' + ext
i += 1
pred_strs += [
preds[str(i)]['prediction'] for i in range(len(preds))
]
return pred_strs
def get_battles_from_judgment(dataset, subdir_path, post_process, WEIGHT=3):
arena_hard_battles = pd.DataFrame()
dataset_abbr = dataset_abbr_from_cfg(dataset)
@ -274,12 +295,12 @@ class ArenaHardSummarizer:
if model == 'gpt4-0314':
stats.at[i, 'avg_tokens'] = 423
else:
with open(os.path.join(output_dir.split('summary')[0], 'predictions', model, dataset_abbr+'.json'), 'r') as f:
model_preds = json.load(f)
pred_length = 0
for k, v in model_preds.items():
pred_length += len(tiktoken.encoding_for_model('gpt-3.5-turbo').encode(v['prediction']))
pred_length /= len(model_preds)
file_name = os.path.join(output_dir.split('summary')[0], 'predictions', model, dataset_abbr+'.json')
model_preds = load_model_preds(file_name)
pred_length = 0
for model_pred in model_preds:
pred_length += len(tiktoken.encoding_for_model('gpt-3.5-turbo').encode(model_pred, disallowed_special=()))
pred_length /= len(model_preds)
stats.at[i, 'avg_tokens'] = pred_length
stats.at[i, 'results'] = bootstrap_elo_lu[model].tolist()
stats.sort_values(by='model', inplace=True)

View File

@ -92,7 +92,7 @@ def get_config_from_arg(args) -> Config:
config['eval']['partitioner']['compare_models'] = change_accelerator(config['eval']['partitioner']['compare_models'], args.accelerator)
if config.get('eval', {}).get('partitioner', {}).get('judge_models') is not None:
config['eval']['partitioner']['judge_models'] = change_accelerator(config['eval']['partitioner']['judge_models'], args.accelerator)
if config.get('judge_models', {}) is not None:
if config.get('judge_models') is not None:
config['judge_models'] = change_accelerator(config['judge_models'], args.accelerator)
return config