[Feature] Update CHARM Memeorziation (#1230)

* update gemini api and add gemini models

* add openai models

* update CHARM evaluation

* add CHARM memorization tasks

* add CharmMemSummarizer (output eval details for memorization-independent reasoning analysis

* update CHARM readme

---------

Co-authored-by: wujiang <wujiang@pjlab.org.cn>
This commit is contained in:
jxd 2024-07-26 18:42:30 +08:00 committed by GitHub
parent d3782c1d47
commit 12b84aeb3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 762 additions and 41 deletions

View File

@ -86,15 +86,69 @@ Below are the steps for quickly downloading CHARM and using OpenCompass for eval
### 1. Download CHARM ### 1. Download CHARM
```bash ```bash
git clone https://github.com/opendatalab/CHARM ${path_to_CHARM_repo} git clone https://github.com/opendatalab/CHARM ${path_to_CHARM_repo}
cd ${path_to_opencompass}
mkdir data
ln -snf ${path_to_CHARM_repo}/data/CHARM ./data/CHARM
``` ```
### 2. Run Inference and Evaluation ### 2. Run Inference and Evaluation
```bash ```bash
cd ${path_to_opencompass} cd ${path_to_opencompass}
mkdir -p data
ln -snf ${path_to_CHARM_repo}/data/CHARM ./data/CHARM
# Infering and evaluating CHARM with hf_llama3_8b_instruct model # modify config file `configs/eval_charm_rea.py`: uncomment or add models you want to evaluate
python run.py --models hf_llama3_8b_instruct --datasets charm_gen python run.py configs/eval_charm_rea.py -r --dump-eval-details
# modify config file `configs/eval_charm_mem.py`: uncomment or add models you want to evaluate
python run.py configs/eval_charm_mem.py -r --dump-eval-details
```
The inference and evaluation results would be in `${path_to_opencompass}/outputs`, like this:
```bash
outputs
├── CHARM_mem
│ └── chat
│ └── 20240605_151442
│ ├── predictions
│ │ ├── internlm2-chat-1.8b-turbomind
│ │ ├── llama-3-8b-instruct-lmdeploy
│ │ └── qwen1.5-1.8b-chat-hf
│ ├── results
│ │ ├── internlm2-chat-1.8b-turbomind_judged-by--GPT-3.5-turbo-0125
│ │ ├── llama-3-8b-instruct-lmdeploy_judged-by--GPT-3.5-turbo-0125
│ │ └── qwen1.5-1.8b-chat-hf_judged-by--GPT-3.5-turbo-0125
│   └── summary
│   └── 20240605_205020 # MEMORY_SUMMARY_DIR
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Anachronisms_Judgment
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Movie_and_Music_Recommendation
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Sport_Understanding
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Time_Understanding
│   └── judged-by--GPT-3.5-turbo-0125.csv # MEMORY_SUMMARY_CSV
└── CHARM_rea
└── chat
└── 20240605_152359
├── predictions
│ ├── internlm2-chat-1.8b-turbomind
│ ├── llama-3-8b-instruct-lmdeploy
│ └── qwen1.5-1.8b-chat-hf
├── results # REASON_RESULTS_DIR
│ ├── internlm2-chat-1.8b-turbomind
│ ├── llama-3-8b-instruct-lmdeploy
│ └── qwen1.5-1.8b-chat-hf
└── summary
├── summary_20240605_205328.csv # REASON_SUMMARY_CSV
└── summary_20240605_205328.txt
```
### 3. Generate Analysis Results
```bash
cd ${path_to_CHARM_repo}
# generate Table5, Table6, Table9 and Table10 in https://arxiv.org/abs/2403.14112
PYTHONPATH=. python tools/summarize_reasoning.py ${REASON_SUMMARY_CSV}
# generate Figure3 and Figure9 in https://arxiv.org/abs/2403.14112
PYTHONPATH=. python tools/summarize_mem_rea.py ${REASON_SUMMARY_CSV} ${MEMORY_SUMMARY_CSV}
# generate Table7, Table12, Table13 and Figure11 in https://arxiv.org/abs/2403.14112
PYTHONPATH=. python tools/analyze_mem_indep_rea.py data/CHARM ${REASON_RESULTS_DIR} ${MEMORY_SUMMARY_DIR} ${MEMORY_SUMMARY_CSV}
``` ```
## 🖊️ Citation ## 🖊️ Citation

View File

@ -84,15 +84,69 @@
### 1. 下载 CHARM ### 1. 下载 CHARM
```bash ```bash
git clone https://github.com/opendatalab/CHARM ${path_to_CHARM_repo} git clone https://github.com/opendatalab/CHARM ${path_to_CHARM_repo}
cd ${path_to_opencompass}
mkdir data
ln -snf ${path_to_CHARM_repo}/data/CHARM ./data/CHARM
``` ```
### 2. 推理和评测 ### 2. 推理和评测
```bash ```bash
cd ${path_to_opencompass} cd ${path_to_opencompass}
mkdir -p data
ln -snf ${path_to_CHARM_repo}/data/CHARM ./data/CHARM
# 在CHARM上对模型hf_llama3_8b_instruct做推理和评测 # 修改配置文件`configs/eval_charm_rea.py`: 将现有的模型取消注释,或者添加你想评测的模型
python run.py --models hf_llama3_8b_instruct --datasets charm_gen python run.py configs/eval_charm_rea.py -r --dump-eval-details
# 修改配置文件`configs/eval_charm_mem.py`: 将现有的模型取消注释,或者添加你想评测的模型
python run.py configs/eval_charm_mem.py -r --dump-eval-details
```
推理和评测的结果位于路径`${path_to_opencompass}/outputs`, 如下所示:
```bash
outputs
├── CHARM_mem
│ └── chat
│ └── 20240605_151442
│ ├── predictions
│ │ ├── internlm2-chat-1.8b-turbomind
│ │ ├── llama-3-8b-instruct-lmdeploy
│ │ └── qwen1.5-1.8b-chat-hf
│ ├── results
│ │ ├── internlm2-chat-1.8b-turbomind_judged-by--GPT-3.5-turbo-0125
│ │ ├── llama-3-8b-instruct-lmdeploy_judged-by--GPT-3.5-turbo-0125
│ │ └── qwen1.5-1.8b-chat-hf_judged-by--GPT-3.5-turbo-0125
│   └── summary
│   └── 20240605_205020 # MEMORY_SUMMARY_DIR
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Anachronisms_Judgment
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Movie_and_Music_Recommendation
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Sport_Understanding
│   ├── judged-by--GPT-3.5-turbo-0125-charm-memory-Chinese_Time_Understanding
│   └── judged-by--GPT-3.5-turbo-0125.csv # MEMORY_SUMMARY_CSV
└── CHARM_rea
└── chat
└── 20240605_152359
├── predictions
│ ├── internlm2-chat-1.8b-turbomind
│ ├── llama-3-8b-instruct-lmdeploy
│ └── qwen1.5-1.8b-chat-hf
├── results # REASON_RESULTS_DIR
│ ├── internlm2-chat-1.8b-turbomind
│ ├── llama-3-8b-instruct-lmdeploy
│ └── qwen1.5-1.8b-chat-hf
└── summary
├── summary_20240605_205328.csv # REASON_SUMMARY_CSV
└── summary_20240605_205328.txt
```
### 3. 生成分析结果
```bash
cd ${path_to_CHARM_repo}
# 生成论文中的Table5, Table6, Table9 and Table10详见https://arxiv.org/abs/2403.14112
PYTHONPATH=. python tools/summarize_reasoning.py ${REASON_SUMMARY_CSV}
# 生成论文中的Figure3 and Figure9详见https://arxiv.org/abs/2403.14112
PYTHONPATH=. python tools/summarize_mem_rea.py ${REASON_SUMMARY_CSV} ${MEMORY_SUMMARY_CSV}
# 生成论文中的Table7, Table12, Table13 and Figure11详见https://arxiv.org/abs/2403.14112
PYTHONPATH=. python tools/analyze_mem_indep_rea.py data/CHARM ${REASON_RESULTS_DIR} ${MEMORY_SUMMARY_DIR} ${MEMORY_SUMMARY_CSV}
``` ```
## 🖊️ 引用 ## 🖊️ 引用

View File

@ -0,0 +1,63 @@
import os
from mmengine.config import read_base
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import CharmDataset, CharmMemoryEvaluator, LMEvaluator
with read_base():
from .charm_memory_settings import charm_memory_tasks, judge_system_prompts, dataset_path
charm_memory_datasets = []
for _task in charm_memory_tasks:
charm_memory_reader_cfg = dict(input_columns=['input'],
output_column='target')
charm_memory_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt='请尽可能简短地回答下述问题。\n问题:{input}\n答:')
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=512),
)
if _task == 'Chinese_Movie_and_Music_Recommendation':
charm_memory_eval_cfg = dict(
evaluator=dict(type=CharmMemoryEvaluator),
pred_role='BOT',
)
else:
judge_system_prompt = judge_system_prompts[_task]
charm_memory_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt=judge_system_prompt +
"\n\n[Question]\n{input}\n[The Start of Reference Answer]\n{target}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{prediction}\n[The End of Assistant's Answer]" # noqa
),
]),
),
),
pred_role='BOT',
)
charm_memory_datasets.append(
dict(
type=CharmDataset,
path=dataset_path,
name=_task,
abbr='charm-memory-' + _task,
reader_cfg=charm_memory_reader_cfg,
infer_cfg=charm_memory_infer_cfg.copy(),
eval_cfg=charm_memory_eval_cfg.copy(),
))

