Spaces:
Running
Running
Use cross encoder for QA
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from transformers import pipeline
|
4 |
-
from
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
import string
|
7 |
from nltk.tokenize import word_tokenize
|
@@ -99,8 +99,8 @@ st.markdown("<div class='custom-header'> 🧩 AI-Powered News Analyzer</div>", u
|
|
99 |
classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
|
100 |
qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
|
101 |
|
102 |
-
# Initialize
|
103 |
-
|
104 |
|
105 |
# Define preprocessing functions for classification
|
106 |
def preprocess_text(text):
|
@@ -200,15 +200,14 @@ with col2:
|
|
200 |
if 'content' in df.columns:
|
201 |
context = df['content'].dropna().tolist() # Use the content column as context
|
202 |
|
203 |
-
#
|
204 |
-
|
205 |
-
question_embedding = sentence_model.encode([user_question])
|
206 |
|
207 |
-
#
|
208 |
-
|
209 |
-
top_indices = similarities[0].argsort()[-5:][::-1] # Get top 5 similar rows
|
210 |
|
211 |
-
#
|
|
|
212 |
top_context = "\n".join([context[i] for i in top_indices])
|
213 |
|
214 |
# Get answer from Hugging Face model using top context
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from transformers import pipeline
|
4 |
+
from sentence_transformers import CrossEncoder
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
import string
|
7 |
from nltk.tokenize import word_tokenize
|
|
|
99 |
classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
|
100 |
qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
|
101 |
|
102 |
+
# Initialize Cross-Encoder for QA relevance scoring
|
103 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Pre-trained Cross-Encoder model
|
104 |
|
105 |
# Define preprocessing functions for classification
|
106 |
def preprocess_text(text):
|
|
|
200 |
if 'content' in df.columns:
|
201 |
context = df['content'].dropna().tolist() # Use the content column as context
|
202 |
|
203 |
+
# Prepare pairs of (question, context)
|
204 |
+
pairs = [(user_question, c) for c in context]
|
|
|
205 |
|
206 |
+
# Score each pair using the Cross-Encoder
|
207 |
+
scores = cross_encoder.predict(pairs)
|
|
|
208 |
|
209 |
+
# Get top matches based on scores
|
210 |
+
top_indices = scores.argsort()[-5:][::-1] # Get indices of top 5 matches
|
211 |
top_context = "\n".join([context[i] for i in top_indices])
|
212 |
|
213 |
# Get answer from Hugging Face model using top context
|