mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Script] Add scripts to evaluate MMBench (#161)
* update * update * Update README.md Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com> * refine * update default * update CN README --------- Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
This commit is contained in:
parent
1bab316624
commit
6ca2be6626
@ -27,7 +27,9 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
|
||||
|
||||
## News
|
||||
|
||||
- **\[2023.07.27\]** We have supported [CMMLU](https://github.com/haonan-li/CMMLU)! More datasets are welcomed to join OpenCompass. 🔥🔥🔥.
|
||||
- **\[2023.08.07\]** We have added a [script](tools/eval_mmbench.py) for users to evaluate the inference results of [MMBench](https://opencompass.org.cn/MMBench)-dev. 🔥🔥🔥.
|
||||
- **\[2023.08.05\]** We have supported [GPT-4](https://openai.com/gpt-4) and [Qwen-7B](https://github.com/QwenLM/Qwen-7B)! Go to our [leaderboard](https://opencompass.org.cn/leaderboard-llm) for more results! More models are welcome to join OpenCompass. 🔥🔥🔥.
|
||||
- **\[2023.07.27\]** We have supported [CMMLU](https://github.com/haonan-li/CMMLU)! More datasets are welcome to join OpenCompass. 🔥🔥🔥.
|
||||
- **\[2023.07.21\]** Performances of Llama-2 are available in [OpenCompass leaderboard](https://opencompass.org.cn/leaderboard-llm)! 🔥🔥🔥.
|
||||
- **\[2023.07.19\]** We have supported [Llama-2](https://ai.meta.com/llama/)! Its performance report will be available soon. \[[Doc](./docs/en/get_started.md#Installation)\] 🔥🔥🔥.
|
||||
- **\[2023.07.13\]** We release [MMBench](https://opencompass.org.cn/MMBench), a meticulously curated dataset to comprehensively evaluate different abilities of multimodality models 🔥🔥🔥.
|
||||
|
@ -27,6 +27,8 @@
|
||||
|
||||
## 更新
|
||||
|
||||
- **\[2023.08.07\]** 新增了 [MMBench 评测脚本](tools/eval_mmbench.py) 以支持用户自行获取 [MMBench](https://opencompass.org.cn/MMBench)-dev 的测试结果. 🔥🔥🔥.
|
||||
- **\[2023.08.05\]** [GPT-4](https://openai.com/gpt-4) 与 [Qwen-7B](https://github.com/QwenLM/Qwen-7B) 的评测结果已更新在 OpenCompass [大语言模型评测榜单](https://opencompass.org.cn/leaderboard-llm)! 🔥🔥🔥.
|
||||
- **\[2023.07.27\]** 新增了 [CMMLU](https://github.com/haonan-li/CMMLU)! 欢迎更多的数据集加入 OpenCompass. 🔥🔥🔥.
|
||||
- **\[2023.07.21\]** Llama-2 的评测结果已更新在 OpenCompass [大语言模型评测榜单](https://opencompass.org.cn/leaderboard-llm)! 🔥🔥🔥.
|
||||
- **\[2023.07.19\]** 新增了 [Llama-2](https://ai.meta.com/llama/)!我们近期将会公布其评测结果。\[[文档](./docs/zh_cn/get_started.md#安装)\] 🔥🔥🔥。
|
||||
|
403
tools/eval_mmbench.py
Normal file
403
tools/eval_mmbench.py
Normal file
@ -0,0 +1,403 @@
|
||||
# Usage: python eval_mmbench.py mmbench_dev_inference_result.xlsx
|
||||
import argparse
|
||||
import json
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import random as rd
|
||||
import string
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.models import OpenAI
|
||||
|
||||
fout = None
|
||||
|
||||
|
||||
# Utils
|
||||
def double_log(msg, fout=None):
|
||||
print(msg)
|
||||
if fout is not None:
|
||||
fout.write(str(msg) + '\n')
|
||||
fout.flush()
|
||||
|
||||
|
||||
def dump(data, f):
|
||||
|
||||
def dump_pkl(data, pth):
|
||||
pickle.dump(data, open(pth, 'wb'))
|
||||
|
||||
def dump_json(data, pth):
|
||||
json.dump(data, open(pth, 'w'))
|
||||
|
||||
def dump_jsonl(data, f):
|
||||
lines = [json.dumps(x, ensure_ascii=False) for x in data]
|
||||
with open(f, 'w', encoding='utf8') as fout:
|
||||
fout.write('\n'.join(lines))
|
||||
|
||||
def dump_xlsx(data, f):
|
||||
data.to_excel(f, index=False)
|
||||
|
||||
def dump_csv(data, f):
|
||||
data.to_csv(f, index=False)
|
||||
|
||||
def dump_tsv(data, f):
|
||||
data.to_csv(f, sep='\t', index=False)
|
||||
|
||||
handlers = dict(pkl=dump_pkl,
|
||||
json=dump_json,
|
||||
jsonl=dump_jsonl,
|
||||
xlsx=dump_xlsx,
|
||||
csv=dump_csv,
|
||||
tsv=dump_tsv)
|
||||
suffix = f.split('.')[-1]
|
||||
return handlers[suffix](data, f)
|
||||
|
||||
|
||||
def load(f):
|
||||
|
||||
def load_pkl(pth):
|
||||
return pickle.load(open(pth, 'rb'))
|
||||
|
||||
def load_json(pth):
|
||||
return json.load(open(pth, 'r', encoding='utf-8'))
|
||||
|
||||
def load_jsonl(f):
|
||||
lines = open(f, encoding='utf-8').readlines()
|
||||
lines = [x.strip() for x in lines]
|
||||
if lines[-1] == '':
|
||||
lines = lines[:-1]
|
||||
data = [json.loads(x) for x in lines]
|
||||
return data
|
||||
|
||||
def load_xlsx(f):
|
||||
return pd.read_excel(f)
|
||||
|
||||
def load_csv(f):
|
||||
return pd.read_csv(f)
|
||||
|
||||
def load_tsv(f):
|
||||
return pd.read_csv(f, sep='\t')
|
||||
|
||||
handlers = dict(pkl=load_pkl,
|
||||
json=load_json,
|
||||
jsonl=load_jsonl,
|
||||
xlsx=load_xlsx,
|
||||
csv=load_csv,
|
||||
tsv=load_tsv)
|
||||
suffix = f.split('.')[-1]
|
||||
return handlers[suffix](f)
|
||||
|
||||
|
||||
# Accuracy Report
|
||||
def report_acc(df, group='category'):
|
||||
assert 'split' in df
|
||||
assert group in [None, 'category', 'l2-category']
|
||||
|
||||
res = defaultdict(list)
|
||||
res['split'] = ['full', 'dev', 'test']
|
||||
if group is None:
|
||||
res['overall'] = [
|
||||
np.mean(df['hit']),
|
||||
np.mean(df[df['split'] == 'dev']['hit']),
|
||||
np.mean(df[df['split'] == 'test']['hit'])
|
||||
]
|
||||
return pd.DataFrame(res)
|
||||
|
||||
elif group in df:
|
||||
abilities = list(set(df[group]))
|
||||
abilities.sort()
|
||||
for ab in abilities:
|
||||
sub_df = df[df[group] == ab]
|
||||
res[ab] = [
|
||||
np.mean(sub_df['hit']),
|
||||
np.mean(sub_df[sub_df['split'] == 'dev']['hit']),
|
||||
np.mean(sub_df[sub_df['split'] == 'test']['hit'])
|
||||
]
|
||||
return pd.DataFrame(res)
|
||||
|
||||
|
||||
# Prompt Building
|
||||
def build_option_str(option_list):
|
||||
chars = string.ascii_uppercase
|
||||
s = 'There are several options: \n'
|
||||
for c, opt in zip(chars, option_list):
|
||||
if not pd.isna(opt):
|
||||
s += f'{c}. {opt}\n'
|
||||
else:
|
||||
return s
|
||||
return s
|
||||
|
||||
|
||||
def extract_options(item):
|
||||
options = []
|
||||
for c in 'ABCD':
|
||||
if c in item and not pd.isna(item[c]):
|
||||
options.append(item[c])
|
||||
else:
|
||||
return options
|
||||
return options
|
||||
|
||||
|
||||
def build_choices(item):
|
||||
ret = {}
|
||||
for ch in 'ABCD':
|
||||
if not pd.isna(item[ch]):
|
||||
ret[ch] = item[ch]
|
||||
return ret
|
||||
|
||||
|
||||
def build_prompt(question, options, prediction):
|
||||
tmpl = (
|
||||
'You are an AI assistant who will help me to match an answer '
|
||||
'with several options of a single-choice question. '
|
||||
'You are provided with a question, several options, and an answer, '
|
||||
'and you need to find which option is most similar to the answer. '
|
||||
'If the meaning of all options are significantly different '
|
||||
'from the answer, output E. '
|
||||
'Your should output a single uppercase character in A, B, C, D '
|
||||
'(if they are valid options), and E. \n'
|
||||
'Example 1: \n'
|
||||
'Question: What is the main object in image?\nOptions: A. teddy bear '
|
||||
'B. rabbit C. cat D. dog\nAnswer: a cute teddy bear\nYour output: A\n'
|
||||
'Example 2: \n'
|
||||
'Question: What is the main object in image?\nOptions: A. teddy bear '
|
||||
'B. rabbit C. cat D. dog\nAnswer: Spider\nYour output: E\n'
|
||||
'Example 3: \n'
|
||||
'Question: {}?\nOptions: {}\nAnswer: {}\nYour output: ')
|
||||
return tmpl.format(question, options, prediction)
|
||||
|
||||
|
||||
# Prefetch Answers
|
||||
def can_infer_option(answer, num_choice=5):
|
||||
choices = string.ascii_uppercase[:num_choice]
|
||||
if 'Failed to obtain answer via API' in answer:
|
||||
return False
|
||||
|
||||
def count(splits, choices='ABCD', prefix='', suffix=''):
|
||||
cnt = 0
|
||||
for c in choices:
|
||||
if prefix + c + suffix in splits:
|
||||
cnt += 1
|
||||
return cnt
|
||||
|
||||
splits = [x.strip() for x in answer.split()]
|
||||
if count(splits, choices) == 1:
|
||||
for ch in choices:
|
||||
if 'A' in splits and len(splits) > 3:
|
||||
double_log(
|
||||
f'A might be a quantifier in the string: {answer}. ', fout)
|
||||
break
|
||||
if ch in splits:
|
||||
return ch
|
||||
tups = [('', '.'), ('', ','), ('', ':'), ('', ')'), ('', ').'), ('(', ')'),
|
||||
('(', ').'), (':', ''), (':', ','), (':', '.'), (':', ')'),
|
||||
(':', ').')]
|
||||
for tup in tups:
|
||||
if count(splits, choices, prefix=tup[0], suffix=tup[1]) == 1:
|
||||
for ch in choices:
|
||||
if tup[0] + ch + tup[1] in splits:
|
||||
return ch
|
||||
return False
|
||||
|
||||
|
||||
def can_infer_text(answer, choices):
|
||||
answer = answer.lower()
|
||||
assert isinstance(choices, dict)
|
||||
for k in choices:
|
||||
assert k in 'ABCD'
|
||||
choices[k] = str(choices[k]).lower()
|
||||
cands = []
|
||||
for k in choices:
|
||||
if choices[k] in answer:
|
||||
cands.append(k)
|
||||
if len(cands) == 1:
|
||||
return cands[0]
|
||||
return False
|
||||
|
||||
|
||||
def can_infer(answer, choices):
|
||||
copt = can_infer_option(answer)
|
||||
return copt if copt else can_infer_text(answer, choices)
|
||||
|
||||
|
||||
def prefetch_answer(item):
|
||||
choices = build_choices(item)
|
||||
return can_infer(item['prediction'], choices)
|
||||
|
||||
|
||||
# Extract answer from a single record
|
||||
def extract_answer_from_item(model, item):
|
||||
# It will return: (pred, raw, llm_time)
|
||||
options = extract_options(item)
|
||||
option_str = build_option_str(options)
|
||||
|
||||
prompt = build_prompt(item['question'], option_str, item['prediction'])
|
||||
retry = 3
|
||||
choices = build_choices(item)
|
||||
|
||||
ret = can_infer(item['prediction'], choices)
|
||||
if ret:
|
||||
return ret, item['prediction']
|
||||
|
||||
while retry:
|
||||
ans = model.generate([prompt])[0]
|
||||
if 'Failed to obtain answer via API' in ans:
|
||||
msg = 'GPT API failed to answer. '
|
||||
double_log(msg, fout)
|
||||
retry -= 1
|
||||
else:
|
||||
ret = can_infer(ans, choices)
|
||||
if ret:
|
||||
return ret, ans
|
||||
else:
|
||||
double_log(
|
||||
f'GPT output includes 0 / >1 letter in "ABCD": {ans}',
|
||||
fout)
|
||||
retry -= 1
|
||||
|
||||
if retry == 0:
|
||||
num_options = sum([ch in item for ch in 'ABCD'])
|
||||
if num_options >= 2:
|
||||
chars = string.ascii_uppercase[:num_options]
|
||||
chars = chars + 'E'
|
||||
num_options += 1
|
||||
tmp = rd.randint(0, num_options - 1)
|
||||
return chars[
|
||||
tmp], 'Failed to predict, thus randomly generate one. '
|
||||
|
||||
|
||||
# Extract answer from multiple rolling records
|
||||
def eval_sub_data(model, sub_data, answer_map):
|
||||
lt = len(sub_data)
|
||||
GT, PRED = [], []
|
||||
for i in range(lt):
|
||||
item = sub_data.iloc[i]
|
||||
idx = item['index']
|
||||
GT.append(answer_map[idx])
|
||||
PRED.append(prefetch_answer(item))
|
||||
if PRED[-1] and (GT[-1] != PRED[-1]):
|
||||
return 0
|
||||
|
||||
for i in range(lt):
|
||||
if PRED[i]:
|
||||
continue
|
||||
else:
|
||||
ret, _ = extract_answer_from_item(model, sub_data.iloc[i])
|
||||
PRED[i] = ret
|
||||
if PRED[i] != GT[i]:
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
# Evaluate Results
|
||||
def eval_result(eval_file, eval_method, meta_file):
|
||||
rd.seed(2680)
|
||||
assert eval_method == 'openai'
|
||||
# Set a large retry number to avoid failure
|
||||
model = OpenAI('gpt-3.5-turbo-0613', retry=99)
|
||||
|
||||
double_log(f'Evaluating {eval_file}', fout)
|
||||
|
||||
result_file = eval_file.replace('.xlsx', f'_{eval_method}_result.pkl')
|
||||
result = {}
|
||||
if osp.exists(result_file):
|
||||
result = load(result_file)
|
||||
|
||||
data = load(eval_file)
|
||||
data = data.sort_values(by='index')
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
for k in data.keys():
|
||||
data[k.lower() if k not in 'ABCD' else k] = data.pop(k)
|
||||
|
||||
meta = load(meta_file)
|
||||
|
||||
data_main = data[data['index'] < int(1e6)]
|
||||
cate_map = {i: c for i, c in zip(meta['index'], meta['category'])}
|
||||
l2_cate_map = {i: c for i, c in zip(meta['index'], meta['l2-category'])}
|
||||
split_map = {i: c for i, c in zip(meta['index'], meta['split'])}
|
||||
answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}
|
||||
|
||||
lt = len(data_main)
|
||||
hit, tot = 0, 0
|
||||
|
||||
for i in tqdm(range(lt)):
|
||||
# Dealing with the normal part
|
||||
item_main = data_main.iloc[i]
|
||||
idx = item_main['index']
|
||||
|
||||
if idx in result:
|
||||
correct = result[idx]
|
||||
assert correct in [0, 1]
|
||||
hit += correct
|
||||
tot += 1
|
||||
continue
|
||||
|
||||
sub_data = data[data['index'] % int(1e6) == idx]
|
||||
ret = eval_sub_data(model, sub_data, answer_map)
|
||||
result[idx] = ret
|
||||
hit += ret
|
||||
tot += 1
|
||||
|
||||
dump(result, result_file)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
double_log((f'Evaluating {eval_file}: {i + 1}/{lt}, '
|
||||
f'Acc: {hit / tot * 100: .2f}%. '), fout)
|
||||
|
||||
dump(data_main, 'tmp.xlsx')
|
||||
data_main = load('tmp.xlsx')
|
||||
|
||||
res = load(result_file)
|
||||
indices = data_main['index']
|
||||
data_main['hit'] = [res[i] for i in indices]
|
||||
data_main['split'] = [split_map[i] for i in indices]
|
||||
main_idx = data_main['index']
|
||||
data_main['category'] = [cate_map[i] for i in main_idx]
|
||||
data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
|
||||
|
||||
# load split
|
||||
dump(data_main, eval_file.replace('.xlsx', f'_{eval_method}_result.xlsx'))
|
||||
data_main = load(eval_file.replace('.xlsx', f'_{eval_method}_result.xlsx'))
|
||||
|
||||
overall = report_acc(data_main, None)
|
||||
dump(overall, eval_file.replace('.xlsx', '_overall.csv'))
|
||||
double_log(overall)
|
||||
|
||||
l2 = report_acc(data_main, 'l2-category')
|
||||
dump(l2, eval_file.replace('.xlsx', '_l2.csv'))
|
||||
double_log(l2)
|
||||
|
||||
leaf = report_acc(data_main, 'category')
|
||||
dump(leaf, eval_file.replace('.xlsx', '_leaf.csv'))
|
||||
double_log(leaf)
|
||||
|
||||
if fout is not None:
|
||||
fout.close()
|
||||
|
||||
return overall, l2, leaf
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Evaluate Inference Results of MMBench-DEV SPLIT. ')
|
||||
parser.add_argument('result',
|
||||
type=str,
|
||||
help='The path to your inference result. ')
|
||||
parser.add_argument('--meta',
|
||||
type=str,
|
||||
default='data/mmbench_dev_20230712.tsv',
|
||||
help=('The path to your meta file (dev). '
|
||||
'Downloaded from MMBench website. '))
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
log_pth = args.result.replace('.xlsx', '_openai_eval.log')
|
||||
fout = open(log_pth, 'a')
|
||||
|
||||
acc, l2, leaf = eval_result(args.result, 'openai', args.meta)
|
Loading…
Reference in New Issue
Block a user