avans06 commited on
Commit
8b989ce
·
1 Parent(s): 4e136fc

fix: allow YOLOv5 custom model loading under PyTorch 2.6+ by using safe_globals context

Browse files
Files changed (1) hide show
  1. image_processing/model.py +39 -18
image_processing/model.py CHANGED
@@ -2,7 +2,7 @@ class Model:
2
  def __init__(self):
3
  self.model = None
4
  self.imported = False
5
-
6
  def load(self):
7
  if self.model is None:
8
  self.__load()
@@ -12,38 +12,59 @@ class Model:
12
  self.imported = True
13
  import torch
14
  import torch.serialization
 
15
  import pathlib
16
  import sys
17
  import os
18
  from myutils.respath import resource_path
19
  from yolov5.models.yolo import DetectionModel
20
- torch.serialization.add_safe_globals([DetectionModel]) # Exception: Weights only load failed.
21
 
 
 
 
 
 
 
22
  # Redirect sys.stderr to a file or a valid stream
23
  if sys.stderr is None:
24
  sys.stderr = open(os.devnull, 'w')
25
-
26
  # Check if the current operating system is Windows
27
  is_windows = (sys.platform == "win32")
28
 
29
- if is_windows:
30
- # If on Windows, apply the patch temporarily
31
- temp = pathlib.PosixPath
32
- pathlib.PosixPath = pathlib.WindowsPath
33
- try:
34
- # Load the model with the patch applied
35
- self.model = torch.hub.load('ultralytics/yolov5', 'custom', path=resource_path('ai-models/2024-11-00/best.pt'))
36
- finally:
37
- # CRITICAL: Always restore the original class, even if loading fails
38
- pathlib.PosixPath = temp
39
- else:
40
- # If on Linux, macOS, or other systems, load the model directly
41
- self.model = torch.hub.load('ultralytics/yolov5', 'custom', path=resource_path('ai-models/2024-11-00/best.pt'))
42
-
43
-
 
 
 
 
 
 
 
 
 
44
  def __call__(self, *args, **kwds):
45
  if self.model is None:
46
  self.__load()
47
  return self.model(*args, **kwds)
48
 
 
 
 
 
 
 
49
  model = Model()
 
2
  def __init__(self):
3
  self.model = None
4
  self.imported = False
5
+
6
  def load(self):
7
  if self.model is None:
8
  self.__load()
 
12
  self.imported = True
13
  import torch
14
  import torch.serialization
15
+ from torch.serialization import safe_globals
16
  import pathlib
17
  import sys
18
  import os
19
  from myutils.respath import resource_path
20
  from yolov5.models.yolo import DetectionModel
 
21
 
22
+ # Add DetectionModel to safe globals—must be inside context that covers load
23
+ torch.serialization.add_safe_globals([DetectionModel])
24
+
25
+ # Context manager to ensure allowlisted globals during load
26
+ self._safe_ctx = safe_globals([DetectionModel])
27
+
28
  # Redirect sys.stderr to a file or a valid stream
29
  if sys.stderr is None:
30
  sys.stderr = open(os.devnull, 'w')
31
+
32
  # Check if the current operating system is Windows
33
  is_windows = (sys.platform == "win32")
34
 
35
+ # Use context during actual load
36
+ with getattr(self, '_safe_ctx', dummy_context()):
37
+ if is_windows:
38
+ # If on Windows, apply the patch temporarily
39
+ temp = pathlib.PosixPath
40
+ pathlib.PosixPath = pathlib.WindowsPath
41
+ try:
42
+ # Load the model with the patch applied
43
+ self.model = torch.hub.load(
44
+ 'ultralytics/yolov5', 'custom',
45
+ path=resource_path('ai-models/2024-11-00/best.pt'),
46
+ force_reload=True
47
+ )
48
+ finally:
49
+ # CRITICAL: Always restore the original class, even if loading fails
50
+ pathlib.PosixPath = temp
51
+ else:
52
+ # If on Linux, macOS, or other systems, load the model directly
53
+ self.model = torch.hub.load(
54
+ 'ultralytics/yolov5', 'custom',
55
+ path=resource_path('ai-models/2024-11-00/best.pt'),
56
+ force_reload=True
57
+ )
58
+
59
  def __call__(self, *args, **kwds):
60
  if self.model is None:
61
  self.__load()
62
  return self.model(*args, **kwds)
63
 
64
+ # A no-op context in case safe_globals isn't set
65
+ from contextlib import contextmanager
66
+ @contextmanager
67
+ def dummy_context():
68
+ yield
69
+
70
  model = Model()