AC2513 commited on
Commit
c0fc237
·
1 Parent(s): b87bea7

revert change

Browse files
Files changed (1) hide show
  1. app.py +47 -116
app.py CHANGED
@@ -17,44 +17,20 @@ from loguru import logger
17
  from PIL import Image
18
 
19
  dotenv_path = find_dotenv()
 
20
  load_dotenv(dotenv_path)
21
 
22
- MODEL_CONFIGS = {
23
- "Gemma 3 4B IT": {
24
- "id": os.getenv("MODEL_ID_27", "google/gemma-3-4b-it"),
25
- "supports_video": True,
26
- "supports_pdf": False
27
- },
28
- "Gemma 3 1B IT": {
29
- "id": os.getenv("MODEL_ID_12", "google/gemma-3-1b-it"),
30
- "supports_video": True,
31
- "supports_pdf": False
32
- },
33
- "Gemma 3N E4B IT": {
34
- "id": os.getenv("MODEL_ID_3N", "google/gemma-3n-E4B-it"),
35
- "supports_video": False,
36
- "supports_pdf": False
37
- }
38
- }
39
-
40
- # Load all models and processors
41
- models = {}
42
- processor = Gemma3Processor.from_pretrained("google/gemma-3-4b-it")
43
-
44
- for model_name, config in MODEL_CONFIGS.items():
45
- logger.info(f"Loading {model_name}...")
46
-
47
- models[model_name] = Gemma3ForConditionalGeneration.from_pretrained(
48
- config["id"],
49
- torch_dtype=torch.bfloat16,
50
- device_map="auto",
51
- attn_implementation="eager",
52
- )
53
-
54
- logger.info(f"✓ {model_name} loaded successfully")
55
 
56
- # Current model selection (default)
57
- current_model = "Gemma 3 27B IT"
58
 
59
  def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
60
  frames: list[tuple[Image.Image, float]] = []
@@ -147,25 +123,10 @@ def process_history(history: list[dict]) -> list[dict]:
147
  return messages
148
 
149
 
150
- def get_supported_file_types(model_name: str) -> list[str]:
151
- """Get supported file types for the selected model."""
152
- config = MODEL_CONFIGS[model_name]
153
-
154
- base_types = [".jpg", ".png", ".jpeg", ".gif", ".bmp", ".webp"]
155
-
156
- if config["supports_video"]:
157
- base_types.extend([".mp4", ".mov", ".avi"])
158
-
159
- if config["supports_pdf"]:
160
- base_types.append(".pdf")
161
-
162
- return base_types
163
-
164
  @spaces.GPU(duration=120)
