ssandy_agents / local_gpu_embeddings.py
Sheshank Joshi
important changes
554ef85
raw
history blame
2.37 kB
# Using local sentence transformers with GPU
# from langchain_community.embeddings import HuggingFaceEmbeddings
import torch
from sentence_transformers import SentenceTransformer
from langchain.embeddings.base import Embeddings
from typing import List
class LocalHuggingFaceEmbeddings(Embeddings):
"""Use local SentenceTransformer embeddings"""
def __init__(self, model_name: str = "sentence-transformers/all-mpnet-base-v2"):
"""Initialize the embeddings"""
# Determine device
if torch.cuda.is_available():
self.device = "cuda"
print("Using CUDA for embeddings")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
print("Using MPS for embeddings")
else:
self.device = "cpu"
print("Using CPU for embeddings")
# Load the model
self.model = SentenceTransformer(model_name, device=self.device)
self.model_name = model_name
def embed_query(self, text: str) -> List[float]:
"""Embed a query"""
# Creates embedding for a single query
embedding = self.model.encode(text, show_progress_bar=False)
return embedding.tolist() # Convert numpy array to list
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents"""
# Creates embeddings for a list of documents
embeddings = self.model.encode(texts, show_progress_bar=True)
return embeddings.tolist() # Convert numpy arrays to lists
if __name__ == "__main__":
# Create the local embeddings model
embeddings = LocalHuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2")
print("Is CUDA Available? ", torch.cuda.is_available())
print("Is MPS Available? ", torch.backends.mps.is_available())
# Test embeddings
sentence = "Hello, how are you?"
embed = embeddings.embed_query(sentence)
print(f"Embedding length: {len(embed)}")
print(f"First few values: {embed[:5]}")
# Test with multiple sentences
sentences = ["Hello, how are you?", "I am fine, thank you.",
"What is the weather like today?"]
embeds = embeddings.embed_documents(sentences)
print(f"Number of embeddings: {len(embeds)}")
print(f"Embedding dimensions: {len(embeds[0])}")