ass / e5_encoder.py
Edgeev's picture
Create e5_encoder.py
138e5ba verified
# e5_encoder.py
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
class E5Encoder:
def __init__(self, model_name: str = "intfloat/e5-base", device: str = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
def encode(self, text: str, normalize: bool = True) -> np.ndarray:
if not text:
return np.zeros((self.model.config.hidden_size,), dtype=np.float32)
# Prefix for E5 models to work properly
prefixed_text = f"passage: {text}" if not text.startswith("query:") else text
inputs = self.tokenizer(prefixed_text, return_tensors="pt", truncation=True, padding=True).to(self.device)
with torch.no_grad():
output = self.model(**inputs)
emb = output.last_hidden_state[:, 0] # [CLS] token
emb = emb.squeeze(0)
if normalize:
emb = emb / torch.norm(emb, p=2)
return emb.cpu().numpy()