Spaces:
				
			
			
	
			
			
		No application file
		
	
	
	
			
			
	
	
	
	
		
		
		No application file
		
	Upload app.ipynb
Browse files
    	
        app.ipynb
    ADDED
    
    | @@ -0,0 +1,535 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 18,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [],
         | 
| 8 | 
            +
               "source": [
         | 
| 9 | 
            +
                "# =============================================================================\n",
         | 
| 10 | 
            +
                "# Imports & Setup\n",
         | 
| 11 | 
            +
                "# =============================================================================\n",
         | 
| 12 | 
            +
                "import os\n",
         | 
| 13 | 
            +
                "import numpy as np\n",
         | 
| 14 | 
            +
                "import pandas as pd\n",
         | 
| 15 | 
            +
                "import faiss  # For fast vector similarity search\n",
         | 
| 16 | 
            +
                "from sentence_transformers import SentenceTransformer  # For generating text embeddings\n",
         | 
| 17 | 
            +
                "from rank_bm25 import BM25Okapi  # For BM25 keyword-based retrieval\n",
         | 
| 18 | 
            +
                "import spacy  # For tokenization\n",
         | 
| 19 | 
            +
                "from sklearn.metrics.pairwise import cosine_similarity  # For computing cosine similarity\n",
         | 
| 20 | 
            +
                "from sklearn.preprocessing import normalize  # For normalizing BM25 scores\n",
         | 
| 21 | 
            +
                "\n",
         | 
| 22 | 
            +
                "# For the Gradio UI\n",
         | 
| 23 | 
            +
                "import gradio as gr\n",
         | 
| 24 | 
            +
                "\n",
         | 
| 25 | 
            +
                "# For response generation using a small language model (we use FLAN-T5-Small)\n",
         | 
| 26 | 
            +
                "from transformers import pipeline, set_seed\n",
         | 
| 27 | 
            +
                "\n",
         | 
| 28 | 
            +
                "# Set a random seed for reproducibility\n",
         | 
| 29 | 
            +
                "set_seed(42)\n",
         | 
| 30 | 
            +
                "\n",
         | 
| 31 | 
            +
                "# Load SpaCy English model (make sure to download it with: python -m spacy download en_core_web_sm)\n",
         | 
| 32 | 
            +
                "nlp = spacy.load(\"en_core_web_sm\")\n"
         | 
| 33 | 
            +
               ]
         | 
| 34 | 
            +
              },
         | 
