Ali Moughnieh commited on
Commit
5446629
·
1 Parent(s): 62478f7

initial commit

Browse files
Files changed (5) hide show
  1. .gitignore +6 -0
  2. 1_curate_data.py +29 -0
  3. 2_ingest.py +68 -0
  4. app.py +88 -0
  5. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ data
2
+ .git
3
+ .idea
4
+ __pycache__
5
+ venv
6
+ .env
1_curate_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from datasets import load_dataset
4
+
5
+ full_dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split='train')
6
+
7
+ dataset = full_dataset.shuffle(seed=42).select(range(50000))
8
+
9
+ script_dir = os.getcwd()
10
+ data_folder = os.path.join(script_dir, 'data', 'raw_documents')
11
+
12
+ if not os.path.exists(data_folder):
13
+ os.makedirs(data_folder)
14
+
15
+ for article in dataset:
16
+ article_data = {
17
+ 'id': article['id'],
18
+ 'url': article['url'],
19
+ 'title': article['title'],
20
+ 'text': article['text'],
21
+ }
22
+ file_path = os.path.join(data_folder, f"{article['id']}.json")
23
+ if not os.path.exists(file_path):
24
+ with open(file_path, 'w', encoding='utf-8') as f:
25
+ print(f.name, 'does not exist. creating file..')
26
+ json.dump(article_data, f, indent=4)
27
+
28
+ if __name__ == '__main__':
29
+ pass
2_ingest.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from langchain_core.documents import Document
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_chroma import Chroma
7
+
8
+ embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
9
+
10
+ script_dir = os.path.dirname(os.path.abspath(__file__))
11
+ data_folder = os.path.join(script_dir, 'data', 'raw_documents')
12
+
13
+ files = os.listdir(data_folder)
14
+
15
+ db_path = os.path.join(script_dir, 'data', 'chroma_db')
16
+
17
+ if not os.path.exists(db_path):
18
+ document_to_store = []
19
+ for file in files:
20
+ with open(os.path.join(data_folder, file), 'r', encoding='utf-8') as f:
21
+ json_dict = json.load(f)
22
+ content = json_dict['text']
23
+ metadata = {key: value for key, value in json_dict.items() if key != 'text'}
24
+ document = Document(page_content=content,
25
+ metadata=metadata)
26
+ document_to_store.append(document)
27
+
28
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
29
+ texts = text_splitter.split_documents(document_to_store)
30
+ min_chunk_size = 50
31
+ long_texts = [doc for doc in texts if len(doc.page_content) > min_chunk_size]
32
+ print(f"Original number of chunks: {len(texts)}")
33
+ print(f"Number of chunks after filtering: {len(long_texts)}")
34
+
35
+ # creating vector database using filtered chunks
36
+ print('Creating the vector database...')
37
+ db = Chroma.from_documents(long_texts,
38
+ embedding_function,
39
+ persist_directory=db_path)
40
+
41
+ print('Finished creating the vector database.')
42
+
43
+ else:
44
+ print('Vector database already exists. Loading...')
45
+ db = Chroma(
46
+ persist_directory=db_path,
47
+ embedding_function=embedding_function
48
+ )
49
+ print('Vector database loaded')
50
+
51
+ print("Checking titles in the database...")
52
+
53
+ retrieved_items = db.get(
54
+ limit=1000000,
55
+ include=['metadatas']
56
+ )
57
+
58
+ unique_titles = set()
59
+ for metadata in retrieved_items['metadatas']:
60
+ if 'title' in metadata:
61
+ unique_titles.add((metadata['title'], metadata['id']))
62
+
63
+ print(f"\n--- {len(unique_titles)} Unique Article Titles Found ---")
64
+ for title in sorted(list(unique_titles)):
65
+ print(title)
66
+
67
+ if __name__ == '__main__':
68
+ pass
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ from langchain_chroma import Chroma
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from langchain.chains import create_retrieval_chain
9
+ from langchain.chains.combine_documents import create_stuff_documents_chain
10
+ from dotenv import load_dotenv
11
+ import os
12
+
13
+ load_dotenv()
14
+
15
+ st.title("AI-Powered Wikipedia Explorer")
16
+
17
+ @st.cache_resource
18
+ def load_chain():
19
+ script_dir = os.path.dirname(os.path.abspath(__file__))
20
+ db_path = os.path.join(script_dir, 'data')
21
+
22
+ persist_directory = os.path.join(db_path, 'chroma_db')
23
+
24
+ embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
25
+
26
+ db = Chroma(
27
+ persist_directory=persist_directory,
28
+ embedding_function=embedding_function
29
+ )
30
+ print(db._collection.metadata)
31
+
32
+ llm = ChatGoogleGenerativeAI(
33
+ model="gemini-2.5-flash-lite",
34
+ google_api_key=os.getenv("GOOGLE_API_KEY")
35
+ )
36
+
37
+ template = '''
38
+ Answer the question based only on the following knowledge base:
39
+ {context}
40
+
41
+ Question: {input}
42
+
43
+ Please remember, if the knowledge base does not include relevant information
44
+ pertaining to the question, do not provide information from your own
45
+ memory, only provide information from the given knowledge base.
46
+ '''
47
+ prompt = ChatPromptTemplate.from_template(template)
48
+
49
+ retriever = db.as_retriever(
50
+ search_type="similarity_score_threshold",
51
+ search_kwargs={'score_threshold': 0.3,
52
+ 'k': 6}
53
+ )
54
+
55
+ document_chain = create_stuff_documents_chain(llm, prompt)
56
+
57
+ retrieval_chain = create_retrieval_chain(retriever, document_chain)
58
+
59
+ return retrieval_chain
60
+
61
+ chain = load_chain()
62
+
63
+ user_question = st.text_input("Ask a question about the articles:")
64
+
65
+ if st.button("Get Answer"):
66
+ if user_question:
67
+ with st.spinner("Thinking..."):
68
+ response = chain.invoke({"input": user_question})
69
+
70
+ if not response["context"]:
71
+ st.header("Answer")
72
+ st.write("I'm sorry, I couldn't find any relevant information in the documents to answer your question.")
73
+ with st.expander("Show Sources"):
74
+ st.write("Number of documents: 0")
75
+ else:
76
+ st.header("Answer")
77
+ st.write(response["answer"])
78
+
79
+ with st.expander("Show Sources"):
80
+ for doc in response["context"]:
81
+ st.write(f"**Source:** {doc.metadata.get('title', 'Unknown Title')}, **ID:** {doc.metadata.get('id', 'Unknown ID')}")
82
+ st.write(f"**URL:** {doc.metadata.get('url', 'No URL')}")
83
+ st.write(f"**Content:** {doc.page_content}")
84
+ st.write("---")
85
+ st.write(f"Number of documents: {len(response['context'])}")
86
+ else:
87
+ st.warning("Please enter a question first.")
88
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ datasets
3
+ langchain
4
+ langchain-google-genai
5
+ langchain-chroma
6
+ langchain-huggingface
7
+ langchain-text-splitters
8
+ sentence-transformers
9
+ python-dotenv