File size: 1,910 Bytes
f8e978b 9ba88a4 f8e978b 2a5a21b f8e978b 9ba88a4 1d606fb 2a5a21b f8e978b 9ba88a4 6519258 9ba88a4 1d606fb 9ba88a4 1d606fb f8e978b 9ba88a4 1d606fb f8e978b 1d606fb 9ba88a4 1d606fb 9ba88a4 1d606fb f8e978b 1d606fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
import json
from transformers import pipeline
import torch
import random
import numpy as np
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
torch.use_deterministic_algorithms(True)
# Load Hugging Face model (text classification)
classifier = pipeline(
task="text-classification",
model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
return_all_scores=True
)
classifier.model.eval()
threshold = 0.2
filtered_results = [item for item in sorted_results if item["score"] >= threshold]
# Load child-to-parent mapping
with open("child_to_parent_mapping.json", "r") as f:
child_to_parent = json.load(f)
def predict_cwe(commit_message: str):
"""
Predict CWE(s) from a commit message and map to parent CWEs.
"""
results = classifier(commit_message)[0]
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
# Map predictions to parent CWE (if available)
mapped_results = {}
for item in sorted_results[:5]:
child_cwe = item["label"].replace("CWE-", "")
parent_cwe = child_to_parent.get(child_cwe, child_cwe) # default to child if no parent
mapped_results[f"CWE-{parent_cwe}"] = round(float(item["score"]), 4)
return mapped_results
# Gradio UI
demo = gr.Interface(
fn=predict_cwe,
inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
outputs=gr.Label(num_top_classes=5),
title="CWE Prediction from Commit Message",
description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
"Predicted child CWEs are mapped to their parent CWEs if applicable.",
examples=[
["Fixed buffer overflow in input parsing"],
["SQL injection possible in login flow"],
["Improved input validation to prevent XSS"],
]
)
if __name__ == "__main__":
demo.launch()
|