Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from wordcloud import WordCloud | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
# β MUST be first Streamlit command | |
st.set_page_config(page_title="π° News Classifier & Q&A App", layout="wide") | |
# ----------------- Model Loader ----------------- | |
def load_text_classifier(): | |
model_name = "MihanTilk/News_Classifier" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_name | |
) | |
return pipeline("text-classification", model=model, tokenizer=tokenizer) | |
# Load Classifier & QA pipeline | |
classifier = load_text_classifier() | |
qa_pipeline = pipeline( | |
"question-answering", | |
model="deepset/roberta-large-squad2", | |
tokenizer="deepset/roberta-large-squad2" | |
) | |
# ----------------- CSS Styling ----------------- | |
st.markdown( | |
""" | |
<style> | |
/* Main background and text colors */ | |
.main { | |
background-color: #f4f4f4; | |
} | |
/* Text input boxes - light blue theme */ | |
.stTextInput>div>div>input, | |
.stTextArea>div>div>textarea { | |
background-color: #e6f2ff; | |
border: 1px solid #b3d1ff; | |
border-radius: 8px; | |
color: #003366; | |
} | |
/* File uploader - matching style */ | |
.stFileUploader>div>div { | |
background-color: #e6f2ff; | |
border: 1px solid #b3d1ff; | |
border-radius: 8px; | |
} | |
/* Buttons - keeping your original style */ | |
.stButton>button { | |
background-color: #ff4b4b; | |
color: white; | |
border-radius: 10px; | |
border: none; | |
} | |
.stDownloadButton>button { | |
background-color: #4CAF50; | |
color: white; | |
border-radius: 10px; | |
border: none; | |
} | |
/* Text colors */ | |
h1, h2, h3, h4, h5, h6 { | |
color: #003366; /* Dark blue for headers */ | |
} | |
p { | |
color: #336699; /* Medium blue for paragraphs */ | |
} | |
/* Dataframe styling */ | |
.dataframe { | |
background-color: #e6f2ff; | |
border: 1px solid #b3d1ff; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# ----------------- App Title ----------------- | |
st.title("π° News Classification & Q&A App") | |
st.markdown("<h4 style='color:#ff4b4b;'>Upload a CSV to classify news headlines and ask questions!</h4>", unsafe_allow_html=True) | |
# ----------------- Upload CSV ----------------- | |
st.subheader("π Upload a CSV File") | |
uploaded_file = st.file_uploader("Choose a CSV file...", type=["csv"]) | |
if uploaded_file: | |
# Read and preprocess | |
df = pd.read_csv(uploaded_file, encoding='utf-8') | |
if "content" not in df.columns: | |
st.error("β The uploaded CSV must contain a 'content' column.") | |
st.stop() | |
# Preprocess text | |
df['cleaned_text'] = df['content'].astype(str).str.lower().str.strip() | |
st.write("π Preview of Uploaded Data:", df.head()) | |
# ----------------- Classification ----------------- | |
with st.spinner("π Classifying news articles..."): | |
df['class'] = df['cleaned_text'].apply(lambda text: classifier(text)[0]['label']) | |
st.success("β Classification Complete!") | |
st.write("π Classified Results:", df[['content', 'class']].head()) | |
# ----------------- Download ----------------- | |
st.subheader("π₯ Download Results") | |
output_df = df[['content', 'class']] | |
csv_output = output_df.to_csv(index=False, encoding='utf-8-sig').encode('utf-8-sig') | |
st.download_button("Download Output CSV", data=csv_output, file_name="output.csv", mime="text/csv") | |
# ----------------- Q&A Section ----------------- | |
st.subheader("π¬ Ask a Question") | |
question = st.text_input("π What do you want to know about the content?") | |
if st.button("Get Answer"): | |
context = " ".join(df['content'].tolist()) | |
with st.spinner("Answering..."): | |
result = qa_pipeline(question=question, context=context) | |
st.success(f"π Answer: {result['answer']}") | |
# ----------------- Visualization Section ----------------- | |
st.subheader("π Data Visualizations") | |
# Create two main columns (60/40 split) | |
main_col1, main_col2 = st.columns([3, 2]) | |
with main_col1: | |
# ----------------- Topic Distribution ----------------- | |
st.markdown("*Topic Distribution*") | |
# Create sub-columns for the charts | |
chart_col1, chart_col2 = st.columns(2) | |
with chart_col1: | |
# Compact Pie Chart | |
fig1, ax1 = plt.subplots(figsize=(4, 4)) | |
df['class'].value_counts().plot.pie( | |
autopct='%1.1f%%', | |
startangle=90, | |
ax=ax1, | |
colors=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0'], | |
wedgeprops={'linewidth': 0.5, 'edgecolor': 'white'} | |
) | |
ax1.set_ylabel('') | |
st.pyplot(fig1, use_container_width=True) | |
with chart_col2: | |
# Compact Bar Chart | |
fig2, ax2 = plt.subplots(figsize=(4, 4)) | |
df['class'].value_counts().plot.bar( | |
color=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0'], | |
ax=ax2, | |
width=0.7 | |
) | |
ax2.set_xlabel('') | |
ax2.set_ylabel('Count') | |
plt.xticks(rotation=45, ha='right') | |
st.pyplot(fig2, use_container_width=True) | |
with main_col2: | |
# ----------------- Compact Word Cloud ----------------- | |
st.markdown("*Word Cloud*") | |
text = " ".join(df['content'].tolist()) | |
wordcloud = WordCloud( | |
width=300, | |
height=200, | |
background_color="white", | |
collocations=False, | |
max_words=100 | |
).generate(text) | |
fig3, ax3 = plt.subplots(figsize=(4, 3)) | |
ax3.imshow(wordcloud, interpolation="bilinear") | |
ax3.axis("off") | |
st.pyplot(fig3, use_container_width=True) | |
# ----------------- Detailed Stats (below) ----------------- | |
with st.expander("π Detailed Statistics", expanded=False): | |
stats_col1, stats_col2 = st.columns(2) | |
with stats_col1: | |
st.write("*Category Breakdown:*") | |
stats_df = df['class'].value_counts().reset_index() | |
stats_df.columns = ['Category', 'Count'] | |
stats_df['Percentage'] = (stats_df['Count'] / stats_df['Count'].sum() * 100).round(1) | |
st.dataframe(stats_df, height=200) | |
with stats_col2: | |
if 'date' in df.columns: | |
try: | |
st.write("*Monthly Trends*") | |
df['date'] = pd.to_datetime(df['date']) | |
trends = df.groupby([df['date'].dt.to_period('M'), 'class']).size().unstack() | |
st.line_chart(trends) | |
except: | |
st.warning("Date parsing failed") | |
# ----------------- News Category Explorer ----------------- | |
st.subheader("π Explore News by Category") | |
# Get unique categories | |
categories = df['class'].unique() | |
# Create 5 columns for category buttons | |
cols = st.columns(5) | |
# Create a dictionary to store category articles | |
category_articles = {category: df[df['class'] == category] for category in categories} | |
# Place each category button in its own column | |
for i, category in enumerate(categories): | |
with cols[i]: | |
if st.button(category, key=f"btn_{category}"): | |
# Create pop-up window | |
with st.popover(f"π° Articles in {category}", use_container_width=True): | |
st.markdown(f"### {category} Articles") | |
articles = category_articles[category] | |
# Display articles with expandable content | |
for idx, row in articles.iterrows(): | |
with st.expander(f"Article {idx + 1}: {row['content'][:50]}...", expanded=False): | |
st.write(row['content']) | |
st.caption(f"Classification confidence: {classifier(row['content'])[0]['score']:.2f}") | |
# ----------------- Footer ----------------- | |
st.markdown("---") | |
st.markdown("<p style='text-align:center; color:#666;'>π Built with using Streamlit & Hugging Face</p>", unsafe_allow_html=True) |