update app.py
Browse files
app.py
CHANGED
@@ -1,23 +1,268 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def detect_phishing(email_text):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"""
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
# Example emails
|
21 |
examples = [
|
22 |
["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
|
23 |
["Hi team, please review the attached document for our quarterly meeting tomorrow."],
|
@@ -26,15 +271,32 @@ examples = [
|
|
26 |
["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
|
27 |
]
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
app.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DistilBertConfig, DistilBertModel, DistilBertPreTrainedModel
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import re
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
|
9 |
+
class DistilBertForSequenceClassificationWithFeatures(DistilBertPreTrainedModel):
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__(config)
|
12 |
+
self.num_labels = config.num_labels
|
13 |
+
self.distilbert = DistilBertModel(config)
|
14 |
+
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
15 |
+
self.classifier = nn.Linear(config.dim + 7, config.num_labels)
|
16 |
+
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
17 |
+
self.post_init()
|
18 |
+
|
19 |
+
def forward(
|
20 |
+
self, input_ids=None, attention_mask=None, labels=None,
|
21 |
+
text_length=None, token_count=None, avg_token_length=None,
|
22 |
+
num_date_tokens=None, has_attachment_reference=None,
|
23 |
+
has_operational_keywords=None, has_phishy_keywords=None
|
24 |
+
):
|
25 |
+
output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
|
26 |
+
pooled_output = output.last_hidden_state[:, 0]
|
27 |
+
pooled_output = self.pre_classifier(pooled_output)
|
28 |
+
pooled_output = nn.ReLU()(pooled_output)
|
29 |
+
pooled_output = self.dropout(pooled_output)
|
30 |
+
|
31 |
+
additional_features = torch.stack([
|
32 |
+
text_length.float(), token_count.float(), avg_token_length.float(),
|
33 |
+
num_date_tokens.float(), has_attachment_reference.float(),
|
34 |
+
has_operational_keywords.float(), has_phishy_keywords.float()
|
35 |
+
], dim=1)
|
36 |
+
|
37 |
+
# Normalize features
|
38 |
+
additional_features = (additional_features - additional_features.mean(dim=0)) / (additional_features.std(dim=0) + 1e-8)
|
39 |
+
|
40 |
+
combined = torch.cat((pooled_output, additional_features), dim=1)
|
41 |
+
logits = self.classifier(combined)
|
42 |
+
|
43 |
+
return logits
|
44 |
+
|
45 |
+
# Load model and tokenizer
|
46 |
+
config = DistilBertConfig.from_pretrained("zionia/email-phishing-detector")
|
47 |
+
model = DistilBertForSequenceClassificationWithFeatures.from_pretrained("zionia/email-phishing-detector", config=config)
|
48 |
tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
|
49 |
|
50 |
+
# Keyword sets
|
51 |
+
DATE_KEYWORDS = {"jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec",
|
52 |
+
"january", "february", "march", "april", "may", "june", "july", "august",
|
53 |
+
"september", "october", "november", "december",
|
54 |
+
*map(str, range(2001, 2026))}
|
55 |
+
|
56 |
+
PHISHY_KEYWORDS = {"verify", "urgent", "login", "click", "bank", "account", "update", "password",
|
57 |
+
"security", "alert", "confirm", "immediately", "action required", "suspended",
|
58 |
+
"verify your account", "limited time", "unauthorized access"}
|
59 |
+
|
60 |
+
ATTACHMENT_KEYWORDS = {".xls", ".xlsx", ".pdf", ".doc", ".docx", "attachment", "attached", "file", "document"}
|
61 |
+
|
62 |
+
OPERATIONAL_KEYWORDS = {"nom", "actual", "vols", "schedule", "attached", "report", "data", "summary",
|
63 |
+
"meeting", "agenda", "minutes", "review", "quarterly", "project"}
|
64 |
+
|
65 |
+
def highlight_keywords(text, keywords, colour):
|
66 |
+
for kw in sorted(keywords, key=len, reverse=True): # Sort by length to match longer phrases first
|
67 |
+
pattern = re.compile(rf"\b({re.escape(kw)})\b", re.IGNORECASE)
|
68 |
+
text = pattern.sub(rf"<mark style='background-color:{colour}; font-weight:bold'>\1</mark>", text)
|
69 |
+
return text
|
70 |
+
|
71 |
+
def extract_features(email_text):
|
72 |
+
lower_text = email_text.lower()
|
73 |
+
words = lower_text.split()
|
74 |
+
token_count = len(words)
|
75 |
+
|
76 |
+
features = {
|
77 |
+
'text_length': len(email_text),
|
78 |
+
'token_count': token_count,
|
79 |
+
'avg_token_length': sum(len(tok) for tok in words) / max(token_count, 1),
|
80 |
+
'num_date_tokens': sum(1 for tok in words if tok in DATE_KEYWORDS),
|
81 |
+
'has_attachment_reference': float(any(k in lower_text for k in ATTACHMENT_KEYWORDS)),
|
82 |
+
'has_operational_keywords': float(any(k in lower_text for k in OPERATIONAL_KEYWORDS)),
|
83 |
+
'has_phishy_keywords': float(any(k in lower_text for k in PHISHY_KEYWORDS)),
|
84 |
+
}
|
85 |
+
|
86 |
+
detected_keywords = {
|
87 |
+
"Phishy Keywords": [kw for kw in PHISHY_KEYWORDS if kw in lower_text],
|
88 |
+
"Attachment References": [kw for kw in ATTACHMENT_KEYWORDS if kw in lower_text],
|
89 |
+
"Operational Terms": [kw for kw in OPERATIONAL_KEYWORDS if kw in lower_text],
|
90 |
+
"Date Mentions": [kw for kw in DATE_KEYWORDS if kw in words],
|
91 |
+
}
|
92 |
+
|
93 |
+
return features, detected_keywords
|
94 |
+
|
95 |
+
def create_feature_plot(features):
|
96 |
+
feature_names = [
|
97 |
+
'Text Length', 'Token Count', 'Avg Token Length',
|
98 |
+
'Date Keywords', 'Attachment Ref', 'Operational Terms', 'Phishy Terms'
|
99 |
+
]
|
100 |
+
values = [features[k] for k in [
|
101 |
+
'text_length', 'token_count', 'avg_token_length',
|
102 |
+
'num_date_tokens', 'has_attachment_reference',
|
103 |
+
'has_operational_keywords', 'has_phishy_keywords'
|
104 |
+
]]
|
105 |
+
|
106 |
+
# Normalize the values for better visualization
|
107 |
+
normalized_values = [(v - min(values)) / (max(values) - min(values) + 1e-8) for v in values]
|
108 |
+
|
109 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
110 |
+
bars = ax.barh(feature_names, normalized_values, color='skyblue')
|
111 |
+
ax.set_xlim(0, 1)
|
112 |
+
ax.set_title('Normalized Feature Values')
|
113 |
+
|
114 |
+
# Add value labels
|
115 |
+
for bar, val in zip(bars, values):
|
116 |
+
ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
|
117 |
+
f'{val:.1f}', va='center')
|
118 |
+
|
119 |
+
plt.tight_layout()
|
120 |
+
return fig
|
121 |
+
|
122 |
def detect_phishing(email_text):
|
123 |
+
if not email_text.strip():
|
124 |
+
return {"decision": "Invalid input: Empty email text", "confidence": 0, "prediction": -1}
|
125 |
+
|
126 |
+
features, detected_keywords = extract_features(email_text)
|
127 |
+
|
128 |
+
inputs = tokenizer(
|
129 |
+
email_text,
|
130 |
+
return_tensors="pt",
|
131 |
+
truncation=True,
|
132 |
+
padding="max_length",
|
133 |
+
max_length=256
|
134 |
+
)
|
135 |
+
|
136 |
+
# Convert features to tensors
|
137 |
+
feature_tensors = {k: torch.tensor([v], dtype=torch.float32) for k, v in features.items()}
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
logits = model(
|
141 |
+
input_ids=inputs['input_ids'],
|
142 |
+
attention_mask=inputs['attention_mask'],
|
143 |
+
**feature_tensors
|
144 |
+
)
|
145 |
+
|
146 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
147 |
+
confidence, pred = torch.max(probs, dim=1)
|
148 |
+
confidence = confidence.item()
|
149 |
+
pred = pred.item()
|
150 |
+
|
151 |
+
return {
|
152 |
+
"decision": "Phishing" if pred == 1 else "Legitimate",
|
153 |
+
"confidence": confidence,
|
154 |
+
"prediction": pred,
|
155 |
+
"features": features,
|
156 |
+
"detected_keywords": detected_keywords
|
157 |
+
}
|
158 |
+
|
159 |
+
def create_highlighted_text(email_text):
|
160 |
+
highlighted = email_text
|
161 |
+
highlighted = highlight_keywords(highlighted, PHISHY_KEYWORDS, "#ffcccc") # red
|
162 |
+
highlighted = highlight_keywords(highlighted, ATTACHMENT_KEYWORDS, "#cce5ff") # blue
|
163 |
+
highlighted = highlight_keywords(highlighted, DATE_KEYWORDS, "#d4edda") # green
|
164 |
+
highlighted = highlight_keywords(highlighted, OPERATIONAL_KEYWORDS, "#fff3cd") # yellow
|
165 |
+
return highlighted
|
166 |
+
|
167 |
+
def create_keyword_table(detected_keywords):
|
168 |
+
table_html = """
|
169 |
+
<table style="width:100%; border-collapse: collapse;">
|
170 |
+
<tr style="background-color: #f2f2f2;">
|
171 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Category</th>
|
172 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Detected Keywords</th>
|
173 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Count</th>
|
174 |
+
</tr>
|
175 |
"""
|
176 |
+
|
177 |
+
for category, keywords in detected_keywords.items():
|
178 |
+
count = len(keywords)
|
179 |
+
if count > 0:
|
180 |
+
color = "#ffcccc" if category == "Phishy Keywords" else "#ffffff"
|
181 |
+
table_html += f"""
|
182 |
+
<tr style="background-color: {color};">
|
183 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>{category}</strong></td>
|
184 |
+
<td style="padding: 8px; border: 1px solid #ddd;">{', '.join(keywords[:5])}{'...' if len(keywords) > 5 else ''}</td>
|
185 |
+
<td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{count}</td>
|
186 |
+
</tr>
|
187 |
+
"""
|
188 |
+
|
189 |
+
table_html += "</table>"
|
190 |
+
return table_html
|
191 |
+
|
192 |
+
def create_decision_output(result):
|
193 |
+
if result["prediction"] == -1:
|
194 |
+
return "<strong style='color:orange'>Invalid input</strong>: Empty email text"
|
195 |
+
|
196 |
+
color = "red" if result["decision"] == "Phishing" else "green"
|
197 |
+
confidence_pct = result["confidence"] * 100
|
198 |
+
|
199 |
+
return f"""
|
200 |
+
<div style='border: 2px solid {color}; padding: 15px; border-radius: 5px;'>
|
201 |
+
<h2 style='color:{color}; margin-top: 0;'>Decision: {result["decision"]}</h2>
|
202 |
+
<p><strong>Confidence:</strong> {confidence_pct:.1f}%</p>
|
203 |
+
<p><strong>Explanation:</strong> {
|
204 |
+
"This email contains suspicious characteristics commonly found in phishing attempts."
|
205 |
+
if result["decision"] == "Phishing" else
|
206 |
+
"This email appears to be legitimate based on its content and characteristics."
|
207 |
+
}</p>
|
208 |
+
</div>
|
209 |
"""
|
210 |
+
|
211 |
+
def analyze_email(email_text):
|
212 |
+
result = detect_phishing(email_text)
|
213 |
|
214 |
+
with gr.Tabs() as tabs:
|
215 |
+
with gr.TabItem("Decision"):
|
216 |
+
gr.HTML(create_decision_output(result))
|
217 |
+
|
218 |
+
with gr.TabItem("Highlighted Text"):
|
219 |
+
highlighted = create_highlighted_text(email_text)
|
220 |
+
gr.HTML(f"""
|
221 |
+
<div style='border: 1px solid #ddd; padding: 15px; border-radius: 5px; background-color: white;'>
|
222 |
+
<h3 style='margin-top: 0;'>Email Content with Detected Features</h3>
|
223 |
+
<div style='background-color: #f9f9f9; padding: 10px; border: 1px solid #eee;'>
|
224 |
+
{highlighted}
|
225 |
+
</div>
|
226 |
+
<div style='margin-top: 15px;'>
|
227 |
+
<span style='display: inline-block; width: 15px; height: 15px; background-color: #ffcccc; margin-right: 5px;'></span> Phishy Keywords
|
228 |
+
<span style='display: inline-block; width: 15px; height: 15px; background-color: #cce5ff; margin-right: 5px; margin-left: 10px;'></span> Attachment References
|
229 |
+
<span style='display: inline-block; width: 15px; height: 15px; background-color: #d4edda; margin-right: 5px; margin-left: 10px;'></span> Date Mentions
|
230 |
+
<span style='display: inline-block; width: 15px; height: 15px; background-color: #fff3cd; margin-right: 5px; margin-left: 10px;'></span> Operational Terms
|
231 |
+
</div>
|
232 |
+
</div>
|
233 |
+
""")
|
234 |
+
|
235 |
+
with gr.TabItem("Detected Features"):
|
236 |
+
fig = create_feature_plot(result["features"])
|
237 |
+
gr.Plot(fig)
|
238 |
+
|
239 |
+
with gr.TabItem("Keyword Analysis"):
|
240 |
+
table_html = create_keyword_table(result["detected_keywords"])
|
241 |
+
gr.HTML(f"""
|
242 |
+
<div style='border: 1px solid #ddd; padding: 15px; border-radius: 5px;'>
|
243 |
+
<h3 style='margin-top: 0;'>Detected Keywords by Category</h3>
|
244 |
+
{table_html}
|
245 |
+
</div>
|
246 |
+
""")
|
247 |
+
|
248 |
+
with gr.TabItem("About"):
|
249 |
+
gr.Markdown("""
|
250 |
+
## COS 720: Email Phishing Detector
|
251 |
+
|
252 |
+
This tool analyzes emails to detect potential phishing attempts using:
|
253 |
+
- **Text content analysis** with DistilBERT model
|
254 |
+
- **Structural features** like length and token statistics
|
255 |
+
- **Keyword detection** for known phishing indicators
|
256 |
+
|
257 |
+
**How to use:**
|
258 |
+
1. Paste email text in the input box
|
259 |
+
2. Click "Analyze Email"
|
260 |
+
3. Explore the different tabs for detailed analysis
|
261 |
+
|
262 |
+
**Disclaimer:** This is a research tool and may produce false positives/negatives.
|
263 |
+
Always use additional verification methods for important communications.
|
264 |
+
""")
|
265 |
|
|
|
266 |
examples = [
|
267 |
["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
|
268 |
["Hi team, please review the attached document for our quarterly meeting tomorrow."],
|
|
|
271 |
["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
|
272 |
]
|
273 |
|
274 |
+
with gr.Blocks(title="COS 720: Email Phishing Detector", theme="soft") as app:
|
275 |
+
gr.Markdown("# COS 720: Email Phishing Detector")
|
276 |
+
gr.Markdown("A lightweight AI-powered phishing email detector that analyzes text and metadata to classify emails with explainable insights.")
|
277 |
+
|
278 |
+
with gr.Row():
|
279 |
+
with gr.Column():
|
280 |
+
email_input = gr.Textbox(
|
281 |
+
label="Email Text",
|
282 |
+
placeholder="Paste the email content here...",
|
283 |
+
lines=8,
|
284 |
+
elem_id="email-input"
|
285 |
+
)
|
286 |
+
analyze_btn = gr.Button("Analyze Email", variant="primary")
|
287 |
+
gr.Examples(
|
288 |
+
examples=examples,
|
289 |
+
inputs=email_input,
|
290 |
+
label="Try these examples:"
|
291 |
+
)
|
292 |
+
|
293 |
+
with gr.Column():
|
294 |
+
analysis_output = gr.Tabs()
|
295 |
+
|
296 |
+
analyze_btn.click(
|
297 |
+
fn=analyze_email,
|
298 |
+
inputs=email_input,
|
299 |
+
outputs=analysis_output
|
300 |
+
)
|
301 |
|
302 |
app.launch()
|