View File

@ -0,0 +1,31 @@
import os
charm_memory_tasks = [
'Chinese_Anachronisms_Judgment',
'Chinese_Movie_and_Music_Recommendation',
'Chinese_Sport_Understanding',
'Chinese_Time_Understanding',
]
dataset_path = 'data/CHARM/memorization'
system_prompt_template = """Please act as an impartial judge, comparing the responses of the AI assistants to the reference answer and determining if the answers are correct.
You will receive the reference answer provided by a human and the responses of the AI assistants.
Your task is to judge whether the AI assistant's answers is correct.
{task_specific_prompt}
After providing your explanation, strictly output your final judgment in the following format: [正确] if the AI assistant's response is correct, “[错误]” if the AI assistant's response is incorrect.
"""
task_specific_prompts = {
'Chinese_Anachronisms_Judgment':
"If the provided reference answer is a list, the model's prediction is considered correct if it matches any item in the list.",
'Chinese_Time_Understanding':
"When evaluating the AI assistant's response regarding Chinese solar terms, as long as the AI assistant's response falls within the time frame provided in the reference answer, consider it correct.",
'Chinese_Sport_Understanding':
"If the provided reference answer is a list, the model's prediction is considered correct if it matches any item in the list."
}
judge_system_prompts = {
k: system_prompt_template.format(task_specific_prompt=v)
for k, v in task_specific_prompts.items()
}

