chongzhou commited on
Commit
9bc4638
·
1 Parent(s): e80adb9

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +30 -0
  2. CODE_OF_CONDUCT.md +80 -0
  3. CONTRIBUTING.md +31 -0
  4. LICENSE +201 -0
  5. README.md +14 -4
  6. app.py +499 -0
  7. checkpoints/download_ckpts.sh +59 -0
  8. checkpoints/edgetam.pt +3 -0
  9. convert_weights.py +36 -0
  10. examples/01_breakdancer.mp4 +3 -0
  11. examples/01_dog.mp4 +3 -0
  12. examples/02_cups.mp4 +3 -0
  13. examples/02_hummingbird.mp4 +3 -0
  14. examples/03_blocks.mp4 +3 -0
  15. examples/03_skateboarder.mp4 +3 -0
  16. examples/04_coffee.mp4 +3 -0
  17. examples/04_octopus.mp4 +3 -0
  18. examples/05_default_juggle.mp4 +3 -0
  19. examples/05_landing_dog_soccer.mp4 +3 -0
  20. examples/06_pingpong.mp4 +3 -0
  21. examples/07_snowboarder.mp4 +3 -0
  22. examples/08_driving.mp4 +3 -0
  23. examples/09_birdcartoon.mp4 +3 -0
  24. examples/10_cloth_magic.mp4 +3 -0
  25. examples/11_polevault.mp4 +3 -0
  26. examples/12_hideandseek.mp4 +3 -0
  27. examples/13_butterfly.mp4 +3 -0
  28. examples/14_social_dog_training.mp4 +3 -0
  29. examples/15_cricket.mp4 +3 -0
  30. examples/16_robotarm.mp4 +3 -0
  31. examples/17_childrendancing.mp4 +3 -0
  32. examples/18_threedogs.mp4 +3 -0
  33. examples/19_cyclist.mp4 +3 -0
  34. examples/20_doughkneading.mp4 +3 -0
  35. examples/21_biker.mp4 +3 -0
  36. examples/22_dogskateboarder.mp4 +3 -0
  37. examples/23_racecar.mp4 +3 -0
  38. examples/24_clownfish.mp4 +3 -0
  39. pyproject.toml +6 -0
  40. requirements.txt +15 -0
  41. sam2/__init__.py +11 -0
  42. sam2/automatic_mask_generator.py +454 -0
  43. sam2/build_sam.py +171 -0
  44. sam2/configs/edgetam.yaml +138 -0
  45. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  46. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  47. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  48. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  49. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  50. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
.gitattributes CHANGED
@@ -33,3 +33,33 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/04_octopus.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/05_default_juggle.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/10_cloth_magic.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/14_social_dog_training.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ examples/18_threedogs.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ examples/01_dog.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ examples/09_birdcartoon.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ examples/04_coffee.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ examples/24_clownfish.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ examples/06_pingpong.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ examples/11_polevault.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ examples/17_childrendancing.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ examples/19_cyclist.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ examples/20_doughkneading.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ examples/02_hummingbird.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ examples/03_blocks.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ examples/23_racecar.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ examples/01_breakdancer.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ examples/15_cricket.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ examples/16_robotarm.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ examples/02_cups.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ examples/08_driving.mp4 filter=lfs diff=lfs merge=lfs -text
58
+ examples/13_butterfly.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ checkpoints/edgetam.pt filter=lfs diff=lfs merge=lfs -text
60
+ examples/05_landing_dog_soccer.mp4 filter=lfs diff=lfs merge=lfs -text
61
+ examples/07_snowboarder.mp4 filter=lfs diff=lfs merge=lfs -text
62
+ examples/12_hideandseek.mp4 filter=lfs diff=lfs merge=lfs -text
63
+ examples/21_biker.mp4 filter=lfs diff=lfs merge=lfs -text
64
+ examples/22_dogskateboarder.mp4 filter=lfs diff=lfs merge=lfs -text
65
+ examples/03_skateboarder.mp4 filter=lfs diff=lfs merge=lfs -text
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to EdgeTAM
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Meta's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to EdgeTAM, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,23 @@
1
  ---
