AmiyendraOP commited on
Commit
60b634d
Β·
verified Β·
1 Parent(s): 0232990

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -2,21 +2,27 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
  import torch
3
  import gradio as gr
4
 
 
5
 
 
 
 
6
 
7
- model = AutoModelForCausalLM.from_pretrained("AmiyendraOP/llama3-legal-finetuned", device_map="auto")
8
- tokenizer = AutoTokenizer.from_pretrained("AmiyendraOP/llama3-legal-finetuned")
9
-
10
-
11
-
12
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
13
-
14
 
 
 
15
 
 
16
  def chat(prompt):
17
  response = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)[0]["generated_text"]
18
  return response
19
 
20
-
21
-
22
- gr.Interface(fn=chat, inputs="text", outputs="text", title="LLaMA 3 Legal Chatbot").launch()
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
 
5
+ model_id = "AmiyendraOP/llama3-legal-finetuned"
6
 
7
+ # Load tokenizer and model
8
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
9
+ model = AutoModelForCausalLM.from_pretrained(model_id)
10
 
11
+ # Set device properly
12
+ device = 0 if torch.cuda.is_available() else -1
 
 
 
 
 
13
 
14
+ # Use pipeline for text generation
15
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
16
 
17
+ # Define chat function
18
  def chat(prompt):
19
  response = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)[0]["generated_text"]
20
  return response
21
 
22
+ # Launch Gradio app
23
+ gr.Interface(
24
+ fn=chat,
25
+ inputs=gr.Textbox(lines=4, placeholder="Enter legal question...", label="Your Question"),
26
+ outputs=gr.Textbox(label="Response"),
27
+ title="LLaMA 3 Legal Chatbot (Fine-tuned)",
28
+ ).launch()