94
configs/eval_charm_mem.py Normal file
View File

@ -0,0 +1,94 @@
from mmengine.config import read_base
from opencompass.models import OpenAI
from opencompass.runners import LocalRunner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import CharmMemSummarizer
with read_base():
from .datasets.CHARM.charm_memory_gen_bbbd53 import charm_memory_datasets as datasets
# ------>>>>>> https://arxiv.org/abs/2403.14112
# from .models.openai.gpt_3_5_turbo_1106 import models as gpt_3_5_turbo_1106_model
# from .models.openai.gpt_4_1106_preview import models as gpt_4_1106_preview_model
# from .models.hf_llama.hf_llama2_7b_chat import models as llama2_7b_chat_model
# from .models.hf_llama.hf_llama2_13b_chat import models as llama2_13b_chat_model
# from .models.hf_llama.hf_llama2_70b_chat import models as llama2_70b_chat_model
# from .models.vicuna.hf_vicuna_7b_v15_16k import models as vicuna_7b_v15_16k_model
# from .models.vicuna.hf_vicuna_13b_v15_16k import models as vicuna_13b_v15_16k_model
# from .models.chatglm.hf_chatglm3_6b_32k import models as chatglm3_6b_32k_model
# from .models.baichuan.hf_baichuan2_7b_chat import models as baichuan2_7b_chat_model # need torch 2.1
# from .models.baichuan.hf_baichuan2_13b_chat import models as baichuan2_13b_chat_model # need torch 2.1
# from .models.hf_internlm.hf_internlm2_chat_7b import models as hf_internlm2_chat_7b_model
# from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model
# from .models.yi.hf_yi_6b_chat import models as yi_6b_chat_model
# from .models.yi.hf_yi_34b_chat import models as yi_34b_chat_model
# from .models.deepseek.hf_deepseek_7b_chat import models as deepseek_7b_chat_model
# from .models.deepseek.hf_deepseek_67b_chat import models as deepseek_67b_chat_model
# from .models.qwen.hf_qwen_7b_chat import models as qwen_7b_chat_model
# from .models.qwen.hf_qwen_14b_chat import models as qwen_14b_chat_model
# from .models.qwen.hf_qwen_72b_chat import models as qwen_72b_chat_model
# <<<<<<------ https://arxiv.org/abs/2403.14112
# from .models.openai.gpt_3_5_turbo_0125 import models as gpt_3_5_turbo_0125_model
# from .models.openai.gpt_4o_2024_05_13 import models as gpt_4o_2024_05_13_model
# from .models.gemini.gemini_1_5_flash import models as gemini_1_5_flash_model
# from .models.gemini.gemini_1_5_pro import models as gemini_1_5_pro_model
# from .models.hf_llama.lmdeploy_llama3_8b_instruct import models as lmdeploy_llama3_8b_instruct_model
# from .models.hf_llama.lmdeploy_llama3_70b_instruct import models as lmdeploy_llama3_70b_instruct_model
# from .models.hf_internlm.lmdeploy_internlm2_chat_1_8b import models as lmdeploy_internlm2_chat_1_8b_model
# from .models.hf_internlm.lmdeploy_internlm2_chat_7b import models as lmdeploy_internlm2_chat_7b_model
# from .models.hf_internlm.lmdeploy_internlm2_chat_20b import models as lmdeploy_internlm2_chat_20b_model
# from .models.yi.hf_yi_1_5_6b_chat import models as yi_1_5_6b_chat_model
# from .models.yi.hf_yi_1_5_34b_chat import models as yi_1_5_34b_chat_model
# from .models.deepseek.hf_deepseek_v2_chat import models as deepseek_v2_chat_model
# from .models.qwen.hf_qwen1_5_1_8b_chat import models as qwen1_5_1_8b_chat_model
# from .models.qwen.hf_qwen1_5_7b_chat import models as qwen1_5_7b_chat_model
# from .models.qwen.hf_qwen1_5_14b_chat import models as qwen1_5_14b_chat_model
# from .models.qwen.hf_qwen1_5_72b_chat import models as qwen1_5_72b_chat_model
models = sum([v for k, v in locals().items() if k.endswith('_model')], [])
## ------------- JudgeLLM Configuration
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
])
judge_models = [
dict(
abbr='GPT-3.5-turbo-0125',
type=OpenAI,
path='gpt-3.5-turbo-0125',
key='ENV',
meta_template=api_meta_template,
query_per_second=16,
max_out_len=2048,
max_seq_len=2048,
batch_size=8,
temperature=0,
)
]
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
max_task_size=1000,
mode='singlescore',
models=models,
judge_models=judge_models,
),
runner=dict(type=LocalRunner,
max_num_workers=2,
task=dict(type=SubjectiveEvalTask)),
)
summarizer = dict(type=CharmMemSummarizer)
work_dir = './outputs/CHARM_mem/chat/'

