mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
0530
This commit is contained in:
parent
e227acc1a8
commit
f038ffac17
@ -18,7 +18,7 @@ import sympy as sp
|
||||
class SRbenchDataset(BaseDataset):
|
||||
@staticmethod
|
||||
def load(path: str,local_mode=True):
|
||||
base_path = get_data_path(path,local_mode=local_mode) # Resolve base path if necessary
|
||||
base_path = get_data_path(path,local_mode=local_mode)
|
||||
formula_csv_path = os.path.join(base_path, f'FeynmanEquation_23.csv')
|
||||
data_files_base_dir = os.path.join(base_path, 'Feynman_with_units')
|
||||
dataset = load_dataset('csv', data_files=formula_csv_path)['train']
|
||||
@ -46,6 +46,8 @@ class SRbenchDataset(BaseDataset):
|
||||
else:
|
||||
prompt_1 = '\n'.join([f'x0={x1:.4f}, x1={x2:.4f}, x2={x3:.4f},y={y:.4f}' for x1, x2,x3, y in sampled_data_i[:-1]])
|
||||
prompt_2=f'x0={sampled_data_i[-1, 0]:.4f}, x1={sampled_data_i[-1, 1]:.4f},x3={sampled_data_i[-1, 2]:.4f}, y={sampled_data_i[-1, 3]:.4f}'
|
||||
|
||||
|
||||
prompt_1_out.append(prompt_1)
|
||||
prompt_2_out.append(prompt_2)
|
||||
dataset=dataset.add_column(name="prompt1",column=prompt_1_out)
|
||||
@ -55,13 +57,18 @@ class SRbenchDataset(BaseDataset):
|
||||
return dataset
|
||||
|
||||
def mydataset_postprocess(formula_str):
|
||||
|
||||
# 1. 删除 Markdown 残留符号
|
||||
formula_str = formula_str.replace('×', '*').replace('·', '*').replace('÷', '/')
|
||||
formula_str = formula_str.replace('−', '-').replace('^', '**')
|
||||
formula_str = formula_str.replace('“', '"').replace('”', '"').replace('’', "'")
|
||||
|
||||
# 2. 去除 markdown 反引号 ``` 和 $ 符号
|
||||
formula_str = formula_str.replace('`', '').replace('$', '').strip()
|
||||
|
||||
# 3. 提取第一行公式(防止有多行解释性输出)
|
||||
formula_str = formula_str.split('\n')[0].strip()
|
||||
|
||||
# 4. 用正则去除非合法字符(保留基本数学表达式)
|
||||
formula_str = re.sub(r'[^\w\s\+\-\*/\^\=\.\(\)]', '', formula_str)
|
||||
|
||||
# 5. 确保左右去空格
|
||||
@ -70,7 +77,7 @@ def mydataset_postprocess(formula_str):
|
||||
class SRbenchDatasetEvaluator(BaseEvaluator):
|
||||
def __init__(self,
|
||||
local_mode: bool = True,path=""):
|
||||
self.dataset=SRbenchDataset.load(path="",local_mode=local_mode)
|
||||
self.dataset=SRbenchDataset.load(path,local_mode=local_mode)
|
||||
def parse_formula(self,formula_str, n_var=2):
|
||||
try:
|
||||
if '=' in formula_str:
|
||||
@ -115,22 +122,22 @@ class SRbenchDatasetEvaluator(BaseEvaluator):
|
||||
'R2': pd.Series(dtype=float),
|
||||
'SymbolicMatch': pd.Series(dtype=bool)
|
||||
})
|
||||
# 结构评分(用 LLM)
|
||||
for row in range(len(references)):
|
||||
#metrics['LLM_Score'] = float(self.llm_evaluate(predictions[row], references[row], mllm='gpt-4o'))
|
||||
print(self.dataset[row]["n_var"])
|
||||
n_var=self.dataset[row]["n_var"]
|
||||
y_true=references[row]
|
||||
data_sample=self.dataset[row]["data_samples_list"]
|
||||
data_sample = np.array(data_sample)
|
||||
x=data_sample[:,:n_var]
|
||||
y_true=data_sample[:,-1]
|
||||
func = self.parse_formula(predictions[row], n_var=n_var)
|
||||
if func is not None:
|
||||
try:
|
||||
x_vars = [x[:, i] for i in range(n_var)]
|
||||
y_pred = func(*x_vars)
|
||||
if np.isscalar(y_pred):
|
||||
y_pred = np.full_like(y_true, y_pred)
|
||||
metrics['RMSE'] = root_mean_squared_error(y_true, y_pred)
|
||||
metrics['R2'] = r2_score(y_true, y_pred)
|
||||
except Exception:
|
||||
pass
|
||||
x_vars = [x[:, i] for i in range(n_var)]
|
||||
y_pred = func(*x_vars)
|
||||
if np.isscalar(y_pred):
|
||||
y_pred = np.full_like(y_true, y_pred)
|
||||
metrics['RMSE'] = root_mean_squared_error(y_true, y_pred)
|
||||
metrics['R2'] = r2_score(y_true, y_pred)
|
||||
else:
|
||||
metrics["R2"]=0
|
||||
metrics["RMSE"]= np.inf
|
||||
@ -156,5 +163,3 @@ class SRbenchDatasetEvaluator(BaseEvaluator):
|
||||
|
||||
return metrics_out
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user