xke commited on
Commit
849e183
·
1 Parent(s): 3c1ee37

initial version

Browse files
Files changed (3) hide show
  1. Dockerfile +36 -0
  2. app.py +109 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # reference: https://huggingface.co/spaces/key2xanadu/hello-world-docker
3
+
4
+ # working on localhost --
5
+ #
6
+ # FROM python:3.11
7
+ # RUN useradd -m -u 1000 user
8
+ # USER user
9
+ # ENV HOME=/home/user \
10
+ # PATH=/home/user/.local/bin:$PATH
11
+ # WORKDIR $HOME/app
12
+ # COPY --chown=user . $HOME/app
13
+ # COPY ./requirements.txt ~/app/requirements.txt
14
+ # RUN pip install -r requirements.txt
15
+ # COPY . .
16
+ # CMD ["chainlit", "run", "app.py", "--port", "7860"]
17
+
18
+ FROM python:3.11
19
+
20
+ RUN useradd -m -u 1000 user
21
+ USER user
22
+
23
+ ENV HOME=/code/user \
24
+ PATH=/code/user/.local/bin:$PATH
25
+
26
+ WORKDIR $HOME/app
27
+ COPY --chown=user . $HOME/app
28
+
29
+ COPY ./requirements.txt $HOME/app/requirements.txt
30
+
31
+
32
+ RUN pip install --no-cache-dir --upgrade -r $HOME/app/requirements.txt
33
+
34
+ COPY . .
35
+
36
+ CMD ["chainlit", "run", "app.py", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ from datasets import load_dataset
3
+ from langchain_community.document_loaders import CSVLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_openai import OpenAIEmbeddings
6
+ from langchain.embeddings import CacheBackedEmbeddings
7
+ from langchain.storage import LocalFileStore
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_core.runnables.base import RunnableSequence
10
+ from langchain_core.runnables.passthrough import RunnablePassthrough
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.prompts import ChatPromptTemplate
13
+ from langchain_openai import ChatOpenAI
14
+ from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
15
+ from langchain.callbacks.base import BaseCallbackHandler
16
+
17
+ def setup_data():
18
+ dataset = load_dataset("ShubhamChoksi/IMDB_Movies")
19
+ dataset_dict = dataset
20
+ dataset_dict["train"].to_csv("imdb.csv")
21
+
22
+ loader = CSVLoader(file_path="imdb.csv")
23
+ data = loader.load()
24
+
25
+ text_splitter = RecursiveCharacterTextSplitter(
26
+ chunk_size=1000,
27
+ chunk_overlap=100
28
+ )
29
+
30
+ chunked_documents = text_splitter.split_documents(data)
31
+ embedding_model = OpenAIEmbeddings()
32
+
33
+ store = LocalFileStore("./cache/")
34
+ embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model)
35
+
36
+ vector_store = FAISS.from_documents(chunked_documents, embedder)
37
+ vector_store.save_local("faiss_index")
38
+
39
+ return vector_store
40
+
41
+
42
+ doc_search = setup_data()
43
+ model = ChatOpenAI(model_name="gpt-4o", temperature=0, streaming=True)
44
+
45
+
46
+ @cl.on_chat_start
47
+ async def on_chat_start():
48
+ template = """Answer the question based only on the following context:
49
+
50
+ {context}
51
+
52
+ Question: {question}
53
+ """
54
+ prompt = ChatPromptTemplate.from_template(template)
55
+
56
+ def format_docs(docs):
57
+ return "\n\n".join([d.page_content for d in docs])
58
+
59
+ retriever = doc_search.as_retriever()
60
+
61
+ runnable = (
62
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
63
+ | prompt
64
+ | model
65
+ | StrOutputParser()
66
+ )
67
+
68
+ cl.user_session.set("runnable", runnable)
69
+
70
+
71
+ @cl.on_message
72
+ async def on_message(message: cl.Message):
73
+ runnable = cl.user_session.get("runnable") # type: Runnable
74
+ msg = cl.Message(content="")
75
+
76
+ class PostMessageHandler(BaseCallbackHandler):
77
+ """
78
+ Callback handler for handling the retriever and LLM processes.
79
+ Used to post the sources of the retrieved documents as a Chainlit element.
80
+ """
81
+
82
+ def __init__(self, msg: cl.Message):
83
+ BaseCallbackHandler.__init__(self)
84
+ self.msg = msg
85
+ self.sources = set() # To store unique pairs
86
+
87
+ def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
88
+ for d in documents:
89
+ source_page_pair = (d.metadata['source'], d.metadata['page'])
90
+ self.sources.add(source_page_pair) # Add unique pairs to the set
91
+
92
+ def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
93
+ if len(self.sources):
94
+ sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
95
+ self.msg.elements.append(
96
+ cl.Text(name="Sources", content=sources_text, display="inline")
97
+ )
98
+
99
+ async with cl.Step(type="run", name="QA Assistant"):
100
+ async for chunk in runnable.astream(
101
+ message.content,
102
+ config=RunnableConfig(callbacks=[
103
+ cl.LangchainCallbackHandler(),
104
+ PostMessageHandler(msg)
105
+ ]),
106
+ ):
107
+ await msg.stream_token(chunk)
108
+
109
+ await msg.send()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ langchain
3
+ langchain-community
4
+ langchain_openai
5
+ faiss-cpu
6
+ tiktoken
7
+ chainlit
8
+