File size: 4,499 Bytes
246133f
 
5cdc45c
237a63b
5cdc45c
ec68c76
73c0f99
f72fd34
 
 
 
 
 
 
 
 
 
6afa531
f72fd34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3412bc8
 
73c0f99
3412bc8
 
 
 
 
9ff5a0e
 
2ec8bb5
9ff5a0e
90bfd68
2ec8bb5
9ff5a0e
f72fd34
 
9ff5a0e
f72fd34
9ff5a0e
 
 
f72fd34
9ff5a0e
 
 
 
f72fd34
9ff5a0e
f72fd34
9ff5a0e
 
f72fd34
9ff5a0e
6afa531
 
 
9ff5a0e
 
6afa531
9ff5a0e
6afa531
 
9ff5a0e
 
 
f72fd34
9ff5a0e
 
 
 
 
f72fd34
9ff5a0e
 
f72fd34
 
9ff5a0e
f72fd34
 
9ff5a0e
 
f72fd34
9ff5a0e
 
 
 
f72fd34
9ff5a0e
f72fd34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import AutoModelForQuestionAnswering

# Streamlit UI
st.set_page_config(page_title="News Classifier & Q&A", layout="wide")
st.markdown("""
    <style>
        body {
            background-color: #f4f4f4;
            color: #333333;
            font-family: 'Arial', sans-serif;
        }
        .stApp {
            background-color: black;
            padding: 20px;
            border-radius: 10px;
            box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
        }
        h1, h2 {
            color: #ff4b4b;
        }
        .stButton>button {
            background-color: #ff4b4b !important;
            color: white;
            font-size: 16px;
            border-radius: 5px;
        }
        .stDownloadButton>button {
            background-color: #28a745 !important;
            color: white;
            font-size: 16px;
            border-radius: 5px;
        }
        .stTextInput>div>div>input {
            border-radius: 5px;
            border: 1px solid #ccc;
        }
        .stTextArea>div>textarea {
            border-radius: 5px;
            border: 1px solid #ccc;
        }
    </style>
""", unsafe_allow_html=True)

# Load fine-tuned models and tokenizers
model_name_classification = "TAgroup5/news-classification-model"
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
tokenizer = AutoTokenizer.from_pretrained(model_name_classification)

model_name_qa = "distilbert-base-cased-distilled-squad"
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)

# Initialize pipelines
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)

# Streamlit App
st.title(" News Classification and Q&A ")

## ====================== Component 1: News Classification ====================== ##
st.header("πŸ“Œ Classify News Articles")
st.markdown("Upload a CSV file with a **'content'** column to classify news into categories.")

uploaded_file = st.file_uploader("πŸ“‚ Choose a CSV file", type="csv")

if uploaded_file is not None:
    try:
        df = pd.read_csv(uploaded_file, encoding="utf-8")
    except UnicodeDecodeError:
        df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")

    if 'content' not in df.columns:
        st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
    else:
        st.write("βœ… Preview of uploaded data:")
        st.dataframe(df.head())

        # Preprocessing function
        def preprocess_text(text):
            text = text.lower() # Convert to lowercase
            text = re.sub(r'\s+', ' ', text) # Remove extra spaces
            text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
            return text

        # Apply preprocessing and classification
        df['processed_content'] = df['content'].apply(preprocess_text)

        # Classify each record into one of the five classes
        df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")

        # Show results
        st.write("πŸ” Classification Results:")
        st.dataframe(df[['content', 'class']])

        # Provide CSV download
        output = io.BytesIO()
        df.to_csv(output, index=False, encoding="utf-8-sig")
        st.download_button(label="πŸ“₯ Download Classified News", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")

## ====================== Component 2: Q&A ====================== ##
st.header("πŸ’¬ Ask a Question About the News")
st.markdown("Enter a question and provide a news article to get an AI-generated answer.")

question = st.text_input("❓ Ask a question:")
context = st.text_area("πŸ“° Provide the news article or content:", height=150)

if question and context.strip():
    model_name_qa = "distilbert-base-uncased-distilled-squad"
    qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
    result = qa_pipeline(question=question, context=context)
    
    if 'answer' in result and result['answer']:
        st.success(f"βœ… Answer: {result['answer']}")
    else:
        st.warning("⚠️ No answer found in the provided content.")