ECG2HRV / src /ecg_2_hrv_pipeline.py
Georg Willer
Remove unused code
708c7e0
from transformers import Pipeline
from src.ecg2hrv import ECG2HRV
class Ecg2HrvPipeline(Pipeline):
ecg2HrvExtractor = ECG2HRV()
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "sampling_rate" in kwargs:
preprocess_kwargs["sampling_rate"] = kwargs["sampling_rate"]
if "baseline" in kwargs:
preprocess_kwargs["baseline"] = kwargs["baseline"]
if "normalization_method" in kwargs:
preprocess_kwargs["normalization_method"] = kwargs["normalization_method"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, sampling_rate = 1000, baseline = None, normalization_method = None):
return self.ecg2HrvExtractor.extract_features(inputs, sampling_rate, baseline, normalization_method)
def _forward(self, model_inputs):
# currently empty as all preprocessing steps are performed by preprocess function
# in future extendable to facilitate end-2-end ML pipelines
return model_inputs
def postprocess(self, model_outputs):
return model_outputs