Spaces:
Runtime error
Runtime error
File size: 1,430 Bytes
8cf4695 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from typing import TypedDict
import logging
from .xgbregressor import XGBRegressor
from .multiple_linear_regressor import MultipleLinearRegressor
from .xgboost import XGBoost
from .prophet import ProphetForecaster
from typing import TypedDict, List
class Model(TypedDict):
name: str
model: any
class AllModels():
def __init__(self) -> None:
# Any available model must register here
self.all_models = {
'xgbreg': XGBRegressor,
'mlr': MultipleLinearRegressor
}
self.all_model_names = self.all_models.keys()
def init_models(
self,
models
) -> List[Model]:
logging.debug('Init models')
if models == 'all':
self.model_names = self.all_model_names
elif isinstance(models, str):
self.model_names = [models]
else:
self.model_names = models
logging.debug('Check model names')
unknown_models = set(self.model_names) - set(self.all_model_names)
if len(unknown_models) > 0:
raise ValueError(
f'Unknown model : {unknown_models}, please use active models: {self.all_model_names}')
else:
self.models = [
{
'name': name,
'model': self.all_models[name]()
}
for name in self.model_names]
return self.models
|