zhang qiao
Upload folder using huggingface_hub
8cf4695
from typing import TypedDict
import logging
from .xgbregressor import XGBRegressor
from .multiple_linear_regressor import MultipleLinearRegressor
from .xgboost import XGBoost
from .prophet import ProphetForecaster
from typing import TypedDict, List
class Model(TypedDict):
name: str
model: any
class AllModels():
def __init__(self) -> None:
# Any available model must register here
self.all_models = {
'xgbreg': XGBRegressor,
'mlr': MultipleLinearRegressor
}
self.all_model_names = self.all_models.keys()
def init_models(
self,
models
) -> List[Model]:
logging.debug('Init models')
if models == 'all':
self.model_names = self.all_model_names
elif isinstance(models, str):
self.model_names = [models]
else:
self.model_names = models
logging.debug('Check model names')
unknown_models = set(self.model_names) - set(self.all_model_names)
if len(unknown_models) > 0:
raise ValueError(
f'Unknown model : {unknown_models}, please use active models: {self.all_model_names}')
else:
self.models = [
{
'name': name,
'model': self.all_models[name]()
}
for name in self.model_names]
return self.models