Spaces:
Runtime error
Runtime error
# app.py | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from peft import PeftModel | |
import json | |
import os | |
# --- 1. Configuration --- | |
adapter_model_name = "Ogero79/threatscope-cyberthreat-analyst" | |
base_model_name = "meta-llama/Meta-Llama-3-8B-Instruct" | |
# --- 2. Model Loading --- | |
print("--- Loading Model and Tokenizer ---") | |
# Load the tokenizer from the adapter repo | |
tokenizer = AutoTokenizer.from_pretrained(adapter_model_name) | |
# Load the base Llama 3 model. | |
# We use float16 to save memory on the CPU Space. | |
# device_map="auto" will intelligently place the model on the CPU. | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
token=os.environ.get("HF_TOKEN"), # Use the token from Space secrets | |
) | |
# Load the PEFT adapter and merge it into the base model for faster inference. | |
model = PeftModel.from_pretrained(base_model, adapter_model_name) | |
model = model.merge_and_unload() | |
model.eval() | |
# Create the text-generation pipeline. device=-1 ensures it runs on CPU. | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=-1, # Explicitly set to CPU | |
torch_dtype=torch.float16 | |
) | |
print("β Model and pipeline loaded successfully!") | |
# --- 3. Inference Function (copied and adapted from your notebook) --- | |
def generate_response(prompt_text, max_new_tokens=512, temperature=0.01): | |
# Define the safe/default JSON structure for non-threats | |
safe_default_response = { | |
"summary": "No actionable cybersecurity threat detected", | |
"threat_type": "Non-Threat", | |
"risk_score": 0, | |
"risk_level": "None", | |
"suggested_defense": "No action required", | |
"iocs": [], | |
"threat_actor": "None", | |
"geographical_scope": "None" | |
} | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are an expert cybersecurity analyst. Analyze input and return JSON with these fields:\n" | |
"- summary: If input describes a threat, summarize it. Otherwise, state no threat detected\n" | |
"- threat_type: Threat category if valid, otherwise 'Non-Threat'\n" | |
"- risk_score: 0-100 (0 for non-threats)\n" | |
"- risk_level: Critical/High/Medium/Low/None\n" | |
"- suggested_defense: Recommendations or 'No action required'\n" | |
"- iocs: Empty list for non-threats\n" | |
"- threat_actor: 'None' for non-threats\n" | |
"- geographical_scope: 'None' for non-threats\n" | |
"For CLEAR non-threats (e.g., 'Hello', weather queries), return the safe default format immediately." | |
) | |
}, | |
{ | |
"role": "user", | |
"content": f"Analyze this input for cybersecurity threats: {prompt_text}\n" | |
f"Return ONLY the JSON output with all fields populated." | |
} | |
] | |
try: | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
outputs = generator(prompt, max_new_tokens=max_new_tokens, temperature=temperature, | |
top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id) | |
generated_full_text = outputs[0]["generated_text"] | |
response = generated__text[len(prompt):].strip() | |
# First try to find and parse JSON | |
try: | |
first_brace = response.find('{') | |
last_brace = response.rfind('}') | |
if first_brace == -1 or last_brace == -1: | |
raise ValueError("No JSON detected in response") | |
parsed_json = json.loads(response[first_brace:last_brace+1]) | |
# Validate required fields exist | |
required_fields = { | |
'summary': str, 'threat_type': str, 'risk_score': int, | |
'risk_level': str, 'suggested_defense': str, 'iocs': list, | |
'threat_actor': str, 'geographical_scope': str | |
} | |
for field, field_type in required_fields.items(): | |
if field not in parsed_json or not isinstance(parsed_json.get(field), field_type): | |
parsed_json[field] = safe_default_response[field] | |
return parsed_json | |
except (json.JSONDecodeError, ValueError): | |
# If JSON parsing fails, analyze the raw response for threat indicators | |
threat_keywords = ["malware", "attack", "phishing", "breach", "exploit", "hack", "ransomware"] | |
if any(keyword in response.lower() for keyword in threat_keywords): | |
# If threat keywords found but JSON invalid, return error with the raw analysis | |
return { | |
**safe_default_response, | |
"summary": f"Potential threat detected but invalid format. Analyst review recommended. Raw response: {response[:200]}...", | |
"threat_type": "Unknown (Format Error)", | |
"risk_score": 50, | |
"risk_level": "Medium" | |
} | |
else: | |
# No threat keywords detected - definitely safe to return default | |
return safe_default_response | |
except Exception as e: | |
# Critical error case - return safe format with error details | |
safe_default_response["summary"] = f"System error: {str(e)}. Default safe response returned" | |
return safe_default_response | |
# --- 4. Gradio Interface --- | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 900px; | |
} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
""" | |
# π€ ThreatScope: AI Cybersecurity Analyst | |
Enter a description of a potential security event below. The fine-tuned Llama 3 model will analyze it and return a structured JSON response with a risk assessment and suggested actions. | |
**Note:** This is an 8B parameter model running on a CPU. The first inference may be slow, but subsequent ones will be faster. | |
""" | |
) | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Enter Threat Description", | |
placeholder="e.g., Our DNS server is being flooded with requests from thousands of botnet IPs.", | |
lines=4 | |
) | |
analyze_button = gr.Button("Analyze Threat") | |
output_json = gr.JSON(label="Analysis Result") | |
gr.Examples( | |
[ | |
"A misconfigured cloud storage bucket exposed sensitive customer data online for months.", | |
"Urgent: Employee received a suspicious email with a malicious attachment claiming to be from HR.", | |
"An ex-employee's credentials were used to log into the main database at 2 AM.", | |
"What's the capital of France?", | |
], | |
inputs=prompt_input, | |
outputs=output_json, | |
fn=generate_response, | |
) | |
analyze_button.click( | |
fn=generate_response, | |
inputs=prompt_input, | |
outputs=output_json | |
) | |
demo.launch() |