eyetrack-to-sacc-pipeline / eyetrack_2_saccade_pipeline.py
Georg Willer
Change default header names to upercase starting & add checking for alternative time_header names
b9f7415
from transformers import Pipeline
from .eyetrack2saccade import Eye2SacExtractor
class Eye2SacPipeline(Pipeline):
eye2SacExtractor = Eye2SacExtractor()
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "missing" in kwargs:
preprocess_kwargs["missing"] = kwargs["missing"]
if "minlen" in kwargs:
preprocess_kwargs["minlen"] = kwargs["minlen"]
if "maxvel" in kwargs:
preprocess_kwargs["maxvel"] = kwargs["maxvel"]
if "maxacc" in kwargs:
preprocess_kwargs["maxacc"] = kwargs["maxacc"]
if "time_header" in kwargs:
preprocess_kwargs["time_header"] = kwargs["time_header"]
if "x_headers" in kwargs:
preprocess_kwargs["x_headers"] = kwargs["x_headers"]
if "y_headers" in kwargs:
preprocess_kwargs["y_headers"] = kwargs["y_headers"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, missing=0.0, minlen=5, maxvel=40, maxacc=30, time_header = 'Time', x_headers = 'X', y_headers= 'Y'):
return self.eye2SacExtractor.extract_features(inputs, time_header, x_headers, y_headers, missing, minlen, maxvel, maxacc)
def _forward(self, model_inputs):
return model_inputs
def postprocess(self, model_outputs):
return model_outputs