TAgroup5's picture
Update app.py
6fcbed4 verified
raw
history blame
5.09 kB
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import AutoModelForQuestionAnswering
# Streamlit UI
st.set_page_config(page_title="News Classifier & Q&A", layout="wide")
st.markdown("""
<style>
body {
background-color: #f4f4f4;
color: #333333;
font-family: 'Arial', sans-serif;
}
.stApp {
background-image: url('https://i.pinimg.com/474x/9c/68/86/9c6886dd642a4869f3fa4578f9fe34ef.jpg');
background-size: cover;
background-position: center;
padding: 20px;
border-radius: 10px;
box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
}
h1 {
color: #ff4b4b;
text-align: center;
}
.stButton>button {
background-color: #088da5 !important;
color: white !important;
font-size: 18px !important;
border-radius: 10px !important;
width: 100%;
padding: 10px;
}
.stDownloadButton>button {
background-color: #28a745 !important;
color: white !important;
font-size: 16px !important;
border-radius: 10px !important;
}
</style>
""", unsafe_allow_html=True)
# Load fine-tuned models
model_name_classification = "TAgroup5/news-classification-model"
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
model_name_qa = "distilbert-base-cased-distilled-squad"
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
# Initialize pipelines for both models
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model_qa, tokenizer=tokenizer_qa)
# Streamlit App
st.title(" News Classification and Q&A ")
## ====================== News Classification ====================== ##
st.header("πŸ“Œ Classify News Articles")
st.markdown("Upload a CSV file containing a **'Content'** column to classify news into pre-defined categories.")
uploaded_file = st.file_uploader("πŸ“‚ Choose a CSV file", type="csv")
if uploaded_file is not None:
df = pd.read_csv(uploaded_file, encoding="utf-8")
if 'content' not in df.columns:
st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
else:
st.write("βœ… Preview of uploaded data:")
st.dataframe(df.head())
def preprocess_text(text):
text = text.lower()
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[^a-z\s]', '', text)
return text
df['processed_content'] = df['content'].apply(preprocess_text)
df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
st.write("πŸ” Classification Results:")
st.dataframe(df[['content', 'class']])
output = io.BytesIO()
df.to_csv(output, index=False, encoding="utf-8-sig")
st.download_button("πŸ“₯ Download Classified News", data=output.getvalue(), file_name="output.csv", mime="text/csv")
#App Component 3: Think!Think!Think! - Introducing a News Filtering Option
st.write("πŸ” **Filter by Category**")
categories = ['All', 'Business', 'Opinion', 'Political_gossip', 'Sports', 'World_news']
col1, col2, col3, col4, col5, col6 = st.columns(6)
selected_category = 'All'
with col1:
if st.button("All"):
selected_category = 'All'
with col2:
if st.button("πŸ“ˆ Business"):
selected_category = 'Business'
with col3:
if st.button("πŸ—£ Opinion"):
selected_category = 'Opinion'
with col4:
if st.button("πŸ› Political Gossip"):
selected_category = 'Political_gossip'
with col5:
if st.button("⚽ Sports"):
selected_category = 'Sports'
with col6:
if st.button("🌎 World News"):
selected_category = 'World_news'
if selected_category != 'All':
filtered_df = df[df['class'] == selected_category]
else:
filtered_df = df
st.write(f"πŸ”Ž Showing news articles in category: {selected_category}")
st.dataframe(filtered_df[['content', 'class']])
# Add a separator
st.markdown("---")
## ====================== Q&A ====================== ##
st.header("πŸ’¬ Ask a Question About the News")
question = st.text_input("❓ Ask a question:")
context = st.text_area("πŸ“° Provide the news article or content:", height=150)
if question and context.strip():
result = qa_pipeline(question=question, context=context)
st.success(f"βœ… Answer: {result['answer']}")