Spaces:
Runtime error
Runtime error
from typing import TypedDict | |
import logging | |
from .xgbregressor import XGBRegressor | |
from .multiple_linear_regressor import MultipleLinearRegressor | |
from .xgboost import XGBoost | |
from .prophet import ProphetForecaster | |
from typing import TypedDict, List | |
class Model(TypedDict): | |
name: str | |
model: any | |
class AllModels(): | |
def __init__(self) -> None: | |
# Any available model must register here | |
self.all_models = { | |
'xgbreg': XGBRegressor, | |
'mlr': MultipleLinearRegressor | |
} | |
self.all_model_names = self.all_models.keys() | |
def init_models( | |
self, | |
models | |
) -> List[Model]: | |
logging.debug('Init models') | |
if models == 'all': | |
self.model_names = self.all_model_names | |
elif isinstance(models, str): | |
self.model_names = [models] | |
else: | |
self.model_names = models | |
logging.debug('Check model names') | |
unknown_models = set(self.model_names) - set(self.all_model_names) | |
if len(unknown_models) > 0: | |
raise ValueError( | |
f'Unknown model : {unknown_models}, please use active models: {self.all_model_names}') | |
else: | |
self.models = [ | |
{ | |
'name': name, | |
'model': self.all_models[name]() | |
} | |
for name in self.model_names] | |
return self.models | |