Spaces:
Runtime error
Runtime error
# Using local sentence transformers with GPU | |
# from langchain_community.embeddings import HuggingFaceEmbeddings | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from langchain.embeddings.base import Embeddings | |
from typing import List | |
class LocalHuggingFaceEmbeddings(Embeddings): | |
"""Use local SentenceTransformer embeddings""" | |
def __init__(self, model_name: str = "sentence-transformers/all-mpnet-base-v2"): | |
"""Initialize the embeddings""" | |
# Determine device | |
if torch.cuda.is_available(): | |
self.device = "cuda" | |
print("Using CUDA for embeddings") | |
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
self.device = "mps" | |
print("Using MPS for embeddings") | |
else: | |
self.device = "cpu" | |
print("Using CPU for embeddings") | |
# Load the model | |
self.model = SentenceTransformer(model_name, device=self.device) | |
self.model_name = model_name | |
def embed_query(self, text: str) -> List[float]: | |
"""Embed a query""" | |
# Creates embedding for a single query | |
embedding = self.model.encode(text, show_progress_bar=False) | |
return embedding.tolist() # Convert numpy array to list | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Embed documents""" | |
# Creates embeddings for a list of documents | |
embeddings = self.model.encode(texts, show_progress_bar=True) | |
return embeddings.tolist() # Convert numpy arrays to lists | |
if __name__ == "__main__": | |
# Create the local embeddings model | |
embeddings = LocalHuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2") | |
print("Is CUDA Available? ", torch.cuda.is_available()) | |
print("Is MPS Available? ", torch.backends.mps.is_available()) | |
# Test embeddings | |
sentence = "Hello, how are you?" | |
embed = embeddings.embed_query(sentence) | |
print(f"Embedding length: {len(embed)}") | |
print(f"First few values: {embed[:5]}") | |
# Test with multiple sentences | |
sentences = ["Hello, how are you?", "I am fine, thank you.", | |
"What is the weather like today?"] | |
embeds = embeddings.embed_documents(sentences) | |
print(f"Number of embeddings: {len(embeds)}") | |
print(f"Embedding dimensions: {len(embeds[0])}") | |