|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
class E5Encoder: |
|
|
def __init__(self, model_name: str = "intfloat/e5-base", device: str = None): |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|
|
|
|
|
def encode(self, text: str, normalize: bool = True) -> np.ndarray: |
|
|
if not text: |
|
|
return np.zeros((self.model.config.hidden_size,), dtype=np.float32) |
|
|
|
|
|
|
|
|
prefixed_text = f"passage: {text}" if not text.startswith("query:") else text |
|
|
inputs = self.tokenizer(prefixed_text, return_tensors="pt", truncation=True, padding=True).to(self.device) |
|
|
with torch.no_grad(): |
|
|
output = self.model(**inputs) |
|
|
emb = output.last_hidden_state[:, 0] |
|
|
|
|
|
emb = emb.squeeze(0) |
|
|
if normalize: |
|
|
emb = emb / torch.norm(emb, p=2) |
|
|
return emb.cpu().numpy() |
|
|
|