Dhruv-Ty commited on
Commit
ac239ba
·
1 Parent(s): 8ce0600

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. app.py +948 -5
  3. checkpoints/MedSAM2_2411.pt +3 -0
  4. checkpoints/MedSAM2_CTLesion.pt +3 -0
  5. checkpoints/MedSAM2_MRI_LiverLesion.pt +3 -0
  6. checkpoints/MedSAM2_US_Heart.pt +3 -0
  7. checkpoints/MedSAM2_latest.pt +3 -0
  8. checkpoints/README.md +10 -0
  9. download.sh +35 -0
  10. download_checkpoints.py +21 -0
  11. gitignore +13 -0
  12. medsam2_infer_3D_CT.py +304 -0
  13. medsam2_infer_video.py +570 -0
  14. multi_node_train.sh +48 -0
  15. notebooks/MedSAM2_Inference_Video.ipynb +0 -0
  16. notebooks/MedSAM2_inference_CT_Lesion.ipynb +0 -0
  17. pyproject.toml +6 -0
  18. requirements.txt +16 -0
  19. sam2/__init__.py +11 -0
  20. sam2/__pycache__/__init__.cpython-312.pyc +0 -0
  21. sam2/__pycache__/build_sam.cpython-312.pyc +0 -0
  22. sam2/__pycache__/sam2_image_predictor.cpython-312.pyc +0 -0
  23. sam2/__pycache__/sam2_video_predictor_npz.cpython-312.pyc +0 -0
  24. sam2/build_sam.py +207 -0
  25. sam2/configs/sam2.1_hiera_t512.yaml +121 -0
  26. sam2/configs/sam2.1_hiera_tiny_finetune512.yaml +389 -0
  27. sam2/csrc/connected_components.cu +289 -0
  28. sam2/modeling/__init__.py +5 -0
  29. sam2/modeling/__pycache__/__init__.cpython-312.pyc +0 -0
  30. sam2/modeling/__pycache__/memory_attention.cpython-312.pyc +0 -0
  31. sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc +0 -0
  32. sam2/modeling/__pycache__/position_encoding.cpython-312.pyc +0 -0
  33. sam2/modeling/__pycache__/sam2_base.cpython-312.pyc +0 -0
  34. sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc +0 -0
  35. sam2/modeling/backbones/__init__.py +5 -0
  36. sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc +0 -0
  37. sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc +0 -0
  38. sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc +0 -0
  39. sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc +0 -0
  40. sam2/modeling/backbones/hieradet.py +317 -0
  41. sam2/modeling/backbones/image_encoder.py +134 -0
  42. sam2/modeling/backbones/utils.py +95 -0
  43. sam2/modeling/memory_attention.py +169 -0
  44. sam2/modeling/memory_encoder.py +181 -0
  45. sam2/modeling/position_encoding.py +221 -0
  46. sam2/modeling/sam/__init__.py +5 -0
  47. sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc +0 -0
  48. sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc +0 -0
  49. sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc +0 -0
  50. sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc +0 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py CHANGED
