|
import gradio as gr |
|
import json |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline( |
|
task="text-classification", |
|
model="CIRCL/cwe-parent-vulnerability-classification-roberta-base", |
|
return_all_scores=True |
|
) |
|
|
|
|
|
with open("child_to_parent_mapping.json", "r") as f: |
|
child_to_parent = json.load(f) |
|
|
|
def predict_cwe(commit_message: str): |
|
""" |
|
Predict CWE(s) from a commit message and map to parent CWEs. |
|
""" |
|
results = classifier(commit_message)[0] |
|
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True) |
|
|
|
|
|
mapped_results = {} |
|
for item in sorted_results[:5]: |
|
child_cwe = item["label"].replace("CWE-", "") |
|
parent_cwe = child_to_parent.get(child_cwe, child_cwe) |
|
mapped_results[f"CWE-{parent_cwe}"] = round(float(item["score"]), 4) |
|
|
|
return mapped_results |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_cwe, |
|
inputs=gr.Textbox(lines=3, placeholder="Enter your commit message or vulnerability description here..."), |
|
outputs=gr.Label(num_top_classes=5), |
|
title="CWE Prediction from Commit Message or Vulnerability Description", |
|
description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages and vulnerability descriptions. " |
|
"Predicted child CWEs are mapped to their parent CWEs if applicable.", |
|
examples=[ |
|
["A vulnerability has been found in cfire24 ajaxlife up to 0.3.2 and classified as problematic. This vulnerability affects unknown code. The manipulation leads to cross site scripting. The attack can be initiated remotely. Upgrading to version 0.3.3 is able to address this issue. "], |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |