|
""" |
|
Load LLMs from huggingface, Groq, etc. |
|
""" |
|
|
|
from transformers import ( |
|
|
|
AutoTokenizer, |
|
pipeline, |
|
) |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain_groq import ChatGroq |
|
from langchain.llms import HuggingFaceTextGenInference |
|
|
|
|
|
|
|
|
|
def get_llm_hf_online(inference_api_url=""): |
|
"""Get LLM using huggingface inference.""" |
|
|
|
if not inference_api_url: |
|
inference_api_url = ( |
|
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta" |
|
) |
|
|
|
llm = HuggingFaceTextGenInference( |
|
verbose=True, |
|
max_new_tokens=1024, |
|
top_p=0.95, |
|
temperature=0.1, |
|
inference_server_url=inference_api_url, |
|
timeout=10, |
|
) |
|
|
|
return llm |
|
|
|
|
|
def get_llm_hf_local(model_path): |
|
"""Get local LLM from huggingface.""" |
|
|
|
model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=2048, |
|
model_kwargs={"temperature": 0.1}, |
|
) |
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
return llm |
|
|
|
|
|
def get_groq_chat(model_name="llama-3.1-70b-versatile"): |
|
"""Get LLM from Groq.""" |
|
|
|
llm = ChatGroq(temperature=0, model_name=model_name) |
|
return llm |
|
|