|
from transformers import Pipeline |
|
from src.ecg2hrv import ECG2HRV |
|
|
|
class Ecg2HrvPipeline(Pipeline): |
|
ecg2HrvExtractor = ECG2HRV() |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "sampling_rate" in kwargs: |
|
preprocess_kwargs["sampling_rate"] = kwargs["sampling_rate"] |
|
if "baseline" in kwargs: |
|
preprocess_kwargs["baseline"] = kwargs["baseline"] |
|
if "normalization_method" in kwargs: |
|
preprocess_kwargs["normalization_method"] = kwargs["normalization_method"] |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, inputs, sampling_rate = 1000, baseline = None, normalization_method = None): |
|
return self.ecg2HrvExtractor.extract_features(inputs, sampling_rate, baseline, normalization_method) |
|
|
|
def _forward(self, model_inputs): |
|
|
|
|
|
return model_inputs |
|
|
|
def postprocess(self, model_outputs): |
|
return model_outputs |