Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from transformers import pipeline
|
4 |
-
from
|
5 |
-
|
6 |
-
import string
|
7 |
from nltk.tokenize import word_tokenize
|
8 |
from nltk.corpus import stopwords
|
9 |
from nltk.stem import WordNetLemmatizer
|
10 |
-
import nltk
|
11 |
|
12 |
# Download NLTK resources (run this once if not already downloaded)
|
13 |
nltk.download('punkt')
|
14 |
-
nltk.download('punkt_tab')
|
15 |
nltk.download('stopwords')
|
16 |
nltk.download('wordnet')
|
17 |
|
@@ -28,7 +25,6 @@ st.markdown("""
|
|
28 |
margin: 0;
|
29 |
padding: 0;
|
30 |
}
|
31 |
-
|
32 |
/* Header Styling */
|
33 |
.custom-header {
|
34 |
background: linear-gradient(to right, #1f4068, #1b1b2f);
|
@@ -40,7 +36,6 @@ st.markdown("""
|
|
40 |
font-weight: bold;
|
41 |
box-shadow: 0px 4px 15px rgba(0, 217, 255, 0.3);
|
42 |
}
|
43 |
-
|
44 |
/* Buttons */
|
45 |
.stButton>button {
|
46 |
background: linear-gradient(45deg, #0072ff, #00c6ff);
|
@@ -54,7 +49,6 @@ st.markdown("""
|
|
54 |
transform: scale(1.05);
|
55 |
box-shadow: 0px 4px 10px rgba(0, 255, 255, 0.5);
|
56 |
}
|
57 |
-
|
58 |
/* Text Input */
|
59 |
.stTextInput>div>div>input {
|
60 |
background-color: rgba(255, 255, 255, 0.1);
|
@@ -62,14 +56,12 @@ st.markdown("""
|
|
62 |
padding: 12px;
|
63 |
font-size: 18px;
|
64 |
}
|
65 |
-
|
66 |
/* Dataframe Container */
|
67 |
.dataframe-container {
|
68 |
background: rgba(255, 255, 255, 0.1);
|
69 |
padding: 15px;
|
70 |
border-radius: 12px;
|
71 |
}
|
72 |
-
|
73 |
/* Answer Display Box - Larger */
|
74 |
.answer-box {
|
75 |
background: rgba(0, 217, 255, 0.15);
|
@@ -86,7 +78,6 @@ st.markdown("""
|
|
86 |
justify-content: center;
|
87 |
transition: all 0.3s ease;
|
88 |
}
|
89 |
-
|
90 |
/* CSV Display Box */
|
91 |
.csv-box {
|
92 |
background: rgba(255, 255, 255, 0.1);
|
@@ -105,8 +96,8 @@ st.markdown("<div class='custom-header'> 🧩 AI-Powered News Analyzer</div>", u
|
|
105 |
classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
|
106 |
qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
|
107 |
|
108 |
-
# Initialize
|
109 |
-
|
110 |
|
111 |
# Define preprocessing functions for classification
|
112 |
def preprocess_text(text):
|
@@ -186,7 +177,6 @@ with col1:
|
|
186 |
st.markdown("<div class='csv-box'><h4>📜 CSV/Excel Preview</h4></div>", unsafe_allow_html=True)
|
187 |
st.dataframe(df_for_display, use_container_width=True)
|
188 |
|
189 |
-
|
190 |
# Right Section - Q&A Interface
|
191 |
with col2:
|
192 |
st.subheader("🤖 AI Assistant")
|
@@ -206,15 +196,14 @@ with col2:
|
|
206 |
if 'content' in df.columns:
|
207 |
context = df['content'].dropna().tolist() # Use the content column as context
|
208 |
|
209 |
-
#
|
210 |
-
|
211 |
-
question_embedding = sentence_model.encode([user_question])
|
212 |
|
213 |
-
#
|
214 |
-
|
215 |
-
top_indices = similarities[0].argsort()[-5:][::-1] # Get top 5 similar rows
|
216 |
|
217 |
-
#
|
|
|
218 |
top_context = "\n".join([context[i] for i in top_indices])
|
219 |
|
220 |
# Get answer from Hugging Face model using top context
|
@@ -225,5 +214,4 @@ with col2:
|
|
225 |
else:
|
226 |
answer = "⚠️ Please upload a valid file first!"
|
227 |
|
228 |
-
answer_placeholder.markdown(f"<div class='answer-box'>{answer}</div>", unsafe_allow_html=True)
|
229 |
-
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from transformers import pipeline
|
4 |
+
from sentence_transformers import CrossEncoder
|
5 |
+
import nltk
|
|
|
6 |
from nltk.tokenize import word_tokenize
|
7 |
from nltk.corpus import stopwords
|
8 |
from nltk.stem import WordNetLemmatizer
|
|
|
9 |
|
10 |
# Download NLTK resources (run this once if not already downloaded)
|
11 |
nltk.download('punkt')
|
|
|
12 |
nltk.download('stopwords')
|
13 |
nltk.download('wordnet')
|
14 |
|
|
|
25 |
margin: 0;
|
26 |
padding: 0;
|
27 |
}
|
|
|
28 |
/* Header Styling */
|
29 |
.custom-header {
|
30 |
background: linear-gradient(to right, #1f4068, #1b1b2f);
|
|
|
36 |
font-weight: bold;
|
37 |
box-shadow: 0px 4px 15px rgba(0, 217, 255, 0.3);
|
38 |
}
|
|
|
39 |
/* Buttons */
|
40 |
.stButton>button {
|
41 |
background: linear-gradient(45deg, #0072ff, #00c6ff);
|
|
|
49 |
transform: scale(1.05);
|
50 |
box-shadow: 0px 4px 10px rgba(0, 255, 255, 0.5);
|
51 |
}
|
|
|
52 |
/* Text Input */
|
53 |
.stTextInput>div>div>input {
|
54 |
background-color: rgba(255, 255, 255, 0.1);
|
|
|
56 |
padding: 12px;
|
57 |
font-size: 18px;
|
58 |
}
|
|
|
59 |
/* Dataframe Container */
|
60 |
.dataframe-container {
|
61 |
background: rgba(255, 255, 255, 0.1);
|
62 |
padding: 15px;
|
63 |
border-radius: 12px;
|
64 |
}
|
|
|
65 |
/* Answer Display Box - Larger */
|
66 |
.answer-box {
|
67 |
background: rgba(0, 217, 255, 0.15);
|
|
|
78 |
justify-content: center;
|
79 |
transition: all 0.3s ease;
|
80 |
}
|
|
|
81 |
/* CSV Display Box */
|
82 |
.csv-box {
|
83 |
background: rgba(255, 255, 255, 0.1);
|
|
|
96 |
classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
|
97 |
qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
|
98 |
|
99 |
+
# Initialize Cross-Encoder for QA relevance scoring
|
100 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Pre-trained Cross-Encoder model
|
101 |
|
102 |
# Define preprocessing functions for classification
|
103 |
def preprocess_text(text):
|
|
|
177 |
st.markdown("<div class='csv-box'><h4>📜 CSV/Excel Preview</h4></div>", unsafe_allow_html=True)
|
178 |
st.dataframe(df_for_display, use_container_width=True)
|
179 |
|
|
|
180 |
# Right Section - Q&A Interface
|
181 |
with col2:
|
182 |
st.subheader("🤖 AI Assistant")
|
|
|
196 |
if 'content' in df.columns:
|
197 |
context = df['content'].dropna().tolist() # Use the content column as context
|
198 |
|
199 |
+
# Prepare pairs of (question, context)
|
200 |
+
pairs = [(user_question, c) for c in context]
|
|
|
201 |
|
202 |
+
# Score each pair using the Cross-Encoder
|
203 |
+
scores = cross_encoder.predict(pairs)
|
|
|
204 |
|
205 |
+
# Get top matches based on scores
|
206 |
+
top_indices = scores.argsort()[-5:][::-1] # Get indices of top 5 matches
|
207 |
top_context = "\n".join([context[i] for i in top_indices])
|
208 |
|
209 |
# Get answer from Hugging Face model using top context
|
|
|
214 |
else:
|
215 |
answer = "⚠️ Please upload a valid file first!"
|
216 |
|
217 |
+
answer_placeholder.markdown(f"<div class='answer-box'>{answer}</div>", unsafe_allow_html=True)
|
|