File size: 662 Bytes
9127367
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import os
from sentence_transformers import CrossEncoder
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = GLiClassModel.from_pretrained(os.getenv("GLICLASS_MODEL_PATH")).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(os.getenv("GLICLASS_MODEL_PATH"))
multi_label_pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label',
                                                      device=device)
st = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")