mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
78 lines
2.1 KiB
Python
78 lines
2.1 KiB
Python
import csv
|
|
import os
|
|
|
|
import pytest
|
|
|
|
output_path = 'regression_result'
|
|
model = 'internlm2-chat-7b-hf'
|
|
dataset = 'siqa'
|
|
|
|
|
|
@pytest.fixture()
|
|
def result_scores():
|
|
file = find_csv_files(output_path)
|
|
if file is None:
|
|
return None
|
|
return read_csv_file(file)
|
|
|
|
|
|
@pytest.mark.usefixtures('result_scores')
|
|
class TestChatScore:
|
|
"""Test cases for chat model."""
|
|
|
|
def test_model_dataset_score(self, result_scores):
|
|
result_score = result_scores.get(model).get(dataset)
|
|
assert_score(result_score, 79.53)
|
|
|
|
|
|
def assert_score(score, baseline):
|
|
if score is None or score == '-':
|
|
assert False, 'value is none'
|
|
if float(score) < (baseline * 1.03) and float(score) > (baseline * 0.97):
|
|
print(score + ' between ' + str(baseline * 0.97) + ' and ' +
|
|
str(baseline * 1.03))
|
|
assert True
|
|
else:
|
|
assert False, score + ' not between ' + str(
|
|
baseline * 0.97) + ' and ' + str(baseline * 1.03)
|
|
|
|
|
|
def find_csv_files(directory):
|
|
csv_files = []
|
|
for root, dirs, files in os.walk(directory):
|
|
for file in files:
|
|
if file.endswith('.csv'):
|
|
csv_files.append(os.path.join(root, file))
|
|
if len(csv_files) > 1:
|
|
raise 'have more than 1 result file, please check the result manually'
|
|
if len(csv_files) == 0:
|
|
return None
|
|
return csv_files[0]
|
|
|
|
|
|
def read_csv_file(file_path):
|
|
with open(file_path, 'r') as csvfile:
|
|
reader = csv.DictReader(csvfile)
|
|
filtered_data = []
|
|
|
|
for row in reader:
|
|
filtered_row = {
|
|
k: v
|
|
for k, v in row.items()
|
|
if k not in ['version', 'metric', 'mode']
|
|
}
|
|
filtered_data.append(filtered_row)
|
|
|
|
result = {}
|
|
for data in filtered_data:
|
|
dataset = data.get('dataset')
|
|
for key in data.keys():
|
|
if key == 'dataset':
|
|
continue
|
|
else:
|
|
if key in result.keys():
|
|
result.get(key)[dataset] = data.get(key)
|
|
else:
|
|
result[key] = {dataset: data.get(key)}
|
|
return result
|