[Feature] Add S3Eval Dataset (#916)

* s3eval_branch

* update s3eval
This commit is contained in:
Fangyu Lei 2024-05-06 11:41:52 +00:00 committed by GitHub
parent d501710155
commit 862044fb7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 329 additions and 0 deletions

View File

@ -0,0 +1,139 @@
# S3Eval
## Introduction
The following benchmark comes from the paper in [S3eval: A synthetic, scalable, systematic evaluation suite for large language models](https://arxiv.org/abs/2310.15147)
S3Eval, our latest contribution to the field, addresses the critical need for comprehensive evaluation resources for Large Language Models (LLMs). In the pursuit of understanding long-context comprehension and enhancing reasoning capabilities, we present a benchmarking suite that is both synthetic and scalable.
Operating on SQL execution tasks, S3Eval challenges LLMs with randomly generated tables and SQL queries, evaluating their ability to produce accurate execution results. This benchmark stands out for its versatility and scalability, providing unlimited evaluation resources for a robust assessment of LLM capabilities.
In this latest submission, we have generated a batch of high-quality data, encompassing nearly all types of queries with strong diversity. Moreover, the length of the tables spans from 200 to 200K, enabling a systematic evaluation of the long-context capabilities of the models.
For researchers and practitioners alike, S3Eval holds the promise of uncovering deeper insights into LLM performance. Explore the paper for detailed information on its design, experiments, and implications. We invite you to leverage S3Eval for your research endeavors and contribute to the evolving landscape of synthetic benchmark construction. 😊
## Official link
### Paper
[S3eval: A synthetic, scalable, systematic evaluation suite for large language models](https://arxiv.org/abs/2310.15147)
### Repository
[s3eval](https://github.com/lfy79001/S3Eval)
## Examples
Input example I:
```
You are an SQL executor, you need to execute SQL based on the give table and SQL statement to obtain the execution results.
| suiting | chisel | highboy | broccoli | newburgh | acetum | brewpub |
|:----------|:----------|----------:|-----------:|:-----------|:----------|----------:|
| zbwamhiui | nnkfvevxw | 50 | 88 | zhwohj | opufj | 214 |
| zroosgm | yvftt | 309 | 168 | zhwohj | xqsu | 136 |
| zroosgm | lnri | 152 | 78 | zhwohj | ikvsd | 219 |
| kjsdl | trei | 234 | 287 | egkgkvbec | mhxcxyg | 23 |
| zroosgm | mctnpwbd | 71 | 242 | egkgkvbec | yszfokeom | 180 |
| zbwamhiui | ptqtj | 19 | 81 | egkgkvbec | hyfmk | 116 |
| zroosgm | lpjvwn | 258 | 313 | uftnwbd | oevmj | 65 |
| kjsdl | ididumrhw | 64 | 101 | uftnwbd | xjakwpayx | 327 |
| zbwamhiui | wdtncbyn | 165 | 209 | uftnwbd | xrbqvxb | 192 |
| zbwamhiui | wyjjc | 219 | 6 | uftnwbd | pzqr | 188 |
| zroosgm | qumxgwvls | 314 | 246 | uftnwbd | ehevtf | 60 |
| zbwamhiui | adiyf | 207 | 298 | egkgkvbec | wbrgejgf | 80 |
| zbwamhiui | qpgpbj | 307 | 306 | egkgkvbec | mcjuonhc | 192 |
| zbwamhiui | ehsk | 47 | 244 | zhwohj | tcdlnc | 280 |
| kjsdl | orlosbok | 21 | 93 | egkgkvbec | dzvwohjo | 103 |
| zbwamhiui | webyyylw | 84 | 195 | egkgkvbec | xbmv | 289 |
| kjsdl | mrcecp | 48 | 264 | egkgkvbec | xhprcocik | 265 |
| kjsdl | ngajupd | 247 | 52 | zhwohj | pcokyw | 247 |
| zroosgm | xeeuixkze | 120 | 288 | zhwohj | yishnriw | 138 |
| kjsdl | kbczy | 119 | 13 | egkgkvbec | ltpmyfdt | 73 |
| zbwamhiui | uvvdzo | 150 | 57 | uftnwbd | tajlsm | 295 |
| zbwamhiui | enbffevhp | 290 | 92 | zhwohj | gjjznp | 18 |
| zroosgm | imubtcc | 79 | 19 | uftnwbd | eqymwj | 112 |
SQL:select suiting from my_table group by suiting having count ( newburgh ) > 6
Answer:
| suiting |
|:----------|
| zbwamhiui |
| zroosgm |
SQL:select acetum,newburgh,suiting from my_table where highboy > 234
Answer:
| acetum | newburgh | suiting |
|:---------|:-----------|:----------|
| xqsu | zhwohj | zroosgm |
| oevmj | uftnwbd | zroosgm |
| ehevtf | uftnwbd | zroosgm |
| mcjuonhc | egkgkvbec | zbwamhiui |
| pcokyw | zhwohj | kjsdl |
| gjjznp | zhwohj | zbwamhiui |
SQL:select count ( chisel ) from my_table where highboy < brewpub group by newburgh having min ( highboy ) < 47
Answer:
| count ( chisel ) |
|-------------------:|
| 5 |
SQL:select newburgh from my_table where brewpub > 138 order by broccoli desc limit 1
Answer:
| newburgh |
|:-----------|
| egkgkvbec |
SQL:select suiting from my_table where highboy > broccoli group by suiting having min ( highboy ) < 314
Answer:
```
Output example I (from GPT-4):
```
| suiting |
|:----------|
| kjsdl |
| zbwamhiui |
| zroosgm |
```
## Evaluation results
| Model | Score |
|---------------|-------|
| GPT-4 | 61.3 |
| GPT3.5-Turbo | 40.2 |
| Code LLama 34B| 28.3 |
| Code LLama 13B| 21.5 |
| Code LLama 7B | 12.7 |
| Starcoder1 15B| 12.5 |
| Starcoder1 7B | 10.2 |
| Starcoder1 3B | 7.8 |
| Starcoder1 1B | 5.4 |
| Llama 13B | 13.1 |
| Llama 7B | 6.5 |
| Deepseek 7B | 12.6 |
| Olmo 7B | 8.2 |
| Qwen 14B | 12.3 |
| Qwen 7B | 11.6 |
| Mistral 7B | 12.4 |
| Internlm 20B | 14.6 |
## Reference
```
@article{lei2023s3eval,
title={S3eval: A synthetic, scalable, systematic evaluation suite for large language models},
author={Lei, Fangyu and Liu, Qian and Huang, Yiming and He, Shizhu and Zhao, Jun and Liu, Kang},
journal={arXiv preprint arXiv:2310.15147},
year={2023}
}
```

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .s3eval_gen_370cc2 import s3eval_datasets # noqa: F401, F40

View File

@ -0,0 +1,17 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import S3EvalDataset, S3EvalEvaluator
s3eval_cfg = dict(evaluator=dict(type=S3EvalEvaluator))
s3eval_datasets = [
dict(
type=S3EvalDataset,
abbr="s3eval",
path='FangyuLei/s3eval',
eval_cfg=s3eval_cfg)
]

View File

@ -0,0 +1,169 @@
import re
import string
from collections import Counter
from datasets import Dataset, load_dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class S3EvalDataset(BaseDataset):
@staticmethod
def load(path: str):
train_data = []
s3eval_dataset = load_dataset(path)
for example in s3eval_dataset['test']:
train_data.append({
'input': example['input'],
'output': example['output']
})
dataset = Dataset.from_list(train_data)
return dataset
@ICL_EVALUATORS.register_module()
class S3EvalEvaluator(BaseEvaluator):
def score(self, predictions, references):
def is_numeric(string):
try:
float(string)
return True
except ValueError:
return False
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra
whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def markdown_to_list(data):
lines = data.split('\n')[2:]
result = []
for line in lines:
if line.strip():
content = line.split('|')[1:-1]
content = [item.strip() for item in content]
result.append(tuple(content))
return result
def calculate_multi_em_score(pred, gold):
true_positives = 0
false_positives = 0
false_negatives = 0
pred_counts = {}
gold_counts = {}
for answer in pred:
pred_counts[answer] = pred_counts.get(answer, 0) + 1
for answer in gold:
gold_counts[answer] = gold_counts.get(answer, 0) + 1
for answer in pred_counts:
true_positives += min(pred_counts[answer],
gold_counts.get(answer, 0))
false_positives += max(
0, pred_counts[answer] - gold_counts.get(answer, 0))
for answer in gold_counts:
false_negatives += max(
0, gold_counts[answer] - pred_counts.get(answer, 0))
if true_positives == 0 or (true_positives + false_positives
) == 0 or (true_positives +
false_negatives) == 0:
return 0
precision = true_positives / (true_positives + false_positives)
recall = true_positives / (true_positives + false_negatives)
f1_score = 2 * (precision * recall) / (precision + recall)
return f1_score
def comma_f1_score(prediction, ground_truth, **kwargs):
prediction_tokens = prediction.split(',')
pred = [item.strip() for item in prediction_tokens]
ground_truth_tokens = ground_truth.split(',')
gold = [item.strip() for item in ground_truth_tokens]
true_positives = len(set(pred) & set(gold))
false_positives = len(set(pred) - set(gold))
false_negatives = len(set(gold) - set(pred))
if true_positives == 0 or (true_positives + false_positives
) == 0 or (true_positives +
false_negatives) == 0:
return 0
precision = true_positives / (true_positives + false_positives)
recall = true_positives / (true_positives + false_negatives)
f1_score = 2 * (precision * recall) / (precision + recall)
return f1_score
def f1_score(prediction, ground_truth, **kwargs):
common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction)
recall = 1.0 * num_same / len(ground_truth)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def qa_f1_score(prediction, ground_truth, **kwargs):
if is_numeric(prediction) and is_numeric(ground_truth):
if float(prediction) == float(ground_truth):
return 1
else:
return 0
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
return f1_score(prediction_tokens, ground_truth_tokens)
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
scores = []
for pred_str, gold_str in zip(predictions, references):
if '|' in gold_str:
pred = markdown_to_list(pred_str)
gold = markdown_to_list(gold_str)
score = calculate_multi_em_score(pred, gold)
else:
if ',' in gold_str:
score = comma_f1_score(pred_str, gold_str)
else:
score = qa_f1_score(pred_str, gold_str)
scores.append(score)
score = sum(scores) / len(scores) * 100
return {'score': score}