View File

@ -2,35 +2,55 @@ from mmengine.config import read_base
with read_base(): with read_base():
from .datasets.CHARM.charm_reason_gen_f8fca2 import charm_reason_datasets as datasets from .datasets.CHARM.charm_reason_gen_f8fca2 import charm_reason_datasets as datasets
from .models.hf_internlm.lmdeploy_internlm2_chat_7b import models as lmdeploy_7b_chat_model
# from models.openai.gpt_3_5_turbo_1106 import models as gpt_3_5_turbo_1106_model # ------>>>>>> https://arxiv.org/abs/2403.14112
# from models.openai.gpt_4_1106_preview import models as gpt_4_1106_preview_model # from .models.openai.gpt_3_5_turbo_1106 import models as gpt_3_5_turbo_1106_model
# from .models.openai.gpt_4_1106_preview import models as gpt_4_1106_preview_model
# from .models.chatglm.hf_chatglm3_6b_32k import models as chatglm3_6b_32k_model
# from .models.yi.hf_yi_6b_chat import models as yi_6b_chat_model
# from .models.hf_internlm.hf_internlm2_chat_7b import models as hf_internlm2_chat_7b_model
# from .models.deepseek.hf_deepseek_7b_chat import models as deepseek_7b_chat_model
# from .models.baichuan.hf_baichuan2_7b_chat import models as baichuan2_7b_chat_model # need torch 2.1
# from .models.hf_llama.hf_llama2_7b_chat import models as llama2_7b_chat_model # from .models.hf_llama.hf_llama2_7b_chat import models as llama2_7b_chat_model
# from .models.vicuna.hf_vicuna_7b_v15_16k import models as vicuna_7b_v15_16k_model
# from .models.baichuan.hf_baichuan2_13b_chat import models as baichuan2_13b_chat_model # need torch 2.1
# from .models.hf_llama.hf_llama2_13b_chat import models as llama2_13b_chat_model # from .models.hf_llama.hf_llama2_13b_chat import models as llama2_13b_chat_model
# from .models.vicuna.hf_vicuna_13b_v15_16k import models as vicuna_13b_v15_16k_model
# from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model
# from .models.yi.hf_yi_34b_chat import models as yi_34b_chat_model
# from .models.deepseek.hf_deepseek_67b_chat import models as deepseek_67b_chat_model
# from .models.hf_llama.hf_llama2_70b_chat import models as llama2_70b_chat_model # from .models.hf_llama.hf_llama2_70b_chat import models as llama2_70b_chat_model
# from .models.vicuna.hf_vicuna_7b_v15_16k import models as vicuna_7b_v15_16k_model
# from .models.vicuna.hf_vicuna_13b_v15_16k import models as vicuna_13b_v15_16k_model
# from .models.chatglm.hf_chatglm3_6b_32k import models as chatglm3_6b_32k_model
# from .models.baichuan.hf_baichuan2_7b_chat import models as baichuan2_7b_chat_model # need torch 2.1
# from .models.baichuan.hf_baichuan2_13b_chat import models as baichuan2_13b_chat_model # need torch 2.1
# from .models.hf_internlm.hf_internlm2_chat_7b import models as hf_internlm2_chat_7b_model
# from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model
# from .models.yi.hf_yi_6b_chat import models as yi_6b_chat_model
# from .models.yi.hf_yi_34b_chat import models as yi_34b_chat_model
# from .models.deepseek.hf_deepseek_7b_chat import models as deepseek_7b_chat_model
# from .models.deepseek.hf_deepseek_67b_chat import models as deepseek_67b_chat_model
# from .models.qwen.hf_qwen_7b_chat import models as qwen_7b_chat_model
# from .models.qwen.hf_qwen_14b_chat import models as qwen_14b_chat_model
# from .models.qwen.hf_qwen_72b_chat import models as qwen_72b_chat_model
# <<<<<<------ https://arxiv.org/abs/2403.14112
# from .models.hf_llama.hf_llama3_8b_instruct import models as llama3_8b_instruct_model # from .models.openai.gpt_3_5_turbo_0125 import models as gpt_3_5_turbo_0125_model
# from .models.hf_llama.hf_llama3_70b_instruct import models as llama3_70b_instruct_model # from .models.openai.gpt_4o_2024_05_13 import models as gpt_4o_2024_05_13_model
from .summarizers.charm_rea import summarizer # from .models.gemini.gemini_1_5_flash import models as gemini_1_5_flash_model
# from .models.gemini.gemini_1_5_pro import models as gemini_1_5_pro_model
# from .models.hf_llama.lmdeploy_llama3_8b_instruct import models as lmdeploy_llama3_8b_instruct_model
# from .models.hf_llama.lmdeploy_llama3_70b_instruct import models as lmdeploy_llama3_70b_instruct_model
# from .models.hf_internlm.lmdeploy_internlm2_chat_1_8b import models as lmdeploy_internlm2_chat_1_8b_model
# from .models.hf_internlm.lmdeploy_internlm2_chat_7b import models as lmdeploy_internlm2_chat_7b_model
# from .models.hf_internlm.lmdeploy_internlm2_chat_20b import models as lmdeploy_internlm2_chat_20b_model
# from .models.yi.hf_yi_1_5_6b_chat import models as yi_1_5_6b_chat_model
# from .models.yi.hf_yi_1_5_34b_chat import models as yi_1_5_34b_chat_model
# from .models.deepseek.hf_deepseek_v2_chat import models as deepseek_v2_chat_model
# from .models.qwen.hf_qwen1_5_1_8b_chat import models as qwen1_5_1_8b_chat_model
# from .models.qwen.hf_qwen1_5_7b_chat import models as qwen1_5_7b_chat_model
# from .models.qwen.hf_qwen1_5_14b_chat import models as qwen1_5_14b_chat_model
# from .models.qwen.hf_qwen1_5_72b_chat import models as qwen1_5_72b_chat_model
from .summarizers.charm_reason import summarizer
models = sum([v for k, v in locals().items() if k.endswith('_model')], []) models = sum([v for k, v in locals().items() if k.endswith('_model')], [])
work_dir = './outputs/CHARM/chat/' work_dir = './outputs/CHARM_rea/chat/'
# dataset version metric mode internlm2-chat-7b-turbomind # dataset version metric mode internlm2-chat-7b-turbomind
# ------------------------------------------------------------- --------- ------------- ------ ----------------------------- # ------------------------------------------------------------- --------- ------------- ------ -----------------------------

