uditk99 commited on
Commit
59ead7a
·
verified ·
1 Parent(s): 7289e75

Upload 3 files

Browse files
Files changed (3) hide show
  1. MSFT_1986-03-13_2025-02-04.csv +0 -0
  2. app.ipynb +551 -0
  3. requirements.txt +10 -0
MSFT_1986-03-13_2025-02-04.csv ADDED
The diff for this file is too large to render. See raw diff
 
app.ipynb ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "d:\\ConvAI_Code\\bits_mtech\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "# =============================================================================\n",
19
+ "# Imports & Setup\n",
20
+ "# =============================================================================\n",
21
+ "import os\n",
22
+ "import numpy as np\n",
23
+ "import pandas as pd\n",
24
+ "import faiss # For fast vector similarity search\n",
25
+ "from sentence_transformers import SentenceTransformer # For generating text embeddings\n",
26
+ "from rank_bm25 import BM25Okapi # For BM25 keyword-based retrieval\n",
27
+ "import spacy # For tokenization\n",
28
+ "from sklearn.metrics.pairwise import cosine_similarity # For computing cosine similarity\n",
29
+ "from sklearn.preprocessing import normalize # For normalizing BM25 scores\n",
30
+ "\n",
31
+ "# For the Gradio UI\n",
32
+ "import gradio as gr\n",
33
+ "\n",
34
+ "# For response generation using a small language model (we use FLAN-T5-Small)\n",
35
+ "from transformers import pipeline, set_seed\n",
36
+ "\n",
37
+ "# Set a random seed for reproducibility\n",
38
+ "set_seed(42)\n",
39
+ "\n",
40
+ "# Load SpaCy English model (make sure to download it with: python -m spacy download en_core_web_sm)\n",
41
+ "nlp = spacy.load(\"en_core_web_sm\")\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 3,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stdout",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "<class 'pandas.core.frame.DataFrame'>\n",
54
+ "RangeIndex: 9800 entries, 0 to 9799\n",
55
+ "Data columns (total 7 columns):\n",
56
+ " # Column Non-Null Count Dtype \n",
57
+ "--- ------ -------------- ----- \n",
58
+ " 0 Date 9800 non-null object \n",
59
+ " 1 Open 9800 non-null float64\n",
60
+ " 2 High 9800 non-null float64\n",
61
+ " 3 Low 9800 non-null float64\n",
62
+ " 4 Close 9800 non-null float64\n",
63
+ " 5 Adj Close 9800 non-null float64\n",
64
+ " 6 Volume 9800 non-null int64 \n",
65
+ "dtypes: float64(5), int64(1), object(1)\n",
66
+ "memory usage: 536.1+ KB\n",
67
+ "None\n",
68
+ " Year Open_Min Open_Max Close_Min Close_Max Avg_Volume \\\n",
69
+ "0 1986 0.088542 0.177083 0.090278 0.177083 3.620005e+07 \n",
70
+ "1 1987 0.165799 0.548611 0.165799 0.548611 9.454613e+07 \n",
71
+ "2 1988 0.319444 0.484375 0.319444 0.483507 6.906268e+07 \n",
72
+ "3 1989 0.322049 0.618056 0.322917 0.614583 7.735760e+07 \n",
73
+ "4 1990 0.591146 1.102431 0.598090 1.100694 7.408945e+07 \n",
74
+ "\n",
75
+ " Summary \n",
76
+ "0 In 1986.0, the stock opened between $0.09 and ... \n",
77
+ "1 In 1987.0, the stock opened between $0.17 and ... \n",
78
+ "2 In 1988.0, the stock opened between $0.32 and ... \n",
79
+ "3 In 1989.0, the stock opened between $0.32 and ... \n",
80
+ "4 In 1990.0, the stock opened between $0.59 and ... \n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "# =============================================================================\n",
86
+ "# 1. Data Collection & Preprocessing\n",
87
+ "# =============================================================================\n",
88
+ "# Load the CSV file containing financial data.\n",
89
+ "# (Make sure the CSV file \"MSFT_1986-03-13_2025-02-04.csv\" is in the \"data\" folder)\n",
90
+ "csv_file_path = r\"MSFT_1986-03-13_2025-02-04.csv\" # Adjust the path if necessary\n",
91
+ "# Load the CSV file into a DataFrame\n",
92
+ "df = pd.read_csv(csv_file_path)\n",
93
+ "\n",
94
+ "# Display basic info about the dataset\n",
95
+ "print(df.info())\n",
96
+ "\n",
97
+ "# Data Cleaning & Structuring\n",
98
+ "\n",
99
+ "# Convert 'Date' column to datetime format\n",
100
+ "df['Date'] = pd.to_datetime(df['Date'])\n",
101
+ "\n",
102
+ "# Sort data by Date\n",
103
+ "df = df.sort_values(by='Date')\n",
104
+ "\n",
105
+ "# Extract Year from Date\n",
106
+ "df['Year'] = df['Date'].dt.year\n",
107
+ "\n",
108
+ "# Aggregate data by Year to generate financial summaries\n",
109
+ "yearly_summary = df.groupby('Year').agg(\n",
110
+ " Open_Min=('Open', 'min'),\n",
111
+ " Open_Max=('Open', 'max'),\n",
112
+ " Close_Min=('Close', 'min'),\n",
113
+ " Close_Max=('Close', 'max'),\n",
114
+ " Avg_Volume=('Volume', 'mean')\n",
115
+ ").reset_index()\n",
116
+ "\n",
117
+ "# Create a textual summary for each year\n",
118
+ "yearly_summary['Summary'] = yearly_summary.apply(\n",
119
+ " lambda row: f\"In {row['Year']}, the stock opened between ${row['Open_Min']:.2f} and ${row['Open_Max']:.2f}, \"\n",
120
+ " f\"while closing between ${row['Close_Min']:.2f} and ${row['Close_Max']:.2f}. \"\n",
121
+ " f\"The average trading volume was {row['Avg_Volume']:,.0f} shares.\",\n",
122
+ " axis=1\n",
123
+ ")\n",
124
+ "\n",
125
+ "# Display the cleaned and structured data\n",
126
+ "print(yearly_summary.head()) # Use this for terminal/console\n",
127
+ "# yearly_summary.head() # Use this in Jupyter Notebook\n",
128
+ "\n"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 4,
134
+ "metadata": {},
135
+ "outputs": [
136
+ {
137
+ "data": {
138
+ "text/plain": [
139
+ "40"
140
+ ]
141
+ },
142
+ "execution_count": 4,
143
+ "metadata": {},
144
+ "output_type": "execute_result"
145
+ }
146
+ ],
147
+ "source": [
148
+ "# =============================================================================\n",
149
+ "# 2. Basic RAG Implementation\n",
150
+ "# =============================================================================\n",
151
+ "# Convert financial summaries into text chunks and generate vector embeddings.\n",
152
+ "embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
153
+ "\n",
154
+ "# Convert yearly financial summaries into vector embeddings\n",
155
+ "summary_texts = yearly_summary[\"Summary\"].tolist() # Extract summaries as text\n",
156
+ "summary_embeddings = embedding_model.encode(summary_texts, convert_to_numpy=True) # Generate embeddings\n",
157
+ "\n",
158
+ "# Store embeddings as a NumPy array for further processing\n",
159
+ "summary_embeddings.shape # This should be (num_years, embedding_size)\n",
160
+ "\n",
161
+ "# Define the dimension of embeddings (384 from MiniLM model)\n",
162
+ "embedding_dim = 384\n",
163
+ "\n",
164
+ "# Create a FAISS index (Flat index for now, can be optimized later)\n",
165
+ "faiss_index = faiss.IndexFlatL2(embedding_dim)\n",
166
+ "\n",
167
+ "# Convert embeddings to float32 (FAISS requires this format)\n",
168
+ "summary_embeddings = summary_embeddings.astype('float32')\n",
169
+ "\n",
170
+ "# Add embeddings to the FAISS index\n",
171
+ "faiss_index.add(summary_embeddings)\n",
172
+ "\n",
173
+ "# Store the year information for retrieval\n",
174
+ "year_map = {i: yearly_summary[\"Year\"].iloc[i] for i in range(len(yearly_summary))}\n",
175
+ "\n",
176
+ "# Verify that embeddings are stored successfully\n",
177
+ "faiss_index.ntotal\n"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 5,
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "name": "stdout",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "Merged summaries shape: (12, 2)\n"
190
+ ]
191
+ },
192
+ {
193
+ "data": {
194
+ "text/html": [
195
+ "<div>\n",
196
+ "<style scoped>\n",
197
+ " .dataframe tbody tr th:only-of-type {\n",
198
+ " vertical-align: middle;\n",
199
+ " }\n",
200
+ "\n",
201
+ " .dataframe tbody tr th {\n",
202
+ " vertical-align: top;\n",
203
+ " }\n",
204
+ "\n",
205
+ " .dataframe thead th {\n",
206
+ " text-align: right;\n",
207
+ " }\n",
208
+ "</style>\n",
209
+ "<table border=\"1\" class=\"dataframe\">\n",
210
+ " <thead>\n",
211
+ " <tr style=\"text-align: right;\">\n",
212
+ " <th></th>\n",
213
+ " <th>Year</th>\n",
214
+ " <th>Merged Summary</th>\n",
215
+ " </tr>\n",
216
+ " </thead>\n",
217
+ " <tbody>\n",
218
+ " <tr>\n",
219
+ " <th>0</th>\n",
220
+ " <td>1986</td>\n",
221
+ " <td>In 1986.0, the stock opened between $0.09 and ...</td>\n",
222
+ " </tr>\n",
223
+ " <tr>\n",
224
+ " <th>1</th>\n",
225
+ " <td>1990</td>\n",
226
+ " <td>In 1989.0, the stock opened between $0.32 and ...</td>\n",
227
+ " </tr>\n",
228
+ " <tr>\n",
229
+ " <th>2</th>\n",
230
+ " <td>1992</td>\n",
231
+ " <td>In 1991.0, the stock opened between $1.03 and ...</td>\n",
232
+ " </tr>\n",
233
+ " <tr>\n",
234
+ " <th>3</th>\n",
235
+ " <td>1996</td>\n",
236
+ " <td>In 1994.0, the stock opened between $2.45 and ...</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <th>4</th>\n",
240
+ " <td>1999</td>\n",
241
+ " <td>In 1997.0, the stock opened between $10.25 and...</td>\n",
242
+ " </tr>\n",
243
+ " </tbody>\n",
244
+ "</table>\n",
245
+ "</div>"
246
+ ],
247
+ "text/plain": [
248
+ " Year Merged Summary\n",
249
+ "0 1986 In 1986.0, the stock opened between $0.09 and ...\n",
250
+ "1 1990 In 1989.0, the stock opened between $0.32 and ...\n",
251
+ "2 1992 In 1991.0, the stock opened between $1.03 and ...\n",
252
+ "3 1996 In 1994.0, the stock opened between $2.45 and ...\n",
253
+ "4 1999 In 1997.0, the stock opened between $10.25 and..."
254
+ ]
255
+ },
256
+ "execution_count": 5,
257
+ "metadata": {},
258
+ "output_type": "execute_result"
259
+ }
260
+ ],
261
+ "source": [
262
+ "# =============================================================================\n",
263
+ "# 3. Advanced RAG Implementation\n",
264
+ "# =============================================================================\n",
265
+ "# 3.1: BM25 for Keyword-Based Search\n",
266
+ "# Tokenize each summary using SpaCy (tokens are converted to lowercase).\n",
267
+ "tokenized_summaries = [[token.text.lower() for token in nlp(summary)] for summary in summary_texts]\n",
268
+ "# Build the BM25 index.\n",
269
+ "bm25 = BM25Okapi(tokenized_summaries)\n",
270
+ "\n",
271
+ "# 3.2: Define Retrieval Functions\n",
272
+ "\n",
273
+ "def retrieve_similar_summaries(query_text, top_k=3):\n",
274
+ " \"\"\"\n",
275
+ " Retrieve similar financial summaries using FAISS vector search.\n",
276
+ " \"\"\"\n",
277
+ " query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
278
+ " distances, indices = faiss_index.search(query_embedding, top_k)\n",
279
+ " results = []\n",
280
+ " for idx in indices[0]:\n",
281
+ " results.append((year_map[idx], yearly_summary.iloc[idx][\"Summary\"]))\n",
282
+ " return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
283
+ "\n",
284
+ "def hybrid_retrieve(query_text, top_k=3, alpha=0.5):\n",
285
+ " \"\"\"\n",
286
+ " Hybrid retrieval combining FAISS (vector search) and BM25 (keyword search).\n",
287
+ " Scores are combined using the weighting factor 'alpha'.\n",
288
+ " \"\"\"\n",
289
+ " query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
290
+ " _, faiss_indices = faiss_index.search(query_embedding, top_k)\n",
291
+ " \n",
292
+ " bm25_scores = bm25.get_scores([token.text.lower() for token in nlp(query_text)])\n",
293
+ " bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]\n",
294
+ " \n",
295
+ " combined_scores = {}\n",
296
+ " for rank, idx in enumerate(faiss_indices[0]):\n",
297
+ " combined_scores[idx] = alpha * (top_k - rank)\n",
298
+ " bm25_norm_scores = normalize([bm25_scores])[0]\n",
299
+ " for rank, idx in enumerate(bm25_top_indices):\n",
300
+ " if idx in combined_scores:\n",
301
+ " combined_scores[idx] += (1 - alpha) * (top_k - rank)\n",
302
+ " else:\n",
303
+ " combined_scores[idx] = (1 - alpha) * (top_k - rank)\n",
304
+ " \n",
305
+ " sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)\n",
306
+ " results = [(year_map[idx], yearly_summary.iloc[idx][\"Summary\"]) for idx, _ in sorted_results]\n",
307
+ " return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
308
+ "\n",
309
+ "def adaptive_retrieve(query_text, top_k=3, alpha=0.5):\n",
310
+ " \"\"\"\n",
311
+ " Adaptive retrieval re-ranks results by combining FAISS and BM25 scores.\n",
312
+ " \"\"\"\n",
313
+ " query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
314
+ " _, faiss_indices = faiss_index.search(query_embedding, top_k)\n",
315
+ " \n",
316
+ " query_tokens = [token.text.lower() for token in nlp(query_text)]\n",
317
+ " bm25_scores = bm25.get_scores(query_tokens)\n",
318
+ " bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]\n",
319
+ " \n",
320
+ " faiss_scores = np.linspace(1, 0, num=top_k)\n",
321
+ " bm25_norm_scores = normalize([bm25_scores])[0]\n",
322
+ " \n",
323
+ " combined_scores = {}\n",
324
+ " for rank, idx in enumerate(faiss_indices[0]):\n",
325
+ " combined_scores[idx] = alpha * faiss_scores[rank]\n",
326
+ " for idx in bm25_top_indices:\n",
327
+ " if idx in combined_scores:\n",
328
+ " combined_scores[idx] += (1 - alpha) * bm25_norm_scores[idx]\n",
329
+ " else:\n",
330
+ " combined_scores[idx] = (1 - alpha) * bm25_norm_scores[idx]\n",
331
+ " \n",
332
+ " sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)\n",
333
+ " results = [(year_map[idx], yearly_summary.iloc[idx][\"Summary\"]) for idx, _ in sorted_results]\n",
334
+ " return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
335
+ "\n",
336
+ "def merge_similar_chunks(threshold=0.95):\n",
337
+ " \"\"\"\n",
338
+ " Chunk Merging: Merge similar financial summaries based on cosine similarity.\n",
339
+ " This reduces redundancy when multiple chunks are very similar.\n",
340
+ " \"\"\"\n",
341
+ " merged_summaries = []\n",
342
+ " used_indices = set()\n",
343
+ " for i in range(len(summary_embeddings)):\n",
344
+ " if i in used_indices:\n",
345
+ " continue\n",
346
+ " similarities = cosine_similarity([summary_embeddings[i]], summary_embeddings)[0]\n",
347
+ " similar_indices = np.where(similarities >= threshold)[0]\n",
348
+ " merged_text = \" \".join(yearly_summary.iloc[idx][\"Summary\"] for idx in similar_indices)\n",
349
+ " merged_summaries.append((yearly_summary.iloc[i][\"Year\"], merged_text))\n",
350
+ " used_indices.update(similar_indices)\n",
351
+ " return pd.DataFrame(merged_summaries, columns=[\"Year\", \"Merged Summary\"])\n",
352
+ "\n",
353
+ "# Optional: Check merged summaries for debugging.\n",
354
+ "merged_summary_df = merge_similar_chunks(threshold=0.95)\n",
355
+ "print(\"Merged summaries shape:\", merged_summary_df.shape)\n",
356
+ "merged_summary_df.head()\n"
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": 6,
362
+ "metadata": {},
363
+ "outputs": [
364
+ {
365
+ "name": "stdout",
366
+ "output_type": "stream",
367
+ "text": [
368
+ "* Running on local URL: http://127.0.0.1:7860\n",
369
+ "\n",
370
+ "To create a public link, set `share=True` in `launch()`.\n"
371
+ ]
372
+ },
373
+ {
374
+ "data": {
375
+ "text/html": [
376
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
377
+ ],
378
+ "text/plain": [
379
+ "<IPython.core.display.HTML object>"
380
+ ]
381
+ },
382
+ "metadata": {},
383
+ "output_type": "display_data"
384
+ },
385
+ {
386
+ "data": {
387
+ "text/plain": []
388
+ },
389
+ "execution_count": 6,
390
+ "metadata": {},
391
+ "output_type": "execute_result"
392
+ }
393
+ ],
394
+ "source": [
395
+ "# =============================================================================\n",
396
+ "# 4. UI Development using Gradio (Updated for newer API)\n",
397
+ "# =============================================================================\n",
398
+ "def generate_response(query_text, top_k=3, alpha=0.5):\n",
399
+ " \"\"\"\n",
400
+ " Generate an answer for a financial query by:\n",
401
+ " - Validating the query with an input-side guardrail.\n",
402
+ " - Retrieving context using adaptive retrieval.\n",
403
+ " - Generating a refined answer using FLAN-T5-Small.\n",
404
+ " Returns:\n",
405
+ " answer (str): The generated answer.\n",
406
+ " confidence (float): A mock confidence score based on BM25 scores.\n",
407
+ " \"\"\"\n",
408
+ " # -----------------------------------------------------------------------------\n",
409
+ " # Guard Rail Implementation (Input-Side)\n",
410
+ " # -----------------------------------------------------------------------------\n",
411
+ " financial_keywords = [\"open\", \"close\", \"stock\", \"price\", \"volume\", \"trading\"]\n",
412
+ " if not any(keyword in query_text.lower() for keyword in financial_keywords):\n",
413
+ " return (\"Guardrail Triggered: Your query does not appear to be related to financial data. Please ask a financial question.\"), 0.0\n",
414
+ "\n",
415
+ " # Retrieve context using adaptive retrieval.\n",
416
+ " context_df = adaptive_retrieve(query_text, top_k=top_k, alpha=alpha)\n",
417
+ " context_text = \" \".join(context_df[\"Summary\"].tolist())\n",
418
+ " \n",
419
+ " # Adjust the prompt to provide clear instructions.\n",
420
+ " prompt = f\"Given the following financial data:\\n{context_text}\\nAnswer this question: {query_text}.\"\n",
421
+ " \n",
422
+ " # Use FLAN-T5-Small for text generation via the text2text-generation pipeline.\n",
423
+ " # Increase max_length to allow longer answers.\n",
424
+ " generator = pipeline('text2text-generation', model='google/flan-t5-small')\n",
425
+ " generated = generator(prompt, max_length=200, num_return_sequences=1)\n",
426
+ " answer = generated[0]['generated_text'].replace(prompt, \"\").strip()\n",
427
+ " \n",
428
+ " # Fallback message if answer is empty.\n",
429
+ " if not answer:\n",
430
+ " answer = \"I'm sorry, I couldn't generate a clear answer. Please try rephrasing your question.\"\n",
431
+ " \n",
432
+ " # Compute a mock confidence score using normalized BM25 scores.\n",
433
+ " query_tokens = [token.text.lower() for token in nlp(query_text)]\n",
434
+ " bm25_scores = bm25.get_scores(query_tokens)\n",
435
+ " max_score = np.max(bm25_scores) if np.max(bm25_scores) > 0 else 1\n",
436
+ " confidence = round(np.mean(bm25_scores) / max_score, 2)\n",
437
+ " \n",
438
+ " return answer, confidence\n",
439
+ "\n",
440
+ "# Create the Gradio interface using the new API.\n",
441
+ "iface = gr.Interface(\n",
442
+ " fn=generate_response,\n",
443
+ " inputs=gr.Textbox(lines=2, placeholder=\"Enter your financial question here...\"),\n",
444
+ " outputs=[gr.Textbox(label=\"Answer\"), gr.Textbox(label=\"Confidence Score\")],\n",
445
+ " title=\"Financial RAG Model Interface\",\n",
446
+ " description=(\"Ask questions based on the company's financial summaries \"\n",
447
+ " )\n",
448
+ ")\n",
449
+ "\n",
450
+ "# Launch the Gradio interface.\n",
451
+ "iface.launch()\n"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": null,
457
+ "metadata": {},
458
+ "outputs": [
459
+ {
460
+ "name": "stderr",
461
+ "output_type": "stream",
462
+ "text": [
463
+ "Device set to use cpu\n"
464
+ ]
465
+ },
466
+ {
467
+ "name": "stdout",
468
+ "output_type": "stream",
469
+ "text": [
470
+ "Question: What year had the lowest stock prices?\n",
471
+ "Answer: I'm sorry, I couldn't generate a clear answer. Please try rephrasing your question.\n",
472
+ "Confidence Score: 1.0\n",
473
+ "--------------------------------------------------\n"
474
+ ]
475
+ },
476
+ {
477
+ "name": "stderr",
478
+ "output_type": "stream",
479
+ "text": [
480
+ "Device set to use cpu\n"
481
+ ]
482
+ },
483
+ {
484
+ "name": "stdout",
485
+ "output_type": "stream",
486
+ "text": [
487
+ "Question: How did the trading volume vary?\n",
488
+ "Answer: In 1994.0, the stock opened between $2.45 and $4.05, while closing between $2.46 and $4.04.\n",
489
+ "Confidence Score: 1.0\n",
490
+ "--------------------------------------------------\n",
491
+ "Question: What is the capital of France?\n",
492
+ "Answer: Guardrail Triggered: Your query does not appear to be related to financial data. Please ask a financial question.\n",
493
+ "Confidence Score: 0.0\n",
494
+ "--------------------------------------------------\n"
495
+ ]
496
+ },
497
+ {
498
+ "name": "stderr",
499
+ "output_type": "stream",
500
+ "text": [
501
+ "Device set to use cpu\n"
502
+ ]
503
+ }
504
+ ],
505
+ "source": [
506
+ "# =============================================================================\n",
507
+ "# 6. Testing & Validation (Updated)\n",
508
+ "# =============================================================================\n",
509
+ "def print_test_results(query_text, top_k=3, alpha=0.5):\n",
510
+ " answer, confidence = generate_response(query_text, top_k, alpha)\n",
511
+ " print(\"Question: \", query_text)\n",
512
+ " print(\"Answer: \", answer)\n",
513
+ " print(\"Confidence Score: \", confidence)\n",
514
+ " print(\"-\" * 50)\n",
515
+ "\n",
516
+ "# Test 1: High-confidence financial query.\n",
517
+ "query_high = \"What year had the lowest stock prices?\"\n",
518
+ "print_test_results(query_high)\n",
519
+ "\n",
520
+ "# Test 2: Low-confidence financial query.\n",
521
+ "query_low = \"How did the trading volume vary?\"\n",
522
+ "print_test_results(query_low)\n",
523
+ "\n",
524
+ "# Test 3: Irrelevant query (should trigger guardrail).\n",
525
+ "query_irrelevant = \"What is the capital of France?\"\n",
526
+ "print_test_results(query_irrelevant)\n"
527
+ ]
528
+ }
529
+ ],
530
+ "metadata": {
531
+ "kernelspec": {
532
+ "display_name": "Python 3",
533
+ "language": "python",
534
+ "name": "python3"
535
+ },
536
+ "language_info": {
537
+ "codemirror_mode": {
538
+ "name": "ipython",
539
+ "version": 3
540
+ },
541
+ "file_extension": ".py",
542
+ "mimetype": "text/x-python",
543
+ "name": "python",
544
+ "nbconvert_exporter": "python",
545
+ "pygments_lexer": "ipython3",
546
+ "version": "3.12.0"
547
+ }
548
+ },
549
+ "nbformat": 4,
550
+ "nbformat_minor": 2
551
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ numpy
3
+ faiss-cpu
4
+ sentence-transformers
5
+ rank_bm25
6
+ spacy
7
+ transformers
8
+ gradio
9
+ scikit-learn
10
+