Sandini commited on
Commit
ec85722
·
verified ·
1 Parent(s): d141fe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -23
app.py CHANGED
@@ -1,17 +1,14 @@
1
  import streamlit as st
2
  import pandas as pd
3
  from transformers import pipeline
4
- from sklearn.metrics.pairwise import cosine_similarity
5
- from sentence_transformers import SentenceTransformer
6
- import string
7
  from nltk.tokenize import word_tokenize
8
  from nltk.corpus import stopwords
9
  from nltk.stem import WordNetLemmatizer
10
- import nltk
11
 
12
  # Download NLTK resources (run this once if not already downloaded)
13
  nltk.download('punkt')
14
- nltk.download('punkt_tab')
15
  nltk.download('stopwords')
16
  nltk.download('wordnet')
17
 
@@ -28,7 +25,6 @@ st.markdown("""
28
  margin: 0;
29
  padding: 0;
30
  }
31
-
32
  /* Header Styling */
33
  .custom-header {
34
  background: linear-gradient(to right, #1f4068, #1b1b2f);
@@ -40,7 +36,6 @@ st.markdown("""
40
  font-weight: bold;
41
  box-shadow: 0px 4px 15px rgba(0, 217, 255, 0.3);
42
  }
43
-
44
  /* Buttons */
45
  .stButton>button {
46
  background: linear-gradient(45deg, #0072ff, #00c6ff);
@@ -54,7 +49,6 @@ st.markdown("""
54
  transform: scale(1.05);
55
  box-shadow: 0px 4px 10px rgba(0, 255, 255, 0.5);
56
  }
57
-
58
  /* Text Input */
59
  .stTextInput>div>div>input {
60
  background-color: rgba(255, 255, 255, 0.1);
@@ -62,14 +56,12 @@ st.markdown("""
62
  padding: 12px;
63
  font-size: 18px;
64
  }
65
-
66
  /* Dataframe Container */
67
  .dataframe-container {
68
  background: rgba(255, 255, 255, 0.1);
69
  padding: 15px;
70
  border-radius: 12px;
71
  }
72
-
73
  /* Answer Display Box - Larger */
74
  .answer-box {
75
  background: rgba(0, 217, 255, 0.15);
@@ -86,7 +78,6 @@ st.markdown("""
86
  justify-content: center;
87
  transition: all 0.3s ease;
88
  }
89
-
90
  /* CSV Display Box */
91
  .csv-box {
92
  background: rgba(255, 255, 255, 0.1);
@@ -105,8 +96,8 @@ st.markdown("<div class='custom-header'> 🧩 AI-Powered News Analyzer</div>", u
105
  classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
106
  qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
107
 
108
- # Initialize sentence transformer model for QA similarity
109
- sentence_model = SentenceTransformer('all-MiniLM-L6-v2') # Pre-trained sentence model
110
 
111
  # Define preprocessing functions for classification
112
  def preprocess_text(text):
@@ -186,7 +177,6 @@ with col1:
186
  st.markdown("<div class='csv-box'><h4>📜 CSV/Excel Preview</h4></div>", unsafe_allow_html=True)
187
  st.dataframe(df_for_display, use_container_width=True)
188
 
189
-
190
  # Right Section - Q&A Interface
191
  with col2:
192
  st.subheader("🤖 AI Assistant")
@@ -206,15 +196,14 @@ with col2:
206
  if 'content' in df.columns:
207
  context = df['content'].dropna().tolist() # Use the content column as context
208
 
209
- # Generate embeddings for the context and the question
210
- context_embeddings = sentence_model.encode(context)
211
- question_embedding = sentence_model.encode([user_question])
212
 
213
- # Calculate cosine similarity
214
- similarities = cosine_similarity(question_embedding, context_embeddings)
215
- top_indices = similarities[0].argsort()[-5:][::-1] # Get top 5 similar rows
216
 
217
- # Prepare the top 5 similar context rows
 
218
  top_context = "\n".join([context[i] for i in top_indices])
219
 
220
  # Get answer from Hugging Face model using top context
@@ -225,5 +214,4 @@ with col2:
225
  else:
226
  answer = "⚠️ Please upload a valid file first!"
227
 
228
- answer_placeholder.markdown(f"<div class='answer-box'>{answer}</div>", unsafe_allow_html=True)
229
-
 
1
  import streamlit as st
2
  import pandas as pd
3
  from transformers import pipeline
4
+ from sentence_transformers import CrossEncoder
5
+ import nltk
 
6
  from nltk.tokenize import word_tokenize
7
  from nltk.corpus import stopwords
8
  from nltk.stem import WordNetLemmatizer
 
9
 
10
  # Download NLTK resources (run this once if not already downloaded)
11
  nltk.download('punkt')
 
12
  nltk.download('stopwords')
13
  nltk.download('wordnet')
14
 
 
25
  margin: 0;
26
  padding: 0;
27
  }
 
28
  /* Header Styling */
29
  .custom-header {
30
  background: linear-gradient(to right, #1f4068, #1b1b2f);
 
36
  font-weight: bold;
37
  box-shadow: 0px 4px 15px rgba(0, 217, 255, 0.3);
38
  }
 
39
  /* Buttons */
40
  .stButton>button {
41
  background: linear-gradient(45deg, #0072ff, #00c6ff);
 
49
  transform: scale(1.05);
50
  box-shadow: 0px 4px 10px rgba(0, 255, 255, 0.5);
51
  }
 
52
  /* Text Input */
53
  .stTextInput>div>div>input {
54
  background-color: rgba(255, 255, 255, 0.1);
 
56
  padding: 12px;
57
  font-size: 18px;
58
  }
 
59
  /* Dataframe Container */
60
  .dataframe-container {
61
  background: rgba(255, 255, 255, 0.1);
62
  padding: 15px;
63
  border-radius: 12px;
64
  }
 
65
  /* Answer Display Box - Larger */
66
  .answer-box {
67
  background: rgba(0, 217, 255, 0.15);
 
78
  justify-content: center;
79
  transition: all 0.3s ease;
80
  }
 
81
  /* CSV Display Box */
82
  .csv-box {
83
  background: rgba(255, 255, 255, 0.1);
 
96
  classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
97
  qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
98
 
99
+ # Initialize Cross-Encoder for QA relevance scoring
100
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Pre-trained Cross-Encoder model
101
 
102
  # Define preprocessing functions for classification
103
  def preprocess_text(text):
 
177
  st.markdown("<div class='csv-box'><h4>📜 CSV/Excel Preview</h4></div>", unsafe_allow_html=True)
178
  st.dataframe(df_for_display, use_container_width=True)
179
 
 
180
  # Right Section - Q&A Interface
181
  with col2:
182
  st.subheader("🤖 AI Assistant")
 
196
  if 'content' in df.columns:
197
  context = df['content'].dropna().tolist() # Use the content column as context
198
 
199
+ # Prepare pairs of (question, context)
200
+ pairs = [(user_question, c) for c in context]
 
201
 
202
+ # Score each pair using the Cross-Encoder
203
+ scores = cross_encoder.predict(pairs)
 
204
 
205
+ # Get top matches based on scores
206
+ top_indices = scores.argsort()[-5:][::-1] # Get indices of top 5 matches
207
  top_context = "\n".join([context[i] for i in top_indices])
208
 
209
  # Get answer from Hugging Face model using top context
 
214
  else:
215
  answer = "⚠️ Please upload a valid file first!"
216
 
217
+ answer_placeholder.markdown(f"<div class='answer-box'>{answer}</div>", unsafe_allow_html=True)