from transformers import Pipeline from .eyetrack2saccade import Eye2SacExtractor class Eye2SacPipeline(Pipeline): eye2SacExtractor = Eye2SacExtractor() def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "missing" in kwargs: preprocess_kwargs["missing"] = kwargs["missing"] if "minlen" in kwargs: preprocess_kwargs["minlen"] = kwargs["minlen"] if "maxvel" in kwargs: preprocess_kwargs["maxvel"] = kwargs["maxvel"] if "maxacc" in kwargs: preprocess_kwargs["maxacc"] = kwargs["maxacc"] if "time_header" in kwargs: preprocess_kwargs["time_header"] = kwargs["time_header"] if "x_headers" in kwargs: preprocess_kwargs["x_headers"] = kwargs["x_headers"] if "y_headers" in kwargs: preprocess_kwargs["y_headers"] = kwargs["y_headers"] return preprocess_kwargs, {}, {} def preprocess(self, inputs, missing=0.0, minlen=5, maxvel=40, maxacc=30, time_header = 'Time', x_headers = 'X', y_headers= 'Y'): return self.eye2SacExtractor.extract_features(inputs, time_header, x_headers, y_headers, missing, minlen, maxvel, maxacc) def _forward(self, model_inputs): return model_inputs def postprocess(self, model_outputs): return model_outputs