165
  def run(
166
  message: dict,
167
  history: list[dict],
168
- model_name: str,
169
  system_prompt: str,
170
  max_new_tokens: int,
171
  max_images: int,
@@ -174,25 +135,12 @@ def run(
174
  top_k: int,
175
  repetition_penalty: float,
176
  ) -> Iterator[str]:
177
-
178
- global current_model
179
-
180
- if model_name != current_model:
181
- current_model = model_name
182
- logger.info(f"Switched to model: {model_name}")
183
-
184
  logger.debug(
185
- f"\n message: {message} \n history: {history} \n model: {model_name} \n "
186
- f"system_prompt: {system_prompt} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
187
  )
188
 
189
- config = MODEL_CONFIGS[model_name]
190
- if not config["supports_video"] and message.get("files"):
191
- for file_path in message["files"]:
192
- if file_path.endswith((".mp4", ".mov", ".avi")):
193
- yield "Error: Selected model does not support video files. Please choose a video-capable model."
194
- return
195
-
196
  messages = []
197
  if system_prompt:
198
  messages.append(
@@ -203,16 +151,16 @@ def run(
203
  {"role": "user", "content": process_user_input(message, max_images)}
204
  )
205
 
206
- inputs = processor.apply_chat_template(
207
  messages,
208
  add_generation_prompt=True,
209
  tokenize=True,
210
  return_dict=True,
211
  return_tensors="pt",
212
- ).to(device=models[current_model].device, dtype=torch.bfloat16)
213
 
214
  streamer = TextIteratorStreamer(
215
- processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
216
  )
217
  generate_kwargs = dict(
218
  inputs,
@@ -224,7 +172,7 @@ def run(
224
  repetition_penalty=repetition_penalty,
225
  do_sample=True,
226
  )
227
- t = Thread(target=models[current_model].generate, kwargs=generate_kwargs)
228
  t.start()
229
 
230
  output = ""
@@ -232,53 +180,36 @@ def run(
232
  output += delta
233
  yield output
234
 
235
- def create_interface():
236
- """Create interface with model selector."""
237
-
238
- initial_file_types = get_supported_file_types(current_model)
239
-
240
- demo = gr.ChatInterface(
241
- fn=run,
242
- type="messages",
243
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
244
- textbox=gr.MultimodalTextbox(
245
- file_types=initial_file_types,
246
- file_count="multiple",
247
- autofocus=True
248
- ),
249
- multimodal=True,
250
- additional_inputs=[
251
- gr.Dropdown(
252
- label="Model",
253
- choices=list(MODEL_CONFIGS.keys()),
254
- value=current_model,
255
- info="Select which model to use for generation"
256
- ),
257
- gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
258
- gr.Slider(
259
- label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
260
- ),
261
- gr.Slider(label="Max Images", minimum=1, maximum=8, step=1, value=2),
262
- gr.Slider(
263
- label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7
264
- ),
265
- gr.Slider(
266
- label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9
267
- ),
268
- gr.Slider(
269
- label="Top K", minimum=1, maximum=100, step=1, value=50
270
- ),
271
- gr.Slider(
272
- label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1
273
- ),
274
- ],
275
- stop_btn=False,
276
- title="Multi-Model Gemma Chat"
277
- )
278
-
279
- return demo
280
 
281
- demo = create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  if __name__ == "__main__":
284
  demo.launch()
 
17
  from PIL import Image
18
 
19
  dotenv_path = find_dotenv()
20
+
21
  load_dotenv(dotenv_path)
22
 
23
+ model_id = os.getenv("MODEL_ID", "google/gemma-3-4b-it")
24
+
25
+ input_processor = Gemma3Processor.from_pretrained(model_id)
26
+
27
+ model = Gemma3ForConditionalGeneration.from_pretrained(
28
+ model_id,
29
+ torch_dtype=torch.bfloat16,
30
+ device_map="auto",
31
+ attn_implementation="eager",
32
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
34
 
35
  def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
36
  frames: list[tuple[Image.Image, float]] = []
 
123
  return messages
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  @spaces.GPU(duration=120)
127
  def run(
128
  message: dict,
129
  history: list[dict],
 
130
  system_prompt: str,
131
  max_new_tokens: int,
132
  max_images: int,
 
135
  top_k: int,
136
  repetition_penalty: float,
137
  ) -> Iterator[str]:
138
+
 
 
 
 
 
 
139
  logger.debug(
140
+ f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
141
+ f"max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
142
  )
143
 
 
 
 
 
 
 
 
144
  messages = []
145
  if system_prompt:
146
  messages.append(
 
151
  {"role": "user", "content": process_user_input(message, max_images)}
152
  )
153
 
154
+ inputs = input_processor.apply_chat_template(
155
  messages,
156
  add_generation_prompt=True,
157
  tokenize=True,
158
  return_dict=True,
159
  return_tensors="pt",
160
+ ).to(device=model.device, dtype=torch.bfloat16)
161
 
162
  streamer = TextIteratorStreamer(
163
+ input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
164
  )
165
  generate_kwargs = dict(
166
  inputs,
 
172
  repetition_penalty=repetition_penalty,
173
  do_sample=True,
174
  )
175
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
176
  t.start()
177
 
178
  output = ""
 
180
  output += delta
181
  yield output
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ demo = gr.ChatInterface(
185
+ fn=run,
186
+ type="messages",
187
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
188
+ textbox=gr.MultimodalTextbox(
189
+ file_types=[".mp4", ".jpg", ".png"], file_count="multiple", autofocus=True
190
+ ),
191
+ multimodal=True,
192
+ additional_inputs=[
193
+ gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
194
+ gr.Slider(
195
+ label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
196
+ ),
197
+ gr.Slider(label="Max Images", minimum=1, maximum=4, step=1, value=2),
198
+ gr.Slider(
199
+ label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7
200
+ ),
201
+ gr.Slider(
202
+ label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9
203
+ ),
204
+ gr.Slider(
205
+ label="Top K", minimum=1, maximum=100, step=1, value=50
206
+ ),
207
+ gr.Slider(
208
+ label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1
209
+ )
210
+ ],
211
+ stop_btn=False,
212
+ )
213
 
214
  if __name__ == "__main__":
215
  demo.launch()