[Feature] add dataset Fofo (#1224)

* add fofo dataset

* add dataset fofo
This commit is contained in:
bittersweet1999 2024-06-06 11:40:48 +08:00 committed by GitHub
parent 02a0a4e857
commit 982e024540
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 390 additions and 1 deletions

View File

@ -0,0 +1,30 @@
# Fofo
## Introduction
This paper presents FoFo, a pioneering benchmark for evaluating large language models' (LLMs) ability to follow complex, domain-specific formats, a crucial yet underexamined capability for their application as AI agents. Despite LLMs' advancements, existing benchmarks fail to assess their format-following proficiency adequately. FoFo fills this gap with a diverse range of real-world formats and instructions, developed through an AI-Human collaborative method. Our evaluation across both open-source (e.g., Llama 2, WizardLM) and closed-source (e.g., GPT-4, PALM2, Gemini) LLMs highlights three key findings: open-source models significantly lag behind closed-source ones in format adherence; LLMs' format-following performance is independent of their content generation quality; and LLMs' format proficiency varies across different domains. These insights suggest the need for specialized tuning for format-following skills and highlight FoFo's role in guiding the selection of domain-specific AI agents.
## Official link
https://github.com/SalesforceAIResearch/FoFo/tree/main
### Paper
https://arxiv.org/abs/2402.18667
## Examples
Input example I:
```
Create a detailed medical diagnostic report in JSON format for a hypothetical patient based on the following clinical scenario and laboratory results. \n\n**Clinical Scenario:**\n- Patient Identifier: 12345X\n- Gender: Female\n- Age: 40 years\n- Presenting Complaint: Acute onset of sharp, right lower quadrant abdominal pain that began approximately 6 hours ago\n- Past Medical History: Hypertension, well-controlled on medication; no known allergies; nonsmoker; nulliparous\n- Recent Labs: Slight leukocytosis, normal hemoglobin, elevated C-reactive protein\n- Imaging: Ultrasound indicates a thickened wall of the appendix with peri-appendiceal fluid collection\n- Surgery: The patient underwent an emergency laparoscopic appendectomy\n- Pathology Report: Confirmed acute appendicitis with peri-appendiceal abscess formation, no malignancy noted\n\nUsing the sample data provided above, generate a JSON object with the following structure:\n\n1. Basic Patient Information:\n - PatientID\n - Gender\n - Age\n\n2. Encounter details:\n - Presentation\n - Symptoms\n - DurationOfSymptoms\n - PastMedicalHistory\n - RecentLaboratoryFindings (nested object with LabName as key and the value as another object containing Result and ReferenceRange)\n - ImagingResults (array of objects with ImagingType, Findings, and Conclusion)\n\n3. Surgical Intervention:\n - SurgeryType\n - Urgency (Elective or Emergency)\n - Outcome (Success or Complication, followed by a brief description)\n\n4. Pathology Report:\n - Diagnosis\n - Findings\n - Implications\n \n5. Recommendations:\n - Follow up actions\n - Medications (array of objects with DrugName, Dosage, Frequency, and Duration)\n - Warnings (e.g., warning signs of complications to watch for post-operation)\n\nYou are allowed to make up any additional relevant data, such as reference ranges for the lab results or specific medication names and dosages for the follow-up recommendations. The JSON must strictly adhere to proper JSON formatting rules, including the use of curly brackets for objects, square brackets for arrays, and double quotes for strings. Keep key names succinct and ensure they accurately reflect their corresponding data. The output should be directly usable as a comprehensive digital diagnostic report for the hypothetical patient described.
```
Output example I (from GPT-4):
```
\n \"BasicPatientInformation\": {\n \"PatientID\": \"12345X\",\n \"Gender\": \"Female\",\n \"Age\": 40\n },\n \"EncounterDetails\": {\n \"Presentation\": \"Acute onset of sharp, right lower quadrant abdominal pain\",\n \"Symptoms\": \"Sharp abdominal pain\",\n \"DurationOfSymptoms\": \"6 hours\",\n \"PastMedicalHistory\": \"Hypertension, well-controlled on medication; no known allergies; nonsmoker; nulliparous\",\n \"RecentLaboratoryFindings\": {\n \"LeukocyteCount\": {\n \"Result\": \"Slightly elevated\",\n \"ReferenceRange\": \"4.5-11.0 x 10^9\/L\"\n },\n \"Hemoglobin\": {\n \"Result\": \"Normal\",\n \"ReferenceRange\": \"12.0-16.0 g\/dL\"\n },\n \"CReactiveProtein\": {\n \"Result\": \"Elevated\",\n \"ReferenceRange\": \"0-5 mg\/L\"\n }\n },\n \"ImagingResults\": [\n {\n \"ImagingType\": \"Ultrasound\",\n \"Findings\": \"Thickened wall of the appendix with peri-appendiceal fluid collection\",\n \"Conclusion\": \"Appendix inflammation indicative of acute appendicitis\"\n }\n ]\n },\n \"SurgicalIntervention\": {\n \"SurgeryType\": \"Laparoscopic Appendectomy\",\n \"Urgency\": \"Emergency\",\n \"Outcome\": \"Success, no post-operative complications noted\"\n },\n \"PathologyReport\": {\n \"Diagnosis\": \"Acute appendicitis with peri-appendiceal abscess formation\",\n \"Findings\": \"Inflammation and abscess formation in the appendix, no malignancy\",\n \"Implications\": \"Appendix successfully removed, no malignancy found\"\n },\n \"Recommendations\": {\n \"FollowUpActions\": \"Patient should rest and recover, and attend a follow-up appointment in 2 weeks\",\n \"Medications\": [\n {\n \"DrugName\": \"Ibuprofen\",\n \"Dosage\": \"400 mg\",\n \"Frequency\": \"Every 6 hours\",\n \"Duration\": \"1 week\"\n },\n {\n \"DrugName\": \"Paracetamol\",\n \"Dosage\": \"500 mg\",\n \"Frequency\": \"Every 4 hours\",\n \"Duration\": \"1 week\"\n }\n ],\n \"Warnings\": \"Contact healthcare provider if pain persists or worsens, if fever develops, or if there are any signs of infection at the surgical site\"\n }\n
```
## Reference
```
@article{xia2024fofo,
title={FOFO: A Benchmark to Evaluate LLMs' Format-Following Capability},
author={Xia, Congying and Xing, Chen and Du, Jiangshu and Yang, Xinyi and Feng, Yihao and Xu, Ran and Yin, Wenpeng and Xiong, Caiming},
journal={arXiv preprint arXiv:2402.18667},
year={2024}
}
```

View File

@ -0,0 +1,96 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import FofoDataset
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['question'],
output_column='judge',
)
subjective_all_sets = [
'fofo_test_prompts', 'fofo_test_prompts_cn',
]
base_prompt = """
I would like you to create a leaderboard that evaluates the correctness of the format of answers from various large language models. To accomplish this, you will need to analyze the text prompts given to the models and their corresponding answers. Specifically, please ensure that your evaluation outputs are properly formatted as a json string. I will provide both the prompts and the responses for this purpose.
Here is the prompt:
{
"instruction": "{question}",
}
Here are the outputs of the models:
[
{
"model": "model",
"answer": "{prediction}"
},
]
Please evaluate the formatting of the model's responses by checking if they comply with the format specifications stated in the prompt. Perform a thorough format check and provide a detailed explanation for why the format is correct or incorrect. Your feedback should include the name of the model, followed by the format correctness status represented as '1' for correct and '0' for incorrect. Present your reasoning as bullet points within a single string for each model assessed. In other words, you should produce the following output:
```json
[
{
'model': <model-name>,
'format_correctness': <correctness>,
'reasons': <reasons-of-format-correctness>
}
]
```
Please note that your response should be a properly formatted JSON string and should not contain any additional content. We will load it directly as a JSON string in Python.
"""
subjective_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt='{question}'
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.")
],
round=[
dict(
role='HUMAN',
prompt = base_prompt
),
]),
),
),
pred_role='BOT',
)
subjective_datasets.append(
dict(
abbr=f'{_name}',
type=FofoDataset,
path='./data/subjective/fofo',
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,69 @@
from mmengine.config import read_base
with read_base():
from .datasets.subjective.fofo.fofo_judge import subjective_datasets
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.models import HuggingFacewithChatTemplate
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import FofoSummarizer
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
# -------------Inference Stage ----------------------------------------
# For subjective evaluation, we often set do sample for models
models = [
dict(
type=HuggingFacewithChatTemplate,
abbr='internlm2-chat-1.8b-hf',
path='internlm/internlm2-chat-1_8b',
max_out_len=1024,
batch_size=8,
run_cfg=dict(num_gpus=1),
stop_words=['</s>', '<|im_end|>'],
generation_kwargs=dict(
do_sample=True,
),
)
]
datasets = [*subjective_datasets]
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-1106-preview',
key='xxxx', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=16,
max_out_len=2048,
max_seq_len=2048,
batch_size=8,
temperature=0,
)]
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner, max_task_size=10000, mode='singlescore', models=models, judge_models=judge_models,
),
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
)
summarizer = dict(type=FofoSummarizer, judge_type='general')
work_dir = 'outputs/fofo/'

