Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +30 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE +201 -0
- README.md +14 -4
- app.py +499 -0
- checkpoints/download_ckpts.sh +59 -0
- checkpoints/edgetam.pt +3 -0
- convert_weights.py +36 -0
- examples/01_breakdancer.mp4 +3 -0
- examples/01_dog.mp4 +3 -0
- examples/02_cups.mp4 +3 -0
- examples/02_hummingbird.mp4 +3 -0
- examples/03_blocks.mp4 +3 -0
- examples/03_skateboarder.mp4 +3 -0
- examples/04_coffee.mp4 +3 -0
- examples/04_octopus.mp4 +3 -0
- examples/05_default_juggle.mp4 +3 -0
- examples/05_landing_dog_soccer.mp4 +3 -0
- examples/06_pingpong.mp4 +3 -0
- examples/07_snowboarder.mp4 +3 -0
- examples/08_driving.mp4 +3 -0
- examples/09_birdcartoon.mp4 +3 -0
- examples/10_cloth_magic.mp4 +3 -0
- examples/11_polevault.mp4 +3 -0
- examples/12_hideandseek.mp4 +3 -0
- examples/13_butterfly.mp4 +3 -0
- examples/14_social_dog_training.mp4 +3 -0
- examples/15_cricket.mp4 +3 -0
- examples/16_robotarm.mp4 +3 -0
- examples/17_childrendancing.mp4 +3 -0
- examples/18_threedogs.mp4 +3 -0
- examples/19_cyclist.mp4 +3 -0
- examples/20_doughkneading.mp4 +3 -0
- examples/21_biker.mp4 +3 -0
- examples/22_dogskateboarder.mp4 +3 -0
- examples/23_racecar.mp4 +3 -0
- examples/24_clownfish.mp4 +3 -0
- pyproject.toml +6 -0
- requirements.txt +15 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/build_sam.py +171 -0
- sam2/configs/edgetam.yaml +138 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
.gitattributes
CHANGED
@@ -33,3 +33,33 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/04_octopus.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/05_default_juggle.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/10_cloth_magic.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/14_social_dog_training.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/18_threedogs.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/01_dog.mp4 filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/09_birdcartoon.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/04_coffee.mp4 filter=lfs diff=lfs merge=lfs -text
|
44 |
+
examples/24_clownfish.mp4 filter=lfs diff=lfs merge=lfs -text
|
45 |
+
examples/06_pingpong.mp4 filter=lfs diff=lfs merge=lfs -text
|
46 |
+
examples/11_polevault.mp4 filter=lfs diff=lfs merge=lfs -text
|
47 |
+
examples/17_childrendancing.mp4 filter=lfs diff=lfs merge=lfs -text
|
48 |
+
examples/19_cyclist.mp4 filter=lfs diff=lfs merge=lfs -text
|
49 |
+
examples/20_doughkneading.mp4 filter=lfs diff=lfs merge=lfs -text
|
50 |
+
examples/02_hummingbird.mp4 filter=lfs diff=lfs merge=lfs -text
|
51 |
+
examples/03_blocks.mp4 filter=lfs diff=lfs merge=lfs -text
|
52 |
+
examples/23_racecar.mp4 filter=lfs diff=lfs merge=lfs -text
|
53 |
+
examples/01_breakdancer.mp4 filter=lfs diff=lfs merge=lfs -text
|
54 |
+
examples/15_cricket.mp4 filter=lfs diff=lfs merge=lfs -text
|
55 |
+
examples/16_robotarm.mp4 filter=lfs diff=lfs merge=lfs -text
|
56 |
+
examples/02_cups.mp4 filter=lfs diff=lfs merge=lfs -text
|
57 |
+
examples/08_driving.mp4 filter=lfs diff=lfs merge=lfs -text
|
58 |
+
examples/13_butterfly.mp4 filter=lfs diff=lfs merge=lfs -text
|
59 |
+
checkpoints/edgetam.pt filter=lfs diff=lfs merge=lfs -text
|
60 |
+
examples/05_landing_dog_soccer.mp4 filter=lfs diff=lfs merge=lfs -text
|
61 |
+
examples/07_snowboarder.mp4 filter=lfs diff=lfs merge=lfs -text
|
62 |
+
examples/12_hideandseek.mp4 filter=lfs diff=lfs merge=lfs -text
|
63 |
+
examples/21_biker.mp4 filter=lfs diff=lfs merge=lfs -text
|
64 |
+
examples/22_dogskateboarder.mp4 filter=lfs diff=lfs merge=lfs -text
|
65 |
+
examples/03_skateboarder.mp4 filter=lfs diff=lfs merge=lfs -text
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <[email protected]>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to EdgeTAM
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Meta's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## License
|
30 |
+
By contributing to EdgeTAM, you agree that your contributions will be licensed
|
31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,13 +1,23 @@
|
|
1 |
---
|
2 |
title: EdgeTAM
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: EdgeTAM
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
short_description: On-Device Track Anything Model
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
15 |
+
|
16 |
+
```
|
17 |
+
@misc{zhou2025edgetam,
|
18 |
+
title={EdgeTAM: On-Device Track Anything Model},
|
19 |
+
author={Zhou, Chong and Zhu, Chenchen and Xiong, Yunyang and Suri, Saksham and Xiao, Fanyi and Wu, Lemeng and Krishnamoorthi, Raghuraman and Dai, Bo and Loy, Chen Change and Chandra, Vikas and Soran, Bilge},
|
20 |
+
journal={arXiv preprint arXiv:2501.07256},
|
21 |
+
year={2025}
|
22 |
+
}
|
23 |
+
```
|
app.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import os
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
|
14 |
+
import tempfile
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import numpy as np
|
19 |
+
import spaces
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from moviepy.editor import ImageSequenceClip
|
23 |
+
from PIL import Image
|
24 |
+
from sam2.build_sam import build_sam2_video_predictor
|
25 |
+
|
26 |
+
# Description
|
27 |
+
title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"
|
28 |
+
|
29 |
+
description_p = """# Instructions
|
30 |
+
<ol>
|
31 |
+
<li> Upload one video or click one example video</li>
|
32 |
+
<li> Click 'include' point type, select the object to segment and track</li>
|
33 |
+
<li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
|
34 |
+
<li> Click the 'Track' button to obtain the masked video </li>
|
35 |
+
</ol>
|
36 |
+
"""
|
37 |
+
|
38 |
+
# examples
|
39 |
+
examples = [
|
40 |
+
["examples/01_dog.mp4"],
|
41 |
+
["examples/02_cups.mp4"],
|
42 |
+
["examples/03_blocks.mp4"],
|
43 |
+
["examples/04_coffee.mp4"],
|
44 |
+
["examples/05_default_juggle.mp4"],
|
45 |
+
["examples/01_breakdancer.mp4"],
|
46 |
+
["examples/02_hummingbird.mp4"],
|
47 |
+
["examples/03_skateboarder.mp4"],
|
48 |
+
["examples/04_octopus.mp4"],
|
49 |
+
["examples/05_landing_dog_soccer.mp4"],
|
50 |
+
["examples/06_pingpong.mp4"],
|
51 |
+
["examples/07_snowboarder.mp4"],
|
52 |
+
["examples/08_driving.mp4"],
|
53 |
+
["examples/09_birdcartoon.mp4"],
|
54 |
+
["examples/10_cloth_magic.mp4"],
|
55 |
+
["examples/11_polevault.mp4"],
|
56 |
+
["examples/12_hideandseek.mp4"],
|
57 |
+
["examples/13_butterfly.mp4"],
|
58 |
+
["examples/14_social_dog_training.mp4"],
|
59 |
+
["examples/15_cricket.mp4"],
|
60 |
+
["examples/16_robotarm.mp4"],
|
61 |
+
["examples/17_childrendancing.mp4"],
|
62 |
+
["examples/18_threedogs.mp4"],
|
63 |
+
["examples/19_cyclist.mp4"],
|
64 |
+
["examples/20_doughkneading.mp4"],
|
65 |
+
["examples/21_biker.mp4"],
|
66 |
+
["examples/22_dogskateboarder.mp4"],
|
67 |
+
["examples/23_racecar.mp4"],
|
68 |
+
["examples/24_clownfish.mp4"],
|
69 |
+
]
|
70 |
+
|
71 |
+
OBJ_ID = 0
|
72 |
+
|
73 |
+
|
74 |
+
@spaces.GPU
|
75 |
+
def get_predictor(session_state):
|
76 |
+
if "predictor" not in session_state:
|
77 |
+
sam2_checkpoint = "checkpoints/edgetam.pt"
|
78 |
+
model_cfg = "edgetam.yaml"
|
79 |
+
predictor = build_sam2_video_predictor(
|
80 |
+
model_cfg, sam2_checkpoint, device="cuda"
|
81 |
+
)
|
82 |
+
print("predictor loaded")
|
83 |
+
|
84 |
+
# use bfloat16 for the entire demo
|
85 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
86 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
87 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
88 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
89 |
+
torch.backends.cudnn.allow_tf32 = True
|
90 |
+
|
91 |
+
session_state["predictor"] = predictor
|
92 |
+
return session_state["predictor"]
|
93 |
+
|
94 |
+
|
95 |
+
def get_video_fps(video_path):
|
96 |
+
# Open the video file
|
97 |
+
cap = cv2.VideoCapture(video_path)
|
98 |
+
|
99 |
+
if not cap.isOpened():
|
100 |
+
print("Error: Could not open video.")
|
101 |
+
return None
|
102 |
+
|
103 |
+
# Get the FPS of the video
|
104 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
105 |
+
|
106 |
+
return fps
|
107 |
+
|
108 |
+
|
109 |
+
@spaces.GPU
|
110 |
+
def reset(session_state):
|
111 |
+
predictor = get_predictor(session_state)
|
112 |
+
predictor.to("cuda")
|
113 |
+
session_state["input_points"] = []
|
114 |
+
session_state["input_labels"] = []
|
115 |
+
if session_state["inference_state"] is not None:
|
116 |
+
predictor.reset_state(session_state["inference_state"])
|
117 |
+
session_state["first_frame"] = None
|
118 |
+
session_state["all_frames"] = None
|
119 |
+
session_state["inference_state"] = None
|
120 |
+
return (
|
121 |
+
None,
|
122 |
+
gr.update(open=True),
|
123 |
+
None,
|
124 |
+
None,
|
125 |
+
gr.update(value=None, visible=False),
|
126 |
+
session_state,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
@spaces.GPU
|
131 |
+
def clear_points(session_state):
|
132 |
+
predictor = get_predictor(session_state)
|
133 |
+
predictor.to("cuda")
|
134 |
+
session_state["input_points"] = []
|
135 |
+
session_state["input_labels"] = []
|
136 |
+
if session_state["inference_state"]["tracking_has_started"]:
|
137 |
+
predictor.reset_state(session_state["inference_state"])
|
138 |
+
return (
|
139 |
+
session_state["first_frame"],
|
140 |
+
None,
|
141 |
+
gr.update(value=None, visible=False),
|
142 |
+
session_state,
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
@spaces.GPU
|
147 |
+
def preprocess_video_in(video_path, session_state):
|
148 |
+
predictor = get_predictor(session_state)
|
149 |
+
predictor.to("cuda")
|
150 |
+
if video_path is None:
|
151 |
+
return (
|
152 |
+
gr.update(open=True), # video_in_drawer
|
153 |
+
None, # points_map
|
154 |
+
None, # output_image
|
155 |
+
gr.update(value=None, visible=False), # output_video
|
156 |
+
session_state,
|
157 |
+
)
|
158 |
+
|
159 |
+
# Read the first frame
|
160 |
+
cap = cv2.VideoCapture(video_path)
|
161 |
+
if not cap.isOpened():
|
162 |
+
print("Error: Could not open video.")
|
163 |
+
return (
|
164 |
+
gr.update(open=True), # video_in_drawer
|
165 |
+
None, # points_map
|
166 |
+
None, # output_image
|
167 |
+
gr.update(value=None, visible=False), # output_video
|
168 |
+
session_state,
|
169 |
+
)
|
170 |
+
|
171 |
+
frame_number = 0
|
172 |
+
first_frame = None
|
173 |
+
all_frames = []
|
174 |
+
|
175 |
+
while True:
|
176 |
+
ret, frame = cap.read()
|
177 |
+
if not ret:
|
178 |
+
break
|
179 |
+
|
180 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
181 |
+
frame = np.array(frame)
|
182 |
+
|
183 |
+
# Store the first frame
|
184 |
+
if frame_number == 0:
|
185 |
+
first_frame = frame
|
186 |
+
all_frames.append(frame)
|
187 |
+
|
188 |
+
frame_number += 1
|
189 |
+
|
190 |
+
cap.release()
|
191 |
+
session_state["first_frame"] = copy.deepcopy(first_frame)
|
192 |
+
session_state["all_frames"] = all_frames
|
193 |
+
|
194 |
+
session_state["inference_state"] = predictor.init_state(video_path=video_path)
|
195 |
+
session_state["input_points"] = []
|
196 |
+
session_state["input_labels"] = []
|
197 |
+
|
198 |
+
return [
|
199 |
+
gr.update(open=False), # video_in_drawer
|
200 |
+
first_frame, # points_map
|
201 |
+
None, # output_image
|
202 |
+
gr.update(value=None, visible=False), # output_video
|
203 |
+
session_state,
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
+
@spaces.GPU
|
208 |
+
def segment_with_points(
|
209 |
+
point_type,
|
210 |
+
session_state,
|
211 |
+
evt: gr.SelectData,
|
212 |
+
):
|
213 |
+
predictor = get_predictor(session_state)
|
214 |
+
predictor.to("cuda")
|
215 |
+
session_state["input_points"].append(evt.index)
|
216 |
+
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
217 |
+
|
218 |
+
if point_type == "include":
|
219 |
+
session_state["input_labels"].append(1)
|
220 |
+
elif point_type == "exclude":
|
221 |
+
session_state["input_labels"].append(0)
|
222 |
+
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
|
223 |
+
|
224 |
+
# Open the image and get its dimensions
|
225 |
+
transparent_background = Image.fromarray(session_state["first_frame"]).convert(
|
226 |
+
"RGBA"
|
227 |
+
)
|
228 |
+
w, h = transparent_background.size
|
229 |
+
|
230 |
+
# Define the circle radius as a fraction of the smaller dimension
|
231 |
+
fraction = 0.01 # You can adjust this value as needed
|
232 |
+
radius = int(fraction * min(w, h))
|
233 |
+
|
234 |
+
# Create a transparent layer to draw on
|
235 |
+
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
236 |
+
|
237 |
+
for index, track in enumerate(session_state["input_points"]):
|
238 |
+
if session_state["input_labels"][index] == 1:
|
239 |
+
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
|
240 |
+
else:
|
241 |
+
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
|
242 |
+
|
243 |
+
# Convert the transparent layer back to an image
|
244 |
+
transparent_layer = Image.fromarray(transparent_layer, "RGBA")
|
245 |
+
selected_point_map = Image.alpha_composite(
|
246 |
+
transparent_background, transparent_layer
|
247 |
+
)
|
248 |
+
|
249 |
+
# Let's add a positive click at (x, y) = (210, 350) to get started
|
250 |
+
points = np.array(session_state["input_points"], dtype=np.float32)
|
251 |
+
# for labels, `1` means positive click and `0` means negative click
|
252 |
+
labels = np.array(session_state["input_labels"], np.int32)
|
253 |
+
_, _, out_mask_logits = predictor.add_new_points(
|
254 |
+
inference_state=session_state["inference_state"],
|
255 |
+
frame_idx=0,
|
256 |
+
obj_id=OBJ_ID,
|
257 |
+
points=points,
|
258 |
+
labels=labels,
|
259 |
+
)
|
260 |
+
|
261 |
+
mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
|
262 |
+
first_frame_output = Image.alpha_composite(transparent_background, mask_image)
|
263 |
+
|
264 |
+
torch.cuda.empty_cache()
|
265 |
+
return selected_point_map, first_frame_output, session_state
|
266 |
+
|
267 |
+
|
268 |
+
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
269 |
+
if random_color:
|
270 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
271 |
+
else:
|
272 |
+
cmap = plt.get_cmap("tab10")
|
273 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
274 |
+
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
275 |
+
h, w = mask.shape[-2:]
|
276 |
+
mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
277 |
+
mask = (mask * 255).astype(np.uint8)
|
278 |
+
if convert_to_image:
|
279 |
+
mask = Image.fromarray(mask, "RGBA")
|
280 |
+
return mask
|
281 |
+
|
282 |
+
|
283 |
+
@spaces.GPU
|
284 |
+
def propagate_to_all(
|
285 |
+
video_in,
|
286 |
+
session_state,
|
287 |
+
):
|
288 |
+
predictor = get_predictor(session_state)
|
289 |
+
predictor.to("cuda")
|
290 |
+
if (
|
291 |
+
len(session_state["input_points"]) == 0
|
292 |
+
or video_in is None
|
293 |
+
or session_state["inference_state"] is None
|
294 |
+
):
|
295 |
+
return (
|
296 |
+
None,
|
297 |
+
session_state,
|
298 |
+
)
|
299 |
+
|
300 |
+
# run propagation throughout the video and collect the results in a dict
|
301 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
302 |
+
print("starting propagate_in_video")
|
303 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
304 |
+
session_state["inference_state"]
|
305 |
+
):
|
306 |
+
video_segments[out_frame_idx] = {
|
307 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
308 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
309 |
+
}
|
310 |
+
|
311 |
+
# obtain the segmentation results every few frames
|
312 |
+
vis_frame_stride = 1
|
313 |
+
|
314 |
+
output_frames = []
|
315 |
+
for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
|
316 |
+
transparent_background = Image.fromarray(
|
317 |
+
session_state["all_frames"][out_frame_idx]
|
318 |
+
).convert("RGBA")
|
319 |
+
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
320 |
+
mask_image = show_mask(out_mask)
|
321 |
+
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
322 |
+
output_frame = np.array(output_frame)
|
323 |
+
output_frames.append(output_frame)
|
324 |
+
|
325 |
+
torch.cuda.empty_cache()
|
326 |
+
|
327 |
+
# Create a video clip from the image sequence
|
328 |
+
original_fps = get_video_fps(video_in)
|
329 |
+
fps = original_fps # Frames per second
|
330 |
+
clip = ImageSequenceClip(output_frames, fps=fps)
|
331 |
+
# Write the result to a file
|
332 |
+
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
333 |
+
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
334 |
+
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
|
335 |
+
|
336 |
+
# Write the result to a file
|
337 |
+
clip.write_videofile(final_vid_output_path, codec="libx264")
|
338 |
+
|
339 |
+
return (
|
340 |
+
gr.update(value=final_vid_output_path),
|
341 |
+
session_state,
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def update_ui():
|
346 |
+
return gr.update(visible=True)
|
347 |
+
|
348 |
+
|
349 |
+
with gr.Blocks() as demo:
|
350 |
+
session_state = gr.State(
|
351 |
+
{
|
352 |
+
"first_frame": None,
|
353 |
+
"all_frames": None,
|
354 |
+
"input_points": [],
|
355 |
+
"input_labels": [],
|
356 |
+
"inference_state": None,
|
357 |
+
}
|
358 |
+
)
|
359 |
+
|
360 |
+
with gr.Column():
|
361 |
+
# Title
|
362 |
+
gr.Markdown(title)
|
363 |
+
with gr.Row():
|
364 |
+
|
365 |
+
with gr.Column():
|
366 |
+
# Instructions
|
367 |
+
gr.Markdown(description_p)
|
368 |
+
|
369 |
+
with gr.Accordion("Input Video", open=True) as video_in_drawer:
|
370 |
+
video_in = gr.Video(label="Input Video", format="mp4")
|
371 |
+
|
372 |
+
with gr.Row():
|
373 |
+
point_type = gr.Radio(
|
374 |
+
label="point type",
|
375 |
+
choices=["include", "exclude"],
|
376 |
+
value="include",
|
377 |
+
scale=2,
|
378 |
+
)
|
379 |
+
propagate_btn = gr.Button("Track", scale=1, variant="primary")
|
380 |
+
clear_points_btn = gr.Button("Clear Points", scale=1)
|
381 |
+
reset_btn = gr.Button("Reset", scale=1)
|
382 |
+
|
383 |
+
points_map = gr.Image(
|
384 |
+
label="Frame with Point Prompt", type="numpy", interactive=False
|
385 |
+
)
|
386 |
+
|
387 |
+
with gr.Column():
|
388 |
+
gr.Markdown("# Try some of the examples below ⬇️")
|
389 |
+
gr.Examples(
|
390 |
+
examples=examples,
|
391 |
+
inputs=[
|
392 |
+
video_in,
|
393 |
+
],
|
394 |
+
examples_per_page=8,
|
395 |
+
)
|
396 |
+
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
|
397 |
+
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
|
398 |
+
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
|
399 |
+
output_image = gr.Image(label="Reference Mask")
|
400 |
+
|
401 |
+
output_video = gr.Video(visible=False)
|
402 |
+
|
403 |
+
# When new video is uploaded
|
404 |
+
video_in.upload(
|
405 |
+
fn=preprocess_video_in,
|
406 |
+
inputs=[
|
407 |
+
video_in,
|
408 |
+
session_state,
|
409 |
+
],
|
410 |
+
outputs=[
|
411 |
+
video_in_drawer, # Accordion to hide uploaded video player
|
412 |
+
points_map, # Image component where we add new tracking points
|
413 |
+
output_image,
|
414 |
+
output_video,
|
415 |
+
session_state,
|
416 |
+
],
|
417 |
+
queue=False,
|
418 |
+
)
|
419 |
+
|
420 |
+
video_in.change(
|
421 |
+
fn=preprocess_video_in,
|
422 |
+
inputs=[
|
423 |
+
video_in,
|
424 |
+
session_state,
|
425 |
+
],
|
426 |
+
outputs=[
|
427 |
+
video_in_drawer, # Accordion to hide uploaded video player
|
428 |
+
points_map, # Image component where we add new tracking points
|
429 |
+
output_image,
|
430 |
+
output_video,
|
431 |
+
session_state,
|
432 |
+
],
|
433 |
+
queue=False,
|
434 |
+
)
|
435 |
+
|
436 |
+
# triggered when we click on image to add new points
|
437 |
+
points_map.select(
|
438 |
+
fn=segment_with_points,
|
439 |
+
inputs=[
|
440 |
+
point_type, # "include" or "exclude"
|
441 |
+
session_state,
|
442 |
+
],
|
443 |
+
outputs=[
|
444 |
+
points_map, # updated image with points
|
445 |
+
output_image,
|
446 |
+
session_state,
|
447 |
+
],
|
448 |
+
queue=False,
|
449 |
+
)
|
450 |
+
|
451 |
+
# Clear every points clicked and added to the map
|
452 |
+
clear_points_btn.click(
|
453 |
+
fn=clear_points,
|
454 |
+
inputs=session_state,
|
455 |
+
outputs=[
|
456 |
+
points_map,
|
457 |
+
output_image,
|
458 |
+
output_video,
|
459 |
+
session_state,
|
460 |
+
],
|
461 |
+
queue=False,
|
462 |
+
)
|
463 |
+
|
464 |
+
reset_btn.click(
|
465 |
+
fn=reset,
|
466 |
+
inputs=session_state,
|
467 |
+
outputs=[
|
468 |
+
video_in,
|
469 |
+
video_in_drawer,
|
470 |
+
points_map,
|
471 |
+
output_image,
|
472 |
+
output_video,
|
473 |
+
session_state,
|
474 |
+
],
|
475 |
+
queue=False,
|
476 |
+
)
|
477 |
+
|
478 |
+
propagate_btn.click(
|
479 |
+
fn=update_ui,
|
480 |
+
inputs=[],
|
481 |
+
outputs=output_video,
|
482 |
+
queue=False,
|
483 |
+
).then(
|
484 |
+
fn=propagate_to_all,
|
485 |
+
inputs=[
|
486 |
+
video_in,
|
487 |
+
session_state,
|
488 |
+
],
|
489 |
+
outputs=[
|
490 |
+
output_video,
|
491 |
+
session_state,
|
492 |
+
],
|
493 |
+
concurrency_limit=10,
|
494 |
+
queue=False,
|
495 |
+
)
|
496 |
+
|
497 |
+
|
498 |
+
demo.queue()
|
499 |
+
demo.launch()
|
checkpoints/download_ckpts.sh
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
# All rights reserved.
|
5 |
+
|
6 |
+
# This source code is licensed under the license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
|
9 |
+
# Use either wget or curl to download the checkpoints
|
10 |
+
if command -v wget &> /dev/null; then
|
11 |
+
CMD="wget"
|
12 |
+
elif command -v curl &> /dev/null; then
|
13 |
+
CMD="curl -L -O"
|
14 |
+
else
|
15 |
+
echo "Please install wget or curl to download the checkpoints."
|
16 |
+
exit 1
|
17 |
+
fi
|
18 |
+
|
19 |
+
# Define the URLs for SAM 2 checkpoints
|
20 |
+
# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
|
21 |
+
# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
|
22 |
+
# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
|
23 |
+
# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
|
24 |
+
# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
|
25 |
+
|
26 |
+
# Download each of the four checkpoints using wget
|
27 |
+
# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
|
28 |
+
# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
|
29 |
+
|
30 |
+
# echo "Downloading sam2_hiera_small.pt checkpoint..."
|
31 |
+
# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
|
32 |
+
|
33 |
+
# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
|
34 |
+
# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
|
35 |
+
|
36 |
+
# echo "Downloading sam2_hiera_large.pt checkpoint..."
|
37 |
+
# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
|
38 |
+
|
39 |
+
# Define the URLs for SAM 2.1 checkpoints
|
40 |
+
SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
|
41 |
+
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
|
42 |
+
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
|
43 |
+
sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
|
44 |
+
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
|
45 |
+
|
46 |
+
# SAM 2.1 checkpoints
|
47 |
+
echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
|
48 |
+
$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
|
49 |
+
|
50 |
+
echo "Downloading sam2.1_hiera_small.pt checkpoint..."
|
51 |
+
$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
|
52 |
+
|
53 |
+
echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
|
54 |
+
$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
|
55 |
+
|
56 |
+
echo "Downloading sam2.1_hiera_large.pt checkpoint..."
|
57 |
+
$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
|
58 |
+
|
59 |
+
echo "All checkpoints are downloaded successfully."
|
checkpoints/edgetam.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed2d4850b8792c239689b043c47046ec239b6e808a3d9b6ae676c803fd8780df
|
3 |
+
size 56116523
|
convert_weights.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
sd = torch.load(args.src, map_location="cpu")["model"]
|
14 |
+
sd = {k: v for k, v in sd.items() if "teacher" not in k}
|
15 |
+
sd = {
|
16 |
+
k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items()
|
17 |
+
}
|
18 |
+
sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()}
|
19 |
+
sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()}
|
20 |
+
sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()}
|
21 |
+
sd = {
|
22 |
+
k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items()
|
23 |
+
}
|
24 |
+
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
|
25 |
+
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
|
26 |
+
sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()}
|
27 |
+
sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()}
|
28 |
+
torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt"))
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument("--src", type=str, required=True)
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
main(args)
|
examples/01_breakdancer.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f6b15f61e741576f658e41939d0109f15f012691a62a86820418d3ae10d1f04
|
3 |
+
size 5251367
|
examples/01_dog.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21425eb9b7fce7d3fb23fa17b1c07bcbe99ba9e64e3b5347d8052ba6d9c38924
|
3 |
+
size 1738970
|
examples/02_cups.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6320759eede8c82c7d43fe1cc569f3a389758ec37b2b1dc9e9b7b4ce4110ab13
|
3 |
+
size 2543710
|
examples/02_hummingbird.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79018494f7cda9ea1fc57bc474fd468ba1291624d85a1b74312029013ed9eb1e
|
3 |
+
size 1098274
|
examples/03_blocks.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2868f2c62f5cfa28dabbe925fae95da7f896531630a0627e5edbc784ea42413d
|
3 |
+
size 1808832
|
examples/03_skateboarder.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8403b992e4c114cf01ccb56c95aeba58d38aac16a5d7611580c1ceb5ef0f3ca
|
3 |
+
size 2132383
|
examples/04_coffee.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9a439a8bf66910db3e2a9d9ea853bc24c760c8d4ca227f4c33051b1cef28c03
|
3 |
+
size 1177415
|
examples/04_octopus.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d705ebd61259495da51e254cba576b2362310033099173e7781624c14fc51e3d
|
3 |
+
size 4395411
|
examples/05_default_juggle.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:60d23773ad50389a5c2df9ba4ab05e9b4d39d21dfa7b75d15b46666ff7a4e2bb
|
3 |
+
size 1842699
|
examples/05_landing_dog_soccer.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65d828933e37694b47d8d22573d812a8b838b22b722a17af362a50daecfe92d9
|
3 |
+
size 2774490
|
examples/06_pingpong.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c5ba5fcd0abd07701f030bbdbeed937f56bd0950103a13dad664fc66a264f98
|
3 |
+
size 1341363
|
examples/07_snowboarder.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28108f8ae296c2b97b48768ab5efe556d57c3b631cb7d86be3b012a93fdc35ae
|
3 |
+
size 3579270
|
examples/08_driving.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6752358cdb184ce27e157ba9c7622dba0596ad8c9951c013ff2d027802079ae1
|
3 |
+
size 1353646
|
examples/09_birdcartoon.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89b3adfc1972b4a400fab3cc52da30a4bd6ee87feb7d3b5ee00e1030885f7f92
|
3 |
+
size 1839837
|
examples/10_cloth_magic.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:407ed3c96d070d4da764f4ae6a114b0ade5bce4fe8900a43c118ff2d636032c5
|
3 |
+
size 1377117
|
examples/11_polevault.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86e7ca662e5282796b1af115c76afb15d81d2ce33a15a12504c90b0831e53911
|
3 |
+
size 1712657
|
examples/12_hideandseek.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82fe69d56a9fa3c228e2dd4a4db05c36db36b75a2b132409400fb1f4af1102c9
|
3 |
+
size 2799452
|
examples/13_butterfly.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:284bb1cf4d6520d0ee1e7f3b0f100e38943410031c7f994a5f8fd95930ccdc02
|
3 |
+
size 2434416
|
examples/14_social_dog_training.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e364b566cfb8ef28f53615a15bf4079bdb1c6dada49828954706c8b6a96ac3f
|
3 |
+
size 1810664
|
examples/15_cricket.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9a7b88ba84dcc16defe2845f6c3b47cc4d2ef25e5752e87721b79ac3540e1e9
|
3 |
+
size 4976505
|
examples/16_robotarm.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f79112ba6937ddccc4be748d6a2ca64822ac85cdbbdeba7a7dc4e58ebfa2ce80
|
3 |
+
size 3575535
|
examples/17_childrendancing.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff42b6651f65431e967c810c24ebbaea1425524055735e932638e3843b42f55d
|
3 |
+
size 5722084
|
examples/18_threedogs.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76f5ad121c6e10b9ed6b8ea3453efe66c48952fabc964c3b1d9ab3fda0d8b5be
|
3 |
+
size 4254672
|
examples/19_cyclist.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c862441a39dd9c5ad823fc1709737208f923d3676cfa13d6567451b2d01f34f
|
3 |
+
size 2471696
|
examples/20_doughkneading.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1eaefb076968272b3e6c04def0ef8fd99524f53bfade63c58a54e4dc47fa8af9
|
3 |
+
size 1943984
|
examples/21_biker.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3bd0799a3a5ae035f26c326baa2ae0906cb8ac57c3016cb4c0ca04959c442c08
|
3 |
+
size 1031840
|
examples/22_dogskateboarder.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e44764e451cce362511ff9f8774d98f5c2f971567397ec1def9d71f52767833
|
3 |
+
size 3948915
|
examples/23_racecar.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df58709b52635efdb09a0ce2a3a60fa2ac7111413c004d206639dca9e1487cd0
|
3 |
+
size 5147963
|
examples/24_clownfish.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e60c4972915e1c8537801e2805b7729338c4a122d759028ade59703d18972d81
|
3 |
+
size 4649052
|
pyproject.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = [
|
3 |
+
"setuptools>=61.0",
|
4 |
+
"torch>=2.3.1",
|
5 |
+
]
|
6 |
+
build-backend = "setuptools.build_meta"
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.3.1
|
2 |
+
torchvision>=0.18.1
|
3 |
+
numpy>=1.24.4
|
4 |
+
tqdm>=4.66.1
|
5 |
+
hydra-core>=1.3.2
|
6 |
+
iopath>=0.1.10
|
7 |
+
pillow>=9.4.0
|
8 |
+
gradio==4.44.0
|
9 |
+
gradio_client==1.3.0
|
10 |
+
gradio_image_prompter==0.1.0
|
11 |
+
opencv-python==4.10.0.84
|
12 |
+
moviepy==1.0.3
|
13 |
+
pydantic==2.10.6
|
14 |
+
timm==1.0.15
|
15 |
+
eva-decord==0.6.1
|
sam2/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from hydra import initialize_config_module
|
8 |
+
from hydra.core.global_hydra import GlobalHydra
|
9 |
+
|
10 |
+
if not GlobalHydra.instance().is_initialized():
|
11 |
+
initialize_config_module("sam2", version_base="1.2")
|
sam2/automatic_mask_generator.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
13 |
+
|
14 |
+
from sam2.modeling.sam2_base import SAM2Base
|
15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
16 |
+
from sam2.utils.amg import (
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
MaskData,
|
28 |
+
remove_small_regions,
|
29 |
+
rle_to_mask,
|
30 |
+
uncrop_boxes_xyxy,
|
31 |
+
uncrop_masks,
|
32 |
+
uncrop_points,
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class SAM2AutomaticMaskGenerator:
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
model: SAM2Base,
|
40 |
+
points_per_side: Optional[int] = 32,
|
41 |
+
points_per_batch: int = 64,
|
42 |
+
pred_iou_thresh: float = 0.8,
|
43 |
+
stability_score_thresh: float = 0.95,
|
44 |
+
stability_score_offset: float = 1.0,
|
45 |
+
mask_threshold: float = 0.0,
|
46 |
+
box_nms_thresh: float = 0.7,
|
47 |
+
crop_n_layers: int = 0,
|
48 |
+
crop_nms_thresh: float = 0.7,
|
49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
50 |
+
crop_n_points_downscale_factor: int = 1,
|
51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
52 |
+
min_mask_region_area: int = 0,
|
53 |
+
output_mode: str = "binary_mask",
|
54 |
+
use_m2m: bool = False,
|
55 |
+
multimask_output: bool = True,
|
56 |
+
**kwargs,
|
57 |
+
) -> None:
|
58 |
+
"""
|
59 |
+
Using a SAM 2 model, generates masks for the entire image.
|
60 |
+
Generates a grid of point prompts over the image, then filters
|
61 |
+
low quality and duplicate masks. The default settings are chosen
|
62 |
+
for SAM 2 with a HieraL backbone.
|
63 |
+
|
64 |
+
Arguments:
|
65 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
66 |
+
points_per_side (int or None): The number of points to be sampled
|
67 |
+
along one side of the image. The total number of points is
|
68 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
69 |
+
point sampling.
|
70 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
71 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
72 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
73 |
+
model's predicted mask quality.
|
74 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
75 |
+
the stability of the mask under changes to the cutoff used to binarize
|
76 |
+
the model's mask predictions.
|
77 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
78 |
+
calculated the stability score.
|
79 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
80 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
81 |
+
suppression to filter duplicate masks.
|
82 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
83 |
+
crops of the image. Sets the number of layers to run, where each
|
84 |
+
layer has 2**i_layer number of image crops.
|
85 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
86 |
+
suppression to filter duplicate masks between different crops.
|
87 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
88 |
+
In the first crop layer, crops will overlap by this fraction of
|
89 |
+
the image length. Later layers with more crops scale down this overlap.
|
90 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
91 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
92 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
93 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
94 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
95 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
96 |
+
to remove disconnected regions and holes in masks with area smaller
|
97 |
+
than min_mask_region_area. Requires opencv.
|
98 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
99 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
100 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
101 |
+
memory.
|
102 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
103 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
104 |
+
"""
|
105 |
+
|
106 |
+
assert (points_per_side is None) != (
|
107 |
+
point_grids is None
|
108 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
109 |
+
if points_per_side is not None:
|
110 |
+
self.point_grids = build_all_layer_point_grids(
|
111 |
+
points_per_side,
|
112 |
+
crop_n_layers,
|
113 |
+
crop_n_points_downscale_factor,
|
114 |
+
)
|
115 |
+
elif point_grids is not None:
|
116 |
+
self.point_grids = point_grids
|
117 |
+
else:
|
118 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
119 |
+
|
120 |
+
assert output_mode in [
|
121 |
+
"binary_mask",
|
122 |
+
"uncompressed_rle",
|
123 |
+
"coco_rle",
|
124 |
+
], f"Unknown output_mode {output_mode}."
|
125 |
+
if output_mode == "coco_rle":
|
126 |
+
try:
|
127 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
128 |
+
except ImportError as e:
|
129 |
+
print("Please install pycocotools")
|
130 |
+
raise e
|
131 |
+
|
132 |
+
self.predictor = SAM2ImagePredictor(
|
133 |
+
model,
|
134 |
+
max_hole_area=min_mask_region_area,
|
135 |
+
max_sprinkle_area=min_mask_region_area,
|
136 |
+
)
|
137 |
+
self.points_per_batch = points_per_batch
|
138 |
+
self.pred_iou_thresh = pred_iou_thresh
|
139 |
+
self.stability_score_thresh = stability_score_thresh
|
140 |
+
self.stability_score_offset = stability_score_offset
|
141 |
+
self.mask_threshold = mask_threshold
|
142 |
+
self.box_nms_thresh = box_nms_thresh
|
143 |
+
self.crop_n_layers = crop_n_layers
|
144 |
+
self.crop_nms_thresh = crop_nms_thresh
|
145 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
146 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
147 |
+
self.min_mask_region_area = min_mask_region_area
|
148 |
+
self.output_mode = output_mode
|
149 |
+
self.use_m2m = use_m2m
|
150 |
+
self.multimask_output = multimask_output
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
154 |
+
"""
|
155 |
+
Load a pretrained model from the Hugging Face hub.
|
156 |
+
|
157 |
+
Arguments:
|
158 |
+
model_id (str): The Hugging Face repository ID.
|
159 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
(SAM2AutomaticMaskGenerator): The loaded model.
|
163 |
+
"""
|
164 |
+
from sam2.build_sam import build_sam2_hf
|
165 |
+
|
166 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
167 |
+
return cls(sam_model, **kwargs)
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
171 |
+
"""
|
172 |
+
Generates masks for the given image.
|
173 |
+
|
174 |
+
Arguments:
|
175 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
179 |
+
a dict containing the following keys:
|
180 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
181 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
182 |
+
is a dictionary containing the RLE.
|
183 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
184 |
+
area (int): The area in pixels of the mask.
|
185 |
+
predicted_iou (float): The model's own prediction of the mask's
|
186 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
187 |
+
point_coords (list(list(float))): The point coordinates input
|
188 |
+
to the model to generate this mask.
|
189 |
+
stability_score (float): A measure of the mask's quality. This
|
190 |
+
is filtered on using the stability_score_thresh parameter.
|
191 |
+
crop_box (list(float)): The crop of the image used to generate
|
192 |
+
the mask, given in XYWH format.
|
193 |
+
"""
|
194 |
+
|
195 |
+
# Generate masks
|
196 |
+
mask_data = self._generate_masks(image)
|
197 |
+
|
198 |
+
# Encode masks
|
199 |
+
if self.output_mode == "coco_rle":
|
200 |
+
mask_data["segmentations"] = [
|
201 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
202 |
+
]
|
203 |
+
elif self.output_mode == "binary_mask":
|
204 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
205 |
+
else:
|
206 |
+
mask_data["segmentations"] = mask_data["rles"]
|
207 |
+
|
208 |
+
# Write mask records
|
209 |
+
curr_anns = []
|
210 |
+
for idx in range(len(mask_data["segmentations"])):
|
211 |
+
ann = {
|
212 |
+
"segmentation": mask_data["segmentations"][idx],
|
213 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
214 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
215 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
216 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
217 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
218 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
219 |
+
}
|
220 |
+
curr_anns.append(ann)
|
221 |
+
|
222 |
+
return curr_anns
|
223 |
+
|
224 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
225 |
+
orig_size = image.shape[:2]
|
226 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
227 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
228 |
+
)
|
229 |
+
|
230 |
+
# Iterate over image crops
|
231 |
+
data = MaskData()
|
232 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
233 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
234 |
+
data.cat(crop_data)
|
235 |
+
|
236 |
+
# Remove duplicate masks between crops
|
237 |
+
if len(crop_boxes) > 1:
|
238 |
+
# Prefer masks from smaller crops
|
239 |
+
scores = 1 / box_area(data["crop_boxes"])
|
240 |
+
scores = scores.to(data["boxes"].device)
|
241 |
+
keep_by_nms = batched_nms(
|
242 |
+
data["boxes"].float(),
|
243 |
+
scores,
|
244 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
245 |
+
iou_threshold=self.crop_nms_thresh,
|
246 |
+
)
|
247 |
+
data.filter(keep_by_nms)
|
248 |
+
data.to_numpy()
|
249 |
+
return data
|
250 |
+
|
251 |
+
def _process_crop(
|
252 |
+
self,
|
253 |
+
image: np.ndarray,
|
254 |
+
crop_box: List[int],
|
255 |
+
crop_layer_idx: int,
|
256 |
+
orig_size: Tuple[int, ...],
|
257 |
+
) -> MaskData:
|
258 |
+
# Crop the image and calculate embeddings
|
259 |
+
x0, y0, x1, y1 = crop_box
|
260 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
261 |
+
cropped_im_size = cropped_im.shape[:2]
|
262 |
+
self.predictor.set_image(cropped_im)
|
263 |
+
|
264 |
+
# Get points for this crop
|
265 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
266 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
267 |
+
|
268 |
+
# Generate masks for this crop in batches
|
269 |
+
data = MaskData()
|
270 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
271 |
+
batch_data = self._process_batch(
|
272 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True
|
273 |
+
)
|
274 |
+
data.cat(batch_data)
|
275 |
+
del batch_data
|
276 |
+
self.predictor.reset_predictor()
|
277 |
+
|
278 |
+
# Remove duplicates within this crop.
|
279 |
+
keep_by_nms = batched_nms(
|
280 |
+
data["boxes"].float(),
|
281 |
+
data["iou_preds"],
|
282 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
283 |
+
iou_threshold=self.box_nms_thresh,
|
284 |
+
)
|
285 |
+
data.filter(keep_by_nms)
|
286 |
+
|
287 |
+
# Return to the original image frame
|
288 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
289 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
290 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
291 |
+
|
292 |
+
return data
|
293 |
+
|
294 |
+
def _process_batch(
|
295 |
+
self,
|
296 |
+
points: np.ndarray,
|
297 |
+
im_size: Tuple[int, ...],
|
298 |
+
crop_box: List[int],
|
299 |
+
orig_size: Tuple[int, ...],
|
300 |
+
normalize=False,
|
301 |
+
) -> MaskData:
|
302 |
+
orig_h, orig_w = orig_size
|
303 |
+
|
304 |
+
# Run model on this batch
|
305 |
+
points = torch.as_tensor(
|
306 |
+
points, dtype=torch.float32, device=self.predictor.device
|
307 |
+
)
|
308 |
+
in_points = self.predictor._transforms.transform_coords(
|
309 |
+
points, normalize=normalize, orig_hw=im_size
|
310 |
+
)
|
311 |
+
in_labels = torch.ones(
|
312 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
313 |
+
)
|
314 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
315 |
+
in_points[:, None, :],
|
316 |
+
in_labels[:, None],
|
317 |
+
multimask_output=self.multimask_output,
|
318 |
+
return_logits=True,
|
319 |
+
)
|
320 |
+
|
321 |
+
# Serialize predictions and store in MaskData
|
322 |
+
data = MaskData(
|
323 |
+
masks=masks.flatten(0, 1),
|
324 |
+
iou_preds=iou_preds.flatten(0, 1),
|
325 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
326 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
327 |
+
)
|
328 |
+
del masks
|
329 |
+
|
330 |
+
if not self.use_m2m:
|
331 |
+
# Filter by predicted IoU
|
332 |
+
if self.pred_iou_thresh > 0.0:
|
333 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
334 |
+
data.filter(keep_mask)
|
335 |
+
|
336 |
+
# Calculate and filter by stability score
|
337 |
+
data["stability_score"] = calculate_stability_score(
|
338 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
339 |
+
)
|
340 |
+
if self.stability_score_thresh > 0.0:
|
341 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
342 |
+
data.filter(keep_mask)
|
343 |
+
else:
|
344 |
+
# One step refinement using previous mask predictions
|
345 |
+
in_points = self.predictor._transforms.transform_coords(
|
346 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
347 |
+
)
|
348 |
+
labels = torch.ones(
|
349 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
350 |
+
)
|
351 |
+
masks, ious = self.refine_with_m2m(
|
352 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch
|
353 |
+
)
|
354 |
+
data["masks"] = masks.squeeze(1)
|
355 |
+
data["iou_preds"] = ious.squeeze(1)
|
356 |
+
|
357 |
+
if self.pred_iou_thresh > 0.0:
|
358 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
359 |
+
data.filter(keep_mask)
|
360 |
+
|
361 |
+
data["stability_score"] = calculate_stability_score(
|
362 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
363 |
+
)
|
364 |
+
if self.stability_score_thresh > 0.0:
|
365 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
366 |
+
data.filter(keep_mask)
|
367 |
+
|
368 |
+
# Threshold masks and calculate boxes
|
369 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
370 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
371 |
+
|
372 |
+
# Filter boxes that touch crop boundaries
|
373 |
+
keep_mask = ~is_box_near_crop_edge(
|
374 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
375 |
+
)
|
376 |
+
if not torch.all(keep_mask):
|
377 |
+
data.filter(keep_mask)
|
378 |
+
|
379 |
+
# Compress to RLE
|
380 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
381 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
382 |
+
del data["masks"]
|
383 |
+
|
384 |
+
return data
|
385 |
+
|
386 |
+
@staticmethod
|
387 |
+
def postprocess_small_regions(
|
388 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
389 |
+
) -> MaskData:
|
390 |
+
"""
|
391 |
+
Removes small disconnected regions and holes in masks, then reruns
|
392 |
+
box NMS to remove any new duplicates.
|
393 |
+
|
394 |
+
Edits mask_data in place.
|
395 |
+
|
396 |
+
Requires open-cv as a dependency.
|
397 |
+
"""
|
398 |
+
if len(mask_data["rles"]) == 0:
|
399 |
+
return mask_data
|
400 |
+
|
401 |
+
# Filter small disconnected regions and holes
|
402 |
+
new_masks = []
|
403 |
+
scores = []
|
404 |
+
for rle in mask_data["rles"]:
|
405 |
+
mask = rle_to_mask(rle)
|
406 |
+
|
407 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
408 |
+
unchanged = not changed
|
409 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
410 |
+
unchanged = unchanged and not changed
|
411 |
+
|
412 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
413 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
414 |
+
# so NMS will prefer ones that didn't need postprocessing
|
415 |
+
scores.append(float(unchanged))
|
416 |
+
|
417 |
+
# Recalculate boxes and remove any new duplicates
|
418 |
+
masks = torch.cat(new_masks, dim=0)
|
419 |
+
boxes = batched_mask_to_box(masks)
|
420 |
+
keep_by_nms = batched_nms(
|
421 |
+
boxes.float(),
|
422 |
+
torch.as_tensor(scores),
|
423 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
424 |
+
iou_threshold=nms_thresh,
|
425 |
+
)
|
426 |
+
|
427 |
+
# Only recalculate RLEs for masks that have changed
|
428 |
+
for i_mask in keep_by_nms:
|
429 |
+
if scores[i_mask] == 0.0:
|
430 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
431 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
432 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
433 |
+
mask_data.filter(keep_by_nms)
|
434 |
+
|
435 |
+
return mask_data
|
436 |
+
|
437 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
|
438 |
+
new_masks = []
|
439 |
+
new_iou_preds = []
|
440 |
+
|
441 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
442 |
+
points_per_batch, points, point_labels, low_res_masks
|
443 |
+
):
|
444 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
445 |
+
cur_points[:, None, :],
|
446 |
+
cur_point_labels[:, None],
|
447 |
+
mask_input=low_res_mask[:, None, :],
|
448 |
+
multimask_output=False,
|
449 |
+
return_logits=True,
|
450 |
+
)
|
451 |
+
new_masks.append(best_masks)
|
452 |
+
new_iou_preds.append(best_iou_preds)
|
453 |
+
masks = torch.cat(new_masks, dim=0)
|
454 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
sam2/build_sam.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import sam2
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from hydra import compose
|
14 |
+
from hydra.utils import instantiate
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
|
17 |
+
# Check if the user is running Python from the parent directory of the sam2 repo
|
18 |
+
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
19 |
+
# it could shadow the sam2 package and cause issues.
|
20 |
+
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
21 |
+
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
22 |
+
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
23 |
+
# This typically happens because the user is running Python from the parent directory
|
24 |
+
# that contains the sam2 repo they cloned.
|
25 |
+
raise RuntimeError(
|
26 |
+
"You're likely running Python from the parent directory of the sam2 repository "
|
27 |
+
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
28 |
+
"This is not supported since the `sam2` Python package could be shadowed by the "
|
29 |
+
"repository name (the repository is also named `sam2` and contains the Python package "
|
30 |
+
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
31 |
+
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
36 |
+
"facebook/sam2-hiera-tiny": (
|
37 |
+
"configs/sam2/sam2_hiera_t.yaml",
|
38 |
+
"sam2_hiera_tiny.pt",
|
39 |
+
),
|
40 |
+
"facebook/sam2-hiera-small": (
|
41 |
+
"configs/sam2/sam2_hiera_s.yaml",
|
42 |
+
"sam2_hiera_small.pt",
|
43 |
+
),
|
44 |
+
"facebook/sam2-hiera-base-plus": (
|
45 |
+
"configs/sam2/sam2_hiera_b+.yaml",
|
46 |
+
"sam2_hiera_base_plus.pt",
|
47 |
+
),
|
48 |
+
"facebook/sam2-hiera-large": (
|
49 |
+
"configs/sam2/sam2_hiera_l.yaml",
|
50 |
+
"sam2_hiera_large.pt",
|
51 |
+
),
|
52 |
+
"facebook/sam2.1-hiera-tiny": (
|
53 |
+
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
54 |
+
"sam2.1_hiera_tiny.pt",
|
55 |
+
),
|
56 |
+
"facebook/sam2.1-hiera-small": (
|
57 |
+
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
58 |
+
"sam2.1_hiera_small.pt",
|
59 |
+
),
|
60 |
+
"facebook/sam2.1-hiera-base-plus": (
|
61 |
+
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
62 |
+
"sam2.1_hiera_base_plus.pt",
|
63 |
+
),
|
64 |
+
"facebook/sam2.1-hiera-large": (
|
65 |
+
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
66 |
+
"sam2.1_hiera_large.pt",
|
67 |
+
),
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def build_sam2(
|
72 |
+
config_file,
|
73 |
+
ckpt_path=None,
|
74 |
+
device="cuda",
|
75 |
+
mode="eval",
|
76 |
+
hydra_overrides_extra=[],
|
77 |
+
apply_postprocessing=True,
|
78 |
+
**kwargs,
|
79 |
+
):
|
80 |
+
|
81 |
+
if apply_postprocessing:
|
82 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
83 |
+
hydra_overrides_extra += [
|
84 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
85 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
86 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
87 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
88 |
+
]
|
89 |
+
# Read config and init model
|
90 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
91 |
+
OmegaConf.resolve(cfg)
|
92 |
+
model = instantiate(cfg.model, _recursive_=True)
|
93 |
+
_load_checkpoint(model, ckpt_path)
|
94 |
+
model = model.to(device)
|
95 |
+
if mode == "eval":
|
96 |
+
model.eval()
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
def build_sam2_video_predictor(
|
101 |
+
config_file,
|
102 |
+
ckpt_path=None,
|
103 |
+
device="cuda",
|
104 |
+
mode="eval",
|
105 |
+
hydra_overrides_extra=[],
|
106 |
+
apply_postprocessing=True,
|
107 |
+
**kwargs,
|
108 |
+
):
|
109 |
+
hydra_overrides = [
|
110 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
111 |
+
]
|
112 |
+
if apply_postprocessing:
|
113 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
114 |
+
hydra_overrides_extra += [
|
115 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
116 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
117 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
118 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
119 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
120 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
121 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
122 |
+
"++model.fill_hole_area=8",
|
123 |
+
]
|
124 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
125 |
+
|
126 |
+
# Read config and init model
|
127 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
128 |
+
OmegaConf.resolve(cfg)
|
129 |
+
print("configuration solved")
|
130 |
+
model = instantiate(cfg.model, _recursive_=True)
|
131 |
+
print("model instantiated")
|
132 |
+
_load_checkpoint(model, ckpt_path)
|
133 |
+
print("checkpoint loaded")
|
134 |
+
model = model.to(device)
|
135 |
+
if mode == "eval":
|
136 |
+
model.eval()
|
137 |
+
print("model ready")
|
138 |
+
return model
|
139 |
+
|
140 |
+
|
141 |
+
def _hf_download(model_id):
|
142 |
+
from huggingface_hub import hf_hub_download
|
143 |
+
|
144 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
145 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
146 |
+
return config_name, ckpt_path
|
147 |
+
|
148 |
+
|
149 |
+
def build_sam2_hf(model_id, **kwargs):
|
150 |
+
config_name, ckpt_path = _hf_download(model_id)
|
151 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
152 |
+
|
153 |
+
|
154 |
+
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
155 |
+
config_name, ckpt_path = _hf_download(model_id)
|
156 |
+
return build_sam2_video_predictor(
|
157 |
+
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def _load_checkpoint(model, ckpt_path):
|
162 |
+
if ckpt_path is not None:
|
163 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
164 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
165 |
+
if missing_keys:
|
166 |
+
logging.error(missing_keys)
|
167 |
+
raise RuntimeError()
|
168 |
+
if unexpected_keys:
|
169 |
+
logging.error(unexpected_keys)
|
170 |
+
raise RuntimeError()
|
171 |
+
logging.info("Loaded checkpoint sucessfully")
|
sam2/configs/edgetam.yaml
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.timm.TimmBackbone
|
11 |
+
name: repvit_m1.dist_in1k
|
12 |
+
features:
|
13 |
+
- layer0
|
14 |
+
- layer1
|
15 |
+
- layer2
|
16 |
+
- layer3
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [384, 192, 96, 48]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [32, 32]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttentionv2
|
53 |
+
rope_theta: 10000.0
|
54 |
+
q_sizes: [64, 64]
|
55 |
+
k_sizes: [16, 16]
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 2
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
spatial_perceiver:
|
89 |
+
_target_: sam2.modeling.perceiver.PerceiverResampler
|
90 |
+
depth: 2
|
91 |
+
dim: 64
|
92 |
+
dim_head: 64
|
93 |
+
heads: 1
|
94 |
+
ff_mult: 4
|
95 |
+
hidden_dropout_p: 0.
|
96 |
+
attention_dropout_p: 0.
|
97 |
+
pos_enc_at_key_value: true # implicit pos
|
98 |
+
concat_kv_latents: false
|
99 |
+
num_latents: 256
|
100 |
+
num_latents_2d: 256
|
101 |
+
position_encoding:
|
102 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
103 |
+
num_pos_feats: 64
|
104 |
+
normalize: true
|
105 |
+
scale: null
|
106 |
+
temperature: 10000
|
107 |
+
use_self_attn: true
|
108 |
+
|
109 |
+
num_maskmem: 7
|
110 |
+
image_size: 1024
|
111 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
112 |
+
sigmoid_scale_for_mem_enc: 20.0
|
113 |
+
sigmoid_bias_for_mem_enc: -10.0
|
114 |
+
use_mask_input_as_output_without_sam: true
|
115 |
+
# Memory
|
116 |
+
directly_add_no_mem_embed: true
|
117 |
+
# use high-resolution feature map in the SAM mask decoder
|
118 |
+
use_high_res_features_in_sam: true
|
119 |
+
# output 3 masks on the first click on initial conditioning frames
|
120 |
+
multimask_output_in_sam: true
|
121 |
+
# SAM heads
|
122 |
+
iou_prediction_use_sigmoid: True
|
123 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
124 |
+
use_obj_ptrs_in_encoder: true
|
125 |
+
add_tpos_enc_to_obj_ptrs: false
|
126 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
127 |
+
# object occlusion prediction
|
128 |
+
pred_obj_scores: true
|
129 |
+
pred_obj_scores_mlp: true
|
130 |
+
fixed_no_obj_ptr: true
|
131 |
+
# multimask tracking settings
|
132 |
+
multimask_output_for_tracking: true
|
133 |
+
use_multimask_token_for_obj_ptr: true
|
134 |
+
multimask_min_pt_num: 0
|
135 |
+
multimask_max_pt_num: 1
|
136 |
+
use_mlp_for_obj_ptr_proj: true
|
137 |
+
# Compilation flag
|
138 |
+
compile_image_encoder: false
|
sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [32, 32]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [32, 32]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
no_obj_embed_spatial: true
|
93 |
+
# use high-resolution feature map in the SAM mask decoder
|
94 |
+
use_high_res_features_in_sam: true
|
95 |
+
# output 3 masks on the first click on initial conditioning frames
|
96 |
+
multimask_output_in_sam: true
|
97 |
+
# SAM heads
|
98 |
+
iou_prediction_use_sigmoid: True
|
99 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
100 |
+
use_obj_ptrs_in_encoder: true
|
101 |
+
add_tpos_enc_to_obj_ptrs: true
|
102 |
+
proj_tpos_enc_in_obj_ptrs: true
|
103 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
+
# object occlusion prediction
|
106 |
+
pred_obj_scores: true
|
107 |
+
pred_obj_scores_mlp: true
|
108 |
+
fixed_no_obj_ptr: true
|
109 |
+
# multimask tracking settings
|
110 |
+
multimask_output_for_tracking: true
|
111 |
+
use_multimask_token_for_obj_ptr: true
|
112 |
+
multimask_min_pt_num: 0
|
113 |
+
multimask_max_pt_num: 1
|
114 |
+
use_mlp_for_obj_ptr_proj: true
|
115 |
+
# Compilation flag
|
116 |
+
compile_image_encoder: False
|
sam2/configs/sam2.1/sam2.1_hiera_l.yaml
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 144
|
12 |
+
num_heads: 2
|
13 |
+
stages: [2, 6, 36, 4]
|
14 |
+
global_att_blocks: [23, 33, 43]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
window_spec: [8, 4, 16, 8]
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [32, 32]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
+
rope_theta: 10000.0
|
54 |
+
feat_sizes: [32, 32]
|
55 |
+
rope_k_repeat: True
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 4
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
num_maskmem: 7
|
89 |
+
image_size: 1024
|
90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
no_obj_embed_spatial: true
|
97 |
+
# use high-resolution feature map in the SAM mask decoder
|
98 |
+
use_high_res_features_in_sam: true
|
99 |
+
# output 3 masks on the first click on initial conditioning frames
|
100 |
+
multimask_output_in_sam: true
|
101 |
+
# SAM heads
|
102 |
+
iou_prediction_use_sigmoid: True
|
103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
104 |
+
use_obj_ptrs_in_encoder: true
|
105 |
+
add_tpos_enc_to_obj_ptrs: true
|
106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
109 |
+
# object occlusion prediction
|
110 |
+
pred_obj_scores: true
|
111 |
+
pred_obj_scores_mlp: true
|
112 |
+
fixed_no_obj_ptr: true
|
113 |
+
# multimask tracking settings
|
114 |
+
multimask_output_for_tracking: true
|
115 |
+
use_multimask_token_for_obj_ptr: true
|
116 |
+
multimask_min_pt_num: 0
|
117 |
+
multimask_max_pt_num: 1
|
118 |
+
use_mlp_for_obj_ptr_proj: true
|
119 |
+
# Compilation flag
|
120 |
+
compile_image_encoder: False
|
sam2/configs/sam2.1/sam2.1_hiera_s.yaml
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 11, 2]
|
14 |
+
global_att_blocks: [7, 10, 13]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
92 |
+
use_mask_input_as_output_without_sam: true
|
93 |
+
# Memory
|
94 |
+
directly_add_no_mem_embed: true
|
95 |
+
no_obj_embed_spatial: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: true
|
105 |
+
proj_tpos_enc_in_obj_ptrs: true
|
106 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
107 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
108 |
+
# object occlusion prediction
|
109 |
+
pred_obj_scores: true
|
110 |
+
pred_obj_scores_mlp: true
|
111 |
+
fixed_no_obj_ptr: true
|
112 |
+
# multimask tracking settings
|
113 |
+
multimask_output_for_tracking: true
|
114 |
+
use_multimask_token_for_obj_ptr: true
|
115 |
+
multimask_min_pt_num: 0
|
116 |
+
multimask_max_pt_num: 1
|
117 |
+
use_mlp_for_obj_ptr_proj: true
|
118 |
+
# Compilation flag
|
119 |
+
compile_image_encoder: False
|
sam2/configs/sam2.1/sam2.1_hiera_t.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 7, 2]
|
14 |
+
global_att_blocks: [5, 7, 9]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
# SAM decoder
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
no_obj_embed_spatial: true
|
97 |
+
# use high-resolution feature map in the SAM mask decoder
|
98 |
+
use_high_res_features_in_sam: true
|
99 |
+
# output 3 masks on the first click on initial conditioning frames
|
100 |
+
multimask_output_in_sam: true
|
101 |
+
# SAM heads
|
102 |
+
iou_prediction_use_sigmoid: True
|
103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
104 |
+
use_obj_ptrs_in_encoder: true
|
105 |
+
add_tpos_enc_to_obj_ptrs: true
|
106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
109 |
+
# object occlusion prediction
|
110 |
+
pred_obj_scores: true
|
111 |
+
pred_obj_scores_mlp: true
|
112 |
+
fixed_no_obj_ptr: true
|
113 |
+
# multimask tracking settings
|
114 |
+
multimask_output_for_tracking: true
|
115 |
+
use_multimask_token_for_obj_ptr: true
|
116 |
+
multimask_min_pt_num: 0
|
117 |
+
multimask_max_pt_num: 1
|
118 |
+
use_mlp_for_obj_ptr_proj: true
|
119 |
+
# Compilation flag
|
120 |
+
# HieraT does not currently support compilation, should always be set to False
|
121 |
+
compile_image_encoder: False
|
sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
scratch:
|
4 |
+
resolution: 1024
|
5 |
+
train_batch_size: 1
|
6 |
+
num_train_workers: 10
|
7 |
+
num_frames: 8
|
8 |
+
max_num_objects: 3
|
9 |
+
base_lr: 5.0e-6
|
10 |
+
vision_lr: 3.0e-06
|
11 |
+
phases_per_epoch: 1
|
12 |
+
num_epochs: 40
|
13 |
+
|
14 |
+
dataset:
|
15 |
+
# PATHS to Dataset
|
16 |
+
img_folder: null # PATH to MOSE JPEGImages folder
|
17 |
+
gt_folder: null # PATH to MOSE Annotations folder
|
18 |
+
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
|
19 |
+
multiplier: 2
|
20 |
+
|
21 |
+
# Video transforms
|
22 |
+
vos:
|
23 |
+
train_transforms:
|
24 |
+
- _target_: training.dataset.transforms.ComposeAPI
|
25 |
+
transforms:
|
26 |
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
27 |
+
consistent_transform: True
|
28 |
+
- _target_: training.dataset.transforms.RandomAffine
|
29 |
+
degrees: 25
|
30 |
+
shear: 20
|
31 |
+
image_interpolation: bilinear
|
32 |
+
consistent_transform: True
|
33 |
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
34 |
+
sizes: ${scratch.resolution}
|
35 |
+
square: true
|
36 |
+
consistent_transform: True
|
37 |
+
- _target_: training.dataset.transforms.ColorJitter
|
38 |
+
consistent_transform: True
|
39 |
+
brightness: 0.1
|
40 |
+
contrast: 0.03
|
41 |
+
saturation: 0.03
|
42 |
+
hue: null
|
43 |
+
- _target_: training.dataset.transforms.RandomGrayscale
|
44 |
+
p: 0.05
|
45 |
+
consistent_transform: True
|
46 |
+
- _target_: training.dataset.transforms.ColorJitter
|
47 |
+
consistent_transform: False
|
48 |
+
brightness: 0.1
|
49 |
+
contrast: 0.05
|
50 |
+
saturation: 0.05
|
51 |
+
hue: null
|
52 |
+
- _target_: training.dataset.transforms.ToTensorAPI
|
53 |
+
- _target_: training.dataset.transforms.NormalizeAPI
|
54 |
+
mean: [0.485, 0.456, 0.406]
|
55 |
+
std: [0.229, 0.224, 0.225]
|
56 |
+
|
57 |
+
trainer:
|
58 |
+
_target_: training.trainer.Trainer
|
59 |
+
mode: train_only
|
60 |
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
61 |
+
accelerator: cuda
|
62 |
+
seed_value: 123
|
63 |
+
|
64 |
+
model:
|
65 |
+
_target_: training.model.sam2.SAM2Train
|
66 |
+
image_encoder:
|
67 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
68 |
+
scalp: 1
|
69 |
+
trunk:
|
70 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
71 |
+
embed_dim: 112
|
72 |
+
num_heads: 2
|
73 |
+
drop_path_rate: 0.1
|
74 |
+
neck:
|
75 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
76 |
+
position_encoding:
|
77 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
78 |
+
num_pos_feats: 256
|
79 |
+
normalize: true
|
80 |
+
scale: null
|
81 |
+
temperature: 10000
|
82 |
+
d_model: 256
|
83 |
+
backbone_channel_list: [896, 448, 224, 112]
|
84 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
85 |
+
fpn_interp_model: nearest
|
86 |
+
|
87 |
+
memory_attention:
|
88 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
89 |
+
d_model: 256
|
90 |
+
pos_enc_at_input: true
|
91 |
+
layer:
|
92 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
93 |
+
activation: relu
|
94 |
+
dim_feedforward: 2048
|
95 |
+
dropout: 0.1
|
96 |
+
pos_enc_at_attn: false
|
97 |
+
self_attention:
|
98 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
99 |
+
rope_theta: 10000.0
|
100 |
+
feat_sizes: [32, 32]
|
101 |
+
embedding_dim: 256
|
102 |
+
num_heads: 1
|
103 |
+
downsample_rate: 1
|
104 |
+
dropout: 0.1
|
105 |
+
d_model: 256
|
106 |
+
pos_enc_at_cross_attn_keys: true
|
107 |
+
pos_enc_at_cross_attn_queries: false
|
108 |
+
cross_attention:
|
109 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
110 |
+
rope_theta: 10000.0
|
111 |
+
feat_sizes: [32, 32]
|
112 |
+
rope_k_repeat: True
|
113 |
+
embedding_dim: 256
|
114 |
+
num_heads: 1
|
115 |
+
downsample_rate: 1
|
116 |
+
dropout: 0.1
|
117 |
+
kv_in_dim: 64
|
118 |
+
num_layers: 4
|
119 |
+
|
120 |
+
memory_encoder:
|
121 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
122 |
+
out_dim: 64
|
123 |
+
position_encoding:
|
124 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
125 |
+
num_pos_feats: 64
|
126 |
+
normalize: true
|
127 |
+
scale: null
|
128 |
+
temperature: 10000
|
129 |
+
mask_downsampler:
|
130 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
131 |
+
kernel_size: 3
|
132 |
+
stride: 2
|
133 |
+
padding: 1
|
134 |
+
fuser:
|
135 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
136 |
+
layer:
|
137 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
138 |
+
dim: 256
|
139 |
+
kernel_size: 7
|
140 |
+
padding: 3
|
141 |
+
layer_scale_init_value: 1e-6
|
142 |
+
use_dwconv: True # depth-wise convs
|
143 |
+
num_layers: 2
|
144 |
+
|
145 |
+
num_maskmem: 7
|
146 |
+
image_size: ${scratch.resolution}
|
147 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
148 |
+
sigmoid_scale_for_mem_enc: 20.0
|
149 |
+
sigmoid_bias_for_mem_enc: -10.0
|
150 |
+
use_mask_input_as_output_without_sam: true
|
151 |
+
# Memory
|
152 |
+
directly_add_no_mem_embed: true
|
153 |
+
no_obj_embed_spatial: true
|
154 |
+
# use high-resolution feature map in the SAM mask decoder
|
155 |
+
use_high_res_features_in_sam: true
|
156 |
+
# output 3 masks on the first click on initial conditioning frames
|
157 |
+
multimask_output_in_sam: true
|
158 |
+
# SAM heads
|
159 |
+
iou_prediction_use_sigmoid: True
|
160 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
161 |
+
use_obj_ptrs_in_encoder: true
|
162 |
+
add_tpos_enc_to_obj_ptrs: true
|
163 |
+
proj_tpos_enc_in_obj_ptrs: true
|
164 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
165 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
166 |
+
# object occlusion prediction
|
167 |
+
pred_obj_scores: true
|
168 |
+
pred_obj_scores_mlp: true
|
169 |
+
fixed_no_obj_ptr: true
|
170 |
+
# multimask tracking settings
|
171 |
+
multimask_output_for_tracking: true
|
172 |
+
use_multimask_token_for_obj_ptr: true
|
173 |
+
multimask_min_pt_num: 0
|
174 |
+
multimask_max_pt_num: 1
|
175 |
+
use_mlp_for_obj_ptr_proj: true
|
176 |
+
# Compilation flag
|
177 |
+
# compile_image_encoder: False
|
178 |
+
|
179 |
+
####### Training specific params #######
|
180 |
+
# box/point input and corrections
|
181 |
+
prob_to_use_pt_input_for_train: 0.5
|
182 |
+
prob_to_use_pt_input_for_eval: 0.0
|
183 |
+
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
|
184 |
+
prob_to_use_box_input_for_eval: 0.0
|
185 |
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
186 |
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
187 |
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
188 |
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
189 |
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
190 |
+
# maximum 2 initial conditioning frames
|
191 |
+
num_init_cond_frames_for_train: 2
|
192 |
+
rand_init_cond_frames_for_train: True # random 1~2
|
193 |
+
num_correction_pt_per_frame: 7
|
194 |
+
use_act_ckpt_iterative_pt_sampling: false
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
199 |
+
forward_backbone_per_frame_for_eval: True
|
200 |
+
|
201 |
+
|
202 |
+
data:
|
203 |
+
train:
|
204 |
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
205 |
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
206 |
+
batch_sizes:
|
207 |
+
- ${scratch.train_batch_size}
|
208 |
+
|
209 |
+
datasets:
|
210 |
+
- _target_: training.dataset.utils.RepeatFactorWrapper
|
211 |
+
dataset:
|
212 |
+
_target_: training.dataset.utils.ConcatDataset
|
213 |
+
datasets:
|
214 |
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
215 |
+
transforms: ${vos.train_transforms}
|
216 |
+
training: true
|
217 |
+
video_dataset:
|
218 |
+
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
|
219 |
+
img_folder: ${dataset.img_folder}
|
220 |
+
gt_folder: ${dataset.gt_folder}
|
221 |
+
file_list_txt: ${dataset.file_list_txt}
|
222 |
+
sampler:
|
223 |
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
224 |
+
num_frames: ${scratch.num_frames}
|
225 |
+
max_num_objects: ${scratch.max_num_objects}
|
226 |
+
multiplier: ${dataset.multiplier}
|
227 |
+
shuffle: True
|
228 |
+
num_workers: ${scratch.num_train_workers}
|
229 |
+
pin_memory: True
|
230 |
+
drop_last: True
|
231 |
+
collate_fn:
|
232 |
+
_target_: training.utils.data_utils.collate_fn
|
233 |
+
_partial_: true
|
234 |
+
dict_key: all
|
235 |
+
|
236 |
+
optim:
|
237 |
+
amp:
|
238 |
+
enabled: True
|
239 |
+
amp_dtype: bfloat16
|
240 |
+
|
241 |
+
optimizer:
|
242 |
+
_target_: torch.optim.AdamW
|
243 |
+
|
244 |
+
gradient_clip:
|
245 |
+
_target_: training.optimizer.GradientClipper
|
246 |
+
max_norm: 0.1
|
247 |
+
norm_type: 2
|
248 |
+
|
249 |
+
param_group_modifiers:
|
250 |
+
- _target_: training.optimizer.layer_decay_param_modifier
|
251 |
+
_partial_: True
|
252 |
+
layer_decay_value: 0.9
|
253 |
+
apply_to: 'image_encoder.trunk'
|
254 |
+
overrides:
|
255 |
+
- pattern: '*pos_embed*'
|
256 |
+
value: 1.0
|
257 |
+
|
258 |
+
options:
|
259 |
+
lr:
|
260 |
+
- scheduler:
|
261 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
262 |
+
start_value: ${scratch.base_lr}
|
263 |
+
end_value: ${divide:${scratch.base_lr},10}
|
264 |
+
- scheduler:
|
265 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
266 |
+
start_value: ${scratch.vision_lr}
|
267 |
+
end_value: ${divide:${scratch.vision_lr},10}
|
268 |
+
param_names:
|
269 |
+
- 'image_encoder.*'
|
270 |
+
weight_decay:
|
271 |
+
- scheduler:
|
272 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
273 |
+
value: 0.1
|
274 |
+
- scheduler:
|
275 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
276 |
+
value: 0.0
|
277 |
+
param_names:
|
278 |
+
- '*bias*'
|
279 |
+
module_cls_names: ['torch.nn.LayerNorm']
|
280 |
+
|
281 |
+
loss:
|
282 |
+
all:
|
283 |
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
284 |
+
weight_dict:
|
285 |
+
loss_mask: 20
|
286 |
+
loss_dice: 1
|
287 |
+
loss_iou: 1
|
288 |
+
loss_class: 1
|
289 |
+
supervise_all_iou: true
|
290 |
+
iou_use_l1_loss: true
|
291 |
+
pred_obj_scores: true
|
292 |
+
focal_gamma_obj_score: 0.0
|
293 |
+
focal_alpha_obj_score: -1.0
|
294 |
+
|
295 |
+
distributed:
|
296 |
+
backend: nccl
|
297 |
+
find_unused_parameters: True
|
298 |
+
|
299 |
+
logging:
|
300 |
+
tensorboard_writer:
|
301 |
+
_target_: training.utils.logger.make_tensorboard_logger
|
302 |
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
303 |
+
flush_secs: 120
|
304 |
+
should_log: True
|
305 |
+
log_dir: ${launcher.experiment_log_dir}/logs
|
306 |
+
log_freq: 10
|
307 |
+
|
308 |
+
# initialize from a SAM 2 checkpoint
|
309 |
+
checkpoint:
|
310 |
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
311 |
+
save_freq: 0 # 0 only last checkpoint is saved.
|
312 |
+
model_weight_initializer:
|
313 |
+
_partial_: True
|
314 |
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
315 |
+
strict: True
|
316 |
+
ignore_unexpected_keys: null
|
317 |
+
ignore_missing_keys: null
|
318 |
+
|
319 |
+
state_dict:
|
320 |
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
321 |
+
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
322 |
+
ckpt_state_dict_keys: ['model']
|
323 |
+
|
324 |
+
launcher:
|
325 |
+
num_nodes: 1
|
326 |
+
gpus_per_node: 8
|
327 |
+
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
328 |
+
|
329 |
+
# SLURM args if running on a cluster
|
330 |
+
submitit:
|
331 |
+
partition: null
|
332 |
+
account: null
|
333 |
+
qos: null
|
334 |
+
cpus_per_task: 10
|
335 |
+
use_cluster: false
|
336 |
+
timeout_hour: 24
|
337 |
+
name: null
|
338 |
+
port_range: [10000, 65000]
|
339 |
+
|
sam2/configs/sam2/sam2_hiera_b+.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [32, 32]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [32, 32]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
# use high-resolution feature map in the SAM mask decoder
|
93 |
+
use_high_res_features_in_sam: true
|
94 |
+
# output 3 masks on the first click on initial conditioning frames
|
95 |
+
multimask_output_in_sam: true
|
96 |
+
# SAM heads
|
97 |
+
iou_prediction_use_sigmoid: True
|
98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
99 |
+
use_obj_ptrs_in_encoder: true
|
100 |
+
add_tpos_enc_to_obj_ptrs: false
|
101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
102 |
+
# object occlusion prediction
|
103 |
+
pred_obj_scores: true
|
104 |
+
pred_obj_scores_mlp: true
|
105 |
+
fixed_no_obj_ptr: true
|
106 |
+
# multimask tracking settings
|
107 |
+
multimask_output_for_tracking: true
|
108 |
+
use_multimask_token_for_obj_ptr: true
|
109 |
+
multimask_min_pt_num: 0
|
110 |
+
multimask_max_pt_num: 1
|
111 |
+
use_mlp_for_obj_ptr_proj: true
|
112 |
+
# Compilation flag
|
113 |
+
compile_image_encoder: False
|