View File

@ -0,0 +1,22 @@
from opencompass.models import Gemini
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
], )
models = [
dict(
abbr='gemini-1.5-flash',
type=Gemini,
path='gemini-1.5-flash',
key=
'ENV', # The key will be obtained from $GEMINI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=15,
max_out_len=100,
max_seq_len=2048,
batch_size=1,
temperature=1,
)
]

View File

@ -0,0 +1,22 @@
from opencompass.models import Gemini
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
], )
models = [
dict(
abbr='gemini-1.5-pro',
type=Gemini,
path='gemini-1.5-pro',
key=
'ENV', # The key will be obtained from $GEMINI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=2,
max_out_len=100,
max_seq_len=2048,
batch_size=1,
temperature=1,
)
]

View File

@ -12,8 +12,7 @@ models = [
dict(abbr='gemini', dict(abbr='gemini',
type=Gemini, type=Gemini,
path='gemini-pro', path='gemini-pro',
key='your keys', # The key will be obtained from Environment, but you can write down your key here as well key='ENV', # The key will be obtained from $GEMINI_API_KEY, but you can write down your key here as well
url = 'your url',
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=16, query_per_second=16,
max_out_len=100, max_out_len=100,

View File

@ -0,0 +1,20 @@
from opencompass.models import OpenAI
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
], )
models = [
dict(
abbr='GPT-3.5-turbo-0125',
type=OpenAI,
path='gpt-3.5-turbo-0125',
key=
'ENV', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048,
max_seq_len=4096,
batch_size=8),
]

View File

@ -0,0 +1,20 @@
from opencompass.models import OpenAI
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
], )
models = [
dict(
abbr='GPT-4o-2024-05-13',
type=OpenAI,
path='gpt-4o-2024-05-13',
key=
'ENV', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048,
max_seq_len=4096,
batch_size=8),
]

