diff --git a/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_98dd6e.py b/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_98dd6e.py index cbbea494..46e13fa0 100644 --- a/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_98dd6e.py +++ b/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_98dd6e.py @@ -2,7 +2,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import PPLInferencer from opencompass.openicl.icl_evaluator import AccEvaluator -from opencompass.datasets import HFDataset +from opencompass.datasets import cmnliDataset cmnli_reader_cfg = dict( input_columns=['sentence1', 'sentence2'], @@ -25,11 +25,9 @@ cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) cmnli_datasets = [ dict( - type=HFDataset, - abbr='cmnli', - path='json', - split='train', - data_files='./data/CLUE/cmnli/cmnli_public/dev.json', + abbr="cmnli", + type=cmnliDataset, + path='./data/CLUE/cmnli/cmnli_public/dev.json', reader_cfg=cmnli_reader_cfg, infer_cfg=cmnli_infer_cfg, eval_cfg=cmnli_eval_cfg) diff --git a/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_ef69e7.py b/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_ef69e7.py index b91c9bf3..bc5b765d 100644 --- a/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_ef69e7.py +++ b/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_ef69e7.py @@ -2,7 +2,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import PPLInferencer from opencompass.openicl.icl_evaluator import AccEvaluator -from opencompass.datasets import HFDataset +from opencompass.datasets import cmnliDataset cmnli_reader_cfg = dict( input_columns=['sentence1', 'sentence2'], @@ -41,11 +41,9 @@ cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) cmnli_datasets = [ dict( - type=HFDataset, - abbr='cmnli', - path='json', - split='train', - data_files='./data/CLUE/cmnli/cmnli_public/dev.json', + abbr="cmnli", + type=cmnliDataset, + path='./data/CLUE/cmnli/cmnli_public/dev.json', reader_cfg=cmnli_reader_cfg, infer_cfg=cmnli_infer_cfg, eval_cfg=cmnli_eval_cfg) diff --git a/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_fdc6de.py b/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_fdc6de.py index eb051898..a3770db6 100644 --- a/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_fdc6de.py +++ b/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_fdc6de.py @@ -2,7 +2,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import PPLInferencer from opencompass.openicl.icl_evaluator import AccEvaluator -from opencompass.datasets import HFDataset +from opencompass.datasets import cmnliDataset cmnli_reader_cfg = dict( input_columns=['sentence1', 'sentence2'], @@ -45,11 +45,9 @@ cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) cmnli_datasets = [ dict( - type=HFDataset, - abbr='cmnli', - path='json', - split='train', - data_files='./data/CLUE/cmnli/cmnli_public/dev.json', + abbr="cmnli", + type=cmnliDataset, + path='./data/CLUE/cmnli/cmnli_public/dev.json', reader_cfg=cmnli_reader_cfg, infer_cfg=cmnli_infer_cfg, eval_cfg=cmnli_eval_cfg) diff --git a/configs/datasets/longbench/longbenchlsht/longbench_lsht_gen_e8a339.py b/configs/datasets/longbench/longbenchlsht/longbench_lsht_gen_e8a339.py index 6ff0ab1e..9ebb82b3 100644 --- a/configs/datasets/longbench/longbenchlsht/longbench_lsht_gen_e8a339.py +++ b/configs/datasets/longbench/longbenchlsht/longbench_lsht_gen_e8a339.py @@ -1,7 +1,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchlshtDataset +from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchlshtDataset, lsht_postprocess LongBench_lsht_reader_cfg = dict( input_columns=['context', 'input'], @@ -23,7 +23,8 @@ LongBench_lsht_infer_cfg = dict( LongBench_lsht_eval_cfg = dict( evaluator=dict(type=LongBenchClassificationEvaluator), - pred_role='BOT' + pred_role='BOT', + pred_postprocessor=dict(type=lsht_postprocess), ) LongBench_lsht_datasets = [ diff --git a/configs/datasets/longbench/longbenchsamsum/longbench_samsum_gen_f4416d.py b/configs/datasets/longbench/longbenchsamsum/longbench_samsum_gen_f4416d.py index af6b1c2e..51d2f74a 100644 --- a/configs/datasets/longbench/longbenchsamsum/longbench_samsum_gen_f4416d.py +++ b/configs/datasets/longbench/longbenchsamsum/longbench_samsum_gen_f4416d.py @@ -1,7 +1,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import LongBenchRougeEvaluator, LongBenchsamsumDataset +from opencompass.datasets import LongBenchRougeEvaluator, LongBenchsamsumDataset, samsum_postprocess LongBench_samsum_reader_cfg = dict( input_columns=['context', 'input'], @@ -23,7 +23,8 @@ LongBench_samsum_infer_cfg = dict( LongBench_samsum_eval_cfg = dict( evaluator=dict(type=LongBenchRougeEvaluator), - pred_role='BOT' + pred_role='BOT', + pred_postprocessor=dict(type=samsum_postprocess), ) LongBench_samsum_datasets = [ diff --git a/configs/datasets/longbench/longbenchtrec/longbench_trec_gen_824187.py b/configs/datasets/longbench/longbenchtrec/longbench_trec_gen_824187.py index f414696c..66719fb9 100644 --- a/configs/datasets/longbench/longbenchtrec/longbench_trec_gen_824187.py +++ b/configs/datasets/longbench/longbenchtrec/longbench_trec_gen_824187.py @@ -1,7 +1,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchtrecDataset +from opencompass.datasets import LongBenchClassificationEvaluator, LongBenchtrecDataset, trec_postprocess LongBench_trec_reader_cfg = dict( input_columns=['context', 'input'], @@ -23,7 +23,8 @@ LongBench_trec_infer_cfg = dict( LongBench_trec_eval_cfg = dict( evaluator=dict(type=LongBenchClassificationEvaluator), - pred_role='BOT' + pred_role='BOT', + pred_postprocessor=dict(type=trec_postprocess), ) LongBench_trec_datasets = [ diff --git a/configs/datasets/longbench/longbenchtriviaqa/longbench_triviaqa_gen_d30cb9.py b/configs/datasets/longbench/longbenchtriviaqa/longbench_triviaqa_gen_d30cb9.py index a09732e6..2cfb7fc1 100644 --- a/configs/datasets/longbench/longbenchtriviaqa/longbench_triviaqa_gen_d30cb9.py +++ b/configs/datasets/longbench/longbenchtriviaqa/longbench_triviaqa_gen_d30cb9.py @@ -1,7 +1,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import LongBenchF1Evaluator, LongBenchtriviaqaDataset +from opencompass.datasets import LongBenchF1Evaluator, LongBenchtriviaqaDataset, triviaqa_postprocess LongBench_triviaqa_reader_cfg = dict( input_columns=['context', 'input'], @@ -23,7 +23,8 @@ LongBench_triviaqa_infer_cfg = dict( LongBench_triviaqa_eval_cfg = dict( evaluator=dict(type=LongBenchF1Evaluator), - pred_role='BOT' + pred_role='BOT', + pred_postprocessor=dict(type=triviaqa_postprocess), ) LongBench_triviaqa_datasets = [ diff --git a/configs/models/vicuna/hf_vicuna_13b_v13.py b/configs/models/vicuna/hf_vicuna_13b_v13.py index 97eb4401..6a04a3c4 100644 --- a/configs/models/vicuna/hf_vicuna_13b_v13.py +++ b/configs/models/vicuna/hf_vicuna_13b_v13.py @@ -17,6 +17,7 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding + use_fastchat_template=True, run_cfg=dict(num_gpus=2, num_procs=1) ) ] diff --git a/configs/models/vicuna/hf_vicuna_13b_v15.py b/configs/models/vicuna/hf_vicuna_13b_v15.py index ede90573..c87b9dc7 100644 --- a/configs/models/vicuna/hf_vicuna_13b_v15.py +++ b/configs/models/vicuna/hf_vicuna_13b_v15.py @@ -17,6 +17,7 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding + use_fastchat_template=True, run_cfg=dict(num_gpus=1, num_procs=1) ) ] diff --git a/configs/models/vicuna/hf_vicuna_13b_v15_16k.py b/configs/models/vicuna/hf_vicuna_13b_v15_16k.py index 794084d6..3496b355 100644 --- a/configs/models/vicuna/hf_vicuna_13b_v15_16k.py +++ b/configs/models/vicuna/hf_vicuna_13b_v15_16k.py @@ -17,6 +17,7 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding + use_fastchat_template=True, run_cfg=dict(num_gpus=2, num_procs=1) ) ] diff --git a/configs/models/vicuna/hf_vicuna_33b_v13.py b/configs/models/vicuna/hf_vicuna_33b_v13.py index 32f55365..0f280e63 100644 --- a/configs/models/vicuna/hf_vicuna_33b_v13.py +++ b/configs/models/vicuna/hf_vicuna_33b_v13.py @@ -17,6 +17,7 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding + use_fastchat_template=True, run_cfg=dict(num_gpus=4, num_procs=1) ) ] diff --git a/configs/models/vicuna/hf_vicuna_7b_v13.py b/configs/models/vicuna/hf_vicuna_7b_v13.py index 6b25db0d..67e1c79b 100644 --- a/configs/models/vicuna/hf_vicuna_7b_v13.py +++ b/configs/models/vicuna/hf_vicuna_7b_v13.py @@ -17,6 +17,7 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding + use_fastchat_template=True, run_cfg=dict(num_gpus=1, num_procs=1) ) ] diff --git a/configs/models/vicuna/hf_vicuna_7b_v15.py b/configs/models/vicuna/hf_vicuna_7b_v15.py index 76d3486a..06f3ef73 100644 --- a/configs/models/vicuna/hf_vicuna_7b_v15.py +++ b/configs/models/vicuna/hf_vicuna_7b_v15.py @@ -17,6 +17,7 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding + use_fastchat_template=True, run_cfg=dict(num_gpus=1, num_procs=1) ) ] diff --git a/configs/models/vicuna/hf_vicuna_7b_v15_16k.py b/configs/models/vicuna/hf_vicuna_7b_v15_16k.py index 45d93b6d..ce590347 100644 --- a/configs/models/vicuna/hf_vicuna_7b_v15_16k.py +++ b/configs/models/vicuna/hf_vicuna_7b_v15_16k.py @@ -17,6 +17,7 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding + use_fastchat_template=True, run_cfg=dict(num_gpus=1, num_procs=1) ) ] diff --git a/configs/summarizers/groups/tydiqa.py b/configs/summarizers/groups/tydiqa.py index e5191ad8..2a22adea 100644 --- a/configs/summarizers/groups/tydiqa.py +++ b/configs/summarizers/groups/tydiqa.py @@ -1,5 +1,5 @@ tydiqa_summary_groups = [] _tydiqa = ['arabic', 'bengali', 'english', 'finnish', 'indonesian', 'japanese', 'korean', 'russian', 'swahili', 'telugu', 'thai'] -_tydiqa = ['tyidqa-goldp_' + s for s in _tydiqa] +_tydiqa = ['tydiqa-goldp_' + s for s in _tydiqa] tydiqa_summary_groups.append({'name': 'tydiqa-goldp', 'subsets': _tydiqa}) diff --git a/opencompass/datasets/cmb.py b/opencompass/datasets/cmb.py index d09f0f8b..f2dd321c 100644 --- a/opencompass/datasets/cmb.py +++ b/opencompass/datasets/cmb.py @@ -18,6 +18,7 @@ class CMBDataset(BaseDataset): for d in val_data: d['option_str'] = '\n'.join( [f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1]) + d['answer'] = 'NULL' val_dataset = Dataset.from_list(val_data) with open(osp.join(path, 'test.json'), 'r', encoding='utf-8') as f: @@ -25,7 +26,6 @@ class CMBDataset(BaseDataset): for d in test_data: d['option_str'] = '\n'.join( [f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1]) - d['answer'] = 'NULL' test_dataset = Dataset.from_list(test_data) return DatasetDict({'val': val_dataset, 'test': test_dataset}) diff --git a/opencompass/datasets/cmnli.py b/opencompass/datasets/cmnli.py index 653148d3..e0309baa 100644 --- a/opencompass/datasets/cmnli.py +++ b/opencompass/datasets/cmnli.py @@ -7,6 +7,19 @@ from opencompass.registry import LOAD_DATASET from .base import BaseDataset +@LOAD_DATASET.register_module() +class cmnliDataset(BaseDataset): + + @staticmethod + def load(path): + data = [] + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = json.loads(line) + data.append(line) + return Dataset.from_list(data) + + @LOAD_DATASET.register_module() class cmnliDataset_V2(BaseDataset): diff --git a/opencompass/datasets/longbench/longbench_lsht.py b/opencompass/datasets/longbench/longbench_lsht.py index c6b65350..99cb4127 100644 --- a/opencompass/datasets/longbench/longbench_lsht.py +++ b/opencompass/datasets/longbench/longbench_lsht.py @@ -1,6 +1,6 @@ from datasets import Dataset, load_dataset -from opencompass.registry import LOAD_DATASET +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS from ..base import BaseDataset @@ -28,3 +28,9 @@ class LongBenchlshtDataset(BaseDataset): }) dataset[split] = Dataset.from_list(raw_data) return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def lsht_postprocess(text: str) -> str: + text = text.lstrip('\n').split('\n')[0] + return text diff --git a/opencompass/datasets/longbench/longbench_samsum.py b/opencompass/datasets/longbench/longbench_samsum.py index ae6e1b57..096f9a0f 100644 --- a/opencompass/datasets/longbench/longbench_samsum.py +++ b/opencompass/datasets/longbench/longbench_samsum.py @@ -1,6 +1,6 @@ from datasets import Dataset, load_dataset -from opencompass.registry import LOAD_DATASET +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS from ..base import BaseDataset @@ -24,3 +24,9 @@ class LongBenchsamsumDataset(BaseDataset): }) dataset[split] = Dataset.from_list(raw_data) return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def samsum_postprocess(text: str) -> str: + text = text.lstrip('\n').split('\n')[0] + return text diff --git a/opencompass/datasets/longbench/longbench_trec.py b/opencompass/datasets/longbench/longbench_trec.py index d8d5c177..c70d0008 100644 --- a/opencompass/datasets/longbench/longbench_trec.py +++ b/opencompass/datasets/longbench/longbench_trec.py @@ -1,6 +1,6 @@ from datasets import Dataset, load_dataset -from opencompass.registry import LOAD_DATASET +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS from ..base import BaseDataset @@ -28,3 +28,9 @@ class LongBenchtrecDataset(BaseDataset): }) dataset[split] = Dataset.from_list(raw_data) return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def trec_postprocess(text: str) -> str: + text = text.lstrip('\n').split('\n')[0] + return text diff --git a/opencompass/datasets/longbench/longbench_trivia_qa.py b/opencompass/datasets/longbench/longbench_trivia_qa.py index a698e8df..de52d7e0 100644 --- a/opencompass/datasets/longbench/longbench_trivia_qa.py +++ b/opencompass/datasets/longbench/longbench_trivia_qa.py @@ -1,6 +1,6 @@ from datasets import Dataset, load_dataset -from opencompass.registry import LOAD_DATASET +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS from ..base import BaseDataset @@ -24,3 +24,9 @@ class LongBenchtriviaqaDataset(BaseDataset): }) dataset[split] = Dataset.from_list(raw_data) return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def triviaqa_postprocess(text: str) -> str: + text = text.lstrip('\n').split('\n')[0] + return text diff --git a/opencompass/models/huggingface.py b/opencompass/models/huggingface.py index 08aaf492..3b939db9 100644 --- a/opencompass/models/huggingface.py +++ b/opencompass/models/huggingface.py @@ -46,6 +46,9 @@ class HuggingFace(BaseModel): mode (str, optional): The method of input truncation when input length exceeds max_seq_len. 'mid' represents the part of input to truncate. Defaults to 'none'. + use_fastchat_template (str, optional): Whether to use fastchat to get + the conversation template. If True, fastchat needs to be + implemented first. Defaults to False. Note: About ``extract_pred_after_decode``: Commonly, we should extract the @@ -68,7 +71,8 @@ class HuggingFace(BaseModel): extract_pred_after_decode: bool = False, batch_padding: bool = False, pad_token_id: Optional[int] = None, - mode: str = 'none'): + mode: str = 'none', + use_fastchat_template: bool = False): super().__init__(path=path, max_seq_len=max_seq_len, tokenizer_only=tokenizer_only, @@ -91,6 +95,7 @@ class HuggingFace(BaseModel): model_kwargs=model_kwargs, peft_path=peft_path) self.generation_kwargs = generation_kwargs + self.use_fastchat_template = use_fastchat_template def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): @@ -220,6 +225,20 @@ class HuggingFace(BaseModel): if self.extract_pred_after_decode: prompt_lens = [len(input_) for input_ in inputs] + if self.use_fastchat_template: + try: + from fastchat.model import get_conversation_template + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'Fastchat is not implemented. You can use ' + '\'pip install "fschat[model_worker,webui]"\' ' + 'to implement fastchat.') + for i in range(len(inputs)): + conv = get_conversation_template('vicuna') + conv.append_message(conv.roles[0], inputs[i]) + conv.append_message(conv.roles[1], None) + inputs[i] = conv.get_prompt() + # step-1: tokenize the input with batch_encode_plus tokens = self.tokenizer.batch_encode_plus(inputs, padding=True, @@ -263,6 +282,19 @@ class HuggingFace(BaseModel): if self.extract_pred_after_decode: prompt_lens = [len(input_) for input_ in inputs] + if self.use_fastchat_template: + try: + from fastchat.model import get_conversation_template + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'Fastchat is not implemented. You can use ' + '\'pip install "fschat[model_worker,webui]"\' ' + 'to implement fastchat.') + conv = get_conversation_template('vicuna') + conv.append_message(conv.roles[0], inputs[0]) + conv.append_message(conv.roles[1], None) + inputs = [conv.get_prompt()] + if self.mode == 'mid': input_ids = self.tokenizer(inputs, truncation=False)['input_ids'] input_ids = torch.tensor(input_ids, device=self.model.device) @@ -491,7 +523,8 @@ class HuggingFaceChatGLM3(HuggingFace): def generate(self, inputs: List[str or PromptList], max_out_len: int = 512, - temperature: float = 0.6) -> str: + temperature: float = 0.6, + skip_overlength=False) -> str: """Generate response from input prompt. Args: @@ -518,6 +551,20 @@ class HuggingFaceChatGLM3(HuggingFace): history.append(msg) user_content = history[-1]['content'] history = history[:-1] + + if skip_overlength: + # The model will report the following error + # if the sequence length is greater than the maximum length: + # "Input length of input_ids is {INPUT_IDS}, + # but `max_length` is set to 8192. + # This can lead to unexpected behavior. + # You should consider increasing `max_new_tokens`." + # The following hardcode can fix this exception. + len_user_content = len(self.tokenizer.encode(user_content)) + if len_user_content > 8192: + responses.append('') + continue + try: response, history = self.model.chat(self.tokenizer, user_content, diff --git a/opencompass/models/llama2.py b/opencompass/models/llama2.py index 9971ece3..92068a0c 100644 --- a/opencompass/models/llama2.py +++ b/opencompass/models/llama2.py @@ -141,12 +141,19 @@ class Llama2Chat(BaseModel): path: str, max_seq_len: int, max_batch_size: int, - tokenizer_path: Optional[str] = None): + tokenizer_path: Optional[str] = None, + force_bf16=False): from llama import Llama self.generator = Llama.build(path, tokenizer_path, max_seq_len, max_batch_size) self.tokenizer = self.generator.tokenizer self.model = self.generator.model + if force_bf16: + # force set model to `bfloat16` to fix + # the exception of 'RuntimeError: probability tensor + # contains either `inf`, `nan` or element < 0', + # encountered during the inference of llama2-7b + self.model = self.model.bfloat16() def _load_tokenizer(self, tokenizer_path: str): from llama import Tokenizer diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index 6d5fc669..f5dc5a4c 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -108,9 +108,13 @@ class GenInferencer(BaseInferencer): 'tmp_' + output_json_filename) if osp.exists(tmp_json_filepath): # TODO: move resume to output handler - tmp_result_dict = mmengine.load(tmp_json_filepath) - output_handler.results_dict = tmp_result_dict - index = len(tmp_result_dict) + try: + tmp_result_dict = mmengine.load(tmp_json_filepath) + except Exception: + pass + else: + output_handler.results_dict = tmp_result_dict + index = len(tmp_result_dict) # 4. Wrap prompts with Dataloader dataloader = self.get_dataloader(prompt_list[index:], self.batch_size) diff --git a/opencompass/runners/slurm_sequential.py b/opencompass/runners/slurm_sequential.py index aa3f5493..18707249 100644 --- a/opencompass/runners/slurm_sequential.py +++ b/opencompass/runners/slurm_sequential.py @@ -96,7 +96,7 @@ class SlurmSequentialRunner(BaseRunner): try: parent_conns = [] - num_workers = min(self.max_num_workers, len(tasks)) + num_workers = max(min(self.max_num_workers, len(tasks)), 1) with Pool(processes=num_workers) as pool: for task in tasks: parent_conn, child_conn = Pipe()