[Feature] Add SC (#126)

* add self-consistency

* add CoT method Self-Consistency

* fix typo error and update openicl_eval

* add tydiQA-GoldP task

* fix sc

* rename gsm8k_sc

* fix sc

* add self-consistency doc

* refine sc

---------

Authored-by: liushz <qq1791167085@163.com>
This commit is contained in:
Leymore 2023-07-28 17:29:37 +08:00 committed by GitHub
parent 538b439302
commit d862f570aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 487 additions and 19 deletions

View File

@ -0,0 +1,90 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import SCInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import HFDataset, gsm8k_postprocess, gsm8k_dataset_postprocess
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer' )
generation_kwargs = dict(do_sample=True, temperature=0.7, top_k=40)
gsm8k_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=
'''Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?
Let's think step by step
Answer:
Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.
For the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.
Angelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.
However, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.
They also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.
And they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.
So Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.
They want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75
They will need to plan to study 4 days to allow for all the time they need.
The answer is 4
Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?
Let's think step by step
Answer:
Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.
His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers
They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.
All together his team scored 50+24+10= 84 points
Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.
His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.
They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.
All together Mark's opponents scored 100+12+5=117 points
The total score for the game is both team's scores added together, so it is 84+117=201 points
The answer is 201
Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?
Let's think step by step
Answer:
When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24
The total number of marbles she'll have is 60+24 = 84
If Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.
If Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.
The total number of frisbees she'll have will increase to 30+12 = 42
Bella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards
If she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.
The total number of deck cards she'll have is 10+4 = 14
Together, Bella will have a total of 14+42+84 = 140 items
The answer is 140
Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?
Let's think step by step
Answer:
For the first three baskets, the number of apples and oranges in one basket is 9+15=24
In total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.
Since there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.
The number of apples in the fourth basket is 9-2=7
There are also 15-2=13 oranges in the fourth basket
The combined number of oranges and apples in the fourth basket is 13+7=20
The fourth basket also contains 14-2=12 bananas.
In total, the fourth basket has 20+12=32 fruits.
The four baskets together have 32+114=146 fruits.
The answer is 146
Question: {question}{answer}
'''),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=SCInferencer, max_out_len=512, generation_kwargs = generation_kwargs, infer_type='sc', sc_size = 20))
gsm8k_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=gsm8k_postprocess),
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
sc_size = 20)
gsm8k_datasets = [
dict(
abbr='gsm8k',
type=HFDataset,
path='gsm8k',
name='main',
reader_cfg=gsm8k_reader_cfg,
infer_cfg=gsm8k_infer_cfg,
eval_cfg=gsm8k_eval_cfg)
]

View File

