zionia commited on
Commit
d332f5d
·
verified ·
1 Parent(s): b52b7c2

update for retrained model

Browse files
Files changed (1) hide show
  1. app.py +70 -282
app.py CHANGED
@@ -1,302 +1,90 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, DistilBertConfig, DistilBertModel, DistilBertPreTrainedModel
3
- import torch
4
- import torch.nn as nn
5
  import re
6
- import matplotlib.pyplot as plt
7
  import numpy as np
8
 
9
- class DistilBertForSequenceClassificationWithFeatures(DistilBertPreTrainedModel):
10
- def __init__(self, config):
11
- super().__init__(config)
12
- self.num_labels = config.num_labels
13
- self.distilbert = DistilBertModel(config)
14
- self.pre_classifier = nn.Linear(config.dim, config.dim)
15
- self.classifier = nn.Linear(config.dim + 7, config.num_labels)
16
- self.dropout = nn.Dropout(config.seq_classif_dropout)
17
- self.post_init()
18
-
19
- def forward(
20
- self, input_ids=None, attention_mask=None, labels=None,
21
- text_length=None, token_count=None, avg_token_length=None,
22
- num_date_tokens=None, has_attachment_reference=None,
23
- has_operational_keywords=None, has_phishy_keywords=None
24
- ):
25
- output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
26
- pooled_output = output.last_hidden_state[:, 0]
27
- pooled_output = self.pre_classifier(pooled_output)
28
- pooled_output = nn.ReLU()(pooled_output)
29
- pooled_output = self.dropout(pooled_output)
30
-
31
- additional_features = torch.stack([
32
- text_length.float(), token_count.float(), avg_token_length.float(),
33
- num_date_tokens.float(), has_attachment_reference.float(),
34
- has_operational_keywords.float(), has_phishy_keywords.float()
35
- ], dim=1)
36
-
37
- # Normalize features
38
- additional_features = (additional_features - additional_features.mean(dim=0)) / (additional_features.std(dim=0) + 1e-8)
39
-
40
- combined = torch.cat((pooled_output, additional_features), dim=1)
41
- logits = self.classifier(combined)
42
-
43
- return logits
44
-
45
- # Load model and tokenizer
46
- config = DistilBertConfig.from_pretrained("zionia/email-phishing-detector")
47
- model = DistilBertForSequenceClassificationWithFeatures.from_pretrained("zionia/email-phishing-detector", config=config)
48
  tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
 
49
 
50
- # Keyword sets
51
- DATE_KEYWORDS = {"jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec",
52
- "january", "february", "march", "april", "may", "june", "july", "august",
53
- "september", "october", "november", "december",
54
- *map(str, range(2001, 2026))}
 
 
55
 
56
- PHISHY_KEYWORDS = {"verify", "urgent", "login", "click", "bank", "account", "update", "password",
57
- "security", "alert", "confirm", "immediately", "action required", "suspended",
58
- "verify your account", "limited time", "unauthorized access"}
59
-
60
- ATTACHMENT_KEYWORDS = {".xls", ".xlsx", ".pdf", ".doc", ".docx", "attachment", "attached", "file", "document"}
61
-
62
- OPERATIONAL_KEYWORDS = {"nom", "actual", "vols", "schedule", "attached", "report", "data", "summary",
63
- "meeting", "agenda", "minutes", "review", "quarterly", "project"}
64
-
65
- def highlight_keywords(text, keywords, colour):
66
- for kw in sorted(keywords, key=len, reverse=True): # Sort by length to match longer phrases first
67
- pattern = re.compile(rf"\b({re.escape(kw)})\b", re.IGNORECASE)
68
- text = pattern.sub(rf"<mark style='background-color:{colour}; font-weight:bold'>\1</mark>", text)
69
- return text
 
70
 
71
  def extract_features(email_text):
72
- lower_text = email_text.lower()
73
- words = lower_text.split()
74
- token_count = len(words)
 
 
 
 
75
 
76
  features = {
77
- 'text_length': len(email_text),
78
- 'token_count': token_count,
79
- 'avg_token_length': sum(len(tok) for tok in words) / max(token_count, 1),
80
- 'num_date_tokens': sum(1 for tok in words if tok in DATE_KEYWORDS),
81
- 'has_attachment_reference': float(any(k in lower_text for k in ATTACHMENT_KEYWORDS)),
82
- 'has_operational_keywords': float(any(k in lower_text for k in OPERATIONAL_KEYWORDS)),
83
- 'has_phishy_keywords': float(any(k in lower_text for k in PHISHY_KEYWORDS)),
84
- }
85
-
86
- detected_keywords = {
87
- "Phishy Keywords": [kw for kw in PHISHY_KEYWORDS if kw in lower_text],
88
- "Attachment References": [kw for kw in ATTACHMENT_KEYWORDS if kw in lower_text],
89
- "Operational Terms": [kw for kw in OPERATIONAL_KEYWORDS if kw in lower_text],
90
- "Date Mentions": [kw for kw in DATE_KEYWORDS if kw in words],
91
  }
