Spaces:
Build error
Build error
| import os | |
| import logging | |
| import click | |
| import torch | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import HuggingFaceInstructEmbeddings | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response | |
| from langchain.callbacks.manager import CallbackManager | |
| torch.set_grad_enabled(False) | |
| callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
| from prompt_template_utils import get_prompt_template | |
| from langchain.vectorstores import Chroma | |
| from transformers import ( | |
| GenerationConfig, | |
| pipeline, | |
| TextStreamer | |
| ) | |
| from load_models import ( | |
| load_quantized_model_gguf_ggml, | |
| load_quantized_model_qptq, | |
| load_full_model, | |
| ) | |
| from constants import ( | |
| EMBEDDING_MODEL_NAME, | |
| PERSIST_DIRECTORY, | |
| MODEL_ID, | |
| MODEL_BASENAME, | |
| MAX_NEW_TOKENS, | |
| MODELS_PATH, | |
| ) | |
| def load_model(device_type, model_id, model_basename=None, LOGGING=logging): | |
| """ | |
| Select a model for text generation using the HuggingFace library. | |
| If you are running this for the first time, it will download a model for you. | |
| subsequent runs will use the model from the disk. | |
| Args: | |
| device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU. | |
| model_id (str): Identifier of the model to load from HuggingFace's model hub. | |
| model_basename (str, optional): Basename of the model if using quantized models. | |
| Defaults to None. | |
| Returns: | |
| HuggingFacePipeline: A pipeline object for text generation using the loaded model. | |
| Raises: | |
| ValueError: If an unsupported model or device type is provided. | |
| """ | |
| logging.info(f"Loading Model: {model_id}, on: {device_type}") | |
| logging.info("This action can take a few minutes!") | |
| if model_basename is not None: | |
| if ".gguf" in model_basename.lower(): | |
| llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING) | |
| return llm | |
| elif ".ggml" in model_basename.lower(): | |
| model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING) | |
| else: | |
| model, tokenizer = load_quantized_model_qptq(model_id, model_basename, device_type, LOGGING) | |
| else: | |
| model, tokenizer = load_full_model(model_id, model_basename, device_type, LOGGING) | |
| # Load configuration from the model to avoid warnings | |
| generation_config = GenerationConfig.from_pretrained(model_id) | |
| # see here for details: | |
| # https://huggingface.co/docs/transformers/ | |
| # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns | |
| # Create a pipeline for text generation | |
| streamer = TextStreamer(tokenizer) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_length=50, | |
| temperature=0.15, | |
| top_p=0.1, | |
| top_k=40, | |
| repetition_penalty=1.0, | |
| generation_config=generation_config, | |
| streamer=streamer | |
| ) | |
| local_llm = HuggingFacePipeline(pipeline=pipe) | |
| logging.info("Local LLM Loaded") | |
| return local_llm | |
| def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"): | |
| """ | |
| Initializes and returns a retrieval-based Question Answering (QA) pipeline. | |
| This function sets up a QA system that retrieves relevant information using embeddings | |
| from the HuggingFace library. It then answers questions based on the retrieved information. | |
| Parameters: | |
| - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'cuda', etc. | |
| - use_history (bool): Flag to determine whether to use chat history or not. | |
| Returns: | |
| - RetrievalQA: An initialized retrieval-based QA system. | |
| Notes: | |
| - The function uses embeddings from the HuggingFace library, either instruction-based or regular. | |
| - The Chroma class is used to load a vector store containing pre-computed embeddings. | |
| - The retriever fetches relevant documents or data based on a query. | |
| - The prompt and memory, obtained from the `get_prompt_template` function, might be used in the QA system. | |
| - The model is loaded onto the specified device using its ID and basename. | |
| - The QA system retrieves relevant documents using the retriever and then answers questions based on those documents. | |
| """ | |
| embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type}) | |
| # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py | |
| # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
| # load the vectorstore | |
| db = Chroma( | |
| persist_directory=PERSIST_DIRECTORY, | |
| embedding_function=embeddings, | |
| ) | |
| retriever = db.as_retriever() | |
| # get the prompt template and memory if set by the user. | |
| prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history) | |
| # load the llm pipeline | |
| llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging) | |
| if use_history: | |
| qa = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank | |
| retriever=retriever, | |
| return_source_documents=True, # verbose=True, | |
| callbacks=callback_manager, | |
| chain_type_kwargs={"prompt": prompt, "memory": memory}, | |
| ) | |
| else: | |
| qa = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank | |
| retriever=retriever, | |
| return_source_documents=True, # verbose=True, | |
| callbacks=callback_manager, | |
| chain_type_kwargs={ | |
| "prompt": prompt, | |
| }, | |
| ) | |
| return qa | |
| # chose device typ to run on as well as to show source documents. | |
| def main(device_type, show_sources, use_history, model_type): | |
| """ | |
| Implements the main information retrieval task for a localGPT. | |
| This function sets up the QA system by loading the necessary embeddings, vectorstore, and LLM model. | |
| It then enters an interactive loop where the user can input queries and receive answers. Optionally, | |
| the source documents used to derive the answers can also be displayed. | |
| Parameters: | |
| - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'mps', 'cuda', etc. | |
| - show_sources (bool): Flag to determine whether to display the source documents used for answering. | |
| - use_history (bool): Flag to determine whether to use chat history or not. | |
| Notes: | |
| - Logging information includes the device type, whether source documents are displayed, and the use of history. | |
| - If the models directory does not exist, it creates a new one to store models. | |
| - The user can exit the interactive loop by entering "exit". | |
| - The source documents are displayed if the show_sources flag is set to True. | |
| """ | |
| logging.info(f"Running on: {device_type}") | |
| logging.info(f"Display Source Documents set to: {show_sources}") | |
| logging.info(f"Use history set to: {use_history}") | |
| # check if models directory do not exist, create a new one and store models here. | |
| if not os.path.exists(MODELS_PATH): | |
| os.mkdir(MODELS_PATH) | |
| qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type=model_type) | |
| # Interactive questions and answers | |
| while True: | |
| query = input("\nEnter a query: ") | |
| if query == "exit": | |
| break | |
| # Get the answer from the chain | |
| res = qa(query) | |
| answer, docs = res["result"], res["source_documents"] | |
| # Print the result | |
| print("\n\n> Question:") | |
| print(query) | |
| print("\n> Answer:") | |
| print(answer) | |
| if show_sources: # this is a flag that you can set to disable showing answers. | |
| # # Print the relevant sources used for the answer | |
| print("----------------------------------SOURCE DOCUMENTS---------------------------") | |
| for document in docs: | |
| print("\n> " + document.metadata["source"] + ":") | |
| print(document.page_content) | |
| print("----------------------------------SOURCE DOCUMENTS---------------------------") | |
| if __name__ == "__main__": | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO | |
| ) | |
| main() | |