Credible / ranker.py
tomalex04's picture
Initial commit
700863c
raw
history blame
5.45 kB
"""
ranker.py - Module for ranking news articles based on relevance to query
This module uses SentenceTransformers to create embeddings and rank articles
based on their semantic similarity to the user query.
"""
import os
import torch
from sentence_transformers import SentenceTransformer
# Global model cache dictionary to store loaded models
_MODEL_CACHE = {}
# Check if CUDA is available
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
class ArticleRanker:
"""A class to rank articles based on their relevance to a query using embeddings."""
def __init__(self, model_name=None):
"""
Initialize the ranker with a SentenceTransformer model.
Args:
model_name (str): Name of the SentenceTransformer model to use
"""
# Get model name or use default
self.model_name = model_name or 'intfloat/multilingual-e5-base'
try:
# Check if model is already in cache
if self.model_name in _MODEL_CACHE:
self.model = _MODEL_CACHE[self.model_name]
print(f"Using cached model: {self.model_name} on {DEVICE}")
else:
# Load model and add to cache, automatically using the appropriate device
self.model = SentenceTransformer(self.model_name, device=str(DEVICE))
_MODEL_CACHE[self.model_name] = self.model
print(f"Loaded model: {self.model_name} on {DEVICE}")
except Exception as e:
print(f"⚠️ Error loading model: {e}")
print("Make sure you have installed sentence-transformers:")
print("pip install sentence-transformers")
raise
def create_embeddings(self, query, article_texts):
"""
Create embeddings for the query and article texts.
Args:
query (str): The user's query/news statement
article_texts (list): List of article text strings
Returns:
tuple: (query_embedding, article_embeddings)
"""
# Use convert_to_tensor=True to get tensors that will be on the right device (CUDA or CPU)
query_embedding = self.model.encode(query, convert_to_tensor=True, device=DEVICE)
article_embeddings = self.model.encode(article_texts, convert_to_tensor=True, device=DEVICE)
return query_embedding, article_embeddings
def calculate_similarities(self, query_embedding, article_embeddings):
"""
Calculate cosine similarity between query and article embeddings.
Args:
query_embedding: Tensor of query embedding
article_embeddings: Tensor of article embeddings
Returns:
list: List of similarity scores
"""
similarities = torch.nn.functional.cosine_similarity(
query_embedding.unsqueeze(0), article_embeddings
)
return similarities.tolist()
def get_top_articles(self, similarities, articles, top_n, min_threshold):
"""
Get the top N articles based on similarity scores.
Args:
similarities (list): List of similarity scores
articles (list): List of article dictionaries
top_n (int): Number of top results to return
min_threshold (float): Minimum similarity threshold
Returns:
list: List of indices of top articles
"""
# Create a list of (index, similarity) tuples and sort by similarity (descending)
indexed_similarities = list(enumerate(similarities))
indexed_similarities.sort(key=lambda x: x[1], reverse=True)
top_indices = []
# Filter results by threshold
for idx, score in indexed_similarities:
if score >= min_threshold and len(top_indices) < top_n:
top_indices.append(idx)
return top_indices
def format_results(self, top_indices, similarities, articles):
"""
Format the results with rank and similarity score.
Args:
top_indices (list): List of indices of top articles
similarities (list): List of similarity scores
articles (list): List of article dictionaries
Returns:
list: List of dictionaries with ranked articles and their similarity scores
"""
result = []
for i, idx in enumerate(top_indices, 1):
similarity_score = similarities[idx]
article = articles[idx]
# Extract source name (handling both object and string formats)
if isinstance(article['source'], dict):
source_name = article['source'].get('name', '')
else:
source_name = article['source']
# Use publishedAt if available, otherwise use published_at
published_at = article.get('publishedAt', article.get('published_at', ''))
result.append({
'rank': i,
'title': article['title'],
'source': source_name,
'url': article['url'],
'similarity_score': similarity_score,
'published_at': published_at
})
return result