| 35 | 
            +
              {
         | 
| 36 | 
            +
               "cell_type": "code",
         | 
| 37 | 
            +
               "execution_count": 35,
         | 
| 38 | 
            +
               "metadata": {},
         | 
| 39 | 
            +
               "outputs": [
         | 
| 40 | 
            +
                {
         | 
| 41 | 
            +
                 "name": "stdout",
         | 
| 42 | 
            +
                 "output_type": "stream",
         | 
| 43 | 
            +
                 "text": [
         | 
| 44 | 
            +
                  "<class 'pandas.core.frame.DataFrame'>\n",
         | 
| 45 | 
            +
                  "RangeIndex: 9800 entries, 0 to 9799\n",
         | 
| 46 | 
            +
                  "Data columns (total 7 columns):\n",
         | 
| 47 | 
            +
                  " #   Column     Non-Null Count  Dtype  \n",
         | 
| 48 | 
            +
                  "---  ------     --------------  -----  \n",
         | 
| 49 | 
            +
                  " 0   Date       9800 non-null   object \n",
         | 
| 50 | 
            +
                  " 1   Open       9800 non-null   float64\n",
         | 
| 51 | 
            +
                  " 2   High       9800 non-null   float64\n",
         | 
| 52 | 
            +
                  " 3   Low        9800 non-null   float64\n",
         | 
| 53 | 
            +
                  " 4   Close      9800 non-null   float64\n",
         | 
| 54 | 
            +
                  " 5   Adj Close  9800 non-null   float64\n",
         | 
| 55 | 
            +
                  " 6   Volume     9800 non-null   int64  \n",
         | 
| 56 | 
            +
                  "dtypes: float64(5), int64(1), object(1)\n",
         | 
| 57 | 
            +
                  "memory usage: 536.1+ KB\n",
         | 
| 58 | 
            +
                  "None\n",
         | 
| 59 | 
            +
                  "   Year  Open_Min  Open_Max  Close_Min  Close_Max    Avg_Volume  \\\n",
         | 
| 60 | 
            +
                  "0  1986  0.088542  0.177083   0.090278   0.177083  3.620005e+07   \n",
         | 
| 61 | 
            +
                  "1  1987  0.165799  0.548611   0.165799   0.548611  9.454613e+07   \n",
         | 
| 62 | 
            +
                  "2  1988  0.319444  0.484375   0.319444   0.483507  6.906268e+07   \n",
         | 
| 63 | 
            +
                  "3  1989  0.322049  0.618056   0.322917   0.614583  7.735760e+07   \n",
         | 
| 64 | 
            +
                  "4  1990  0.591146  1.102431   0.598090   1.100694  7.408945e+07   \n",
         | 
| 65 | 
            +
                  "\n",
         | 
| 66 | 
            +
                  "                                             Summary  \n",
         | 
| 67 | 
            +
                  "0  In 1986.0, the stock opened between $0.09 and ...  \n",
         | 
| 68 | 
            +
                  "1  In 1987.0, the stock opened between $0.17 and ...  \n",
         | 
| 69 | 
            +
                  "2  In 1988.0, the stock opened between $0.32 and ...  \n",
         | 
| 70 | 
            +
                  "3  In 1989.0, the stock opened between $0.32 and ...  \n",
         | 
| 71 | 
            +
                  "4  In 1990.0, the stock opened between $0.59 and ...  \n"
         | 
| 72 | 
            +
                 ]
         | 
| 73 | 
            +
                }
         | 
| 74 | 
            +
               ],
         | 
| 75 | 
            +
               "source": [
         | 
| 76 | 
            +
                "# =============================================================================\n",
         | 
| 77 | 
            +
                "# 1. Data Collection & Preprocessing\n",
         | 
| 78 | 
            +
                "# =============================================================================\n",
         | 
| 79 | 
            +
                "# Load the CSV file containing financial data.\n",
         | 
| 80 | 
            +
                "# (Make sure the CSV file \"MSFT_1986-03-13_2025-02-04.csv\" is in the \"data\" folder)\n",
         | 
| 81 | 
            +
                "csv_file_path = r\"D:\\ConvAI_Code\\MSFT_1986-03-13_2025-02-04.csv\"  # Adjust the path if necessary\n",
         | 
| 82 | 
            +
                "# Load the CSV file into a DataFrame\n",
         | 
| 83 | 
            +
                "df = pd.read_csv(csv_file_path)\n",
         | 
| 84 | 
            +
                "\n",
         | 
| 85 | 
            +
                "# Display basic info about the dataset\n",
         | 
| 86 | 
            +
                "print(df.info())\n",
         | 
| 87 | 
            +
                "\n",
         | 
| 88 | 
            +
                "# Data Cleaning & Structuring\n",
         | 
| 89 | 
            +
                "\n",
         | 
| 90 | 
            +
                "# Convert 'Date' column to datetime format\n",
         | 
| 91 | 
            +
                "df['Date'] = pd.to_datetime(df['Date'])\n",
         | 
| 92 | 
            +
                "\n",
         | 
| 93 | 
            +
                "# Sort data by Date\n",
         | 
| 94 | 
            +
                "df = df.sort_values(by='Date')\n",
         | 
| 95 | 
            +
                "\n",
         | 
| 96 | 
            +
                "# Extract Year from Date\n",
         | 
| 97 | 
            +
                "df['Year'] = df['Date'].dt.year\n",
         | 
| 98 | 
            +
                "\n",
         | 
| 99 | 
            +
                "# Aggregate data by Year to generate financial summaries\n",
         | 
| 100 | 
            +
                "yearly_summary = df.groupby('Year').agg(\n",
         | 
| 101 | 
            +
                "    Open_Min=('Open', 'min'),\n",
         | 
| 102 | 
            +
                "    Open_Max=('Open', 'max'),\n",
         | 
| 103 | 
            +
                "    Close_Min=('Close', 'min'),\n",
         | 
| 104 | 
            +
                "    Close_Max=('Close', 'max'),\n",
         | 
| 105 | 
            +
                "    Avg_Volume=('Volume', 'mean')\n",
         | 
| 106 | 
            +
                ").reset_index()\n",
         | 
| 107 | 
            +
                "\n",
         | 
| 108 | 
            +
                "# Create a textual summary for each year\n",
         | 
| 109 | 
            +
                "yearly_summary['Summary'] = yearly_summary.apply(\n",
         | 
| 110 | 
            +
                "    lambda row: f\"In {row['Year']}, the stock opened between ${row['Open_Min']:.2f} and ${row['Open_Max']:.2f}, \"\n",
         | 
| 111 | 
            +
                "                f\"while closing between ${row['Close_Min']:.2f} and ${row['Close_Max']:.2f}. \"\n",
         | 
| 112 | 
            +
                "                f\"The average trading volume was {row['Avg_Volume']:,.0f} shares.\",\n",
         | 
| 113 | 
            +
                "    axis=1\n",
         | 
| 114 | 
            +
                ")\n",
         | 
| 115 | 
            +
                "\n",
         | 
| 116 | 
            +
                "# Display the cleaned and structured data\n",
         | 
| 117 | 
            +
                "print(yearly_summary.head())  # Use this for terminal/console\n",
         | 
| 118 | 
            +
                "# yearly_summary.head()  # Use this in Jupyter Notebook\n",
         | 
| 119 | 
            +
                "\n"
         | 
| 120 | 
            +
               ]
         | 
| 121 | 
            +
              },
         | 
