Threatscope_ai / app.py
Ogero79's picture
Update app.py
93c3afd verified
# app.py
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
import json
import os
# --- 1. Configuration ---
adapter_model_name = "Ogero79/threatscope-cyberthreat-analyst"
base_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
# --- 2. Model Loading ---
print("--- Loading Model and Tokenizer ---")
# Load the tokenizer from the adapter repo
tokenizer = AutoTokenizer.from_pretrained(adapter_model_name)
# Load the base Llama 3 model.
# We use float16 to save memory on the CPU Space.
# device_map="auto" will intelligently place the model on the CPU.
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
token=os.environ.get("HF_TOKEN"), # Use the token from Space secrets
)
# Load the PEFT adapter and merge it into the base model for faster inference.
model = PeftModel.from_pretrained(base_model, adapter_model_name)
model = model.merge_and_unload()
model.eval()
# Create the text-generation pipeline. device=-1 ensures it runs on CPU.
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1, # Explicitly set to CPU
torch_dtype=torch.float16
)
print("βœ… Model and pipeline loaded successfully!")
# --- 3. Inference Function (copied and adapted from your notebook) ---
def generate_response(prompt_text, max_new_tokens=512, temperature=0.01):
# Define the safe/default JSON structure for non-threats
safe_default_response = {
"summary": "No actionable cybersecurity threat detected",
"threat_type": "Non-Threat",
"risk_score": 0,
"risk_level": "None",
"suggested_defense": "No action required",
"iocs": [],
"threat_actor": "None",
"geographical_scope": "None"
}
messages = [
{
"role": "system",
"content": (
"You are an expert cybersecurity analyst. Analyze input and return JSON with these fields:\n"
"- summary: If input describes a threat, summarize it. Otherwise, state no threat detected\n"
"- threat_type: Threat category if valid, otherwise 'Non-Threat'\n"
"- risk_score: 0-100 (0 for non-threats)\n"
"- risk_level: Critical/High/Medium/Low/None\n"
"- suggested_defense: Recommendations or 'No action required'\n"
"- iocs: Empty list for non-threats\n"
"- threat_actor: 'None' for non-threats\n"
"- geographical_scope: 'None' for non-threats\n"
"For CLEAR non-threats (e.g., 'Hello', weather queries), return the safe default format immediately."
)
},
{
"role": "user",
"content": f"Analyze this input for cybersecurity threats: {prompt_text}\n"
f"Return ONLY the JSON output with all fields populated."
}
]
try:
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = generator(prompt, max_new_tokens=max_new_tokens, temperature=temperature,
top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id)
generated_full_text = outputs[0]["generated_text"]
response = generated__text[len(prompt):].strip()
# First try to find and parse JSON
try:
first_brace = response.find('{')
last_brace = response.rfind('}')
if first_brace == -1 or last_brace == -1:
raise ValueError("No JSON detected in response")
parsed_json = json.loads(response[first_brace:last_brace+1])
# Validate required fields exist
required_fields = {
'summary': str, 'threat_type': str, 'risk_score': int,
'risk_level': str, 'suggested_defense': str, 'iocs': list,
'threat_actor': str, 'geographical_scope': str
}
for field, field_type in required_fields.items():
if field not in parsed_json or not isinstance(parsed_json.get(field), field_type):
parsed_json[field] = safe_default_response[field]
return parsed_json
except (json.JSONDecodeError, ValueError):
# If JSON parsing fails, analyze the raw response for threat indicators
threat_keywords = ["malware", "attack", "phishing", "breach", "exploit", "hack", "ransomware"]
if any(keyword in response.lower() for keyword in threat_keywords):
# If threat keywords found but JSON invalid, return error with the raw analysis
return {
**safe_default_response,
"summary": f"Potential threat detected but invalid format. Analyst review recommended. Raw response: {response[:200]}...",
"threat_type": "Unknown (Format Error)",
"risk_score": 50,
"risk_level": "Medium"
}
else:
# No threat keywords detected - definitely safe to return default
return safe_default_response
except Exception as e:
# Critical error case - return safe format with error details
safe_default_response["summary"] = f"System error: {str(e)}. Default safe response returned"
return safe_default_response
# --- 4. Gradio Interface ---
css = """
#col-container {
margin: 0 auto;
max-width: 900px;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
# πŸ€– ThreatScope: AI Cybersecurity Analyst
Enter a description of a potential security event below. The fine-tuned Llama 3 model will analyze it and return a structured JSON response with a risk assessment and suggested actions.
**Note:** This is an 8B parameter model running on a CPU. The first inference may be slow, but subsequent ones will be faster.
"""
)
with gr.Row():
prompt_input = gr.Textbox(
label="Enter Threat Description",
placeholder="e.g., Our DNS server is being flooded with requests from thousands of botnet IPs.",
lines=4
)
analyze_button = gr.Button("Analyze Threat")
output_json = gr.JSON(label="Analysis Result")
gr.Examples(
[
"A misconfigured cloud storage bucket exposed sensitive customer data online for months.",
"Urgent: Employee received a suspicious email with a malicious attachment claiming to be from HR.",
"An ex-employee's credentials were used to log into the main database at 2 AM.",
"What's the capital of France?",
],
inputs=prompt_input,
outputs=output_json,
fn=generate_response,
)
analyze_button.click(
fn=generate_response,
inputs=prompt_input,
outputs=output_json
)
demo.launch()