File size: 1,430 Bytes
8cf4695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import TypedDict

import logging

from .xgbregressor import XGBRegressor
from .multiple_linear_regressor import MultipleLinearRegressor

from .xgboost import XGBoost
from .prophet import ProphetForecaster

from typing import TypedDict, List


class Model(TypedDict):
    name: str
    model: any


class AllModels():
    def __init__(self) -> None:

        # Any available model must register here
        self.all_models = {
            'xgbreg': XGBRegressor,
            'mlr': MultipleLinearRegressor
        }
        self.all_model_names = self.all_models.keys()

    def init_models(
        self,
        models
    ) -> List[Model]:
        logging.debug('Init models')

        if models == 'all':
            self.model_names = self.all_model_names
        elif isinstance(models, str):
            self.model_names = [models]
        else:
            self.model_names = models

        logging.debug('Check model names')
        unknown_models = set(self.model_names) - set(self.all_model_names)
        if len(unknown_models) > 0:
            raise ValueError(
                f'Unknown model : {unknown_models}, please use active models: {self.all_model_names}')
        else:
            self.models = [
                {
                    'name': name,
                    'model': self.all_models[name]()
                }
                for name in self.model_names]

        return self.models