import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
# ✅ MUST be first Streamlit command
st.set_page_config(page_title="📰 News Classifier & Q&A App", layout="wide")
# ----------------- Model Loader -----------------
@st.cache_resource
def load_text_classifier():
model_name = "MihanTilk/News_Classifier"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name
)
return pipeline("text-classification", model=model, tokenizer=tokenizer)
# Load Classifier & QA pipeline
classifier = load_text_classifier()
qa_pipeline = pipeline(
"question-answering",
model="deepset/roberta-large-squad2",
tokenizer="deepset/roberta-large-squad2"
)
# ----------------- CSS Styling -----------------
st.markdown(
"""
""",
unsafe_allow_html=True
)
# ----------------- App Title -----------------
st.title("📰 News Classification & Q&A App")
st.markdown("
Upload a CSV to classify news headlines and ask questions!
", unsafe_allow_html=True)
# ----------------- Upload CSV -----------------
st.subheader("📂 Upload a CSV File")
uploaded_file = st.file_uploader("Choose a CSV file...", type=["csv"])
if uploaded_file:
# Read and preprocess
df = pd.read_csv(uploaded_file, encoding='utf-8')
if "content" not in df.columns:
st.error("❌ The uploaded CSV must contain a 'content' column.")
st.stop()
# Preprocess text
df['cleaned_text'] = df['content'].astype(str).str.lower().str.strip()
st.write("📊 Preview of Uploaded Data:", df.head())
# ----------------- Classification -----------------
with st.spinner("🔍 Classifying news articles..."):
df['class'] = df['cleaned_text'].apply(lambda text: classifier(text)[0]['label'])
st.success("✅ Classification Complete!")
st.write("🔎 Classified Results:", df[['content', 'class']].head())
# ----------------- Download -----------------
st.subheader("📥 Download Results")
output_df = df[['content', 'class']]
csv_output = output_df.to_csv(index=False, encoding='utf-8-sig').encode('utf-8-sig')
st.download_button("Download Output CSV", data=csv_output, file_name="output.csv", mime="text/csv")
# ----------------- Q&A Section -----------------
st.subheader("💬 Ask a Question")
question = st.text_input("🔍 What do you want to know about the content?")
if st.button("Get Answer"):
context = " ".join(df['content'].tolist())
with st.spinner("Answering..."):
result = qa_pipeline(question=question, context=context)
st.success(f"📝 Answer: {result['answer']}")
# ----------------- Visualization Section -----------------
st.subheader("📊 Data Visualizations")
# Create two main columns (60/40 split)
main_col1, main_col2 = st.columns([3, 2])
with main_col1:
# ----------------- Topic Distribution -----------------
st.markdown("*Topic Distribution*")
# Create sub-columns for the charts
chart_col1, chart_col2 = st.columns(2)
with chart_col1:
# Compact Pie Chart
fig1, ax1 = plt.subplots(figsize=(4, 4))
df['class'].value_counts().plot.pie(
autopct='%1.1f%%',
startangle=90,
ax=ax1,
colors=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0'],
wedgeprops={'linewidth': 0.5, 'edgecolor': 'white'}
)
ax1.set_ylabel('')
st.pyplot(fig1, use_container_width=True)
with chart_col2:
# Compact Bar Chart
fig2, ax2 = plt.subplots(figsize=(4, 4))
df['class'].value_counts().plot.bar(
color=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0'],
ax=ax2,
width=0.7
)
ax2.set_xlabel('')
ax2.set_ylabel('Count')
plt.xticks(rotation=45, ha='right')
st.pyplot(fig2, use_container_width=True)
with main_col2:
# ----------------- Compact Word Cloud -----------------
st.markdown("*Word Cloud*")
text = " ".join(df['content'].tolist())
wordcloud = WordCloud(
width=300,
height=200,
background_color="white",
collocations=False,
max_words=100
).generate(text)
fig3, ax3 = plt.subplots(figsize=(4, 3))
ax3.imshow(wordcloud, interpolation="bilinear")
ax3.axis("off")
st.pyplot(fig3, use_container_width=True)
# ----------------- Detailed Stats (below) -----------------
with st.expander("📈 Detailed Statistics", expanded=False):
stats_col1, stats_col2 = st.columns(2)
with stats_col1:
st.write("*Category Breakdown:*")
stats_df = df['class'].value_counts().reset_index()
stats_df.columns = ['Category', 'Count']
stats_df['Percentage'] = (stats_df['Count'] / stats_df['Count'].sum() * 100).round(1)
st.dataframe(stats_df, height=200)
with stats_col2:
if 'date' in df.columns:
try:
st.write("*Monthly Trends*")
df['date'] = pd.to_datetime(df['date'])
trends = df.groupby([df['date'].dt.to_period('M'), 'class']).size().unstack()
st.line_chart(trends)
except:
st.warning("Date parsing failed")
# ----------------- News Category Explorer -----------------
st.subheader("🔍 Explore News by Category")
# Get unique categories
categories = df['class'].unique()
# Create 5 columns for category buttons
cols = st.columns(5)
# Create a dictionary to store category articles
category_articles = {category: df[df['class'] == category] for category in categories}
# Place each category button in its own column
for i, category in enumerate(categories):
with cols[i]:
if st.button(category, key=f"btn_{category}"):
# Create pop-up window
with st.popover(f"📰 Articles in {category}", use_container_width=True):
st.markdown(f"### {category} Articles")
articles = category_articles[category]
# Display articles with expandable content
for idx, row in articles.iterrows():
with st.expander(f"Article {idx + 1}: {row['content'][:50]}...", expanded=False):
st.write(row['content'])
st.caption(f"Classification confidence: {classifier(row['content'])[0]['score']:.2f}")
# ----------------- Footer -----------------
st.markdown("---")
st.markdown("🚀 Built with using Streamlit & Hugging Face
", unsafe_allow_html=True)