Spaces:
Runtime error
Runtime error
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings | |
import torch | |
import os | |
from langchain_openai import ChatOpenAI | |
from langchain_groq import ChatGroq | |
from langchain.chat_models.base import BaseChatModel | |
from langchain_chroma import Chroma | |
def get_llm(provider: str = "groq") -> BaseChatModel: | |
# Load environment variables from .env file | |
if provider == "groq": | |
# Groq https://console.groq.com/docs/models | |
# optional : qwen-qwq-32b gemma2-9b-it | |
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) | |
elif provider == "huggingface": | |
# TODO: Add huggingface endpoint | |
llm = ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
model="Meta-DeepLearning/llama-2-7b-chat-hf", | |
temperature=0, | |
), | |
) | |
elif provider == "openai_local": | |
from langchain_openai import ChatOpenAI | |
llm = ChatOpenAI( | |
base_url="http://localhost:11432/v1", # default LM Studio endpoint | |
api_key="not-used", # required by interface but ignored #type: ignore | |
# model="mistral-nemo-instruct-2407", | |
model="mistral-nemo-instruct-2407", | |
temperature=0.2 | |
) | |
elif provider == "openai": | |
from langchain_openai import ChatOpenAI | |
llm = ChatOpenAI( | |
model="gpt-4o", | |
temperature=0.2, | |
) | |
else: | |
raise ValueError( | |
"Invalid provider. Choose 'groq' or 'huggingface'.") | |
return llm | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
model_kwargs={"device": "gpu" if torch.cuda.is_available() else "cpu", | |
"token": os.getenv("HF_TOKEN")}, | |
show_progress=True, | |
) | |
# Initialize empty Chroma vector store | |
vector_store = Chroma( | |
embedding_function=embeddings, | |
collection_name="agent_memory" | |
) | |