| 122 | 
            +
              {
         | 
| 123 | 
            +
               "cell_type": "code",
         | 
| 124 | 
            +
               "execution_count": 20,
         | 
| 125 | 
            +
               "metadata": {},
         | 
| 126 | 
            +
               "outputs": [
         | 
| 127 | 
            +
                {
         | 
| 128 | 
            +
                 "data": {
         | 
| 129 | 
            +
                  "text/plain": [
         | 
| 130 | 
            +
                   "40"
         | 
| 131 | 
            +
                  ]
         | 
| 132 | 
            +
                 },
         | 
| 133 | 
            +
                 "execution_count": 20,
         | 
| 134 | 
            +
                 "metadata": {},
         | 
| 135 | 
            +
                 "output_type": "execute_result"
         | 
| 136 | 
            +
                }
         | 
| 137 | 
            +
               ],
         | 
| 138 | 
            +
               "source": [
         | 
| 139 | 
            +
                "# =============================================================================\n",
         | 
| 140 | 
            +
                "# 2. Basic RAG Implementation\n",
         | 
| 141 | 
            +
                "# =============================================================================\n",
         | 
| 142 | 
            +
                "# Convert financial summaries into text chunks and generate vector embeddings.\n",
         | 
| 143 | 
            +
                "embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
         | 
| 144 | 
            +
                "\n",
         | 
| 145 | 
            +
                "# Convert yearly financial summaries into vector embeddings\n",
         | 
| 146 | 
            +
                "summary_texts = yearly_summary[\"Summary\"].tolist()  # Extract summaries as text\n",
         | 
| 147 | 
            +
                "summary_embeddings = embedding_model.encode(summary_texts, convert_to_numpy=True)  # Generate embeddings\n",
         | 
| 148 | 
            +
                "\n",
         | 
| 149 | 
            +
                "# Store embeddings as a NumPy array for further processing\n",
         | 
| 150 | 
            +
                "summary_embeddings.shape  # This should be (num_years, embedding_size)\n",
         | 
| 151 | 
            +
                "\n",
         | 
| 152 | 
            +
                "# Define the dimension of embeddings (384 from MiniLM model)\n",
         | 
| 153 | 
            +
                "embedding_dim = 384\n",
         | 
| 154 | 
            +
                "\n",
         | 
| 155 | 
            +
                "# Create a FAISS index (Flat index for now, can be optimized later)\n",
         | 
| 156 | 
            +
                "faiss_index = faiss.IndexFlatL2(embedding_dim)\n",
         | 
| 157 | 
            +
                "\n",
         | 
| 158 | 
            +
                "# Convert embeddings to float32 (FAISS requires this format)\n",
         | 
| 159 | 
            +
                "summary_embeddings = summary_embeddings.astype('float32')\n",
         | 
| 160 | 
            +
                "\n",
         | 
| 161 | 
            +
                "# Add embeddings to the FAISS index\n",
         | 
| 162 | 
            +
                "faiss_index.add(summary_embeddings)\n",
         | 
| 163 | 
            +
                "\n",
         | 
| 164 | 
            +
                "# Store the year information for retrieval\n",
         | 
| 165 | 
            +
                "year_map = {i: yearly_summary[\"Year\"].iloc[i] for i in range(len(yearly_summary))}\n",
         | 
| 166 | 
            +
                "\n",
         | 
| 167 | 
            +
                "# Verify that embeddings are stored successfully\n",
         | 
| 168 | 
            +
                "faiss_index.ntotal\n"
         | 
| 169 | 
            +
               ]
         | 
| 170 | 
            +
              },
         | 
