mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* add calm dataset * modify config max_out_len * update README * Modify README * update README * update README * update README * update README * update README * add summarizer and modify readme * delete summarizer config comment * update summarizer * modify same response to all questions * update README
43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
import datasets
|
|
from ..base import BaseDataset
|
|
from datasets import Dataset
|
|
from .data_processing.generate_questions import generate_question_list
|
|
from .evaluation.core_metrics import compute_core_metrics, task_to_accuracy_module_map
|
|
from .evaluation.errors import identify_model_errors
|
|
from opencompass.registry import LOAD_DATASET
|
|
from typing import List
|
|
|
|
@LOAD_DATASET.register_module()
|
|
class CaLMDataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str, prompt_style: str) -> datasets.Dataset:
|
|
question_list = generate_question_list(dataset_path=path, prompt_style=prompt_style)
|
|
dataset = Dataset.from_list(question_list)
|
|
return dataset
|
|
|
|
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
|
|
class CaLMEvaluator(BaseEvaluator):
|
|
def __init__(self, core_metrics, error_analysis, prompt_style, task) -> None:
|
|
super().__init__()
|
|
self.core_metrics = core_metrics
|
|
self.error_analysis = error_analysis
|
|
self.prompt_style = prompt_style
|
|
self.task = task
|
|
|
|
def score(self, predictions: List, references: List, ) -> dict:
|
|
results = {}
|
|
if self.core_metrics:
|
|
metrics, pred_list = compute_core_metrics(predictions, task=self.task, prompt_style=self.prompt_style, gt_items=references)
|
|
results.update(metrics)
|
|
if self.error_analysis:
|
|
if self.task.startswith("CEG-O_E-CARE"):
|
|
print("There's no error analysis for CEG-O_E-CARE task. Skipping error analysis.")
|
|
return results
|
|
errors = identify_model_errors(predictions, task=self.task, prompt_style=self.prompt_style, gt_items=references) # Define specific criteria
|
|
results.update(errors)
|
|
return results
|
|
|