92
 
93
- return features, detected_keywords
 
94
 
95
- def create_feature_plot(features):
96
- feature_names = [
97
- 'Text Length', 'Token Count', 'Avg Token Length',
98
- 'Date Keywords', 'Attachment Ref', 'Operational Terms', 'Phishy Terms'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  ]
100
- values = [features[k] for k in [
101
- 'text_length', 'token_count', 'avg_token_length',
102
- 'num_date_tokens', 'has_attachment_reference',
103
- 'has_operational_keywords', 'has_phishy_keywords'
104
- ]]
105
-
106
- # Normalize the values for better visualization
107
- normalized_values = [(v - min(values)) / (max(values) - min(values) + 1e-8) for v in values]
108
-
109
- fig, ax = plt.subplots(figsize=(10, 4))
110
- bars = ax.barh(feature_names, normalized_values, color='skyblue')
111
- ax.set_xlim(0, 1)
112
- ax.set_title('Normalized Feature Values')
113
-
114
- # Add value labels
115
- for bar, val in zip(bars, values):
116
- ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
117
- f'{val:.1f}', va='center')
118
-
119
- plt.tight_layout()
120
- return fig
121
 
122
- def detect_phishing(email_text):
123
- if not email_text.strip():
124
- return {"decision": "Invalid input: Empty email text", "confidence": 0, "prediction": -1}
125
-
126
- features, detected_keywords = extract_features(email_text)
127
-
128
- inputs = tokenizer(
129
- email_text,
130
- return_tensors="pt",
131
- truncation=True,
132
- padding="max_length",
133
- max_length=256
134
  )
135
-
136
- # Convert features to tensors
137
- feature_tensors = {k: torch.tensor([v], dtype=torch.float32) for k, v in features.items()}
138
-
139
- with torch.no_grad():
140
- logits = model(
141
- input_ids=inputs['input_ids'],
142
- attention_mask=inputs['attention_mask'],
143
- **feature_tensors
144
- )
145
-
146
- probs = torch.nn.functional.softmax(logits, dim=1)
147
- confidence, pred = torch.max(probs, dim=1)
148
- confidence = confidence.item()
149
- pred = pred.item()
150
-
151
- return {
152
- "decision": "Phishing" if pred == 1 else "Legitimate",
153
- "confidence": confidence,
154
- "prediction": pred,
155
- "features": features,
156
- "detected_keywords": detected_keywords
157
- }
158
 
159
- def create_highlighted_text(email_text):
160
- highlighted = email_text
161
- highlighted = highlight_keywords(highlighted, PHISHY_KEYWORDS, "#ffcccc") # red
162
- highlighted = highlight_keywords(highlighted, ATTACHMENT_KEYWORDS, "#cce5ff") # blue
163
- highlighted = highlight_keywords(highlighted, DATE_KEYWORDS, "#d4edda") # green
164
- highlighted = highlight_keywords(highlighted, OPERATIONAL_KEYWORDS, "#fff3cd") # yellow
165
- return highlighted
166
-
167
- def create_keyword_table(detected_keywords):
168
- table_html = """
169
- <table style="width:100%; border-collapse: collapse;">
170
- <tr style="background-color: #f2f2f2;">
171
- <th style="padding: 8px; border: 1px solid #ddd;">Category</th>
172
- <th style="padding: 8px; border: 1px solid #ddd;">Detected Keywords</th>
173
- <th style="padding: 8px; border: 1px solid #ddd;">Count</th>
174
- </tr>
175
- """
176
-
177
- for category, keywords in detected_keywords.items():
178
- count = len(keywords)
179
- if count > 0:
180
- color = "#ffcccc" if category == "Phishy Keywords" else "#ffffff"
181
- table_html += f"""
182
- <tr style="background-color: {color};">
183
- <td style="padding: 8px; border: 1px solid #ddd;"><strong>{category}</strong></td>
184
- <td style="padding: 8px; border: 1px solid #ddd;">{', '.join(keywords[:5])}{'...' if len(keywords) > 5 else ''}</td>
185
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{count}</td>
186
- </tr>
187
- """
188
-
189
- table_html += "</table>"
190
- return table_html
191
 
