TAgroup5's picture
Update app.py
90bfd68 verified
raw
history blame
4.5 kB
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import AutoModelForQuestionAnswering
# Streamlit UI
st.set_page_config(page_title="News Classifier & Q&A", layout="wide")
st.markdown("""
<style>
body {
background-color: #f4f4f4;
color: #333333;
font-family: 'Arial', sans-serif;
}
.stApp {
background-color: black;
padding: 20px;
border-radius: 10px;
box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
}
h1, h2 {
color: #ff4b4b;
}
.stButton>button {
background-color: #ff4b4b !important;
color: white;
font-size: 16px;
border-radius: 5px;
}
.stDownloadButton>button {
background-color: #28a745 !important;
color: white;
font-size: 16px;
border-radius: 5px;
}
.stTextInput>div>div>input {
border-radius: 5px;
border: 1px solid #ccc;
}
.stTextArea>div>textarea {
border-radius: 5px;
border: 1px solid #ccc;
}
</style>
""", unsafe_allow_html=True)
# Load fine-tuned models and tokenizers
model_name_classification = "TAgroup5/news-classification-model"
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
model_name_qa = "distilbert-base-cased-distilled-squad"
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
# Initialize pipelines
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
# Streamlit App
st.title(" News Classification and Q&A ")
## ====================== Component 1: News Classification ====================== ##
st.header("πŸ“Œ Classify News Articles")
st.markdown("Upload a CSV file with a **'content'** column to classify news into categories.")
uploaded_file = st.file_uploader("πŸ“‚ Choose a CSV file", type="csv")
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file, encoding="utf-8")
except UnicodeDecodeError:
df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
if 'content' not in df.columns:
st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
else:
st.write("βœ… Preview of uploaded data:")
st.dataframe(df.head())
# Preprocessing function
def preprocess_text(text):
text = text.lower() # Convert to lowercase
text = re.sub(r'\s+', ' ', text) # Remove extra spaces
text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
return text
# Apply preprocessing and classification
df['processed_content'] = df['content'].apply(preprocess_text)
# Classify each record into one of the five classes
df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
# Show results
st.write("πŸ” Classification Results:")
st.dataframe(df[['content', 'class']])
# Provide CSV download
output = io.BytesIO()
df.to_csv(output, index=False, encoding="utf-8-sig")
st.download_button(label="πŸ“₯ Download Classified News", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
## ====================== Component 2: Q&A ====================== ##
st.header("πŸ’¬ Ask a Question About the News")
st.markdown("Enter a question and provide a news article to get an AI-generated answer.")
question = st.text_input("❓ Ask a question:")
context = st.text_area("πŸ“° Provide the news article or content:", height=150)
if question and context.strip():
model_name_qa = "distilbert-base-uncased-distilled-squad"
qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
result = qa_pipeline(question=question, context=context)
if 'answer' in result and result['answer']:
st.success(f"βœ… Answer: {result['answer']}")
else:
st.warning("⚠️ No answer found in the provided content.")