uditk99 commited on
Commit
ea40f87
·
verified ·
1 Parent(s): 3b1034c

Added Financial RAG Chatbot

Browse files
MSFT_1986-03-13_2025-02-04.csv ADDED
The diff for this file is too large to render. See raw diff
 
convai_assignment.ipynb ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 18,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# =============================================================================\n",
10
+ "# Imports & Setup\n",
11
+ "# =============================================================================\n",
12
+ "import os\n",
13
+ "import numpy as np\n",
14
+ "import pandas as pd\n",
15
+ "import faiss # For fast vector similarity search\n",
16
+ "from sentence_transformers import SentenceTransformer # For generating text embeddings\n",
17
+ "from rank_bm25 import BM25Okapi # For BM25 keyword-based retrieval\n",
18
+ "import spacy # For tokenization\n",
19
+ "from sklearn.metrics.pairwise import cosine_similarity # For computing cosine similarity\n",
20
+ "from sklearn.preprocessing import normalize # For normalizing BM25 scores\n",
21
+ "\n",
22
+ "# For the Gradio UI\n",
23
+ "import gradio as gr\n",
24
+ "\n",
25
+ "# For response generation using a small language model (we use FLAN-T5-Small)\n",
26
+ "from transformers import pipeline, set_seed\n",
27
+ "\n",
28
+ "# Set a random seed for reproducibility\n",
29
+ "set_seed(42)\n",
30
+ "\n",
31
+ "# Load SpaCy English model (make sure to download it with: python -m spacy download en_core_web_sm)\n",
32
+ "nlp = spacy.load(\"en_core_web_sm\")\n"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 35,
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "name": "stdout",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "<class 'pandas.core.frame.DataFrame'>\n",
45
+ "RangeIndex: 9800 entries, 0 to 9799\n",
46
+ "Data columns (total 7 columns):\n",
47
+ " # Column Non-Null Count Dtype \n",
48
+ "--- ------ -------------- ----- \n",
49
+ " 0 Date 9800 non-null object \n",
50
+ " 1 Open 9800 non-null float64\n",
51
+ " 2 High 9800 non-null float64\n",
52
+ " 3 Low 9800 non-null float64\n",
53
+ " 4 Close 9800 non-null float64\n",
54
+ " 5 Adj Close 9800 non-null float64\n",
55
+ " 6 Volume 9800 non-null int64 \n",
56
+ "dtypes: float64(5), int64(1), object(1)\n",
57
+ "memory usage: 536.1+ KB\n",
58
+ "None\n",
59
+ " Year Open_Min Open_Max Close_Min Close_Max Avg_Volume \\\n",
60
+ "0 1986 0.088542 0.177083 0.090278 0.177083 3.620005e+07 \n",
61
+ "1 1987 0.165799 0.548611 0.165799 0.548611 9.454613e+07 \n",
62
+ "2 1988 0.319444 0.484375 0.319444 0.483507 6.906268e+07 \n",
63
+ "3 1989 0.322049 0.618056 0.322917 0.614583 7.735760e+07 \n",
64
+ "4 1990 0.591146 1.102431 0.598090 1.100694 7.408945e+07 \n",
65
+ "\n",
66
+ " Summary \n",
67
+ "0 In 1986.0, the stock opened between $0.09 and ... \n",
68
+ "1 In 1987.0, the stock opened between $0.17 and ... \n",
69
+ "2 In 1988.0, the stock opened between $0.32 and ... \n",
70
+ "3 In 1989.0, the stock opened between $0.32 and ... \n",
71
+ "4 In 1990.0, the stock opened between $0.59 and ... \n"
72
+ ]
73
+ }
74
+ ],
75
+ "source": [
76
+ "# =============================================================================\n",
77
+ "# 1. Data Collection & Preprocessing\n",
78
+ "# =============================================================================\n",
79
+ "# Load the CSV file containing financial data.\n",
80
+ "# (Make sure the CSV file \"MSFT_1986-03-13_2025-02-04.csv\" is in the \"data\" folder)\n",
81
+ "csv_file_path = r\"D:\\ConvAI_Code\\MSFT_1986-03-13_2025-02-04.csv\" # Adjust the path if necessary\n",
82
+ "# Load the CSV file into a DataFrame\n",
83
+ "df = pd.read_csv(csv_file_path)\n",
84
+ "\n",
85
+ "# Display basic info about the dataset\n",
86
+ "print(df.info())\n",
87
+ "\n",
88
+ "# Data Cleaning & Structuring\n",
89
+ "\n",
90
+ "# Convert 'Date' column to datetime format\n",
91
+ "df['Date'] = pd.to_datetime(df['Date'])\n",
92
+ "\n",
93
+ "# Sort data by Date\n",
94
+ "df = df.sort_values(by='Date')\n",
95
+ "\n",
96
+ "# Extract Year from Date\n",
97
+ "df['Year'] = df['Date'].dt.year\n",
98
+ "\n",
99
+ "# Aggregate data by Year to generate financial summaries\n",
100
+ "yearly_summary = df.groupby('Year').agg(\n",
101
+ " Open_Min=('Open', 'min'),\n",
102
+ " Open_Max=('Open', 'max'),\n",
103
+ " Close_Min=('Close', 'min'),\n",
104
+ " Close_Max=('Close', 'max'),\n",
105
+ " Avg_Volume=('Volume', 'mean')\n",
106
+ ").reset_index()\n",
107
+ "\n",
108
+ "# Create a textual summary for each year\n",
109
+ "yearly_summary['Summary'] = yearly_summary.apply(\n",
110
+ " lambda row: f\"In {row['Year']}, the stock opened between ${row['Open_Min']:.2f} and ${row['Open_Max']:.2f}, \"\n",
111
+ " f\"while closing between ${row['Close_Min']:.2f} and ${row['Close_Max']:.2f}. \"\n",
112
+ " f\"The average trading volume was {row['Avg_Volume']:,.0f} shares.\",\n",
113
+ " axis=1\n",
114
+ ")\n",
115
+ "\n",
116
+ "# Display the cleaned and structured data\n",
117
+ "print(yearly_summary.head()) # Use this for terminal/console\n",
118
+ "# yearly_summary.head() # Use this in Jupyter Notebook\n",
119
+ "\n"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 20,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "data": {
129
+ "text/plain": [
130
+ "40"
131
+ ]
132
+ },
133
+ "execution_count": 20,
134
+ "metadata": {},
135
+ "output_type": "execute_result"
136
+ }
137
+ ],
138
+ "source": [
139
+ "# =============================================================================\n",
140
+ "# 2. Basic RAG Implementation\n",
141
+ "# =============================================================================\n",
142
+ "# Convert financial summaries into text chunks and generate vector embeddings.\n",
143
+ "embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
144
+ "\n",
145
+ "# Convert yearly financial summaries into vector embeddings\n",
146
+ "summary_texts = yearly_summary[\"Summary\"].tolist() # Extract summaries as text\n",
147
+ "summary_embeddings = embedding_model.encode(summary_texts, convert_to_numpy=True) # Generate embeddings\n",
148
+ "\n",
149
+ "# Store embeddings as a NumPy array for further processing\n",
150
+ "summary_embeddings.shape # This should be (num_years, embedding_size)\n",
151
+ "\n",
152
+ "# Define the dimension of embeddings (384 from MiniLM model)\n",
153
+ "embedding_dim = 384\n",
154
+ "\n",
155
+ "# Create a FAISS index (Flat index for now, can be optimized later)\n",
156
+ "faiss_index = faiss.IndexFlatL2(embedding_dim)\n",
157
+ "\n",
158
+ "# Convert embeddings to float32 (FAISS requires this format)\n",
159
+ "summary_embeddings = summary_embeddings.astype('float32')\n",
160
+ "\n",
161
+ "# Add embeddings to the FAISS index\n",
162
+ "faiss_index.add(summary_embeddings)\n",
163
+ "\n",
164
+ "# Store the year information for retrieval\n",
165
+ "year_map = {i: yearly_summary[\"Year\"].iloc[i] for i in range(len(yearly_summary))}\n",
166
+ "\n",
167
+ "# Verify that embeddings are stored successfully\n",
168
+ "faiss_index.ntotal\n"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 21,
174
+ "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "name": "stdout",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "Merged summaries shape: (12, 2)\n"
181
+ ]
182
+ },
183
+ {
184
+ "data": {
185
+ "text/html": [
186
+ "<div>\n",
187
+ "<style scoped>\n",
188
+ " .dataframe tbody tr th:only-of-type {\n",
189
+ " vertical-align: middle;\n",
190
+ " }\n",
191
+ "\n",
192
+ " .dataframe tbody tr th {\n",
193
+ " vertical-align: top;\n",
194
+ " }\n",
195
+ "\n",
196
+ " .dataframe thead th {\n",
197
+ " text-align: right;\n",
198
+ " }\n",
199
+ "</style>\n",
200
+ "<table border=\"1\" class=\"dataframe\">\n",
201
+ " <thead>\n",
202
+ " <tr style=\"text-align: right;\">\n",
203
+ " <th></th>\n",
204
+ " <th>Year</th>\n",
205
+ " <th>Merged Summary</th>\n",
206
+ " </tr>\n",
207
+ " </thead>\n",
208
+ " <tbody>\n",
209
+ " <tr>\n",
210
+ " <th>0</th>\n",
211
+ " <td>1986</td>\n",
212
+ " <td>In 1986.0, the stock opened between $0.09 and ...</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <th>1</th>\n",
216
+ " <td>1990</td>\n",
217
+ " <td>In 1989.0, the stock opened between $0.32 and ...</td>\n",
218
+ " </tr>\n",
219
+ " <tr>\n",
220
+ " <th>2</th>\n",
221
+ " <td>1992</td>\n",
222
+ " <td>In 1991.0, the stock opened between $1.03 and ...</td>\n",
223
+ " </tr>\n",
224
+ " <tr>\n",
225
+ " <th>3</th>\n",
226
+ " <td>1996</td>\n",
227
+ " <td>In 1994.0, the stock opened between $2.45 and ...</td>\n",
228
+ " </tr>\n",
229
+ " <tr>\n",
230
+ " <th>4</th>\n",
231
+ " <td>1999</td>\n",
232
+ " <td>In 1997.0, the stock opened between $10.25 and...</td>\n",
233
+ " </tr>\n",
234
+ " </tbody>\n",
235
+ "</table>\n",
236
+ "</div>"
237
+ ],
238
+ "text/plain": [
239
+ " Year Merged Summary\n",
240
+ "0 1986 In 1986.0, the stock opened between $0.09 and ...\n",
241
+ "1 1990 In 1989.0, the stock opened between $0.32 and ...\n",
242
+ "2 1992 In 1991.0, the stock opened between $1.03 and ...\n",
243
+ "3 1996 In 1994.0, the stock opened between $2.45 and ...\n",
244
+ "4 1999 In 1997.0, the stock opened between $10.25 and..."
245
+ ]
246
+ },
247
+ "execution_count": 21,
248
+ "metadata": {},
249
+ "output_type": "execute_result"
250
+ }
251
+ ],
252
+ "source": [
253
+ "# =============================================================================\n",
254
+ "# 3. Advanced RAG Implementation\n",
255
+ "# =============================================================================\n",
256
+ "# 3.1: BM25 for Keyword-Based Search\n",
257
+ "# Tokenize each summary using SpaCy (tokens are converted to lowercase).\n",
258
+ "tokenized_summaries = [[token.text.lower() for token in nlp(summary)] for summary in summary_texts]\n",
259
+ "# Build the BM25 index.\n",
260
+ "bm25 = BM25Okapi(tokenized_summaries)\n",
261
+ "\n",
262
+ "# 3.2: Define Retrieval Functions\n",
263
+ "\n",
264
+ "def retrieve_similar_summaries(query_text, top_k=3):\n",
265
+ " \"\"\"\n",
266
+ " Retrieve similar financial summaries using FAISS vector search.\n",
267
+ " \"\"\"\n",
268
+ " query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
269
+ " distances, indices = faiss_index.search(query_embedding, top_k)\n",
270
+ " results = []\n",
271
+ " for idx in indices[0]:\n",
272
+ " results.append((year_map[idx], yearly_summary.iloc[idx][\"Summary\"]))\n",
273
+ " return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
274
+ "\n",
275
+ "def hybrid_retrieve(query_text, top_k=3, alpha=0.5):\n",
276
+ " \"\"\"\n",
277
+ " Hybrid retrieval combining FAISS (vector search) and BM25 (keyword search).\n",
278
+ " Scores are combined using the weighting factor 'alpha'.\n",
279
+ " \"\"\"\n",
280
+ " query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
281
+ " _, faiss_indices = faiss_index.search(query_embedding, top_k)\n",
282
+ " \n",
283
+ " bm25_scores = bm25.get_scores([token.text.lower() for token in nlp(query_text)])\n",
284
+ " bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]\n",
285
+ " \n",
286
+ " combined_scores = {}\n",
287
+ " for rank, idx in enumerate(faiss_indices[0]):\n",
288
+ " combined_scores[idx] = alpha * (top_k - rank)\n",
289
+ " bm25_norm_scores = normalize([bm25_scores])[0]\n",
290
+ " for rank, idx in enumerate(bm25_top_indices):\n",
291
+ " if idx in combined_scores:\n",
292
+ " combined_scores[idx] += (1 - alpha) * (top_k - rank)\n",
293
+ " else:\n",
294
+ " combined_scores[idx] = (1 - alpha) * (top_k - rank)\n",
295
+ " \n",
296
+ " sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)\n",
297
+ " results = [(year_map[idx], yearly_summary.iloc[idx][\"Summary\"]) for idx, _ in sorted_results]\n",
298
+ " return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
299
+ "\n",
300
+ "def adaptive_retrieve(query_text, top_k=3, alpha=0.5):\n",
301
+ " \"\"\"\n",
302
+ " Adaptive retrieval re-ranks results by combining FAISS and BM25 scores.\n",
303
+ " \"\"\"\n",
304
+ " query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
305
+ " _, faiss_indices = faiss_index.search(query_embedding, top_k)\n",
306
+ " \n",
307
+ " query_tokens = [token.text.lower() for token in nlp(query_text)]\n",
308
+ " bm25_scores = bm25.get_scores(query_tokens)\n",
309
+ " bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]\n",
310
+ " \n",
311
+ " faiss_scores = np.linspace(1, 0, num=top_k)\n",
312
+ " bm25_norm_scores = normalize([bm25_scores])[0]\n",
313
+ " \n",
314
+ " combined_scores = {}\n",
315
+ " for rank, idx in enumerate(faiss_indices[0]):\n",
316
+ " combined_scores[idx] = alpha * faiss_scores[rank]\n",
317
+ " for idx in bm25_top_indices:\n",
318
+ " if idx in combined_scores:\n",
319
+ " combined_scores[idx] += (1 - alpha) * bm25_norm_scores[idx]\n",
320
+ " else:\n",
321
+ " combined_scores[idx] = (1 - alpha) * bm25_norm_scores[idx]\n",
322
+ " \n",
323
+ " sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)\n",
324
+ " results = [(year_map[idx], yearly_summary.iloc[idx][\"Summary\"]) for idx, _ in sorted_results]\n",
325
+ " return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
326
+ "\n",
327
+ "def merge_similar_chunks(threshold=0.95):\n",
328
+ " \"\"\"\n",
329
+ " Chunk Merging: Merge similar financial summaries based on cosine similarity.\n",
330
+ " This reduces redundancy when multiple chunks are very similar.\n",
331
+ " \"\"\"\n",
332
+ " merged_summaries = []\n",
333
+ " used_indices = set()\n",
334
+ " for i in range(len(summary_embeddings)):\n",
335
+ " if i in used_indices:\n",
336
+ " continue\n",
337
+ " similarities = cosine_similarity([summary_embeddings[i]], summary_embeddings)[0]\n",
338
+ " similar_indices = np.where(similarities >= threshold)[0]\n",
339
+ " merged_text = \" \".join(yearly_summary.iloc[idx][\"Summary\"] for idx in similar_indices)\n",
340
+ " merged_summaries.append((yearly_summary.iloc[i][\"Year\"], merged_text))\n",
341
+ " used_indices.update(similar_indices)\n",
342
+ " return pd.DataFrame(merged_summaries, columns=[\"Year\", \"Merged Summary\"])\n",
343
+ "\n",
344
+ "# Optional: Check merged summaries for debugging.\n",
345
+ "merged_summary_df = merge_similar_chunks(threshold=0.95)\n",
346
+ "print(\"Merged summaries shape:\", merged_summary_df.shape)\n",
347
+ "merged_summary_df.head()\n"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": 34,
353
+ "metadata": {},
354
+ "outputs": [
355
+ {
356
+ "name": "stdout",
357
+ "output_type": "stream",
358
+ "text": [
359
+ "* Running on local URL: http://127.0.0.1:7864\n",
360
+ "\n",
361
+ "To create a public link, set `share=True` in `launch()`.\n"
362
+ ]
363
+ },
364
+ {
365
+ "data": {
366
+ "text/html": [
367
+ "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
368
+ ],
369
+ "text/plain": [
370
+ "<IPython.core.display.HTML object>"
371
+ ]
372
+ },
373
+ "metadata": {},
374
+ "output_type": "display_data"
375
+ },
376
+ {
377
+ "data": {
378
+ "text/plain": []
379
+ },
380
+ "execution_count": 34,
381
+ "metadata": {},
382
+ "output_type": "execute_result"
383
+ }
384
+ ],
385
+ "source": [
386
+ "# =============================================================================\n",
387
+ "# 4. UI Development using Gradio (Updated for newer API)\n",
388
+ "# =============================================================================\n",
389
+ "def generate_response(query_text, top_k=3, alpha=0.5):\n",
390
+ " \"\"\"\n",
391
+ " Generate an answer for a financial query by:\n",
392
+ " - Validating the query with an input-side guardrail.\n",
393
+ " - Retrieving context using adaptive retrieval.\n",
394
+ " - Generating a refined answer using FLAN-T5-Small.\n",
395
+ " Returns:\n",
396
+ " answer (str): The generated answer.\n",
397
+ " confidence (float): A mock confidence score based on BM25 scores.\n",
398
+ " \"\"\"\n",
399
+ " # -----------------------------------------------------------------------------\n",
400
+ " # Guard Rail Implementation (Input-Side)\n",
401
+ " # -----------------------------------------------------------------------------\n",
402
+ " financial_keywords = [\"open\", \"close\", \"stock\", \"price\", \"volume\", \"trading\"]\n",
403
+ " if not any(keyword in query_text.lower() for keyword in financial_keywords):\n",
404
+ " return (\"Guardrail Triggered: Your query does not appear to be related to financial data. Please ask a financial question.\"), 0.0\n",
405
+ "\n",
406
+ " # Retrieve context using adaptive retrieval.\n",
407
+ " context_df = adaptive_retrieve(query_text, top_k=top_k, alpha=alpha)\n",
408
+ " context_text = \" \".join(context_df[\"Summary\"].tolist())\n",
409
+ " \n",
410
+ " # Adjust the prompt to provide clear instructions.\n",
411
+ " prompt = f\"Given the following financial data:\\n{context_text}\\nAnswer this question: {query_text}.\"\n",
412
+ " \n",
413
+ " # Use FLAN-T5-Small for text generation via the text2text-generation pipeline.\n",
414
+ " # Increase max_length to allow longer answers.\n",
415
+ " generator = pipeline('text2text-generation', model='google/flan-t5-small')\n",
416
+ " generated = generator(prompt, max_length=200, num_return_sequences=1)\n",
417
+ " answer = generated[0]['generated_text'].replace(prompt, \"\").strip()\n",
418
+ " \n",
419
+ " # Fallback message if answer is empty.\n",
420
+ " if not answer:\n",
421
+ " answer = \"I'm sorry, I couldn't generate a clear answer. Please try rephrasing your question.\"\n",
422
+ " \n",
423
+ " # Compute a mock confidence score using normalized BM25 scores.\n",
424
+ " query_tokens = [token.text.lower() for token in nlp(query_text)]\n",
425
+ " bm25_scores = bm25.get_scores(query_tokens)\n",
426
+ " max_score = np.max(bm25_scores) if np.max(bm25_scores) > 0 else 1\n",
427
+ " confidence = round(np.mean(bm25_scores) / max_score, 2)\n",
428
+ " \n",
429
+ " return answer, confidence\n",
430
+ "\n",
431
+ "# Create the Gradio interface using the new API.\n",
432
+ "iface = gr.Interface(\n",
433
+ " fn=generate_response,\n",
434
+ " inputs=gr.Textbox(lines=2, placeholder=\"Enter your financial question here...\"),\n",
435
+ " outputs=[gr.Textbox(label=\"Answer\"), gr.Textbox(label=\"Confidence Score\")],\n",
436
+ " title=\"Financial RAG Model Interface\",\n",
437
+ " description=(\"Ask questions based on the company's financial summaries \"\n",
438
+ " )\n",
439
+ ")\n",
440
+ "\n",
441
+ "# Launch the Gradio interface.\n",
442
+ "iface.launch()\n"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "execution_count": 29,
448
+ "metadata": {},
449
+ "outputs": [
450
+ {
451
+ "name": "stderr",
452
+ "output_type": "stream",
453
+ "text": [
454
+ "Device set to use cpu\n"
455
+ ]
456
+ },
457
+ {
458
+ "name": "stdout",
459
+ "output_type": "stream",
460
+ "text": [
461
+ "Question: What year had the lowest stock prices?\n",
462
+ "Answer: I'm sorry, I couldn't generate a clear answer. Please try rephrasing your question.\n",
463
+ "Confidence Score: 1.0\n",
464
+ "--------------------------------------------------\n"
465
+ ]
466
+ },
467
+ {
468
+ "name": "stderr",
469
+ "output_type": "stream",
470
+ "text": [
471
+ "Device set to use cpu\n"
472
+ ]
473
+ },
474
+ {
475
+ "name": "stdout",
476
+ "output_type": "stream",
477
+ "text": [
478
+ "Question: How did the trading volume vary?\n",
479
+ "Answer: The average trading volume was 23,244,919 shares\n",
480
+ "Confidence Score: 1.0\n",
481
+ "--------------------------------------------------\n",
482
+ "Question: What is the capital of France?\n",
483
+ "Answer: Guardrail Triggered: Your query does not appear to be related to financial data. Please ask a financial question.\n",
484
+ "Confidence Score: 0.0\n",
485
+ "--------------------------------------------------\n"
486
+ ]
487
+ }
488
+ ],
489
+ "source": [
490
+ "# =============================================================================\n",
491
+ "# 6. Testing & Validation (Updated)\n",
492
+ "# =============================================================================\n",
493
+ "def print_test_results(query_text, top_k=3, alpha=0.5):\n",
494
+ " answer, confidence = generate_response(query_text, top_k, alpha)\n",
495
+ " print(\"Question: \", query_text)\n",
496
+ " print(\"Answer: \", answer)\n",
497
+ " print(\"Confidence Score: \", confidence)\n",
498
+ " print(\"-\" * 50)\n",
499
+ "\n",
500
+ "# Test 1: High-confidence financial query.\n",
501
+ "query_high = \"What year had the lowest stock prices?\"\n",
502
+ "print_test_results(query_high)\n",
503
+ "\n",
504
+ "# Test 2: Low-confidence financial query.\n",
505
+ "query_low = \"How did the trading volume vary?\"\n",
506
+ "print_test_results(query_low)\n",
507
+ "\n",
508
+ "# Test 3: Irrelevant query (should trigger guardrail).\n",
509
+ "query_irrelevant = \"What is the capital of France?\"\n",
510
+ "print_test_results(query_irrelevant)\n"
511
+ ]
512
+ }
513
+ ],
514
+ "metadata": {
515
+ "kernelspec": {
516
+ "display_name": "Python 3",
517
+ "language": "python",
518
+ "name": "python3"
519
+ },
520
+ "language_info": {
521
+ "codemirror_mode": {
522
+ "name": "ipython",
523
+ "version": 3
524
+ },
525
+ "file_extension": ".py",
526
+ "mimetype": "text/x-python",
527
+ "name": "python",
528
+ "nbconvert_exporter": "python",
529
+ "pygments_lexer": "ipython3",
530
+ "version": "3.12.0"
531
+ }
532
+ },
533
+ "nbformat": 4,
534
+ "nbformat_minor": 2
535
+ }
requirements.txt CHANGED
@@ -1 +1,10 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ numpy
3
+ faiss-cpu
4
+ sentence-transformers
5
+ rank_bm25
6
+ spacy
7
+ transformers
8
+ gradio
9
+ scikit-learn
10
+