File size: 1,382 Bytes
efca5ee 45856e0 efca5ee fb2fb4c 9b37b0e efca5ee b9f7415 9b37b0e efca5ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from transformers import Pipeline
from .eyetrack2saccade import Eye2SacExtractor
class Eye2SacPipeline(Pipeline):
eye2SacExtractor = Eye2SacExtractor()
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "missing" in kwargs:
preprocess_kwargs["missing"] = kwargs["missing"]
if "minlen" in kwargs:
preprocess_kwargs["minlen"] = kwargs["minlen"]
if "maxvel" in kwargs:
preprocess_kwargs["maxvel"] = kwargs["maxvel"]
if "maxacc" in kwargs:
preprocess_kwargs["maxacc"] = kwargs["maxacc"]
if "time_header" in kwargs:
preprocess_kwargs["time_header"] = kwargs["time_header"]
if "x_headers" in kwargs:
preprocess_kwargs["x_headers"] = kwargs["x_headers"]
if "y_headers" in kwargs:
preprocess_kwargs["y_headers"] = kwargs["y_headers"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, missing=0.0, minlen=5, maxvel=40, maxacc=30, time_header = 'Time', x_headers = 'X', y_headers= 'Y'):
return self.eye2SacExtractor.extract_features(inputs, time_header, x_headers, y_headers, missing, minlen, maxvel, maxacc)
def _forward(self, model_inputs):
return model_inputs
def postprocess(self, model_outputs):
return model_outputs |