| 171 | 
            +
              {
         | 
| 172 | 
            +
               "cell_type": "code",
         | 
| 173 | 
            +
               "execution_count": 21,
         | 
| 174 | 
            +
               "metadata": {},
         | 
| 175 | 
            +
               "outputs": [
         | 
| 176 | 
            +
                {
         | 
| 177 | 
            +
                 "name": "stdout",
         | 
| 178 | 
            +
                 "output_type": "stream",
         | 
| 179 | 
            +
                 "text": [
         | 
| 180 | 
            +
                  "Merged summaries shape: (12, 2)\n"
         | 
| 181 | 
            +
                 ]
         | 
| 182 | 
            +
                },
         | 
| 183 | 
            +
                {
         | 
| 184 | 
            +
                 "data": {
         | 
| 185 | 
            +
                  "text/html": [
         | 
| 186 | 
            +
                   "<div>\n",
         | 
| 187 | 
            +
                   "<style scoped>\n",
         | 
| 188 | 
            +
                   "    .dataframe tbody tr th:only-of-type {\n",
         | 
| 189 | 
            +
                   "        vertical-align: middle;\n",
         | 
| 190 | 
            +
                   "    }\n",
         | 
| 191 | 
            +
                   "\n",
         | 
| 192 | 
            +
                   "    .dataframe tbody tr th {\n",
         | 
| 193 | 
            +
                   "        vertical-align: top;\n",
         | 
| 194 | 
            +
                   "    }\n",
         | 
| 195 | 
            +
                   "\n",
         | 
| 196 | 
            +
                   "    .dataframe thead th {\n",
         | 
| 197 | 
            +
                   "        text-align: right;\n",
         | 
| 198 | 
            +
                   "    }\n",
         | 
| 199 | 
            +
                   "</style>\n",
         | 
| 200 | 
            +
                   "<table border=\"1\" class=\"dataframe\">\n",
         | 
| 201 | 
            +
                   "  <thead>\n",
         | 
| 202 | 
            +
                   "    <tr style=\"text-align: right;\">\n",
         | 
| 203 | 
            +
                   "      <th></th>\n",
         | 
| 204 | 
            +
                   "      <th>Year</th>\n",
         | 
| 205 | 
            +
                   "      <th>Merged Summary</th>\n",
         | 
| 206 | 
            +
                   "    </tr>\n",
         | 
| 207 | 
            +
                   "  </thead>\n",
         | 
| 208 | 
            +
                   "  <tbody>\n",
         | 
| 209 | 
            +
                   "    <tr>\n",
         | 
| 210 | 
            +
                   "      <th>0</th>\n",
         | 
| 211 | 
            +
                   "      <td>1986</td>\n",
         | 
| 212 | 
            +
                   "      <td>In 1986.0, the stock opened between $0.09 and ...</td>\n",
         | 
| 213 | 
            +
                   "    </tr>\n",
         | 
| 214 | 
            +
                   "    <tr>\n",
         | 
| 215 | 
            +
                   "      <th>1</th>\n",
         | 
| 216 | 
            +
                   "      <td>1990</td>\n",
         | 
| 217 | 
            +
                   "      <td>In 1989.0, the stock opened between $0.32 and ...</td>\n",
         | 
| 218 | 
            +
                   "    </tr>\n",
         | 
| 219 | 
            +
                   "    <tr>\n",
         | 
| 220 | 
            +
                   "      <th>2</th>\n",
         | 
| 221 | 
            +
                   "      <td>1992</td>\n",
         | 
| 222 | 
            +
                   "      <td>In 1991.0, the stock opened between $1.03 and ...</td>\n",
         | 
| 223 | 
            +
                   "    </tr>\n",
         | 
| 224 | 
            +
                   "    <tr>\n",
         | 
| 225 | 
            +
                   "      <th>3</th>\n",
         | 
| 226 | 
            +
                   "      <td>1996</td>\n",
         | 
| 227 | 
            +
                   "      <td>In 1994.0, the stock opened between $2.45 and ...</td>\n",
         | 
| 228 | 
            +
                   "    </tr>\n",
         | 
| 229 | 
            +
                   "    <tr>\n",
         | 
| 230 | 
            +
                   "      <th>4</th>\n",
         | 
| 231 | 
            +
                   "      <td>1999</td>\n",
         | 
| 232 | 
            +
                   "      <td>In 1997.0, the stock opened between $10.25 and...</td>\n",
         | 
| 233 | 
            +
                   "    </tr>\n",
         | 
| 234 | 
            +
                   "  </tbody>\n",
         | 
| 235 | 
            +
                   "</table>\n",
         | 
| 236 | 
            +
                   "</div>"
         | 
| 237 | 
            +
                  ],
         | 
| 238 | 
            +
                  "text/plain": [
         | 
| 239 | 
            +
                   "   Year                                     Merged Summary\n",
         | 
| 240 | 
            +
                   "0  1986  In 1986.0, the stock opened between $0.09 and ...\n",
         | 
| 241 | 
            +
                   "1  1990  In 1989.0, the stock opened between $0.32 and ...\n",
         | 
| 242 | 
            +
                   "2  1992  In 1991.0, the stock opened between $1.03 and ...\n",
         | 
| 243 | 
            +
                   "3  1996  In 1994.0, the stock opened between $2.45 and ...\n",
         | 
| 244 | 
            +
                   "4  1999  In 1997.0, the stock opened between $10.25 and..."
         | 
| 245 | 
            +
                  ]
         | 
| 246 | 
            +
                 },
         | 
| 247 | 
            +
                 "execution_count": 21,
         | 
| 248 | 
            +
                 "metadata": {},
         | 
| 249 | 
            +
                 "output_type": "execute_result"
         | 
| 250 | 
            +
                }
         | 
| 251 | 
            +
               ],
         | 
| 252 | 
            +
               "source": [
         | 
| 253 | 
            +
                "# =============================================================================\n",
         | 
| 254 | 
            +
                "# 3. Advanced RAG Implementation\n",
         | 
| 255 | 
            +
                "# =============================================================================\n",
         | 
| 256 | 
            +
                "# 3.1: BM25 for Keyword-Based Search\n",
         | 
| 257 | 
            +
                "# Tokenize each summary using SpaCy (tokens are converted to lowercase).\n",
         | 
| 258 | 
            +
                "tokenized_summaries = [[token.text.lower() for token in nlp(summary)] for summary in summary_texts]\n",
         | 
| 259 | 
            +
                "# Build the BM25 index.\n",
         | 
| 260 | 
            +
                "bm25 = BM25Okapi(tokenized_summaries)\n",
         | 
| 261 | 
            +
                "\n",
         | 
| 262 | 
            +
                "# 3.2: Define Retrieval Functions\n",
         | 
| 263 | 
            +
                "\n",
         | 
| 264 | 
            +
                "def retrieve_similar_summaries(query_text, top_k=3):\n",
         | 
| 265 | 
            +
                "    \"\"\"\n",
         | 
| 266 | 
            +
                "    Retrieve similar financial summaries using FAISS vector search.\n",
         | 
| 267 | 
            +
                "    \"\"\"\n",
         | 
| 268 | 
            +
                "    query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
         | 
| 269 | 
            +
                "    distances, indices = faiss_index.search(query_embedding, top_k)\n",
         | 
| 270 | 
            +
                "    results = []\n",
         | 
| 271 | 
            +
                "    for idx in indices[0]:\n",
         | 
| 272 | 
            +
                "        results.append((year_map[idx], yearly_summary.iloc[idx][\"Summary\"]))\n",
         | 
| 273 | 
            +
                "    return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
         | 
| 274 | 
            +
                "\n",
         | 
| 275 | 
            +
                "def hybrid_retrieve(query_text, top_k=3, alpha=0.5):\n",
         | 
| 276 | 
            +
                "    \"\"\"\n",
         | 
| 277 | 
            +
                "    Hybrid retrieval combining FAISS (vector search) and BM25 (keyword search).\n",
         | 
| 278 | 
            +
                "    Scores are combined using the weighting factor 'alpha'.\n",
         | 
| 279 | 
            +
                "    \"\"\"\n",
         | 
| 280 | 
            +
                "    query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
         | 
| 281 | 
            +
                "    _, faiss_indices = faiss_index.search(query_embedding, top_k)\n",
         | 
| 282 | 
            +
                "    \n",
         | 
| 283 | 
            +
                "    bm25_scores = bm25.get_scores([token.text.lower() for token in nlp(query_text)])\n",
         | 
| 284 | 
            +
                "    bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]\n",
         | 
| 285 | 
            +
                "    \n",
         | 
| 286 | 
            +
                "    combined_scores = {}\n",
         | 
| 287 | 
            +
                "    for rank, idx in enumerate(faiss_indices[0]):\n",
         | 
| 288 | 
            +
                "        combined_scores[idx] = alpha * (top_k - rank)\n",
         | 
| 289 | 
            +
                "    bm25_norm_scores = normalize([bm25_scores])[0]\n",
         | 
| 290 | 
            +
                "    for rank, idx in enumerate(bm25_top_indices):\n",
         | 
| 291 | 
            +
                "        if idx in combined_scores:\n",
         | 
| 292 | 
            +
                "            combined_scores[idx] += (1 - alpha) * (top_k - rank)\n",
         | 
| 293 | 
            +
                "        else:\n",
         | 
| 294 | 
            +
                "            combined_scores[idx] = (1 - alpha) * (top_k - rank)\n",
         | 
| 295 | 
            +
                "    \n",
         | 
| 296 | 
            +
                "    sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)\n",
         | 
| 297 | 
            +
                "    results = [(year_map[idx], yearly_summary.iloc[idx][\"Summary\"]) for idx, _ in sorted_results]\n",
         | 
| 298 | 
            +
                "    return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
         | 
| 299 | 
            +
                "\n",
         | 
| 300 | 
            +
                "def adaptive_retrieve(query_text, top_k=3, alpha=0.5):\n",
         | 
| 301 | 
            +
                "    \"\"\"\n",
         | 
| 302 | 
            +
                "    Adaptive retrieval re-ranks results by combining FAISS and BM25 scores.\n",
         | 
| 303 | 
            +
                "    \"\"\"\n",
         | 
| 304 | 
            +
                "    query_embedding = embedding_model.encode([query_text], convert_to_numpy=True).astype('float32')\n",
         | 
| 305 | 
            +
                "    _, faiss_indices = faiss_index.search(query_embedding, top_k)\n",
         | 
| 306 | 
            +
                "    \n",
         | 
| 307 | 
            +
                "    query_tokens = [token.text.lower() for token in nlp(query_text)]\n",
         | 
| 308 | 
            +
                "    bm25_scores = bm25.get_scores(query_tokens)\n",
         | 
| 309 | 
            +
                "    bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]\n",
         | 
| 310 | 
            +
                "    \n",
         | 
| 311 | 
            +
                "    faiss_scores = np.linspace(1, 0, num=top_k)\n",
         | 
| 312 | 
            +
                "    bm25_norm_scores = normalize([bm25_scores])[0]\n",
         | 
| 313 | 
            +
                "    \n",
         | 
| 314 | 
            +
                "    combined_scores = {}\n",
         | 
| 315 | 
            +
                "    for rank, idx in enumerate(faiss_indices[0]):\n",
         | 
| 316 | 
            +
                "        combined_scores[idx] = alpha * faiss_scores[rank]\n",
         | 
| 317 | 
            +
                "    for idx in bm25_top_indices:\n",
         | 
| 318 | 
            +
                "        if idx in combined_scores:\n",
         | 
| 319 | 
            +
                "            combined_scores[idx] += (1 - alpha) * bm25_norm_scores[idx]\n",
         | 
| 320 | 
            +
                "        else:\n",
         | 
| 321 | 
            +
                "            combined_scores[idx] = (1 - alpha) * bm25_norm_scores[idx]\n",
         | 
| 322 | 
            +
                "    \n",
         | 
| 323 | 
            +
                "    sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)\n",
         | 
| 324 | 
            +
                "    results = [(year_map[idx], yearly_summary.iloc[idx][\"Summary\"]) for idx, _ in sorted_results]\n",
         | 
| 325 | 
            +
                "    return pd.DataFrame(results, columns=[\"Year\", \"Summary\"])\n",
         | 
| 326 | 
            +
                "\n",
         | 
| 327 | 
            +
                "def merge_similar_chunks(threshold=0.95):\n",
         | 
| 328 | 
            +
                "    \"\"\"\n",
         | 
| 329 | 
            +
                "    Chunk Merging: Merge similar financial summaries based on cosine similarity.\n",
         | 
| 330 | 
            +
                "    This reduces redundancy when multiple chunks are very similar.\n",
         | 
| 331 | 
            +
                "    \"\"\"\n",
         | 
| 332 | 
            +
                "    merged_summaries = []\n",
         | 
| 333 | 
            +
                "    used_indices = set()\n",
         | 
| 334 | 
            +
                "    for i in range(len(summary_embeddings)):\n",
         | 
| 335 | 
            +
                "        if i in used_indices:\n",
         | 
| 336 | 
            +
                "            continue\n",
         | 
| 337 | 
            +
                "        similarities = cosine_similarity([summary_embeddings[i]], summary_embeddings)[0]\n",
         | 
| 338 | 
            +
                "        similar_indices = np.where(similarities >= threshold)[0]\n",
         | 
| 339 | 
            +
                "        merged_text = \" \".join(yearly_summary.iloc[idx][\"Summary\"] for idx in similar_indices)\n",
         | 
| 340 | 
            +
                "        merged_summaries.append((yearly_summary.iloc[i][\"Year\"], merged_text))\n",
         | 
| 341 | 
            +
                "        used_indices.update(similar_indices)\n",
         | 
| 342 | 
            +
                "    return pd.DataFrame(merged_summaries, columns=[\"Year\", \"Merged Summary\"])\n",
         | 
| 343 | 
            +
                "\n",
         | 
| 344 | 
            +
                "# Optional: Check merged summaries for debugging.\n",
         | 
| 345 | 
            +
                "merged_summary_df = merge_similar_chunks(threshold=0.95)\n",
         | 
| 346 | 
            +
                "print(\"Merged summaries shape:\", merged_summary_df.shape)\n",
         | 
| 347 | 
            +
                "merged_summary_df.head()\n"
         | 
| 348 | 
            +
               ]
         | 
| 349 | 
            +
              },
         | 
