File size: 1,382 Bytes
efca5ee
45856e0
efca5ee
 
 
 
 
 
 
fb2fb4c
 
 
 
 
 
 
 
9b37b0e
 
 
 
 
 
efca5ee
 
b9f7415
9b37b0e
efca5ee
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import Pipeline
from .eyetrack2saccade import Eye2SacExtractor

class Eye2SacPipeline(Pipeline):

    eye2SacExtractor = Eye2SacExtractor()
    
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "missing" in kwargs:
            preprocess_kwargs["missing"] = kwargs["missing"]
        if "minlen" in kwargs:
            preprocess_kwargs["minlen"] = kwargs["minlen"]
        if "maxvel" in kwargs:
            preprocess_kwargs["maxvel"] = kwargs["maxvel"]
        if "maxacc" in kwargs:
            preprocess_kwargs["maxacc"] = kwargs["maxacc"]
        if "time_header" in kwargs:
            preprocess_kwargs["time_header"] = kwargs["time_header"]
        if "x_headers" in kwargs:
            preprocess_kwargs["x_headers"] = kwargs["x_headers"]
        if "y_headers" in kwargs:
            preprocess_kwargs["y_headers"] = kwargs["y_headers"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs, missing=0.0, minlen=5, maxvel=40, maxacc=30, time_header = 'Time', x_headers = 'X', y_headers= 'Y'):
        return self.eye2SacExtractor.extract_features(inputs, time_header, x_headers, y_headers, missing, minlen, maxvel, maxacc)

    def _forward(self, model_inputs):
        return model_inputs

    def postprocess(self, model_outputs):
        return model_outputs