zionia commited on
Commit
682e36a
·
verified ·
1 Parent(s): 1583b76
Files changed (1) hide show
  1. app.py +21 -136
app.py CHANGED
@@ -1,139 +1,23 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, DistilBertTokenizer, DistilBertConfig, DistilBertModel, DistilBertPreTrainedModel
3
- import torch
4
- import torch.nn as nn
5
- import re
6
-
7
- class DistilBertForSequenceClassificationWithFeatures(DistilBertPreTrainedModel):
8
- def __init__(self, config):
9
- super().__init__(config)
10
- self.num_labels = config.num_labels
11
- self.distilbert = DistilBertModel(config)
12
- self.pre_classifier = nn.Linear(config.dim, config.dim)
13
- self.classifier = nn.Linear(config.dim + 7, config.num_labels)
14
- self.dropout = nn.Dropout(config.seq_classif_dropout)
15
- self.post_init()
16
-
17
- def forward(
18
- self, input_ids=None, attention_mask=None, labels=None,
19
- text_length=None, token_count=None, avg_token_length=None,
20
- num_date_tokens=None, has_attachment_reference=None,
21
- has_operational_keywords=None, has_phishy_keywords=None
22
- ):
23
- output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
24
- pooled_output = output.last_hidden_state[:, 0]
25
- pooled_output = self.pre_classifier(pooled_output)
26
- pooled_output = nn.ReLU()(pooled_output)
27
- pooled_output = self.dropout(pooled_output)
28
-
29
- additional_features = torch.stack([
30
- text_length.float(), token_count.float(), avg_token_length.float(),
31
- num_date_tokens.float(), has_attachment_reference.float(),
32
- has_operational_keywords.float(), has_phishy_keywords.float()
33
- ], dim=1)
34
-
35
- additional_features = (additional_features - additional_features.mean(dim=0)) / (additional_features.std(dim=0) + 1e-8)
36
-
37
- combined = torch.cat((pooled_output, additional_features), dim=1)
38
- logits = self.classifier(combined)
39
-
40
- return logits
41
-
42
- config = DistilBertConfig.from_pretrained("zionia/email-phishing-detector")
43
- model = DistilBertForSequenceClassificationWithFeatures.from_pretrained("zionia/email-phishing-detector", config=config)
44
- tokenizer = DistilBertTokenizer.from_pretrained("zionia/email-phishing-detector")
45
-
46
- DATE_KEYWORDS = {"jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec",
47
- "january", "february", "march", "april", "may", "june", "july", "august",
48
- "september", "october", "november", "december",
49
- *map(str, range(2001, 2026))}
50
-
51
- PHISHY_KEYWORDS = {"verify", "urgent", "login", "click", "bank", "account", "update", "password",
52
- "security", "alert", "confirm", "immediately"}
53
-
54
- ATTACHMENT_KEYWORDS = {".xls", ".xlsx", ".pdf", ".doc", ".docx", "attachment", "attached", "file"}
55
-
56
- OPERATIONAL_KEYWORDS = {"nom", "actual", "vols", "schedule", "attached", "report", "data", "summary"}
57
-
58
- def explain_features(email_text):
59
- lower_text = email_text.lower()
60
- words = lower_text.split()
61
-
62
- def highlight_keywords(text, keywords, colour):
63
- for kw in keywords:
64
- pattern = re.compile(rf"\b({re.escape(kw)})\b", re.IGNORECASE)
65
- text = pattern.sub(rf"<mark style='background-color:{colour}; font-weight:bold'>\1</mark>", text)
66
- return text
67
-
68
- highlighted_text = email_text
69
- highlighted_text = highlight_keywords(highlighted_text, PHISHY_KEYWORDS, "#ffcccc") # red
70
- highlighted_text = highlight_keywords(highlighted_text, ATTACHMENT_KEYWORDS, "#cce5ff") # blue
71
- highlighted_text = highlight_keywords(highlighted_text, DATE_KEYWORDS, "#d4edda") # green
72
- highlighted_text = highlight_keywords(highlighted_text, OPERATIONAL_KEYWORDS, "#fff3cd") # yellow
73
-
74
- features_detected = {
75
- "Phishy keywords": [kw for kw in PHISHY_KEYWORDS if kw in lower_text],
76
- "Attachment refs": [kw for kw in ATTACHMENT_KEYWORDS if kw in lower_text],
77
- "Operational terms": [kw for kw in OPERATIONAL_KEYWORDS if kw in lower_text],
78
- "Date mentions": [kw for kw in DATE_KEYWORDS if kw in words],
79
- }
80
-
81
- return highlighted_text, features_detected
82
-
83
-
84
- def detect_and_explain(email_text):
85
- inputs = tokenizer(email_text, return_tensors="pt", truncation=True, padding="max_length", max_length=256)
86
-
87
- lower_text = email_text.lower()
88
- tokens = lower_text.split()
89
- token_count = len(tokens)
90
 
