jatingocodeo commited on
Commit
ce62d55
·
verified ·
1 Parent(s): 29fbf12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py CHANGED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
+
6
+ # Load model and tokenizer
7
+ def load_model(model_id):
8
+ # First load the base model
9
+ base_model_id = "microsoft/phi-2"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
11
+
12
+ # Ensure tokenizer has a padding token
13
+ if tokenizer.pad_token is None:
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+
16
+ base_model = AutoModelForCausalLM.from_pretrained(
17
+ base_model_id,
18
+ torch_dtype=torch.float16,
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+
23
+ # Load and merge the LoRA adapter
24
+ model = PeftModel.from_pretrained(base_model, model_id)
25
+ return model, tokenizer
26
+
27
+ def generate_response(instruction, model, tokenizer, max_length=200, temperature=0.7, top_p=0.9):
28
+ # Format the input text
29
+ input_text = instruction.strip()
30
+
31
+ # Tokenize input
32
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
33
+
34
+ # Generate response
35
+ with torch.no_grad():
36
+ outputs = model.generate(
37
+ **inputs,
38
+ max_new_tokens=max_length,
39
+ temperature=temperature,
40
+ top_p=top_p,
41
+ num_return_sequences=1,
42
+ pad_token_id=tokenizer.eos_token_id,
43
+ do_sample=True
44
+ )
45
+
46
+ # Decode and return the response
47
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ # Extract only the response part (what comes after the instruction)
50
+ if len(input_text) < len(full_text):
51
+ response = full_text[len(input_text):].strip()
52
+ return response
53
+ return full_text.strip()
54
+
55
+ def create_demo(model_id):
56
+ # Load model and tokenizer
57
+ model, tokenizer = load_model(model_id)
58
+
59
+ # Define the interface
60
+ def process_input(instruction, max_length, temperature, top_p):
61
+ try:
62
+ return generate_response(
63
+ instruction,
64
+ model,
65
+ tokenizer,
66
+ max_length=max_length,
67
+ temperature=temperature,
68
+ top_p=top_p
69
+ )
70
+ except Exception as e:
71
+ return f"Error generating response: {str(e)}"
72
+
73
+ # Create the interface
74
+ demo = gr.Interface(
75
+ fn=process_input,
76
+ inputs=[
77
+ gr.Textbox(
78
+ label="Input Text",
79
+ placeholder="Enter your text here...",
80
+ lines=4
81
+ ),
82
+ gr.Slider(
83
+ minimum=50,
84
+ maximum=500,
85
+ value=150,
86
+ step=10,
87
+ label="Maximum Length"
88
+ ),
89
+ gr.Slider(
90
+ minimum=0.1,
91
+ maximum=1.0,
92
+ value=0.7,
93
+ step=0.1,
94
+ label="Temperature"
95
+ ),
96
+ gr.Slider(
97
+ minimum=0.1,
98
+ maximum=1.0,
99
+ value=0.9,
100
+ step=0.1,
101
+ label="Top P"
102
+ )
103
+ ],
104
+ outputs=gr.Textbox(label="Completion", lines=8),
105
+ title="Phi-2 GRPO Model Demo",
106
+ description="""This is a generative model trained using GRPO (Generative Reinforcement from Preference Optimization)
107
+ on the TLDR dataset. The model was trained to generate completions of around 150 characters.
108
+
109
+ You can adjust the generation parameters:
110
+ - **Maximum Length**: Controls the maximum length of the generated response
111
+ - **Temperature**: Higher values make the output more random, lower values make it more focused
112
+ - **Top P**: Controls the cumulative probability threshold for token sampling
113
+ """,
114
+ examples=[
115
+ ["The quick brown fox jumps over the lazy dog."],
116
+ ["In this tutorial, we will explore how to build a neural network for image classification."],
117
+ ["The best way to prepare for an interview is to"],
118
+ ["Python is a popular programming language because"]
119
+ ]
120
+ )
121
+ return demo
122
+
123
+ if __name__ == "__main__":
124
+ # Use your model ID
125
+ model_id = "jatingocodeo/phi2-grpo"
126
+ demo = create_demo(model_id)
127
+ demo.launch()