import csv import os import pytest output_path = 'regression_result' model = 'internlm2-chat-7b-hf' dataset = 'siqa' @pytest.fixture() def result_scores(): file = find_csv_files(output_path) if file is None: return None return read_csv_file(file) @pytest.mark.usefixtures('result_scores') class TestChatScore: """Test cases for chat model.""" def test_model_dataset_score(self, result_scores): result_score = result_scores.get(model).get(dataset) assert_score(result_score, 79.53) def assert_score(score, baseline): if score is None or score == '-': assert False, 'value is none' if float(score) < (baseline * 1.03) and float(score) > (baseline * 0.97): print(score + ' between ' + str(baseline * 0.97) + ' and ' + str(baseline * 1.03)) assert True else: assert False, score + ' not between ' + str( baseline * 0.97) + ' and ' + str(baseline * 1.03) def find_csv_files(directory): csv_files = [] for root, dirs, files in os.walk(directory): for file in files: if file.endswith('.csv'): csv_files.append(os.path.join(root, file)) if len(csv_files) > 1: raise 'have more than 1 result file, please check the result manually' if len(csv_files) == 0: return None return csv_files[0] def read_csv_file(file_path): with open(file_path, 'r') as csvfile: reader = csv.DictReader(csvfile) filtered_data = [] for row in reader: filtered_row = { k: v for k, v in row.items() if k not in ['version', 'metric', 'mode'] } filtered_data.append(filtered_row) result = {} for data in filtered_data: dataset = data.get('dataset') for key in data.keys(): if key == 'dataset': continue else: if key in result.keys(): result.get(key)[dataset] = data.get(key) else: result[key] = {dataset: data.get(key)} return result