@@ -1,7 +1,950 @@
1
- import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for interactive medical video segmentation using MedSAM2.
3
+ Please use gradio==3.38.0
4
+ """
5
 
6
+ import datetime
7
+ import gc
8
+ from glob import glob
9
+ import hashlib
10
+ import math
11
+ import multiprocessing as mp
12
+ import platform
13
+ import os
14
+ from os.path import basename, splitext, dirname
15
+ import threading
16
+ import time
17
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
18
+ import shutil
19
+ import ffmpeg
20
+ from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
21
+ import zipfile
22
+ import torch
23
+ import numpy as np
24
+ import matplotlib.pyplot as plt
25
+ from PIL import Image
26
+ from sam2.build_sam import build_sam2
27
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
28
+ from sam2.build_sam import build_sam2_video_predictor
29
+ import cv2
30
 
31
+
32
+ user_processes = {}
33
+ PROCESS_TIMEOUT = datetime.timedelta(minutes=15)
34
+
35
+ def reset(seg_tracker):
36
+ if seg_tracker is not None:
37
+ predictor, inference_state, image_predictor = seg_tracker
38
+ predictor.reset_state(inference_state)
39
+ del predictor
40
+ del inference_state
41
+ del image_predictor
42
+ del seg_tracker
43
+ gc.collect()
44
+ torch.cuda.empty_cache()
45
+ return None, ({}, {}), None, None, 0, None, None, None, 0, 0,
46
+
47
+ def extract_video_info(input_video):
48
+ if input_video is None:
49
+ return 4, 4, None, None, None, None, None
50
+ cap = cv2.VideoCapture(input_video)
51
+ fps = cap.get(cv2.CAP_PROP_FPS)
52
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
53
+ cap.release()
54
+ return fps, total_frames, None, None, None, None, None
55
+
56
+ def get_meta_from_video(session_id, input_video, scale_slider, config_path, checkpoint_path):
57
+ output_dir = f'/tmp/output_frames/{session_id}'
58
+ output_masks_dir = f'/tmp/output_masks/{session_id}'
59
+ output_combined_dir = f'/tmp/output_combined/{session_id}'
60
+ clear_folder(output_dir)
61
+ clear_folder(output_masks_dir)
62
+ clear_folder(output_combined_dir)
63
+ if input_video is None:
64
+ return None, ({}, {}), None, None, (4, 1, 4), None, None, None, 0, 0
65
+ cap = cv2.VideoCapture(input_video)
66
+ fps = cap.get(cv2.CAP_PROP_FPS)
67
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
68
+ cap.release()
69
+ frame_interval = max(1, int(fps // scale_slider))
70
+ print(f"frame_interval: {frame_interval}")
71
+ try:
72
+ ffmpeg.input(input_video, hwaccel='cuda').output(
73
+ os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
74
+ vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
75
+ ).run()
76
+ except:
77
+ print(f"ffmpeg cuda err")
78
+ ffmpeg.input(input_video).output(
79
+ os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
80
+ vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
81
+ ).run()
82
+
83
+ first_frame_path = os.path.join(output_dir, '0000000.jpg')
84
+ first_frame = cv2.imread(first_frame_path)
85
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
86
+
87
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
88
+ if torch.cuda.get_device_properties(0).major >= 8:
89
+ torch.backends.cuda.matmul.allow_tf32 = True
90
+ torch.backends.cudnn.allow_tf32 = True
91
+
92
+ predictor = build_sam2_video_predictor(config_path, checkpoint_path, device="cuda")
93
+ sam2_model = build_sam2(config_path, checkpoint_path, device="cuda")
94
+ image_predictor = SAM2ImagePredictor(sam2_model)
95
+ inference_state = predictor.init_state(video_path=output_dir)
96
+ predictor.reset_state(inference_state)
97
+ return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, (fps, frame_interval, total_frames), None, None, None, 0, 0
98
+
99
+ def mask2bbox(mask):
100
+ if len(np.where(mask > 0)[0]) == 0:
101
+ print(f'not mask')
102
+ return np.array([0, 0, 0, 0]).astype(np.int64), False
103
+ x_ = np.sum(mask, axis=0)
104
+ y_ = np.sum(mask, axis=1)
105
+ x0 = np.min(np.nonzero(x_)[0])
106
+ x1 = np.max(np.nonzero(x_)[0])
107
+ y0 = np.min(np.nonzero(y_)[0])
108
+ y1 = np.max(np.nonzero(y_)[0])
109
+ return np.array([x0, y0, x1, y1]).astype(np.int64), True
110
+
111
+ def sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id):
112
+ predictor, inference_state, image_predictor = seg_tracker
113
+ image_path = f'/tmp/output_frames/{session_id}/{frame_num:07d}.jpg'
114
+ image = cv2.imread(image_path)
115
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
116
+ display_image = drawing_board["image"]
117
+ image_predictor.set_image(image)
118
+ input_mask = drawing_board["mask"]
119
+ input_mask[input_mask != 0] = 255
120
+ if last_draw is not None:
121
+ diff_mask = cv2.absdiff(input_mask, last_draw)
122
+ input_mask = diff_mask
123
+ bbox, hasMask = mask2bbox(input_mask[:, :, 0])
124
+ if not hasMask :
125
+ return seg_tracker, display_image, display_image, None
126
+ masks, scores, logits = image_predictor.predict( point_coords=None, point_labels=None, box=bbox[None, :], multimask_output=False,)
127
+ mask = masks > 0.0
128
+ masked_frame = show_mask(mask, display_image, ann_obj_id)
129
+ masked_with_rect = draw_rect(masked_frame, bbox, ann_obj_id)
130
+ frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=frame_num, obj_id=ann_obj_id, mask=mask[0])
131
+ last_draw = drawing_board["mask"]
132
+ return seg_tracker, masked_with_rect, masked_with_rect, last_draw
133
+
134
+ def draw_rect(image, bbox, obj_id):
135
+ cmap = plt.get_cmap("tab10")
136
+ color = np.array(cmap(obj_id)[:3])
137
+ rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
138
+ inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
139
+ x0, y0, x1, y1 = bbox
140
+ image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), rgb_color, thickness=2)
141
+ return image_with_rect
142
+
143
+ def sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point):
144
+ points_dict, labels_dict = click_stack
145
+ predictor, inference_state, image_predictor = seg_tracker
146
+ ann_frame_idx = frame_num # the frame index we interact with
147
+ print(f'ann_frame_idx: {ann_frame_idx}')
148
+ if point_mode == "Positive":
149
+ label = np.array([1], np.int32)
150
+ else:
151
+ label = np.array([0], np.int32)
152
+
153
+ if ann_frame_idx not in points_dict:
154
+ points_dict[ann_frame_idx] = {}
155
+ if ann_frame_idx not in labels_dict:
156
+ labels_dict[ann_frame_idx] = {}
157
+
158
+ if ann_obj_id not in points_dict[ann_frame_idx]:
159
+ points_dict[ann_frame_idx][ann_obj_id] = np.empty((0, 2), dtype=np.float32)
160
+ if ann_obj_id not in labels_dict[ann_frame_idx]:
161
+ labels_dict[ann_frame_idx][ann_obj_id] = np.empty((0,), dtype=np.int32)
162
+
163
+ points_dict[ann_frame_idx][ann_obj_id] = np.append(points_dict[ann_frame_idx][ann_obj_id], point, axis=0)
164
+ labels_dict[ann_frame_idx][ann_obj_id] = np.append(labels_dict[ann_frame_idx][ann_obj_id], label, axis=0)
165
+
166
+ click_stack = (points_dict, labels_dict)
167
+
168
+ frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points(
169
+ inference_state=inference_state,
170
+ frame_idx=ann_frame_idx,
171
+ obj_id=ann_obj_id,
172
+ points=points_dict[ann_frame_idx][ann_obj_id],
173
+ labels=labels_dict[ann_frame_idx][ann_obj_id],
174
+ )
175
+
176
+ image_path = f'/tmp/output_frames/{session_id}/{ann_frame_idx:07d}.jpg'
177
+ image = cv2.imread(image_path)
178
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
179
+
180
+ masked_frame = image.copy()
181
+ for i, obj_id in enumerate(out_obj_ids):
182
+ mask = (out_mask_logits[i] > 0.0).cpu().numpy()
183
+ masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
184
+ masked_frame_with_markers = draw_markers(masked_frame, points_dict[ann_frame_idx], labels_dict[ann_frame_idx])
185
+
186
+ return seg_tracker, masked_frame_with_markers, masked_frame_with_markers, click_stack
187
+
188
+ def draw_markers(image, points_dict, labels_dict):
189
+ cmap = plt.get_cmap("tab10")
190
+ image_h, image_w = image.shape[:2]
191
+ marker_size = max(1, int(min(image_h, image_w) * 0.05))
192
+
193
+ for obj_id in points_dict:
194
+ color = np.array(cmap(obj_id)[:3])
195
+ rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
196
+ inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
197
+ for point, label in zip(points_dict[obj_id], labels_dict[obj_id]):
198
+ x, y = int(point[0]), int(point[1])
199
+ if label == 1:
200
+ cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_CROSS, markerSize=marker_size, thickness=2)
201
+ else:
202
+ cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_TILTED_CROSS, markerSize=int(marker_size / np.sqrt(2)), thickness=2)
203
+
204
+ return image
205
+
206
+ def show_mask(mask, image=None, obj_id=None):
207
+ cmap = plt.get_cmap("tab10")
208
+ cmap_idx = 0 if obj_id is None else obj_id
209
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
210
+
211
+ h, w = mask.shape[-2:]
212
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
213
+ mask_image = (mask_image * 255).astype(np.uint8)
214
+ if image is not None:
215
+ image_h, image_w = image.shape[:2]
216
+ if (image_h, image_w) != (h, w):
217
+ raise ValueError(f"Image dimensions ({image_h}, {image_w}) and mask dimensions ({h}, {w}) do not match")
218
+ colored_mask = np.zeros_like(image, dtype=np.uint8)
219
+ for c in range(3):
220
+ colored_mask[..., c] = mask_image[..., c]
221
+ alpha_mask = mask_image[..., 3] / 255.0
222
+ for c in range(3):
223
+ image[..., c] = np.where(alpha_mask > 0, (1 - alpha_mask) * image[..., c] + alpha_mask * colored_mask[..., c], image[..., c])
224
+ return image
225
+ return mask_image
226
+
227
+ def show_res_by_slider(session_id, frame_per, click_stack):
228
+ image_path = f'/tmp/output_frames/{session_id}'
229
+ output_combined_dir = f'/tmp/output_combined/{session_id}'
230
+
231
+ combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)])
232
+ if combined_frames:
233
+ output_masked_frame_path = combined_frames
234
+ else:
235
+ original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)])
236
+ output_masked_frame_path = original_frames
237
+
238
+ total_frames_num = len(output_masked_frame_path)
239
+ if total_frames_num == 0:
240
+ print("No output results found")
241
+ return None, None, 0
242
+ else:
243
+ frame_num = math.floor(total_frames_num * frame_per)
244
+ if frame_num >= total_frames_num:
245
+ frame_num = total_frames_num - 1
246
+ chosen_frame_path = output_masked_frame_path[frame_num]
247
+ print(f"{chosen_frame_path}")
248
+ chosen_frame_show = cv2.imread(chosen_frame_path)
249
+ chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
250
+ points_dict, labels_dict = click_stack
251
+ if frame_num in points_dict and frame_num in labels_dict:
252
+ chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
253
+ return chosen_frame_show, chosen_frame_show, frame_num
254
+
255
+ def clear_folder(folder_path):
256
+ if os.path.exists(folder_path):
257
+ shutil.rmtree(folder_path)
258
+ os.makedirs(folder_path)
259
+
260
+ def zip_folder(folder_path, output_zip_path):
261
+ with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_STORED) as zipf:
262
+ for root, _, files in os.walk(folder_path):
263
+ for file in files:
264
+ file_path = os.path.join(root, file)
265
+ zipf.write(file_path, os.path.relpath(file_path, folder_path))
266
+
267
+ def tracking_objects(session_id, seg_tracker, frame_num, input_video):
268
+ output_dir = f'/tmp/output_frames/{session_id}'
269
+ output_masks_dir = f'/tmp/output_masks/{session_id}'
270
+ output_combined_dir = f'/tmp/output_combined/{session_id}'
271
+ output_files_dir = f'/tmp/output_files/{session_id}'
272
+ output_video_path = f'{output_files_dir}/output_video.mp4'
273
+ output_zip_path = f'{output_files_dir}/output_masks.zip'
274
+ clear_folder(output_masks_dir)
275
+ clear_folder(output_combined_dir)
276
+ clear_folder(output_files_dir)
277
+ video_segments = {}
278
+ predictor, inference_state, image_predictor = seg_tracker
279
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
280
+ video_segments[out_frame_idx] = {
281
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
282
+ for i, out_obj_id in enumerate(out_obj_ids)
283
+ }
284
+ frame_files = sorted([f for f in os.listdir(output_dir) if f.endswith('.jpg')])
285
+ # for frame_idx in sorted(video_segments.keys()):
286
+ for frame_file in frame_files:
287
+ frame_idx = int(os.path.splitext(frame_file)[0])
288
+ frame_path = os.path.join(output_dir, frame_file)
289
+ image = cv2.imread(frame_path)
290
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
291
+ masked_frame = image.copy()
292
+ if frame_idx in video_segments:
293
+ for obj_id, mask in video_segments[frame_idx].items():
294
+ masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
295
+ mask_output_path = os.path.join(output_masks_dir, f'{obj_id}_{frame_idx:07d}.png')
296
+ cv2.imwrite(mask_output_path, show_mask(mask))
297
+ combined_output_path = os.path.join(output_combined_dir, f'{frame_idx:07d}.png')
298
+ combined_image_bgr = cv2.cvtColor(masked_frame, cv2.COLOR_RGB2BGR)
299
+ cv2.imwrite(combined_output_path, combined_image_bgr)
300
+ if frame_idx == frame_num:
301
+ final_masked_frame = masked_frame
302
+
303
+ cap = cv2.VideoCapture(input_video)
304
+ fps = cap.get(cv2.CAP_PROP_FPS)
305
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
306
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
307
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
308
+ cap.release()
309
+ # output_frames = int(total_frames * scale_slider)
310
+ output_frames = len([name for name in os.listdir(output_combined_dir) if os.path.isfile(os.path.join(output_combined_dir, name)) and name.endswith('.png')])
311
+ out_fps = fps * output_frames / total_frames
312
+
313
+ # ffmpeg.input(os.path.join(output_combined_dir, '%07d.png'), framerate=out_fps).output(output_video_path, vcodec='h264_nvenc', pix_fmt='yuv420p').run()
314
+
315
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
316
+ # out = cv2.VideoWriter(output_video_path, fourcc, out_fps, (frame_width, frame_height))
317
+ # for i in range(output_frames):
318
+ # frame_path = os.path.join(output_combined_dir, f'{i:07d}.png')
319
+ # frame = cv2.imread(frame_path)
320
+ # out.write(frame)
321
+ # out.release()
322
+
323
+ image_files = [os.path.join(output_combined_dir, f'{i:07d}.png') for i in range(output_frames)]
324
+ clip = ImageSequenceClip(image_files, fps=out_fps)
325
+ clip.write_videofile(output_video_path, codec="libx264", fps=out_fps)
326
+
327
+ zip_folder(output_masks_dir, output_zip_path)
328
+ print("done")
329
+ return final_masked_frame, final_masked_frame, output_video_path, output_video_path, output_zip_path, ({}, {})
330
+
331
+ def increment_ann_obj_id(max_obj_id):
332
+ max_obj_id += 1
333
+ ann_obj_id = max_obj_id
334
+ return ann_obj_id, max_obj_id
335
+
336
+ def update_current_id(ann_obj_id):
337
+ return ann_obj_id
338
+
339
+ def drawing_board_get_input_first_frame(input_first_frame):
340
+ return input_first_frame
341
+
342
+ def process_video(queue, result_queue, session_id):
343
+ seg_tracker = None
344
+ click_stack = ({}, {})
345
+ frame_num = int(0)
346
+ ann_obj_id = int(0)
347
+ last_draw = None
348
+ while True:
349
+ task = queue.get()
350
+ if task["command"] == "exit":
351
+ print(f"Process for {session_id} exiting.")
352
+ break
353
+ elif task["command"] == "extract_video_info":
354
+ input_video = task["input_video"]
355
+ fps, total_frames, input_first_frame, drawing_board, output_video, output_mp4, output_mask = extract_video_info(input_video)
356
+ result_queue.put({"fps": fps, "total_frames": total_frames, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask})
357
+ elif task["command"] == "get_meta_from_video":
358
+ input_video = task["input_video"]
359
+ scale_slider = task["scale_slider"]
360
+ config_path = task["config_path"]
361
+ checkpoint_path = task["checkpoint_path"]
362
+ seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id, max_obj_id = get_meta_from_video(session_id, input_video, scale_slider, config_path, checkpoint_path)
363
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id, "max_obj_id": max_obj_id})
364
+ elif task["command"] == "sam_stroke":
365
+ drawing_board = task["drawing_board"]
366
+ last_draw = task["last_draw"]
367
+ frame_num = task["frame_num"]
368
+ ann_obj_id = task["ann_obj_id"]
369
+ seg_tracker, input_first_frame, drawing_board, last_draw = sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id)
370
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw})
371
+ elif task["command"] == "sam_click":
372
+ frame_num = task["frame_num"]
373
+ point_mode = task["point_mode"]
374
+ click_stack = task["click_stack"]
375
+ ann_obj_id = task["ann_obj_id"]
376
+ point = task["point"]
377
+ seg_tracker, input_first_frame, drawing_board, last_draw = sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point)
378
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw})
379
+ elif task["command"] == "increment_ann_obj_id":
380
+ max_obj_id = task["max_obj_id"]
381
+ ann_obj_id, max_obj_id = increment_ann_obj_id(max_obj_id)
382
+ result_queue.put({"ann_obj_id": ann_obj_id, "max_obj_id": max_obj_id})
383
+ elif task["command"] == "update_current_id":
384
+ ann_obj_id = task["ann_obj_id"]
385
+ ann_obj_id = update_current_id(ann_obj_id)
386
+ result_queue.put({"ann_obj_id": ann_obj_id})
387
+ elif task["command"] == "drawing_board_get_input_first_frame":
388
+ input_first_frame = task["input_first_frame"]
389
+ input_first_frame = drawing_board_get_input_first_frame(input_first_frame)
390
+ result_queue.put({"input_first_frame": input_first_frame})
391
+ elif task["command"] == "reset":
392
+ seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id, max_obj_id = reset(seg_tracker)
393
+ result_queue.put({"click_stack": click_stack, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id, "max_obj_id": max_obj_id})
394
+ elif task["command"] == "show_res_by_slider":
395
+ frame_per = task["frame_per"]
396
+ click_stack = task["click_stack"]
397
+ input_first_frame, drawing_board, frame_num = show_res_by_slider(session_id, frame_per, click_stack)
398
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_num": frame_num})
399
+ elif task["command"] == "tracking_objects":
400
+ frame_num = task["frame_num"]
401
+ input_video = task["input_video"]
402
+ input_first_frame, drawing_board, output_video, output_mp4, output_mask, click_stack = tracking_objects(session_id, seg_tracker, frame_num, input_video)
403
+ result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "click_stack": click_stack})
404
+ else:
405
+ print(f"Unknown command {task['command']} for {session_id}")
406
+ result_queue.put("Unknown command")
407
+
408
+ def start_process(session_id):
409
+ if session_id not in user_processes:
410
+ queue = mp.Queue()
411
+ result_queue = mp.Queue()
412
+ process = mp.Process(target=process_video, args=(queue, result_queue, session_id))
413
+ process.start()
414
+ user_processes[session_id] = {
415
+ "process": process,
416
+ "queue": queue,
417
+ "result_queue": result_queue,
418
+ "last_active": datetime.datetime.now()
419
+ }
420
+ else:
421
+ user_processes[session_id]["last_active"] = datetime.datetime.now()
422
+ return user_processes[session_id]["queue"]
423
+
424
+ # def clean_up_processes(session_id, init_clean = False):
425
+ # now = datetime.datetime.now()
426
+ # to_remove = []
427
+ # for s_id, process_info in user_processes.items():
428
+ # if (now - process_info["last_active"] > PROCESS_TIMEOUT) or (s_id == session_id and init_clean):
429
+ # process_info["queue"].put({"command": "exit"})
430
+ # process_info["process"].terminate()
431
+ # process_info["process"].join()
432
+ # to_remove.append(s_id)
433
+ # for s_id in to_remove:
434
+ # del user_processes[s_id]
435
+ # print(f"Cleaned up process for session {s_id}.")
436
+
437
+ def monitor_and_cleanup_processes():
438
+ while True:
439
+ now = datetime.datetime.now()
440
+ to_remove = []
441
+ for session_id, process_info in user_processes.items():
442
+ if now - process_info["last_active"] > PROCESS_TIMEOUT:
443
+ process_info["queue"].put({"command": "exit"})
444
+ process_info["process"].terminate()
445
+ process_info["process"].join()
446
+ to_remove.append(session_id)
447
+ for session_id in to_remove:
448
+ del user_processes[session_id]
449
+ print(f"Automatically cleaned up process for session {session_id}.")
450
+ time.sleep(10)
451
+
452
+ def seg_track_app():
453
+ # Only supports gradio==3.38.0
454
+ import gradio as gr
455
+
456
+ def extract_session_id_from_request(request: gr.Request):
457
+ session_id = hashlib.sha256(f'{request.client.host}:{request.client.port}'.encode('utf-8')).hexdigest()
458
+ # cookies = request.kwargs["headers"].get('cookie', '')
459
+ # session_id = None
460
+ # if '_gid=' in cookies:
461
+ # session_id = cookies.split('_gid=')[1].split(';')[0]
462
+ # else:
463
+ # session_id = str(uuid.uuid4())
464
+ print(f"session_id {session_id}")
465
+ return session_id
466
+
467
+ def handle_extract_video_info(session_id, input_video):
468
+ # clean_up_processes(session_id, init_clean=True)
469
+ if input_video == None:
470
+ return 0, 0, {
471
+ "minimum": 0.0,
472
+ "maximum": 100,
473
+ "step": 0.01,
474
+ "value": 0.0,
475
+ }, None, None, None, None, None
476
+ queue = start_process(session_id)
477
+ result_queue = user_processes[session_id]["result_queue"]
478
+ queue.put({"command": "extract_video_info", "input_video": input_video})
479
+ result = result_queue.get()
480
+ fps = result.get("fps")
481
+ total_frames = result.get("total_frames")
482
+ input_first_frame = result.get("input_first_frame")
483
+ drawing_board = result.get("drawing_board")
484
+ output_video = result.get("output_video")
485
+ output_mp4 = result.get("output_mp4")
486
+ output_mask = result.get("output_mask")
487
+ scale_slider = gr.Slider.update(minimum=1.0,
488
+ maximum=fps,
489
+ step=1.0,
490
+ value=fps,)
491
+ frame_per = gr.Slider.update(minimum= 0.0,
492
+ maximum= total_frames / fps,
493
+ step=1.0/fps,
494
+ value=0.0,)
495
+ slider_state = {
496
+ "minimum": 0.0,
497
+ "maximum": total_frames / fps,
498
+ "step": 1.0/fps,
499
+ "value": 0.0,
500
+ }
501
+ return scale_slider, frame_per, slider_state, input_first_frame, drawing_board, output_video, output_mp4, output_mask
502
+
503
+ def handle_get_meta_from_video(session_id, input_video, scale_slider, selected_config, selected_checkpoint):
504
+ config_path = config_file_map[selected_config]
505
+ checkpoint_path = checkpoint_file_map[selected_checkpoint]
506
+ # clean_up_processes(session_id)
507
+ queue = start_process(session_id)
508
+ result_queue = user_processes[session_id]["result_queue"]
509
+ queue.put({"command": "get_meta_from_video", "input_video": input_video, "scale_slider": scale_slider, "config_path": config_path, "checkpoint_path": checkpoint_path})
510
+ result = result_queue.get()
511
+ input_first_frame = result.get("input_first_frame")
512
+ drawing_board = result.get("drawing_board")
513
+ (fps, frame_interval, total_frames) = result.get("frame_per")
514
+ output_video = result.get("output_video")
515
+ output_mp4 = result.get("output_mp4")
516
+ output_mask = result.get("output_mask")
517
+ ann_obj_id = result.get("ann_obj_id")
518
+ max_obj_id = result.get("max_obj_id")
519
+ frame_per = gr.Slider.update(minimum= 0.0,
520
+ maximum= total_frames / fps,
521
+ step=frame_interval / fps / 2,
522
+ value=0.0,)
523
+ slider_state = {
524
+ "minimum": 0.0,
525
+ "maximum": total_frames / fps,
526
+ "step": frame_interval/fps / 2 ,
527
+ "value": 0.0,
528
+ }
529
+ obj_id_slider = gr.Slider.update(
530
+ maximum=max_obj_id,
531
+ value=ann_obj_id
532
+ )
533
+ return input_first_frame, drawing_board, frame_per, slider_state, output_video, output_mp4, output_mask, ann_obj_id, max_obj_id, obj_id_slider
534
+
535
+ def handle_sam_stroke(session_id, drawing_board, last_draw, frame_num, ann_obj_id):
536
+ # clean_up_processes(session_id)
537
+ queue = start_process(session_id)
538
+ result_queue = user_processes[session_id]["result_queue"]
539
+ queue.put({"command": "sam_stroke", "drawing_board": drawing_board, "last_draw": last_draw, "frame_num": frame_num, "ann_obj_id": ann_obj_id})
540
+ result = result_queue.get()
541
+ input_first_frame = result.get("input_first_frame")
542
+ drawing_board = result.get("drawing_board")
543
+ last_draw = result.get("last_draw")
544
+ return input_first_frame, drawing_board, last_draw
545
+
546
+ def handle_sam_click(session_id, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
547
+ # clean_up_processes(session_id)
548
+ queue = start_process(session_id)
549
+ result_queue = user_processes[session_id]["result_queue"]
550
+ point = np.array([[evt.index[0], evt.index[1]]], dtype=np.float32)
551
+ queue.put({"command": "sam_click", "frame_num": frame_num, "point_mode": point_mode, "click_stack": click_stack, "ann_obj_id": ann_obj_id, "point": point})
552
+ result = result_queue.get()
553
+ input_first_frame = result.get("input_first_frame")
554
+ drawing_board = result.get("drawing_board")
555
+ last_draw = result.get("last_draw")
556
+ return input_first_frame, drawing_board, last_draw
557
+
558
+ def handle_increment_ann_obj_id(session_id, max_obj_id):
559
+ # clean_up_processes(session_id)
560
+ queue = start_process(session_id)
561
+ result_queue = user_processes[session_id]["result_queue"]
562
+ queue.put({"command": "increment_ann_obj_id", "max_obj_id": max_obj_id})
563
+ result = result_queue.get()
564
+ ann_obj_id = result.get("ann_obj_id")
565
+ max_obj_id = result.get("max_obj_id")
566
+ obj_id_slider = gr.Slider.update(maximum=max_obj_id, value=ann_obj_id)
567
+ return ann_obj_id, max_obj_id, obj_id_slider
568
+
569
+ def handle_update_current_id(session_id, ann_obj_id):
570
+ # clean_up_processes(session_id)
571
+ queue = start_process(session_id)
572
+ result_queue = user_processes[session_id]["result_queue"]
573
+ queue.put({"command": "update_current_id", "ann_obj_id": ann_obj_id})
574
+ result = result_queue.get()
575
+ ann_obj_id = result.get("ann_obj_id")
576
+ return ann_obj_id
577
+
578
+ def handle_drawing_board_get_input_first_frame(session_id, input_first_frame):
579
+ # clean_up_processes(session_id)
580
+ queue = start_process(session_id)
581
+ result_queue = user_processes[session_id]["result_queue"]
582
+ queue.put({"command": "drawing_board_get_input_first_frame", "input_first_frame": input_first_frame})
583
+ result = result_queue.get()
584
+ input_first_frame = result.get("input_first_frame")
585
+ return input_first_frame
586
+
587
+ def handle_reset(session_id):
588
+ # clean_up_processes(session_id)
589
+ queue = start_process(session_id)
590
+ result_queue = user_processes[session_id]["result_queue"]
591
+ queue.put({"command": "reset"})
592
+ result = result_queue.get()
593
+ click_stack = result.get("click_stack")
594
+ input_first_frame = result.get("input_first_frame")
595
+ drawing_board = result.get("drawing_board")
596
+ slider_state = {
597
+ "minimum": 0.0,
598
+ "maximum": 100,
599
+ "step": 0.01,
600
+ "value": 0.0,
601
+ }
602
+ output_video = result.get("output_video")
603
+ output_mp4 = result.get("output_mp4")
604
+ output_mask = result.get("output_mask")
605
+ ann_obj_id = result.get("ann_obj_id")
606
+ max_obj_id = result.get("max_obj_id")
607
+ obj_id_slider = gr.Slider.update(
608
+ maximum=max_obj_id,
609
+ value=ann_obj_id)
610
+ return click_stack, input_first_frame, drawing_board, frame_per, slider_state, output_video, output_mp4, output_mask, ann_obj_id, max_obj_id, obj_id_slider
611
+
612
+ def handle_show_res_by_slider(session_id, frame_per, slider_state, click_stack):
613
+ # clean_up_processes(session_id)
614
+ queue = start_process(session_id)
615
+ frame_per = frame_per/slider_state["maximum"]
616
+ result_queue = user_processes[session_id]["result_queue"]
617
+ queue.put({"command": "show_res_by_slider", "frame_per": frame_per, "click_stack": click_stack})
618
+ result = result_queue.get()
619
+ input_first_frame = result.get("input_first_frame")
620
+ drawing_board = result.get("drawing_board")
621
+ frame_num = result.get("frame_num")
622
+ return input_first_frame, drawing_board, frame_num
623
+
624
+ def handle_tracking_objects(session_id, frame_num, input_video):
625
+ # clean_up_processes(session_id)
626
+ queue = start_process(session_id)
627
+ result_queue = user_processes[session_id]["result_queue"]
628
+ queue.put({"command": "tracking_objects", "frame_num": frame_num, "input_video": input_video})
629
+ result = result_queue.get()
630
+ input_first_frame = result.get("input_first_frame")
631
+ drawing_board = result.get("drawing_board")
632
+ output_video = result.get("output_video")
633
+ output_mp4 = result.get("output_mp4")
634
+ output_mask = result.get("output_mask")
635
+ click_stack = result.get("click_stack")
636
+ return input_first_frame, drawing_board, output_video, output_mp4, output_mask, click_stack
637
+
638
+ ##########################################################
639
+ ###################### Front-end ########################
640
+ ##########################################################
641
+ css = """
642
+ #input_output_video video {
643
+ max-height: 550px;
644
+ max-width: 100%;
645
+ height: auto;
646
+ }
647
+ """
648
+
649
+ if platform.system() == "Windows":
650
+ config_path = os.path.abspath(os.environ.get("CONFIG_PATH", "sam2/configs/"))
651
+ checkpoint_path = os.environ.get("CHECKPOINT_PATH", "checkpoints/")
652
+
653
+ config_files = glob(os.path.join(config_path, "*.yaml"))
654
+ config_files.sort(key=lambda x: '_t.' not in basename(x))
655
+
656
+ checkpoint_files = glob(os.path.join(checkpoint_path, "*.pt"))
657
+ checkpoint_files.sort(key=lambda x: 'tiny' not in basename(x))
658
+
659
+ medsam_checkpoints = glob("checkpoints/*.pt")
660
+ else:
661
+ config_path = "/" + os.path.abspath(os.environ.get("CONFIG_PATH", "./sam2/configs/"))
662
+ checkpoint_path = os.environ.get("CHECKPOINT_PATH", "./checkpoints")
663
+
664
+ config_files = glob(os.path.join(config_path, "*.yaml"))
665
+ config_files.sort(key=lambda x: '_t.' not in basename(x))
666
+
667
+ checkpoint_files = glob(os.path.join(checkpoint_path, "*.pt"))
668
+ checkpoint_files.sort(key=lambda x: 'tiny' not in basename(x))
669
+
670
+ medsam_checkpoints = glob("./checkpoints/*.pt")
671
+
672
+ config_display = [splitext(basename(f))[0] for f in config_files]
673
+ medsam_display = [
674
+ f"{os.path.basename(dirname(dirname(path)))} / {splitext(basename(path))[0]}"
675
+ for path in medsam_checkpoints
676
+ ]
677
+ checkpoint_display = [
678
+ splitext(basename(f))[0] for f in checkpoint_files
679
+ ] + medsam_display
680
+ checkpoint_files.extend(medsam_checkpoints)
681
+
682
+ config_file_map = dict(zip(config_display, config_files))
683
+ checkpoint_file_map = dict(zip(checkpoint_display, checkpoint_files))
684
+
685
+ app = gr.Blocks(css=css)
686
+ with app:
687
+ session_id = gr.State()
688
+ app.load(extract_session_id_from_request, None, session_id)
689
+ gr.Markdown(
690
+ '''
691
+ <div style="text-align:center; margin-bottom:20px;">
692
+ <span style="font-size:3em; font-weight:bold;">MedSAM2: Segment Anything in 3D Medical Images and Videos</span>
693
+ </div>
694
+ <div style="text-align:center; margin-bottom:20px;">
695
+ <a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2">
696
+ <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;">
697
+ </a>
698
+ <a href="https://arxiv.org/abs/2408.03322">
699
+ <img src="https://img.shields.io/badge/arXiv-2408.03322-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;">
700
+ </a>
701
+ <a href="https://github.com/bowang-lab/MedSAMSlicer/tree/MedSAM2">
702
+ <img src="https://img.shields.io/badge/3D-Slicer-Plugin" alt="3D Slicer Plugin" style="display:inline-block; margin-right:10px;">
703
+ </a>
704
+ </div>
705
+ <div style="text-align:left; margin-bottom:20px;">
706
+ This API supports using box (generated by scribble) and point prompts for medical video segmentation.
707
+ </div>
708
+ <div style="margin-bottom:20px;">
709
+ <ol style="list-style:none; padding-left:0;">
710
+ <li>1. Upload video file</li>
711
+ <li>2. Select model size and downsample frame rate and run <b>Preprocess</b></li>
712
+ <li>3. Use <b>Stroke to Box Prompt</b> to draw box on the first frame or <b>Point Prompt</b> to click on the first frame.</li>
713
+ <li>&nbsp;&nbsp;&nbsp;Note: The bounding rectangle of the stroke should be able to cover the segmentation target.</li>
714
+ <li>4. Click <b>Segment</b> to get the segmentation result</li>
715
+ <li>5. Click <b>Add New Object</b> to add new object</li>
716
+ <li>6. Click <b>Start Tracking</b> to track objects in the video</li>
717
+ <li>7. Click <b>Reset</b> to reset the app</li>
718
+ <li>8. Download the video with segmentation results</li>
719
+ </ol>
720
+ </div>
721
+ <div style="text-align:left; line-height:1.8;">
722
+ If you find these tools useful, please consider citing the following papers:
723
+ </div>
724
+ <div style="text-align:left; line-height:1.8;">
725
+ Ravi, N., Gabeur, V., Hu, Y.T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädle, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K.V., Carion, N., Wu, C.Y., Girshick, R., Dollár, P., Feichtenhofer, C.: SAM 2: Segment Anything in Images and Videos. ICLR 2025
726
+ </div>
727
+ <div style="text-align:left; line-height:1.8;">
728
+ Ma, J.*, Yang, Z.*, Kim, S., Chen, B., Baharoon, M., Fallahpour, A, Asakereh, R., Lyu, H., Wang, B.: MedSAM2: Segment Anything in Medical Images and Videos. arXiv preprint (2025)
729
+ </div>
730
+ '''
731
+ )
732
+
733
+ click_stack = gr.State(({}, {}))
734
+ frame_num = gr.State(value=(int(0)))
735
+ ann_obj_id = gr.State(value=(int(0)))
736
+ max_obj_id = gr.State(value=(int(0)))
737
+ last_draw = gr.State(None)
738
+ slider_state = gr.State(value={
739
+ "minimum": 0.0,
740
+ "maximum": 100,
741
+ "step": 0.01,
742
+ "value": 0.0,
743
+ })
744
+
745
+ with gr.Row():
746
+ with gr.Column(scale=0.5):
747
+ with gr.Row():
748
+ tab_video_input = gr.Tab(label="Video input")
749
+ with tab_video_input:
750
+ input_video = gr.Video(label='Input video', type=["mp4", "mov", "avi"], elem_id="input_output_video")
751
+ with gr.Row():
752
+ # checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
753
+ config_dropdown = gr.Dropdown(
754
+ choices=config_display,
755
+ value=config_display[0],
756
+ label="Select Config File"
757
+ )
758
+
759
+ checkpoint_dropdown = gr.Dropdown(
760
+ choices=checkpoint_display,
761
+ value=checkpoint_display[0],
762
+ label="Select Checkpoint File"
763
+ )
764
+ scale_slider = gr.Slider(
765
+ label="Downsampe Frame Rate (fps)",
766
+ minimum=0.0,
767
+ maximum=1.0,
768
+ step=0.25,
769
+ value=1.0,
770
+ interactive=True
771
+ )
772
+ preprocess_button = gr.Button(
773
+ value="Preprocess",
774
+ interactive=True,
775
+ )
776
+
777
+ with gr.Row():
778
+ tab_stroke = gr.Tab(label="Stroke to Box Prompt")
779
+ with tab_stroke:
780
+ drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True)
781
+ with gr.Row():
782
+ seg_acc_stroke = gr.Button(value="Segment", interactive=True)
783
+
784
+ tab_click = gr.Tab(label="Point Prompt")
785
+ with tab_click:
786
+ input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)
787
+ with gr.Row():
788
+ point_mode = gr.Radio(
789
+ choices=["Positive", "Negative"],
790
+ value="Positive",
791
+ label="Point Prompt",
792
+ interactive=True)
793
+
794
+ with gr.Row():
795
+ with gr.Column():
796
+ frame_per = gr.Slider(
797
+ label = "Time (seconds)",
798
+ minimum= 0.0,
799
+ maximum= 100.0,
800
+ step=0.01,
801
+ value=0.0,
802
+ )
803
+ with gr.Row():
804
+ with gr.Column():
805
+ obj_id_slider = gr.Slider(
806
+ minimum=0,
807
+ maximum=0,
808
+ step=1,
809
+ interactive=True,
810
+ label="Current Object ID"
811
+ )
812
+ with gr.Column():
813
+ new_object_button = gr.Button(
814
+ value="Add New Object",
815
+ interactive=True
816
+ )
817
+ track_for_video = gr.Button(
818
+ value="Start Tracking",
819
+ interactive=True,
820
+ )
821
+ reset_button = gr.Button(
822
+ value="Reset",
823
+ interactive=True, visible=False,
824
+ )
825
+
826
+ with gr.Column(scale=0.5):
827
+ output_video = gr.Video(label='Visualize Results', elem_id="input_output_video")
828
+ output_mp4 = gr.File(label="Predicted video")
829
+ output_mask = gr.File(label="Predicted masks")
830
+
831
+ gr.Markdown(
832
+ '''
833
+ <div style="text-align:center; margin-top: 20px;">
834
+ The authors of this work highly appreciate Meta AI for making SAM2 publicly available to the community.
835
+ The interface was built on <a href="https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/tutorial/tutorial%20for%20WebUI-1.0-Version.md" target="_blank">SegTracker</a>, which is also an amazing tool for video segmentation tracking.
836
+ <a href="https://docs.google.com/document/d/1idDBV0faOjdjVs-iAHr0uSrw_9_ZzLGrUI2FEdK-lso/edit?usp=sharing" target="_blank">Data source</a>
837
+ </div>
838
+ '''
839
+ )
840
+
841
+ ##########################################################
842
+ ###################### back-end #########################
843
+ ##########################################################
844
+
845
+ # listen to the preprocess button click to get the first frame of video with scaling
846
+ preprocess_button.click(
847
+ fn=handle_get_meta_from_video,
848
+ inputs=[
849
+ session_id,
850
+ input_video,
851
+ scale_slider,
852
+ config_dropdown,
853
+ checkpoint_dropdown
854
+ ],
855
+ outputs=[
856
+ input_first_frame, drawing_board, frame_per, slider_state, output_video, output_mp4, output_mask, ann_obj_id, max_obj_id, obj_id_slider
857
+ ], queue=False, every=15
858
+ )
859
+
860
+ frame_per.release(
861
+ fn=handle_show_res_by_slider,
862
+ inputs=[
863
+ session_id, frame_per, slider_state, click_stack
864
+ ],
865
+ outputs=[
866
+ input_first_frame, drawing_board, frame_num
867
+ ]
868
+ )
869
+
870
+ # Interactively modify the mask acc click
871
+ input_first_frame.select(
872
+ fn=handle_sam_click,
873
+ inputs=[
874
+ session_id, frame_num, point_mode, click_stack, ann_obj_id
875
+ ],
876
+ outputs=[
877
+ input_first_frame, drawing_board, click_stack
878
+ ]
879
+ )
880
+
881
+ # Track object in video
882
+ track_for_video.click(
883
+ fn=handle_tracking_objects,
884
+ inputs=[
885
+ session_id,
886
+ frame_num,
887
+ input_video,
888
+ ],
889
+ outputs=[
890
+ input_first_frame,
891
+ drawing_board,
892
+ output_video,
893
+ output_mp4,
894
+ output_mask,
895
+ click_stack
896
+ ], queue=False, every=15
897
+ )
898
+
899
+ reset_button.click(
900
+ fn=handle_reset,
901
+ inputs=[session_id],
902
+ outputs=[
903
+ click_stack, input_first_frame, drawing_board, frame_per, slider_state, output_video, output_mp4, output_mask, ann_obj_id, max_obj_id, obj_id_slider
904
+ ]
905
+ )
906
+
907
+ new_object_button.click(
908
+ fn=handle_increment_ann_obj_id,
909
+ inputs=[ session_id, max_obj_id ],
910
+ outputs=[ ann_obj_id, max_obj_id, obj_id_slider ]
911
+ )
912
+
913
+ obj_id_slider.change(
914
+ fn=handle_update_current_id,
915
+ inputs=[session_id, obj_id_slider],
916
+ outputs={ann_obj_id}
917
+ )
918
+
919
+ tab_stroke.select(
920
+ fn=handle_drawing_board_get_input_first_frame,
921
+ inputs=[session_id, input_first_frame],
922
+ outputs=[drawing_board,],
923
+ )
924
+
925
+ seg_acc_stroke.click(
926
+ fn=handle_sam_stroke,
927
+ inputs=[
928
+ session_id, drawing_board, last_draw, frame_num, ann_obj_id
929
+ ],
930
+ outputs=[
931
+ input_first_frame, drawing_board, last_draw
932
+ ]
933
+ )
934
+
935
+ input_video.change(
936
+ fn=handle_extract_video_info,
937
+ inputs=[session_id, input_video],
938
+ outputs=[scale_slider, frame_per, slider_state, input_first_frame, drawing_board, output_video, output_mp4, output_mask], queue=False, every=15
939
+ )
940
+
941
+ app.queue(concurrency_count=1)
942
+ app.launch(debug=True, enable_queue=True, share=False, server_name="0.0.0.0", server_port=18862)
943
+ # app.launch(debug=True, enable_queue=True, share=True)
944
+
945
+ if __name__ == "__main__":
946
+ mp.set_start_method("spawn")
947
+ monitor_thread = threading.Thread(target=monitor_and_cleanup_processes)
948
+ monitor_thread.daemon = True
949
+ monitor_thread.start()
950
+ seg_track_app()
checkpoints/MedSAM2_2411.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcd946a4d934f553236866fc7e8af77f7e931430e9c044f4ac9d6a723630a870
3
+ size 156039179
checkpoints/MedSAM2_CTLesion.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78f7e125418dfd6fec22f4afe90bcd85cb1d4423d0a9df36f7a87ed63aa1a5f5
3
+ size 156041079
checkpoints/MedSAM2_MRI_LiverLesion.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3632fc77def3a136d7ae8d734613472d187a803b4a1846370b45419622072b2b
3
+ size 156044532
checkpoints/MedSAM2_US_Heart.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:295c0ff8912c99947c364287bbecd1cd36963f0c0ac67a042d292f0dedf8d933
3
+ size 156041079
checkpoints/MedSAM2_latest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c92743b99f00d078bf32a3afcc38aaa9faf1c1692dffe3eaa7a90938c1991060
3
+ size 156040129
checkpoints/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Download checkpoints `sh download.sh`
3
+
4
+ - `MedSAM2_2411.pt`: The based model trained in Nov. 2024
5
+ - `MedSAM2_US_Heart.pt`: Fine-tuned model for heart ultrasound video segmentation
6
+ - `MedSAM2_MRI_LiverLesion.pt`: Fine-tuned model for liver lesion MRI segmentation
7
+ - `MedSAM2_CTLesion.pt`: Fine-tuned model for CT lesion segmentation
8
+ - `MedSAM2_latest.pt` (recommended): Latest model trained on the combination of existing public datasets and newly annotated datasets
9
+
10
+
download.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ # Script to download MedSAM2 model checkpoints
3
+ # Create checkpoints directory if it doesn't exist
4
+ mkdir -p checkpoints
5
+ # Use either wget or curl to download the checkpoints
6
+ if command -v wget > /dev/null 2>&1; then
7
+ CMD="wget -P checkpoints"
8
+ elif command -v curl > /dev/null 2>&1; then
9
+ CMD="curl -L -o"
10
+ CURL=1
11
+ else
12
+ echo "Please install wget or curl to download the checkpoints."
13
+ exit 1
14
+ fi
15
+ # Define the base URL for MedSAM2 models on Hugging Face
16
+ HF_BASE_URL="https://huggingface.co/wanglab/MedSAM2/resolve/main"
17
+ # Define the model checkpoint files (as separate variables instead of an array)
18
+ MODEL1="MedSAM2_2411.pt"
19
+ MODEL2="MedSAM2_US_Heart.pt"
20
+ MODEL3="MedSAM2_MRI_LiverLesion.pt"
21
+ MODEL4="MedSAM2_CTLesion.pt"
22
+ MODEL5="MedSAM2_latest.pt"
23
+
24
+ # Download each checkpoint
25
+ for model in $MODEL1 $MODEL2 $MODEL3 $MODEL4 $MODEL5; do
26
+ echo "Downloading ${model}..."
27
+ model_url="${HF_BASE_URL}/${model}"
28
+
29
+ if [ -n "$CURL" ]; then
30
+ $CMD "checkpoints/${model}" "$model_url" || { echo "Failed to download checkpoint from $model_url"; exit 1; }
31
+ else
32
+ $CMD "$model_url" || { echo "Failed to download checkpoint from $model_url"; exit 1; }
33
+ fi
34
+ done
35
+ echo "All MedSAM2 model checkpoints have been downloaded successfully to the 'checkpoints' directory."
download_checkpoints.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run this as a Python script in the terminal (or via Python shell)
2
+ from huggingface_hub import hf_hub_download
3
+ import os
4
+
5
+ os.makedirs("checkpoints", exist_ok=True)
6
+
7
+ model_files = [
8
+ "MedSAM2_2411.pt",
9
+ "MedSAM2_US_Heart.pt",
10
+ "MedSAM2_MRI_LiverLesion.pt",
11
+ "MedSAM2_CTLesion.pt",
12
+ "MedSAM2_latest.pt"
13
+ ]
14
+
15
+ for filename in model_files:
16
+ hf_hub_download(
17
+ repo_id="wanglab/MedSAM2",
18
+ filename=filename,
19
+ local_dir="checkpoints",
20
+ local_dir_use_symlinks=False
21
+ )
gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode/
2
+ .DS_Store
3
+ __pycache__/
4
+ *-checkpoint.ipynb
5
+ .venv
6
+ *.egg*
7
+ build/*
8
+ _C.*
9
+ *.nii.gz
10
+ *.csv
11
+ outputs/*
12
+ checkpoints/*.pt
13
+ *.pt
medsam2_infer_3D_CT.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ from tqdm import tqdm
3
+ import os
4
+ from os.path import join, basename
5
+ import re
6
+ import matplotlib.pyplot as plt
7
+ from collections import OrderedDict
8
+ import pandas as pd
9
+ import numpy as np
10
+ import argparse
11
+
12
+ from PIL import Image
13
+ import SimpleITK as sitk
14
+ import torch
15
+ import torch.multiprocessing as mp
16
+ from sam2.build_sam import build_sam2_video_predictor_npz
17
+ import SimpleITK as sitk
18
+ from skimage import measure, morphology
19
+
20
+ torch.set_float32_matmul_precision('high')
21
+ torch.manual_seed(2024)
22
+ torch.cuda.manual_seed(2024)
23
+ np.random.seed(2024)
24
+
25
+ parser = argparse.ArgumentParser()
26
+
27
+ parser.add_argument(
28
+ '--checkpoint',
29
+ type=str,
30
+ default="checkpoints/MedSAM2_latest.pt",
31
+ help='checkpoint path',
32
+ )
33
+ parser.add_argument(
34
+ '--cfg',
35
+ type=str,
36
+ default="configs/sam2.1_hiera_t512.yaml",
37
+ help='model config',
38
+ )
39
+
40
+ parser.add_argument(
41
+ '-i',
42
+ '--imgs_path',
43
+ type=str,
44
+ default="CT_DeepLesion/images",
45
+ help='imgs path',
46
+ )
47
+ parser.add_argument(
48
+ '--gts_path',
49
+ default=None,
50
+ help='simulate prompts based on ground truth',
51
+ )
52
+ parser.add_argument(
53
+ '-o',
54
+ '--pred_save_dir',
55
+ type=str,
56
+ default="./DeeLesion_results",
57
+ help='path to save segmentation results',
58
+ )
59
+ # add option to propagate with either box or mask
60
+ parser.add_argument(
61
+ '--propagate_with_box',
62
+ default=True,
63
+ action='store_true',
64
+ help='whether to propagate with box'
65
+ )
66
+
67
+ args = parser.parse_args()
68
+ checkpoint = args.checkpoint
69
+ model_cfg = args.cfg
70
+ imgs_path = args.imgs_path
71
+ gts_path = args.gts_path
72
+ pred_save_dir = args.pred_save_dir
73
+ os.makedirs(pred_save_dir, exist_ok=True)
74
+ propagate_with_box = args.propagate_with_box
75
+
76
+ def getLargestCC(segmentation):
77
+ labels = measure.label(segmentation)
78
+ largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
79
+ return largestCC
80
+
81
+ def dice_multi_class(preds, targets):
82
+ smooth = 1.0
83
+ assert preds.shape == targets.shape
84
+ labels = np.unique(targets)[1:]
85
+ dices = []
86
+ for label in labels:
87
+ pred = preds == label
88
+ target = targets == label
89
+ intersection = (pred * target).sum()
90
+ dices.append((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))
91
+ return np.mean(dices)
92
+
93
+ def show_mask(mask, ax, mask_color=None, alpha=0.5):
94
+ """
95
+ show mask on the image
96
+
97
+ Parameters
98
+ ----------
99
+ mask : numpy.ndarray
100
+ mask of the image
101
+ ax : matplotlib.axes.Axes
102
+ axes to plot the mask
103
+ mask_color : numpy.ndarray
104
+ color of the mask
105
+ alpha : float
106
+ transparency of the mask
107
+ """
108
+ if mask_color is not None:
109
+ color = np.concatenate([mask_color, np.array([alpha])], axis=0)
110
+ else:
111
+ color = np.array([251/255, 252/255, 30/255, alpha])
112
+ h, w = mask.shape[-2:]
113
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
114
+ ax.imshow(mask_image)
115
+
116
+
117
+ def show_box(box, ax, edgecolor='blue'):
118
+ """
119
+ show bounding box on the image
120
+
121
+ Parameters
122
+ ----------
123
+ box : numpy.ndarray
124
+ bounding box coordinates in the original image
125
+ ax : matplotlib.axes.Axes
126
+ axes to plot the bounding box
127
+ edgecolor : str
128
+ color of the bounding box
129
+ """
130
+ x0, y0 = box[0], box[1]
131
+ w, h = box[2] - box[0], box[3] - box[1]
132
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2))
133
+
134
+
135
+ def resize_grayscale_to_rgb_and_resize(array, image_size):
136
+ """
137
+ Resize a 3D grayscale NumPy array to an RGB image and then resize it.
138
+
139
+ Parameters:
140
+ array (np.ndarray): Input array of shape (d, h, w).
141
+ image_size (int): Desired size for the width and height.
142
+
143
+ Returns:
144
+ np.ndarray: Resized array of shape (d, 3, image_size, image_size).
145
+ """
146
+ d, h, w = array.shape
147
+ resized_array = np.zeros((d, 3, image_size, image_size))
148
+
149
+ for i in range(d):
150
+ img_pil = Image.fromarray(array[i].astype(np.uint8))
151
+ img_rgb = img_pil.convert("RGB")
152
+ img_resized = img_rgb.resize((image_size, image_size))
153
+ img_array = np.array(img_resized).transpose(2, 0, 1) # (3, image_size, image_size)
154
+ resized_array[i] = img_array
155
+
156
+ return resized_array
157
+
158
+ def mask2D_to_bbox(gt2D, max_shift=20):
159
+ y_indices, x_indices = np.where(gt2D > 0)
160
+ x_min, x_max = np.min(x_indices), np.max(x_indices)
161
+ y_min, y_max = np.min(y_indices), np.max(y_indices)
162
+ H, W = gt2D.shape
163
+ bbox_shift = np.random.randint(0, max_shift + 1, 1)[0]
164
+ x_min = max(0, x_min - bbox_shift)
165
+ x_max = min(W-1, x_max + bbox_shift)
166
+ y_min = max(0, y_min - bbox_shift)
167
+ y_max = min(H-1, y_max + bbox_shift)
168
+ boxes = np.array([x_min, y_min, x_max, y_max])
169
+ return boxes
170
+
171
+ def mask3D_to_bbox(gt3D, max_shift=20):
172
+ z_indices, y_indices, x_indices = np.where(gt3D > 0)
173
+ x_min, x_max = np.min(x_indices), np.max(x_indices)
174
+ y_min, y_max = np.min(y_indices), np.max(y_indices)
175
+ z_min, z_max = np.min(z_indices), np.max(z_indices)
176
+ D, H, W = gt3D.shape
177
+ bbox_shift = np.random.randint(0, max_shift + 1, 1)[0]
178
+ x_min = max(0, x_min - bbox_shift)
179
+ x_max = min(W-1, x_max + bbox_shift)
180
+ y_min = max(0, y_min - bbox_shift)
181
+ y_max = min(H-1, y_max + bbox_shift)
182
+ z_min = max(0, z_min)
183
+ z_max = min(D-1, z_max)
184
+ boxes3d = np.array([x_min, y_min, z_min, x_max, y_max, z_max])
185
+ return boxes3d
186
+
187
+
188
+ DL_info = pd.read_csv('CT_DeepLesion/DeepLesion_Dataset_Info.csv')
189
+ nii_fnames = sorted(os.listdir(imgs_path))
190
+ nii_fnames = [i for i in nii_fnames if i.endswith('.nii.gz')]
191
+ nii_fnames = [i for i in nii_fnames if not i.startswith('._')]
192
+ print(f'Processing {len(nii_fnames)} nii files')
193
+ seg_info = OrderedDict()
194
+ seg_info['nii_name'] = []
195
+ seg_info['key_slice_index'] = []
196
+ seg_info['DICOM_windows'] = []
197
+ # initialized predictor
198
+ predictor = build_sam2_video_predictor_npz(model_cfg, checkpoint)
199
+
200
+ for nii_fname in tqdm(nii_fnames):
201
+ # get corresponding case info
202
+ range_suffix = re.findall(r'\d{3}-\d{3}', nii_fname)[0]
203
+ slice_range = range_suffix.split('-')
204
+ slice_range = [str(int(s)) for s in slice_range]
205
+ slice_range = ', '.join(slice_range)
206
+ nii_image = sitk.ReadImage(join(imgs_path, nii_fname))
207
+ nii_image_data = sitk.GetArrayFromImage(nii_image)
208
+
209
+ case_name = re.findall(r'^(\d{6}_\d{2}_\d{2})', nii_fname)[0]
210
+ case_df = DL_info[
211
+ DL_info['File_name'].str.contains(case_name) &
212
+ DL_info['Slice_range'].str.contains(slice_range)
213
+ ].copy()
214
+
215
+ segs_3D = np.zeros(nii_image_data.shape, dtype=np.uint8)
216
+
217
+ for row_id, row in case_df.iterrows():
218
+ # print(f'Processing {case_name} tumor {tumor_idx}')
219
+ # get the key slice info
220
+ lower_bound, upper_bound = row['DICOM_windows'].split(',')
221
+ lower_bound, upper_bound = float(lower_bound), float(upper_bound)
222
+ nii_image_data_pre = np.clip(nii_image_data, lower_bound, upper_bound)
223
+ nii_image_data_pre = (nii_image_data_pre - np.min(nii_image_data_pre))/(np.max(nii_image_data_pre)-np.min(nii_image_data_pre))*255.0
224
+ nii_image_data_pre = np.uint8(nii_image_data_pre)
225
+ key_slice_idx = row['Key_slice_index']
226
+ key_slice_idx = int(key_slice_idx)
227
+ slice_range = row['Slice_range']
228
+ slice_idx_start, slice_idx_end = slice_range.split(',')
229
+ slice_idx_start, slice_idx_end = int(slice_idx_start), int(slice_idx_end)
230
+ bbox_coords = row['Bounding_boxes']
231
+ bbox_coords = bbox_coords.split(',')
232
+ bbox_coords = [int(float(coord)) for coord in bbox_coords]
233
+ #bbox_coords = expand_box(bbox_coords)
234
+ bbox = np.array(bbox_coords) # y_min, x_min, y_max, x_max
235
+ bbox = np.array([bbox[1], bbox[0], bbox[3], bbox[2]])
236
+
237
+ key_slice_idx_offset = key_slice_idx - slice_idx_start
238
+ key_slice_img = nii_image_data_pre[key_slice_idx_offset, :,:]
239
+
240
+ img_3D_ori = nii_image_data_pre
241
+ assert np.max(img_3D_ori) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3D_ori)}'
242
+
243
+ video_height = key_slice_img.shape[0]
244
+ video_width = key_slice_img.shape[1]
245
+ img_resized = resize_grayscale_to_rgb_and_resize(img_3D_ori, 512)
246
+ img_resized = img_resized / 255.0
247
+ img_resized = torch.from_numpy(img_resized).cuda()
248
+ img_mean=(0.485, 0.456, 0.406)
249
+ img_std=(0.229, 0.224, 0.225)
250
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].cuda()
251
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].cuda()
252
+ img_resized -= img_mean
253
+ img_resized /= img_std
254
+ z_mids = []
255
+
256
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
257
+ inference_state = predictor.init_state(img_resized, video_height, video_width)
258
+ if propagate_with_box:
259
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
260
+ inference_state=inference_state,
261
+ frame_idx=key_slice_idx_offset,
262
+ obj_id=1,
263
+ box=bbox,
264
+ )
265
+ else: # gt
266
+ pass
267
+
268
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
269
+ segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
270
+ predictor.reset_state(inference_state)
271
+ if propagate_with_box:
272
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
273
+ inference_state=inference_state,
274
+ frame_idx=key_slice_idx_offset,
275
+ obj_id=1,
276
+ box=bbox,
277
+ )
278
+ else: # gt
279
+ pass
280
+
281
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True):
282
+ segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
283
+ predictor.reset_state(inference_state)
284
+ if np.max(segs_3D) > 0:
285
+ segs_3D = getLargestCC(segs_3D)
286
+ segs_3D = np.uint8(segs_3D)
287
+ sitk_image = sitk.GetImageFromArray(img_3D_ori)
288
+ sitk_image.CopyInformation(nii_image)
289
+ sitk_mask = sitk.GetImageFromArray(segs_3D)
290
+ sitk_mask.CopyInformation(nii_image)
291
+ # save single lesion
292
+ key_slice_idx = row['Key_slice_index']
293
+ save_seg_name = nii_fname.split('.nii.gz')[0] + f'_k{key_slice_idx}_mask.nii.gz'
294
+ sitk.WriteImage(sitk_image, os.path.join(pred_save_dir, nii_fname.replace('.nii.gz', '_img.nii.gz')))
295
+ sitk.WriteImage(sitk_mask, os.path.join(pred_save_dir, save_seg_name))
296
+ seg_info['nii_name'].append(save_seg_name)
297
+ seg_info['key_slice_index'].append(key_slice_idx)
298
+ seg_info['DICOM_windows'].append(row['DICOM_windows'])
299
+
300
+ seg_info_df = pd.DataFrame(seg_info)
301
+ seg_info_df.to_csv(join(pred_save_dir, 'tiny_seg_info202412.csv'), index=False)
302
+
303
+
304
+
medsam2_infer_video.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ from collections import defaultdict
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+ from sam2.build_sam import build_sam2_video_predictor
15
+
16
+ # the PNG palette for DAVIS 2017 dataset
17
+ DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"
18
+
19
+
20
+ def load_ann_png(path):
21
+ """Load a PNG file as a mask and its palette."""
22
+ mask = Image.open(path)
23
+ palette = mask.getpalette()
24
+ mask = np.array(mask).astype(np.uint8)
25
+ return mask, palette
26
+
27
+
28
+ def save_ann_png(path, mask, palette):
29
+ """Save a mask as a PNG file with the given palette."""
30
+ assert mask.dtype == np.uint8
31
+ assert mask.ndim == 2
32
+ output_mask = Image.fromarray(mask)
33
+ output_mask.putpalette(palette)
34
+ output_mask.save(path)
35
+
36
+
37
+ def get_per_obj_mask(mask):
38
+ """Split a mask into per-object masks."""
39
+ object_ids = np.unique(mask)
40
+ object_ids = object_ids[object_ids > 0].tolist()
41
+ per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
42
+ return per_obj_mask
43
+
44
+
45
+ def put_per_obj_mask(per_obj_mask, height, width):
46
+ """Combine per-object masks into a single mask."""
47
+ mask = np.zeros((height, width), dtype=np.uint8)
48
+ object_ids = sorted(per_obj_mask)[::-1]
49
+ for object_id in object_ids:
50
+ object_mask = per_obj_mask[object_id]
51
+ object_mask = object_mask.reshape(height, width)
52
+ mask[object_mask] = object_id
53
+ return mask
54
+
55
+
56
+ def load_masks_from_dir(
57
+ input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False
58
+ ):
59
+ """Load masks from a directory as a dict of per-object masks."""
60
+ if not per_obj_png_file:
61
+ input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
62
+ if allow_missing and not os.path.exists(input_mask_path):
63
+ return {}, None
64
+ input_mask, input_palette = load_ann_png(input_mask_path)
65
+ per_obj_input_mask = get_per_obj_mask(input_mask)
66
+ else:
67
+ per_obj_input_mask = {}
68
+ input_palette = None
69
+ # each object is a directory in "{object_id:%03d}" format
70
+ for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
71
+ object_id = int(object_name)
72
+ input_mask_path = os.path.join(
73
+ input_mask_dir, video_name, object_name, f"{frame_name}.png"
74
+ )
75
+ if allow_missing and not os.path.exists(input_mask_path):
76
+ continue
77
+ input_mask, input_palette = load_ann_png(input_mask_path)
78
+ per_obj_input_mask[object_id] = input_mask > 0
79
+
80
+ return per_obj_input_mask, input_palette
81
+
82
+
83
+ def save_palette_masks_to_dir(
84
+ output_mask_dir,
85
+ video_name,
86
+ frame_name,
87
+ per_obj_output_mask,
88
+ height,
89
+ width,
90
+ per_obj_png_file,
91
+ output_palette,
92
+ ):
93
+ """Save masks to a directory as PNG files."""
94
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
95
+ if not per_obj_png_file:
96
+ output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
97
+ output_mask_path = os.path.join(
98
+ output_mask_dir, video_name, f"{frame_name}.png"
99
+ )
100
+ save_ann_png(output_mask_path, output_mask, output_palette)
101
+ else:
102
+ for object_id, object_mask in per_obj_output_mask.items():
103
+ object_name = f"{object_id:03d}"
104
+ os.makedirs(
105
+ os.path.join(output_mask_dir, video_name, object_name),
106
+ exist_ok=True,
107
+ )
108
+ output_mask = object_mask.reshape(height, width).astype(np.uint8)
109
+ output_mask_path = os.path.join(
110
+ output_mask_dir, video_name, object_name, f"{frame_name}.png"
111
+ )
112
+ save_ann_png(output_mask_path, output_mask, output_palette)
113
+
114
+
115
+ def save_masks_to_dir(
116
+ output_mask_dir,
117
+ video_name,
118
+ frame_name,
119
+ per_obj_output_mask,
120
+ height,
121
+ width,
122
+ per_obj_png_file,
123
+ ):
124
+ """Save masks to a directory as greyscale PNG files."""
125
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
126
+ if not per_obj_png_file:
127
+ output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
128
+ output_mask_path = os.path.join(
129
+ output_mask_dir, video_name, f"{frame_name}.png"
130
+ )
131
+ assert output_mask.dtype == np.uint8
132
+ assert output_mask.ndim == 2
133
+ output_mask = Image.fromarray(output_mask)
134
+ output_mask.save(output_mask_path)
135
+ else:
136
+ for object_id, object_mask in per_obj_output_mask.items():
137
+ object_name = f"{object_id:03d}"
138
+ os.makedirs(
139
+ os.path.join(output_mask_dir, video_name, object_name),
140
+ exist_ok=True,
141
+ )
142
+ output_mask = object_mask.reshape(height, width).astype(np.uint8)
143
+ output_mask_path = os.path.join(
144
+ output_mask_dir, video_name, object_name, f"{frame_name}.png"
145
+ )
146
+ assert output_mask.dtype == np.uint8
147
+ assert output_mask.ndim == 2
148
+ output_mask = Image.fromarray(output_mask)
149
+ output_mask.save(output_mask_path)
150
+
151
+ @torch.inference_mode()
152
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
153
+ def vos_inference(
154
+ predictor,
155
+ base_video_dir,
156
+ input_mask_dir,
157
+ output_mask_dir,
158
+ video_name,
159
+ score_thresh=0.0,
160
+ use_all_masks=False,
161
+ per_obj_png_file=False,
162
+ save_palette_png=False,
163
+ ):
164
+ """Run inference on a single video with the given predictor."""
165
+ # load the video frames and initialize the inference state on this video
166
+ video_dir = os.path.join(base_video_dir, video_name)
167
+ frame_names = [
168
+ os.path.splitext(p)[0]
169
+ for p in os.listdir(video_dir)
170
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
171
+ ]
172
+ frame_names = list(sorted(frame_names))
173
+ inference_state = predictor.init_state(
174
+ video_path=video_dir, async_loading_frames=False
175
+ )
176
+ height = inference_state["video_height"]
177
+ width = inference_state["video_width"]
178
+ input_palette = None
179
+
180
+ # fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks)
181
+ if not use_all_masks:
182
+ # use only the first video's ground-truth mask as the input mask
183
+ input_frame_inds = [0]
184
+ else:
185
+ # use all mask files available in the input_mask_dir as the input masks
186
+ if not per_obj_png_file:
187
+ input_frame_inds = [
188
+ idx
189
+ for idx, name in enumerate(frame_names)
190
+ if os.path.exists(
191
+ os.path.join(input_mask_dir, video_name, f"{name}.png")
192
+ )
193
+ ]
194
+ else:
195
+ input_frame_inds = [
196
+ idx
197
+ for object_name in os.listdir(os.path.join(input_mask_dir, video_name))
198
+ for idx, name in enumerate(frame_names)
199
+ if os.path.exists(
200
+ os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
201
+ )
202
+ ]
203
+ # check and make sure we got at least one input frame
204
+ if len(input_frame_inds) == 0:
205
+ raise RuntimeError(
206
+ f"In {video_name=}, got no input masks in {input_mask_dir=}. "
207
+ "Please make sure the input masks are available in the correct format."
208
+ )
209
+ input_frame_inds = sorted(set(input_frame_inds))
210
+
211
+ # add those input masks to SAM 2 inference state before propagation
212
+ object_ids_set = None
213
+ for input_frame_idx in input_frame_inds:
214
+ try:
215
+ per_obj_input_mask, input_palette = load_masks_from_dir(
216
+ input_mask_dir=input_mask_dir,
217
+ video_name=video_name,
218
+ frame_name=frame_names[input_frame_idx],
219
+ per_obj_png_file=per_obj_png_file,
220
+ )
221
+ except FileNotFoundError as e:
222
+ raise RuntimeError(
223
+ f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. "
224
+ "Please add the `--track_object_appearing_later_in_video` flag "
225
+ "for VOS datasets that don't have all objects to track appearing "
226
+ "in the first frame (such as LVOS or YouTube-VOS)."
227
+ ) from e
228
+
229
+ # get the list of object ids to track from the first input frame
230
+ if object_ids_set is None:
231
+ object_ids_set = set(per_obj_input_mask)
232
+ for object_id, object_mask in per_obj_input_mask.items():
233
+ # check and make sure no new object ids appear only in later frames
234
+ if object_id not in object_ids_set:
235
+ raise RuntimeError(
236
+ f"In {video_name=}, got a new {object_id=} appearing only in a "
237
+ f"later {input_frame_idx=} (but not appearing in the first frame). "
238
+ "Please add the `--track_object_appearing_later_in_video` flag "
239
+ "for VOS datasets that don't have all objects to track appearing "
240
+ "in the first frame (such as LVOS or YouTube-VOS)."
241
+ )
242
+ predictor.add_new_mask(
243
+ inference_state=inference_state,
244
+ frame_idx=input_frame_idx,
245
+ obj_id=object_id,
246
+ mask=object_mask,
247
+ )
248
+
249
+ # check and make sure we have at least one object to track
250
+ if object_ids_set is None or len(object_ids_set) == 0:
251
+ raise RuntimeError(
252
+ f"In {video_name=}, got no object ids on {input_frame_inds=}. "
253
+ "Please add the `--track_object_appearing_later_in_video` flag "
254
+ "for VOS datasets that don't have all objects to track appearing "
255
+ "in the first frame (such as LVOS or YouTube-VOS)."
256
+ )
257
+
258
+ # run propagation throughout the video and collect the results in a dict
259
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
260
+ output_palette = input_palette or DAVIS_PALETTE
261
+ video_segments = {} # video_segments contains the per-frame segmentation results
262
+
263
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
264
+ inference_state
265
+ ):
266
+ per_obj_output_mask = {
267
+ out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
268
+ for i, out_obj_id in enumerate(out_obj_ids)
269
+ }
270
+ video_segments[out_frame_idx] = per_obj_output_mask
271
+
272
+ # write the output masks as palette PNG files to output_mask_dir
273
+ for out_frame_idx, per_obj_output_mask in video_segments.items():
274
+ if save_palette_png:
275
+ # save palette PNG prediction results
276
+ save_palette_masks_to_dir(
277
+ output_mask_dir=output_mask_dir,
278
+ video_name=video_name,
279
+ frame_name=frame_names[out_frame_idx],
280
+ per_obj_output_mask=per_obj_output_mask,
281
+ height=height,
282
+ width=width,
283
+ per_obj_png_file=per_obj_png_file,
284
+ output_palette=output_palette,
285
+ )
286
+ else:
287
+ # save raw prediction results
288
+ save_masks_to_dir(
289
+ output_mask_dir=output_mask_dir,
290
+ video_name=video_name,
291
+ frame_name=frame_names[out_frame_idx],
292
+ per_obj_output_mask=per_obj_output_mask,
293
+ height=height,
294
+ width=width,
295
+ per_obj_png_file=per_obj_png_file,
296
+ )
297
+
298
+ @torch.inference_mode()
299
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
300
+ def vos_separate_inference_per_object(
301
+ predictor,
302
+ base_video_dir,
303
+ input_mask_dir,
304
+ output_mask_dir,
305
+ video_name,
306
+ score_thresh=0.0,
307
+ use_all_masks=False,
308
+ per_obj_png_file=False,
309
+ ):
310
+ """
311
+ Run inference on a single video with the given predictor.
312
+
313
+ Unlike `vos_inference`, this function run inference separately for each object
314
+ in a video, which could be applied to datasets like LVOS or YouTube-VOS that
315
+ don't have all objects to track appearing in the first frame (i.e. some objects
316
+ might appear only later in the video).
317
+ """
318
+ # load the video frames and initialize the inference state on this video
319
+ video_dir = os.path.join(base_video_dir, video_name)
320
+ frame_names = [
321
+ os.path.splitext(p)[0]
322
+ for p in os.listdir(video_dir)
323
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
324
+ ]
325
+ frame_names = list(sorted(frame_names))
326
+ inference_state = predictor.init_state(
327
+ video_path=video_dir, async_loading_frames=False
328
+ )
329
+ height = inference_state["video_height"]
330
+ width = inference_state["video_width"]
331
+ input_palette = None
332
+
333
+ # collect all the object ids and their input masks
334
+ inputs_per_object = defaultdict(dict)
335
+ for idx, name in enumerate(frame_names):
336
+ if per_obj_png_file or os.path.exists(
337
+ os.path.join(input_mask_dir, video_name, f"{name}.png")
338
+ ):
339
+ per_obj_input_mask, input_palette = load_masks_from_dir(
340
+ input_mask_dir=input_mask_dir,
341
+ video_name=video_name,
342
+ frame_name=frame_names[idx],
343
+ per_obj_png_file=per_obj_png_file,
344
+ allow_missing=True,
345
+ )
346
+ for object_id, object_mask in per_obj_input_mask.items():
347
+ # skip empty masks
348
+ if not np.any(object_mask):
349
+ continue
350
+ # if `use_all_masks=False`, we only use the first mask for each object
351
+ if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
352
+ continue
353
+ print(f"adding mask from frame {idx} as input for {object_id=}")
354
+ inputs_per_object[object_id][idx] = object_mask
355
+
356
+
357
+ # run inference separately for each object in the video
358
+ object_ids = sorted(inputs_per_object)
359
+ output_scores_per_object = defaultdict(dict)
360
+ for object_id in object_ids:
361
+ # add those input masks to SAM 2 inference state before propagation
362
+ input_frame_inds = sorted(inputs_per_object[object_id])
363
+ predictor.reset_state(inference_state)
364
+ for input_frame_idx in input_frame_inds:
365
+ predictor.add_new_mask(
366
+ inference_state=inference_state,
367
+ frame_idx=input_frame_idx,
368
+ obj_id=object_id,
369
+ mask=inputs_per_object[object_id][input_frame_idx],
370
+ )
371
+
372
+ # run propagation throughout the video and collect the results in a dict
373
+ for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(
374
+ inference_state,
375
+ start_frame_idx=min(input_frame_inds),
376
+ reverse=False,
377
+ ):
378
+ obj_scores = out_mask_logits.cpu().numpy()
379
+ output_scores_per_object[object_id][out_frame_idx] = obj_scores
380
+
381
+ # post-processing: consolidate the per-object scores into per-frame masks
382
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
383
+ output_palette = input_palette or DAVIS_PALETTE
384
+
385
+ video_segments = {} # video_segments contains the per-frame segmentation results
386
+ for frame_idx in range(len(frame_names)):
387
+ scores = torch.full(
388
+ size=(len(object_ids), 1, height, width),
389
+ fill_value=-1024.0,
390
+ dtype=torch.float32,
391
+ )
392
+ for i, object_id in enumerate(object_ids):
393
+ if frame_idx in output_scores_per_object[object_id]:
394
+ scores[i] = torch.from_numpy(
395
+ output_scores_per_object[object_id][frame_idx]
396
+ )
397
+
398
+ if not per_obj_png_file:
399
+ scores = predictor._apply_non_overlapping_constraints(scores)
400
+ per_obj_output_mask = {
401
+ object_id: (scores[i] > score_thresh).cpu().numpy()
402
+ for i, object_id in enumerate(object_ids)
403
+ }
404
+ video_segments[frame_idx] = per_obj_output_mask
405
+
406
+ # write the output masks as palette PNG files to output_mask_dir
407
+ for frame_idx, per_obj_output_mask in video_segments.items():
408
+ save_palette_masks_to_dir(
409
+ output_mask_dir=output_mask_dir,
410
+ video_name=video_name,
411
+ frame_name=frame_names[frame_idx],
412
+ per_obj_output_mask=per_obj_output_mask,
413
+ height=height,
414
+ width=width,
415
+ per_obj_png_file=per_obj_png_file,
416
+ output_palette=output_palette,
417
+ )
418
+
419
+
420
+ def main():
421
+ parser = argparse.ArgumentParser()
422
+ parser.add_argument(
423
+ "--sam2_cfg",
424
+ type=str,
425
+ default="configs/sam2.1_hiera_t512.yaml",
426
+ help="MedSAM2 model configuration file",
427
+ )
428
+ parser.add_argument(
429
+ "--sam2_checkpoint",
430
+ type=str,
431
+ default="./checkpoints/MedSAM2_latest.pt",
432
+ help="path to the MedSAM2 model checkpoint",
433
+ )
434
+ parser.add_argument(
435
+ "-i",
436
+ "--base_video_dir",
437
+ type=str,
438
+ required=True,
439
+ help="directory containing videos (as JPEG files) to run inference on",
440
+ )
441
+ parser.add_argument(
442
+ "-m",
443
+ "--input_mask_dir",
444
+ type=str,
445
+ required=True,
446
+ help="directory containing input masks (as PNG files) of each video",
447
+ )
448
+ parser.add_argument(
449
+ "--video_list_file",
450
+ type=str,
451
+ default=None,
452
+ help="text file containing the list of video names to run inference on",
453
+ )
454
+ parser.add_argument(
455
+ "-o",
456
+ "--output_mask_dir",
457
+ type=str,
458
+ required=True,
459
+ help="directory to save the output masks (as PNG files)",
460
+ )
461
+ parser.add_argument(
462
+ "--score_thresh",
463
+ type=float,
464
+ default=0.0,
465
+ help="threshold for the output mask logits (default: 0.0)",
466
+ )
467
+ parser.add_argument(
468
+ "--use_all_masks",
469
+ action="store_true",
470
+ help="whether to use all available PNG files in input_mask_dir "
471
+ "(default without this flag: just the first PNG file as input to the SAM 2 model; "
472
+ "usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)",
473
+ )
474
+ parser.add_argument(
475
+ "--per_obj_png_file",
476
+ action="store_true",
477
+ help="whether use separate per-object PNG files for input and output masks "
478
+ "(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; "
479
+ "note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)",
480
+ )
481
+ parser.add_argument(
482
+ "--save_palette_png",
483
+ action="store_true",
484
+ help="whether to save palette PNG files for output masks "
485
+ "(default without this flag: all object masks are saved as grayscale PNG files (np.uint8) without palette)",
486
+ )
487
+ parser.add_argument(
488
+ "--apply_postprocessing",
489
+ action="store_true",
490
+ help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
491
+ "(we don't apply such post-processing in the SAM 2 model evaluation)",
492
+ )
493
+ parser.add_argument(
494
+ "--track_object_appearing_later_in_video",
495
+ action="store_true",
496
+ help="whether to track objects that appear later in the video (i.e. not on the first frame; "
497
+ "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
498
+ )
499
+ parser.add_argument(
500
+ "--use_vos_optimized_video_predictor",
501
+ action="store_true",
502
+ help="whether to use vos optimized video predictor with all modules compiled",
503
+ )
504
+ args = parser.parse_args()
505
+
506
+ # if we use per-object PNG files, they could possibly overlap in inputs and outputs
507
+ hydra_overrides_extra = [
508
+ "++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true")
509
+ ]
510
+ predictor = build_sam2_video_predictor(
511
+ config_file=args.sam2_cfg,
512
+ ckpt_path=args.sam2_checkpoint,
513
+ apply_postprocessing=args.apply_postprocessing,
514
+ hydra_overrides_extra=hydra_overrides_extra,
515
+ vos_optimized=args.use_vos_optimized_video_predictor,
516
+ )
517
+
518
+ if args.use_all_masks:
519
+ print("using all available masks in input_mask_dir as input to the MedSAM2 model")
520
+ else:
521
+ print(
522
+ "using only the first frame's mask in input_mask_dir as input to the MedSAM2 model"
523
+ )
524
+ # if a video list file is provided, read the video names from the file
525
+ # (otherwise, we use all subdirectories in base_video_dir)
526
+ if args.video_list_file is not None:
527
+ with open(args.video_list_file, "r") as f:
528
+ video_names = [v.strip() for v in f.readlines()]
529
+ else:
530
+ video_names = [
531
+ p
532
+ for p in os.listdir(args.base_video_dir)
533
+ if os.path.isdir(os.path.join(args.base_video_dir, p))
534
+ ]
535
+ print(f"running inference on {len(video_names)} videos:\n{video_names}")
536
+
537
+ for n_video, video_name in enumerate(video_names):
538
+ print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
539
+ if not args.track_object_appearing_later_in_video:
540
+ vos_inference(
541
+ predictor=predictor,
542
+ base_video_dir=args.base_video_dir,
543
+ input_mask_dir=args.input_mask_dir,
544
+ output_mask_dir=args.output_mask_dir,
545
+ video_name=video_name,
546
+ score_thresh=args.score_thresh,
547
+ use_all_masks=args.use_all_masks,
548
+ per_obj_png_file=args.per_obj_png_file,
549
+ save_palette_png=args.save_palette_png,
550
+ )
551
+ else:
552
+ vos_separate_inference_per_object(
553
+ predictor=predictor,
554
+ base_video_dir=args.base_video_dir,
555
+ input_mask_dir=args.input_mask_dir,
556
+ output_mask_dir=args.output_mask_dir,
557
+ video_name=video_name,
558
+ score_thresh=args.score_thresh,
559
+ use_all_masks=args.use_all_masks,
560
+ per_obj_png_file=args.per_obj_png_file,
561
+ )
562
+
563
+ print(
564
+ f"completed inference on {len(video_names)} videos -- "
565
+ f"output masks saved to {args.output_mask_dir}"
566
+ )
567
+
568
+
569
+ if __name__ == "__main__":
570
+ main()
multi_node_train.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -t 7-00:0:0
3
+ #SBATCH -J medsam2-tr-tiny
4
+ #SBATCH --mem=450G
5
+ #SBATCH -c 60
6
+ #SBATCH -N 3
7
+ #SBATCH --ntasks-per-node=1
8
+ #SBATCH --gres=gpu:4
9
+ #SBATCH -o out_mnodes_tiny.out
10
+
11
+ export PATH=/usr/local/cuda/bin:$PATH
12
+ timestamp=$(date +"%Y%m%d-%H%M")
13
+
14
+ # Set the master node address (first node in the allocation)
15
+ export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
16
+ # export MASTER_PORT=29500
17
+ export MASTER_PORT=$(python - <<EOF
18
+ import socket
19
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
20
+ sock.bind(('', 0)) # OS will allocate a free port
21
+ free_port = sock.getsockname()[1]
22
+ sock.close()
23
+ print(free_port)
24
+ EOF
25
+ )
26
+
27
+ # Print some information
28
+ echo "Master node: $MASTER_ADDR"
29
+ echo "Master port: $MASTER_PORT"
30
+ echo "Number of nodes: $SLURM_NNODES"
31
+ echo "GPUs per node: $SLURM_GPUS_ON_NODE"
32
+
33
+ config=configs/sam2.1_hiera_tiny_finetune512.yaml
34
+ output_path=./exp_log/mnode_tiny
35
+
36
+ # Function to run the training script
37
+ srun --exclusive python training/train.py \
38
+ -c $config \
39
+ --output-path $output_path \
40
+ --use-cluster 0 \
41
+ --num-gpus $SLURM_GPUS_ON_NODE \
42
+ --num-nodes $SLURM_NNODES \
43
+ --master-addr $MASTER_ADDR \
44
+ --main-port $MASTER_PORT
45
+
46
+ echo "training done"
47
+
48
+
notebooks/MedSAM2_Inference_Video.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/MedSAM2_inference_CT_Lesion.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=61.0",
4
+ "torch>=2.5.1",
5
+ ]
6
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.38.0
2
+ torch>=2.0
3
+ torchvision
4
+ numpy
5
+ SimpleITK
6
+ nibabel
7
+ opencv-python-headless
8
+ imageio
9
+ tqdm
10
+ matplotlib
11
+ einops
12
+ omegaconf
13
+ ffmpeg-python
14
+ moviepy
15
+ huggingface_hub
16
+ hydra-core
sam2/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+ from hydra.core.global_hydra import GlobalHydra
9
+
10
+ if not GlobalHydra.instance().is_initialized():
11
+ initialize_config_module("sam2", version_base="1.2")
sam2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (431 Bytes). View file
 
sam2/__pycache__/build_sam.cpython-312.pyc ADDED
Binary file (5.16 kB). View file
 
sam2/__pycache__/sam2_image_predictor.cpython-312.pyc ADDED
Binary file (22.7 kB). View file
 
sam2/__pycache__/sam2_video_predictor_npz.cpython-312.pyc ADDED
Binary file (38 kB). View file
 
sam2/build_sam.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from hydra import compose
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+
14
+ HF_MODEL_ID_TO_FILENAMES = {
15
+ "facebook/sam2-hiera-tiny": (
16
+ "configs/sam2/sam2_hiera_t.yaml",
17
+ "sam2_hiera_tiny.pt",
18
+ ),
19
+ "facebook/sam2-hiera-small": (
20
+ "configs/sam2/sam2_hiera_s.yaml",
21
+ "sam2_hiera_small.pt",
22
+ ),
23
+ "facebook/sam2-hiera-base-plus": (
24
+ "configs/sam2/sam2_hiera_b+.yaml",
25
+ "sam2_hiera_base_plus.pt",
26
+ ),
27
+ "facebook/sam2-hiera-large": (
28
+ "configs/sam2/sam2_hiera_l.yaml",
29
+ "sam2_hiera_large.pt",
30
+ ),
31
+ "facebook/sam2.1-hiera-tiny": (
32
+ "configs/sam2.1/sam2.1_hiera_t.yaml",
33
+ "sam2.1_hiera_tiny.pt",
34
+ ),
35
+ "facebook/sam2.1-hiera-small": (
36
+ "configs/sam2.1/sam2.1_hiera_s.yaml",
37
+ "sam2.1_hiera_small.pt",
38
+ ),
39
+ "facebook/sam2.1-hiera-base-plus": (
40
+ "configs/sam2.1/sam2.1_hiera_b+.yaml",
41
+ "sam2.1_hiera_base_plus.pt",
42
+ ),
43
+ "facebook/sam2.1-hiera-large": (
44
+ "configs/sam2.1/sam2.1_hiera_l.yaml",
45
+ "sam2.1_hiera_large.pt",
46
+ ),
47
+ }
48
+
49
+
50
+ def get_best_available_device():
51
+ """
52
+ Get the best available device in the order: CUDA, MPS, CPU
53
+ Returns: device string for torch.device
54
+ """
55
+ if torch.cuda.is_available():
56
+ return "cuda"
57
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
58
+ return "mps"
59
+ else:
60
+ return "cpu"
61
+
62
+
63
+ def build_sam2(
64
+ config_file,
65
+ ckpt_path=None,
66
+ device=None,
67
+ mode="eval",
68
+ hydra_overrides_extra=[],
69
+ apply_postprocessing=True,
70
+ **kwargs,
71
+ ):
72
+ # Use the provided device or get the best available one
73
+ device = device or get_best_available_device()
74
+ logging.info(f"Using device: {device}")
75
+
76
+ if apply_postprocessing:
77
+ hydra_overrides_extra = hydra_overrides_extra.copy()
78
+ hydra_overrides_extra += [
79
+ # dynamically fall back to multi-mask if the single mask is not stable
80
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
81
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
82
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
83
+ ]
84
+ # Read config and init model
85
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
86
+ OmegaConf.resolve(cfg)
87
+ model = instantiate(cfg.model, _recursive_=True)
88
+ _load_checkpoint(model, ckpt_path)
89
+ model = model.to(device)
90
+ if mode == "eval":
91
+ model.eval()
92
+ return model
93
+
94
+
95
+ def build_sam2_video_predictor(
96
+ config_file,
97
+ ckpt_path=None,
98
+ device=None,
99
+ mode="eval",
100
+ hydra_overrides_extra=[],
101
+ apply_postprocessing=True,
102
+ **kwargs,
103
+ ):
104
+ # Use the provided device or get the best available one
105
+ device = device or get_best_available_device()
106
+ logging.info(f"Using device: {device}")
107
+
108
+ hydra_overrides = [
109
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
110
+ ]
111
+ if apply_postprocessing:
112
+ hydra_overrides_extra = hydra_overrides_extra.copy()
113
+ hydra_overrides_extra += [
114
+ # dynamically fall back to multi-mask if the single mask is not stable
115
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
116
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
117
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
118
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
119
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
120
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
121
+ "++model.fill_hole_area=8",
122
+ ]
123
+ hydra_overrides.extend(hydra_overrides_extra)
124
+
125
+ # Read config and init model
126
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
127
+ OmegaConf.resolve(cfg)
128
+ model = instantiate(cfg.model, _recursive_=True)
129
+ _load_checkpoint(model, ckpt_path)
130
+ model = model.to(device)
131
+ if mode == "eval":
132
+ model.eval()
133
+ return model
134
+
135
+ def build_sam2_video_predictor_npz(
136
+ config_file,
137
+ ckpt_path=None,
138
+ device=None,
139
+ mode="eval",
140
+ hydra_overrides_extra=[],
141
+ apply_postprocessing=True,
142
+ **kwargs,
143
+ ):
144
+ # Use the provided device or get the best available one
145
+ device = device or get_best_available_device()
146
+ logging.info(f"Using device: {device}")
147
+
148
+ hydra_overrides = [
149
+ "++model._target_=sam2.sam2_video_predictor_npz.SAM2VideoPredictorNPZ",
150
+ ]
151
+ if apply_postprocessing:
152
+ hydra_overrides_extra = hydra_overrides_extra.copy()
153
+ hydra_overrides_extra += [
154
+ # dynamically fall back to multi-mask if the single mask is not stable
155
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
156
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
157
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
158
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
159
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
160
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
161
+ "++model.fill_hole_area=8",
162
+ ]
163
+ hydra_overrides.extend(hydra_overrides_extra)
164
+
165
+ # Read config and init model
166
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
167
+ OmegaConf.resolve(cfg)
168
+ model = instantiate(cfg.model, _recursive_=True)
169
+ _load_checkpoint(model, ckpt_path)
170
+ model = model.to(device)
171
+ if mode == "eval":
172
+ model.eval()
173
+ return model
174
+
175
+
176
+
177
+ def _hf_download(model_id):
178
+ from huggingface_hub import hf_hub_download
179
+
180
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
181
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
182
+ return config_name, ckpt_path
183
+
184
+
185
+ def build_sam2_hf(model_id, **kwargs):
186
+ config_name, ckpt_path = _hf_download(model_id)
187
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
188
+
189
+
190
+ def build_sam2_video_predictor_hf(model_id, **kwargs):
191
+ config_name, ckpt_path = _hf_download(model_id)
192
+ return build_sam2_video_predictor(
193
+ config_file=config_name, ckpt_path=ckpt_path, **kwargs
194
+ )
195
+
196
+
197
+ def _load_checkpoint(model, ckpt_path):
198
+ if ckpt_path is not None:
199
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
200
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
201
+ if missing_keys:
202
+ logging.error(missing_keys)
203
+ raise RuntimeError()
204
+ if unexpected_keys:
205
+ logging.error(unexpected_keys)
206
+ raise RuntimeError()
207
+ logging.info("Loaded checkpoint sucessfully")
sam2/configs/sam2.1_hiera_t512.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 512
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ # HieraT does not currently support compilation, should always be set to False
121
+ compile_image_encoder: False
sam2/configs/sam2.1_hiera_tiny_finetune512.yaml ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 512
5
+ train_video_batch_size: 8
6
+ num_train_workers: 15
7
+ num_frames: 8
8
+ max_num_objects: 5
9
+ base_lr: 5.0e-5
10
+ vision_lr: 3.0e-05
11
+ phases_per_epoch: 1
12
+ num_epochs: 75
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ folder: # PATH to Med NPZ folder
17
+ multiplier: 1
18
+
19
+ # Video transforms
20
+ vos:
21
+ train_transforms:
22
+ - _target_: training.dataset.transforms.ComposeAPI
23
+ transforms:
24
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
25
+ consistent_transform: True
26
+ - _target_: training.dataset.transforms.RandomAffine
27
+ degrees: 25
28
+ shear: 20
29
+ image_interpolation: bilinear
30
+ consistent_transform: True
31
+ - _target_: training.dataset.transforms.RandomResizeAPI
32
+ sizes: ${scratch.resolution}
33
+ square: true
34
+ consistent_transform: True
35
+ - _target_: training.dataset.transforms.ColorJitter
36
+ consistent_transform: True
37
+ brightness: 0.1
38
+ contrast: 0.03
39
+ saturation: 0.03
40
+ hue: null
41
+ - _target_: training.dataset.transforms.RandomGrayscale
42
+ p: 0.05
43
+ consistent_transform: True
44
+ - _target_: training.dataset.transforms.ColorJitter
45
+ consistent_transform: False
46
+ brightness: 0.1
47
+ contrast: 0.05
48
+ saturation: 0.05
49
+ hue: null
50
+ - _target_: training.dataset.transforms.ToTensorAPI
51
+ - _target_: training.dataset.transforms.NormalizeAPI
52
+ mean: [0.485, 0.456, 0.406]
53
+ std: [0.229, 0.224, 0.225]
54
+
55
+
56
+ trainer:
57
+ _target_: training.trainer.Trainer
58
+ mode: train_only
59
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
60
+ accelerator: cuda
61
+ seed_value: 123
62
+
63
+ model:
64
+ _target_: training.model.sam2.SAM2Train
65
+ image_encoder:
66
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
67
+ scalp: 1
68
+ trunk:
69
+ _target_: sam2.modeling.backbones.hieradet.Hiera
70
+ embed_dim: 96
71
+ num_heads: 1
72
+ stages: [1, 2, 7, 2]
73
+ global_att_blocks: [5, 7, 9]
74
+ window_pos_embed_bkg_spatial_size: [7, 7]
75
+ neck:
76
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
77
+ position_encoding:
78
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
79
+ num_pos_feats: 256
80
+ normalize: true
81
+ scale: null
82
+ temperature: 10000
83
+ d_model: 256
84
+ backbone_channel_list: [768, 384, 192, 96]
85
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
86
+ fpn_interp_model: nearest
87
+
88
+ memory_attention:
89
+ _target_: sam2.modeling.memory_attention.MemoryAttention
90
+ d_model: 256
91
+ pos_enc_at_input: true
92
+ layer:
93
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
94
+ activation: relu
95
+ dim_feedforward: 2048
96
+ dropout: 0.1
97
+ pos_enc_at_attn: false
98
+ self_attention:
99
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
100
+ rope_theta: 10000.0
101
+ feat_sizes: [32, 32]
102
+ embedding_dim: 256
103
+ num_heads: 1
104
+ downsample_rate: 1
105
+ dropout: 0.1
106
+ d_model: 256
107
+ pos_enc_at_cross_attn_keys: true
108
+ pos_enc_at_cross_attn_queries: false
109
+ cross_attention:
110
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
111
+ rope_theta: 10000.0
112
+ feat_sizes: [32, 32]
113
+ rope_k_repeat: True
114
+ embedding_dim: 256
115
+ num_heads: 1
116
+ downsample_rate: 1
117
+ dropout: 0.1
118
+ kv_in_dim: 64
119
+ num_layers: 4
120
+
121
+ memory_encoder:
122
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
123
+ out_dim: 64
124
+ position_encoding:
125
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
126
+ num_pos_feats: 64
127
+ normalize: true
128
+ scale: null
129
+ temperature: 10000
130
+ mask_downsampler:
131
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
132
+ kernel_size: 3
133
+ stride: 2
134
+ padding: 1
135
+ fuser:
136
+ _target_: sam2.modeling.memory_encoder.Fuser
137
+ layer:
138
+ _target_: sam2.modeling.memory_encoder.CXBlock
139
+ dim: 256
140
+ kernel_size: 7
141
+ padding: 3
142
+ layer_scale_init_value: 1e-6
143
+ use_dwconv: True # depth-wise convs
144
+ num_layers: 2
145
+
146
+ num_maskmem: 7
147
+ image_size: ${scratch.resolution}
148
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
149
+ # SAM decoder
150
+ sigmoid_scale_for_mem_enc: 20.0
151
+ sigmoid_bias_for_mem_enc: -10.0
152
+ use_mask_input_as_output_without_sam: true
153
+ # Memory
154
+ directly_add_no_mem_embed: true
155
+ no_obj_embed_spatial: true
156
+ # use high-resolution feature map in the SAM mask decoder
157
+ use_high_res_features_in_sam: true
158
+ # output 3 masks on the first click on initial conditioning frames
159
+ multimask_output_in_sam: true
160
+ # SAM heads
161
+ iou_prediction_use_sigmoid: True
162
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
163
+ use_obj_ptrs_in_encoder: true
164
+ add_tpos_enc_to_obj_ptrs: true
165
+ proj_tpos_enc_in_obj_ptrs: true
166
+ use_signed_tpos_enc_to_obj_ptrs: true
167
+ only_obj_ptrs_in_the_past_for_eval: true
168
+ # object occlusion prediction
169
+ pred_obj_scores: true
170
+ pred_obj_scores_mlp: true
171
+ fixed_no_obj_ptr: true
172
+ # multimask tracking settings
173
+ multimask_output_for_tracking: true
174
+ use_multimask_token_for_obj_ptr: true
175
+ multimask_min_pt_num: 0
176
+ multimask_max_pt_num: 1
177
+ use_mlp_for_obj_ptr_proj: true
178
+ # Compilation flag
179
+ # compile_image_encoder: False
180
+
181
+ ####### Training specific params #######
182
+ # box/point input and corrections
183
+ prob_to_use_pt_input_for_train: 0.5
184
+ prob_to_use_pt_input_for_eval: 0.0
185
+ prob_to_use_box_input_for_train: 1.0
186
+ prob_to_use_box_input_for_eval: 0.0
187
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
188
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
189
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
190
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
191
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
192
+ # maximum 2 initial conditioning frames
193
+ num_init_cond_frames_for_train: 2
194
+ rand_init_cond_frames_for_train: True # random 1~2
195
+ num_correction_pt_per_frame: 7
196
+ use_act_ckpt_iterative_pt_sampling: false
197
+
198
+
199
+
200
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
201
+ forward_backbone_per_frame_for_eval: True
202
+
203
+
204
+ data:
205
+ train:
206
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
207
+ phases_per_epoch: ${scratch.phases_per_epoch}
208
+ batch_sizes:
209
+ - ${scratch.train_video_batch_size}
210
+ datasets:
211
+ - _target_: training.dataset.utils.RepeatFactorWrapper
212
+ dataset:
213
+ _target_: training.dataset.utils.ConcatDataset
214
+ datasets:
215
+ # CT
216
+ - _target_: training.dataset.vos_dataset.VOSDataset
217
+ transforms: ${vos.train_transforms}
218
+ training: true
219
+ video_dataset:
220
+ _target_: training.dataset.vos_raw_dataset.NPZRawDataset
221
+ folder: CVPR25/3D_train_npz_random_10percent_16G/CT
222
+ sampler:
223
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
224
+ num_frames: ${scratch.num_frames}
225
+ max_num_objects: ${scratch.max_num_objects}
226
+ multiplier: 1
227
+ # MR
228
+ - _target_: training.dataset.vos_dataset.VOSDataset
229
+ transforms: ${vos.train_transforms}
230
+ training: true
231
+ video_dataset:
232
+ _target_: training.dataset.vos_raw_dataset.NPZRawDataset
233
+ folder: CVPR25/3D_train_npz_random_10percent_16G/MR
234
+ sampler:
235
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
236
+ num_frames: ${scratch.num_frames}
237
+ max_num_objects: ${scratch.max_num_objects}
238
+ multiplier: 1
239
+ # PET
240
+ - _target_: training.dataset.vos_dataset.VOSDataset
241
+ transforms: ${vos.train_transforms}
242
+ training: true
243
+ video_dataset:
244
+ _target_: training.dataset.vos_raw_dataset.NPZRawDataset
245
+ folder: CVPR25/3D_train_npz_random_10percent_16G/PET
246
+ sampler:
247
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
248
+ num_frames: ${scratch.num_frames}
249
+ max_num_objects: ${scratch.max_num_objects}
250
+ multiplier: 10
251
+ # Ultrasound 3D
252
+ - _target_: training.dataset.vos_dataset.VOSDataset
253
+ transforms: ${vos.train_transforms}
254
+ training: true
255
+ video_dataset:
256
+ _target_: training.dataset.vos_raw_dataset.NPZRawDataset
257
+ folder: CVPR25/3D_train_npz_random_10percent_16G/US3D
258
+ sampler:
259
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
260
+ num_frames: ${scratch.num_frames}
261
+ max_num_objects: ${scratch.max_num_objects}
262
+ multiplier: 1
263
+ # Microscopy 3D
264
+ - _target_: training.dataset.vos_dataset.VOSDataset
265
+ transforms: ${vos.train_transforms}
266
+ training: true
267
+ video_dataset:
268
+ _target_: training.dataset.vos_raw_dataset.NPZRawDataset
269
+ folder: CVPR25/3D_train_npz_random_10percent_16G/Microscopy
270
+ sampler:
271
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
272
+ num_frames: ${scratch.num_frames}
273
+ max_num_objects: ${scratch.max_num_objects}
274
+ multiplier: 1
275
+
276
+ shuffle: True
277
+ num_workers: ${scratch.num_train_workers}
278
+ pin_memory: True
279
+ drop_last: True
280
+ collate_fn:
281
+ _target_: training.utils.data_utils.collate_fn
282
+ _partial_: true
283
+ dict_key: all
284
+
285
+ optim:
286
+ amp:
287
+ enabled: True
288
+ amp_dtype: bfloat16
289
+
290
+ optimizer:
291
+ _target_: torch.optim.AdamW
292
+
293
+ gradient_clip:
294
+ _target_: training.optimizer.GradientClipper
295
+ max_norm: 0.1
296
+ norm_type: 2
297
+
298
+ param_group_modifiers:
299
+ - _target_: training.optimizer.layer_decay_param_modifier
300
+ _partial_: True
301
+ layer_decay_value: 0.9
302
+ apply_to: 'image_encoder.trunk'
303
+ overrides:
304
+ - pattern: '*pos_embed*'
305
+ value: 1.0
306
+
307
+ options:
308
+ lr:
309
+ - scheduler:
310
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
311
+ start_value: ${scratch.base_lr}
312
+ end_value: ${divide:${scratch.base_lr},10}
313
+ - scheduler:
314
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
315
+ start_value: ${scratch.vision_lr}
316
+ end_value: ${divide:${scratch.vision_lr},10}
317
+ param_names:
318
+ - 'image_encoder.*'
319
+ weight_decay:
320
+ - scheduler:
321
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
322
+ value: 0.1
323
+ - scheduler:
324
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
325
+ value: 0.0
326
+ param_names:
327
+ - '*bias*'
328
+ module_cls_names: ['torch.nn.LayerNorm']
329
+
330
+ loss:
331
+ all:
332
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
333
+ weight_dict:
334
+ loss_mask: 20
335
+ loss_dice: 1
336
+ loss_iou: 1
337
+ loss_class: 1
338
+ supervise_all_iou: true
339
+ iou_use_l1_loss: true
340
+ pred_obj_scores: true
341
+ focal_gamma_obj_score: 0.0
342
+ focal_alpha_obj_score: -1.0
343
+
344
+ distributed:
345
+ backend: nccl # gloo or nccl
346
+ find_unused_parameters: True
347
+
348
+ logging:
349
+ tensorboard_writer:
350
+ _target_: training.utils.logger.make_tensorboard_logger
351
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
352
+ flush_secs: 120
353
+ should_log: True
354
+ log_dir: ${launcher.experiment_log_dir}/logs
355
+ log_freq: 10
356
+
357
+ # initialize from a SAM 2 checkpoint
358
+ checkpoint:
359
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
360
+ save_freq: 10 # 0 only last checkpoint is saved.
361
+ model_weight_initializer:
362
+ _partial_: True
363
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
364
+ strict: True
365
+ ignore_unexpected_keys: null
366
+ ignore_missing_keys: null
367
+
368
+ state_dict:
369
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
370
+ checkpoint_path: checkpoints/sam2.1_hiera_tiny.pt # PATH to SAM 2.1 checkpoint
371
+ ckpt_state_dict_keys: ['model']
372
+
373
+ launcher:
374
+ num_nodes: 1
375
+ gpus_per_node: 4
376
+ experiment_log_dir: exp_log # Path to log directory, defaults to ./sam2_logs/${config_name}
377
+
378
+ # SLURM args if running on a cluster
379
+ submitit:
380
+ partition: gpu_bwanggroup
381
+ account: null
382
+ qos: null
383
+ cpus_per_task: 10
384
+ use_cluster: false
385
+ timeout_hour: 24
386
+ name: null
387
+ port_range: [10000, 65000]
388
+
389
+
sam2/csrc/connected_components.cu ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ // adapted from https://github.com/zsef123/Connected_components_PyTorch
8
+ // with license found in the LICENSE_cctorch file in the root directory.
9
+ #include <ATen/cuda/CUDAContext.h>
10
+ #include <cuda.h>
11
+ #include <cuda_runtime.h>
12
+ #include <torch/extension.h>
13
+ #include <torch/script.h>
14
+ #include <vector>
15
+
16
+ // 2d
17
+ #define BLOCK_ROWS 16
18
+ #define BLOCK_COLS 16
19
+
20
+ namespace cc2d {
21
+
22
+ template <typename T>
23
+ __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
24
+ return (bitmap >> pos) & 1;
25
+ }
26
+
27
+ __device__ int32_t find(const int32_t* s_buf, int32_t n) {
28
+ while (s_buf[n] != n)
29
+ n = s_buf[n];
30
+ return n;
31
+ }
32
+
33
+ __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
34
+ const int32_t id = n;
35
+ while (s_buf[n] != n) {
36
+ n = s_buf[n];
37
+ s_buf[id] = n;
38
+ }
39
+ return n;
40
+ }
41
+
42
+ __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
43
+ bool done;
44
+ do {
45
+ a = find(s_buf, a);
46
+ b = find(s_buf, b);
47
+
48
+ if (a < b) {
49
+ int32_t old = atomicMin(s_buf + b, a);
50
+ done = (old == b);
51
+ b = old;
52
+ } else if (b < a) {
53
+ int32_t old = atomicMin(s_buf + a, b);
54
+ done = (old == a);
55
+ a = old;
56
+ } else
57
+ done = true;
58
+
59
+ } while (!done);
60
+ }
61
+
62
+ __global__ void
63
+ init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
64
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
65
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
66
+ const uint32_t idx = row * W + col;
67
+
68
+ if (row < H && col < W)
69
+ label[idx] = idx;
70
+ }
71
+
72
+ __global__ void
73
+ merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
74
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
75
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
76
+ const uint32_t idx = row * W + col;
77
+
78
+ if (row >= H || col >= W)
79
+ return;
80
+
81
+ uint32_t P = 0;
82
+
83
+ if (img[idx])
84
+ P |= 0x777;
85
+ if (row + 1 < H && img[idx + W])
86
+ P |= 0x777 << 4;
87
+ if (col + 1 < W && img[idx + 1])
88
+ P |= 0x777 << 1;
89
+
90
+ if (col == 0)
91
+ P &= 0xEEEE;
92
+ if (col + 1 >= W)
93
+ P &= 0x3333;
94
+ else if (col + 2 >= W)
95
+ P &= 0x7777;
96
+
97
+ if (row == 0)
98
+ P &= 0xFFF0;
99
+ if (row + 1 >= H)
100
+ P &= 0xFF;
101
+
102
+ if (P > 0) {
103
+ // If need check about top-left pixel(if flag the first bit) and hit the
104
+ // top-left pixel
105
+ if (hasBit(P, 0) && img[idx - W - 1]) {
106
+ union_(label, idx, idx - 2 * W - 2); // top left block
107
+ }
108
+
109
+ if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
110
+ union_(label, idx, idx - 2 * W); // top bottom block
111
+
112
+ if (hasBit(P, 3) && img[idx + 2 - W])
113
+ union_(label, idx, idx - 2 * W + 2); // top right block
114
+
115
+ if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
116
+ union_(label, idx, idx - 2); // just left block
117
+ }
118
+ }
119
+
120
+ __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
121
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
122
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
123
+ const uint32_t idx = row * W + col;
124
+
125
+ if (row < H && col < W)
126
+ find_n_compress(label, idx);
127
+ }
128
+
129
+ __global__ void final_labeling(
130
+ const uint8_t* img,
131
+ int32_t* label,
132
+ const int32_t W,
133
+ const int32_t H) {
134
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
135
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
136
+ const uint32_t idx = row * W + col;
137
+
138
+ if (row >= H || col >= W)
139
+ return;
140
+
141
+ int32_t y = label[idx] + 1;
142
+
143
+ if (img[idx])
144
+ label[idx] = y;
145
+ else
146
+ label[idx] = 0;
147
+
148
+ if (col + 1 < W) {
149
+ if (img[idx + 1])
150
+ label[idx + 1] = y;
151
+ else
152
+ label[idx + 1] = 0;
153
+
154
+ if (row + 1 < H) {
155
+ if (img[idx + W + 1])
156
+ label[idx + W + 1] = y;
157
+ else
158
+ label[idx + W + 1] = 0;
159
+ }
160
+ }
161
+
162
+ if (row + 1 < H) {
163
+ if (img[idx + W])
164
+ label[idx + W] = y;
165
+ else
166
+ label[idx + W] = 0;
167
+ }
168
+ }
169
+
170
+ __global__ void init_counting(
171
+ const int32_t* label,
172
+ int32_t* count_init,
173
+ const int32_t W,
174
+ const int32_t H) {
175
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
176
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
177
+ const uint32_t idx = row * W + col;
178
+
179
+ if (row >= H || col >= W)
180
+ return;
181
+
182
+ int32_t y = label[idx];
183
+ if (y > 0) {
184
+ int32_t count_idx = y - 1;
185
+ atomicAdd(count_init + count_idx, 1);
186
+ }
187
+ }
188
+
189
+ __global__ void final_counting(
190
+ const int32_t* label,
191
+ const int32_t* count_init,
192
+ int32_t* count_final,
193
+ const int32_t W,
194
+ const int32_t H) {
195
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
196
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
197
+ const uint32_t idx = row * W + col;
198
+
199
+ if (row >= H || col >= W)
200
+ return;
201
+
202
+ int32_t y = label[idx];
203
+ if (y > 0) {
204
+ int32_t count_idx = y - 1;
205
+ count_final[idx] = count_init[count_idx];
206
+ } else {
207
+ count_final[idx] = 0;
208
+ }
209
+ }
210
+
211
+ } // namespace cc2d
212
+
213
+ std::vector<torch::Tensor> get_connected_componnets(
214
+ const torch::Tensor& inputs) {
215
+ AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
216
+ AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
217
+ AT_ASSERTM(
218
+ inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
219
+
220
+ const uint32_t N = inputs.size(0);
221
+ const uint32_t C = inputs.size(1);
222
+ const uint32_t H = inputs.size(2);
223
+ const uint32_t W = inputs.size(3);
224
+
225
+ AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
226
+ AT_ASSERTM((H % 2) == 0, "height must be an even number");
227
+ AT_ASSERTM((W % 2) == 0, "width must be an even number");
228
+
229
+ // label must be uint32_t
230
+ auto label_options =
231
+ torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
232
+ torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
233
+ torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
234
+ torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
235
+
236
+ dim3 grid = dim3(
237
+ ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
238
+ ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
239
+ dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
240
+ dim3 grid_count =
241
+ dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
242
+ dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
243
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
244
+
245
+ for (int n = 0; n < N; n++) {
246
+ uint32_t offset = n * H * W;
247
+
248
+ cc2d::init_labeling<<<grid, block, 0, stream>>>(
249
+ labels.data_ptr<int32_t>() + offset, W, H);
250
+ cc2d::merge<<<grid, block, 0, stream>>>(
251
+ inputs.data_ptr<uint8_t>() + offset,
252
+ labels.data_ptr<int32_t>() + offset,
253
+ W,
254
+ H);
255
+ cc2d::compression<<<grid, block, 0, stream>>>(
256
+ labels.data_ptr<int32_t>() + offset, W, H);
257
+ cc2d::final_labeling<<<grid, block, 0, stream>>>(
258
+ inputs.data_ptr<uint8_t>() + offset,
259
+ labels.data_ptr<int32_t>() + offset,
260
+ W,
261
+ H);
262
+
263
+ // get the counting of each pixel
264
+ cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
265
+ labels.data_ptr<int32_t>() + offset,
266
+ counts_init.data_ptr<int32_t>() + offset,
267
+ W,
268
+ H);
269
+ cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
270
+ labels.data_ptr<int32_t>() + offset,
271
+ counts_init.data_ptr<int32_t>() + offset,
272
+ counts_final.data_ptr<int32_t>() + offset,
273
+ W,
274
+ H);
275
+ }
276
+
277
+ // returned values are [labels, counts]
278
+ std::vector<torch::Tensor> outputs;
279
+ outputs.push_back(labels);
280
+ outputs.push_back(counts_final);
281
+ return outputs;
282
+ }
283
+
284
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
285
+ m.def(
286
+ "get_connected_componnets",
287
+ &get_connected_componnets,
288
+ "get_connected_componnets");
289
+ }
sam2/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (146 Bytes). View file
 
sam2/modeling/__pycache__/memory_attention.cpython-312.pyc ADDED
Binary file (6.79 kB). View file
 
sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc ADDED
Binary file (7.82 kB). View file
 
sam2/modeling/__pycache__/position_encoding.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
sam2/modeling/__pycache__/sam2_base.cpython-312.pyc ADDED
Binary file (30.6 kB). View file
 
sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc ADDED
Binary file (17.4 kB). View file
 
sam2/modeling/backbones/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (167 Bytes). View file
 
sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc ADDED
Binary file (5.47 kB). View file
 
sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.31 kB). View file
 
sam2/modeling/backbones/hieradet.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from functools import partial
9
+ from typing import List, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from iopath.common.file_io import g_pathmgr
15
+
16
+ from sam2.modeling.backbones.utils import (
17
+ PatchEmbed,
18
+ window_partition,
19
+ window_unpartition,
20
+ )
21
+
22
+ from sam2.modeling.sam2_utils import DropPath, MLP
23
+
24
+
25
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26
+ if pool is None:
27
+ return x
28
+ # (B, H, W, C) -> (B, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+ x = pool(x)
31
+ # (B, C, H', W') -> (B, H', W', C)
32
+ x = x.permute(0, 2, 3, 1)
33
+ if norm:
34
+ x = norm(x)
35
+
36
+ return x
37
+
38
+
39
+ class MultiScaleAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ dim_out: int,
44
+ num_heads: int,
45
+ q_pool: nn.Module = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.dim = dim
50
+ self.dim_out = dim_out
51
+ self.num_heads = num_heads
52
+ self.q_pool = q_pool
53
+ self.qkv = nn.Linear(dim, dim_out * 3)
54
+ self.proj = nn.Linear(dim_out, dim_out)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ B, H, W, _ = x.shape
58
+ # qkv with shape (B, H * W, 3, nHead, C)
59
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60
+ # q, k, v with shape (B, H * W, nheads, C)
61
+ q, k, v = torch.unbind(qkv, 2)
62
+
63
+ # Q pooling (for downsample at stage changes)
64
+ if self.q_pool:
65
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66
+ H, W = q.shape[1:3] # downsampled shape
67
+ q = q.reshape(B, H * W, self.num_heads, -1)
68
+
69
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70
+ x = F.scaled_dot_product_attention(
71
+ q.transpose(1, 2),
72
+ k.transpose(1, 2),
73
+ v.transpose(1, 2),
74
+ )
75
+ # Transpose back
76
+ x = x.transpose(1, 2)
77
+ x = x.reshape(B, H, W, -1)
78
+
79
+ x = self.proj(x)
80
+
81
+ return x
82
+
83
+
84
+ class MultiScaleBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ drop_path: float = 0.0,
92
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
93
+ q_stride: Tuple[int, int] = None,
94
+ act_layer: nn.Module = nn.GELU,
95
+ window_size: int = 0,
96
+ ):
97
+ super().__init__()
98
+
99
+ if isinstance(norm_layer, str):
100
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101
+
102
+ self.dim = dim
103
+ self.dim_out = dim_out
104
+ self.norm1 = norm_layer(dim)
105
+
106
+ self.window_size = window_size
107
+
108
+ self.pool, self.q_stride = None, q_stride
109
+ if self.q_stride:
110
+ self.pool = nn.MaxPool2d(
111
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
112
+ )
113
+
114
+ self.attn = MultiScaleAttention(
115
+ dim,
116
+ dim_out,
117
+ num_heads=num_heads,
118
+ q_pool=self.pool,
119
+ )
120
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121
+
122
+ self.norm2 = norm_layer(dim_out)
123
+ self.mlp = MLP(
124
+ dim_out,
125
+ int(dim_out * mlp_ratio),
126
+ dim_out,
127
+ num_layers=2,
128
+ activation=act_layer,
129
+ )
130
+
131
+ if dim != dim_out:
132
+ self.proj = nn.Linear(dim, dim_out)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ shortcut = x # B, H, W, C
136
+ x = self.norm1(x)
137
+
138
+ # Skip connection
139
+ if self.dim != self.dim_out:
140
+ shortcut = do_pool(self.proj(x), self.pool)
141
+
142
+ # Window partition
143
+ window_size = self.window_size
144
+ if window_size > 0:
145
+ H, W = x.shape[1], x.shape[2]
146
+ x, pad_hw = window_partition(x, window_size)
147
+
148
+ # Window Attention + Q Pooling (if stage change)
149
+ x = self.attn(x)
150
+ if self.q_stride:
151
+ # Shapes have changed due to Q pooling
152
+ window_size = self.window_size // self.q_stride[0]
153
+ H, W = shortcut.shape[1:3]
154
+
155
+ pad_h = (window_size - H % window_size) % window_size
156
+ pad_w = (window_size - W % window_size) % window_size
157
+ pad_hw = (H + pad_h, W + pad_w)
158
+
159
+ # Reverse window partition
160
+ if self.window_size > 0:
161
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
162
+
163
+ x = shortcut + self.drop_path(x)
164
+ # MLP
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class Hiera(nn.Module):
170
+ """
171
+ Reference: https://arxiv.org/abs/2306.00989
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ embed_dim: int = 96, # initial embed dim
177
+ num_heads: int = 1, # initial number of heads
178
+ drop_path_rate: float = 0.0, # stochastic depth
179
+ q_pool: int = 3, # number of q_pool stages
180
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
183
+ head_mul: float = 2.0, # head_mul factor at stage shift
184
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185
+ # window size per stage, when not using global att.
186
+ window_spec: Tuple[int, ...] = (
187
+ 8,
188
+ 4,
189
+ 14,
190
+ 7,
191
+ ),
192
+ # global attn in these blocks
193
+ global_att_blocks: Tuple[int, ...] = (
194
+ 12,
195
+ 16,
196
+ 20,
197
+ ),
198
+ weights_path=None,
199
+ return_interm_layers=True, # return feats from every stage
200
+ ):
201
+ super().__init__()
202
+
203
+ assert len(stages) == len(window_spec)
204
+ self.window_spec = window_spec
205
+
206
+ depth = sum(stages)
207
+ self.q_stride = q_stride
208
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
210
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211
+ self.return_interm_layers = return_interm_layers
212
+
213
+ self.patch_embed = PatchEmbed(
214
+ embed_dim=embed_dim,
215
+ )
216
+ # Which blocks have global att?
217
+ self.global_att_blocks = global_att_blocks
218
+
219
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
+ self.pos_embed = nn.Parameter(
222
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
+ )
224
+ self.pos_embed_window = nn.Parameter(
225
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
+ )
227
+
228
+ dpr = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
+ ] # stochastic depth decay rule
231
+
232
+ cur_stage = 1
233
+ self.blocks = nn.ModuleList()
234
+
235
+ for i in range(depth):
236
+ dim_out = embed_dim
237
+ # lags by a block, so first block of
238
+ # next stage uses an initial window size
239
+ # of previous stage and final window size of current stage
240
+ window_size = self.window_spec[cur_stage - 1]
241
+
242
+ if self.global_att_blocks is not None:
243
+ window_size = 0 if i in self.global_att_blocks else window_size
244
+
245
+ if i - 1 in self.stage_ends:
246
+ dim_out = int(embed_dim * dim_mul)
247
+ num_heads = int(num_heads * head_mul)
248
+ cur_stage += 1
249
+
250
+ block = MultiScaleBlock(
251
+ dim=embed_dim,
252
+ dim_out=dim_out,
253
+ num_heads=num_heads,
254
+ drop_path=dpr[i],
255
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
256
+ window_size=window_size,
257
+ )
258
+
259
+ embed_dim = dim_out
260
+ self.blocks.append(block)
261
+
262
+ self.channel_list = (
263
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
+ if return_interm_layers
265
+ else [self.blocks[-1].dim_out]
266
+ )
267
+
268
+ if weights_path is not None:
269
+ with g_pathmgr.open(weights_path, "rb") as f:
270
+ chkpt = torch.load(f, map_location="cpu")
271
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272
+
273
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274
+ h, w = hw
275
+ window_embed = self.pos_embed_window
276
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277
+ pos_embed = pos_embed + window_embed.tile(
278
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279
+ )
280
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
281
+ return pos_embed
282
+
283
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284
+ x = self.patch_embed(x)
285
+ # x: (B, H, W, C)
286
+
287
+ # Add pos embed
288
+ x = x + self._get_pos_embed(x.shape[1:3])
289
+
290
+ outputs = []
291
+ for i, blk in enumerate(self.blocks):
292
+ x = blk(x)
293
+ if (i == self.stage_ends[-1]) or (
294
+ i in self.stage_ends and self.return_interm_layers
295
+ ):
296
+ feats = x.permute(0, 3, 1, 2)
297
+ outputs.append(feats)
298
+
299
+ return outputs
300
+
301
+ def get_layer_id(self, layer_name):
302
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303
+ num_layers = self.get_num_layers()
304
+
305
+ if layer_name.find("rel_pos") != -1:
306
+ return num_layers + 1
307
+ elif layer_name.find("pos_embed") != -1:
308
+ return 0
309
+ elif layer_name.find("patch_embed") != -1:
310
+ return 0
311
+ elif layer_name.find("blocks") != -1:
312
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313
+ else:
314
+ return num_layers + 1
315
+
316
+ def get_num_layers(self) -> int:
317
+ return len(self.blocks)
sam2/modeling/backbones/image_encoder.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ImageEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ trunk: nn.Module,
18
+ neck: nn.Module,
19
+ scalp: int = 0,
20
+ ):
21
+ super().__init__()
22
+ self.trunk = trunk
23
+ self.neck = neck
24
+ self.scalp = scalp
25
+ assert (
26
+ self.trunk.channel_list == self.neck.backbone_channel_list
27
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28
+
29
+ def forward(self, sample: torch.Tensor):
30
+ # Forward through backbone
31
+ features, pos = self.neck(self.trunk(sample))
32
+ if self.scalp > 0:
33
+ # Discard the lowest resolution features
34
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
35
+
36
+ src = features[-1]
37
+ output = {
38
+ "vision_features": src,
39
+ "vision_pos_enc": pos,
40
+ "backbone_fpn": features,
41
+ }
42
+ return output
43
+
44
+
45
+ class FpnNeck(nn.Module):
46
+ """
47
+ A modified variant of Feature Pyramid Network (FPN) neck
48
+ (we remove output conv and also do bicubic interpolation similar to ViT
49
+ pos embed interpolation)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ position_encoding: nn.Module,
55
+ d_model: int,
56
+ backbone_channel_list: List[int],
57
+ kernel_size: int = 1,
58
+ stride: int = 1,
59
+ padding: int = 0,
60
+ fpn_interp_model: str = "bilinear",
61
+ fuse_type: str = "sum",
62
+ fpn_top_down_levels: Optional[List[int]] = None,
63
+ ):
64
+ """Initialize the neck
65
+ :param trunk: the backbone
66
+ :param position_encoding: the positional encoding to use
67
+ :param d_model: the dimension of the model
68
+ :param neck_norm: the normalization to use
69
+ """
70
+ super().__init__()
71
+ self.position_encoding = position_encoding
72
+ self.convs = nn.ModuleList()
73
+ self.backbone_channel_list = backbone_channel_list
74
+ self.d_model = d_model
75
+ for dim in backbone_channel_list:
76
+ current = nn.Sequential()
77
+ current.add_module(
78
+ "conv",
79
+ nn.Conv2d(
80
+ in_channels=dim,
81
+ out_channels=d_model,
82
+ kernel_size=kernel_size,
83
+ stride=stride,
84
+ padding=padding,
85
+ ),
86
+ )
87
+
88
+ self.convs.append(current)
89
+ self.fpn_interp_model = fpn_interp_model
90
+ assert fuse_type in ["sum", "avg"]
91
+ self.fuse_type = fuse_type
92
+
93
+ # levels to have top-down features in its outputs
94
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95
+ # have top-down propagation, while outputs of level 0 and level 1 have only
96
+ # lateral features from the same backbone level.
97
+ if fpn_top_down_levels is None:
98
+ # default is to have top-down features on all levels
99
+ fpn_top_down_levels = range(len(self.convs))
100
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
101
+
102
+ def forward(self, xs: List[torch.Tensor]):
103
+
104
+ out = [None] * len(self.convs)
105
+ pos = [None] * len(self.convs)
106
+ assert len(xs) == len(self.convs)
107
+ # fpn forward pass
108
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109
+ prev_features = None
110
+ # forward in top-down order (from low to high resolution)
111
+ n = len(self.convs) - 1
112
+ for i in range(n, -1, -1):
113
+ x = xs[i]
114
+ lateral_features = self.convs[n - i](x)
115
+ if i in self.fpn_top_down_levels and prev_features is not None:
116
+ top_down_features = F.interpolate(
117
+ prev_features.to(dtype=torch.float32),
118
+ scale_factor=2.0,
119
+ mode=self.fpn_interp_model,
120
+ align_corners=(
121
+ None if self.fpn_interp_model == "nearest" else False
122
+ ),
123
+ antialias=False,
124
+ )
125
+ prev_features = lateral_features + top_down_features
126
+ if self.fuse_type == "avg":
127
+ prev_features /= 2
128
+ else:
129
+ prev_features = lateral_features
130
+ x_out = prev_features
131
+ out[i] = x_out
132
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133
+
134
+ return out, pos
sam2/modeling/backbones/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+ Returns:
23
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
24
+ (Hp, Wp): padded height and width before partition
25
+ """
26
+ B, H, W, C = x.shape
27
+
28
+ pad_h = (window_size - H % window_size) % window_size
29
+ pad_w = (window_size - W % window_size) % window_size
30
+ if pad_h > 0 or pad_w > 0:
31
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32
+ Hp, Wp = H + pad_h, W + pad_w
33
+
34
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35
+ windows = (
36
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ )
38
+ return windows, (Hp, Wp)
39
+
40
+
41
+ def window_unpartition(windows, window_size, pad_hw, hw):
42
+ """
43
+ Window unpartition into original sequences and removing padding.
44
+ Args:
45
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46
+ window_size (int): window size.
47
+ pad_hw (Tuple): padded height and width (Hp, Wp).
48
+ hw (Tuple): original height and width (H, W) before padding.
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(
56
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57
+ )
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ """
67
+ Image to Patch Embedding.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ kernel_size: Tuple[int, ...] = (7, 7),
73
+ stride: Tuple[int, ...] = (4, 4),
74
+ padding: Tuple[int, ...] = (3, 3),
75
+ in_chans: int = 3,
76
+ embed_dim: int = 768,
77
+ ):
78
+ """
79
+ Args:
80
+ kernel_size (Tuple): kernel size of the projection layer.
81
+ stride (Tuple): stride of the projection layer.
82
+ padding (Tuple): padding size of the projection layer.
83
+ in_chans (int): Number of input image channels.
84
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
85
+ """
86
+ super().__init__()
87
+ self.proj = nn.Conv2d(
88
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ x = self.proj(x)
93
+ # B C H W -> B H W C
94
+ x = x.permute(0, 2, 3, 1)
95
+ return x
sam2/modeling/memory_attention.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from sam2.modeling.sam.transformer import RoPEAttention
13
+
14
+ from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
+
16
+
17
+ class MemoryAttentionLayer(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ activation: str,
22
+ cross_attention: nn.Module,
23
+ d_model: int,
24
+ dim_feedforward: int,
25
+ dropout: float,
26
+ pos_enc_at_attn: bool,
27
+ pos_enc_at_cross_attn_keys: bool,
28
+ pos_enc_at_cross_attn_queries: bool,
29
+ self_attention: nn.Module,
30
+ ):
31
+ super().__init__()
32
+ self.d_model = d_model
33
+ self.dim_feedforward = dim_feedforward
34
+ self.dropout_value = dropout
35
+ self.self_attn = self_attention
36
+ self.cross_attn_image = cross_attention
37
+
38
+ # Implementation of Feedforward model
39
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
40
+ self.dropout = nn.Dropout(dropout)
41
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
42
+
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2 = nn.LayerNorm(d_model)
45
+ self.norm3 = nn.LayerNorm(d_model)
46
+ self.dropout1 = nn.Dropout(dropout)
47
+ self.dropout2 = nn.Dropout(dropout)
48
+ self.dropout3 = nn.Dropout(dropout)
49
+
50
+ self.activation_str = activation
51
+ self.activation = get_activation_fn(activation)
52
+
53
+ # Where to add pos enc
54
+ self.pos_enc_at_attn = pos_enc_at_attn
55
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57
+
58
+ def _forward_sa(self, tgt, query_pos):
59
+ # Self-Attention
60
+ tgt2 = self.norm1(tgt)
61
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62
+ tgt2 = self.self_attn(q, k, v=tgt2)
63
+ tgt = tgt + self.dropout1(tgt2)
64
+ return tgt
65
+
66
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67
+ kwds = {}
68
+ if num_k_exclude_rope > 0:
69
+ assert isinstance(self.cross_attn_image, RoPEAttention)
70
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71
+
72
+ # Cross-Attention
73
+ tgt2 = self.norm2(tgt)
74
+ tgt2 = self.cross_attn_image(
75
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77
+ v=memory,
78
+ **kwds,
79
+ )
80
+ tgt = tgt + self.dropout2(tgt2)
81
+ return tgt
82
+
83
+ def forward(
84
+ self,
85
+ tgt,
86
+ memory,
87
+ pos: Optional[Tensor] = None,
88
+ query_pos: Optional[Tensor] = None,
89
+ num_k_exclude_rope: int = 0,
90
+ ) -> torch.Tensor:
91
+
92
+ # Self-Attn, Cross-Attn
93
+ tgt = self._forward_sa(tgt, query_pos)
94
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95
+ # MLP
96
+ tgt2 = self.norm3(tgt)
97
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt
100
+
101
+
102
+ class MemoryAttention(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ pos_enc_at_input: bool,
107
+ layer: nn.Module,
108
+ num_layers: int,
109
+ batch_first: bool = True, # Do layers expect batch first input?
110
+ ):
111
+ super().__init__()
112
+ self.d_model = d_model
113
+ self.layers = get_clones(layer, num_layers)
114
+ self.num_layers = num_layers
115
+ self.norm = nn.LayerNorm(d_model)
116
+ self.pos_enc_at_input = pos_enc_at_input
117
+ self.batch_first = batch_first
118
+
119
+ def forward(
120
+ self,
121
+ curr: torch.Tensor, # self-attention inputs
122
+ memory: torch.Tensor, # cross-attention inputs
123
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126
+ ):
127
+ if isinstance(curr, list):
128
+ assert isinstance(curr_pos, list)
129
+ assert len(curr) == len(curr_pos) == 1
130
+ curr, curr_pos = (
131
+ curr[0],
132
+ curr_pos[0],
133
+ )
134
+
135
+ assert (
136
+ curr.shape[1] == memory.shape[1]
137
+ ), "Batch size must be the same for curr and memory"
138
+
139
+ output = curr
140
+ if self.pos_enc_at_input and curr_pos is not None:
141
+ output = output + 0.1 * curr_pos
142
+
143
+ if self.batch_first:
144
+ # Convert to batch first
145
+ output = output.transpose(0, 1)
146
+ curr_pos = curr_pos.transpose(0, 1)
147
+ memory = memory.transpose(0, 1)
148
+ memory_pos = memory_pos.transpose(0, 1)
149
+
150
+ for layer in self.layers:
151
+ kwds = {}
152
+ if isinstance(layer.cross_attn_image, RoPEAttention):
153
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154
+
155
+ output = layer(
156
+ tgt=output,
157
+ memory=memory,
158
+ pos=memory_pos,
159
+ query_pos=curr_pos,
160
+ **kwds,
161
+ )
162
+ normed_output = self.norm(output)
163
+
164
+ if self.batch_first:
165
+ # Convert back to seq first
166
+ normed_output = normed_output.transpose(0, 1)
167
+ curr_pos = curr_pos.transpose(0, 1)
168
+
169
+ return normed_output
sam2/modeling/memory_encoder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
sam2/modeling/position_encoding.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention Is All You Need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ ):
29
+ super().__init__()
30
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
31
+ self.num_pos_feats = num_pos_feats // 2
32
+ self.temperature = temperature
33
+ self.normalize = normalize
34
+ if scale is not None and normalize is False:
35
+ raise ValueError("normalize should be True if scale is passed")
36
+ if scale is None:
37
+ scale = 2 * math.pi
38
+ self.scale = scale
39
+
40
+ self.cache = {}
41
+
42
+ def _encode_xy(self, x, y):
43
+ # The positions are expected to be normalized
44
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
45
+ x_embed = x * self.scale
46
+ y_embed = y * self.scale
47
+
48
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
49
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
50
+
51
+ pos_x = x_embed[:, None] / dim_t
52
+ pos_y = y_embed[:, None] / dim_t
53
+ pos_x = torch.stack(
54
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
55
+ ).flatten(1)
56
+ pos_y = torch.stack(
57
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
58
+ ).flatten(1)
59
+ return pos_x, pos_y
60
+
61
+ @torch.no_grad()
62
+ def encode_boxes(self, x, y, w, h):
63
+ pos_x, pos_y = self._encode_xy(x, y)
64
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
65
+ return pos
66
+
67
+ encode = encode_boxes # Backwards compatibility
68
+
69
+ @torch.no_grad()
70
+ def encode_points(self, x, y, labels):
71
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
72
+ assert bx == by and nx == ny and bx == bl and nx == nl
73
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
74
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
75
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
76
+ return pos
77
+
78
+ @torch.no_grad()
79
+ def forward(self, x: torch.Tensor):
80
+ cache_key = (x.shape[-2], x.shape[-1])
81
+ if cache_key in self.cache:
82
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
83
+ y_embed = (
84
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
85
+ .view(1, -1, 1)
86
+ .repeat(x.shape[0], 1, x.shape[-1])
87
+ )
88
+ x_embed = (
89
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
90
+ .view(1, 1, -1)
91
+ .repeat(x.shape[0], x.shape[-2], 1)
92
+ )
93
+
94
+ if self.normalize:
95
+ eps = 1e-6
96
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
97
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
98
+
99
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
100
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101
+
102
+ pos_x = x_embed[:, :, :, None] / dim_t
103
+ pos_y = y_embed[:, :, :, None] / dim_t
104
+ pos_x = torch.stack(
105
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
106
+ ).flatten(3)
107
+ pos_y = torch.stack(
108
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
109
+ ).flatten(3)
110
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
111
+ self.cache[cache_key] = pos[0]
112
+ return pos
113
+
114
+
115
+ class PositionEmbeddingRandom(nn.Module):
116
+ """
117
+ Positional encoding using random spatial frequencies.
118
+ """
119
+
120
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
121
+ super().__init__()
122
+ if scale is None or scale <= 0.0:
123
+ scale = 1.0
124
+ self.register_buffer(
125
+ "positional_encoding_gaussian_matrix",
126
+ scale * torch.randn((2, num_pos_feats)),
127
+ )
128
+
129
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
130
+ """Positionally encode points that are normalized to [0,1]."""
131
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
132
+ coords = 2 * coords - 1
133
+ coords = coords @ self.positional_encoding_gaussian_matrix
134
+ coords = 2 * np.pi * coords
135
+ # outputs d_1 x ... x d_n x C shape
136
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
137
+
138
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
139
+ """Generate positional encoding for a grid of the specified size."""
140
+ h, w = size
141
+ device: Any = self.positional_encoding_gaussian_matrix.device
142
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
143
+ y_embed = grid.cumsum(dim=0) - 0.5
144
+ x_embed = grid.cumsum(dim=1) - 0.5
145
+ y_embed = y_embed / h
146
+ x_embed = x_embed / w
147
+
148
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
149
+ return pe.permute(2, 0, 1) # C x H x W
150
+
151
+ def forward_with_coords(
152
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
153
+ ) -> torch.Tensor:
154
+ """Positionally encode points that are not normalized to [0,1]."""
155
+ coords = coords_input.clone()
156
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
157
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
158
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
159
+
160
+
161
+ # Rotary Positional Encoding, adapted from:
162
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
163
+ # 2. https://github.com/naver-ai/rope-vit
164
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
165
+
166
+
167
+ def init_t_xy(end_x: int, end_y: int):
168
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
169
+ t_x = (t % end_x).float()
170
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
171
+ return t_x, t_y
172
+
173
+
174
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
175
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
177
+
178
+ t_x, t_y = init_t_xy(end_x, end_y)
179
+ freqs_x = torch.outer(t_x, freqs_x)
180
+ freqs_y = torch.outer(t_y, freqs_y)
181
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
182
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
183
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
184
+
185
+
186
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
187
+ ndim = x.ndim
188
+ assert 0 <= 1 < ndim
189
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
190
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
191
+ return freqs_cis.view(*shape)
192
+
193
+
194
+ def apply_rotary_enc(
195
+ xq: torch.Tensor,
196
+ xk: torch.Tensor,
197
+ freqs_cis: torch.Tensor,
198
+ repeat_freqs_k: bool = False,
199
+ ):
200
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
201
+ xk_ = (
202
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
203
+ if xk.shape[-2] != 0
204
+ else None
205
+ )
206
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
207
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
208
+ if xk_ is None:
209
+ # no keys to rotate, due to dropout
210
+ return xq_out.type_as(xq).to(xq.device), xk
211
+ # repeat freqs along seq_len dim to match k seq_len
212
+ if repeat_freqs_k:
213
+ r = xk_.shape[-2] // xq_.shape[-2]
214
+ if freqs_cis.is_cuda:
215
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
216
+ else:
217
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
218
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
219
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
220
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
221
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
sam2/modeling/sam/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (150 Bytes). View file
 
sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc ADDED
Binary file (12.6 kB). View file
 
sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc ADDED
Binary file (9.44 kB). View file
 
sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (15.3 kB). View file