Ogero79 commited on
Commit
93c3afd
·
verified ·
1 Parent(s): 1fa1383

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -26
app.py CHANGED
@@ -1,34 +1,181 @@
 
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
4
 
5
- # Load the tokenizer
6
- model_name = "Ogero79/threatscope-cyberthreat-analyst" # Ensure this matches your public model name
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- # Load model in FP32 on CPU (no quantization)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_name,
12
- torch_dtype=torch.float32, # Use FP32
13
- low_cpu_mem_usage=True # Use this to reduce memory usage
 
 
 
 
 
 
 
 
14
  )
15
 
16
- def generate_response(prompt):
17
- inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
18
-
19
- # Reduce max_length to fit CPU memory
20
- with torch.no_grad():
21
- outputs = model.generate(**inputs, max_length=100)
22
-
23
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
24
-
25
- # Create Gradio Interface
26
- demo = gr.Interface(
27
- fn=generate_response,
28
- inputs="text",
29
- outputs="text",
30
- examples=["Phishing email detected", "Potential DDoS attack"]
31
  )
32
 
33
- if __name__ == "__main__":
34
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ from peft import PeftModel
6
+ import json
7
+ import os
8
 
9
+ # --- 1. Configuration ---
10
+ adapter_model_name = "Ogero79/threatscope-cyberthreat-analyst"
11
+ base_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
12
 
13
+ # --- 2. Model Loading ---
14
+ print("--- Loading Model and Tokenizer ---")
15
+ # Load the tokenizer from the adapter repo
16
+ tokenizer = AutoTokenizer.from_pretrained(adapter_model_name)
17
+
18
+ # Load the base Llama 3 model.
19
+ # We use float16 to save memory on the CPU Space.
20
+ # device_map="auto" will intelligently place the model on the CPU.
21
+ base_model = AutoModelForCausalLM.from_pretrained(
22
+ base_model_name,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto",
25
+ token=os.environ.get("HF_TOKEN"), # Use the token from Space secrets
26
  )
27
 
28
+ # Load the PEFT adapter and merge it into the base model for faster inference.
29
+ model = PeftModel.from_pretrained(base_model, adapter_model_name)
30
+ model = model.merge_and_unload()
31
+ model.eval()
32
+
33
+ # Create the text-generation pipeline. device=-1 ensures it runs on CPU.
34
+ generator = pipeline(
35
+ "text-generation",
36
+ model=model,
37
+ tokenizer=tokenizer,
38
+ device=-1, # Explicitly set to CPU
39
+ torch_dtype=torch.float16
 
 
 
40
  )
41
 
