From 9e488c2f8f83fc1b6fefdee9be8b2fe5a93c8f33 Mon Sep 17 00:00:00 2001 From: Jun <1557706594@qq.com> Date: Tue, 27 May 2025 03:41:21 +0000 Subject: [PATCH] 250527 --- .../configs/datasets/srbench/srbench_gen.py | 58 +++++ opencompass/datasets/srbench.py | 222 ++++++++++++++++++ .../srbench/Feynman/FeynmanEquation_23.csv | 48 ++++ opencompass/utils/datasets_info.py | 5 + 4 files changed, 333 insertions(+) create mode 100644 opencompass/configs/datasets/srbench/srbench_gen.py create mode 100644 opencompass/datasets/srbench.py create mode 100644 opencompass/datasets/srbench/Feynman/FeynmanEquation_23.csv diff --git a/opencompass/configs/datasets/srbench/srbench_gen.py b/opencompass/configs/datasets/srbench/srbench_gen.py new file mode 100644 index 00000000..fe7fdf8e --- /dev/null +++ b/opencompass/configs/datasets/srbench/srbench_gen.py @@ -0,0 +1,58 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import ( + SRbenchDataset,SRbenchDatasetEvaluator +) + +from opencompass.evaluator import GenericLLMEvaluator + + + +INFER_TEMPLATE = f''' + You will be provided with a set of input-output pairs. Based on these data, infer the mathematical relationship between y and multiple input variables. Please note that the possible mathematical operations include: +, -, *, /, exp, sqrt, sin, arcsin, and constant terms. + The input sample data are as follows: + {{prompt1}} + Based on the above data, please infer the possible formula. Ensure that your inference applies to all the provided data points, and consider both linear and nonlinear combinations. + Verify whether your formula applies to the following new data point and adjust it to ensure accuracy: + {{prompt2}} + Finally, please output only the formula string you inferred (e.g. y=x_0 * x_1), without any additional information. + ''' + +srbench_reader_cfg = dict(input_columns=["prompt1","prompt2"], output_column='Formula') + +srbench_datasets = [] + +srbench_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role='HUMAN', + prompt=INFER_TEMPLATE) + ] + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + + +srbench_eval_cfg = dict( + evaluator=dict(type=SRbenchDatasetEvaluator), + path="opencompass/srbench", + pred_role='BOT', + ) + +srbench_datasets.append( + dict( + abbr='srbench', + type=SRbenchDataset, + path='opencompass/srbench', + reader_cfg=srbench_reader_cfg, + infer_cfg=srbench_infer_cfg, + eval_cfg=srbench_eval_cfg, + ) + ) + diff --git a/opencompass/datasets/srbench.py b/opencompass/datasets/srbench.py new file mode 100644 index 00000000..f2441df7 --- /dev/null +++ b/opencompass/datasets/srbench.py @@ -0,0 +1,222 @@ + +from datasets import load_dataset +from opencompass.datasets.base import BaseDataset +from opencompass.registry import LOAD_DATASET +from opencompass.utils import get_data_path +from opencompass.openicl.icl_evaluator import BaseEvaluator +from sklearn.metrics import r2_score,root_mean_squared_error + +import os +import numpy as np +import pandas as pd +import json +import requests +import sympy as sp + +@LOAD_DATASET.register_module() +class SRbenchDataset(BaseDataset): + @staticmethod + def load(path: str,local_mode=True): + path="path_to_dataset" + base_path = get_data_path(path,local_mode=local_mode) + formula_csv_path = os.path.join(base_path, f'FeynmanEquation_23.csv') + data_files_base_dir = os.path.join(base_path, 'Feynman_with_units') + processed_formulas_df = load_dataset('csv', data_files=formula_csv_path)['train'] + sample_data=[] + prompt_1_out=[] + prompt_2_out=[] + for row in processed_formulas_df: + true_formula = str(row["Formula"]) + n_var=int(row["n_variables"]) + data_filename = str(row['Filename']) + data_file_path = os.path.join(data_files_base_dir, data_filename) + full_dataset = np.loadtxt(data_file_path) + rand_idx = np.random.choice(full_dataset.shape[0], 100, replace=False) + sampled_data_i = full_dataset[rand_idx] + if isinstance(sampled_data_i, np.ndarray): + sample_data.append(sampled_data_i.tolist()) + else: + sample_data.append(sampled_data_i) + if n_var == 2: + prompt_1 = '\n'.join([f'x0={x1:.4f}, x1={x2:.4f}, y={y:.4f}' for x1, x2, y in sampled_data_i[:-1]]) + prompt_2=f'x0={sampled_data_i[-1, 0]:.4f}, x1={sampled_data_i[-1, 1]:.4f}, y={sampled_data_i[-1, 2]:.4f}' + else: + prompt_1 = '\n'.join([f'x0={x1:.4f}, x1={x2:.4f}, x2={x3:.4f},y={y:.4f}' for x1, x2,x3, y in sampled_data_i[:-1]]) + prompt_2=f'x0={sampled_data_i[-1, 0]:.4f}, x1={sampled_data_i[-1, 1]:.4f},x3={sampled_data_i[-1, 2]:.4f}, y={sampled_data_i[-1, 3]:.4f}' + prompt_1_out.append(prompt_1) + prompt_2_out.append(prompt_2) + processed_formulas_df=processed_formulas_df.add_column(name="prompt1",column=prompt_1_out) + processed_formulas_df=processed_formulas_df.add_column(name="prompt2",column=prompt_2_out) + processed_formulas_df=processed_formulas_df.add_column(name="data_samples_list",column=sample_data) + processed_formulas_df = processed_formulas_df.rename_column('n_variables', 'n_var') + return processed_formulas_df + +class SRbenchDatasetEvaluator(BaseEvaluator): + def __init__(self, + local_mode: bool = True,path=""): + self.dataset=SRbenchDataset.load(path="",local_mode=local_mode) + def _send_request(self,messages, mllm='4o'): + URL = f"your_api_url" + API_KEY = "your_api_key" + HEADERS = { + 'Accept': 'application/json', + 'Authorization': f'Bearer {API_KEY}', + 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)', + 'Content-Type': 'application/json' + } + model = mllm + count = 0 + while True and count < 20: + count += 1 + payload = json.dumps({ + "model": model, + "messages": messages, + "temperature": 0.6, + "max_tokens": 50 + }) + session = requests.Session() + session.keep_alive = False + response = session.post(URL, headers=HEADERS, data=payload, verify=True) + try: + content = response.json()['choices'][0]['message']['content'] + break + except: + content=None + pass + + return content + def parse_formula(self,formula_str, n_var=2): + try: + if '=' in formula_str: + _, expr_str = formula_str.split('=', 1) + else: + expr_str = formula_str + variables = [sp.Symbol(f'x{i}') for i in range(n_var)] + expr = sp.sympify(expr_str) + func = sp.lambdify(variables, expr, modules='numpy') + return func + except Exception as e: + print(f'[Parse Error] {formula_str}\n{e}') + return None + + def is_symbolically_equivalent(self,formula1, formula2, n_var=2): + try: + x = [sp.Symbol(f'x{i}') for i in range(n_var)] + expr1 = sp.sympify(formula1.split('=')[1] if '=' in formula1 else formula1) + expr2 = sp.sympify(formula2.split('=')[1] if '=' in formula2 else formula2) + + return sp.simplify(expr1 - expr2) == 0 + except Exception: + return False + def llm_evaluate(self,inferred_formula, true_formula, mllm='gpt-4o'): + content = f''' + You are given two mathematical formulas. Your task is to evaluate how structurally similar they are, and return a similarity score between 0 and 1. + + The score should reflect how closely the formulas match in terms of: + - Mathematical operations and structure (e.g., same use of +, *, sin, etc.) + - Term arrangement and complexity + - Overall symbolic expression and intent + + A score of: + - 1 means the formulas are structurally identical or mathematically equivalent + - Around 0.8-0.9 means they are very similar but not identical + - Around 0.5 means moderately similar (e.g., same overall shape but different terms) + - Near 0 means structurally unrelated formulas + + Do not consider numerical evaluation or specific input values — only the symbolic structure and mathematical form. + + Formulas: + Inferred Formula: {inferred_formula} + True Formula: {true_formula} + + ONLY RETURN [THE SIMILARITY SCORE] + ''' + messages = [{"role": "user", "content": content}] + similarity_score = self._send_request(messages, mllm=mllm) + #print(similarity_score) + specific_emoji = "😊" + if similarity_score.endswith(specific_emoji): + similarity_score = similarity_score[:-len(specific_emoji)].rstrip() + if similarity_score.startswith("["): + similarity_score = similarity_score[1:] + if similarity_score.endswith("]"): + similarity_score = similarity_score[:-1] + if similarity_score == ".": + similarity_score= 0.0 + if similarity_score.endswith(specific_emoji): + similarity_score = similarity_score[:-len(specific_emoji)].rstrip() + return similarity_score + + def llm_translate(self,dirty_formula, mllm='gpt-4o'): + content = f''' + This is a language model's judgment on a mathematical formula. Please help me extract the mathematical formula from this judgment and return it: + {dirty_formula} + Please serve pi as pi and use x0, x1, x2,... to represent the variable names. + ONLY RETURN THE FORMULA STRING (Not LATEX). + ''' + messages = [{"role": "user", "content": content}] + clean_formula = _send_request(messages, mllm=mllm) + return clean_formula + + + def score(self, predictions, references) -> dict: + metrics = { + 'LLM_Score': None, + 'RMSE': None, + 'SymbolicMatch': False, + 'R2': 0} + metrics_out={ + 'LLM_Score': None, + 'RMSE': None, + 'Accuray': False, + 'R2': 0 + } + result = pd.DataFrame({ + 'GT': pd.Series(dtype=str), + 'Pred': pd.Series(dtype=str), + 'Score': pd.Series(dtype=float), + 'RMSE': pd.Series(dtype=float), + 'R2': pd.Series(dtype=float), + 'SymbolicMatch': pd.Series(dtype=bool) + }) + + for row in range(len(references)): + metrics['LLM_Score'] = float(self.llm_evaluate(predictions[row], references[row], mllm='gpt-4o')) + n_var=self.dataset[row]["n_var"] + y_true=references[row] + func = self.parse_formula(predictions[row], n_var=n_var) + if func is not None: + try: + x_vars = [x[:, i] for i in range(n_var)] + y_pred = func(*x_vars) + if np.isscalar(y_pred): + y_pred = np.full_like(y_true, y_pred) + metrics['RMSE'] = root_mean_squared_error(y_true, y_pred) + metrics['R2'] = r2_score(y_true, y_pred) + except Exception: + pass + else: + metrics["R2"]=0 + metrics["RMSE"]= root_mean_squared_error(y_true, y_pred) + metrics['SymbolicMatch'] = self.is_symbolically_equivalent(predictions[row], references[row], n_var) + result = result._append({ + 'GT': references[row], + 'Pred': predictions[row], + 'Score': metrics['LLM_Score'], + 'RMSE': metrics['RMSE'], + 'R2': metrics['R2'], + 'SymbolicMatch': bool(metrics['SymbolicMatch']) + }, ignore_index=True) + + if not result.empty: + symbolic_accuracy = result['SymbolicMatch'].sum() / len(result) + R2_out = result['R2'].sum() / len(result) + Score_out = result['Score'].sum() / len(result) + RMSE_out = result['RMSE'].sum() / len(result) + metrics_out={ + 'LLM_Score': Score_out, + 'RMSE': RMSE_out, + 'R2': R2_out, + "Accuracy":symbolic_accuracy + } + return metrics_out diff --git a/opencompass/datasets/srbench/Feynman/FeynmanEquation_23.csv b/opencompass/datasets/srbench/Feynman/FeynmanEquation_23.csv new file mode 100644 index 00000000..80b68624 --- /dev/null +++ b/opencompass/datasets/srbench/Feynman/FeynmanEquation_23.csv @@ -0,0 +1,48 @@ +Formula,Filename,n_variables +y = exp(-(x1/x0)**2/2) / (sqrt(2*pi)*x0),I.6.2,2 +y = x0 * x1,I.12.1,2 +y = x0 * x1,I.12.5,2 +y = 1/2 * x0 * x1**2,I.14.4,2 +y = x0 / x1,I.25.13,2 +y = arcsin(x0 * sin(x1)),I.26.2,2 +y = x0 / x1,I.29.4,2 +y = (x1 / (2 * pi)) * x0,I.34.27,2 +y = (3/2) * x0 * x1,I.39.1,2 +y = x0 / (4 * pi * x1**2),II.3.24,2 +y = x0 * x1**2 / 2,II.8.31,2 +y = 1 + x0 * x1 / (1 - (x0 * x1 / 3)),II.11.28,2 +y = x0 * x1**2,II.27.18,2 +y = x0 / (2 * (1 + x1)),II.38.14,2 +y = x0 * (x1 / (2 * pi)),III.12.43,2 +y = exp(-((x1 - x2) / x0)**2 / 2) / (sqrt(2 * pi) * x0),I.6.2b,3 +y = x0 / sqrt(1 - x1**2 / x2**2),I.10.7,3 +y = x0*x2/(4*pi*x1*x2**3),I.12.4,3 +y = x0 * x1 * x2,I.14.3,3 +y = (x1 + x2) / (1 + x1 * x2 / x0**2),I.16.6,3 +y = 1 / (1 / x0 + x2 / x1),I.27.6,3 +y = x0 * sin(x2 * x1 / 2)**2 / sin(x1 / 2)**2,I.30.3,3 +y = arcsin(x0 / (x2 * x1)),I.30.5,3 +y = x2 / (1 - x1 / x0),I.34.1,3 +y = (1 + x1/x0) / sqrt(1 - x1**2/x0**2) * x2,I.34.14,3 +y = x0 + x1 + 2 * sqrt(x0 * x1) * cos(x2),I.37.4,3 +y = 1/(x0-1) * x1 * x2,I.39.11,3 +y = x0 * x2 * x1,I.43.31,3 +y = sqrt(x0 * x1 / x2),I.47.23,3 +y = x0 / (4 * pi * x1 * x2),II.4.23,3 +y = (3/5)*x0**2/(4*pi*x1*x2),II.8.7,3 +y = x0 / (x1 * (1 + x2)),II.10.9,3 +y = x0 / sqrt(1 - x1**2 / x2**2),II.13.23,3 +y = x0 * x1 / sqrt(1 - x1**2 / x2**2),II.13.34,3 +y = -x0 * x1 * cos(x2),II.15.4,3 +y = -x0 * x1 * cos(x2),II.15.5,3 +y = sqrt(x0**2/x1**2 - pi**2/x2**2),II.24.17,3 +y = x0 * x1 * x2**2,II.27.16,3 +y = x0 * x1 / (2 * pi * x2),II.34.2a,3 +y = x0 * x1 * x2 / 2,II.34.2,3 +y = x0 * x1 / (4 * pi * x2),II.34.29a,3 +y = 2*x0*x1/(x2/(2*pi)),III.7.38,3 +y = sin(x0 * x1 / (x2 / (2 * pi)))**2,III.8.54,3 +y = 2*x0*(1 - cos(x1*x2)),III.15.12,3 +y = (x0 / (2 * pi))**2 / (2 * x1 * x2**2),III.15.14,3 +y = 2*pi*x0/(x1*x2),III.15.27,3 +y = x0 * (1 + x1 * cos(x2)),III.17.37,3 diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index 5048a496..ccbb652d 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -446,6 +446,11 @@ DATASETS_MAPPING = { "hf_id": "", "local": "./data/ChemBench4K", }, + "opencompass/srbench": { + "ms_id": "", + "hf_id": "", + "local": "./data/srbench", + }, }