zionia commited on
Commit
bd40926
·
verified ·
1 Parent(s): f1ee5a2

update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -21
app.py CHANGED
@@ -1,23 +1,268 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
3
 
4
- model = AutoModelForSequenceClassification.from_pretrained("zionia/email-phishing-detector")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def detect_phishing(email_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
- Detect if the input email text is phishing or not
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
- result = model(email_text)
12
- label = result[0]['label']
13
- score = result[0]['score']
14
 
15
- if label == "LABEL_1":
16
- return f"Phishing detected! (confidence: {score:.2%})"
17
- else:
18
- return f"Legitimate email (confidence: {score:.2%})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Example emails
21
  examples = [
22
  ["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
23
  ["Hi team, please review the attached document for our quarterly meeting tomorrow."],
@@ -26,15 +271,32 @@ examples = [
26
  ["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
27
  ]
28
 
29
- # Create the Gradio interface
30
- app = gr.Interface(
31
- fn=detect_phishing,
32
- inputs=gr.Textbox(label="Email Text", placeholder="Paste the email content here...", lines=5),
33
- outputs=gr.Textbox(label="Detection Result"),
34
- title="Email Phishing Detector",
35
- description="Detect whether an email is phishing or legitimate using AI.",
36
- examples=examples,
37
- theme="soft"
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  app.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, DistilBertConfig, DistilBertModel, DistilBertPreTrainedModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import re
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
 
9
+ class DistilBertForSequenceClassificationWithFeatures(DistilBertPreTrainedModel):
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.num_labels = config.num_labels
13
+ self.distilbert = DistilBertModel(config)
14
+ self.pre_classifier = nn.Linear(config.dim, config.dim)
15
+ self.classifier = nn.Linear(config.dim + 7, config.num_labels)
16
+ self.dropout = nn.Dropout(config.seq_classif_dropout)
17
+ self.post_init()
18
+
19
+ def forward(
20
+ self, input_ids=None, attention_mask=None, labels=None,
21
+ text_length=None, token_count=None, avg_token_length=None,
22
+ num_date_tokens=None, has_attachment_reference=None,
23
+ has_operational_keywords=None, has_phishy_keywords=None
24
+ ):
25
+ output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
26
+ pooled_output = output.last_hidden_state[:, 0]
27
+ pooled_output = self.pre_classifier(pooled_output)
28
+ pooled_output = nn.ReLU()(pooled_output)
29
+ pooled_output = self.dropout(pooled_output)
30
+
31
+ additional_features = torch.stack([
32
+ text_length.float(), token_count.float(), avg_token_length.float(),
33
+ num_date_tokens.float(), has_attachment_reference.float(),
34
+ has_operational_keywords.float(), has_phishy_keywords.float()
35
+ ], dim=1)
36
+
37
+ # Normalize features
38
+ additional_features = (additional_features - additional_features.mean(dim=0)) / (additional_features.std(dim=0) + 1e-8)
39
+
40
+ combined = torch.cat((pooled_output, additional_features), dim=1)
41
+ logits = self.classifier(combined)
42
+
43
+ return logits
44
+
45
+ # Load model and tokenizer
46
+ config = DistilBertConfig.from_pretrained("zionia/email-phishing-detector")
47
+ model = DistilBertForSequenceClassificationWithFeatures.from_pretrained("zionia/email-phishing-detector", config=config)
48
  tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
49
 
50
+ # Keyword sets
51
+ DATE_KEYWORDS = {"jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec",
52
+ "january", "february", "march", "april", "may", "june", "july", "august",
53
+ "september", "october", "november", "december",
54
+ *map(str, range(2001, 2026))}
55
+
56
+ PHISHY_KEYWORDS = {"verify", "urgent", "login", "click", "bank", "account", "update", "password",
57
+ "security", "alert", "confirm", "immediately", "action required", "suspended",
58
+ "verify your account", "limited time", "unauthorized access"}
59
+
60
+ ATTACHMENT_KEYWORDS = {".xls", ".xlsx", ".pdf", ".doc", ".docx", "attachment", "attached", "file", "document"}
61
+
62
+ OPERATIONAL_KEYWORDS = {"nom", "actual", "vols", "schedule", "attached", "report", "data", "summary",
63
+ "meeting", "agenda", "minutes", "review", "quarterly", "project"}
64
+
65
+ def highlight_keywords(text, keywords, colour):
66
+ for kw in sorted(keywords, key=len, reverse=True): # Sort by length to match longer phrases first
67
+ pattern = re.compile(rf"\b({re.escape(kw)})\b", re.IGNORECASE)
68
+ text = pattern.sub(rf"<mark style='background-color:{colour}; font-weight:bold'>\1</mark>", text)
69
+ return text
70
+
71
+ def extract_features(email_text):
72
+ lower_text = email_text.lower()
73
+ words = lower_text.split()
74
+ token_count = len(words)
75
+
76
+ features = {
77
+ 'text_length': len(email_text),
78
+ 'token_count': token_count,
79
+ 'avg_token_length': sum(len(tok) for tok in words) / max(token_count, 1),
80
+ 'num_date_tokens': sum(1 for tok in words if tok in DATE_KEYWORDS),
81
+ 'has_attachment_reference': float(any(k in lower_text for k in ATTACHMENT_KEYWORDS)),
82
+ 'has_operational_keywords': float(any(k in lower_text for k in OPERATIONAL_KEYWORDS)),
83
+ 'has_phishy_keywords': float(any(k in lower_text for k in PHISHY_KEYWORDS)),
84
+ }
85
+
86
+ detected_keywords = {
87
+ "Phishy Keywords": [kw for kw in PHISHY_KEYWORDS if kw in lower_text],
88
+ "Attachment References": [kw for kw in ATTACHMENT_KEYWORDS if kw in lower_text],
89
+ "Operational Terms": [kw for kw in OPERATIONAL_KEYWORDS if kw in lower_text],
90
+ "Date Mentions": [kw for kw in DATE_KEYWORDS if kw in words],
91
+ }
92
+
93
+ return features, detected_keywords
94
+
95
+ def create_feature_plot(features):
96
+ feature_names = [
97
+ 'Text Length', 'Token Count', 'Avg Token Length',
98
+ 'Date Keywords', 'Attachment Ref', 'Operational Terms', 'Phishy Terms'
99
+ ]
100
+ values = [features[k] for k in [
101
+ 'text_length', 'token_count', 'avg_token_length',
102
+ 'num_date_tokens', 'has_attachment_reference',
103
+ 'has_operational_keywords', 'has_phishy_keywords'
104
+ ]]
105
+
106
+ # Normalize the values for better visualization
107
+ normalized_values = [(v - min(values)) / (max(values) - min(values) + 1e-8) for v in values]
108
+
109
+ fig, ax = plt.subplots(figsize=(10, 4))
110
+ bars = ax.barh(feature_names, normalized_values, color='skyblue')
111
+ ax.set_xlim(0, 1)
112
+ ax.set_title('Normalized Feature Values')
113
+
114
+ # Add value labels
115
+ for bar, val in zip(bars, values):
116
+ ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
117
+ f'{val:.1f}', va='center')
118
+
119
+ plt.tight_layout()
120
+ return fig
121
+
122
  def detect_phishing(email_text):
123
+ if not email_text.strip():
124
+ return {"decision": "Invalid input: Empty email text", "confidence": 0, "prediction": -1}
125
+
126
+ features, detected_keywords = extract_features(email_text)
127
+
128
+ inputs = tokenizer(
129
+ email_text,
130
+ return_tensors="pt",
131
+ truncation=True,
132
+ padding="max_length",
133
+ max_length=256
134
+ )
135
+
136
+ # Convert features to tensors
137
+ feature_tensors = {k: torch.tensor([v], dtype=torch.float32) for k, v in features.items()}
138
+
139
+ with torch.no_grad():
140
+ logits = model(
141
+ input_ids=inputs['input_ids'],
142
+ attention_mask=inputs['attention_mask'],
143
+ **feature_tensors
144
+ )
145
+
146
+ probs = torch.nn.functional.softmax(logits, dim=1)
147
+ confidence, pred = torch.max(probs, dim=1)
148
+ confidence = confidence.item()
149
+ pred = pred.item()
150
+
151
+ return {
152
+ "decision": "Phishing" if pred == 1 else "Legitimate",
153
+ "confidence": confidence,
154
+ "prediction": pred,
155
+ "features": features,
156
+ "detected_keywords": detected_keywords
157
+ }
158
+
159
+ def create_highlighted_text(email_text):
160
+ highlighted = email_text
161
+ highlighted = highlight_keywords(highlighted, PHISHY_KEYWORDS, "#ffcccc") # red
162
+ highlighted = highlight_keywords(highlighted, ATTACHMENT_KEYWORDS, "#cce5ff") # blue
163
+ highlighted = highlight_keywords(highlighted, DATE_KEYWORDS, "#d4edda") # green
164
+ highlighted = highlight_keywords(highlighted, OPERATIONAL_KEYWORDS, "#fff3cd") # yellow
165
+ return highlighted
166
+
167
+ def create_keyword_table(detected_keywords):
168
+ table_html = """
169
+ <table style="width:100%; border-collapse: collapse;">
170
+ <tr style="background-color: #f2f2f2;">
171
+ <th style="padding: 8px; border: 1px solid #ddd;">Category</th>
172
+ <th style="padding: 8px; border: 1px solid #ddd;">Detected Keywords</th>
173
+ <th style="padding: 8px; border: 1px solid #ddd;">Count</th>
174
+ </tr>
175
  """
176
+
177
+ for category, keywords in detected_keywords.items():
178
+ count = len(keywords)
179
+ if count > 0:
180
+ color = "#ffcccc" if category == "Phishy Keywords" else "#ffffff"
181
+ table_html += f"""
182
+ <tr style="background-color: {color};">
183
+ <td style="padding: 8px; border: 1px solid #ddd;"><strong>{category}</strong></td>
184
+ <td style="padding: 8px; border: 1px solid #ddd;">{', '.join(keywords[:5])}{'...' if len(keywords) > 5 else ''}</td>
185
+ <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{count}</td>
186
+ </tr>
187
+ """
188
+
189
+ table_html += "</table>"
190
+ return table_html
191
+
192
+ def create_decision_output(result):
193
+ if result["prediction"] == -1:
194
+ return "<strong style='color:orange'>Invalid input</strong>: Empty email text"
195
+
196
+ color = "red" if result["decision"] == "Phishing" else "green"
197
+ confidence_pct = result["confidence"] * 100
198
+
199
+ return f"""
200
+ <div style='border: 2px solid {color}; padding: 15px; border-radius: 5px;'>
201
+ <h2 style='color:{color}; margin-top: 0;'>Decision: {result["decision"]}</h2>
202
+ <p><strong>Confidence:</strong> {confidence_pct:.1f}%</p>
203
+ <p><strong>Explanation:</strong> {
204
+ "This email contains suspicious characteristics commonly found in phishing attempts."
205
+ if result["decision"] == "Phishing" else
206
+ "This email appears to be legitimate based on its content and characteristics."
207
+ }</p>
208
+ </div>
209
  """
210
+
211
+ def analyze_email(email_text):
212
+ result = detect_phishing(email_text)
213
 
214
+ with gr.Tabs() as tabs:
215
+ with gr.TabItem("Decision"):
216
+ gr.HTML(create_decision_output(result))
217
+
218
+ with gr.TabItem("Highlighted Text"):
219
+ highlighted = create_highlighted_text(email_text)
220
+ gr.HTML(f"""
221
+ <div style='border: 1px solid #ddd; padding: 15px; border-radius: 5px; background-color: white;'>
222
+ <h3 style='margin-top: 0;'>Email Content with Detected Features</h3>
223
+ <div style='background-color: #f9f9f9; padding: 10px; border: 1px solid #eee;'>
224
+ {highlighted}
225
+ </div>
226
+ <div style='margin-top: 15px;'>
227
+ <span style='display: inline-block; width: 15px; height: 15px; background-color: #ffcccc; margin-right: 5px;'></span> Phishy Keywords
228
+ <span style='display: inline-block; width: 15px; height: 15px; background-color: #cce5ff; margin-right: 5px; margin-left: 10px;'></span> Attachment References
229
+ <span style='display: inline-block; width: 15px; height: 15px; background-color: #d4edda; margin-right: 5px; margin-left: 10px;'></span> Date Mentions
230
+ <span style='display: inline-block; width: 15px; height: 15px; background-color: #fff3cd; margin-right: 5px; margin-left: 10px;'></span> Operational Terms
231
+ </div>
232
+ </div>
233
+ """)
234
+
235
+ with gr.TabItem("Detected Features"):
236
+ fig = create_feature_plot(result["features"])
237
+ gr.Plot(fig)
238
+
239
+ with gr.TabItem("Keyword Analysis"):
240
+ table_html = create_keyword_table(result["detected_keywords"])
241
+ gr.HTML(f"""
242
+ <div style='border: 1px solid #ddd; padding: 15px; border-radius: 5px;'>
243
+ <h3 style='margin-top: 0;'>Detected Keywords by Category</h3>
244
+ {table_html}
245
+ </div>
246
+ """)
247
+
248
+ with gr.TabItem("About"):
249
+ gr.Markdown("""
250
+ ## COS 720: Email Phishing Detector
251
+
252
+ This tool analyzes emails to detect potential phishing attempts using:
253
+ - **Text content analysis** with DistilBERT model
254
+ - **Structural features** like length and token statistics
255
+ - **Keyword detection** for known phishing indicators
256
+
257
+ **How to use:**
258
+ 1. Paste email text in the input box
259
+ 2. Click "Analyze Email"
260
+ 3. Explore the different tabs for detailed analysis
261
+
262
+ **Disclaimer:** This is a research tool and may produce false positives/negatives.
263
+ Always use additional verification methods for important communications.
264
+ """)
265
 
 
266
  examples = [
267
  ["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
268
  ["Hi team, please review the attached document for our quarterly meeting tomorrow."],
 
271
  ["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
272
  ]
273
 
274
+ with gr.Blocks(title="COS 720: Email Phishing Detector", theme="soft") as app:
275
+ gr.Markdown("# COS 720: Email Phishing Detector")
276
+ gr.Markdown("A lightweight AI-powered phishing email detector that analyzes text and metadata to classify emails with explainable insights.")
277
+
278
+ with gr.Row():
279
+ with gr.Column():
280
+ email_input = gr.Textbox(
281
+ label="Email Text",
282
+ placeholder="Paste the email content here...",
283
+ lines=8,
284
+ elem_id="email-input"
285
+ )
286
+ analyze_btn = gr.Button("Analyze Email", variant="primary")
287
+ gr.Examples(
288
+ examples=examples,
289
+ inputs=email_input,
290
+ label="Try these examples:"
291
+ )
292
+
293
+ with gr.Column():
294
+ analysis_output = gr.Tabs()
295
+
296
+ analyze_btn.click(
297
+ fn=analyze_email,
298
+ inputs=email_input,
299
+ outputs=analysis_output
300
+ )
301
 
302
  app.launch()