View File

@ -1,12 +1,14 @@
import json import json
import os.path as osp import os.path as osp
import re import re
from typing import List, Union
from datasets import Dataset from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.openicl.icl_evaluator import BaseEvaluator, LMEvaluator
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET, from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS) TEXT_POSTPROCESSORS)
from opencompass.utils import build_dataset_from_cfg
from .base import BaseDataset from .base import BaseDataset
@ -44,6 +46,102 @@ class CharmReasonEvaluator(BaseEvaluator):
return {'score': score, 'details': details} return {'score': score, 'details': details}
UNCERTAIN_LIST = ['不确定', '无法确定', '无法回答', '不知道', '不认识']
def charm_memory_eval(pred: str, ref: Union[str, List[str]]) -> str:
for uncertain in UNCERTAIN_LIST:
if uncertain in pred:
return '[错误]'
is_negative = False
if isinstance(ref, str):
if ref.startswith('[not]'):
# 部分CHARM记忆题目的ref是"[not]xxx"
# 即xxx是一个负例pred中不应该出现xxx
# 例如https://github.com/opendatalab/CHARM/blob/v1.0/data/CHARM/memorization/Chinese_Movie_and_Music_Recommendation.json#L45
is_negative = True
ref = ref[5:] # 去掉[not]保留xxx
references = [ref]
else:
references = ref # 部分CHARM记忆题目的ref是List[str]
assert isinstance(references, list)
for r in references:
if r in pred: # pred中包含ref
if is_negative:
return '[错误]'
else:
return '[正确]'
if is_negative: # 已验证pred中不包含ref且ref是负例所以pred是正确的
return '[正确]'
else:
return '[错误]'
class CharmMemoryEvaluator(LMEvaluator):
"""本Evaluator是基于规则评判CHARM记忆题目的回答是否正确,
只用于Chinese_Movie_and_Music_Recommendation这一个任务的评判
由于CHARM其他的记忆任务需要使用LLM作为judge使用LMEvaluator因而整个eval使用的是SubjectiveEvalTask
因此本Evaluator的输入输出与LMEvaluator一致"""
def __init__(self, prompt_template=None, *nargs, **kwargs):
if prompt_template is None:
prompt_template = dict(
type='PromptTemplate',
template=dict(
round=[dict(role='HUMAN', prompt='')])) # useless
super().__init__(prompt_template, *nargs, **kwargs)
def score(self, predictions, references, **kwargs):
assert isinstance(predictions, dict) # single-model scoring
references = [{} for _ in range(len(predictions[0]['model_preds']))
] if references is None else references
predictions = predictions['model_preds']
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
eval_results = [
charm_memory_eval(pred, ref)
for pred, ref in zip(predictions, references)
]
dataset = None
if self.dataset_cfg:
dataset = build_dataset_from_cfg(self.dataset_cfg)
output = dict()
for i in range(len(predictions)):
if dataset is not None:
question = ''
for col in dataset.reader.input_columns:
question += dataset.reader['test'][col][i] + '\n'
output[str(i)] = {
'origin_prompt': [{
'role':
'HUMAN',
'prompt':
f"[Question]: {question}[Assistant's Answer]: {predictions[i]}" # noqa
}],
'prediction':
eval_results[i],
'gold':
references[i],
}
return output
@LOAD_DATASET.register_module() @LOAD_DATASET.register_module()
class CharmDataset(BaseDataset): class CharmDataset(BaseDataset):

