File size: 694 Bytes
c0139e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d84f238
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from typing import Dict, List, Any
import pandas as pd
import webvtt
from datetime import datetime
import torch
import spacy

nlp = spacy.load("en_core_web_sm")
tokenizer = nlp.tokenizer
token_limit = 200

class EndpointHandler():
    def __init__(self, path="."):
        print("Loading models...")
        cuda_available = torch.cuda.is_available()
        self.model = ClassificationModel(
            "roberta", path, use_cuda=cuda_available
        )

    def __call__(self, data_file: str) -> List[Dict[str, Any]]:
        ''' data_file is a str pointing to filename of type .vtt '''
        return []