KennyAI01 commited on
Commit
79e9767
·
verified ·
1 Parent(s): e6f19f8

Initialise the app

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit app
3
+ """
4
+ try:
5
+ __import__('pysqlite3')
6
+ import sys
7
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
8
+ except:
9
+ pass
10
+
11
+ import os
12
+ import streamlit as st
13
+ from ragvizexpander import RAGVizChain
14
+ from ragvizexpander.llms import *
15
+ from ragvizexpander.embeddings import *
16
+ from ragvizexpander.splitters import RecursiveChar2TokenSplitter
17
+
18
+ st.set_page_config(
19
+ page_title="RAGVizExpander Demo",
20
+ page_icon="🔬",
21
+ layout="wide"
22
+ )
23
+
24
+ os.environ['OPENAI_API_KEY'] = st.secrets["OPENAI_API_KEY"]
25
+ os.environ['HF_API_KEY'] = st.secrets["HF_API_KEY"]
26
+
27
+
28
+ if "chart" not in st.session_state:
29
+ st.session_state['chart'] = None
30
+
31
+ if "loaded" not in st.session_state:
32
+ st.session_state['loaded'] = False
33
+
34
+ st.title("RAGVizExpander Demo🔬")
35
+ st.markdown("📦 More details can be found at the GitHub repo [here](https://github.com/KKenny0/RAGVizExpander)")
36
+
37
+ if not st.session_state['loaded']:
38
+ main_page = st.empty()
39
+ main_button = st.empty()
40
+ with main_page.container():
41
+ uploaded_file = st.file_uploader("Upload your file",
42
+ label_visibility="collapsed",
43
+ type=['pdf', 'docx', 'txt', 'pptx'])
44
+
45
+ # --- setting llm model
46
+ st.markdown("### Settings for *LLM* model")
47
+
48
+ st.session_state["llm_model_type"] = st.radio("Select type of llm model",
49
+ ["OpenAI", "Ollama"],
50
+ horizontal=True)
51
+
52
+ if st.session_state["llm_model_type"] == "OpenAI":
53
+ st.session_state["openai_llm_base_url"] = st.text_input("Enter OpenAI LLM API Base")
54
+ st.session_state["openai_llm_api_key"] = st.text_input("Enter OpenAI LLM API Key")
55
+ st.session_state["openai_llm_model"] = st.text_input("Enter OpenAI LLM model name")
56
+ st.session_state["chosen_llm_model"] = ChatOpenAI(
57
+ base_url=st.session_state["openai_llm_base_url"],
58
+ api_key=st.session_state["openai_llm_api_key"],
59
+ model_name=st.session_state["openai_llm_model"],
60
+ )
61
+ else:
62
+ st.session_state["ollama_llm_model"] = st.text_input("Enter Ollama model name")
63
+ st.session_state["chosen_llm_model"] = ChatOllama(model_name=st.session_state["ollama_llm_model"])
64
+
65
+ st.markdown("""---""")
66
+
67
+ # --- setting embedding model
68
+ st.markdown("### Settings for *EMBEDDING* model")
69
+
70
+ st.session_state["embedding_model_type"] = st.radio("Select type of embedding model",
71
+ ["OpenAI", "SentenceTransformer", "HuggingFace", "TEI"],
72
+ horizontal=True)
73
+
74
+ if st.session_state["embedding_model_type"] == "OpenAI":
75
+ st.session_state["openai_embed_model"] = st.selectbox("Select embedding model",
76
+ ["text-embedding-3-small",
77
+ "text-embedding-3-large",
78
+ "text-embedding-ada-002"])
79
+ st.session_state["openai_embed_api_key"] = st.text_input("Enter OpenAI Embedding API Key")
80
+ st.session_state["openai_embed_api_base"] = st.text_input("Enter OpenAI Embedding API Base")
81
+ st.session_state["chosen_embedding_model"] = OpenAIEmbeddings(
82
+ api_base=st.session_state["openai_embed_api_base"],
83
+ api_key=st.session_state["openai_embed_api_key"],
84
+ model_name=st.session_state["openai_embed_model"],
85
+ )
86
+
87
+ elif st.session_state["embedding_model_type"] == "HuggingFace":
88
+ st.session_state["hf_embed_model"] = st.text_input("Enter HF repository name")
89
+ st.session_state["hf_api_key"] = st.text_input("Enter HF API key")
90
+ st.session_state["chosen_embedding_model"] = HuggingFaceEmbeddings(
91
+ model_name=st.session_state["hf_embed_model"],
92
+ api_key=st.session_state["hf_api_key"]
93
+ )
94
+
95
+ else:
96
+ st.session_state["tei_api_url"] = st.text_input("Enter TEI(Text-Embedding-Inference) api url")
97
+ st.session_state["chosen_embedding_model"] = TEIEmbeddings(
98
+ api_url=st.session_state["tei_api_url"]
99
+ )
100
+
101
+ st.markdown("""---""")
102
+
103
+ # --- setting chunking parameters
104
+ st.markdown("### Settings for *CHUNKING* model")
105
+ st.session_state["chunk_size"] = st.number_input("Chunk size", value=500, min_value=100, max_value=1000, step=100)
106
+ st.session_state["chunk_overlap"] = st.number_input("Chunk overlap", value=0, min_value=0, max_value=100, step=10)
107
+ st.session_state["split_func"] = RecursiveChar2TokenSplitter(
108
+ chunk_size=st.session_state["chunk_size"],
109
+ chunk_overlap=st.session_state["chunk_overlap"],
110
+ )
111
+
112
+ if st.button("Build Vector DB"):
113
+ st.session_state["client"] = RAGVizChain(embedding_model=st.session_state["chosen_embedding_model"],
114
+ llm=st.session_state["chosen_llm_model"],
115
+ split_func=st.session_state["split_func"])
116
+ main_page.empty()
117
+ main_button.empty()
118
+ with st.spinner("Building Vector DB"):
119
+ st.session_state["client"].load_data(uploaded_file,)
120
+ st.session_state['loaded'] = True
121
+ st.rerun()
122
+ else:
123
+ col1, col2 = st.columns(2)
124
+ st.session_state['query'] = col1.text_area("Enter your query here")
125
+ st.session_state['technique'] = col1.radio("Select retrival technique", ["naive", "HyAE", "multi_qns"], horizontal=True)
126
+ st.session_state['top_k'] = col1.number_input("Top k", value=5, min_value=1, max_value=10, step=1)
127
+ if col1.button("Execute Query"):
128
+ st.session_state['chart'] = st.session_state["client"].visualize_query(st.session_state['query'], retrieval_method=st.session_state['technique'], top_k=st.session_state['top_k'])
129
+ if st.session_state['chart'] is not None:
130
+ col2.plotly_chart(st.session_state['chart'])
131
+
132
+ if col1.button("Reset Application"):
133
+ st.session_state['loaded'] = False
134
+ st.rerun()