91
- if token_count == 0:
92
- return "<strong style='color:orange'>Invalid input</strong>: Empty email text"
93
-
94
- features = {
95
- 'text_length': torch.tensor([len(email_text)], dtype=torch.float32),
96
- 'token_count': torch.tensor([token_count], dtype=torch.float32),
97
- 'avg_token_length': torch.tensor([sum(len(tok) for tok in tokens) / max(token_count, 1)], dtype=torch.float32),
98
- 'num_date_tokens': torch.tensor([sum(1 for tok in tokens if tok in DATE_KEYWORDS)], dtype=torch.float32),
99
- 'has_attachment_reference': torch.tensor([float(any(k in lower_text for k in ATTACHMENT_KEYWORDS))], dtype=torch.float32),
100
- 'has_operational_keywords': torch.tensor([float(any(k in lower_text for k in OPERATIONAL_KEYWORDS))], dtype=torch.float32),
101
- 'has_phishy_keywords': torch.tensor([float(any(k in lower_text for k in PHISHY_KEYWORDS))], dtype=torch.float32),
102
- }
103
-
104
- with torch.no_grad():
105
- outputs = model(
106
- input_ids=inputs['input_ids'],
107
- attention_mask=inputs['attention_mask'],
108
- **features
109
- )
110
-
111
- if isinstance(outputs, torch.Tensor):
112
- logits = outputs
113
- else:
114
- logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
115
-
116
- probs = torch.nn.functional.softmax(logits, dim=1)
117
- confidence, pred = torch.max(probs, dim=1)
118
- confidence = confidence.item()
119
- pred = pred.item()
120
-
121
- highlight_html, features_dict = explain_features(email_text)
122
-
123
- if pred == 1:
124
- decision = f"<strong style='color:red'>Phishing detected!</strong> (confidence: {confidence:.2%})"
125
  else:
126
- decision = f"<strong style='color:green'>Legitimate email</strong> (confidence: {confidence:.2%})"
127
-
128
- feature_html = "<br><u><strong>Detected Indicators:</strong></u><ul>"
129
- for category, items in features_dict.items():
130
- if items:
131
- item_str = ", ".join(items[:5]) # Limit to 5 items per category
132
- feature_html += f"<li><strong>{category}:</strong> {item_str}</li>"
133
- feature_html += "</ul>"
134
-
135
- return f"{decision}<br><br>{feature_html}<u><strong>Email Highlight View:</strong></u><br>{highlight_html}"
136
 
 
137
  examples = [
138
  ["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
139
  ["Hi team, please review the attached document for our quarterly meeting tomorrow."],
@@ -142,12 +26,13 @@ examples = [
142
  ["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
143
  ]
144
 
 
145
  app = gr.Interface(
146
- fn=detect_and_explain,
147
  inputs=gr.Textbox(label="Email Text", placeholder="Paste the email content here...", lines=5),
148
- outputs=gr.HTML(label="Detection & Explanation"),
149
- title="COS 720: Email Phishing Detector",
150
- description="A lightweight AI-powered phishing email detector that analyzes text and metadata to classify emails as phishing or legitimate with explainable insights to highlight suspicious content.",
151
  examples=examples,
152
  theme="soft"
153
  )
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ odel = AutoModelForSequenceClassification.from_pretrained("zionia/email-phishing-detector")
5
+ tokenizer = AutoTokenizer.from_pretrained("zionia/email-phishing-detector")
6
+
7
+ def detect_phishing(email_text):
8
+ """
9
+ Detect if the input email text is phishing or not
10
+ """
11
+ result = model(email_text)
12
+ label = result[0]['label']
13
+ score = result[0]['score']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ if label == "LABEL_1":
16
+ return f"Phishing detected! (confidence: {score:.2%})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  else:
18
+ return f"Legitimate email (confidence: {score:.2%})"
 
 
 
 
 
 
 
 
 
19
 
20
+ # Example emails
21
  examples = [
22
  ["Dear customer, your account has been compromised. Click here to verify your identity: http://bit.ly/2XyZABC"],
23
  ["Hi team, please review the attached document for our quarterly meeting tomorrow."],
 
26
  ["You've won a $1000 Amazon gift card! Click to claim your prize within 24 hours!"]
27
  ]
28
 
29
+ # Create the Gradio interface
30
  app = gr.Interface(
31
+ fn=detect_phishing,
32
  inputs=gr.Textbox(label="Email Text", placeholder="Paste the email content here...", lines=5),
33
+ outputs=gr.Textbox(label="Detection Result"),
34
+ title="Email Phishing Detector",
35
+ description="Detect whether an email is phishing or legitimate using AI.",
36
  examples=examples,
37
  theme="soft"
38
  )