Spaces:
Build error
Build error
| # %% | |
| # git clone https://huggingface.co/nyanko7/LLaMA-7B | |
| # python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu117/torch2.00/index.html | |
| # apt-get update && apt-get install ffmpeg libsm6 libxext6 -y | |
| from transformers import LlamaForCausalLM, LlamaTokenizer | |
| from langchain.embeddings import LlamaCppEmbeddings, HuggingFaceInstructEmbeddings, OpenAIEmbeddings | |
| from langchain.llms import LlamaCpp, HuggingFacePipeline | |
| from langchain.vectorstores import Chroma | |
| from transformers import pipeline | |
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import streamlit as st | |
| import cloudpickle | |
| import os | |
| from langchain.chains import RetrievalQA | |
| from langchain.indexes import VectorstoreIndexCreator | |
| from langchain.llms import OpenAI | |
| import multiprocessing | |
| from chromadb.config import Settings | |
| import chromadb | |
| import pathlib | |
| current_path = str( pathlib.Path(__file__).parent.resolve() ) | |
| print(current_path) | |
| persist_directory = current_path + "/VectorStore" | |
| # %% | |
| def load_cpu_model(): | |
| """Does not work atm, bc cpu model is not persisted""" | |
| model_path= "./mymodels/LLaMA-7B/ggml-model-q4_0.bin" | |
| device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} | |
| llm = LlamaCpp( | |
| model_path=model_path, | |
| n_ctx=6000, | |
| n_threads=multiprocessing.cpu_count(), | |
| temperature=0.6, | |
| top_p=0.95 | |
| ) | |
| llama_embeddings = LlamaCppEmbeddings(model_path=model_path) | |
| return llm | |
| def load_gpu_model(used_model): | |
| torch.cuda.empty_cache() | |
| tokenizer = LlamaTokenizer.from_pretrained(used_model) | |
| if not torch.cuda.is_available(): | |
| device_map = { | |
| "": "cpu" | |
| } | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) | |
| torch_dtype=torch.float32 | |
| load_in_8bit=False | |
| else: | |
| device_map="auto" | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) #atm no offload, bc device_map="auto" | |
| base_model = LlamaForCausalLM.from_pretrained( | |
| used_model, | |
| device_map=device_map, | |
| offload_folder=current_path + "/models_gpt/", | |
| low_cpu_mem_usage=True, | |
| quantization_config=quantization_config, | |
| cache_dir = current_path + "/mymodels/" | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=base_model, | |
| tokenizer=tokenizer, | |
| max_length=8000, | |
| temperature=0.6, | |
| top_p=0.95, | |
| repetition_penalty=1.2 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| return llm | |
| #@st.cache_resource | |
| def load_openai_model(temperature=0.9): | |
| return OpenAI(temperature=temperature) | |
| def load_openai_embedding(): | |
| return OpenAIEmbeddings() | |
| def load_embedding(model_name): | |
| embeddings = HuggingFaceInstructEmbeddings( | |
| query_instruction="Represent the query for retrieval: ", | |
| model_name = model_name, | |
| cache_folder=current_path + "/mymodels/" | |
| ) | |
| return embeddings | |
| def load_vectorstore(model_name, collection, metadata): | |
| embeddings = load_embedding(model_name) | |
| client_settings = Settings( | |
| chroma_db_impl="duckdb+parquet", | |
| persist_directory=persist_directory, | |
| anonymized_telemetry=False | |
| ) | |
| vectorstore = Chroma( | |
| collection_name=collection, | |
| embedding_function=embeddings, | |
| client_settings=client_settings, | |
| persist_directory=persist_directory, | |
| collection_metadata=metadata | |
| ) | |
| return vectorstore | |
| def create_chain(_llm, collection, model_name, metadata): | |
| vectorstore = load_vectorstore(model_name, collection, metadata=metadata) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) | |
| chain = RetrievalQA.from_chain_type(llm=_llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | |
| return chain | |
| # %% | |