GLiClass ONNX
Collection
GLiClass models converted to ONNX format, as well as 8bit quantization
•
5 items
•
Updated
•
2
Original model here
Code for converting to onnx and quantizing here
Dependencies:
pip install huggingface-hub onnx onnxruntime numpy tokenizers
Inference code:
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
import onnxruntime as ort
import numpy as np
class GLiClassOnnxInference:
def __init__(self, model_id: str, use_int8_quant: bool = False):
self.onnx_runtime_session = ort.InferenceSession(
hf_hub_download(repo_id=model_id, filename="model_i8.onnx" if use_int8_quant else "model.onnx")
)
self.tokenizer = Tokenizer.from_file(
hf_hub_download(repo_id=model_id, filename="tokenizer.json")
)
def encode(self, text: str, max_length: int = 512, pad: bool = True):
encoded = self.tokenizer.encode(text)
ids = encoded.ids
mask = encoded.attention_mask
if pad and len(ids) < max_length:
pad_len = max_length - len(ids)
ids += [self.tokenizer.token_to_id("[PAD]")] * pad_len
mask += [0] * pad_len
ids = ids[:max_length]
mask = mask[:max_length]
return np.array([ids], dtype=np.int64), np.array([mask], dtype=np.int64)
def onnx_predict(self, text: str, labels: list[str]):
full_text = "".join([f"<<LABEL>>{l}" for l in labels]) + "<<SEP>>" + text
ids, mask = self.encode(full_text, max_length=512)
ort_inputs = {"input_ids": ids, "attention_mask": mask}
logits = self.onnx_runtime_session.run(None, ort_inputs)[0]
probs = 1 / (1 + np.exp(-logits[0]))
return [{"label": label, "score": float(prob)} for label, prob in zip(labels, probs)]
inference_session = GLiClassOnnxInference(
"cnmoro/gliclass-large-v3.0-onnx",
use_int8_quant = False
)
results = inference_session.onnx_predict(
text = "One day I will see the world!",
labels = ["travel", "dreams", "sport", "science", "politics"]
)
for r in results:
print(f"{r['label']} => {r['score']:.3f}")
Base model
knowledgator/gliclass-large-v3.0