View File

@ -1,5 +1,6 @@
# flake8: noqa: E501 # flake8: noqa: E501
import json import json
import os
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@ -48,7 +49,18 @@ class Gemini(BaseAPIModel):
query_per_second=query_per_second, query_per_second=query_per_second,
meta_template=meta_template, meta_template=meta_template,
retry=retry) retry=retry)
self.url = f'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={key}' assert isinstance(key, str)
if key == 'ENV':
if 'GEMINI_API_KEY' not in os.environ:
raise ValueError('GEMINI API key is not set.')
key = os.getenv('GEMINI_API_KEY')
assert path in [
'gemini-1.0-pro', 'gemini-pro', 'gemini-1.5-flash',
'gemini-1.5-pro'
] # https://ai.google.dev/gemini-api/docs/models/gemini#model-variations
self.url = f'https://generativelanguage.googleapis.com/v1beta/models/{path}:generateContent?key={key}'
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
@ -171,17 +183,20 @@ class Gemini(BaseAPIModel):
str(raw_response.content)) str(raw_response.content))
time.sleep(1) time.sleep(1)
continue continue
if raw_response.status_code == 200 and response['msg'] == 'ok': if raw_response.status_code == 200:
body = response['body'] if 'candidates' not in response:
if 'candidates' not in body:
self.logger.error(response) self.logger.error(response)
else: else:
if 'content' not in body['candidates'][0]: if 'content' not in response['candidates'][0]:
return "Due to Google's restrictive policies, I am unable to respond to this question." return "Due to Google's restrictive policies, I am unable to respond to this question."
else: else:
return body['candidates'][0]['content']['parts'][0][ return response['candidates'][0]['content']['parts'][
'text'].strip() 0]['text'].strip()
self.logger.error(response['msg']) try:
msg = response['error']['message']
self.logger.error(msg)
except KeyError:
pass
self.logger.error(response) self.logger.error(response)
time.sleep(1) time.sleep(1)

View File

@ -3,6 +3,7 @@ from .alignmentbench import AlignmentBenchSummarizer
from .all_obj import AllObjSummarizer from .all_obj import AllObjSummarizer
from .alpacaeval import AlpacaSummarizer from .alpacaeval import AlpacaSummarizer
from .arenahard import ArenaHardSummarizer from .arenahard import ArenaHardSummarizer
from .charm import CharmMemSummarizer
from .compass_arena import CompassArenaSummarizer from .compass_arena import CompassArenaSummarizer
from .compassbench import CompassBenchSummarizer from .compassbench import CompassBenchSummarizer
from .corev2 import Corev2Summarizer from .corev2 import Corev2Summarizer

View File

