mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Fix Slurm ENV (#1392)
1. Support Slurm Cluster 2. Support automatic data download 3. Update InternLM2.5-1.8B/20B-Chat
This commit is contained in:
parent
c09fc79ba8
commit
c81329b548
@ -14,7 +14,9 @@ exclude: |
|
||||
docs/zh_cn/advanced_guides/compassbench_v2_0.md |
|
||||
opencompass/configs/datasets/ |
|
||||
opencompass/configs/models/|
|
||||
opencompass/configs/summarizers/
|
||||
opencompass/configs/summarizers/|
|
||||
opencompass/utils/datasets.py |
|
||||
opencompass/utils/datasets_info.py
|
||||
)
|
||||
repos:
|
||||
- repo: https://gitee.com/openmmlab/mirrors-flake8
|
||||
|
@ -16,7 +16,9 @@ exclude: |
|
||||
docs/zh_cn/advanced_guides/compassbench_v2_0.md |
|
||||
opencompass/configs/datasets/ |
|
||||
opencompass/configs/models/|
|
||||
opencompass/configs/summarizers/
|
||||
opencompass/configs/summarizers/ |
|
||||
opencompass/utils/datasets.py |
|
||||
opencompass/utils/datasets_info.py
|
||||
)
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
|
@ -150,5 +150,5 @@ for _name, _prompt in sub_map.items():
|
||||
infer_order='double',
|
||||
base_models=gpt4,
|
||||
summarizer = dict(type=CompassArenaSummarizer, summary_type='half_add'),
|
||||
given_pred = [{'abbr':'gpt4-turbo', 'path':'./data/subjective/alpaca_eval/gpt4-turbo'}]
|
||||
given_pred = [{'abbr':'gpt4-turbo', 'path':'./data/subjective/compass_arena/gpt4-turbo'}]
|
||||
))
|
||||
|
12
configs/models/hf_internlm/hf_internlm2_5_1_8b_chat.py
Normal file
12
configs/models/hf_internlm/hf_internlm2_5_1_8b_chat.py
Normal file
@ -0,0 +1,12 @@
|
||||
from opencompass.models import HuggingFacewithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFacewithChatTemplate,
|
||||
abbr='internlm2_5-1_8b-chat-hf',
|
||||
path='internlm/internlm2_5-1_8b-chat',
|
||||
max_out_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
12
configs/models/hf_internlm/hf_internlm2_5_20b_chat.py
Normal file
12
configs/models/hf_internlm/hf_internlm2_5_20b_chat.py
Normal file
@ -0,0 +1,12 @@
|
||||
from opencompass.models import HuggingFacewithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFacewithChatTemplate,
|
||||
abbr='internlm2_5-20b-chat-hf',
|
||||
path='internlm/internlm2_5-20b-chat',
|
||||
max_out_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=2),
|
||||
)
|
||||
]
|
15
configs/models/hf_internlm/lmdeploy_internlm2_5_1_8b_chat.py
Normal file
15
configs/models/hf_internlm/lmdeploy_internlm2_5_1_8b_chat.py
Normal file
@ -0,0 +1,15 @@
|
||||
from opencompass.models import TurboMindModelwithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=TurboMindModelwithChatTemplate,
|
||||
abbr='internlm2_5-1_8b-chat-turbomind',
|
||||
path='internlm/internlm2_5-1_8b-chat',
|
||||
engine_config=dict(session_len=8192, max_batch_size=16, tp=1),
|
||||
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=2048),
|
||||
max_seq_len=8192,
|
||||
max_out_len=2048,
|
||||
batch_size=16,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
15
configs/models/hf_internlm/lmdeploy_internlm2_5_20b_chat.py
Normal file
15
configs/models/hf_internlm/lmdeploy_internlm2_5_20b_chat.py
Normal file
@ -0,0 +1,15 @@
|
||||
from opencompass.models import TurboMindModelwithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=TurboMindModelwithChatTemplate,
|
||||
abbr='internlm2_5-20b-chat-turbomind',
|
||||
path='internlm/internlm2_5-20b-chat',
|
||||
engine_config=dict(session_len=8192, max_batch_size=16, tp=2),
|
||||
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=2048),
|
||||
max_seq_len=8192,
|
||||
max_out_len=2048,
|
||||
batch_size=16,
|
||||
run_cfg=dict(num_gpus=2),
|
||||
)
|
||||
]
|
@ -150,5 +150,5 @@ for _name, _prompt in sub_map.items():
|
||||
infer_order='double',
|
||||
base_models=gpt4,
|
||||
summarizer = dict(type=CompassArenaSummarizer, summary_type='half_add'),
|
||||
given_pred = [{'abbr':'gpt4-turbo', 'path':'./data/subjective/alpaca_eval/gpt4-turbo'}]
|
||||
given_pred = [{'abbr':'gpt4-turbo', 'path':'./data/subjective/compass_arena/gpt4-turbo'}]
|
||||
))
|
||||
|
@ -0,0 +1,12 @@
|
||||
from opencompass.models import HuggingFacewithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFacewithChatTemplate,
|
||||
abbr='internlm2_5-1_8b-chat-hf',
|
||||
path='internlm/internlm2_5-1_8b-chat',
|
||||
max_out_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
@ -0,0 +1,12 @@
|
||||
from opencompass.models import HuggingFacewithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFacewithChatTemplate,
|
||||
abbr='internlm2_5-20b-chat-hf',
|
||||
path='internlm/internlm2_5-20b-chat',
|
||||
max_out_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=2),
|
||||
)
|
||||
]
|
@ -0,0 +1,15 @@
|
||||
from opencompass.models import TurboMindModelwithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=TurboMindModelwithChatTemplate,
|
||||
abbr='internlm2_5-1_8b-chat-turbomind',
|
||||
path='internlm/internlm2_5-1_8b-chat',
|
||||
engine_config=dict(session_len=8192, max_batch_size=16, tp=1),
|
||||
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=2048),
|
||||
max_seq_len=8192,
|
||||
max_out_len=2048,
|
||||
batch_size=16,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
@ -0,0 +1,15 @@
|
||||
from opencompass.models import TurboMindModelwithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=TurboMindModelwithChatTemplate,
|
||||
abbr='internlm2_5-20b-chat-turbomind',
|
||||
path='internlm/internlm2_5-20b-chat',
|
||||
engine_config=dict(session_len=8192, max_batch_size=16, tp=2),
|
||||
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=2048),
|
||||
max_seq_len=8192,
|
||||
max_out_len=2048,
|
||||
batch_size=16,
|
||||
run_cfg=dict(num_gpus=2),
|
||||
)
|
||||
]
|
@ -8,7 +8,7 @@ from datasets import Dataset
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator, LMEvaluator
|
||||
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||
TEXT_POSTPROCESSORS)
|
||||
from opencompass.utils import build_dataset_from_cfg
|
||||
from opencompass.utils import build_dataset_from_cfg, get_data_path
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
@ -147,6 +147,7 @@ class CharmDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str):
|
||||
path = get_data_path(path, local_mode=True)
|
||||
with open(osp.join(path, f'{name}.json'), 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)['examples']
|
||||
dataset = Dataset.from_list(data)
|
||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from .subjective_cmp import SubjectiveCmpDataset
|
||||
|
||||
@ -15,6 +16,8 @@ class Config:
|
||||
|
||||
def __init__(self, alignment_bench_config_path,
|
||||
alignment_bench_config_name) -> None:
|
||||
alignment_bench_config_path = get_data_path(
|
||||
alignment_bench_config_path, local_mode=True)
|
||||
config_file_path = osp.join(alignment_bench_config_path,
|
||||
alignment_bench_config_name + '.json')
|
||||
with open(config_file_path, 'r') as config_file:
|
||||
|
@ -4,6 +4,7 @@ import os.path as osp
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from ..base import BaseDataset
|
||||
|
||||
@ -12,6 +13,7 @@ from ..base import BaseDataset
|
||||
class ArenaHardDataset(BaseDataset):
|
||||
|
||||
def load(self, path: str, name: str, *args, **kwargs):
|
||||
path = get_data_path(path, local_mode=True)
|
||||
filename = osp.join(path, f'{name}.jsonl')
|
||||
dataset = DatasetDict()
|
||||
raw_data = []
|
||||
|
@ -6,6 +6,7 @@ import re
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from ..base import BaseDataset
|
||||
|
||||
@ -172,6 +173,7 @@ class MTBenchDataset(BaseDataset):
|
||||
multi_turn=True,
|
||||
*args,
|
||||
**kwargs):
|
||||
path = get_data_path(path, local_mode=True)
|
||||
filename = osp.join(path, f'{name}.json')
|
||||
dataset = DatasetDict()
|
||||
raw_data = []
|
||||
|
@ -4,6 +4,7 @@ import os.path as osp
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from ..base import BaseDataset
|
||||
|
||||
@ -12,6 +13,7 @@ from ..base import BaseDataset
|
||||
class SubjectiveCmpDataset(BaseDataset):
|
||||
|
||||
def load(self, path: str, name: str, *args, **kwargs):
|
||||
path = get_data_path(path, local_mode=True)
|
||||
filename = osp.join(path, f'{name}.json')
|
||||
dataset = DatasetDict()
|
||||
raw_data = []
|
||||
|
@ -311,11 +311,15 @@ class OpenAI(BaseAPIModel):
|
||||
try:
|
||||
enc = self.tiktoken.encoding_for_model(self.tokenizer_path)
|
||||
return len(enc.encode(prompt))
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
self.logger.warn(f'{e}, tiktoken encoding cannot load '
|
||||
'{self.tokenizer_path}')
|
||||
from transformers import AutoTokenizer
|
||||
if self.hf_tokenizer is None:
|
||||
self.hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer_path)
|
||||
self.tokenizer_path, trust_remote_code=True)
|
||||
self.logger.infer(
|
||||
f'Tokenizer is loaded from {self.tokenizer_path}')
|
||||
return len(self.hf_tokenizer(prompt).input_ids)
|
||||
else:
|
||||
enc = self.tiktoken.encoding_for_model(self.path)
|
||||
@ -424,14 +428,14 @@ class OpenAISDK(OpenAI):
|
||||
messages.append(msg)
|
||||
|
||||
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
||||
try:
|
||||
max_out_len = min(
|
||||
max_out_len,
|
||||
context_window - self.get_token_len(str(input)) - 100)
|
||||
except KeyError:
|
||||
max_out_len = max_out_len
|
||||
if max_out_len <= 0:
|
||||
return ''
|
||||
# try:
|
||||
# max_out_len = min(
|
||||
# max_out_len,
|
||||
# context_window - self.get_token_len(str(input)) - 100)
|
||||
# except KeyError:
|
||||
# max_out_len = max_out_len
|
||||
# if max_out_len <= 0:
|
||||
# return ''
|
||||
|
||||
num_retries = 0
|
||||
while num_retries < self.retry:
|
||||
|
@ -15,7 +15,7 @@ from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.tasks.openicl_eval import extract_role_pred
|
||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
deal_with_judge_model_abbr,
|
||||
deal_with_judge_model_abbr, get_data_path,
|
||||
get_infer_output_path, get_logger,
|
||||
model_abbr_from_cfg, task_abbr_from_cfg)
|
||||
|
||||
@ -140,6 +140,7 @@ class SubjectiveEvalTask(BaseTask):
|
||||
for given_pred in given_preds:
|
||||
abbr = given_pred['abbr']
|
||||
path = given_pred['path']
|
||||
path = get_data_path(path, local_mode=True)
|
||||
if abbr == model_cfg['abbr']:
|
||||
filename = osp.join(path, osp.basename(filename))
|
||||
# Get partition name
|
||||
|
@ -1,262 +1,10 @@
|
||||
import os
|
||||
from .fileio import download_and_extract_archive
|
||||
from .datasets_info import DATASETS_MAPPING, DATASETS_URL
|
||||
from .logging import get_logger
|
||||
|
||||
DATASETS_MAPPING = {
|
||||
# ADVGLUE Datasets
|
||||
'opencompass/advglue-dev': {
|
||||
'ms_id': None,
|
||||
'hf_id': None,
|
||||
'local': './data/adv_glue/dev_ann.json',
|
||||
},
|
||||
# AGIEval Datasets
|
||||
'opencompass/agieval': {
|
||||
'ms_id': 'opencompass/agieval',
|
||||
'hf_id': 'opencompass/agieval',
|
||||
'local': './data/AGIEval/data/v1/',
|
||||
},
|
||||
|
||||
# ARC Datasets(Test)
|
||||
'opencompass/ai2_arc-test': {
|
||||
'ms_id': 'opencompass/ai2_arc',
|
||||
'hf_id': 'opencompass/ai2_arc',
|
||||
'local': './data/ARC/ARC-c/ARC-Challenge-Test.jsonl',
|
||||
},
|
||||
'opencompass/ai2_arc-dev': {
|
||||
'ms_id': 'opencompass/ai2_arc',
|
||||
'hf_id': 'opencompass/ai2_arc',
|
||||
'local': './data/ARC/ARC-c/ARC-Challenge-Dev.jsonl',
|
||||
},
|
||||
'opencompass/ai2_arc-easy-dev': {
|
||||
'ms_id': 'opencompass/ai2_arc',
|
||||
'hf_id': 'opencompass/ai2_arc',
|
||||
'local': './data/ARC/ARC-e/ARC-Easy-Dev.jsonl',
|
||||
},
|
||||
# BBH
|
||||
'opencompass/bbh': {
|
||||
'ms_id': 'opencompass/bbh',
|
||||
'hf_id': 'opencompass/bbh',
|
||||
'local': './data/BBH/data',
|
||||
},
|
||||
# C-Eval
|
||||
'opencompass/ceval-exam': {
|
||||
'ms_id': 'opencompass/ceval-exam',
|
||||
'hf_id': 'opencompass/ceval-exam',
|
||||
'local': './data/ceval/formal_ceval',
|
||||
},
|
||||
# AFQMC
|
||||
'opencompass/afqmc-dev': {
|
||||
'ms_id': 'opencompass/afqmc',
|
||||
'hf_id': 'opencompass/afqmc',
|
||||
'local': './data/CLUE/AFQMC/dev.json',
|
||||
},
|
||||
# CMNLI
|
||||
'opencompass/cmnli-dev': {
|
||||
'ms_id': 'opencompass/cmnli',
|
||||
'hf_id': 'opencompass/cmnli',
|
||||
'local': './data/CLUE/cmnli/cmnli_public/dev.json',
|
||||
},
|
||||
# OCNLI
|
||||
'opencompass/OCNLI-dev': {
|
||||
'ms_id': 'opencompass/OCNLI',
|
||||
'hf_id': 'opencompass/OCNLI',
|
||||
'local': './data/CLUE/OCNLI/dev.json',
|
||||
},
|
||||
# ChemBench
|
||||
'opencompass/ChemBench': {
|
||||
'ms_id': 'opencompass/ChemBench',
|
||||
'hf_id': 'opencompass/ChemBench',
|
||||
'local': './data/ChemBench/',
|
||||
},
|
||||
# CMMLU
|
||||
'opencompass/cmmlu': {
|
||||
'ms_id': 'opencompass/cmmlu',
|
||||
'hf_id': 'opencompass/cmmlu',
|
||||
'local': './data/cmmlu/',
|
||||
},
|
||||
# CommonsenseQA
|
||||
'opencompass/commonsense_qa': {
|
||||
'ms_id': 'opencompass/commonsense_qa',
|
||||
'hf_id': 'opencompass/commonsense_qa',
|
||||
'local': './data/commonsenseqa',
|
||||
},
|
||||
# CMRC
|
||||
'opencompass/cmrc_dev': {
|
||||
'ms_id': 'opencompass/cmrc_dev',
|
||||
'hf_id': 'opencompass/cmrc_dev',
|
||||
'local': './data/CLUE/CMRC/dev.json'
|
||||
},
|
||||
# DRCD_dev
|
||||
'opencompass/drcd_dev': {
|
||||
'ms_id': 'opencompass/drcd_dev',
|
||||
'hf_id': 'opencompass/drcd_dev',
|
||||
'local': './data/CLUE/DRCD/dev.json'
|
||||
},
|
||||
# clozeTest_maxmin
|
||||
'opencompass/clozeTest_maxmin': {
|
||||
'ms_id': None,
|
||||
'hf_id': None,
|
||||
'local': './data/clozeTest-maxmin/python/clozeTest.json',
|
||||
},
|
||||
# clozeTest_maxmin
|
||||
'opencompass/clozeTest_maxmin_answers': {
|
||||
'ms_id': None,
|
||||
'hf_id': None,
|
||||
'local': './data/clozeTest-maxmin/python/answers.txt',
|
||||
},
|
||||
# Flores
|
||||
'opencompass/flores': {
|
||||
'ms_id': 'opencompass/flores',
|
||||
'hf_id': 'opencompass/flores',
|
||||
'local': './data/flores_first100',
|
||||
},
|
||||
# MBPP
|
||||
'opencompass/mbpp': {
|
||||
'ms_id': 'opencompass/mbpp',
|
||||
'hf_id': 'opencompass/mbpp',
|
||||
'local': './data/mbpp/mbpp.jsonl',
|
||||
},
|
||||
# 'opencompass/mbpp': {
|
||||
# 'ms_id': 'opencompass/mbpp',
|
||||
# 'hf_id': 'opencompass/mbpp',
|
||||
# 'local': './data/mbpp/mbpp.jsonl',
|
||||
# },
|
||||
'opencompass/sanitized_mbpp': {
|
||||
'ms_id': 'opencompass/mbpp',
|
||||
'hf_id': 'opencompass/mbpp',
|
||||
'local': './data/mbpp/sanitized-mbpp.jsonl',
|
||||
},
|
||||
# GSM
|
||||
'opencompass/gsm8k': {
|
||||
'ms_id': 'opencompass/gsm8k',
|
||||
'hf_id': 'opencompass/gsm8k',
|
||||
'local': './data/gsm8k/',
|
||||
},
|
||||
# HellaSwag
|
||||
'opencompass/hellaswag': {
|
||||
'ms_id': 'opencompass/hellaswag',
|
||||
'hf_id': 'opencompass/hellaswag',
|
||||
'local': './data/hellaswag/hellaswag.jsonl',
|
||||
},
|
||||
# HellaSwagICE
|
||||
'opencompass/hellaswag_ice': {
|
||||
'ms_id': 'opencompass/hellaswag',
|
||||
'hf_id': 'opencompass/hellaswag',
|
||||
'local': './data/hellaswag/',
|
||||
},
|
||||
# HumanEval
|
||||
'opencompass/humaneval': {
|
||||
'ms_id': 'opencompass/humaneval',
|
||||
'hf_id': 'opencompass/humaneval',
|
||||
'local': './data/humaneval/human-eval-v2-20210705.jsonl',
|
||||
},
|
||||
# HumanEvalCN
|
||||
'opencompass/humaneval_cn': {
|
||||
'ms_id': 'opencompass/humaneval',
|
||||
'hf_id': 'opencompass/humaneval',
|
||||
'local': './data/humaneval_cn/human-eval-cn-v2-20210705.jsonl',
|
||||
},
|
||||
# Lambada
|
||||
'opencompass/lambada': {
|
||||
'ms_id': 'opencompass/lambada',
|
||||
'hf_id': 'opencompass/lambada',
|
||||
'local': './data/lambada/test.jsonl',
|
||||
},
|
||||
# LCSTS
|
||||
'opencompass/LCSTS': {
|
||||
'ms_id': 'opencompass/LCSTS',
|
||||
'hf_id': 'opencompass/LCSTS',
|
||||
'local': './data/LCSTS',
|
||||
},
|
||||
# MATH
|
||||
'opencompass/math': {
|
||||
'ms_id': 'opencompass/math',
|
||||
'hf_id': 'opencompass/math',
|
||||
'local': './data/math/math.json',
|
||||
},
|
||||
# MMLU
|
||||
'opencompass/mmlu': {
|
||||
'ms_id': 'opencompass/mmlu',
|
||||
'hf_id': 'opencompass/mmlu',
|
||||
'local': './data/mmlu/',
|
||||
},
|
||||
# NQ
|
||||
'opencompass/natural_question': {
|
||||
'ms_id': 'opencompass/natural_question',
|
||||
'hf_id': 'opencompass/natural_question',
|
||||
'local': './data/nq/',
|
||||
},
|
||||
# OpenBook QA-test
|
||||
'opencompass/openbookqa_test': {
|
||||
'ms_id': 'opencompass/openbookqa',
|
||||
'hf_id': 'opencompass/openbookqa',
|
||||
'local': './data/openbookqa/Main/test.jsonl',
|
||||
},
|
||||
# OpenBook QA-fact
|
||||
'opencompass/openbookqa_fact': {
|
||||
'ms_id': 'opencompass/openbookqa',
|
||||
'hf_id': 'opencompass/openbookqa',
|
||||
'local': './data/openbookqa/Additional/test_complete.jsonl',
|
||||
},
|
||||
# PIQA
|
||||
'opencompass/piqa': {
|
||||
'ms_id': 'opencompass/piqa',
|
||||
'hf_id': 'opencompass/piqa',
|
||||
'local': './data/piqa',
|
||||
},
|
||||
# RACE
|
||||
'opencompass/race': {
|
||||
'ms_id': 'opencompass/race',
|
||||
'hf_id': 'opencompass/race',
|
||||
'local': './data/race',
|
||||
},
|
||||
# SIQA
|
||||
'opencompass/siqa': {
|
||||
'ms_id': 'opencompass/siqa',
|
||||
'hf_id': 'opencompass/siqa',
|
||||
'local': './data/siqa',
|
||||
},
|
||||
# XStoryCloze
|
||||
'opencompass/xstory_cloze': {
|
||||
'ms_id': 'opencompass/xstory_cloze',
|
||||
'hf_id': 'opencompass/xstory_cloze',
|
||||
'local': './data/xstory_cloze',
|
||||
},
|
||||
# StrategyQA
|
||||
'opencompass/strategy_qa': {
|
||||
'ms_id': 'opencompass/strategy_qa',
|
||||
'hf_id': 'opencompass/strategy_qa',
|
||||
'local': './data/strategyqa/strategyQA_train.json',
|
||||
},
|
||||
# SummEdits
|
||||
'opencompass/summedits': {
|
||||
'ms_id': 'opencompass/summedits',
|
||||
'hf_id': 'opencompass/summedits',
|
||||
'local': './data/summedits/summedits.jsonl',
|
||||
},
|
||||
# TriviaQA
|
||||
'opencompass/trivia_qa': {
|
||||
'ms_id': 'opencompass/trivia_qa',
|
||||
'hf_id': 'opencompass/trivia_qa',
|
||||
'local': './data/triviaqa/',
|
||||
},
|
||||
# TydiQA
|
||||
'opencompass/tydiqa': {
|
||||
'ms_id': 'opencompass/tydiqa',
|
||||
'hf_id': 'opencompass/tydiqa',
|
||||
'local': './data/tydiqa/',
|
||||
},
|
||||
# Winogrande
|
||||
'opencompass/winogrande': {
|
||||
'ms_id': 'opencompass/winogrande',
|
||||
'hf_id': 'opencompass/winogrande',
|
||||
'local': './data/winogrande/',
|
||||
},
|
||||
# XSum
|
||||
'opencompass/xsum': {
|
||||
'ms_id': 'opencompass/xsum',
|
||||
'hf_id': 'opencompass/xsum',
|
||||
'local': './data/Xsum/dev.jsonl',
|
||||
}
|
||||
}
|
||||
USER_HOME = os.path.expanduser("~")
|
||||
DEFAULT_DATA_FOLDER = os.path.join(USER_HOME, '.cache/opencompass/')
|
||||
|
||||
|
||||
def get_data_path(dataset_id: str, local_mode: bool = False):
|
||||
@ -278,8 +26,10 @@ def get_data_path(dataset_id: str, local_mode: bool = False):
|
||||
# For relative path, with CACHE_DIR
|
||||
if local_mode:
|
||||
local_path = os.path.join(cache_dir, dataset_id)
|
||||
assert os.path.exists(local_path), f'{local_path} does not exist!'
|
||||
return local_path
|
||||
if not os.path.exists(local_path):
|
||||
return download_dataset(local_path, cache_dir)
|
||||
else:
|
||||
return local_path
|
||||
|
||||
dataset_source = os.environ.get('DATASET_SOURCE', None)
|
||||
if dataset_source == 'ModelScope':
|
||||
@ -297,5 +47,57 @@ def get_data_path(dataset_id: str, local_mode: bool = False):
|
||||
# for the local path
|
||||
local_path = DATASETS_MAPPING[dataset_id]['local']
|
||||
local_path = os.path.join(cache_dir, local_path)
|
||||
assert os.path.exists(local_path), f'{local_path} does not exist!'
|
||||
return local_path
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
return download_dataset(local_path, cache_dir)
|
||||
else:
|
||||
return local_path
|
||||
|
||||
|
||||
def download_dataset(data_path, cache_dir, remove_finished=True):
|
||||
get_logger().info(f'{data_path} does not exist!'
|
||||
'Start Download data automatically!'
|
||||
'If you have downloaded the data before,'
|
||||
'You can specific `COMPASS_DATA_CACHE` '
|
||||
'to avoid downloading~')
|
||||
# Try to load from default cache folder
|
||||
try_default_path = os.path.join(DEFAULT_DATA_FOLDER, data_path)
|
||||
if os.path.exists(try_default_path):
|
||||
get_logger().info(f"Try to load the data from {try_default_path}")
|
||||
return try_default_path
|
||||
|
||||
# Cannot find data from default cache folder, download data.
|
||||
# Update DATASET_URL for internal dataset
|
||||
try:
|
||||
import json
|
||||
internal_datasets = '.OPENCOMPASS_INTERNAL_DATA_URL.json'
|
||||
file_path = os.path.join(USER_HOME, internal_datasets)
|
||||
assert os.path.exists(file_path), f"{file_path} not exits"
|
||||
with open(file_path, 'r') as f:
|
||||
internal_datasets_info = json.load(f)
|
||||
DATASETS_URL.update(internal_datasets_info)
|
||||
get_logger().info("Load internal dataset from: {file_path}")
|
||||
except Exception as e: # noqa
|
||||
pass
|
||||
|
||||
valid_data_names = list(DATASETS_URL.keys())
|
||||
dataset_name = ''
|
||||
for name in valid_data_names:
|
||||
if name in data_path:
|
||||
dataset_name = name
|
||||
assert dataset_name, f'No valid url for {data_path}!\n' + \
|
||||
f'Please make sure `{data_path}` is correct'
|
||||
dataset_info = DATASETS_URL[dataset_name]
|
||||
dataset_url = dataset_info['url']
|
||||
dataset_md5 = dataset_info['md5']
|
||||
cache_dir = cache_dir if cache_dir else DEFAULT_DATA_FOLDER
|
||||
|
||||
# download and extract files
|
||||
download_and_extract_archive(
|
||||
url=dataset_url,
|
||||
download_root=os.path.join(cache_dir, 'data'),
|
||||
md5=dataset_md5,
|
||||
remove_finished=remove_finished
|
||||
)
|
||||
|
||||
return os.path.join(cache_dir, data_path)
|
||||
|
345
opencompass/utils/datasets_info.py
Normal file
345
opencompass/utils/datasets_info.py
Normal file
@ -0,0 +1,345 @@
|
||||
DATASETS_MAPPING = {
|
||||
# ADVGLUE Datasets
|
||||
'opencompass/advglue-dev': {
|
||||
'ms_id': None,
|
||||
'hf_id': None,
|
||||
'local': './data/adv_glue/dev_ann.json',
|
||||
},
|
||||
# AGIEval Datasets
|
||||
'opencompass/agieval': {
|
||||
'ms_id': 'opencompass/agieval',
|
||||
'hf_id': 'opencompass/agieval',
|
||||
'local': './data/AGIEval/data/v1/',
|
||||
},
|
||||
|
||||
# ARC Datasets(Test)
|
||||
'opencompass/ai2_arc-test': {
|
||||
'ms_id': 'opencompass/ai2_arc',
|
||||
'hf_id': 'opencompass/ai2_arc',
|
||||
'local': './data/ARC/ARC-c/ARC-Challenge-Test.jsonl',
|
||||
},
|
||||
'opencompass/ai2_arc-dev': {
|
||||
'ms_id': 'opencompass/ai2_arc',
|
||||
'hf_id': 'opencompass/ai2_arc',
|
||||
'local': './data/ARC/ARC-c/ARC-Challenge-Dev.jsonl',
|
||||
},
|
||||
'opencompass/ai2_arc-easy-dev': {
|
||||
'ms_id': 'opencompass/ai2_arc',
|
||||
'hf_id': 'opencompass/ai2_arc',
|
||||
'local': './data/ARC/ARC-e/ARC-Easy-Dev.jsonl',
|
||||
},
|
||||
# BBH
|
||||
'opencompass/bbh': {
|
||||
'ms_id': 'opencompass/bbh',
|
||||
'hf_id': 'opencompass/bbh',
|
||||
'local': './data/BBH/data',
|
||||
},
|
||||
# C-Eval
|
||||
'opencompass/ceval-exam': {
|
||||
'ms_id': 'opencompass/ceval-exam',
|
||||
'hf_id': 'opencompass/ceval-exam',
|
||||
'local': './data/ceval/formal_ceval',
|
||||
},
|
||||
# AFQMC
|
||||
'opencompass/afqmc-dev': {
|
||||
'ms_id': 'opencompass/afqmc',
|
||||
'hf_id': 'opencompass/afqmc',
|
||||
'local': './data/CLUE/AFQMC/dev.json',
|
||||
},
|
||||
# CMNLI
|
||||
'opencompass/cmnli-dev': {
|
||||
'ms_id': 'opencompass/cmnli',
|
||||
'hf_id': 'opencompass/cmnli',
|
||||
'local': './data/CLUE/cmnli/cmnli_public/dev.json',
|
||||
},
|
||||
# OCNLI
|
||||
'opencompass/OCNLI-dev': {
|
||||
'ms_id': 'opencompass/OCNLI',
|
||||
'hf_id': 'opencompass/OCNLI',
|
||||
'local': './data/CLUE/OCNLI/dev.json',
|
||||
},
|
||||
# ChemBench
|
||||
'opencompass/ChemBench': {
|
||||
'ms_id': 'opencompass/ChemBench',
|
||||
'hf_id': 'opencompass/ChemBench',
|
||||
'local': './data/ChemBench/',
|
||||
},
|
||||
# CMMLU
|
||||
'opencompass/cmmlu': {
|
||||
'ms_id': 'opencompass/cmmlu',
|
||||
'hf_id': 'opencompass/cmmlu',
|
||||
'local': './data/cmmlu/',
|
||||
},
|
||||
# CommonsenseQA
|
||||
'opencompass/commonsense_qa': {
|
||||
'ms_id': 'opencompass/commonsense_qa',
|
||||
'hf_id': 'opencompass/commonsense_qa',
|
||||
'local': './data/commonsenseqa',
|
||||
},
|
||||
# CMRC
|
||||
'opencompass/cmrc_dev': {
|
||||
'ms_id': 'opencompass/cmrc_dev',
|
||||
'hf_id': 'opencompass/cmrc_dev',
|
||||
'local': './data/CLUE/CMRC/dev.json'
|
||||
},
|
||||
# DRCD_dev
|
||||
'opencompass/drcd_dev': {
|
||||
'ms_id': 'opencompass/drcd_dev',
|
||||
'hf_id': 'opencompass/drcd_dev',
|
||||
'local': './data/CLUE/DRCD/dev.json'
|
||||
},
|
||||
# clozeTest_maxmin
|
||||
'opencompass/clozeTest_maxmin': {
|
||||
'ms_id': None,
|
||||
'hf_id': None,
|
||||
'local': './data/clozeTest-maxmin/python/clozeTest.json',
|
||||
},
|
||||
# clozeTest_maxmin
|
||||
'opencompass/clozeTest_maxmin_answers': {
|
||||
'ms_id': None,
|
||||
'hf_id': None,
|
||||
'local': './data/clozeTest-maxmin/python/answers.txt',
|
||||
},
|
||||
# Flores
|
||||
'opencompass/flores': {
|
||||
'ms_id': 'opencompass/flores',
|
||||
'hf_id': 'opencompass/flores',
|
||||
'local': './data/flores_first100',
|
||||
},
|
||||
# MBPP
|
||||
'opencompass/mbpp': {
|
||||
'ms_id': 'opencompass/mbpp',
|
||||
'hf_id': 'opencompass/mbpp',
|
||||
'local': './data/mbpp/mbpp.jsonl',
|
||||
},
|
||||
# 'opencompass/mbpp': {
|
||||
# 'ms_id': 'opencompass/mbpp',
|
||||
# 'hf_id': 'opencompass/mbpp',
|
||||
# 'local': './data/mbpp/mbpp.jsonl',
|
||||
# },
|
||||
'opencompass/sanitized_mbpp': {
|
||||
'ms_id': 'opencompass/mbpp',
|
||||
'hf_id': 'opencompass/mbpp',
|
||||
'local': './data/mbpp/sanitized-mbpp.jsonl',
|
||||
},
|
||||
# GSM
|
||||
'opencompass/gsm8k': {
|
||||
'ms_id': 'opencompass/gsm8k',
|
||||
'hf_id': 'opencompass/gsm8k',
|
||||
'local': './data/gsm8k/',
|
||||
},
|
||||
# HellaSwag
|
||||
'opencompass/hellaswag': {
|
||||
'ms_id': 'opencompass/hellaswag',
|
||||
'hf_id': 'opencompass/hellaswag',
|
||||
'local': './data/hellaswag/hellaswag.jsonl',
|
||||
},
|
||||
# HellaSwagICE
|
||||
'opencompass/hellaswag_ice': {
|
||||
'ms_id': 'opencompass/hellaswag',
|
||||
'hf_id': 'opencompass/hellaswag',
|
||||
'local': './data/hellaswag/',
|
||||
},
|
||||
# HumanEval
|
||||
'opencompass/humaneval': {
|
||||
'ms_id': 'opencompass/humaneval',
|
||||
'hf_id': 'opencompass/humaneval',
|
||||
'local': './data/humaneval/human-eval-v2-20210705.jsonl',
|
||||
},
|
||||
# HumanEvalCN
|
||||
'opencompass/humaneval_cn': {
|
||||
'ms_id': 'opencompass/humaneval',
|
||||
'hf_id': 'opencompass/humaneval',
|
||||
'local': './data/humaneval_cn/human-eval-cn-v2-20210705.jsonl',
|
||||
},
|
||||
# Lambada
|
||||
'opencompass/lambada': {
|
||||
'ms_id': 'opencompass/lambada',
|
||||
'hf_id': 'opencompass/lambada',
|
||||
'local': './data/lambada/test.jsonl',
|
||||
},
|
||||
# LCSTS
|
||||
'opencompass/LCSTS': {
|
||||
'ms_id': 'opencompass/LCSTS',
|
||||
'hf_id': 'opencompass/LCSTS',
|
||||
'local': './data/LCSTS',
|
||||
},
|
||||
# MATH
|
||||
'opencompass/math': {
|
||||
'ms_id': 'opencompass/math',
|
||||
'hf_id': 'opencompass/math',
|
||||
'local': './data/math/math.json',
|
||||
},
|
||||
# MMLU
|
||||
'opencompass/mmlu': {
|
||||
'ms_id': 'opencompass/mmlu',
|
||||
'hf_id': 'opencompass/mmlu',
|
||||
'local': './data/mmlu/',
|
||||
},
|
||||
# NQ
|
||||
'opencompass/natural_question': {
|
||||
'ms_id': 'opencompass/natural_question',
|
||||
'hf_id': 'opencompass/natural_question',
|
||||
'local': './data/nq/',
|
||||
},
|
||||
# OpenBook QA-test
|
||||
'opencompass/openbookqa_test': {
|
||||
'ms_id': 'opencompass/openbookqa',
|
||||
'hf_id': 'opencompass/openbookqa',
|
||||
'local': './data/openbookqa/Main/test.jsonl',
|
||||
},
|
||||
# OpenBook QA-fact
|
||||
'opencompass/openbookqa_fact': {
|
||||
'ms_id': 'opencompass/openbookqa',
|
||||
'hf_id': 'opencompass/openbookqa',
|
||||
'local': './data/openbookqa/Additional/test_complete.jsonl',
|
||||
},
|
||||
# PIQA
|
||||
'opencompass/piqa': {
|
||||
'ms_id': 'opencompass/piqa',
|
||||
'hf_id': 'opencompass/piqa',
|
||||
'local': './data/piqa',
|
||||
},
|
||||
# RACE
|
||||
'opencompass/race': {
|
||||
'ms_id': 'opencompass/race',
|
||||
'hf_id': 'opencompass/race',
|
||||
'local': './data/race',
|
||||
},
|
||||
# SIQA
|
||||
'opencompass/siqa': {
|
||||
'ms_id': 'opencompass/siqa',
|
||||
'hf_id': 'opencompass/siqa',
|
||||
'local': './data/siqa',
|
||||
},
|
||||
# XStoryCloze
|
||||
'opencompass/xstory_cloze': {
|
||||
'ms_id': 'opencompass/xstory_cloze',
|
||||
'hf_id': 'opencompass/xstory_cloze',
|
||||
'local': './data/xstory_cloze',
|
||||
},
|
||||
# StrategyQA
|
||||
'opencompass/strategy_qa': {
|
||||
'ms_id': 'opencompass/strategy_qa',
|
||||
'hf_id': 'opencompass/strategy_qa',
|
||||
'local': './data/strategyqa/strategyQA_train.json',
|
||||
},
|
||||
# SummEdits
|
||||
'opencompass/summedits': {
|
||||
'ms_id': 'opencompass/summedits',
|
||||
'hf_id': 'opencompass/summedits',
|
||||
'local': './data/summedits/summedits.jsonl',
|
||||
},
|
||||
# TriviaQA
|
||||
'opencompass/trivia_qa': {
|
||||
'ms_id': 'opencompass/trivia_qa',
|
||||
'hf_id': 'opencompass/trivia_qa',
|
||||
'local': './data/triviaqa/',
|
||||
},
|
||||
# TydiQA
|
||||
'opencompass/tydiqa': {
|
||||
'ms_id': 'opencompass/tydiqa',
|
||||
'hf_id': 'opencompass/tydiqa',
|
||||
'local': './data/tydiqa/',
|
||||
},
|
||||
# Winogrande
|
||||
'opencompass/winogrande': {
|
||||
'ms_id': 'opencompass/winogrande',
|
||||
'hf_id': 'opencompass/winogrande',
|
||||
'local': './data/winogrande/',
|
||||
},
|
||||
# XSum
|
||||
'opencompass/xsum': {
|
||||
'ms_id': 'opencompass/xsum',
|
||||
'hf_id': 'opencompass/xsum',
|
||||
'local': './data/Xsum/dev.jsonl',
|
||||
}
|
||||
}
|
||||
|
||||
DATASETS_URL = {
|
||||
'/mmlu/': {
|
||||
'url':
|
||||
'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu.zip',
|
||||
'md5': '761310671509a239e41c4b717f7fab9c',
|
||||
},
|
||||
'/gpqa/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/gpqa.zip',
|
||||
'md5': '2e9657959030a765916f1f2aca29140d'
|
||||
},
|
||||
'/CHARM/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/CHARM.zip',
|
||||
'md5': 'fdf51e955d1b8e0bb35bc1997eaf37cb'
|
||||
},
|
||||
'/ifeval/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ifeval.zip',
|
||||
'md5': '64d98b6f36b42e7390c9cef76cace75f'
|
||||
},
|
||||
'/mbpp/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mbpp.zip',
|
||||
'md5': '777739c90f04bce44096a5bc96c8f9e5'
|
||||
},
|
||||
'/cmmlu/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/cmmlu.zip',
|
||||
'md5': 'a59f4003d6918509a719ce3bc2a5d5bc'
|
||||
},
|
||||
'/math/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/math.zip',
|
||||
'md5': '8b1b897259684672055e6fd4fc07c808'
|
||||
},
|
||||
'/hellaswag/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/hellaswag.zip',
|
||||
'md5': '2b700a02ffb58571c7df8d8d0619256f'
|
||||
},
|
||||
'/BBH/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/BBH.zip',
|
||||
'md5': '60c49f9bef5148aa7e1941328e96a554'
|
||||
},
|
||||
'/mmlu/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mmlu.zip',
|
||||
'md5': '761310671509a239e41c4b717f7fab9c'
|
||||
},
|
||||
'/compass_arena/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/compass_arena.zip',
|
||||
'md5': 'cd59b54a179d16f2a858b359b60588f6'
|
||||
},
|
||||
'/TheoremQA/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/TheoremQA.zip',
|
||||
'md5': 'f2793b07bc26510d507aa710d9bd8622'
|
||||
},
|
||||
'/mathbench_v1/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mathbench_v1.zip',
|
||||
'md5': '50257a910ca43d1f61a610a79fdb16b5'
|
||||
},
|
||||
'/gsm8k/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/gsm8k.zip',
|
||||
'md5': '901e5dc93a2889789a469da9850cdca8'
|
||||
},
|
||||
'/LCBench2023/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/LCBench2023.zip',
|
||||
'md5': 'e1a38c94a42ad1809e9e0650476a9306'
|
||||
},
|
||||
'/humaneval/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/humaneval.zip',
|
||||
'md5':'88b1b89dc47b7121c81da6bcd85a69c3'
|
||||
},
|
||||
'/drop_simple_eval/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/drop_simple_eval.zip',
|
||||
'md5': 'c912afe5b4a63509851cf16e6b91830e'
|
||||
},
|
||||
'subjective/alignment_bench/': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/alignment_bench.zip',
|
||||
'md5': 'd8ae9a0398526479dbbcdb80fafabceb'
|
||||
},
|
||||
'subjective/alpaca_eval': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/alpaca_eval.zip',
|
||||
'md5': 'd7399d63cb46c82f089447160ef49b6a'
|
||||
},
|
||||
'subjective/arena_hard': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/arena_hard.zip',
|
||||
'md5': '02cd09a482cb0f0cd9d2c2afe7a1697f'
|
||||
},
|
||||
'subjective/mtbench': {
|
||||
'url': 'http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/mtbench.zip',
|
||||
'md5': 'd1afc0787aeac7f1f24872742e161069'
|
||||
},
|
||||
}
|
@ -1,4 +1,14 @@
|
||||
import gzip
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
import mmengine.fileio as fileio
|
||||
@ -166,3 +176,203 @@ def patch_hf_auto_model(cache_dir=None):
|
||||
auto_class.from_pretrained = auto_pt
|
||||
|
||||
patch_hf_auto_model._patched = True
|
||||
|
||||
|
||||
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
backend = get_file_backend(fpath, enable_singleton=True)
|
||||
if isinstance(backend, LocalBackend):
|
||||
# Enable chunk update for local file.
|
||||
with open(fpath, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
else:
|
||||
md5.update(backend.get(fpath))
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath, md5, **kwargs):
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
|
||||
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
"""Download object at the given URL to a local path.
|
||||
|
||||
Modified from
|
||||
https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file
|
||||
|
||||
Args:
|
||||
url (str): URL of the object to download
|
||||
dst (str): Full path where object will be saved,
|
||||
e.g. ``/tmp/temporary_file``
|
||||
hash_prefix (string, optional): If not None, the SHA256 downloaded
|
||||
file should start with ``hash_prefix``. Defaults to None.
|
||||
progress (bool): whether or not to display a progress bar to stderr.
|
||||
Defaults to True
|
||||
"""
|
||||
file_size = None
|
||||
req = urllib.request.Request(url)
|
||||
u = urllib.request.urlopen(req)
|
||||
meta = u.info()
|
||||
if hasattr(meta, 'getheaders'):
|
||||
content_length = meta.getheaders('Content-Length')
|
||||
else:
|
||||
content_length = meta.get_all('Content-Length')
|
||||
if content_length is not None and len(content_length) > 0:
|
||||
file_size = int(content_length[0])
|
||||
|
||||
# We deliberately save it in a temp file and move it after download is
|
||||
# complete. This prevents a local file being overridden by a broken
|
||||
# download.
|
||||
dst = os.path.expanduser(dst)
|
||||
dst_dir = os.path.dirname(dst)
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
||||
|
||||
import rich.progress
|
||||
columns = [
|
||||
rich.progress.DownloadColumn(),
|
||||
rich.progress.BarColumn(bar_width=None),
|
||||
rich.progress.TimeRemainingColumn(),
|
||||
]
|
||||
try:
|
||||
if hash_prefix is not None:
|
||||
sha256 = hashlib.sha256()
|
||||
with rich.progress.Progress(*columns) as pbar:
|
||||
task = pbar.add_task('download', total=file_size, visible=progress)
|
||||
while True:
|
||||
buffer = u.read(8192)
|
||||
if len(buffer) == 0:
|
||||
break
|
||||
f.write(buffer)
|
||||
if hash_prefix is not None:
|
||||
sha256.update(buffer)
|
||||
pbar.update(task, advance=len(buffer))
|
||||
|
||||
f.close()
|
||||
if hash_prefix is not None:
|
||||
digest = sha256.hexdigest()
|
||||
if digest[:len(hash_prefix)] != hash_prefix:
|
||||
raise RuntimeError(
|
||||
'invalid hash value (expected "{}", got "{}")'.format(
|
||||
hash_prefix, digest))
|
||||
shutil.move(f.name, dst)
|
||||
finally:
|
||||
f.close()
|
||||
if os.path.exists(f.name):
|
||||
os.remove(f.name)
|
||||
|
||||
|
||||
def download_url(url, root, filename=None, md5=None):
|
||||
"""Download a file from a url and place it in root.
|
||||
|
||||
Args:
|
||||
url (str): URL to download file from.
|
||||
root (str): Directory to place downloaded file in.
|
||||
filename (str | None): Name to save the file under.
|
||||
If filename is None, use the basename of the URL.
|
||||
md5 (str | None): MD5 checksum of the download.
|
||||
If md5 is None, download without md5 check.
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
fpath = os.path.join(root, filename)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
if check_integrity(fpath, md5):
|
||||
print(f'Using downloaded and verified file: {fpath}')
|
||||
else:
|
||||
try:
|
||||
print(f'Downloading {url} to {fpath}')
|
||||
download_url_to_file(url, fpath)
|
||||
except (urllib.error.URLError, IOError) as e:
|
||||
if url[:5] == 'https':
|
||||
url = url.replace('https:', 'http:')
|
||||
print('Failed download. Trying https -> http instead.'
|
||||
f' Downloading {url} to {fpath}')
|
||||
download_url_to_file(url, fpath)
|
||||
else:
|
||||
raise e
|
||||
# check integrity of downloaded file
|
||||
if not check_integrity(fpath, md5):
|
||||
raise RuntimeError('File not found or corrupted.')
|
||||
|
||||
|
||||
def _is_tarxz(filename):
|
||||
return filename.endswith('.tar.xz')
|
||||
|
||||
|
||||
def _is_tar(filename):
|
||||
return filename.endswith('.tar')
|
||||
|
||||
|
||||
def _is_targz(filename):
|
||||
return filename.endswith('.tar.gz')
|
||||
|
||||
|
||||
def _is_tgz(filename):
|
||||
return filename.endswith('.tgz')
|
||||
|
||||
|
||||
def _is_gzip(filename):
|
||||
return filename.endswith('.gz') and not filename.endswith('.tar.gz')
|
||||
|
||||
|
||||
def _is_zip(filename):
|
||||
return filename.endswith('.zip')
|
||||
|
||||
|
||||
def extract_archive(from_path, to_path=None, remove_finished=False):
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
if _is_tar(from_path):
|
||||
with tarfile.open(from_path, 'r') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_targz(from_path) or _is_tgz(from_path):
|
||||
with tarfile.open(from_path, 'r:gz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_tarxz(from_path):
|
||||
with tarfile.open(from_path, 'r:xz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_gzip(from_path):
|
||||
to_path = os.path.join(
|
||||
to_path,
|
||||
os.path.splitext(os.path.basename(from_path))[0])
|
||||
with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
|
||||
out_f.write(zip_f.read())
|
||||
elif _is_zip(from_path):
|
||||
with zipfile.ZipFile(from_path, 'r') as z:
|
||||
z.extractall(to_path)
|
||||
else:
|
||||
raise ValueError(f'Extraction of {from_path} not supported')
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
|
||||
def download_and_extract_archive(url,
|
||||
download_root,
|
||||
extract_root=None,
|
||||
filename=None,
|
||||
md5=None,
|
||||
remove_finished=False):
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url(url, download_root, filename, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print(f'Extracting {archive} to {extract_root}')
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
@ -1,3 +1,4 @@
|
||||
alpaca-eval==0.6
|
||||
faiss_gpu==1.7.2
|
||||
latex2sympy2
|
||||
scikit-learn==1.5
|
||||
|
Loading…
Reference in New Issue
Block a user