# app.py import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from peft import PeftModel import json import os # --- 1. Configuration --- adapter_model_name = "Ogero79/threatscope-cyberthreat-analyst" base_model_name = "meta-llama/Meta-Llama-3-8B-Instruct" # --- 2. Model Loading --- print("--- Loading Model and Tokenizer ---") # Load the tokenizer from the adapter repo tokenizer = AutoTokenizer.from_pretrained(adapter_model_name) # Load the base Llama 3 model. # We use float16 to save memory on the CPU Space. # device_map="auto" will intelligently place the model on the CPU. base_model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16, device_map="auto", token=os.environ.get("HF_TOKEN"), # Use the token from Space secrets ) # Load the PEFT adapter and merge it into the base model for faster inference. model = PeftModel.from_pretrained(base_model, adapter_model_name) model = model.merge_and_unload() model.eval() # Create the text-generation pipeline. device=-1 ensures it runs on CPU. generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=-1, # Explicitly set to CPU torch_dtype=torch.float16 ) print("✅ Model and pipeline loaded successfully!") # --- 3. Inference Function (copied and adapted from your notebook) --- def generate_response(prompt_text, max_new_tokens=512, temperature=0.01): # Define the safe/default JSON structure for non-threats safe_default_response = { "summary": "No actionable cybersecurity threat detected", "threat_type": "Non-Threat", "risk_score": 0, "risk_level": "None", "suggested_defense": "No action required", "iocs": [], "threat_actor": "None", "geographical_scope": "None" } messages = [ { "role": "system", "content": ( "You are an expert cybersecurity analyst. Analyze input and return JSON with these fields:\n" "- summary: If input describes a threat, summarize it. Otherwise, state no threat detected\n" "- threat_type: Threat category if valid, otherwise 'Non-Threat'\n" "- risk_score: 0-100 (0 for non-threats)\n" "- risk_level: Critical/High/Medium/Low/None\n" "- suggested_defense: Recommendations or 'No action required'\n" "- iocs: Empty list for non-threats\n" "- threat_actor: 'None' for non-threats\n" "- geographical_scope: 'None' for non-threats\n" "For CLEAR non-threats (e.g., 'Hello', weather queries), return the safe default format immediately." ) }, { "role": "user", "content": f"Analyze this input for cybersecurity threats: {prompt_text}\n" f"Return ONLY the JSON output with all fields populated." } ] try: prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) outputs = generator(prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id) generated_full_text = outputs[0]["generated_text"] response = generated__text[len(prompt):].strip() # First try to find and parse JSON try: first_brace = response.find('{') last_brace = response.rfind('}') if first_brace == -1 or last_brace == -1: raise ValueError("No JSON detected in response") parsed_json = json.loads(response[first_brace:last_brace+1]) # Validate required fields exist required_fields = { 'summary': str, 'threat_type': str, 'risk_score': int, 'risk_level': str, 'suggested_defense': str, 'iocs': list, 'threat_actor': str, 'geographical_scope': str } for field, field_type in required_fields.items(): if field not in parsed_json or not isinstance(parsed_json.get(field), field_type): parsed_json[field] = safe_default_response[field] return parsed_json except (json.JSONDecodeError, ValueError): # If JSON parsing fails, analyze the raw response for threat indicators threat_keywords = ["malware", "attack", "phishing", "breach", "exploit", "hack", "ransomware"] if any(keyword in response.lower() for keyword in threat_keywords): # If threat keywords found but JSON invalid, return error with the raw analysis return { **safe_default_response, "summary": f"Potential threat detected but invalid format. Analyst review recommended. Raw response: {response[:200]}...", "threat_type": "Unknown (Format Error)", "risk_score": 50, "risk_level": "Medium" } else: # No threat keywords detected - definitely safe to return default return safe_default_response except Exception as e: # Critical error case - return safe format with error details safe_default_response["summary"] = f"System error: {str(e)}. Default safe response returned" return safe_default_response # --- 4. Gradio Interface --- css = """ #col-container { margin: 0 auto; max-width: 900px; } """ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown( """ # 🤖 ThreatScope: AI Cybersecurity Analyst Enter a description of a potential security event below. The fine-tuned Llama 3 model will analyze it and return a structured JSON response with a risk assessment and suggested actions. **Note:** This is an 8B parameter model running on a CPU. The first inference may be slow, but subsequent ones will be faster. """ ) with gr.Row(): prompt_input = gr.Textbox( label="Enter Threat Description", placeholder="e.g., Our DNS server is being flooded with requests from thousands of botnet IPs.", lines=4 ) analyze_button = gr.Button("Analyze Threat") output_json = gr.JSON(label="Analysis Result") gr.Examples( [ "A misconfigured cloud storage bucket exposed sensitive customer data online for months.", "Urgent: Employee received a suspicious email with a malicious attachment claiming to be from HR.", "An ex-employee's credentials were used to log into the main database at 2 AM.", "What's the capital of France?", ], inputs=prompt_input, outputs=output_json, fn=generate_response, ) analyze_button.click( fn=generate_response, inputs=prompt_input, outputs=output_json ) demo.launch()