@ -0,0 +1,208 @@
# flake8: noqa: E501
import csv
import json
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
import mmengine
import numpy as np
import pandas as pd
from mmengine import ConfigDict
from prettytable import from_csv
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
model_abbr_from_cfg)
from .utils import get_outdir
def post_process_charm_mem(judgement: str):
"""Input a string like below:
xxx[correct]xxx, and extract the judge
"""
pattern = r'(?i)\[(incorrect|correct|正确|错误|Yes|No)\]'
matched_result = re.findall(pattern, judgement)
if matched_result:
content = matched_result[0].lower()
if content in ['correct', '正确', 'yes']:
return {'correct': True}
elif content in ['incorrect', '错误', 'no']:
return {'correct': False}
else:
return None
def get_judgeanswer_and_reference_charm_mem(dataset, subdir_path,
post_process):
"""Extract judgements (scores), references and original judging prompts.
Args:
dataset (ConfigDict): Dataset config.
subdir_path (str): Model path in results dir.
post_process (function): The pre-defined extract function.
"""
dataset_abbr = dataset_abbr_from_cfg(dataset)
filename = osp.join(subdir_path, dataset_abbr + '.json')
partial_filename = osp.join(subdir_path, dataset_abbr + '_0.json')
if osp.exists(osp.realpath(filename)):
result = mmengine.load(filename)
elif osp.exists(osp.realpath(partial_filename)):
filename = partial_filename
result = {}
i = 1
partial_dict_flag = 0
while osp.exists(osp.realpath(filename)):
res = mmengine.load(filename)
for k, v in res.items():
result[partial_dict_flag] = v
partial_dict_flag += 1
filename = osp.join(subdir_path,
dataset_abbr + '_' + str(i) + '.json')
i += 1
else:
result = {}
if len(result) == 0:
print('*' * 100)
print('There are no results for ' + filename + ' or ' +
partial_filename)
print('*' * 100)
assert len(result) > 0
judging_prompts = []
judged_answers = []
references = []
for k, v in result.items():
processed_judge = post_process(v['prediction'])
if processed_judge is not None:
judged_answers.append(processed_judge)
references.append(v['gold'])
judging_origin_prompts = v['origin_prompt']
if len(judging_origin_prompts) > 0:
judging_prompts.append(judging_origin_prompts[0].get(
'prompt', None))
if len(judged_answers) != len(result):
print(
f'Among {len(result)} judgements, successfully extracted {len(judged_answers)} judgements, please check!'
)
if len(judged_answers) == 0:
print('*' * 100)
print(
'There are no extracted judgements, please change your judge model or check your prompt!!!'
)
print('*' * 100)
assert len(judged_answers) > 0
return judged_answers, references, judging_prompts
def get_accuracy(judged_answers):
n_total = 0
n_correct = 0
for ans in judged_answers:
if ans.get('correct', False):
n_correct += 1
n_total += 1
return round(n_correct / n_total * 100, 2)
class CharmMemSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, judge_type='single') -> None:
self.judge_type = judge_type
self.tasks = []
self.cfg = config
if self.judge_type == 'single':
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
self.eval_model_abbrs = [
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
else:
raise NotImplementedError
self.judge_abbr = model_abbr_from_cfg(
self.cfg['eval']['partitioner']['judge_models'][0])
self.judge_map = {'single': post_process_charm_mem}
self.judge_function = self.judge_map[self.judge_type]
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
if self.judge_type == 'single':
dataset_cfgs = self.cfg['datasets']
judge_model = self.judge_abbr
output_dir, results_folder = get_outdir(self.cfg, time_str)
accuracy_df = pd.DataFrame(columns=self.eval_model_abbrs)
for dataset in dataset_cfgs:
dataset_abbr = dataset_abbr_from_cfg(dataset)
dataset_instance = build_dataset_from_cfg(dataset)
out_dir = osp.join(
output_dir,
'judged-by--' + judge_model + '-' + dataset_abbr)
os.makedirs(out_dir, exist_ok=True)
cur_acc_dict = {'dataset': dataset_abbr}
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model = eval_model_abbr
(judged_answers, references, judging_prompts
) = get_judgeanswer_and_reference_charm_mem(
dataset,
subdir_path,
self.judge_function,
)
accuracy = get_accuracy(judged_answers)
cur_acc_dict[eval_model_abbr] = accuracy
detail_dict = {}
for i in range(len(judged_answers)):
cur_dict = {}
cur_dict['judging_prompt'] = judging_prompts[i]
for input_col in dataset_instance.reader.input_columns:
cur_dict[input_col] = dataset_instance.reader[
'test'][input_col][i]
cur_dict['reference'] = references[i]
cur_dict.update(judged_answers[i])
detail_dict[str(i)] = cur_dict
out_dict = {'score': accuracy, 'details': detail_dict}
fout = osp.join(out_dir, model + '.json')
with open(fout, 'w', encoding='utf-8') as f:
json.dump(out_dict,
f,
indent=4,
ensure_ascii=False)
else:
print(subdir_path + ' is not exist! please check!')
accuracy_df = accuracy_df.append(cur_acc_dict,
ignore_index=True)
accuracy_df.set_index('dataset', inplace=True)
accuracy_file = osp.join(output_dir,
'judged-by--' + judge_model + '.csv')
accuracy_df.to_csv(accuracy_file, index=True)
with open(accuracy_file, 'r') as f:
x = from_csv(f)
print(x)