medivocate / src /utilities /llm_models.py
alexneakameni's picture
Medivocate : An AI-powered platform exploring African history, culture, and traditional medicine, fostering understanding and appreciation of the continent's rich heritage.
15aea1e verified
raw
history blame
1.22 kB
import os
from enum import Enum
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama, OllamaEmbeddings
from .embedding import CustomEmbedding
class LLMModel(Enum):
OLLAMA = "ChatOllama"
GROQ = "ChatGroq"
def get_llm_model_chat(temperature=0.01, max_tokens: int = None):
if str(os.getenv("USE_OLLAMA_CHAT")) == "1":
return ChatOllama(
model=os.getenv("OLLAMA_MODEL"),
temperature=temperature,
num_predict=max_tokens,
)
return ChatGroq(
model=os.getenv("GROQ_MODEL_NAME"),
temperature=temperature,
max_tokens=max_tokens,
)
def get_llm_model_embedding():
if str(os.getenv("USE_HF_EMBEDDING")) == "1":
return CustomEmbedding()
return OllamaEmbeddings(
model=os.getenv("OLLAM_EMB"),
base_url=(
os.getenv("OLLAMA_HOST") if os.getenv("OLLAMA_HOST") is not None else None
),
client_kwargs=(
{
"headers": {
"Authorization": "Bearer " + (os.getenv("OLLAMA_TOKEN") or "")
}
}
if os.getenv("OLLAMA_HOST") is not None
else None
),
)