File size: 7,101 Bytes
12cca3e bfa79fd 12cca3e 5e43d3e 12cca3e 5e43d3e 12cca3e 5e43d3e 12cca3e 5e43d3e 12cca3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from pysolr import Solr
import os
import csv
from sentence_transformers import SentenceTransformer, util
import torch
from datetime import datetime
from get_keywords import get_keywords
import os
import re
"""
This function creates top 15 articles from Solr and saves them in a csv file
Input:
query: str
num_articles: int
keyword_type: str (openai, rake, or na)
Output: path to csv file
"""
def sanitize_query(text):
"""Sanitize the query text for Solr."""
# Remove special characters that could break Solr syntax
sanitized = re.sub(r'[[\]{}()*+?\\^|;:!]', ' ', text)
# Normalize whitespace
sanitized = ' '.join(sanitized.split())
return sanitized
def save_solr_articles_full(query: str, num_articles: int, keyword_type: str = "openai") -> str:
try:
keywords = get_keywords(query, keyword_type)
if keyword_type == "na":
keywords = query
# Sanitize keywords before creating Solr query
keywords = sanitize_query(keywords)
return save_solr_articles(keywords, num_articles)
except Exception as e:
raise
"""
Removes spaces and newlines from text
Input: text: str
Output: text: str
"""
def remove_spaces_newlines(text: str) -> str:
text = text.replace('\n', ' ')
text = text.replace(' ', ' ')
return text
# truncates long articles to 1500 words
def truncate_article(text: str) -> str:
split = text.split()
if len(split) > 1500:
split = split[:1500]
text = ' '.join(split)
return text
"""
Searches Solr for articles based on keywords and saves them in a csv file
Input:
keywords: str
num_articles: int
Output: path to csv file
Minor details:
Removes duplicate articles to start with.
Articles with dead urls are removed since those articles are often wierd.
Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes.
If one of title, uuid, cleaned_content, url are missing the article is skipped.
"""
def save_solr_articles(keywords: str, num_articles=15) -> str:
"""Save top articles from Solr search to CSV."""
try:
solr_key = os.getenv("SOLR_KEY")
SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/"
solr = Solr(SOLR_ARTICLES_URL, verify=False)
# No duplicates and must be in English
fq = ['-dups:0', 'is_english:(true)']
# Construct and sanitize query
query = f'text:({keywords}) AND dead_url:(false)'
print(f"Executing Solr query: {query}")
# Use boost function to combine relevance score with recency
# This gives higher weight to more recent articles while still considering relevance
boost_query = "sum(score,product(0.3,recip(ms(NOW,year_month_day),3.16e-11,1,1)))"
try:
outputs = solr.search(
query,
fq=fq,
sort=boost_query + " desc",
rows=num_articles * 2,
fl='*,score' # Include score in results
)
except Exception as e:
print(f"Solr query failed: {str(e)}")
raise
article_count = 0
save_path = os.path.join("data", "articles.csv")
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
with open(save_path, 'w', newline='') as csvfile:
fieldnames = ['title', 'uuid', 'content', 'url', 'domain', 'published_date']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
writer.writeheader()
title_five_words = set()
for d in outputs.docs:
if article_count == num_articles:
break
# Skip if required fields are missing
if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d:
continue
title_cleaned = remove_spaces_newlines(d['title'])
# Skip duplicate titles based on first five words
split = title_cleaned.split()
if len(split) >= 5:
five_words = ' '.join(split[:5])
if five_words in title_five_words:
continue
title_five_words.add(five_words)
article_count += 1
cleaned_content = remove_spaces_newlines(d['cleaned_content'])
cleaned_content = truncate_article(cleaned_content)
domain = d.get('domain', "Not Specified")
raw_date = d.get('year_month_day', "Unknown Date")
# Format the date
if raw_date != "Unknown Date":
try:
publication_date = datetime.strptime(raw_date, "%Y-%m-%d").strftime("%m/%d/%Y")
except ValueError:
publication_date = "Invalid Date"
else:
publication_date = raw_date
writer.writerow({
'title': title_cleaned,
'uuid': d['uuid'],
'content': cleaned_content,
'url': d['url'],
'domain': domain,
'published_date': publication_date
})
print(f"Article saved: {title_cleaned}, {d['uuid']}, {domain}, {publication_date}")
return save_path
except Exception as e:
print(f"Error in save_solr_articles: {str(e)}")
raise
def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15):
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, article_embeddings, top_k=15)
hits = hits[0]
corpus_ids = [item['corpus_id'] for item in hits]
r_contents = [contents[idx] for idx in corpus_ids]
r_titles = [titles[idx] for idx in corpus_ids]
r_uuids = [uuids[idx] for idx in corpus_ids]
r_urls = [urls[idx] for idx in corpus_ids]
save_path = os.path.join("data", "articles.csv")
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
with open(save_path, 'w', newline='', encoding="utf-8") as csvfile:
fieldNames = ['title', 'uuid', 'content', 'url']
writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC)
writer.writeheader()
for i in range(num_articles):
writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]})
return save_path |