View File

@ -4,6 +4,7 @@ from .compass_arena import CompassArenaDataset # noqa: F401, F403
from .compassbench import CompassBenchDataset # noqa: F401, F403
from .corev2 import Corev2Dataset # noqa: F401, F403
from .creationbench import CreationBenchDataset # noqa: F401, F403
from .fofo import FofoDataset # noqa: F401, F403
from .information_retrival import IRDataset # noqa: F401, F403
from .mtbench import MTBenchDataset # noqa: F401, F403
from .mtbench101 import MTBench101Dataset # noqa: F401, F403

View File

@ -20,7 +20,7 @@ base_prompt_zh = """请根据 用户问题 以及 相应的两个回答,判断
{prediction2}
[回答2结束]
根据评分要求请先对两个回答进行评价最后在以下 3 个选项中做出选择:
请先对两个回答进行评价最后在以下 3 个选项中做出选择:
A. 回答1更好
B. 回答2更好
C. 回答12平局
@ -87,6 +87,7 @@ class CompassBenchDataset(BaseDataset):
lan = problem['language']
others = problem['others']
judge_prompt = base_prompt_zh if lan == 'zh' else base_prompt_en
judge_prompt = judge_prompt.replace('{question}', question)
raw_data.append({
'question': question,
'judge_prompt': judge_prompt,

View File

@ -0,0 +1,36 @@
# flake8: noqa
import json
import os.path as osp
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class FofoDataset(BaseDataset):
def load(self, path: str, name: str):
filename = osp.join(path, f'{name}.json')
raw_data = []
with open(filename, 'r', encoding='utf-8') as f:
json_data = json.load(f)
for problem in json_data:
question = problem['instruction']
lan = 'cn' if 'cn' in name else 'en'
raw_data.append({
'question': question,
'judge': {
'lan': lan,
'id': problem['id'],
'domain': problem['domain'],
'sub_domain': problem['sub_domain'],
'format': problem['format'],
'format_type': problem['format_type'],
'question': question
}
})
dataset = Dataset.from_list(raw_data)
return dataset

View File

@ -215,6 +215,7 @@ class LMEvaluator:
for k, v in pred_dict.items():
dataset.reader.dataset['test'] = dataset.test.add_column(k, v)
dataset.reader.input_columns.append(k)
if references:
dataset.reader.input_columns.append('reference')
dataset.reader.dataset['test'] = dataset.test.add_column(

View File

@ -8,6 +8,7 @@ from .compassbench import CompassBenchSummarizer
from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer
from .flames import FlamesSummarizer
from .fofo import FofoSummarizer
from .information_retrival import IRSummarizer
from .mtbench import MTBenchSummarizer
from .mtbench101 import MTBench101Summarizer

View File

@ -0,0 +1,154 @@
# flake8: noqa: E501
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
import numpy as np
from mmengine import ConfigDict
from tabulate import tabulate
try:
from prettytable import from_csv
except ImportError:
from_csv = None
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
from .compass_arena import CompassArenaSummarizer
from .utils import get_judgeanswer_and_reference, get_outdir
# from .utils.writer import Writer
def post_process_fofo(judgement: str):
"""Input a string like below:
xxx[[5]]xxx, and extract the score
"""
match = re.search(r"[\"']format_correctness[\"']:\s*([0-1]+)", judgement)
if match:
score = int(match.group(1))
else:
return None
return {'score': score, 'judgement': judgement}
class FofoSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, judge_type='single') -> None:
self.tasks = []
self.cfg = config
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
self.eval_model_abbrs = [
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
self.judge_models = self.cfg.get('judge_models', None)
self.judge_function = post_process_fofo
def get_score(self, time_str):
output_dir, results_folder = get_outdir(self.cfg, time_str)
total_scores = {}
for idx, judge_model_cfg in enumerate(self.judge_models):
judge_model = model_abbr_from_cfg(judge_model_cfg)
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + judge_model
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, self.judge_function)
scores = defaultdict(list)
for ans, ref in zip(judged_answers, references):
domain = ref['domain']
format_name = ref['format']
format_type = ref['format_type']
score = ans['score']
if score is not None:
scores['overall'].append(score)
scores[domain].append(score)
if format_type == 'general':
scores[format_name].append(score)
single_model_scores = {
task: sum(score) / len(score)
for task, score in scores.items()
}
if judge_model not in total_scores:
total_scores[judge_model] = {}
if dataset_abbr not in total_scores[judge_model]:
total_scores[judge_model][dataset_abbr] = {}
total_scores[judge_model][dataset_abbr][
eval_model_abbr] = single_model_scores
else:
print(subdir_path + ' is not exist! please check!')
return total_scores
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
scores = self.get_score(time_str)
print(scores)
output_dir, results_folder = get_outdir(self.cfg, time_str)
for idx, judge_model in enumerate(self.judge_models):
judge_abbr = model_abbr_from_cfg(judge_model)
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
summarizer_model_abbrs = self.eval_model_abbrs
one_column = list(scores[judge_abbr][dataset_abbr].values())[0]
format_types = ['Json', 'CSV', 'XML', 'YAML', 'Markdown']
row_headers = [
i for i in one_column.keys()
if i not in [dataset_abbr] + format_types
]
row_headers = ['overall'] + format_types + row_headers
headers = [dataset_abbr] + summarizer_model_abbrs
table = []
for row_header in row_headers:
row = [row_header]
for model_abbr in summarizer_model_abbrs:
s = scores[judge_abbr][dataset_abbr][model_abbr].get(
row_header, '')
if isinstance(s, float):
s = f'{s:.2f}'
if isinstance(s, int):
s = str(s)
row.append(s)
table.append(row)
txt = tabulate(table, headers=headers)
print(txt)
if idx == len(self.judge_models):
output_filename = osp.join(
output_dir, 'summarized-by--' + judge_abbr + '-' +
dataset_abbr + '-report.csv')
else:
output_filename = osp.join(
output_dir, 'judged-by--' + judge_abbr + '-' +
dataset_abbr + '-report.csv')
with open(output_filename, 'w') as f:
f.write(','.join(headers) + '\n')
for line in table:
f.write(','.join(line) + '\n')
print(output_filename)