@ -0,0 +1,72 @@
# Chain of Thought
## Background
During the process of reasoning, CoT (Chain of Thought) method is an efficient way to help LLMs deal complex questions, for example: math problem and relation inference. In OpenCompass, we support multiple types of CoT method.
## 1. Zero Shot CoT
You can change the `PromptTemplate` of the dataset config, by simply add *Let's think step by step* to realize a Zero-Shot CoT prompt for your evaluation:
```python
qa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="Answer the question:\nQ: {question}?\nLet's think step by step:\n"
),
retriever=dict(type=ZeroRetriever)
)
```
## 2. Few Shot CoT
Few-shot CoT can make LLMs easy to follow your instructions and get better answers. For few-shot CoT, add your CoT template to `PromptTemplate` like following config to create a one-shot prompt:
```python
qa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=
'''Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?
Let's think step by step
Answer:
Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.
His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers
They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.
All together his team scored 50+24+10= 84 points
Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.
His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.
They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.
All together Mark's opponents scored 100+12+5=117 points
The total score for the game is both team's scores added together, so it is 84+117=201 points
The answer is 201
Question: {question}\nLet's think step by step:\n{answer}
'''),
retriever=dict(type=ZeroRetriever)
)
```
## 3. Self-Consistency
The SC (Self-Consistency) method is proposed in [this paper](https://arxiv.org/abs/2203.11171), which will sample multiple reasoning paths for the question, and make majority voting to the generated answers for LLMs. This method displays remarkable proficiency among reasoning tasks with high accuracy but may consume more time and resources when inferencing, because of the majority voting strategy. In OpenCompass, you can simply set SC method in the dataset config like:
```python
gsm8k_infer_cfg = dict(
inferencer=dict(
type=SCInferencer,
generation_kwargs=dict(do_sample=True, temperature=0.7, top_k=40), # Set sample parameters to make sure model generate various output
infer_type='SC',
sc_size = SAMPLE_SIZE
)
)
gsm8k_eval_cfg = dict(sc_size=SAMPLE_SIZE)
```
```{note}
注意OpenCompass 默认使用默认使用 argmax 的方式采样下一个 token因此若不指定采样参数模型每次的推理结果将会是完全一致的多轮评测将会失效。
```
Where `SAMPLE_SIZE` is the number of reasoning paths in Self-Consistency, higher value usually outcome higher performance. The following figure from the paper demonstrates the relation between reasoning paths and performance in several reasoning tasks:
![image](https://github.com/InternLM/opencompass/assets/28834990/05c7d850-7076-43ca-b165-e6251f9b3001)
From the figure, it can be seen that in different reasoning tasks, performance tends to improve as the number of reasoning paths increases. However, for some tasks, increasing the number of reasoning paths may reach a limit, and further increasing the number of paths may not bring significant performance improvement. Therefore, it is necessary to conduct experiments and adjustments on specific tasks to find the optimal number of reasoning paths that best suit the task.

View File

@ -0,0 +1,72 @@
# Chain of Thought
## 背景
CoT思维链是帮助大型语言模型解决如数学问题和关系推理问题等复杂问题的有效方式在OpenCompass中我们支持多种类型的CoT方法。
## 1. 零样本思维链
可以通过在数据集配置中简单地添加 “Let's think step by step",来更改数据集配置的 PromptTemplate从而实现 零样本 CoT prompt 以进行评估:
```python
qa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="Answer the question:\nQ: {question}?\nLet's think step by step:\n"
),
retriever=dict(type=ZeroRetriever)
)
```
## 2. 小样本思维链
小样本思维链可以使大型语言模型更容易跟随预设的指示并得到更好的答案。对于小样本思维链,按照以下配置将思维链模板添加到 `PromptTemplate` 中,可以创建一个 one-shot prompt
```python
qa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=
'''Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?
Let's think step by step
Answer:
Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.
His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers
They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.
All together his team scored 50+24+10= 84 points
Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.
His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.
They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.
All together Mark's opponents scored 100+12+5=117 points
The total score for the game is both team's scores added together, so it is 84+117=201 points
The answer is 201
Question: {question}\nLet's think step by step:\n{answer}
'''),
retriever=dict(type=ZeroRetriever)
)
```
## 3. Self-Consistency
SC (Self-Consistency) 方法是在 [此文章](https://arxiv.org/abs/2203.11171) 中提出的,该方法会为问题生成多个不同的推理路径,并对生成的答案进行众数投票。这种方法在复杂推理任务中表现出了显著的能力,但由于需要推理多次来采样多条推理链,所以可能会消耗很多的时间和资源。在 OpenCompass 中,您可以在数据集配置中简单地设置 SC 方法,例如:
```python
gsm8k_infer_cfg = dict(
inferencer=dict(
type=SCInferencer,
generation_kwargs=dict(do_sample=True, temperature=0.7, top_k=40), # 设置采样参数以确保模型生成不同的输出
infer_type='SC',
sc_size = SAMPLE_SIZE
)
)
gsm8k_eval_cfg = dict(sc_size=SAMPLE_SIZE)
```
```{note}
注意OpenCompass 默认使用默认使用 argmax 的方式采样下一个 token因此若不指定采样参数模型每次的推理结果将会是完全一致的多轮评测将会失效。
```
其中 `SAMPLE_SIZE` 是推理路径的数量,较高的值通常会带来更高的性能。文章中展示了不同推理任务间推理路径数量与性能之间的关系:
![image](https://github.com/InternLM/opencompass/assets/28834990/05c7d850-7076-43ca-b165-e6251f9b3001)
从图中可以看出,在不同的推理任务中,随着推理路径数量的增加,性能呈现出增长的趋势。但是,对于某些任务,增加推理路径的数量可能达到一个极限,进一步增加推理路径的数量可能不会带来更多的性能提升。因此,需要在具体任务中进行实验和调整,找到最适合任务的推理路径数量。

View File

@ -106,7 +106,7 @@ class BaseModel:
return self.get_ppl(inputs, mask_length)
def generate_from_template(self, templates: List[PromptType],
max_out_len: int):
max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
Args:
@ -114,7 +114,7 @@ class BaseModel:
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates, mode='gen')
return self.generate(inputs, max_out_len=max_out_len)
return self.generate(inputs, max_out_len=max_out_len, **kwargs)
def get_token_len_from_template(
self,

View File

@ -121,7 +121,8 @@ class HuggingFace(BaseModel):
self.model.config.eos_token_id = 2
self.model.config.pad_token_id = self.tokenizer.pad_token_id
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
@ -132,14 +133,16 @@ class HuggingFace(BaseModel):
List[str]: A list of generated strings.
"""
if self.batch_padding and len(inputs) > 1:
return self._batch_generate(inputs=inputs, max_out_len=max_out_len)
return self._batch_generate(inputs=inputs,
max_out_len=max_out_len,
**kwargs)
else:
return sum((self._single_generate(inputs=[input_],
max_out_len=max_out_len)
return sum((self._single_generate(
inputs=[input_], max_out_len=max_out_len, **kwargs)
for input_ in inputs), [])
def _batch_generate(self, inputs: List[str],
max_out_len: int) -> List[str]:
def _batch_generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
"""Support for batch prompts inference.
Args:
@ -164,7 +167,9 @@ class HuggingFace(BaseModel):
}
# step-2: conduct model forward to generate output
outputs = self.model.generate(**tokens, max_new_tokens=max_out_len)
outputs = self.model.generate(**tokens,
max_new_tokens=max_out_len,
**kwargs)
if not self.extract_pred_after_decode:
outputs = outputs[:, tokens['input_ids'].shape[1]:]
@ -179,8 +184,8 @@ class HuggingFace(BaseModel):
return decodeds
def _single_generate(self, inputs: List[str],
max_out_len: int) -> List[str]:
def _single_generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
"""Support for single prompt inference.
Args:
@ -198,8 +203,9 @@ class HuggingFace(BaseModel):
max_length=self.max_seq_len -
max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
outputs = self.model.generate(input_ids=input_ids,
max_new_tokens=max_out_len)
outputs = self.model.generate(input_ids,
max_new_tokens=max_out_len,
**kwargs)
if not self.extract_pred_after_decode:
outputs = outputs[:, input_ids.shape[1]:]

View File

@ -2,3 +2,4 @@ from .icl_base_inferencer import BaseInferencer # noqa
from .icl_clp_inferencer import CLPInferencer # noqa
from .icl_gen_inferencer import GenInferencer # noqa
from .icl_ppl_inferencer import PPLInferencer # noqa
from .icl_sc_inferencer import SCInferencer # noqa

View File

@ -0,0 +1,197 @@
"""Self-Consistency Generation Inferencer."""
import os
import os.path as osp
from typing import List, Optional
import mmengine
import torch
from tqdm import tqdm
from opencompass.models.base import BaseModel
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
logger = get_logger(__name__)
class SCInferencer(BaseInferencer):
"""Self-Consistency Inferencer class to evaluate by multiple generations.
Attributes:
model (:obj:`BaseModelWrapper`, optional): The module to inference.
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
allowed by the LM.
batch_size (:obj:`int`, optional): Batch size for the
:obj:`DataLoader`.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
gen_field_replace_token (:obj:`str`, optional): Used to replace the
generation field token when generating prompts.
save_every (:obj:`int`, optional): Save intermediate results every
`save_every` epochs.
generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method.
sc_size (:obj:`int`, optional): Sample size for Self-Consistency
infer_type (:obj:`str`, optional): Infer CoT type for
:obj:`inference()` method.
"""
def __init__(
self,
model: BaseModel,
max_out_len: int,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
gen_field_replace_token: Optional[str] = '',
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
sc_size: Optional[int] = 1,
infer_type: Optional[str] = '',
generation_kwargs: dict = {},
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.gen_field_replace_token = gen_field_replace_token
self.generation_kwargs = generation_kwargs
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
self.sc_size = sc_size
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = GenInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
ice_idx_list,
retriever,
self.gen_field_replace_token,
max_seq_len=self.max_seq_len,
ice_template=ice_template,
prompt_template=prompt_template)
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
tmp_result_dict = mmengine.load(tmp_json_filepath)
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for entry in tqdm(dataloader, disable=not self.is_main_process):
# TODO: add more types of CoT method
# 5-1. Inference sc_size times with local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entry, mode='gen')
sc_results = []
for _ in range(self.sc_size):
results = self.model.generate_from_template(
entry,
max_out_len=self.max_out_len,
**self.generation_kwargs)
sc_results.append(results)
sc_prediction = list(map(list, zip(*sc_results)))
generated = sc_prediction
print(generated)
# 5-3. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
index = index + 1
# 5-4. Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return [
sample['prediction']
for sample in output_handler.results_dict.values()
]
def get_generation_prompt_list_from_retriever_indices(
self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
gen_field_replace_token: str,
max_seq_len: Optional[int] = None,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
prompt_list.append(prompt)
return prompt_list

View File

@ -1,6 +1,7 @@
import argparse
import os.path as osp
import time
from collections import Counter
from typing import Optional
import mmengine
@ -77,6 +78,9 @@ class OpenICLEvalTask(BaseTask):
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
# Get sc_size if use Self-Consistency
sc_size = self.eval_cfg.get('sc_size')
if not osp.exists(osp.realpath(filename)) and not osp.exists(
osp.realpath(partial_filename)):
result = {'error': 'No predictions found.'}
@ -105,6 +109,19 @@ class OpenICLEvalTask(BaseTask):
from opencompass.models.base import LMTemplateParser
parser = LMTemplateParser(self.model_cfg['meta_template'])
role = parser.roles[self.eval_cfg['pred_role']]
if sc_size is not None:
for pred in pred_strs:
if not isinstance(pred, list):
raise TypeError(
'The prediction for Self-Consistency'
'must be list.')
pred_strs.append([
self._extract_role_pred(sc_pred,
role.get('begin', None),
role.get('end', None))
for sc_pred in pred
])
else:
pred_strs = [
self._extract_role_pred(pred, role.get('begin', None),
role.get('end', None))
@ -115,6 +132,11 @@ class OpenICLEvalTask(BaseTask):
if 'pred_postprocessor' in self.eval_cfg:
proc = TEXT_POSTPROCESSORS.get(
self.eval_cfg['pred_postprocessor']['type'])
if sc_size is not None:
pred_strs = [
self._get_vote_out(proc, s) for s in pred_strs
]
else:
pred_strs = [proc(s) for s in pred_strs]
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
@ -164,6 +186,14 @@ class OpenICLEvalTask(BaseTask):
return s[start:end]
def _get_vote_out(
self,
proc: Optional[callable],
sc_prediction: Optional[list],
) -> str:
counter = Counter([proc(prediction) for prediction in sc_prediction])
return counter.most_common(1)[0][0]
def parse_args():
parser = argparse.ArgumentParser(description='Score Calculator')