schroneko commited on
Commit
e2fac8d
1 Parent(s): 29e0785

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -4
app.py CHANGED
@@ -1,9 +1,50 @@
 
 
1
  import gradio as gr
2
  import spaces
3
 
 
 
 
 
 
 
4
  @spaces.GPU
5
- def inference():
6
- return gr.load("models/meta-llama/Llama-Guard-3-8B-INT8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- demo = inference()
9
- demo.launch()
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import gradio as gr
4
  import spaces
5
 
6
+ model_id = "meta-llama/Llama-Guard-3-8B-INT8"
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ dtype = torch.bfloat16
9
+
10
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
11
+
12
  @spaces.GPU
13
+ def load_model():
14
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
+ torch_dtype=dtype,
18
+ device_map=device,
19
+ quantization_config=quantization_config
20
+ )
21
+ return tokenizer, model
22
+
23
+ tokenizer, model = load_model()
24
+
25
+ def moderate(user_input, assistant_response):
26
+ chat = [
27
+ {"role": "user", "content": user_input},
28
+ {"role": "assistant", "content": assistant_response},
29
+ ]
30
+ input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
31
+ output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
32
+ prompt_len = input_ids.shape[-1]
33
+ return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
34
+
35
+ def gradio_moderate(user_input, assistant_response):
36
+ return moderate(user_input, assistant_response)
37
+
38
+ iface = gr.Interface(
39
+ fn=gradio_moderate,
40
+ inputs=[
41
+ gr.Textbox(lines=3, label="User Input"),
42
+ gr.Textbox(lines=3, label="Assistant Response")
43
+ ],
44
+ outputs=gr.Textbox(label="Moderation Result"),
45
+ title="Llama Guard Moderation",
46
+ description="Enter a user input and an assistant response to check for content moderation."
47
+ )
48
 
49
+ if __name__ == "__main__":
50
+ iface.launch()