mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import List
|
|
|
|
from opencompass.datasets.evalplus.provider.utility import EOS
|
|
|
|
|
|
class DecoderBase(ABC):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
batch_size: int = 1,
|
|
temperature: float = 0.8,
|
|
max_new_tokens: int = 768,
|
|
dtype: str = "bfloat16", # default
|
|
trust_remote_code: bool = False,
|
|
instruction_prefix: str = None,
|
|
response_prefix: str = None,
|
|
) -> None:
|
|
print("Initializing a decoder model: {} ...".format(name))
|
|
self.name = name
|
|
self.batch_size = batch_size
|
|
self.temperature = temperature
|
|
self.eos = EOS
|
|
self.skip_special_tokens = False
|
|
self.max_new_tokens = max_new_tokens
|
|
self.dtype = dtype
|
|
self.trust_remote_code = trust_remote_code
|
|
self.instruction_prefix = instruction_prefix
|
|
self.response_prefix = response_prefix
|
|
|
|
@abstractmethod
|
|
def codegen(
|
|
self, prompt: str, do_sample: bool = True, num_samples: int = 200
|
|
) -> List[str]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_direct_completion(self) -> bool:
|
|
pass
|
|
|
|
def __repr__(self) -> str:
|
|
return self.name
|
|
|
|
def __str__(self) -> str:
|
|
return self.name
|