AC2513 commited on
Commit
97a997a
·
1 Parent(s): 89f9a59

adjusted model loading

Browse files
Files changed (2) hide show
  1. app.py +29 -34
  2. tests/conftest.py +12 -0
app.py CHANGED
@@ -28,37 +28,30 @@ model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
28
  MAX_VIDEO_SIZE = 100 * 1024 * 1024 # 100 MB
29
  MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10 MB
30
 
31
- # Global variables to hold models (loaded lazily)
32
- input_processor = None
33
- model_12 = None
34
- model_3n = None
35
-
36
- def load_models():
37
- """Load models lazily when needed."""
38
- global input_processor, model_12, model_3n
39
-
40
- # Skip model loading during testing
41
- if os.getenv("SKIP_MODEL_LOADING") == "true" or "pytest" in os.getenv("_", ""):
42
- return
43
-
44
- if input_processor is None:
45
- input_processor = Gemma3Processor.from_pretrained(model_12_id)
46
-
47
- if model_12 is None:
48
- model_12 = Gemma3ForConditionalGeneration.from_pretrained(
49
- model_12_id,
50
- torch_dtype=torch.bfloat16,
51
- device_map="auto",
52
- attn_implementation="eager",
53
- )
54
-
55
- if model_3n is None:
56
- model_3n = Gemma3nForConditionalGeneration.from_pretrained(
57
- model_3n_id,
58
- torch_dtype=torch.bfloat16,
59
- device_map="auto",
60
- attn_implementation="eager",
61
- )
62
 
63
 
64
  def check_file_size(file_path: str) -> bool:
@@ -229,9 +222,6 @@ def run(
229
  repetition_penalty: float,
230
  ) -> Iterator[str]:
231
 
232
- # Load models only when needed (not during testing)
233
- load_models()
234
-
235
  # Define preset system prompts
236
  preset_prompts = {
237
  "General Assistant": "You are a helpful AI assistant capable of analyzing images, videos, and PDF documents. Provide clear, accurate, and helpful responses to user queries.",
@@ -259,6 +249,11 @@ def run(
259
  )
260
 
261
  selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
 
 
 
 
 
262
 
263
  messages = []
264
  if system_prompt:
 
28
  MAX_VIDEO_SIZE = 100 * 1024 * 1024 # 100 MB
29
  MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10 MB
30
 
31
+ # Skip model loading during tests
32
+ SKIP_MODEL_LOADING = os.getenv("SKIP_MODEL_LOADING", "false").lower() == "true"
33
+
34
+ if not SKIP_MODEL_LOADING:
35
+ input_processor = Gemma3Processor.from_pretrained(model_12_id)
36
+
37
+ model_12 = Gemma3ForConditionalGeneration.from_pretrained(
38
+ model_12_id,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ attn_implementation="eager",
42
+ )
43
+
44
+ model_3n = Gemma3nForConditionalGeneration.from_pretrained(
45
+ model_3n_id,
46
+ torch_dtype=torch.bfloat16,
47
+ device_map="auto",
48
+ attn_implementation="eager",
49
+ )
50
+ else:
51
+ # Mock objects for testing
52
+ input_processor = None
53
+ model_12 = None
54
+ model_3n = None
 
 
 
 
 
 
 
55
 
56
 
57
  def check_file_size(file_path: str) -> bool:
 
222
  repetition_penalty: float,
223
  ) -> Iterator[str]:
224
 
 
 
 
225
  # Define preset system prompts
226
  preset_prompts = {
227
  "General Assistant": "You are a helpful AI assistant capable of analyzing images, videos, and PDF documents. Provide clear, accurate, and helpful responses to user queries.",
 
249
  )
250
 
251
  selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
252
+
253
+ # If models are skipped (during testing), return a mock response
254
+ if SKIP_MODEL_LOADING or selected_model is None or input_processor is None:
255
+ yield "Mock response for testing - models are not loaded."
256
+ return
257
 
258
  messages = []
259
  if system_prompt:
tests/conftest.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ # Set environment variable before any imports
5
+ os.environ["SKIP_MODEL_LOADING"] = "true"
6
+
7
+ @pytest.fixture(autouse=True)
8
+ def skip_model_loading():
9
+ """Automatically set environment variable to skip model loading for all tests."""
10
+ # Environment variable is already set above
11
+ yield
12
+ # Keep the variable set for the entire test session