svijayanand commited on
Commit
6da76e5
·
verified ·
1 Parent(s): aa28a4f

Upload 4 files

Browse files
Files changed (4) hide show
  1. __init__.py +0 -0
  2. ingest_data.py +104 -0
  3. requirements.txt +116 -0
  4. week2.py +105 -0
__init__.py ADDED
File without changes
ingest_data.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ from langchain_community.vectorstores import FAISS
5
+ from datasets import load_dataset
6
+ from langchain.document_loaders.csv_loader import (
7
+ CSVLoader,
8
+ ) # import to load our imdb.csv file
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_openai import OpenAIEmbeddings
11
+ from langchain.embeddings import CacheBackedEmbeddings
12
+ from langchain.storage import LocalFileStore
13
+ from langchain_community.vectorstores import FAISS
14
+
15
+ load_dotenv()
16
+ openai_api_key = os.getenv("OPENAI_API_KEY")
17
+ underlying_embeddings = OpenAIEmbeddings(api_key=openai_api_key)
18
+
19
+ def download_data_and_create_embedding():
20
+ # Download an IMDB datset from Hugging Face Hub, load the ShubhamChoksi/IMDB_Movies dataset
21
+ dataset = load_dataset("ShubhamChoksi/IMDB_Movies")
22
+ print(dataset)
23
+
24
+ # store imdb.csv from ShubhamChoksi/IMDB_Movies
25
+ dataset_dict = dataset
26
+ dataset_dict["train"].to_csv("imdb.csv")
27
+
28
+ # load the csv file exported into a document
29
+ loader = CSVLoader("imdb.csv") # TODO
30
+ data = loader.load() # TODO
31
+ print(len(data)) # ensure we have actually loaded data into a format LangChain can recognize
32
+
33
+ """# Chunk the loaded data to improve retrieval performance
34
+ In a RAG system, the model needs to be able to quickly and accurately retrieve relevant information
35
+ from a knowledge base or other data sources to assist in generating high-quality responses.
36
+ However, working with large, unstructured datasets can be computationally expensive and time-consuming,
37
+ especially during the retrieval process.
38
+
39
+ By splitting the data into these smaller, overlapping chunks, the RAG system can more efficiently search
40
+ and retrieve the most relevant information to include in the generated response. This can lead to improved performance,
41
+ as the model doesn't have to process the entire dataset at once, and can focus on the most relevant parts of the data.
42
+ """
43
+
44
+ # create a text splitter with 1000 character chunks and 100 character overlap?
45
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
46
+ chunked_documents = text_splitter.split_documents(
47
+ data
48
+ ) # TODO: How do we chunk the data?
49
+ print(len(chunked_documents)) # ensure we have actually split the data into chunks
50
+
51
+ """# Use OpenAI embeddings to create a vector store
52
+ The first step in creating a vector store is to create embeddings from the data that you want the RAG system to be able
53
+ to retrieve. This is done using an embedding model, which transforms text data into a high-dimensional vector representation.
54
+ Each piece of text (such as a document, paragraph, or sentence) is converted into a vector that captures its semantic meaning.
55
+ For this exercise, we will use OpenAI's embedding model.
56
+ """
57
+
58
+ openai_api_key = os.getenv("OPENAI_API_KEY")
59
+ # create our embedding model
60
+ embedding_model = OpenAIEmbeddings(
61
+ model="text-embedding-3-large", api_key=openai_api_key
62
+ )
63
+
64
+ """# Create embedder
65
+ We will create our embedder using the `CacheBackedEmbeddings` class. This class is designed to optimize the process of generating embeddings by
66
+ caching the results of expensive embedding computations. This caching mechanism prevents the need to recompute embeddings for the same text
67
+ multiple times, which can be computationally expensive and time-consuming.
68
+ """
69
+
70
+ # create a local file store to for our cached embeddings
71
+ store = LocalFileStore(
72
+ "./cache/"
73
+ )
74
+ embedder = CacheBackedEmbeddings.from_bytes_store(
75
+ underlying_embeddings, store, namespace=underlying_embeddings.model
76
+ )
77
+
78
+ # Create vector store using Facebook AI Similarity Search (FAISS)
79
+ vector_store = FAISS.from_documents(
80
+ documents=chunked_documents, embedding=embedder
81
+ ) # TODO: How do we create our vector store using FAISS?
82
+ print(vector_store.index.ntotal)
83
+
84
+
85
+ # save our vector store locally
86
+ vector_store.save_local("faiss_index")
87
+
88
+ query_embedding(vector_store=vector_store)
89
+
90
+ return vector_store
91
+
92
+ def query_embedding(vector_store) -> None:
93
+ # Ask your RAG system a question!
94
+ query = "What are some good sci-fi movies from the 1980s?"
95
+
96
+ # embed our query
97
+ embedded_query = underlying_embeddings.embed_query(query)
98
+ similar_documents = vector_store.similarity_search_by_vector(
99
+ embedded_query
100
+ ) # TODO: How do we do a similarity search to find documents similar to our query?
101
+
102
+ for page in similar_documents:
103
+ # Print the similar documents that the similarity search returns?
104
+ print(page.page_content)
requirements.txt ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ async-timeout==4.0.3
7
+ asyncer==0.0.2
8
+ attrs==23.2.0
9
+ backoff==2.2.1
10
+ bidict==0.23.1
11
+ black==24.4.2
12
+ boilerpy3==1.0.7
13
+ certifi==2024.2.2
14
+ chainlit==1.1.101
15
+ charset-normalizer==3.3.2
16
+ chevron==0.14.0
17
+ click==8.1.7
18
+ dataclasses-json==0.5.14
19
+ datasets==2.19.1
20
+ Deprecated==1.2.14
21
+ dill==0.3.8
22
+ distro==1.9.0
23
+ exceptiongroup==1.2.1
24
+ faiss-cpu==1.8.0
25
+ fastapi==0.110.3
26
+ fastapi-socketio==0.0.10
27
+ filelock==3.14.0
28
+ filetype==1.2.0
29
+ frozenlist==1.4.1
30
+ fsspec==2024.3.1
31
+ googleapis-common-protos==1.63.0
32
+ greenlet==3.0.3
33
+ grpcio==1.63.0
34
+ h11==0.14.0
35
+ haystack-ai==2.1.2
36
+ haystack-bm25==1.0.2
37
+ httpcore==1.0.5
38
+ httpx==0.27.0
39
+ huggingface-hub==0.23.0
40
+ idna==3.7
41
+ importlib-metadata==7.0.0
42
+ Jinja2==3.1.4
43
+ jsonpatch==1.33
44
+ jsonpointer==2.4
45
+ langchain==0.1.20
46
+ langchain-community==0.0.38
47
+ langchain-core==0.1.52
48
+ langchain-openai==0.1.7
49
+ langchain-text-splitters==0.0.2
50
+ langsmith==0.1.59
51
+ Lazify==0.4.0
52
+ lazy-imports==0.3.1
53
+ literalai==0.0.601
54
+ MarkupSafe==2.1.5
55
+ marshmallow==3.21.2
56
+ monotonic==1.6
57
+ more-itertools==10.2.0
58
+ multidict==6.0.5
59
+ multiprocess==0.70.16
60
+ mypy-extensions==1.0.0
61
+ nest-asyncio==1.6.0
62
+ networkx==3.3
63
+ numpy==1.26.4
64
+ openai==1.30.1
65
+ opentelemetry-api==1.24.0
66
+ opentelemetry-exporter-otlp==1.24.0
67
+ opentelemetry-exporter-otlp-proto-common==1.24.0
68
+ opentelemetry-exporter-otlp-proto-grpc==1.24.0
69
+ opentelemetry-exporter-otlp-proto-http==1.24.0
70
+ opentelemetry-instrumentation==0.45b0
71
+ opentelemetry-proto==1.24.0
72
+ opentelemetry-sdk==1.24.0
73
+ opentelemetry-semantic-conventions==0.45b0
74
+ orjson==3.10.3
75
+ packaging==23.2
76
+ pandas==2.2.2
77
+ pathspec==0.12.1
78
+ platformdirs==4.2.2
79
+ posthog==3.5.0
80
+ protobuf==4.25.3
81
+ pyarrow==16.1.0
82
+ pyarrow-hotfix==0.6
83
+ pydantic==2.7.1
84
+ pydantic_core==2.18.2
85
+ PyJWT==2.8.0
86
+ python-dateutil==2.9.0.post0
87
+ python-dotenv==1.0.1
88
+ python-engineio==4.9.1
89
+ python-multipart==0.0.9
90
+ python-socketio==5.11.2
91
+ pytz==2024.1
92
+ PyYAML==6.0.1
93
+ regex==2024.5.15
94
+ requests==2.31.0
95
+ simple-websocket==1.0.0
96
+ six==1.16.0
97
+ sniffio==1.3.1
98
+ SQLAlchemy==2.0.30
99
+ starlette==0.37.2
100
+ syncer==2.0.3
101
+ tenacity==8.3.0
102
+ tiktoken==0.7.0
103
+ tomli==2.0.1
104
+ tqdm==4.66.4
105
+ typing-inspect==0.9.0
106
+ typing_extensions==4.11.0
107
+ tzdata==2024.1
108
+ uptrace==1.24.0
109
+ urllib3==2.2.1
110
+ uvicorn==0.25.0
111
+ watchfiles==0.20.0
112
+ wrapt==1.16.0
113
+ wsproto==1.2.0
114
+ xxhash==3.4.1
115
+ yarl==1.9.4
116
+ zipp==3.18.2
week2.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from dotenv import load_dotenv
3
+ from pathlib import Path
4
+ from ingest_data import download_data_and_create_embedding
5
+
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_core.runnables.passthrough import RunnablePassthrough
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_openai import ChatOpenAI
11
+ from ingest_data import underlying_embeddings, openai_api_key
12
+
13
+ from langchain.chat_models import ChatOpenAI
14
+ from langchain.prompts import ChatPromptTemplate
15
+ from langchain.schema import StrOutputParser
16
+
17
+ import chainlit as cl
18
+
19
+ # load env variables
20
+ load_dotenv()
21
+
22
+ # Specify the path to the file you want to check
23
+ file_path = Path('./faiss_index/index.faiss')
24
+
25
+ # Check if the file exists
26
+ if file_path.exists():
27
+ print("Embeddings already done, use the saved index")
28
+ # Combine the retrieved data with the output of the LLM
29
+ vector_store = FAISS.load_local(
30
+ "faiss_index", underlying_embeddings, allow_dangerous_deserialization=True
31
+ )
32
+ else:
33
+ vector_store = download_data_and_create_embedding()
34
+
35
+
36
+ # create a prompt template to send to our LLM that will incorporate the documents from our retriever with the
37
+ # question we ask the chat model
38
+ prompt_template = ChatPromptTemplate.from_template(
39
+ "Answer the {question} based on the following {context}."
40
+ )
41
+
42
+ # create a retriever for our documents
43
+ retriever = vector_store.as_retriever()
44
+
45
+ # create a chat model / LLM
46
+ chat_model = ChatOpenAI(
47
+ model="gpt-4o-2024-05-13", temperature=0, api_key=openai_api_key
48
+ )
49
+
50
+ # create a parser to parse the output of our LLM
51
+ parser = StrOutputParser()
52
+
53
+ # 💻 Create the sequence (recipe)
54
+ runnable_chain = (
55
+ # TODO: How do we chain the output of our retriever, prompt, model and model output parser so that we can get a good answer to our query?
56
+ {"context": retriever, "question": RunnablePassthrough()}
57
+ | prompt_template
58
+ | chat_model
59
+ | StrOutputParser()
60
+ )
61
+
62
+
63
+ # Asynchronous execution (e.g., for a better a chatbot user experience)
64
+ async def call_chain_async(question):
65
+ output_chunks = await runnable_chain.ainvoke(question)
66
+ return output_chunks
67
+
68
+
69
+ # output_stream = asyncio.run(call_chain_async("What are some good sci-fi movies from the 1980s?"))
70
+ # print("".join(output_stream))
71
+
72
+ @cl.on_chat_start
73
+ async def on_chat_start():
74
+ model = ChatOpenAI(streaming=True)
75
+ prompt = ChatPromptTemplate.from_messages(
76
+ [
77
+ (
78
+ "system",
79
+ "You're a very knowledgeable historian who provides accurate and eloquent answers to historical questions.",
80
+ ),
81
+ ("human", "{question}"),
82
+ ]
83
+ )
84
+ runnable = prompt | model | StrOutputParser()
85
+ cl.user_session.set("runnable", runnable)
86
+
87
+
88
+ # @cl.on_message
89
+ # async def on_message(message: cl.Message):
90
+ # runnable = cl.user_session.get("runnable") # type: Runnable
91
+
92
+ # msg = cl.Message(content="")
93
+
94
+ # async for chunk in runnable.astream(
95
+ # {"question": message.content},
96
+ # config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
97
+ # ):
98
+ # await msg.stream_token(chunk)
99
+
100
+ # await msg.send()
101
+
102
+ @cl.on_message
103
+ async def main(question):
104
+ response = await call_chain_async(question.content)
105
+ await cl.Message(content=response).send()