Sandini commited on
Commit
6ce73bd
·
verified ·
1 Parent(s): 1651757

Use cross encoder for QA

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  import pandas as pd
3
  from transformers import pipeline
4
- from sklearn.metrics.pairwise import cosine_similarity
5
  from sentence_transformers import SentenceTransformer
6
  import string
7
  from nltk.tokenize import word_tokenize
@@ -99,8 +99,8 @@ st.markdown("<div class='custom-header'> 🧩 AI-Powered News Analyzer</div>", u
99
  classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
100
  qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
101
 
102
- # Initialize sentence transformer model for QA similarity
103
- sentence_model = SentenceTransformer('all-MiniLM-L6-v2') # Pre-trained sentence model
104
 
105
  # Define preprocessing functions for classification
106
  def preprocess_text(text):
@@ -200,15 +200,14 @@ with col2:
200
  if 'content' in df.columns:
201
  context = df['content'].dropna().tolist() # Use the content column as context
202
 
203
- # Generate embeddings for the context and the question
204
- context_embeddings = sentence_model.encode(context)
205
- question_embedding = sentence_model.encode([user_question])
206
 
207
- # Calculate cosine similarity
208
- similarities = cosine_similarity(question_embedding, context_embeddings)
209
- top_indices = similarities[0].argsort()[-5:][::-1] # Get top 5 similar rows
210
 
211
- # Prepare the top 5 similar context rows
 
212
  top_context = "\n".join([context[i] for i in top_indices])
213
 
214
  # Get answer from Hugging Face model using top context
 
1
  import streamlit as st
2
  import pandas as pd
3
  from transformers import pipeline
4
+ from sentence_transformers import CrossEncoder
5
  from sentence_transformers import SentenceTransformer
6
  import string
7
  from nltk.tokenize import word_tokenize
 
99
  classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline
100
  qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline
101
 
102
+ # Initialize Cross-Encoder for QA relevance scoring
103
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Pre-trained Cross-Encoder model
104
 
105
  # Define preprocessing functions for classification
106
  def preprocess_text(text):
 
200
  if 'content' in df.columns:
201
  context = df['content'].dropna().tolist() # Use the content column as context
202
 
203
+ # Prepare pairs of (question, context)
204
+ pairs = [(user_question, c) for c in context]
 
205
 
206
+ # Score each pair using the Cross-Encoder
207
+ scores = cross_encoder.predict(pairs)
 
208
 
209
+ # Get top matches based on scores
210
+ top_indices = scores.argsort()[-5:][::-1] # Get indices of top 5 matches
211
  top_context = "\n".join([context[i] for i in top_indices])
212
 
213
  # Get answer from Hugging Face model using top context