revert
Browse files
app.py
CHANGED
@@ -1,139 +1,23 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import pipeline
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
15 |
-
self.post_init()
|
16 |
-
|
17 |
-
def forward(
|
18 |
-
self, input_ids=None, attention_mask=None, labels=None,
|
19 |
-
text_length=None, token_count=None, avg_token_length=None,
|
20 |
-
num_date_tokens=None, has_attachment_reference=None,
|
21 |
-
has_operational_keywords=None, has_phishy_keywords=None
|
22 |
-
):
|
23 |
-
output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
|
24 |
-
pooled_output = output.last_hidden_state[:, 0]
|
25 |
-
pooled_output = self.pre_classifier(pooled_output)
|
26 |
-
pooled_output = nn.ReLU()(pooled_output)
|
27 |
-
pooled_output = self.dropout(pooled_output)
|
28 |
-
|
29 |
-
additional_features = torch.stack([
|
30 |
-
text_length.float(), token_count.float(), avg_token_length.float(),
|
31 |
-
num_date_tokens.float(), has_attachment_reference.float(),
|
32 |
-
has_operational_keywords.float(), has_phishy_keywords.float()
|
33 |
-
], dim=1)
|
34 |
-
|
35 |
-
additional_features = (additional_features - additional_features.mean(dim=0)) / (additional_features.std(dim=0) + 1e-8)
|
36 |
-
|
37 |
-
combined = torch.cat((pooled_output, additional_features), dim=1)
|
38 |
-
logits = self.classifier(combined)
|
39 |
-
|
40 |
-
return logits
|
41 |
-
|
42 |
-
config = DistilBertConfig.from_pretrained("zionia/email-phishing-detector")
|
43 |
-
model = DistilBertForSequenceClassificationWithFeatures.from_pretrained("zionia/email-phishing-detector", config=config)
|
44 |
-
tokenizer = DistilBertTokenizer.from_pretrained("zionia/email-phishing-detector")
|
45 |
-
|
46 |
-
DATE_KEYWORDS = {"jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec",
|
47 |
-
"january", "february", "march", "april", "may", "june", "july", "august",
|
48 |
-
"september", "october", "november", "december",
|
49 |
-
*map(str, range(2001, 2026))}
|
50 |
-
|
51 |
-
PHISHY_KEYWORDS = {"verify", "urgent", "login", "click", "bank", "account", "update", "password",
|
52 |
-
"security", "alert", "confirm", "immediately"}
|
53 |
-
|
54 |
-
ATTACHMENT_KEYWORDS = {".xls", ".xlsx", ".pdf", ".doc", ".docx", "attachment", "attached", "file"}
|
55 |
-
|
56 |
-
OPERATIONAL_KEYWORDS = {"nom", "actual", "vols", "schedule", "attached", "report", "data", "summary"}
|
57 |
-
|
58 |
-
def explain_features(email_text):
|
59 |
-
lower_text = email_text.lower()
|
60 |
-
words = lower_text.split()
|
61 |
-
|
62 |
-
def highlight_keywords(text, keywords, colour):
|
63 |
-
for kw in keywords:
|
64 |
-
pattern = re.compile(rf"\b({re.escape(kw)})\b", re.IGNORECASE)
|
65 |
-
text = pattern.sub(rf"<mark style='background-color:{colour}; font-weight:bold'>\1</mark>", text)
|
66 |
-
return text
|
67 |
-
|
68 |
-
highlighted_text = email_text
|
69 |
-
highlighted_text = highlight_keywords(highlighted_text, PHISHY_KEYWORDS, "#ffcccc") # red
|
70 |
-
highlighted_text = highlight_keywords(highlighted_text, ATTACHMENT_KEYWORDS, "#cce5ff") # blue
|
71 |
-
highlighted_text = highlight_keywords(highlighted_text, DATE_KEYWORDS, "#d4edda") # green
|
72 |
-
highlighted_text = highlight_keywords(highlighted_text, OPERATIONAL_KEYWORDS, "#fff3cd") # yellow
|
73 |
-
|
74 |
-
features_detected = {
|
75 |
-
"Phishy keywords": [kw for kw in PHISHY_KEYWORDS if kw in lower_text],
|
76 |
-
"Attachment refs": [kw for kw in ATTACHMENT_KEYWORDS if kw in lower_text],
|
77 |
-
"Operational terms": [kw for kw in OPERATIONAL_KEYWORDS if kw in lower_text],
|
78 |
-
"Date mentions": [kw for kw in DATE_KEYWORDS if kw in words],
|
79 |
-
}
|
80 |
-
|
81 |
-
return highlighted_text, features_detected
|
82 |
-
|
83 |
-
|
84 |
-
def detect_and_explain(email_text):
|
85 |
-
inputs = tokenizer(email_text, return_tensors="pt", truncation=True, padding="max_length", max_length=256)
|
86 |
-
|
87 |
-
lower_text = email_text.lower()
|
88 |
-
tokens = lower_text.split()
|
89 |
-
token_count = len(tokens)
|
90 |
|
91 |
-
if
|
92 |
-
return "
|
93 |
-
|
94 |
-
features = {
|
95 |
-
'text_length': torch.tensor([len(email_text)], dtype=torch.float32),
|
96 |
-
'token_count': torch.tensor([token_count], dtype=torch.float32),
|
97 |
-
'avg_token_length': torch.tensor([sum(len(tok) for tok in tokens) / max(token_count, 1)], dtype=torch.float32),
|
98 |
-
'num_date_tokens': torch.tensor([sum(1 for tok in tokens if tok in DATE_KEYWORDS)], dtype=torch.float32),
|
99 |
-
'has_attachment_reference': torch.tensor([float(any(k in lower_text for k in ATTACHMENT_KEYWORDS))], dtype=torch.float32),
|
100 |
-
'has_operational_keywords': torch.tensor([float(any(k in lower_text for k in OPERATIONAL_KEYWORDS))], dtype=torch.float32),
|
101 |
-
'has_phishy_keywords': torch.tensor([float(any(k in lower_text for k in PHISHY_KEYWORDS))], dtype=torch.float32),
|
102 |
-
}
|
103 |
-
|
104 |
-
with torch.no_grad():
|
105 |
-
outputs = model(
|
106 |
-
input_ids=inputs['input_ids'],
|
107 |
-
attention_mask=inputs['attention_mask'],
|
108 |
-
**features
|
109 |
-
)
|
110 |
-
|
111 |
-
if isinstance(outputs, torch.Tensor):
|
112 |
-
logits = outputs
|
113 |
-
else:
|
114 |
-
logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
|
115 |
-
|
116 |
-
probs = torch.nn.functional.softmax(logits, dim=1)
|
117 |
-
confidence, pred = torch.max(probs, dim=1)
|
118 |
-
confidence = confidence.item()
|
119 |
-
pred = pred.item()
|
120 |
-
|
121 |
-
highlight_html, features_dict = explain_features(email_text)
|
122 |
-
|
123 |
-
if pred == 1:
|
124 |
-
decision = f"<strong style='color:red'>Phishing detected!</strong> (confidence: {confidence:.2%})"
|
125 |
else:
|
126 |
-
|
127 |
-
|
128 |
-
feature_html = "<br><u><strong>Detected Indicators:</strong></u><ul>"
|
129 |
-
for category, items in features_dict.items():
|
130 |
-
if items:
|
131 |
-
item_str = ", ".join(items[:5]) # Limit to 5 items per category
|
132 |
-
feature_html += f"<li><strong>{category}:</strong> {item_str}</li>"
|
133 |
-
feature_html += "</ul>"
|
134 |
-
|
135 |
-
return f"{decision}<br><br>{feature_html}<u><strong>Email Highlight View:</strong></u><br>{highlight_html}"
|
136 |
|
|
|
137 |
examples = [
|
138 |
["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
|
139 |
["Hi team, please review the attached document for our quarterly meeting tomorrow."],
|
@@ -142,12 +26,13 @@ examples = [
|
|
142 |
["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
|
143 |
]
|
144 |
|
|
|
145 |
app = gr.Interface(
|
146 |
-
fn=
|
147 |
inputs=gr.Textbox(label="Email Text", placeholder="Paste the email content here...", lines=5),
|
148 |
-
outputs=gr.
|
149 |
-
title="
|
150 |
-
description="
|
151 |
examples=examples,
|
152 |
theme="soft"
|
153 |
)
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
|
4 |
+
odel = AutoModelForSequenceClassification.from_pretrained("zionia/email-phishing-detector")
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
|
6 |
+
|
7 |
+
def detect_phishing(email_text):
|
8 |
+
"""
|
9 |
+
Detect if the input email text is phishing or not
|
10 |
+
"""
|
11 |
+
result = model(email_text)
|
12 |
+
label = result[0]['label']
|
13 |
+
score = result[0]['score']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
if label == "LABEL_1":
|
16 |
+
return f"Phishing detected! (confidence: {score:.2%})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
else:
|
18 |
+
return f"Legitimate email (confidence: {score:.2%})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
# Example emails
|
21 |
examples = [
|
22 |
["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
|
23 |
["Hi team, please review the attached document for our quarterly meeting tomorrow."],
|
|
|
26 |
["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
|
27 |
]
|
28 |
|
29 |
+
# Create the Gradio interface
|
30 |
app = gr.Interface(
|
31 |
+
fn=detect_phishing,
|
32 |
inputs=gr.Textbox(label="Email Text", placeholder="Paste the email content here...", lines=5),
|
33 |
+
outputs=gr.Textbox(label="Detection Result"),
|
34 |
+
title="Email Phishing Detector",
|
35 |
+
description="Detect whether an email is phishing or legitimate using AI.",
|
36 |
examples=examples,
|
37 |
theme="soft"
|
38 |
)
|