2
  title: EdgeTAM
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.28.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: EdgeTAM
3
+ emoji: 🚀
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: On-Device Track Anything Model
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ ```
17
+ @misc{zhou2025edgetam,
18
+ title={EdgeTAM: On-Device Track Anything Model},
19
+ author={Zhou, Chong and Zhu, Chenchen and Xiong, Yunyang and Suri, Saksham and Xiao, Fanyi and Wu, Lemeng and Krishnamoorthi, Raghuraman and Dai, Bo and Loy, Chen Change and Chandra, Vikas and Soran, Bilge},
20
+ journal={arXiv preprint arXiv:2501.07256},
21
+ year={2025}
22
+ }
23
+ ```
app.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ import os
9
+ from datetime import datetime
10
+
11
+ import gradio as gr
12
+
13
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
14
+ import tempfile
15
+
16
+ import cv2
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ import spaces
20
+ import torch
21
+
22
+ from moviepy.editor import ImageSequenceClip
23
+ from PIL import Image
24
+ from sam2.build_sam import build_sam2_video_predictor
25
+
26
+ # Description
27
+ title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"
28
+
29
+ description_p = """# Instructions
30
+ <ol>
31
+ <li> Upload one video or click one example video</li>
32
+ <li> Click 'include' point type, select the object to segment and track</li>
33
+ <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
34
+ <li> Click the 'Track' button to obtain the masked video </li>
35
+ </ol>
36
+ """
37
+
38
+ # examples
39
+ examples = [
40
+ ["examples/01_dog.mp4"],
41
+ ["examples/02_cups.mp4"],
42
+ ["examples/03_blocks.mp4"],
43
+ ["examples/04_coffee.mp4"],
44
+ ["examples/05_default_juggle.mp4"],
45
+ ["examples/01_breakdancer.mp4"],
46
+ ["examples/02_hummingbird.mp4"],
47
+ ["examples/03_skateboarder.mp4"],
48
+ ["examples/04_octopus.mp4"],
49
+ ["examples/05_landing_dog_soccer.mp4"],
50
+ ["examples/06_pingpong.mp4"],
51
+ ["examples/07_snowboarder.mp4"],
52
+ ["examples/08_driving.mp4"],
53
+ ["examples/09_birdcartoon.mp4"],
54
+ ["examples/10_cloth_magic.mp4"],
55
+ ["examples/11_polevault.mp4"],
56
+ ["examples/12_hideandseek.mp4"],
57
+ ["examples/13_butterfly.mp4"],
58
+ ["examples/14_social_dog_training.mp4"],
59
+ ["examples/15_cricket.mp4"],
60
+ ["examples/16_robotarm.mp4"],
61
+ ["examples/17_childrendancing.mp4"],
62
+ ["examples/18_threedogs.mp4"],
63
+ ["examples/19_cyclist.mp4"],
64
+ ["examples/20_doughkneading.mp4"],
65
+ ["examples/21_biker.mp4"],
66
+ ["examples/22_dogskateboarder.mp4"],
67
+ ["examples/23_racecar.mp4"],
68
+ ["examples/24_clownfish.mp4"],
69
+ ]
70
+
71
+ OBJ_ID = 0
72
+
73
+
74
+ @spaces.GPU
75
+ def get_predictor(session_state):
76
+ if "predictor" not in session_state:
77
+ sam2_checkpoint = "checkpoints/edgetam.pt"
78
+ model_cfg = "edgetam.yaml"
79
+ predictor = build_sam2_video_predictor(
80
+ model_cfg, sam2_checkpoint, device="cuda"
81
+ )
82
+ print("predictor loaded")
83
+
84
+ # use bfloat16 for the entire demo
85
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
86
+ if torch.cuda.get_device_properties(0).major >= 8:
87
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
88
+ torch.backends.cuda.matmul.allow_tf32 = True
89
+ torch.backends.cudnn.allow_tf32 = True
90
+
91
+ session_state["predictor"] = predictor
92
+ return session_state["predictor"]
93
+
94
+
95
+ def get_video_fps(video_path):
96
+ # Open the video file
97
+ cap = cv2.VideoCapture(video_path)
98
+
99
+ if not cap.isOpened():
100
+ print("Error: Could not open video.")
101
+ return None
102
+
103
+ # Get the FPS of the video
104
+ fps = cap.get(cv2.CAP_PROP_FPS)
105
+
106
+ return fps
107
+
108
+
109
+ @spaces.GPU
110
+ def reset(session_state):
111
+ predictor = get_predictor(session_state)
112
+ predictor.to("cuda")
113
+ session_state["input_points"] = []
114
+ session_state["input_labels"] = []
115
+ if session_state["inference_state"] is not None:
116
+ predictor.reset_state(session_state["inference_state"])
117
+ session_state["first_frame"] = None
118
+ session_state["all_frames"] = None
119
+ session_state["inference_state"] = None
120
+ return (
121
+ None,
122
+ gr.update(open=True),
123
+ None,
124
+ None,
125
+ gr.update(value=None, visible=False),
126
+ session_state,
127
+ )
128
+
129
+
130
+ @spaces.GPU
131
+ def clear_points(session_state):
132
+ predictor = get_predictor(session_state)
133
+ predictor.to("cuda")
134
+ session_state["input_points"] = []
135
+ session_state["input_labels"] = []
136
+ if session_state["inference_state"]["tracking_has_started"]:
137
+ predictor.reset_state(session_state["inference_state"])
138
+ return (
139
+ session_state["first_frame"],
140
+ None,
141
+ gr.update(value=None, visible=False),
142
+ session_state,
143
+ )
144
+
145
+
146
+ @spaces.GPU
147
+ def preprocess_video_in(video_path, session_state):
148
+ predictor = get_predictor(session_state)
149
+ predictor.to("cuda")
150
+ if video_path is None:
151
+ return (
152
+ gr.update(open=True), # video_in_drawer
153
+ None, # points_map
154
+ None, # output_image
155
+ gr.update(value=None, visible=False), # output_video
156
+ session_state,
157
+ )
158
+
159
+ # Read the first frame
160
+ cap = cv2.VideoCapture(video_path)
161
+ if not cap.isOpened():
162
+ print("Error: Could not open video.")
163
+ return (
164
+ gr.update(open=True), # video_in_drawer
165
+ None, # points_map
166
+ None, # output_image
167
+ gr.update(value=None, visible=False), # output_video
168
+ session_state,
169
+ )
170
+
171
+ frame_number = 0
172
+ first_frame = None
173
+ all_frames = []
174
+
175
+ while True:
176
+ ret, frame = cap.read()
177
+ if not ret:
178
+ break
179
+
180
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
181
+ frame = np.array(frame)
182
+
183
+ # Store the first frame
184
+ if frame_number == 0:
185
+ first_frame = frame
186
+ all_frames.append(frame)
187
+
188
+ frame_number += 1
189
+
190
+ cap.release()
191
+ session_state["first_frame"] = copy.deepcopy(first_frame)
192
+ session_state["all_frames"] = all_frames
193
+
194
+ session_state["inference_state"] = predictor.init_state(video_path=video_path)
195
+ session_state["input_points"] = []
196
+ session_state["input_labels"] = []
197
+
198
+ return [
199
+ gr.update(open=False), # video_in_drawer
200
+ first_frame, # points_map
201
+ None, # output_image
202
+ gr.update(value=None, visible=False), # output_video
203
+ session_state,
204
+ ]
205
+
206
+
207
+ @spaces.GPU
208
+ def segment_with_points(
209
+ point_type,
210
+ session_state,
211
+ evt: gr.SelectData,
212
+ ):
213
+ predictor = get_predictor(session_state)
214
+ predictor.to("cuda")
215
+ session_state["input_points"].append(evt.index)
216
+ print(f"TRACKING INPUT POINT: {session_state['input_points']}")
217
+
218
+ if point_type == "include":
219
+ session_state["input_labels"].append(1)
220
+ elif point_type == "exclude":
221
+ session_state["input_labels"].append(0)
222
+ print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
223
+
224
+ # Open the image and get its dimensions
225
+ transparent_background = Image.fromarray(session_state["first_frame"]).convert(
226
+ "RGBA"
227
+ )
228
+ w, h = transparent_background.size
229
+
230
+ # Define the circle radius as a fraction of the smaller dimension
231
+ fraction = 0.01 # You can adjust this value as needed
232
+ radius = int(fraction * min(w, h))
233
+
234
+ # Create a transparent layer to draw on
235
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
236
+
237
+ for index, track in enumerate(session_state["input_points"]):
238
+ if session_state["input_labels"][index] == 1:
239
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
240
+ else:
241
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
242
+
243
+ # Convert the transparent layer back to an image
244
+ transparent_layer = Image.fromarray(transparent_layer, "RGBA")
245
+ selected_point_map = Image.alpha_composite(
246
+ transparent_background, transparent_layer
247
+ )
248
+
249
+ # Let's add a positive click at (x, y) = (210, 350) to get started
250
+ points = np.array(session_state["input_points"], dtype=np.float32)
251
+ # for labels, `1` means positive click and `0` means negative click
252
+ labels = np.array(session_state["input_labels"], np.int32)
253
+ _, _, out_mask_logits = predictor.add_new_points(
254
+ inference_state=session_state["inference_state"],
255
+ frame_idx=0,
256
+ obj_id=OBJ_ID,
257
+ points=points,
258
+ labels=labels,
259
+ )
260
+
261
+ mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
262
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
263
+
264
+ torch.cuda.empty_cache()
265
+ return selected_point_map, first_frame_output, session_state
266
+
267
+
268
+ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
269
+ if random_color:
270
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
271
+ else:
272
+ cmap = plt.get_cmap("tab10")
273
+ cmap_idx = 0 if obj_id is None else obj_id
274
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
275
+ h, w = mask.shape[-2:]
276
+ mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
277
+ mask = (mask * 255).astype(np.uint8)
278
+ if convert_to_image:
279
+ mask = Image.fromarray(mask, "RGBA")
280
+ return mask
281
+
282
+
283
+ @spaces.GPU
284
+ def propagate_to_all(
285
+ video_in,
286
+ session_state,
287
+ ):
288
+ predictor = get_predictor(session_state)
289
+ predictor.to("cuda")
290
+ if (
291
+ len(session_state["input_points"]) == 0
292
+ or video_in is None
293
+ or session_state["inference_state"] is None
294
+ ):
295
+ return (
296
+ None,
297
+ session_state,
298
+ )
299
+
300
+ # run propagation throughout the video and collect the results in a dict
301
+ video_segments = {} # video_segments contains the per-frame segmentation results
302
+ print("starting propagate_in_video")
303
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
304
+ session_state["inference_state"]
305
+ ):
306
+ video_segments[out_frame_idx] = {
307
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
308
+ for i, out_obj_id in enumerate(out_obj_ids)
309
+ }
310
+
311
+ # obtain the segmentation results every few frames
312
+ vis_frame_stride = 1
313
+
314
+ output_frames = []
315
+ for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
316
+ transparent_background = Image.fromarray(
317
+ session_state["all_frames"][out_frame_idx]
318
+ ).convert("RGBA")
319
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
320
+ mask_image = show_mask(out_mask)
321
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
322
+ output_frame = np.array(output_frame)
323
+ output_frames.append(output_frame)
324
+
325
+ torch.cuda.empty_cache()
326
+
327
+ # Create a video clip from the image sequence
328
+ original_fps = get_video_fps(video_in)
329
+ fps = original_fps # Frames per second
330
+ clip = ImageSequenceClip(output_frames, fps=fps)
331
+ # Write the result to a file
332
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
333
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
334
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
335
+
336
+ # Write the result to a file
337
+ clip.write_videofile(final_vid_output_path, codec="libx264")
338
+
339
+ return (
340
+ gr.update(value=final_vid_output_path),
341
+ session_state,
342
+ )
343
+
344
+
345
+ def update_ui():
346
+ return gr.update(visible=True)
347
+
348
+
349
+ with gr.Blocks() as demo:
350
+ session_state = gr.State(
351
+ {
352
+ "first_frame": None,
353
+ "all_frames": None,
354
+ "input_points": [],
355
+ "input_labels": [],
356
+ "inference_state": None,
357
+ }
358
+ )
359
+
360
+ with gr.Column():
361
+ # Title
362
+ gr.Markdown(title)
363
+ with gr.Row():
364
+
365
+ with gr.Column():
366
+ # Instructions
367
+ gr.Markdown(description_p)
368
+
369
+ with gr.Accordion("Input Video", open=True) as video_in_drawer:
370
+ video_in = gr.Video(label="Input Video", format="mp4")
371
+
372
+ with gr.Row():
373
+ point_type = gr.Radio(
374
+ label="point type",
375
+ choices=["include", "exclude"],
376
+ value="include",
377
+ scale=2,
378
+ )
379
+ propagate_btn = gr.Button("Track", scale=1, variant="primary")
380
+ clear_points_btn = gr.Button("Clear Points", scale=1)
381
+ reset_btn = gr.Button("Reset", scale=1)
382
+
383
+ points_map = gr.Image(
384
+ label="Frame with Point Prompt", type="numpy", interactive=False
385
+ )
386
+
387
+ with gr.Column():
388
+ gr.Markdown("# Try some of the examples below ⬇️")
389
+ gr.Examples(
390
+ examples=examples,
391
+ inputs=[
392
+ video_in,
393
+ ],
394
+ examples_per_page=8,
395
+ )
396
+ gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
397
+ gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
398
+ gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
399
+ output_image = gr.Image(label="Reference Mask")
400
+
401
+ output_video = gr.Video(visible=False)
402
+
403
+ # When new video is uploaded
404
+ video_in.upload(
405
+ fn=preprocess_video_in,
406
+ inputs=[
407
+ video_in,
408
+ session_state,
409
+ ],
410
+ outputs=[
411
+ video_in_drawer, # Accordion to hide uploaded video player
412
+ points_map, # Image component where we add new tracking points
413
+ output_image,
414
+ output_video,
415
+ session_state,
416
+ ],
417
+ queue=False,
418
+ )
419
+
420
+ video_in.change(
421
+ fn=preprocess_video_in,
422
+ inputs=[
423
+ video_in,
424
+ session_state,
425
+ ],
426
+ outputs=[
427
+ video_in_drawer, # Accordion to hide uploaded video player
428
+ points_map, # Image component where we add new tracking points
429
+ output_image,
430
+ output_video,
431
+ session_state,
432
+ ],
433
+ queue=False,
434
+ )
435
+
436
+ # triggered when we click on image to add new points
437
+ points_map.select(
438
+ fn=segment_with_points,
439
+ inputs=[
440
+ point_type, # "include" or "exclude"
441
+ session_state,
442
+ ],
443
+ outputs=[
444
+ points_map, # updated image with points
445
+ output_image,
446
+ session_state,
447
+ ],
448
+ queue=False,
449
+ )
450
+
451
+ # Clear every points clicked and added to the map
452
+ clear_points_btn.click(
453
+ fn=clear_points,
454
+ inputs=session_state,
455
+ outputs=[
456
+ points_map,
457
+ output_image,
458
+ output_video,
459
+ session_state,
460
+ ],
461
+ queue=False,
462
+ )
463
+
464
+ reset_btn.click(
465
+ fn=reset,
466
+ inputs=session_state,
467
+ outputs=[
468
+ video_in,
469
+ video_in_drawer,
470
+ points_map,
471
+ output_image,
472
+ output_video,
473
+ session_state,
474
+ ],
475
+ queue=False,
476
+ )
477
+
478
+ propagate_btn.click(
479
+ fn=update_ui,
480
+ inputs=[],
481
+ outputs=output_video,
482
+ queue=False,
483
+ ).then(
484
+ fn=propagate_to_all,
485
+ inputs=[
486
+ video_in,
487
+ session_state,
488
+ ],
489
+ outputs=[
490
+ output_video,
491
+ session_state,
492
+ ],
493
+ concurrency_limit=10,
494
+ queue=False,
495
+ )
496
+
497
+
498
+ demo.queue()
499
+ demo.launch()
checkpoints/download_ckpts.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ # Use either wget or curl to download the checkpoints
10
+ if command -v wget &> /dev/null; then
11
+ CMD="wget"
12
+ elif command -v curl &> /dev/null; then
13
+ CMD="curl -L -O"
14
+ else
15
+ echo "Please install wget or curl to download the checkpoints."
16
+ exit 1
17
+ fi
18
+
19
+ # Define the URLs for SAM 2 checkpoints
20
+ # SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
21
+ # sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
22
+ # sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
23
+ # sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
24
+ # sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
25
+
26
+ # Download each of the four checkpoints using wget
27
+ # echo "Downloading sam2_hiera_tiny.pt checkpoint..."
28
+ # $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
29
+
30
+ # echo "Downloading sam2_hiera_small.pt checkpoint..."
31
+ # $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
32
+
33
+ # echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
34
+ # $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
35
+
36
+ # echo "Downloading sam2_hiera_large.pt checkpoint..."
37
+ # $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
38
+
39
+ # Define the URLs for SAM 2.1 checkpoints
40
+ SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
41
+ sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
42
+ sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
43
+ sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
44
+ sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
45
+
46
+ # SAM 2.1 checkpoints
47
+ echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
48
+ $CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
49
+
50
+ echo "Downloading sam2.1_hiera_small.pt checkpoint..."
51
+ $CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
52
+
53
+ echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
54
+ $CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
55
+
56
+ echo "Downloading sam2.1_hiera_large.pt checkpoint..."
57
+ $CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
58
+
59
+ echo "All checkpoints are downloaded successfully."
checkpoints/edgetam.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed2d4850b8792c239689b043c47046ec239b6e808a3d9b6ae676c803fd8780df
3
+ size 56116523
convert_weights.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+
9
+ import torch
10
+
11
+
12
+ def main(args):
13
+ sd = torch.load(args.src, map_location="cpu")["model"]
14
+ sd = {k: v for k, v in sd.items() if "teacher" not in k}
15
+ sd = {
16
+ k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items()
17
+ }
18
+ sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()}
19
+ sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()}
20
+ sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()}
21
+ sd = {
22
+ k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items()
23
+ }
24
+ sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
25
+ sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
26
+ sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()}
27
+ sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()}
28
+ torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt"))
29
+
30
+
31
+ if __name__ == "__main__":
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument("--src", type=str, required=True)
34
+ args = parser.parse_args()
35
+
36
+ main(args)
examples/01_breakdancer.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f6b15f61e741576f658e41939d0109f15f012691a62a86820418d3ae10d1f04
3
+ size 5251367
examples/01_dog.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21425eb9b7fce7d3fb23fa17b1c07bcbe99ba9e64e3b5347d8052ba6d9c38924
3
+ size 1738970
examples/02_cups.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6320759eede8c82c7d43fe1cc569f3a389758ec37b2b1dc9e9b7b4ce4110ab13
3
+ size 2543710
examples/02_hummingbird.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79018494f7cda9ea1fc57bc474fd468ba1291624d85a1b74312029013ed9eb1e
3
+ size 1098274
examples/03_blocks.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2868f2c62f5cfa28dabbe925fae95da7f896531630a0627e5edbc784ea42413d
3
+ size 1808832
examples/03_skateboarder.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8403b992e4c114cf01ccb56c95aeba58d38aac16a5d7611580c1ceb5ef0f3ca
3
+ size 2132383
examples/04_coffee.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9a439a8bf66910db3e2a9d9ea853bc24c760c8d4ca227f4c33051b1cef28c03
3
+ size 1177415
examples/04_octopus.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d705ebd61259495da51e254cba576b2362310033099173e7781624c14fc51e3d
3
+ size 4395411
examples/05_default_juggle.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60d23773ad50389a5c2df9ba4ab05e9b4d39d21dfa7b75d15b46666ff7a4e2bb
3
+ size 1842699
examples/05_landing_dog_soccer.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65d828933e37694b47d8d22573d812a8b838b22b722a17af362a50daecfe92d9
3
+ size 2774490
examples/06_pingpong.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c5ba5fcd0abd07701f030bbdbeed937f56bd0950103a13dad664fc66a264f98
3
+ size 1341363
examples/07_snowboarder.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28108f8ae296c2b97b48768ab5efe556d57c3b631cb7d86be3b012a93fdc35ae
3
+ size 3579270
examples/08_driving.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6752358cdb184ce27e157ba9c7622dba0596ad8c9951c013ff2d027802079ae1
3
+ size 1353646
examples/09_birdcartoon.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89b3adfc1972b4a400fab3cc52da30a4bd6ee87feb7d3b5ee00e1030885f7f92
3
+ size 1839837
examples/10_cloth_magic.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:407ed3c96d070d4da764f4ae6a114b0ade5bce4fe8900a43c118ff2d636032c5
3
+ size 1377117
examples/11_polevault.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86e7ca662e5282796b1af115c76afb15d81d2ce33a15a12504c90b0831e53911
3
+ size 1712657
examples/12_hideandseek.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82fe69d56a9fa3c228e2dd4a4db05c36db36b75a2b132409400fb1f4af1102c9
3
+ size 2799452
examples/13_butterfly.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:284bb1cf4d6520d0ee1e7f3b0f100e38943410031c7f994a5f8fd95930ccdc02
3
+ size 2434416
examples/14_social_dog_training.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e364b566cfb8ef28f53615a15bf4079bdb1c6dada49828954706c8b6a96ac3f
3
+ size 1810664
examples/15_cricket.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9a7b88ba84dcc16defe2845f6c3b47cc4d2ef25e5752e87721b79ac3540e1e9
3
+ size 4976505
examples/16_robotarm.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f79112ba6937ddccc4be748d6a2ca64822ac85cdbbdeba7a7dc4e58ebfa2ce80
3
+ size 3575535
examples/17_childrendancing.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff42b6651f65431e967c810c24ebbaea1425524055735e932638e3843b42f55d
3
+ size 5722084
examples/18_threedogs.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76f5ad121c6e10b9ed6b8ea3453efe66c48952fabc964c3b1d9ab3fda0d8b5be
3
+ size 4254672
examples/19_cyclist.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c862441a39dd9c5ad823fc1709737208f923d3676cfa13d6567451b2d01f34f
3
+ size 2471696
examples/20_doughkneading.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1eaefb076968272b3e6c04def0ef8fd99524f53bfade63c58a54e4dc47fa8af9
3
+ size 1943984
examples/21_biker.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bd0799a3a5ae035f26c326baa2ae0906cb8ac57c3016cb4c0ca04959c442c08
3
+ size 1031840
examples/22_dogskateboarder.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e44764e451cce362511ff9f8774d98f5c2f971567397ec1def9d71f52767833
3
+ size 3948915
examples/23_racecar.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df58709b52635efdb09a0ce2a3a60fa2ac7111413c004d206639dca9e1487cd0
3
+ size 5147963
examples/24_clownfish.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e60c4972915e1c8537801e2805b7729338c4a122d759028ade59703d18972d81
3
+ size 4649052
pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=61.0",
4
+ "torch>=2.3.1",
5
+ ]
6
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.3.1
2
+ torchvision>=0.18.1
3
+ numpy>=1.24.4
4
+ tqdm>=4.66.1
5
+ hydra-core>=1.3.2
6
+ iopath>=0.1.10
7
+ pillow>=9.4.0
8
+ gradio==4.44.0
9
+ gradio_client==1.3.0
10
+ gradio_image_prompter==0.1.0
11
+ opencv-python==4.10.0.84
12
+ moviepy==1.0.3
13
+ pydantic==2.10.6
14
+ timm==1.0.15
15
+ eva-decord==0.6.1
sam2/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+ from hydra.core.global_hydra import GlobalHydra
9
+
10
+ if not GlobalHydra.instance().is_initialized():
11
+ initialize_config_module("sam2", version_base="1.2")
sam2/automatic_mask_generator.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
+
14
+ from sam2.modeling.sam2_base import SAM2Base
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from sam2.utils.amg import (
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ MaskData,
28
+ remove_small_regions,
29
+ rle_to_mask,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ uncrop_points,
33
+ )
34
+
35
+
36
+ class SAM2AutomaticMaskGenerator:
37
+ def __init__(
38
+ self,
39
+ model: SAM2Base,
40
+ points_per_side: Optional[int] = 32,
41
+ points_per_batch: int = 64,
42
+ pred_iou_thresh: float = 0.8,
43
+ stability_score_thresh: float = 0.95,
44
+ stability_score_offset: float = 1.0,
45
+ mask_threshold: float = 0.0,
46
+ box_nms_thresh: float = 0.7,
47
+ crop_n_layers: int = 0,
48
+ crop_nms_thresh: float = 0.7,
49
+ crop_overlap_ratio: float = 512 / 1500,
50
+ crop_n_points_downscale_factor: int = 1,
51
+ point_grids: Optional[List[np.ndarray]] = None,
52
+ min_mask_region_area: int = 0,
53
+ output_mode: str = "binary_mask",
54
+ use_m2m: bool = False,
55
+ multimask_output: bool = True,
56
+ **kwargs,
57
+ ) -> None:
58
+ """
59
+ Using a SAM 2 model, generates masks for the entire image.
60
+ Generates a grid of point prompts over the image, then filters
61
+ low quality and duplicate masks. The default settings are chosen
62
+ for SAM 2 with a HieraL backbone.
63
+
64
+ Arguments:
65
+ model (Sam): The SAM 2 model to use for mask prediction.
66
+ points_per_side (int or None): The number of points to be sampled
67
+ along one side of the image. The total number of points is
68
+ points_per_side**2. If None, 'point_grids' must provide explicit
69
+ point sampling.
70
+ points_per_batch (int): Sets the number of points run simultaneously
71
+ by the model. Higher numbers may be faster but use more GPU memory.
72
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
73
+ model's predicted mask quality.
74
+ stability_score_thresh (float): A filtering threshold in [0,1], using
75
+ the stability of the mask under changes to the cutoff used to binarize
76
+ the model's mask predictions.
77
+ stability_score_offset (float): The amount to shift the cutoff when
78
+ calculated the stability score.
79
+ mask_threshold (float): Threshold for binarizing the mask logits
80
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
81
+ suppression to filter duplicate masks.
82
+ crop_n_layers (int): If >0, mask prediction will be run again on
83
+ crops of the image. Sets the number of layers to run, where each
84
+ layer has 2**i_layer number of image crops.
85
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
86
+ suppression to filter duplicate masks between different crops.
87
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
88
+ In the first crop layer, crops will overlap by this fraction of
89
+ the image length. Later layers with more crops scale down this overlap.
90
+ crop_n_points_downscale_factor (int): The number of points-per-side
91
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
92
+ point_grids (list(np.ndarray) or None): A list over explicit grids
93
+ of points used for sampling, normalized to [0,1]. The nth grid in the
94
+ list is used in the nth crop layer. Exclusive with points_per_side.
95
+ min_mask_region_area (int): If >0, postprocessing will be applied
96
+ to remove disconnected regions and holes in masks with area smaller
97
+ than min_mask_region_area. Requires opencv.
98
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
99
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
100
+ For large resolutions, 'binary_mask' may consume large amounts of
101
+ memory.
102
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
103
+ multimask_output (bool): Whether to output multimask at each point of the grid.
104
+ """
105
+
106
+ assert (points_per_side is None) != (
107
+ point_grids is None
108
+ ), "Exactly one of points_per_side or point_grid must be provided."
109
+ if points_per_side is not None:
110
+ self.point_grids = build_all_layer_point_grids(
111
+ points_per_side,
112
+ crop_n_layers,
113
+ crop_n_points_downscale_factor,
114
+ )
115
+ elif point_grids is not None:
116
+ self.point_grids = point_grids
117
+ else:
118
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
119
+
120
+ assert output_mode in [
121
+ "binary_mask",
122
+ "uncompressed_rle",
123
+ "coco_rle",
124
+ ], f"Unknown output_mode {output_mode}."
125
+ if output_mode == "coco_rle":
126
+ try:
127
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
128
+ except ImportError as e:
129
+ print("Please install pycocotools")
130
+ raise e
131
+
132
+ self.predictor = SAM2ImagePredictor(
133
+ model,
134
+ max_hole_area=min_mask_region_area,
135
+ max_sprinkle_area=min_mask_region_area,
136
+ )
137
+ self.points_per_batch = points_per_batch
138
+ self.pred_iou_thresh = pred_iou_thresh
139
+ self.stability_score_thresh = stability_score_thresh
140
+ self.stability_score_offset = stability_score_offset
141
+ self.mask_threshold = mask_threshold
142
+ self.box_nms_thresh = box_nms_thresh
143
+ self.crop_n_layers = crop_n_layers
144
+ self.crop_nms_thresh = crop_nms_thresh
145
+ self.crop_overlap_ratio = crop_overlap_ratio
146
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
147
+ self.min_mask_region_area = min_mask_region_area
148
+ self.output_mode = output_mode
149
+ self.use_m2m = use_m2m
150
+ self.multimask_output = multimask_output
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
154
+ """
155
+ Load a pretrained model from the Hugging Face hub.
156
+
157
+ Arguments:
158
+ model_id (str): The Hugging Face repository ID.
159
+ **kwargs: Additional arguments to pass to the model constructor.
160
+
161
+ Returns:
162
+ (SAM2AutomaticMaskGenerator): The loaded model.
163
+ """
164
+ from sam2.build_sam import build_sam2_hf
165
+
166
+ sam_model = build_sam2_hf(model_id, **kwargs)
167
+ return cls(sam_model, **kwargs)
168
+
169
+ @torch.no_grad()
170
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
171
+ """
172
+ Generates masks for the given image.
173
+
174
+ Arguments:
175
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
176
+
177
+ Returns:
178
+ list(dict(str, any)): A list over records for masks. Each record is
179
+ a dict containing the following keys:
180
+ segmentation (dict(str, any) or np.ndarray): The mask. If
181
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
182
+ is a dictionary containing the RLE.
183
+ bbox (list(float)): The box around the mask, in XYWH format.
184
+ area (int): The area in pixels of the mask.
185
+ predicted_iou (float): The model's own prediction of the mask's
186
+ quality. This is filtered by the pred_iou_thresh parameter.
187
+ point_coords (list(list(float))): The point coordinates input
188
+ to the model to generate this mask.
189
+ stability_score (float): A measure of the mask's quality. This
190
+ is filtered on using the stability_score_thresh parameter.
191
+ crop_box (list(float)): The crop of the image used to generate
192
+ the mask, given in XYWH format.
193
+ """
194
+
195
+ # Generate masks
196
+ mask_data = self._generate_masks(image)
197
+
198
+ # Encode masks
199
+ if self.output_mode == "coco_rle":
200
+ mask_data["segmentations"] = [
201
+ coco_encode_rle(rle) for rle in mask_data["rles"]
202
+ ]
203
+ elif self.output_mode == "binary_mask":
204
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
205
+ else:
206
+ mask_data["segmentations"] = mask_data["rles"]
207
+
208
+ # Write mask records
209
+ curr_anns = []
210
+ for idx in range(len(mask_data["segmentations"])):
211
+ ann = {
212
+ "segmentation": mask_data["segmentations"][idx],
213
+ "area": area_from_rle(mask_data["rles"][idx]),
214
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
215
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
216
+ "point_coords": [mask_data["points"][idx].tolist()],
217
+ "stability_score": mask_data["stability_score"][idx].item(),
218
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
219
+ }
220
+ curr_anns.append(ann)
221
+
222
+ return curr_anns
223
+
224
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
225
+ orig_size = image.shape[:2]
226
+ crop_boxes, layer_idxs = generate_crop_boxes(
227
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
228
+ )
229
+
230
+ # Iterate over image crops
231
+ data = MaskData()
232
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
233
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
234
+ data.cat(crop_data)
235
+
236
+ # Remove duplicate masks between crops
237
+ if len(crop_boxes) > 1:
238
+ # Prefer masks from smaller crops
239
+ scores = 1 / box_area(data["crop_boxes"])
240
+ scores = scores.to(data["boxes"].device)
241
+ keep_by_nms = batched_nms(
242
+ data["boxes"].float(),
243
+ scores,
244
+ torch.zeros_like(data["boxes"][:, 0]), # categories
245
+ iou_threshold=self.crop_nms_thresh,
246
+ )
247
+ data.filter(keep_by_nms)
248
+ data.to_numpy()
249
+ return data
250
+
251
+ def _process_crop(
252
+ self,
253
+ image: np.ndarray,
254
+ crop_box: List[int],
255
+ crop_layer_idx: int,
256
+ orig_size: Tuple[int, ...],
257
+ ) -> MaskData:
258
+ # Crop the image and calculate embeddings
259
+ x0, y0, x1, y1 = crop_box
260
+ cropped_im = image[y0:y1, x0:x1, :]
261
+ cropped_im_size = cropped_im.shape[:2]
262
+ self.predictor.set_image(cropped_im)
263
+
264
+ # Get points for this crop
265
+ points_scale = np.array(cropped_im_size)[None, ::-1]
266
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
267
+
268
+ # Generate masks for this crop in batches
269
+ data = MaskData()
270
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
271
+ batch_data = self._process_batch(
272
+ points, cropped_im_size, crop_box, orig_size, normalize=True
273
+ )
274
+ data.cat(batch_data)
275
+ del batch_data
276
+ self.predictor.reset_predictor()
277
+
278
+ # Remove duplicates within this crop.
279
+ keep_by_nms = batched_nms(
280
+ data["boxes"].float(),
281
+ data["iou_preds"],
282
+ torch.zeros_like(data["boxes"][:, 0]), # categories
283
+ iou_threshold=self.box_nms_thresh,
284
+ )
285
+ data.filter(keep_by_nms)
286
+
287
+ # Return to the original image frame
288
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
289
+ data["points"] = uncrop_points(data["points"], crop_box)
290
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
291
+
292
+ return data
293
+
294
+ def _process_batch(
295
+ self,
296
+ points: np.ndarray,
297
+ im_size: Tuple[int, ...],
298
+ crop_box: List[int],
299
+ orig_size: Tuple[int, ...],
300
+ normalize=False,
301
+ ) -> MaskData:
302
+ orig_h, orig_w = orig_size
303
+
304
+ # Run model on this batch
305
+ points = torch.as_tensor(
306
+ points, dtype=torch.float32, device=self.predictor.device
307
+ )
308
+ in_points = self.predictor._transforms.transform_coords(
309
+ points, normalize=normalize, orig_hw=im_size
310
+ )
311
+ in_labels = torch.ones(
312
+ in_points.shape[0], dtype=torch.int, device=in_points.device
313
+ )
314
+ masks, iou_preds, low_res_masks = self.predictor._predict(
315
+ in_points[:, None, :],
316
+ in_labels[:, None],
317
+ multimask_output=self.multimask_output,
318
+ return_logits=True,
319
+ )
320
+
321
+ # Serialize predictions and store in MaskData
322
+ data = MaskData(
323
+ masks=masks.flatten(0, 1),
324
+ iou_preds=iou_preds.flatten(0, 1),
325
+ points=points.repeat_interleave(masks.shape[1], dim=0),
326
+ low_res_masks=low_res_masks.flatten(0, 1),
327
+ )
328
+ del masks
329
+
330
+ if not self.use_m2m:
331
+ # Filter by predicted IoU
332
+ if self.pred_iou_thresh > 0.0:
333
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
334
+ data.filter(keep_mask)
335
+
336
+ # Calculate and filter by stability score
337
+ data["stability_score"] = calculate_stability_score(
338
+ data["masks"], self.mask_threshold, self.stability_score_offset
339
+ )
340
+ if self.stability_score_thresh > 0.0:
341
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
342
+ data.filter(keep_mask)
343
+ else:
344
+ # One step refinement using previous mask predictions
345
+ in_points = self.predictor._transforms.transform_coords(
346
+ data["points"], normalize=normalize, orig_hw=im_size
347
+ )
348
+ labels = torch.ones(
349
+ in_points.shape[0], dtype=torch.int, device=in_points.device
350
+ )
351
+ masks, ious = self.refine_with_m2m(
352
+ in_points, labels, data["low_res_masks"], self.points_per_batch
353
+ )
354
+ data["masks"] = masks.squeeze(1)
355
+ data["iou_preds"] = ious.squeeze(1)
356
+
357
+ if self.pred_iou_thresh > 0.0:
358
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
359
+ data.filter(keep_mask)
360
+
361
+ data["stability_score"] = calculate_stability_score(
362
+ data["masks"], self.mask_threshold, self.stability_score_offset
363
+ )
364
+ if self.stability_score_thresh > 0.0:
365
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
366
+ data.filter(keep_mask)
367
+
368
+ # Threshold masks and calculate boxes
369
+ data["masks"] = data["masks"] > self.mask_threshold
370
+ data["boxes"] = batched_mask_to_box(data["masks"])
371
+
372
+ # Filter boxes that touch crop boundaries
373
+ keep_mask = ~is_box_near_crop_edge(
374
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
375
+ )
376
+ if not torch.all(keep_mask):
377
+ data.filter(keep_mask)
378
+
379
+ # Compress to RLE
380
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
381
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
382
+ del data["masks"]
383
+
384
+ return data
385
+
386
+ @staticmethod
387
+ def postprocess_small_regions(
388
+ mask_data: MaskData, min_area: int, nms_thresh: float
389
+ ) -> MaskData:
390
+ """
391
+ Removes small disconnected regions and holes in masks, then reruns
392
+ box NMS to remove any new duplicates.
393
+
394
+ Edits mask_data in place.
395
+
396
+ Requires open-cv as a dependency.
397
+ """
398
+ if len(mask_data["rles"]) == 0:
399
+ return mask_data
400
+
401
+ # Filter small disconnected regions and holes
402
+ new_masks = []
403
+ scores = []
404
+ for rle in mask_data["rles"]:
405
+ mask = rle_to_mask(rle)
406
+
407
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
408
+ unchanged = not changed
409
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
410
+ unchanged = unchanged and not changed
411
+
412
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
413
+ # Give score=0 to changed masks and score=1 to unchanged masks
414
+ # so NMS will prefer ones that didn't need postprocessing
415
+ scores.append(float(unchanged))
416
+
417
+ # Recalculate boxes and remove any new duplicates
418
+ masks = torch.cat(new_masks, dim=0)
419
+ boxes = batched_mask_to_box(masks)
420
+ keep_by_nms = batched_nms(
421
+ boxes.float(),
422
+ torch.as_tensor(scores),
423
+ torch.zeros_like(boxes[:, 0]), # categories
424
+ iou_threshold=nms_thresh,
425
+ )
426
+
427
+ # Only recalculate RLEs for masks that have changed
428
+ for i_mask in keep_by_nms:
429
+ if scores[i_mask] == 0.0:
430
+ mask_torch = masks[i_mask].unsqueeze(0)
431
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
432
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
433
+ mask_data.filter(keep_by_nms)
434
+
435
+ return mask_data
436
+
437
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
438
+ new_masks = []
439
+ new_iou_preds = []
440
+
441
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
442
+ points_per_batch, points, point_labels, low_res_masks
443
+ ):
444
+ best_masks, best_iou_preds, _ = self.predictor._predict(
445
+ cur_points[:, None, :],
446
+ cur_point_labels[:, None],
447
+ mask_input=low_res_mask[:, None, :],
448
+ multimask_output=False,
449
+ return_logits=True,
450
+ )
451
+ new_masks.append(best_masks)
452
+ new_iou_preds.append(best_iou_preds)
453
+ masks = torch.cat(new_masks, dim=0)
454
+ return masks, torch.cat(new_iou_preds, dim=0)
sam2/build_sam.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ import sam2
11
+
12
+ import torch
13
+ from hydra import compose
14
+ from hydra.utils import instantiate
15
+ from omegaconf import OmegaConf
16
+
17
+ # Check if the user is running Python from the parent directory of the sam2 repo
18
+ # (i.e. the directory where this repo is cloned into) -- this is not supported since
19
+ # it could shadow the sam2 package and cause issues.
20
+ if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21
+ # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22
+ # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23
+ # This typically happens because the user is running Python from the parent directory
24
+ # that contains the sam2 repo they cloned.
25
+ raise RuntimeError(
26
+ "You're likely running Python from the parent directory of the sam2 repository "
27
+ "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28
+ "This is not supported since the `sam2` Python package could be shadowed by the "
29
+ "repository name (the repository is also named `sam2` and contains the Python package "
30
+ "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31
+ "rather than its parent dir, or from your home directory) after installing SAM 2."
32
+ )
33
+
34
+
35
+ HF_MODEL_ID_TO_FILENAMES = {
36
+ "facebook/sam2-hiera-tiny": (
37
+ "configs/sam2/sam2_hiera_t.yaml",
38
+ "sam2_hiera_tiny.pt",
39
+ ),
40
+ "facebook/sam2-hiera-small": (
41
+ "configs/sam2/sam2_hiera_s.yaml",
42
+ "sam2_hiera_small.pt",
43
+ ),
44
+ "facebook/sam2-hiera-base-plus": (
45
+ "configs/sam2/sam2_hiera_b+.yaml",
46
+ "sam2_hiera_base_plus.pt",
47
+ ),
48
+ "facebook/sam2-hiera-large": (
49
+ "configs/sam2/sam2_hiera_l.yaml",
50
+ "sam2_hiera_large.pt",
51
+ ),
52
+ "facebook/sam2.1-hiera-tiny": (
53
+ "configs/sam2.1/sam2.1_hiera_t.yaml",
54
+ "sam2.1_hiera_tiny.pt",
55
+ ),
56
+ "facebook/sam2.1-hiera-small": (
57
+ "configs/sam2.1/sam2.1_hiera_s.yaml",
58
+ "sam2.1_hiera_small.pt",
59
+ ),
60
+ "facebook/sam2.1-hiera-base-plus": (
61
+ "configs/sam2.1/sam2.1_hiera_b+.yaml",
62
+ "sam2.1_hiera_base_plus.pt",
63
+ ),
64
+ "facebook/sam2.1-hiera-large": (
65
+ "configs/sam2.1/sam2.1_hiera_l.yaml",
66
+ "sam2.1_hiera_large.pt",
67
+ ),
68
+ }
69
+
70
+
71
+ def build_sam2(
72
+ config_file,
73
+ ckpt_path=None,
74
+ device="cuda",
75
+ mode="eval",
76
+ hydra_overrides_extra=[],
77
+ apply_postprocessing=True,
78
+ **kwargs,
79
+ ):
80
+
81
+ if apply_postprocessing:
82
+ hydra_overrides_extra = hydra_overrides_extra.copy()
83
+ hydra_overrides_extra += [
84
+ # dynamically fall back to multi-mask if the single mask is not stable
85
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88
+ ]
89
+ # Read config and init model
90
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91
+ OmegaConf.resolve(cfg)
92
+ model = instantiate(cfg.model, _recursive_=True)
93
+ _load_checkpoint(model, ckpt_path)
94
+ model = model.to(device)
95
+ if mode == "eval":
96
+ model.eval()
97
+ return model
98
+
99
+
100
+ def build_sam2_video_predictor(
101
+ config_file,
102
+ ckpt_path=None,
103
+ device="cuda",
104
+ mode="eval",
105
+ hydra_overrides_extra=[],
106
+ apply_postprocessing=True,
107
+ **kwargs,
108
+ ):
109
+ hydra_overrides = [
110
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
111
+ ]
112
+ if apply_postprocessing:
113
+ hydra_overrides_extra = hydra_overrides_extra.copy()
114
+ hydra_overrides_extra += [
115
+ # dynamically fall back to multi-mask if the single mask is not stable
116
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
117
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
118
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
119
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
120
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
121
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
122
+ "++model.fill_hole_area=8",
123
+ ]
124
+ hydra_overrides.extend(hydra_overrides_extra)
125
+
126
+ # Read config and init model
127
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
128
+ OmegaConf.resolve(cfg)
129
+ print("configuration solved")
130
+ model = instantiate(cfg.model, _recursive_=True)
131
+ print("model instantiated")
132
+ _load_checkpoint(model, ckpt_path)
133
+ print("checkpoint loaded")
134
+ model = model.to(device)
135
+ if mode == "eval":
136
+ model.eval()
137
+ print("model ready")
138
+ return model
139
+
140
+
141
+ def _hf_download(model_id):
142
+ from huggingface_hub import hf_hub_download
143
+
144
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
145
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
146
+ return config_name, ckpt_path
147
+
148
+
149
+ def build_sam2_hf(model_id, **kwargs):
150
+ config_name, ckpt_path = _hf_download(model_id)
151
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
152
+
153
+
154
+ def build_sam2_video_predictor_hf(model_id, **kwargs):
155
+ config_name, ckpt_path = _hf_download(model_id)
156
+ return build_sam2_video_predictor(
157
+ config_file=config_name, ckpt_path=ckpt_path, **kwargs
158
+ )
159
+
160
+
161
+ def _load_checkpoint(model, ckpt_path):
162
+ if ckpt_path is not None:
163
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
164
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
165
+ if missing_keys:
166
+ logging.error(missing_keys)
167
+ raise RuntimeError()
168
+ if unexpected_keys:
169
+ logging.error(unexpected_keys)
170
+ raise RuntimeError()
171
+ logging.info("Loaded checkpoint sucessfully")
sam2/configs/edgetam.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.timm.TimmBackbone
11
+ name: repvit_m1.dist_in1k
12
+ features:
13
+ - layer0
14
+ - layer1
15
+ - layer2
16
+ - layer3
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [384, 192, 96, 48]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttentionv2
53
+ rope_theta: 10000.0
54
+ q_sizes: [64, 64]
55
+ k_sizes: [16, 16]
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 2
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ spatial_perceiver:
89
+ _target_: sam2.modeling.perceiver.PerceiverResampler
90
+ depth: 2
91
+ dim: 64
92
+ dim_head: 64
93
+ heads: 1
94
+ ff_mult: 4
95
+ hidden_dropout_p: 0.
96
+ attention_dropout_p: 0.
97
+ pos_enc_at_key_value: true # implicit pos
98
+ concat_kv_latents: false
99
+ num_latents: 256
100
+ num_latents_2d: 256
101
+ position_encoding:
102
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
103
+ num_pos_feats: 64
104
+ normalize: true
105
+ scale: null
106
+ temperature: 10000
107
+ use_self_attn: true
108
+
109
+ num_maskmem: 7
110
+ image_size: 1024
111
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
112
+ sigmoid_scale_for_mem_enc: 20.0
113
+ sigmoid_bias_for_mem_enc: -10.0
114
+ use_mask_input_as_output_without_sam: true
115
+ # Memory
116
+ directly_add_no_mem_embed: true
117
+ # use high-resolution feature map in the SAM mask decoder
118
+ use_high_res_features_in_sam: true
119
+ # output 3 masks on the first click on initial conditioning frames
120
+ multimask_output_in_sam: true
121
+ # SAM heads
122
+ iou_prediction_use_sigmoid: True
123
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
124
+ use_obj_ptrs_in_encoder: true
125
+ add_tpos_enc_to_obj_ptrs: false
126
+ only_obj_ptrs_in_the_past_for_eval: true
127
+ # object occlusion prediction
128
+ pred_obj_scores: true
129
+ pred_obj_scores_mlp: true
130
+ fixed_no_obj_ptr: true
131
+ # multimask tracking settings
132
+ multimask_output_for_tracking: true
133
+ use_multimask_token_for_obj_ptr: true
134
+ multimask_min_pt_num: 0
135
+ multimask_max_pt_num: 1
136
+ use_mlp_for_obj_ptr_proj: true
137
+ # Compilation flag
138
+ compile_image_encoder: false
sam2/configs/sam2.1/sam2.1_hiera_b+.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ no_obj_embed_spatial: true
93
+ # use high-resolution feature map in the SAM mask decoder
94
+ use_high_res_features_in_sam: true
95
+ # output 3 masks on the first click on initial conditioning frames
96
+ multimask_output_in_sam: true
97
+ # SAM heads
98
+ iou_prediction_use_sigmoid: True
99
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
100
+ use_obj_ptrs_in_encoder: true
101
+ add_tpos_enc_to_obj_ptrs: true
102
+ proj_tpos_enc_in_obj_ptrs: true
103
+ use_signed_tpos_enc_to_obj_ptrs: true
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
sam2/configs/sam2.1/sam2.1_hiera_l.yaml ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ compile_image_encoder: False
sam2/configs/sam2.1/sam2.1_hiera_s.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ no_obj_embed_spatial: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: true
105
+ proj_tpos_enc_in_obj_ptrs: true
106
+ use_signed_tpos_enc_to_obj_ptrs: true
107
+ only_obj_ptrs_in_the_past_for_eval: true
108
+ # object occlusion prediction
109
+ pred_obj_scores: true
110
+ pred_obj_scores_mlp: true
111
+ fixed_no_obj_ptr: true
112
+ # multimask tracking settings
113
+ multimask_output_for_tracking: true
114
+ use_multimask_token_for_obj_ptr: true
115
+ multimask_min_pt_num: 0
116
+ multimask_max_pt_num: 1
117
+ use_mlp_for_obj_ptr_proj: true
118
+ # Compilation flag
119
+ compile_image_encoder: False
sam2/configs/sam2.1/sam2.1_hiera_t.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ # HieraT does not currently support compilation, should always be set to False
121
+ compile_image_encoder: False
sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 8
8
+ max_num_objects: 3
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 40
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: null # PATH to MOSE JPEGImages folder
17
+ gt_folder: null # PATH to MOSE Annotations folder
18
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
19
+ multiplier: 2
20
+
21
+ # Video transforms
22
+ vos:
23
+ train_transforms:
24
+ - _target_: training.dataset.transforms.ComposeAPI
25
+ transforms:
26
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
27
+ consistent_transform: True
28
+ - _target_: training.dataset.transforms.RandomAffine
29
+ degrees: 25
30
+ shear: 20
31
+ image_interpolation: bilinear
32
+ consistent_transform: True
33
+ - _target_: training.dataset.transforms.RandomResizeAPI
34
+ sizes: ${scratch.resolution}
35
+ square: true
36
+ consistent_transform: True
37
+ - _target_: training.dataset.transforms.ColorJitter
38
+ consistent_transform: True
39
+ brightness: 0.1
40
+ contrast: 0.03
41
+ saturation: 0.03
42
+ hue: null
43
+ - _target_: training.dataset.transforms.RandomGrayscale
44
+ p: 0.05
45
+ consistent_transform: True
46
+ - _target_: training.dataset.transforms.ColorJitter
47
+ consistent_transform: False
48
+ brightness: 0.1
49
+ contrast: 0.05
50
+ saturation: 0.05
51
+ hue: null
52
+ - _target_: training.dataset.transforms.ToTensorAPI
53
+ - _target_: training.dataset.transforms.NormalizeAPI
54
+ mean: [0.485, 0.456, 0.406]
55
+ std: [0.229, 0.224, 0.225]
56
+
57
+ trainer:
58
+ _target_: training.trainer.Trainer
59
+ mode: train_only
60
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
61
+ accelerator: cuda
62
+ seed_value: 123
63
+
64
+ model:
65
+ _target_: training.model.sam2.SAM2Train
66
+ image_encoder:
67
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
+ scalp: 1
69
+ trunk:
70
+ _target_: sam2.modeling.backbones.hieradet.Hiera
71
+ embed_dim: 112
72
+ num_heads: 2
73
+ drop_path_rate: 0.1
74
+ neck:
75
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
+ position_encoding:
77
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
+ num_pos_feats: 256
79
+ normalize: true
80
+ scale: null
81
+ temperature: 10000
82
+ d_model: 256
83
+ backbone_channel_list: [896, 448, 224, 112]
84
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
85
+ fpn_interp_model: nearest
86
+
87
+ memory_attention:
88
+ _target_: sam2.modeling.memory_attention.MemoryAttention
89
+ d_model: 256
90
+ pos_enc_at_input: true
91
+ layer:
92
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
+ activation: relu
94
+ dim_feedforward: 2048
95
+ dropout: 0.1
96
+ pos_enc_at_attn: false
97
+ self_attention:
98
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
99
+ rope_theta: 10000.0
100
+ feat_sizes: [32, 32]
101
+ embedding_dim: 256
102
+ num_heads: 1
103
+ downsample_rate: 1
104
+ dropout: 0.1
105
+ d_model: 256
106
+ pos_enc_at_cross_attn_keys: true
107
+ pos_enc_at_cross_attn_queries: false
108
+ cross_attention:
109
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
110
+ rope_theta: 10000.0
111
+ feat_sizes: [32, 32]
112
+ rope_k_repeat: True
113
+ embedding_dim: 256
114
+ num_heads: 1
115
+ downsample_rate: 1
116
+ dropout: 0.1
117
+ kv_in_dim: 64
118
+ num_layers: 4
119
+
120
+ memory_encoder:
121
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
+ out_dim: 64
123
+ position_encoding:
124
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
+ num_pos_feats: 64
126
+ normalize: true
127
+ scale: null
128
+ temperature: 10000
129
+ mask_downsampler:
130
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
+ kernel_size: 3
132
+ stride: 2
133
+ padding: 1
134
+ fuser:
135
+ _target_: sam2.modeling.memory_encoder.Fuser
136
+ layer:
137
+ _target_: sam2.modeling.memory_encoder.CXBlock
138
+ dim: 256
139
+ kernel_size: 7
140
+ padding: 3
141
+ layer_scale_init_value: 1e-6
142
+ use_dwconv: True # depth-wise convs
143
+ num_layers: 2
144
+
145
+ num_maskmem: 7
146
+ image_size: ${scratch.resolution}
147
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
148
+ sigmoid_scale_for_mem_enc: 20.0
149
+ sigmoid_bias_for_mem_enc: -10.0
150
+ use_mask_input_as_output_without_sam: true
151
+ # Memory
152
+ directly_add_no_mem_embed: true
153
+ no_obj_embed_spatial: true
154
+ # use high-resolution feature map in the SAM mask decoder
155
+ use_high_res_features_in_sam: true
156
+ # output 3 masks on the first click on initial conditioning frames
157
+ multimask_output_in_sam: true
158
+ # SAM heads
159
+ iou_prediction_use_sigmoid: True
160
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
161
+ use_obj_ptrs_in_encoder: true
162
+ add_tpos_enc_to_obj_ptrs: true
163
+ proj_tpos_enc_in_obj_ptrs: true
164
+ use_signed_tpos_enc_to_obj_ptrs: true
165
+ only_obj_ptrs_in_the_past_for_eval: true
166
+ # object occlusion prediction
167
+ pred_obj_scores: true
168
+ pred_obj_scores_mlp: true
169
+ fixed_no_obj_ptr: true
170
+ # multimask tracking settings
171
+ multimask_output_for_tracking: true
172
+ use_multimask_token_for_obj_ptr: true
173
+ multimask_min_pt_num: 0
174
+ multimask_max_pt_num: 1
175
+ use_mlp_for_obj_ptr_proj: true
176
+ # Compilation flag
177
+ # compile_image_encoder: False
178
+
179
+ ####### Training specific params #######
180
+ # box/point input and corrections
181
+ prob_to_use_pt_input_for_train: 0.5
182
+ prob_to_use_pt_input_for_eval: 0.0
183
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
184
+ prob_to_use_box_input_for_eval: 0.0
185
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
186
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
187
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
188
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
189
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
190
+ # maximum 2 initial conditioning frames
191
+ num_init_cond_frames_for_train: 2
192
+ rand_init_cond_frames_for_train: True # random 1~2
193
+ num_correction_pt_per_frame: 7
194
+ use_act_ckpt_iterative_pt_sampling: false
195
+
196
+
197
+
198
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
199
+ forward_backbone_per_frame_for_eval: True
200
+
201
+
202
+ data:
203
+ train:
204
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
205
+ phases_per_epoch: ${scratch.phases_per_epoch}
206
+ batch_sizes:
207
+ - ${scratch.train_batch_size}
208
+
209
+ datasets:
210
+ - _target_: training.dataset.utils.RepeatFactorWrapper
211
+ dataset:
212
+ _target_: training.dataset.utils.ConcatDataset
213
+ datasets:
214
+ - _target_: training.dataset.vos_dataset.VOSDataset
215
+ transforms: ${vos.train_transforms}
216
+ training: true
217
+ video_dataset:
218
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
219
+ img_folder: ${dataset.img_folder}
220
+ gt_folder: ${dataset.gt_folder}
221
+ file_list_txt: ${dataset.file_list_txt}
222
+ sampler:
223
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
224
+ num_frames: ${scratch.num_frames}
225
+ max_num_objects: ${scratch.max_num_objects}
226
+ multiplier: ${dataset.multiplier}
227
+ shuffle: True
228
+ num_workers: ${scratch.num_train_workers}
229
+ pin_memory: True
230
+ drop_last: True
231
+ collate_fn:
232
+ _target_: training.utils.data_utils.collate_fn
233
+ _partial_: true
234
+ dict_key: all
235
+
236
+ optim:
237
+ amp:
238
+ enabled: True
239
+ amp_dtype: bfloat16
240
+
241
+ optimizer:
242
+ _target_: torch.optim.AdamW
243
+
244
+ gradient_clip:
245
+ _target_: training.optimizer.GradientClipper
246
+ max_norm: 0.1
247
+ norm_type: 2
248
+
249
+ param_group_modifiers:
250
+ - _target_: training.optimizer.layer_decay_param_modifier
251
+ _partial_: True
252
+ layer_decay_value: 0.9
253
+ apply_to: 'image_encoder.trunk'
254
+ overrides:
255
+ - pattern: '*pos_embed*'
256
+ value: 1.0
257
+
258
+ options:
259
+ lr:
260
+ - scheduler:
261
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
262
+ start_value: ${scratch.base_lr}
263
+ end_value: ${divide:${scratch.base_lr},10}
264
+ - scheduler:
265
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
266
+ start_value: ${scratch.vision_lr}
267
+ end_value: ${divide:${scratch.vision_lr},10}
268
+ param_names:
269
+ - 'image_encoder.*'
270
+ weight_decay:
271
+ - scheduler:
272
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
273
+ value: 0.1
274
+ - scheduler:
275
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
276
+ value: 0.0
277
+ param_names:
278
+ - '*bias*'
279
+ module_cls_names: ['torch.nn.LayerNorm']
280
+
281
+ loss:
282
+ all:
283
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
284
+ weight_dict:
285
+ loss_mask: 20
286
+ loss_dice: 1
287
+ loss_iou: 1
288
+ loss_class: 1
289
+ supervise_all_iou: true
290
+ iou_use_l1_loss: true
291
+ pred_obj_scores: true
292
+ focal_gamma_obj_score: 0.0
293
+ focal_alpha_obj_score: -1.0
294
+
295
+ distributed:
296
+ backend: nccl
297
+ find_unused_parameters: True
298
+
299
+ logging:
300
+ tensorboard_writer:
301
+ _target_: training.utils.logger.make_tensorboard_logger
302
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
303
+ flush_secs: 120
304
+ should_log: True
305
+ log_dir: ${launcher.experiment_log_dir}/logs
306
+ log_freq: 10
307
+
308
+ # initialize from a SAM 2 checkpoint
309
+ checkpoint:
310
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
311
+ save_freq: 0 # 0 only last checkpoint is saved.
312
+ model_weight_initializer:
313
+ _partial_: True
314
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
315
+ strict: True
316
+ ignore_unexpected_keys: null
317
+ ignore_missing_keys: null
318
+
319
+ state_dict:
320
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
321
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
322
+ ckpt_state_dict_keys: ['model']
323
+
324
+ launcher:
325
+ num_nodes: 1
326
+ gpus_per_node: 8
327
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
328
+
329
+ # SLURM args if running on a cluster
330
+ submitit:
331
+ partition: null
332
+ account: null
333
+ qos: null
334
+ cpus_per_task: 10
335
+ use_cluster: false
336
+ timeout_hour: 24
337
+ name: null
338
+ port_range: [10000, 65000]
339
+
sam2/configs/sam2/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False