| 350 | 
            +
              {
         | 
| 351 | 
            +
               "cell_type": "code",
         | 
| 352 | 
            +
               "execution_count": 34,
         | 
| 353 | 
            +
               "metadata": {},
         | 
| 354 | 
            +
               "outputs": [
         | 
| 355 | 
            +
                {
         | 
| 356 | 
            +
                 "name": "stdout",
         | 
| 357 | 
            +
                 "output_type": "stream",
         | 
| 358 | 
            +
                 "text": [
         | 
| 359 | 
            +
                  "* Running on local URL:  http://127.0.0.1:7864\n",
         | 
| 360 | 
            +
                  "\n",
         | 
| 361 | 
            +
                  "To create a public link, set `share=True` in `launch()`.\n"
         | 
| 362 | 
            +
                 ]
         | 
| 363 | 
            +
                },
         | 
| 364 | 
            +
                {
         | 
| 365 | 
            +
                 "data": {
         | 
| 366 | 
            +
                  "text/html": [
         | 
| 367 | 
            +
                   "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
         | 
| 368 | 
            +
                  ],
         | 
| 369 | 
            +
                  "text/plain": [
         | 
| 370 | 
            +
                   "<IPython.core.display.HTML object>"
         | 
| 371 | 
            +
                  ]
         | 
| 372 | 
            +
                 },
         | 
| 373 | 
            +
                 "metadata": {},
         | 
| 374 | 
            +
                 "output_type": "display_data"
         | 
| 375 | 
            +
                },
         | 
| 376 | 
            +
                {
         | 
| 377 | 
            +
                 "data": {
         | 
| 378 | 
            +
                  "text/plain": []
         | 
| 379 | 
            +
                 },
         | 
| 380 | 
            +
                 "execution_count": 34,
         | 
| 381 | 
            +
                 "metadata": {},
         | 
| 382 | 
            +
                 "output_type": "execute_result"
         | 
| 383 | 
            +
                }
         | 
| 384 | 
            +
               ],
         | 
| 385 | 
            +
               "source": [
         | 
| 386 | 
            +
                "# =============================================================================\n",
         | 
| 387 | 
            +
                "# 4. UI Development using Gradio (Updated for newer API)\n",
         | 
| 388 | 
            +
                "# =============================================================================\n",
         | 
| 389 | 
            +
                "def generate_response(query_text, top_k=3, alpha=0.5):\n",
         | 
| 390 | 
            +
                "    \"\"\"\n",
         | 
| 391 | 
            +
                "    Generate an answer for a financial query by:\n",
         | 
| 392 | 
            +
                "      - Validating the query with an input-side guardrail.\n",
         | 
| 393 | 
            +
                "      - Retrieving context using adaptive retrieval.\n",
         | 
| 394 | 
            +
                "      - Generating a refined answer using FLAN-T5-Small.\n",
         | 
| 395 | 
            +
                "    Returns:\n",
         | 
| 396 | 
            +
                "      answer (str): The generated answer.\n",
         | 
| 397 | 
            +
                "      confidence (float): A mock confidence score based on BM25 scores.\n",
         | 
| 398 | 
            +
                "    \"\"\"\n",
         | 
| 399 | 
            +
                "    # -----------------------------------------------------------------------------\n",
         | 
| 400 | 
            +
                "    # Guard Rail Implementation (Input-Side)\n",
         | 
| 401 | 
            +
                "    # -----------------------------------------------------------------------------\n",
         | 
| 402 | 
            +
                "    financial_keywords = [\"open\", \"close\", \"stock\", \"price\", \"volume\", \"trading\"]\n",
         | 
| 403 | 
            +
                "    if not any(keyword in query_text.lower() for keyword in financial_keywords):\n",
         | 
| 404 | 
            +
                "        return (\"Guardrail Triggered: Your query does not appear to be related to financial data. Please ask a financial question.\"), 0.0\n",
         | 
| 405 | 
            +
                "\n",
         | 
| 406 | 
            +
                "    # Retrieve context using adaptive retrieval.\n",
         | 
| 407 | 
            +
                "    context_df = adaptive_retrieve(query_text, top_k=top_k, alpha=alpha)\n",
         | 
| 408 | 
            +
                "    context_text = \" \".join(context_df[\"Summary\"].tolist())\n",
         | 
| 409 | 
            +
                "    \n",
         | 
| 410 | 
            +
                "    # Adjust the prompt to provide clear instructions.\n",
         | 
| 411 | 
            +
                "    prompt = f\"Given the following financial data:\\n{context_text}\\nAnswer this question: {query_text}.\"\n",
         | 
| 412 | 
            +
                "    \n",
         | 
| 413 | 
            +
                "    # Use FLAN-T5-Small for text generation via the text2text-generation pipeline.\n",
         | 
| 414 | 
            +
                "    # Increase max_length to allow longer answers.\n",
         | 
| 415 | 
            +
                "    generator = pipeline('text2text-generation', model='google/flan-t5-small')\n",
         | 
| 416 | 
            +
                "    generated = generator(prompt, max_length=200, num_return_sequences=1)\n",
         | 
| 417 | 
            +
                "    answer = generated[0]['generated_text'].replace(prompt, \"\").strip()\n",
         | 
| 418 | 
            +
                "    \n",
         | 
| 419 | 
            +
                "    # Fallback message if answer is empty.\n",
         | 
| 420 | 
            +
                "    if not answer:\n",
         | 
| 421 | 
            +
                "        answer = \"I'm sorry, I couldn't generate a clear answer. Please try rephrasing your question.\"\n",
         | 
| 422 | 
            +
                "    \n",
         | 
| 423 | 
            +
                "    # Compute a mock confidence score using normalized BM25 scores.\n",
         | 
| 424 | 
            +
                "    query_tokens = [token.text.lower() for token in nlp(query_text)]\n",
         | 
| 425 | 
            +
                "    bm25_scores = bm25.get_scores(query_tokens)\n",
         | 
| 426 | 
            +
                "    max_score = np.max(bm25_scores) if np.max(bm25_scores) > 0 else 1\n",
         | 
| 427 | 
            +
                "    confidence = round(np.mean(bm25_scores) / max_score, 2)\n",
         | 
| 428 | 
            +
                "    \n",
         | 
| 429 | 
            +
                "    return answer, confidence\n",
         | 
| 430 | 
            +
                "\n",
         | 
| 431 | 
            +
                "# Create the Gradio interface using the new API.\n",
         | 
| 432 | 
            +
                "iface = gr.Interface(\n",
         | 
| 433 | 
            +
                "    fn=generate_response,\n",
         | 
| 434 | 
            +
                "    inputs=gr.Textbox(lines=2, placeholder=\"Enter your financial question here...\"),\n",
         | 
| 435 | 
            +
                "    outputs=[gr.Textbox(label=\"Answer\"), gr.Textbox(label=\"Confidence Score\")],\n",
         | 
| 436 | 
            +
                "    title=\"Financial RAG Model Interface\",\n",
         | 
| 437 | 
            +
                "    description=(\"Ask questions based on the company's financial summaries  \"\n",
         | 
| 438 | 
            +
                "                 )\n",
         | 
| 439 | 
            +
                ")\n",
         | 
| 440 | 
            +
                "\n",
         | 
| 441 | 
            +
                "# Launch the Gradio interface.\n",
         | 
| 442 | 
            +
                "iface.launch()\n"
         | 
| 443 | 
            +
               ]
         | 
| 444 | 
            +
              },
         | 
