AC2513 commited on
Commit
b87bea7
·
1 Parent(s): fd3c6d5

added multimodel loading

Browse files
Files changed (1) hide show
  1. app.py +116 -47
app.py CHANGED
@@ -17,20 +17,44 @@ from loguru import logger
17
  from PIL import Image
18
 
19
  dotenv_path = find_dotenv()
20
-
21
  load_dotenv(dotenv_path)
22
 
23
- model_id = os.getenv("MODEL_ID", "google/gemma-3-4b-it")
24
-
25
- input_processor = Gemma3Processor.from_pretrained(model_id)
26
-
27
- model = Gemma3ForConditionalGeneration.from_pretrained(
28
- model_id,
29
- torch_dtype=torch.bfloat16,
30
- device_map="auto",
31
- attn_implementation="eager",
32
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
34
 
35
  def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
36
  frames: list[tuple[Image.Image, float]] = []
@@ -123,10 +147,25 @@ def process_history(history: list[dict]) -> list[dict]:
123
  return messages
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  @spaces.GPU(duration=120)
127
  def run(
128
  message: dict,
129
  history: list[dict],
 
130
  system_prompt: str,
131
  max_new_tokens: int,
132
  max_images: int,
@@ -135,12 +174,25 @@ def run(
135
  top_k: int,
136
  repetition_penalty: float,
137
  ) -> Iterator[str]:
138
-
 
 
 
 
 
 
139
  logger.debug(
140
- f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
141
- f"max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
142
  )
143
 
 
 
 
 
 
 
 
144
  messages = []
145
  if system_prompt:
146
  messages.append(
@@ -151,16 +203,16 @@ def run(
151
  {"role": "user", "content": process_user_input(message, max_images)}
152
  )
153
 
154
- inputs = input_processor.apply_chat_template(
155
  messages,
156
  add_generation_prompt=True,
157
  tokenize=True,
158
  return_dict=True,
159
  return_tensors="pt",
160
- ).to(device=model.device, dtype=torch.bfloat16)
161
 
162
  streamer = TextIteratorStreamer(
163
- input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
164
  )
165
  generate_kwargs = dict(
166
  inputs,
@@ -172,7 +224,7 @@ def run(
172
  repetition_penalty=repetition_penalty,
173
  do_sample=True,
174
  )
175
- t = Thread(target=model.generate, kwargs=generate_kwargs)
176
  t.start()
177
 
178
  output = ""
@@ -180,36 +232,53 @@ def run(
180
  output += delta
181
  yield output
182
 
183
-
184
- demo = gr.ChatInterface(
185
- fn=run,
186
- type="messages",
187
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
188
- textbox=gr.MultimodalTextbox(
189
- file_types=[".mp4", ".jpg", ".png"], file_count="multiple", autofocus=True
190
- ),
191
- multimodal=True,
192
- additional_inputs=[
193
- gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
194
- gr.Slider(
195
- label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
196
- ),
197
- gr.Slider(label="Max Images", minimum=1, maximum=4, step=1, value=2),
198
- gr.Slider(
199
- label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7
200
- ),
201
- gr.Slider(
202
- label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9
203
- ),
204
- gr.Slider(
205
- label="Top K", minimum=1, maximum=100, step=1, value=50
206
  ),
207
- gr.Slider(
208
- label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1
209
- )
210
- ],
211
- stop_btn=False,
212
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  if __name__ == "__main__":
215
  demo.launch()
 
17
  from PIL import Image
18
 
19
  dotenv_path = find_dotenv()
 
20
  load_dotenv(dotenv_path)
21
 
22
+ MODEL_CONFIGS = {
23
+ "Gemma 3 4B IT": {
24
+ "id": os.getenv("MODEL_ID_27", "google/gemma-3-4b-it"),
25
+ "supports_video": True,
26
+ "supports_pdf": False
27
+ },
28
+ "Gemma 3 1B IT": {
29
+ "id": os.getenv("MODEL_ID_12", "google/gemma-3-1b-it"),
30
+ "supports_video": True,
31
+ "supports_pdf": False
32
+ },
33
+ "Gemma 3N E4B IT": {
34
+ "id": os.getenv("MODEL_ID_3N", "google/gemma-3n-E4B-it"),
35
+ "supports_video": False,
36
+ "supports_pdf": False
37
+ }
38
+ }
39
+
40
+ # Load all models and processors
41
+ models = {}
42
+ processor = Gemma3Processor.from_pretrained("google/gemma-3-4b-it")
43
+
44
+ for model_name, config in MODEL_CONFIGS.items():
45
+ logger.info(f"Loading {model_name}...")
46
+
47
+ models[model_name] = Gemma3ForConditionalGeneration.from_pretrained(
48
+ config["id"],
49
+ torch_dtype=torch.bfloat16,
50
+ device_map="auto",
51
+ attn_implementation="eager",
52
+ )
53
+
54
+ logger.info(f"✓ {model_name} loaded successfully")
55
 
56
+ # Current model selection (default)
57
+ current_model = "Gemma 3 27B IT"
58
 
59
  def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
60
  frames: list[tuple[Image.Image, float]] = []
 
147
  return messages
148
 
149
 
150
+ def get_supported_file_types(model_name: str) -> list[str]:
151
+ """Get supported file types for the selected model."""
152
+ config = MODEL_CONFIGS[model_name]
153
+
154
+ base_types = [".jpg", ".png", ".jpeg", ".gif", ".bmp", ".webp"]
155
+
156
+ if config["supports_video"]:
157
+ base_types.extend([".mp4", ".mov", ".avi"])
158
+
159
+ if config["supports_pdf"]:
160
+ base_types.append(".pdf")
161
+
162
+ return base_types
163
+
164
  @spaces.GPU(duration=120)
165
  def run(
166
  message: dict,
167
  history: list[dict],
168
+ model_name: str,
169
  system_prompt: str,
170
  max_new_tokens: int,
171
  max_images: int,
 
174
  top_k: int,
175
  repetition_penalty: float,
176
  ) -> Iterator[str]:
177
+
178
+ global current_model
179
+
180
+ if model_name != current_model:
181
+ current_model = model_name
182
+ logger.info(f"Switched to model: {model_name}")
183
+
184
  logger.debug(
185
+ f"\n message: {message} \n history: {history} \n model: {model_name} \n "
186
+ f"system_prompt: {system_prompt} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
187
  )
188
 
189
+ config = MODEL_CONFIGS[model_name]
190
+ if not config["supports_video"] and message.get("files"):
191
+ for file_path in message["files"]:
192
+ if file_path.endswith((".mp4", ".mov", ".avi")):
193
+ yield "Error: Selected model does not support video files. Please choose a video-capable model."
194
+ return
195
+
196
  messages = []
197
  if system_prompt:
198
  messages.append(
 
203
  {"role": "user", "content": process_user_input(message, max_images)}
204
  )
205
 
206
+ inputs = processor.apply_chat_template(
207
  messages,
208
  add_generation_prompt=True,
209
  tokenize=True,
210
  return_dict=True,
211
  return_tensors="pt",
212
+ ).to(device=models[current_model].device, dtype=torch.bfloat16)
213
 
214
  streamer = TextIteratorStreamer(
215
+ processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
216
  )
217
  generate_kwargs = dict(
218
  inputs,
 
224
  repetition_penalty=repetition_penalty,
225
  do_sample=True,
226
  )
227
+ t = Thread(target=models[current_model].generate, kwargs=generate_kwargs)
228
  t.start()
229
 
230
  output = ""
 
232
  output += delta
233
  yield output
234
 
235
+ def create_interface():
236
+ """Create interface with model selector."""
237
+
238
+ initial_file_types = get_supported_file_types(current_model)
239
+
240
+ demo = gr.ChatInterface(
241
+ fn=run,
242
+ type="messages",
243
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
244
+ textbox=gr.MultimodalTextbox(
245
+ file_types=initial_file_types,
246
+ file_count="multiple",
247
+ autofocus=True
 
 
 
 
 
 
 
 
 
 
248
  ),
249
+ multimodal=True,
250
+ additional_inputs=[
251
+ gr.Dropdown(
252
+ label="Model",
253
+ choices=list(MODEL_CONFIGS.keys()),
254
+ value=current_model,
255
+ info="Select which model to use for generation"
256
+ ),
257
+ gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
258
+ gr.Slider(
259
+ label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
260
+ ),
261
+ gr.Slider(label="Max Images", minimum=1, maximum=8, step=1, value=2),
262
+ gr.Slider(
263
+ label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7
264
+ ),
265
+ gr.Slider(
266
+ label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9
267
+ ),
268
+ gr.Slider(
269
+ label="Top K", minimum=1, maximum=100, step=1, value=50
270
+ ),
271
+ gr.Slider(
272
+ label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1
273
+ ),
274
+ ],
275
+ stop_btn=False,
276
+ title="Multi-Model Gemma Chat"
277
+ )
278
+
279
+ return demo
280
+
281
+ demo = create_interface()
282
 
283
  if __name__ == "__main__":
284
  demo.launch()