42
+ print("✅ Model and pipeline loaded successfully!")
43
+
44
+
45
+ # --- 3. Inference Function (copied and adapted from your notebook) ---
46
+ def generate_response(prompt_text, max_new_tokens=512, temperature=0.01):
47
+ # Define the safe/default JSON structure for non-threats
48
+ safe_default_response = {
49
+ "summary": "No actionable cybersecurity threat detected",
50
+ "threat_type": "Non-Threat",
51
+ "risk_score": 0,
52
+ "risk_level": "None",
53
+ "suggested_defense": "No action required",
54
+ "iocs": [],
55
+ "threat_actor": "None",
56
+ "geographical_scope": "None"
57
+ }
58
+
59
+ messages = [
60
+ {
61
+ "role": "system",
62
+ "content": (
63
+ "You are an expert cybersecurity analyst. Analyze input and return JSON with these fields:\n"
64
+ "- summary: If input describes a threat, summarize it. Otherwise, state no threat detected\n"
65
+ "- threat_type: Threat category if valid, otherwise 'Non-Threat'\n"
66
+ "- risk_score: 0-100 (0 for non-threats)\n"
67
+ "- risk_level: Critical/High/Medium/Low/None\n"
68
+ "- suggested_defense: Recommendations or 'No action required'\n"
69
+ "- iocs: Empty list for non-threats\n"
70
+ "- threat_actor: 'None' for non-threats\n"
71
+ "- geographical_scope: 'None' for non-threats\n"
72
+ "For CLEAR non-threats (e.g., 'Hello', weather queries), return the safe default format immediately."
73
+ )
74
+ },
75
+ {
76
+ "role": "user",
77
+ "content": f"Analyze this input for cybersecurity threats: {prompt_text}\n"
78
+ f"Return ONLY the JSON output with all fields populated."
79
+ }
80
+ ]
81
+
82
+ try:
83
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
84
+ outputs = generator(prompt, max_new_tokens=max_new_tokens, temperature=temperature,
85
+ top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id)
86
+ generated_full_text = outputs[0]["generated_text"]
87
+ response = generated__text[len(prompt):].strip()
88
+
89
+ # First try to find and parse JSON
90
+ try:
91
+ first_brace = response.find('{')
92
+ last_brace = response.rfind('}')
93
+ if first_brace == -1 or last_brace == -1:
94
+ raise ValueError("No JSON detected in response")
95
+
96
+ parsed_json = json.loads(response[first_brace:last_brace+1])
97
+
98
+ # Validate required fields exist
99
+ required_fields = {
100
+ 'summary': str, 'threat_type': str, 'risk_score': int,
101
+ 'risk_level': str, 'suggested_defense': str, 'iocs': list,
102
+ 'threat_actor': str, 'geographical_scope': str
103
+ }
104
+
105
+ for field, field_type in required_fields.items():
106
+ if field not in parsed_json or not isinstance(parsed_json.get(field), field_type):
107
+ parsed_json[field] = safe_default_response[field]
108
+
109
+ return parsed_json
110
+
111
+ except (json.JSONDecodeError, ValueError):
112
+ # If JSON parsing fails, analyze the raw response for threat indicators
113
+ threat_keywords = ["malware", "attack", "phishing", "breach", "exploit", "hack", "ransomware"]
114
+ if any(keyword in response.lower() for keyword in threat_keywords):
115
+ # If threat keywords found but JSON invalid, return error with the raw analysis
116
+ return {
117
+ **safe_default_response,
118
+ "summary": f"Potential threat detected but invalid format. Analyst review recommended. Raw response: {response[:200]}...",
119
+ "threat_type": "Unknown (Format Error)",
120
+ "risk_score": 50,
121
+ "risk_level": "Medium"
122
+ }
123
+ else:
124
+ # No threat keywords detected - definitely safe to return default
125
+ return safe_default_response
126
+
127
+ except Exception as e:
128
+ # Critical error case - return safe format with error details
129
+ safe_default_response["summary"] = f"System error: {str(e)}. Default safe response returned"
130
+ return safe_default_response
131
+
132
+ # --- 4. Gradio Interface ---
133
+ css = """
134
+ #col-container {
135
+ margin: 0 auto;
136
+ max-width: 900px;
137
+ }
138
+ """
139
+
140
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
141
+ with gr.Column(elem_id="col-container"):
142
+ gr.Markdown(
143
+ """
144
+ # 🤖 ThreatScope: AI Cybersecurity Analyst
145
+
146
+ Enter a description of a potential security event below. The fine-tuned Llama 3 model will analyze it and return a structured JSON response with a risk assessment and suggested actions.
147
+
148
+ **Note:** This is an 8B parameter model running on a CPU. The first inference may be slow, but subsequent ones will be faster.
149
+ """
150
+ )
151
+
152
+ with gr.Row():
153
+ prompt_input = gr.Textbox(
154
+ label="Enter Threat Description",
155
+ placeholder="e.g., Our DNS server is being flooded with requests from thousands of botnet IPs.",
156
+ lines=4
157
+ )
158
+
159
+ analyze_button = gr.Button("Analyze Threat")
160
+
161
+ output_json = gr.JSON(label="Analysis Result")
162
+
163
+ gr.Examples(
164
+ [
165
+ "A misconfigured cloud storage bucket exposed sensitive customer data online for months.",
166
+ "Urgent: Employee received a suspicious email with a malicious attachment claiming to be from HR.",
167
+ "An ex-employee's credentials were used to log into the main database at 2 AM.",
168
+ "What's the capital of France?",
169
+ ],
170
+ inputs=prompt_input,
171
+ outputs=output_json,
172
+ fn=generate_response,
173
+ )
174
+
175
+ analyze_button.click(
176
+ fn=generate_response,
177
+ inputs=prompt_input,
178
+ outputs=output_json
179
+ )
180
+
181
+ demo.launch()