| 445 | 
            +
              {
         | 
| 446 | 
            +
               "cell_type": "code",
         | 
| 447 | 
            +
               "execution_count": 29,
         | 
| 448 | 
            +
               "metadata": {},
         | 
| 449 | 
            +
               "outputs": [
         | 
| 450 | 
            +
                {
         | 
| 451 | 
            +
                 "name": "stderr",
         | 
| 452 | 
            +
                 "output_type": "stream",
         | 
| 453 | 
            +
                 "text": [
         | 
| 454 | 
            +
                  "Device set to use cpu\n"
         | 
| 455 | 
            +
                 ]
         | 
| 456 | 
            +
                },
         | 
| 457 | 
            +
                {
         | 
| 458 | 
            +
                 "name": "stdout",
         | 
| 459 | 
            +
                 "output_type": "stream",
         | 
| 460 | 
            +
                 "text": [
         | 
| 461 | 
            +
                  "Question:  What year had the lowest stock prices?\n",
         | 
| 462 | 
            +
                  "Answer:  I'm sorry, I couldn't generate a clear answer. Please try rephrasing your question.\n",
         | 
| 463 | 
            +
                  "Confidence Score:  1.0\n",
         | 
| 464 | 
            +
                  "--------------------------------------------------\n"
         | 
| 465 | 
            +
                 ]
         | 
| 466 | 
            +
                },
         | 
| 467 | 
            +
                {
         | 
| 468 | 
            +
                 "name": "stderr",
         | 
| 469 | 
            +
                 "output_type": "stream",
         | 
| 470 | 
            +
                 "text": [
         | 
| 471 | 
            +
                  "Device set to use cpu\n"
         | 
| 472 | 
            +
                 ]
         | 
| 473 | 
            +
                },
         | 
| 474 | 
            +
                {
         | 
| 475 | 
            +
                 "name": "stdout",
         | 
| 476 | 
            +
                 "output_type": "stream",
         | 
| 477 | 
            +
                 "text": [
         | 
| 478 | 
            +
                  "Question:  How did the trading volume vary?\n",
         | 
| 479 | 
            +
                  "Answer:  The average trading volume was 23,244,919 shares\n",
         | 
| 480 | 
            +
                  "Confidence Score:  1.0\n",
         | 
| 481 | 
            +
                  "--------------------------------------------------\n",
         | 
| 482 | 
            +
                  "Question:  What is the capital of France?\n",
         | 
| 483 | 
            +
                  "Answer:  Guardrail Triggered: Your query does not appear to be related to financial data. Please ask a financial question.\n",
         | 
| 484 | 
            +
                  "Confidence Score:  0.0\n",
         | 
| 485 | 
            +
                  "--------------------------------------------------\n"
         | 
| 486 | 
            +
                 ]
         | 
| 487 | 
            +
                }
         | 
| 488 | 
            +
               ],
         | 
| 489 | 
            +
               "source": [
         | 
| 490 | 
            +
                "# =============================================================================\n",
         | 
| 491 | 
            +
                "# 6. Testing & Validation (Updated)\n",
         | 
| 492 | 
            +
                "# =============================================================================\n",
         | 
| 493 | 
            +
                "def print_test_results(query_text, top_k=3, alpha=0.5):\n",
         | 
| 494 | 
            +
                "    answer, confidence = generate_response(query_text, top_k, alpha)\n",
         | 
| 495 | 
            +
                "    print(\"Question: \", query_text)\n",
         | 
| 496 | 
            +
                "    print(\"Answer: \", answer)\n",
         | 
| 497 | 
            +
                "    print(\"Confidence Score: \", confidence)\n",
         | 
| 498 | 
            +
                "    print(\"-\" * 50)\n",
         | 
| 499 | 
            +
                "\n",
         | 
| 500 | 
            +
                "# Test 1: High-confidence financial query.\n",
         | 
| 501 | 
            +
                "query_high = \"What year had the lowest stock prices?\"\n",
         | 
| 502 | 
            +
                "print_test_results(query_high)\n",
         | 
| 503 | 
            +
                "\n",
         | 
| 504 | 
            +
                "# Test 2: Low-confidence financial query.\n",
         | 
| 505 | 
            +
                "query_low = \"How did the trading volume vary?\"\n",
         | 
| 506 | 
            +
                "print_test_results(query_low)\n",
         | 
| 507 | 
            +
                "\n",
         | 
| 508 | 
            +
                "# Test 3: Irrelevant query (should trigger guardrail).\n",
         | 
| 509 | 
            +
                "query_irrelevant = \"What is the capital of France?\"\n",
         | 
| 510 | 
            +
                "print_test_results(query_irrelevant)\n"
         | 
| 511 | 
            +
               ]
         | 
| 512 | 
            +
              }
         | 
| 513 | 
            +
             ],
         | 
| 514 | 
            +
             "metadata": {
         | 
| 515 | 
            +
              "kernelspec": {
         | 
| 516 | 
            +
               "display_name": "Python 3",
         | 
| 517 | 
            +
               "language": "python",
         | 
| 518 | 
            +
               "name": "python3"
         | 
| 519 | 
            +
              },
         | 
| 520 | 
            +
              "language_info": {
         | 
| 521 | 
            +
               "codemirror_mode": {
         | 
| 522 | 
            +
                "name": "ipython",
         | 
| 523 | 
            +
                "version": 3
         | 
| 524 | 
            +
               },
         | 
| 525 | 
            +
               "file_extension": ".py",
         | 
| 526 | 
            +
               "mimetype": "text/x-python",
         | 
| 527 | 
            +
               "name": "python",
         | 
| 528 | 
            +
               "nbconvert_exporter": "python",
         | 
| 529 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 530 | 
            +
               "version": "3.12.0"
         | 
| 531 | 
            +
              }
         | 
| 532 | 
            +
             },
         | 
| 533 | 
            +
             "nbformat": 4,
         | 
| 534 | 
            +
             "nbformat_minor": 2
         | 
| 535 | 
            +
            }
         |