Georg Willer
Change default header names to upercase starting & add checking for alternative time_header names
b9f7415
from transformers import Pipeline | |
from .eyetrack2saccade import Eye2SacExtractor | |
class Eye2SacPipeline(Pipeline): | |
eye2SacExtractor = Eye2SacExtractor() | |
def _sanitize_parameters(self, **kwargs): | |
preprocess_kwargs = {} | |
if "missing" in kwargs: | |
preprocess_kwargs["missing"] = kwargs["missing"] | |
if "minlen" in kwargs: | |
preprocess_kwargs["minlen"] = kwargs["minlen"] | |
if "maxvel" in kwargs: | |
preprocess_kwargs["maxvel"] = kwargs["maxvel"] | |
if "maxacc" in kwargs: | |
preprocess_kwargs["maxacc"] = kwargs["maxacc"] | |
if "time_header" in kwargs: | |
preprocess_kwargs["time_header"] = kwargs["time_header"] | |
if "x_headers" in kwargs: | |
preprocess_kwargs["x_headers"] = kwargs["x_headers"] | |
if "y_headers" in kwargs: | |
preprocess_kwargs["y_headers"] = kwargs["y_headers"] | |
return preprocess_kwargs, {}, {} | |
def preprocess(self, inputs, missing=0.0, minlen=5, maxvel=40, maxacc=30, time_header = 'Time', x_headers = 'X', y_headers= 'Y'): | |
return self.eye2SacExtractor.extract_features(inputs, time_header, x_headers, y_headers, missing, minlen, maxvel, maxacc) | |
def _forward(self, model_inputs): | |
return model_inputs | |
def postprocess(self, model_outputs): | |
return model_outputs |