import gradio as gr import json from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer import torch import random import numpy as np torch.manual_seed(42) random.seed(42) np.random.seed(42) torch.use_deterministic_algorithms(True) model_path = "CIRCL/cwe-parent-vulnerability-classification-roberta-base" model = AutoModelForSequenceClassification.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) classifier = pipeline( task="text-classification", model=model, tokenizer=tokenizer, top_k=None, return_all_scores=True ) model.eval() with open(f"{model_path}/config.json", "r") as f: config = json.load(f) id_to_cwe = {int(k): v for k, v in config["id2label"].items()} valid_cwes = set(id_to_cwe.values()) with open("deep_child_to_ancestor.json", "r") as f: child_to_ancestor = json.load(f) def map_prediction_to_valid_cwes(predictions, id_to_cwe, child_to_ancestor, threshold=0.2, top_k=5): """ Map model predictions to CWE ancestors and return top_k valid results. """ results = [] for item in predictions: for label_idx, score in enumerate(item): if score["score"] >= threshold: label_id = score["label"].split("_")[-1] # "LABEL_123" → "123" label_id = int(label_id) if label_id in id_to_cwe: cwe = id_to_cwe[label_id] ancestor = child_to_ancestor.get(cwe, cwe) if ancestor in valid_cwes: results.append((f"CWE-{ancestor}", round(score["score"], 4))) aggregated = {} for cwe, score in results: aggregated[cwe] = max(aggregated.get(cwe, 0), score) sorted_results = sorted(aggregated.items(), key=lambda x: x[1], reverse=True) return dict(sorted_results[:top_k]) def predict_cwe(commit_message: str): raw_preds = classifier(commit_message) return map_prediction_to_valid_cwes(raw_preds, id_to_cwe, child_to_ancestor) demo = gr.Interface( fn=predict_cwe, inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."), outputs=gr.Label(num_top_classes=5), title="CWE Prediction from Commit Message and vulnerability description", description="This tool predicts CWE ancestor categories from Git commit messages and vulnerability descriptions, based on a fine-tuned transformer model.", examples=[ ["Fixed buffer overflow in input parsing"], ["SQL injection possible in login flow"], ["Improved input validation to prevent XSS"], ] ) if __name__ == "__main__": demo.launch()