192
- def create_decision_output(result):
193
- if result["prediction"] == -1:
194
- return "<strong style='color:orange'>Invalid input</strong>: Empty email text"
195
-
196
- color = "red" if result["decision"] == "Phishing" else "green"
197
- confidence_pct = result["confidence"] * 100
198
-
199
- return f"""
200
- <div style='border: 2px solid {color}; padding: 15px; border-radius: 5px;'>
201
- <h2 style='color:{color}; margin-top: 0;'>Decision: {result["decision"]}</h2>
202
- <p><strong>Confidence:</strong> {confidence_pct:.1f}%</p>
203
- <p><strong>Explanation:</strong> {
204
- "This email contains suspicious characteristics commonly found in phishing attempts."
205
- if result["decision"] == "Phishing" else
206
- "This email appears to be legitimate based on its content and characteristics."
207
- }</p>
208
- </div>
209
- """
210
-
211
- def analyze_email(email_text):
212
- result = detect_phishing(email_text)
213
-
214
- with gr.Tabs() as tabs:
215
- with gr.TabItem("Decision"):
216
- gr.HTML(create_decision_output(result))
217
-
218
- with gr.TabItem("Highlighted Text"):
219
- highlighted = create_highlighted_text(email_text)
220
- gr.HTML(f"""
221
- <div style='border: 1px solid #ddd; padding: 15px; border-radius: 5px; background-color: white;'>
222
- <h3 style='margin-top: 0;'>Email Content with Detected Features</h3>
223
- <div style='background-color: #f9f9f9; padding: 10px; border: 1px solid #eee;'>
224
- {highlighted}
225
- </div>
226
- <div style='margin-top: 15px;'>
227
- <span style='display: inline-block; width: 15px; height: 15px; background-color: #ffcccc; margin-right: 5px;'></span> Phishy Keywords
228
- <span style='display: inline-block; width: 15px; height: 15px; background-color: #cce5ff; margin-right: 5px; margin-left: 10px;'></span> Attachment References
229
- <span style='display: inline-block; width: 15px; height: 15px; background-color: #d4edda; margin-right: 5px; margin-left: 10px;'></span> Date Mentions
230
- <span style='display: inline-block; width: 15px; height: 15px; background-color: #fff3cd; margin-right: 5px; margin-left: 10px;'></span> Operational Terms
231
- </div>
232
- </div>
233
- """)
234
-
235
- with gr.TabItem("Detected Features"):
236
- fig = create_feature_plot(result["features"])
237
- gr.Plot(fig)
238
-
239
- with gr.TabItem("Keyword Analysis"):
240
- table_html = create_keyword_table(result["detected_keywords"])
241
- gr.HTML(f"""
242
- <div style='border: 1px solid #ddd; padding: 15px; border-radius: 5px;'>
243
- <h3 style='margin-top: 0;'>Detected Keywords by Category</h3>
244
- {table_html}
245
- </div>
246
- """)
247
-
248
- with gr.TabItem("About"):
249
- gr.Markdown("""
250
- ## COS 720: Email Phishing Detector
251
-
252
- This tool analyzes emails to detect potential phishing attempts using:
253
- - **Text content analysis** with DistilBERT model
254
- - **Structural features** like length and token statistics
255
- - **Keyword detection** for known phishing indicators
256
-
257
- **How to use:**
258
- 1. Paste email text in the input box
259
- 2. Click "Analyze Email"
260
- 3. Explore the different tabs for detailed analysis
261
-
262
- **Disclaimer:** This is a research tool and may produce false positives/negatives.
263
- Always use additional verification methods for important communications.
264
- """)
265
-
266
- examples = [
267
- ["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
268
- ["Hi team, please review the attached document for our quarterly meeting tomorrow."],
269
- ["URGENT: Your PayPal account will be suspended unless you confirm your details now!"],
270
- ["Hello John, just following up on our conversation yesterday about the project timeline."],
271
- ["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
272
- ]
273
-
274
- with gr.Blocks(title="COS 720: Email Phishing Detector", theme="soft") as app:
275
- gr.Markdown("# COS 720: Email Phishing Detector")
276
- gr.Markdown("A lightweight AI-powered phishing email detector that analyzes text and metadata to classify emails with explainable insights.")
277
-
278
- with gr.Row():
279
- with gr.Column():
280
- email_input = gr.Textbox(
281
- label="Email Text",
282
- placeholder="Paste the email content here...",
283
- lines=8,
284
- elem_id="email-input"
285
- )
286
- analyze_btn = gr.Button("Analyze Email", variant="primary")
287
- gr.Examples(
288
- examples=examples,
289
- inputs=email_input,
290
- label="Try these examples:"
291
- )
292
-
293
- with gr.Column():
294
- analysis_output = gr.Tabs()
295
-
296
- analyze_btn.click(
297
- fn=analyze_email,
298
- inputs=email_input,
299
- outputs=analysis_output
300
- )
301
 
302
- app.launch()
 
1
  import gradio as gr
 
 
 
2
  import re
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
4
  import numpy as np
5
 
6
+ model = AutoModelForSequenceClassification.from_pretrained("zionia/email-phishing-detector")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
8
+ pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
9
 
10
+ PHISHY_KEYWORDS = ["verify", "urgent", "login", "click", "bank", "account", "update", "password",
11
+ "security", "alert", "confirm", "immediately"]
12
+ ATTACHMENT_KEYWORDS = [".xls", ".xlsx", ".pdf", ".doc", ".docx", "attachment", "attached", "file"]
13
+ OPERATIONAL_KEYWORDS = ["nom", "actual", "vols", "schedule", "attached", "report", "data", "summary"]
14
+ DATE_RELATED = {"jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec",
15
+ "january", "february", "march", "april", "may", "june", "july", "august",
16
+ "september", "october", "november", "december"} | {str(y) for y in range(2001, 2026)}
17
 
18
+ def detect_phishing(email_text):
19
+ result = pipe(email_text)
20
+ label = result[0]['label']
21
+ score = result[0]['score']
22
+ if label == "LABEL_1":
23
+ return f"Phishing detected! (Confidence: {score:.2%})"
24
+ else:
25
+ return f"Legitimate email (Confidence: {score:.2%})"
26
+
27
+ def highlight_suspicious_text(email_text):
28
+ highlighted = email_text
29
+ for word in PHISHY_KEYWORDS:
30
+ pattern = re.compile(rf'\b({re.escape(word)})\b', re.IGNORECASE)
31
+ highlighted = pattern.sub(r'<mark style="background-color: #ffcccc">\1</mark>', highlighted)
32
+ return highlighted
33
 
34
  def extract_features(email_text):
35
+ tokens = email_text.lower().split()
36
+ token_count = len(tokens)
37
+ avg_token_len = sum(len(token) for token in tokens) / token_count if token_count > 0 else 0
38
+ date_tokens = sum(1 for token in tokens if token in DATE_RELATED)
39
+ attachment_present = any(ext in email_text.lower() for ext in ATTACHMENT_KEYWORDS)
40
+ operational_terms = any(word in email_text.lower() for word in OPERATIONAL_KEYWORDS)
41
+ phishy_terms = [word for word in PHISHY_KEYWORDS if word in email_text.lower()]
42
 
43
  features = {
44
+ "Text Length": len(email_text),
45
+ "Token Count": token_count,
46
+ "Avg Token Length": round(avg_token_len, 2),
47
+ "Date References": date_tokens,
48
+ "Contains Attachment": "Yes" if attachment_present else "No",
49
+ "Operational Terms Present": "Yes" if operational_terms else "No",
50
+ "Suspicious Keywords": ", ".join(phishy_terms) if phishy_terms else "None"
 
 
 
 
 
 
 
51
  }
52
 
53
+ feature_str = "\n".join([f"{k}: {v}" for k, v in features.items()])
54
+ return feature_str
55
 
56
+ with gr.Blocks(title="Email Phishing Detector") as app:
57
+ gr.Markdown("# Email Phishing Detector")
58
+ gr.Markdown("Use this tool to analyse suspicious emails. It’ll flag phishing attempts and show you what looks dodgy.")
59
+
60
+ with gr.Row():
61
+ email_input = gr.Textbox(label="Email Text", placeholder="Paste the email content here...", lines=10)
62
+
63
+ with gr.Tabs():
64
+ with gr.TabItem("Detection"):
65
+ detection_output = gr.Textbox(label="Result")
66
+ with gr.TabItem("Suspicious Highlights"):
67
+ suspicious_output = gr.HTML(label="Suspicious Keywords Highlighted")
68
+ with gr.TabItem("Feature Breakdown"):
69
+ feature_output = gr.Textbox(label="Analysed Features", lines=8)
70
+
71
+ examples = [
72
+ ["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
73
+ ["Hi team, please review the attached document for our quarterly meeting tomorrow."],
74
+ ["URGENT: Your PayPal account will be suspended unless you confirm your details now!"],
75
+ ["Hello John, just following up on our conversation yesterday about the project timeline."],
76
+ ["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
77
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ gr.Examples(
80
+ examples=examples,
81
+ inputs=email_input
 
 
 
 
 
 
 
 
 
82
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ def full_analysis(email_text):
85
+ return detect_phishing(email_text), highlight_suspicious_text(email_text), extract_features(email_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ email_input.change(fn=full_analysis, inputs=email_input,
88
+ outputs=[detection_output, suspicious_output, feature_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ app.launch()