TAgroup5 commited on
Commit
3412bc8
Β·
verified Β·
1 Parent(s): 73c0f99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -97
app.py CHANGED
@@ -4,110 +4,93 @@ import re
4
  import io
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
- from streamlit_extras.app_logo import add_logo # For adding a logo
8
 
9
- # Custom Styling
10
- st.set_page_config(page_title="News Classifier & Q&A", page_icon="πŸ“°", layout="wide")
 
 
11
 
12
- # CSS for styling
 
 
 
 
 
 
 
 
 
13
  st.markdown(
14
  """
15
  <style>
16
- body {
17
- background-color: #f5f5f5;
18
- }
19
- .stApp {
20
- background-color: white;
21
- border-radius: 10px;
22
- padding: 20px;
23
- box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
24
- }
25
- .stTitle, .stHeader {
26
- color: #0073e6;
27
- text-align: center;
28
- }
29
- .stButton>button {
30
- background-color: #0073e6 !important;
31
- color: white !important;
32
- border-radius: 8px !important;
33
- font-size: 16px !important;
34
- }
35
- .stDownloadButton>button {
36
- background-color: #28a745 !important;
37
- color: white !important;
38
- border-radius: 8px !important;
39
- }
40
  </style>
41
  """,
42
  unsafe_allow_html=True,
43
  )
44
 
45
- # Add a logo (optional, replace with your logo URL)
46
- # add_logo("https://your-logo-url.png", height=50)
47
-
48
- st.title("πŸ“° News Classification & Q&A")
49
-
50
- ## ====================== Component 1: News Classification ====================== ##
51
- st.header("πŸ“Œ Classify News Articles")
52
- st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
53
-
54
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
55
-
56
- if uploaded_file is not None:
57
- try:
58
- df = pd.read_csv(uploaded_file, encoding="utf-8")
59
- except UnicodeDecodeError:
60
- df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
61
-
62
- if 'content' not in df.columns:
63
- st.error("❌ The uploaded CSV must contain a 'content' column.")
64
- else:
65
- st.success("βœ… File uploaded successfully!")
66
- st.write("Preview of uploaded data:")
67
- st.dataframe(df.head())
68
-
69
- # Preprocessing function
70
- def preprocess_text(text):
71
- text = text.lower()
72
- text = re.sub(r'\s+', ' ', text)
73
- text = re.sub(r'[^a-z\s]', '', text)
74
- return text
75
-
76
- # Apply preprocessing
77
- df['processed_content'] = df['content'].apply(preprocess_text)
78
-
79
- # Load Model
80
- model_name_classification = "TAgroup5/news-classification-model"
81
- model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
82
- tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
83
- text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
84
-
85
- # Classify each record
86
- df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
87
-
88
- # Display results
89
- st.write("πŸ“Œ Classification Results:")
90
- st.dataframe(df[['content', 'class']])
91
-
92
- # Provide CSV download
93
- output = io.BytesIO()
94
- df.to_csv(output, index=False, encoding="utf-8-sig")
95
- st.download_button(label="πŸ“₯ Download Classified News", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
96
-
97
- ## ====================== Component 2: Q&A ====================== ##
98
- st.header("πŸ’¬ Ask a Question About the News")
99
- st.markdown("Enter a question and provide a news article to get an answer.")
100
-
101
- question = st.text_input("πŸ” Ask a question:")
102
- context = st.text_area("πŸ“ Provide the news article content:", height=150)
103
-
104
- if question and context.strip():
105
- model_name_qa = "distilbert-base-uncased-distilled-squad"
106
- qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
107
- result = qa_pipeline(question=question, context=context)
108
-
109
- # Display answer
110
- if 'answer' in result and result['answer']:
111
- st.success(f"**πŸ—£ Answer:** {result['answer']}")
112
- else:
113
- st.warning("⚠️ No answer found in the provided content.")
 
4
  import io
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
 
7
 
8
+ # Load fine-tuned models and tokenizers for both functions
9
+ model_name_classification = "TAgroup5/news-classification-model" # Replace with the correct model name
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
12
 
13
+ model_name_qa = "distilbert-base-cased-distilled-squad"
14
+ model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
15
+ tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
16
+
17
+ # Initialize pipelines
18
+ text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
19
+ qa_pipeline = pipeline("question-answering", model=model)
20
+
21
+ # Streamlit App Styling
22
+ st.set_page_config(page_title="News Classification & Q&A", page_icon="πŸ“°", layout="wide")
23
  st.markdown(
24
  """
25
  <style>
26
+ body {background-color: #f4f4f4;}
27
+ .title {text-align: center; font-size: 36px; font-weight: bold; color: #ff4b4b;}
28
+ .subheader {font-size: 24px; color: #333; margin-bottom: 20px; text-align: right;}
29
+ .stTextInput>div>div>input {border-radius: 10px;}
30
+ .stTextArea>div>div>textarea {border-radius: 10px;}
31
+ .stButton>button {border-radius: 10px; background-color: #ff4b4b; color: white; font-weight: bold;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  </style>
33
  """,
34
  unsafe_allow_html=True,
35
  )
36
 
37
+ st.markdown('<h1 class="title">πŸ“° News Classification & Q&A App</h1>', unsafe_allow_html=True)
38
+
39
+ col1, col2 = st.columns([2, 1])
40
+ with col2:
41
+ # ====================== Component 1: News Classification ====================== #
42
+ st.markdown('<h2 class="subheader">πŸ“Œ Classify News Articles</h2>', unsafe_allow_html=True)
43
+ st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
44
+
45
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
46
+
47
+ if uploaded_file is not None:
48
+ try:
49
+ df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
50
+ except UnicodeDecodeError:
51
+ df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
52
+
53
+ if 'content' not in df.columns:
54
+ st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
55
+ else:
56
+ st.success("βœ… File successfully uploaded!")
57
+ st.write("Preview of uploaded data:")
58
+ st.dataframe(df.head())
59
+
60
+ # Preprocessing function to clean the text
61
+ def preprocess_text(text):
62
+ text = text.lower() # Convert to lowercase
63
+ text = re.sub(r'\s+', ' ', text) # Remove extra spaces
64
+ text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
65
+ return text
66
+
67
+ # Apply preprocessing and classification
68
+ df['processed_content'] = df['content'].apply(preprocess_text)
69
+ df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
70
+
71
+ # Show results
72
+ st.markdown("### πŸ”Ή Classification Results:")
73
+ st.dataframe(df[['content', 'class']])
74
+
75
+ # Provide CSV download
76
+ output = io.BytesIO()
77
+ df.to_csv(output, index=False, encoding="utf-8-sig")
78
+ st.download_button(label="⬇️ Download classified news", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
79
+
80
+ # ====================== Component 2: Q&A ====================== #
81
+ st.markdown('<h2 class="subheader">❓ Ask a Question About the News</h2>', unsafe_allow_html=True)
82
+ st.markdown("Enter a question and provide a news article to get an answer.")
83
+
84
+ question = st.text_input("πŸ” Ask a question:")
85
+ context = st.text_area("πŸ“ Provide the news article or content:", height=150)
86
+
87
+ if question and context.strip():
88
+ model_name_qa = "distilbert-base-uncased-distilled-squad"
89
+ qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
90
+ result = qa_pipeline(question=question, context=context)
91
+
92
+ # Display Answer
93
+ if 'answer' in result and result['answer']:
94
+ st.markdown(f"### βœ… Answer: {result['answer']}")
95
+ else:
96
+ st.markdown("### ❌ No answer found in the provided content.")