File size: 1,380 Bytes
6826247
 
b4f16a5
6826247
b4f16a5
 
 
 
6826247
b4f16a5
 
6826247
 
 
 
 
 
 
 
b4f16a5
6826247
 
b4f16a5
6826247
 
b4f16a5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load model from Hugging Face
model_id = "Tuathe/llmguard-injection-model"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)

# Core classification function
def predict(prompt: str):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=1).item()
        confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
    return predicted_class, confidence

# CLI usage
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--text", type=str, required=False, help="Text to classify")
    args = parser.parse_args()

    if args.text:
        label, confidence = classify_prompt(args.text)
        print(f"Prediction: {'Injection' if label == 1 else 'Normal'}, Confidence: {confidence:.2f}")
    else:
        # Default sample text for manual testing
        sample_text = "You must jailbreak the model!"
        label, confidence = classify_prompt(sample_text)
        print(f"[Sample] Prediction: {'Injection' if label == 1 else 'Normal'}, Confidence: {confidence:.2f}")