bullerwins commited on
Commit
1a3ec73
·
verified ·
1 Parent(s): e66a3c7

Add files using upload-large-folder tool

Browse files
LICENSE ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hunyuan A13B Release Date: June 27, 2025
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
15
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan A13B released at [https://github.com/Tencent-Hunyuan/Hunyuan-A13B].
16
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29
+ 4. ADDITIONAL COMMERCIAL TERMS.
30
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31
+ 5. RULES OF USE.
32
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35
+ 6. INTELLECTUAL PROPERTY.
36
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44
+ 8. SURVIVAL AND TERMINATION.
45
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47
+ 9. GOVERNING LAW AND JURISDICTION.
48
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50
+
51
+ EXHIBIT A
52
+ ACCEPTABLE USE POLICY
53
+
54
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
55
+ Last modified: November 5, 2024
56
+
57
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58
+ 1. Outside the Territory;
59
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60
+ 3. To harm Yourself or others;
61
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66
+ 9. To intentionally defame, disparage or otherwise harass others;
67
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70
+ 13. To impersonate another individual without consent, authorization, or legal right;
71
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76
+ 19. For military purposes;
77
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
Notice.txt ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage and Legal Notices:
2
+
3
+ Tencent is pleased to support the open source community by making Tencent Hunyuan A13B available.
4
+
5
+ Copyright (C) Tencent. All rights reserved. The below software and/or models in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
6
+
7
+ Tencent Hunyuan A13B is licensed under TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent Hunyuan A13B does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
8
+
9
+ For avoidance of doubts, Tencent Hunyuan A13B refers to the inference code, training code, parameters and the weights of Tencent Hunyuan A13B only, which are made publicly available by Tencent in accordance with the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
10
+
11
+
12
+ Other dependencies and licenses:
13
+
14
+
15
+ Open Source Software Licensed under the Apache License Version 2.0:
16
+ The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
17
+ --------------------------------------------------------------------
18
+ 1. pytorch
19
+ Copyright 2016-2017 TorchAPI
20
+ Copyright 2016-2017 Contributors
21
+
22
+ 2. VLLM
23
+ Copyright (c) vllm original author and authors
24
+ Please note this software has been modified by Tencent in this distribution.
25
+
26
+ 3. transformers
27
+ Copyright 2018- The Hugging Face team. All rights reserved.
28
+
29
+ 4. accelerate
30
+ Copyright (c) accelerate original author and authors
31
+
32
+
33
+ Terms of the Apache License Version 2.0:
34
+ --------------------------------------------------------------------
35
+ Apache License
36
+
37
+ Version 2.0, January 2004
38
+
39
+ http://www.apache.org/licenses/
40
+
41
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
42
+ 1. Definitions.
43
+
44
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
45
+
46
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
47
+
48
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
49
+
50
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
51
+
52
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
53
+
54
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
55
+
56
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
57
+
58
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
59
+
60
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
63
+
64
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
65
+
66
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
67
+
68
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
69
+
70
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
71
+
72
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
73
+
74
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
75
+
76
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
77
+
78
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
79
+
80
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
81
+
82
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
83
+
84
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
85
+
86
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
87
+
88
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
89
+
90
+ END OF TERMS AND CONDITIONS
91
+
92
+
93
+
94
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
95
+ --------------------------------------------------------------------
96
+ 1. pytorch
97
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
98
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
99
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
100
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
101
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
102
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
103
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
104
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
105
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
106
+
107
+
108
+ Terms of the BSD 3-Clause:
109
+ --------------------------------------------------------------------
110
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
111
+
112
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
113
+
114
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
115
+
116
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
117
+
118
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
119
+
120
+ For the license of other third party components, please refer to the following URL:
121
+ https://github.com/pytorch/pytorch/blob/v2.1.1/NOTICE
122
+ https://github.com/pytorch/pytorch/tree/v2.1.1/third_party
123
+
124
+
125
+ Open Source Software Licensed under the BSD 3-Clause License:
126
+ --------------------------------------------------------------------
127
+ 1. flash_attn
128
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
129
+ All rights reserved.
130
+
131
+
132
+ A copy of the BSD 3-Clause is included in this file.
133
+
134
+
135
+
136
+ Open Source Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
137
+ The below software in this distribution is modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
138
+ --------------------------------------------------------------------
139
+ 1. sglang
140
+ Copyright 2023-2024 SGLang Team
141
+
142
+
143
+ A copy of the Apache 2.0 is included in this file.
144
+
145
+ For the license of other third party components, please refer to the following URL:
146
+ https://github.com/sgl-project/sglang/tree/v0.4.7/3rdparty/amd
147
+
148
+
149
+
150
+ Open Source Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
151
+ The below software in this distribution is modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
152
+ --------------------------------------------------------------------
153
+ 1. TensorRT-LLM
154
+ Copyright (c) TensorRT-LLM original author and authors
155
+
156
+
157
+ A copy of the Apache 2.0 is included in this file.
158
+
159
+ For the license of other third party components, please refer to the following URL:
160
+ https://github.com/NVIDIA/TensorRT-LLM/tree/v0.20.0/3rdparty
README.md ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: tencent-hunyuan-a13b
4
+ license_link: https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/LICENSE
5
+ ---
6
+
7
+ <p align="center">
8
+ <img src="https://dscache.tencent-cloud.cn/upload/uploader/hunyuan-64b418fd052c033b228e04bc77bbc4b54fd7f5bc.png" width="400"/> <br>
9
+ </p><p></p>
10
+
11
+
12
+ <p align="center">
13
+ 🫣&nbsp;<a href="https://huggingface.co/tencent/Hunyuan-A13B-Instruct"><b>Hugging Face</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
14
+ 🖥️&nbsp;<a href="https://llm.hunyuan.tencent.com/" style="color: red;"><b>Official Website</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
15
+ 🕖&nbsp;<a href="https://cloud.tencent.com/product/hunyuan"><b>HunyuanAPI</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
16
+ 🕹️&nbsp;<a href="https://hunyuan.tencent.com/?model=hunyuan-a13b"><b>Demo</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
17
+ </p>
18
+
19
+
20
+ <p align="center">
21
+ <a href="https://github.com/Tencent-Hunyuan/Hunyuan-A13B"><b>GITHUB</b></a> |
22
+ <a href="https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/LICENSE"><b>LICENSE</b></a>
23
+ </p>
24
+
25
+
26
+
27
+ Welcome to the official repository of **Hunyuan-A13B**, an innovative and open-source large language model (LLM) built on a fine-grained Mixture-of-Experts (MoE) architecture. Designed for efficiency and scalability, Hunyuan-A13B delivers cutting-edge performance with minimal computational overhead, making it an ideal choice for advanced reasoning and general-purpose applications, especially in resource-constrained environments.
28
+
29
+ ## Model Introduction
30
+
31
+ With the rapid advancement of artificial intelligence technology, large language models (LLMs) have achieved remarkable progress in natural language processing, computer vision, and scientific tasks. However, as model scales continue to expand, optimizing resource consumption while maintaining high performance has become a critical challenge. To address this, we have explored Mixture of Experts (MoE) architectures. The newly introduced Hunyuan-A13B model features a total of 80 billion parameters with 13 billion active parameters. It not only delivers high-performance results but also achieves optimal resource efficiency, successfully balancing computational power and resource utilization.
32
+
33
+ ### Key Features and Advantages
34
+
35
+ - **Compact yet Powerful**: With only 13 billion active parameters (out of a total of 80 billion), the model delivers competitive performance on a wide range of benchmark tasks, rivaling much larger models.
36
+ - **Hybrid Inference Support**: Supports both fast and slow thinking modes, allowing users to flexibly choose according to their needs.
37
+ - **Ultra-Long Context Understanding**: Natively supports a 256K context window, maintaining stable performance on long-text tasks.
38
+ - **Enhanced Agent Capabilities**: Optimized for agent tasks, achieving leading results on benchmarks such as BFCL-v3 and τ-Bench.
39
+ - **Efficient Inference**: Utilizes Grouped Query Attention (GQA) and supports multiple quantization formats, enabling highly efficient inference.
40
+
41
+ ### Why Choose Hunyuan-A13B?
42
+
43
+ As a powerful yet computationally efficient large model, Hunyuan-A13B is an ideal choice for researchers and developers seeking high performance under resource constraints. Whether for academic research, cost-effective AI solution development, or innovative application exploration, this model provides a robust foundation for advancement.
44
+
45
+ &nbsp;
46
+
47
+ ## Related News
48
+ * 2025.6.27 We have open-sourced **Hunyuan-A13B-Pretrain** , **Hunyuan-A13B-Instruct** , **Hunyuan-A13B-Instruct-FP8** , **Hunyuan-A13B-Instruct-GPTQ-Int4** on Hugging Face.
49
+ <br>
50
+
51
+
52
+ ## Benchmark
53
+
54
+ Note: The following benchmarks are evaluated by TRT-LLM-backend
55
+
56
+ | Model | Hunyuan-Large | Qwen2.5-72B | Qwen3-A22B | Hunyuan-A13B |
57
+ |------------------|---------------|--------------|-------------|---------------|
58
+ | MMLU | 88.40 | 86.10 | 87.81 | 88.17 |
59
+ | MMLU-Pro | 60.20 | 58.10 | 68.18 | 67.23 |
60
+ | MMLU-Redux | 87.47 | 83.90 | 87.40 | 87.67 |
61
+ | BBH | 86.30 | 85.80 | 88.87 | 87.56 |
62
+ | SuperGPQA | 38.90 | 36.20 | 44.06 | 41.32 |
63
+ | EvalPlus | 75.69 | 65.93 | 77.60 | 78.64 |
64
+ | MultiPL-E | 59.13 | 60.50 | 65.94 | 69.33 |
65
+ | MBPP | 72.60 | 76.00 | 81.40 | 83.86 |
66
+ | CRUX-I | 57.00 | 57.63 | - | 70.13 |
67
+ | CRUX-O | 60.63 | 66.20 | 79.00 | 77.00 |
68
+ | MATH | 69.80 | 62.12 | 71.84 | 72.35 |
69
+ | CMATH | 91.30 | 84.80 | - | 91.17 |
70
+ | GSM8k | 92.80 | 91.50 | 94.39 | 91.83 |
71
+ | GPQA | 25.18 | 45.90 | 47.47 | 49.12 |
72
+
73
+
74
+
75
+
76
+ Hunyuan-A13B-Instruct has achieved highly competitive performance across multiple benchmarks, particularly in mathematics, science, agent domains, and more. We compared it with several powerful models, and the results are shown below.
77
+
78
+ | Topic | Bench | OpenAI-o1-1217 | DeepSeek R1 | Qwen3-A22B | Hunyuan-A13B-Instruct |
79
+ |:-------------------:|:-----------------------------:|:-------------:|:------------:|:-----------:|:---------------------:|
80
+ | **Mathematics** | AIME 2024<br>AIME 2025<br>MATH | 74.3<br>79.2<br>96.4 | 79.8<br>70<br>94.9 | 85.7<br>81.5<br>94.0 | 87.3<br>76.8<br>94.3 |
81
+ | **Science** | GPQA-Diamond<br>OlympiadBench | 78<br>83.1 | 71.5<br>82.4 | 71.1<br>85.7 | 71.2<br>82.7 |
82
+ | **Coding** | Livecodebench<br>Fullstackbench<br>ArtifactsBench | 63.9<br>64.6<br>38.6 | 65.9<br>71.6<br>44.6 | 70.7<br>65.6<br>44.6 | 63.9<br>67.8<br>43 |
83
+ | **Reasoning** | BBH<br>DROP<br>ZebraLogic | 80.4<br>90.2<br>81 | 83.7<br>92.2<br>78.7 | 88.9<br>90.3<br>80.3 | 89.1<br>91.1<br>84.7 |
84
+ | **Instruction<br>Following** | IF-Eval<br>SysBench | 91.8<br>82.5 | 88.3<br>77.7 | 83.4<br>74.2 | 84.7<br>76.1 |
85
+ | **Text<br>Creation**| LengthCtrl<br>InsCtrl | 60.1<br>74.8 | 55.9<br>69 | 53.3<br>73.7 | 55.4<br>71.9 |
86
+ | **NLU** | ComplexNLU<br>Word-Task | 64.7<br>67.1 | 64.5<br>76.3 | 59.8<br>56.4 | 61.2<br>62.9 |
87
+ | **Agent** | BDCL v3<br> τ-Bench<br>ComplexFuncBench<br> C3-Bench | 67.8<br>60.4<br>47.6<br>58.8 | 56.9<br>43.8<br>41.1<br>55.3 | 70.8<br>44.6<br>40.6<br>51.7 | 78.3<br>54.7<br>61.2<br>63.5 |
88
+
89
+
90
+ &nbsp;
91
+
92
+ ## Use with transformers
93
+ The following code snippet shows how to use the transformers library to load and apply the model. It also demonstrates how to enable and disable the reasoning mode , and how to parse the reasoning process along with the final output.
94
+
95
+ ```python
96
+ from transformers import AutoModelForCausalLM, AutoTokenizer
97
+ import os
98
+ import re
99
+
100
+ model_name_or_path = os.environ['MODEL_PATH']
101
+ # model_name_or_path = "tencent/Hunyuan-A13B-Instruct"
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
104
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto",trust_remote_code=True) # You may want to use bfloat16 and/or move to GPU here
105
+ messages = [
106
+ {"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
107
+ ]
108
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True,return_tensors="pt",
109
+ enable_thinking=True # Toggle thinking mode (default: True)
110
+ )
111
+
112
+ outputs = model.generate(tokenized_chat.to(model.device), max_new_tokens=4096)
113
+
114
+ output_text = tokenizer.decode(outputs[0])
115
+
116
+ think_pattern = r'<think>(.*?)</think>'
117
+ think_matches = re.findall(think_pattern, output_text, re.DOTALL)
118
+
119
+ answer_pattern = r'<answer>(.*?)</answer>'
120
+ answer_matches = re.findall(answer_pattern, output_text, re.DOTALL)
121
+
122
+ think_content = [match.strip() for match in think_matches][0]
123
+ answer_content = [match.strip() for match in answer_matches][0]
124
+ print(f"thinking_content:{think_content}\n\n")
125
+ print(f"answer_content:{answer_content}\n\n")
126
+ ```
127
+
128
+ ## Quantitative Compression
129
+ We used our own `AngleSlim` compression tool to produce FP8 and INT4 quantization models. `AngleSlim` compression tool is expected to be open source in early July, which will support one-click quantization and compression of large models, please look forward to it, and you can download our quantization models directly for deployment testing now.
130
+
131
+ ### FP8 Quantization
132
+ We use FP8-static quantization, FP8 quantization adopts 8-bit floating point format, through a small amount of calibration data (without training) to pre-determine the quantization scale, the model weights and activation values will be converted to FP8 format, to improve the inference efficiency and reduce the deployment threshold. We you can use AngleSlim quantization, you can also directly download our quantization completed open source model to use [Hunyuan-A13B-Instruct-FP8](https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8).
133
+
134
+ #### FP8 Benchmark
135
+ This subsection describes the Benchmark metrics for the Hunyuan-80B-A13B-Instruct-FP8 quantitative model.
136
+
137
+ | Bench | Hunyuan-A13B-Instruct | Hunyuan-A13B-Instruct-FP8 |
138
+ |:---------:|:---------------------:|:-------------------------:|
139
+ | AIME 2024 | 87.3 | 86.7 |
140
+ | Gsm8k | 94.39 | 94.01 |
141
+ | BBH | 89.1 | 88.34 |
142
+ | DROP | 91.1 | 91.1 |
143
+
144
+ ### Int4 Quantization
145
+ We use the GPTQ algorithm to achieve W4A16 quantization, which processes the model weights layer by layer, uses a small amount of calibration data to minimize the reconfiguration error of the quantized weights, and adjusts the weights layer by layer by the optimization process of approximating the Hessian inverse matrix. The process eliminates the need to retrain the model and requires only a small amount of calibration data to quantize the weights, improving inference efficiency and lowering the deployment threshold. You can use `AngleSlim` quantization, you can also directly download our quantization completed open source model to use [Hunyuan-A13B-Instruct-Int4](https://huggingface.co/tencent/Hunyuan-A13B-Instruct-GPTQ-Int4).
146
+
147
+ #### Int4 Benchmark
148
+ This subsection describes the Benchmark metrics for the Hunyuan-80B-A13B-Instruct-GPTQ-Int4 quantitative model.
149
+
150
+ | Bench | Hunyuan-A13B-Instruct | Hunyuan-A13B-Instruct-GPTQ-Int4 |
151
+ |:--------------:|:---------------------:|:-------------------------------:|
152
+ | OlympiadBench | 82.7 | 84.0 |
153
+ | AIME 2024 | 87.3 | 86.7 |
154
+ | Gsm8k | 94.39 | 94.24 |
155
+ | BBH | 88.34 | 87.91 |
156
+ | DROP | 91.12 | 91.05 |
157
+
158
+
159
+ ## Deployment
160
+
161
+ For deployment, you can use frameworks such as **TensorRT-LLM**, **vLLM**, or **SGLang** to serve the model and create an OpenAI-compatible API endpoint.
162
+
163
+ image: https://hub.docker.com/r/hunyuaninfer/hunyuan-a13b/tags
164
+
165
+
166
+ ### TensorRT-LLM
167
+
168
+ #### Docker Image
169
+
170
+ We provide a pre-built Docker image based on the latest version of TensorRT-LLM.
171
+
172
+ - To get started:
173
+
174
+ https://hub.docker.com/r/hunyuaninfer/hunyuan-large/tags
175
+
176
+ ```
177
+ docker pull hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-trtllm
178
+ ```
179
+
180
+ - Start the API server:
181
+
182
+ ```
183
+ docker run --name hunyuanLLM_infer --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-trtllm
184
+ ```
185
+ ```
186
+ trtllm-serve \
187
+ /path/to/HunYuan-moe-A13B \
188
+ --host localhost \
189
+ --port 8000 \
190
+ --backend pytorch \
191
+ --max_batch_size 128 \
192
+ --max_num_tokens 16384 \
193
+ --tp_size 2 \
194
+ --kv_cache_free_gpu_memory_fraction 0.95 \
195
+ --extra_llm_api_options /path/to/extra-llm-api-config.yml
196
+ ```
197
+
198
+
199
+ ### vllm
200
+
201
+ #### Docker Image
202
+ We provide a pre-built Docker image containing vLLM 0.8.5 with full support for this model. The official vllm release is currently under development, **note: cuda 12.8 is require for this docker**.
203
+
204
+ - To get started:
205
+
206
+ ```
207
+ docker pull docker.cnb.cool/tencent/hunyuan/hunyuan-a13b:hunyuan-moe-A13B-vllm
208
+ or
209
+ docker pull hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-vllm
210
+ ```
211
+
212
+ - Download Model file:
213
+ - Huggingface: will download automicly by vllm.
214
+ - ModelScope: `modelscope download --model Tencent-Hunyuan/Hunyuan-A13B-Instruct`
215
+
216
+
217
+ - Start the API server:
218
+
219
+ model download by huggingface:
220
+ ```
221
+ docker run --privileged --user root --net=host --ipc=host \
222
+ -v ~/.cache:/root/.cache/ \
223
+ --gpus=all -it --entrypoint python hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-vllm
224
+ \
225
+ -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 \
226
+ --tensor-parallel-size 4 --model tencent/Hunyuan-A13B-Instruct --trust-remote-code
227
+
228
+ ```
229
+
230
+ model downloaded by modelscope:
231
+ ```
232
+ docker run --privileged --user root --net=host --ipc=host \
233
+ -v ~/.cache/modelscope:/root/.cache/modelscope \
234
+ --gpus=all -it --entrypoint python hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-vllm \
235
+ -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --tensor-parallel-size 4 --port 8000 \
236
+ --model /root/.cache/modelscope/hub/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct/ --trust_remote_code
237
+ ```
238
+
239
+
240
+ ### SGLang
241
+
242
+ #### Docker Image
243
+
244
+ We also provide a pre-built Docker image based on the latest version of SGLang.
245
+
246
+ To get started:
247
+
248
+ - Pull the Docker image
249
+
250
+ ```
251
+ docker pull docker.cnb.cool/tencent/hunyuan/hunyuan-a13b:hunyuan-moe-A13B-sglang
252
+ or
253
+ docker pull hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-sglang
254
+ ```
255
+
256
+ - Start the API server:
257
+
258
+ ```
259
+ docker run --gpus all \
260
+ --shm-size 32g \
261
+ -p 30000:30000 \
262
+ --ipc=host \
263
+ docker.cnb.cool/tencent/hunyuan/hunyuan-a13b:hunyuan-moe-A13B-sglang \
264
+ -m sglang.launch_server --model-path hunyuan/huanyuan_A13B --tp 4 --trust-remote-code --host 0.0.0.0 --port 30000
265
+ ```
266
+
267
+
268
+ ## Contact Us
269
+
270
+ If you would like to leave a message for our R&D and product teams, Welcome to contact our open-source team . You can also contact us via email ([email protected]).
README_CN.md ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="https://dscache.tencent-cloud.cn/upload/uploader/hunyuan-64b418fd052c033b228e04bc77bbc4b54fd7f5bc.png" width="400"/> <br>
3
+ </p><p></p>
4
+
5
+ <p align="center">
6
+ 🫣&nbsp;<a href="https://huggingface.co/tencent/Hunyuan-A13B-Instruct"><b>Hugging Face</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
7
+ 🖥️&nbsp;<a href="https://llm.hunyuan.tencent.com/" style="color: red;"><b>Official Website</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
8
+ 🕖&nbsp;<a href="https://cloud.tencent.com/product/hunyuan"><b>HunyuanAPI</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
9
+ 🕹️&nbsp;<a href="https://hunyuan.tencent.com/?model=hunyuan-a13b"><b>Demo</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
10
+ <img src="https://avatars.githubusercontent.com/u/109945100?s=200&v=4" width="16"/>&nbsp;<a href="https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct"><b>ModelScope</b></a>
11
+ </p>
12
+
13
+ <p align="center">
14
+ <a href="https://github.com/Tencent/Hunyuan-A13B"><b>GITHUB</b></a>
15
+ </p>
16
+
17
+
18
+
19
+
20
+ ## 模型介绍
21
+
22
+ 随着人工智能技术的快速发展,大型语言模型(LLMs)在自然语言处理、计算机视觉和科学任务等领域取得了显著进展。然而,随着模型规模的扩大,如何在保持高性能的同时优化资源消耗成为一个关键挑战。为了应对这一挑战,我们研究了混合专家(MoE)模型,当前亮相的 Hunyuan-A13B 模型,拥有800亿总参数和130亿激活参数。不仅在效果上达到了高标准,而且在尺寸上也做到了极致的优化,成功平衡了模型性能与资源占用。
23
+
24
+
25
+ ### 核心特性与优势
26
+ - ​**小参数量,高性能**​:仅激活130亿参数(总参数量800亿),即可在多样化基准任务中媲美更大规模模型的竞争力表现
27
+ - ​**混合推理支持**​:同时支持快思考和慢思考两种模式,支持用户灵活选择
28
+ - ​**超长上下文理解**​:原生支持256K上下文窗口,在长文本任务中保持稳定性能
29
+ - ​**增强Agent能力**​:优化Agent能力,在BFCL-v3、τ-Bench等智能体基准测试中领先
30
+ - ​**高效推理**​:采用分组查询注意力(GQA)策略,支持多量化格式,实现高效推理
31
+
32
+
33
+ ### 为何选择Hunyuan-A13B?
34
+ 作为兼具强大性能与计算效率的大模型,Hunyuan-A13B是研究者与开发者在资源受限条件下追求高性能的理想选择。无论学术研究、高性价比AI解决方案开发,还是创新应用探索,本模型都能提供强大的基础支持。
35
+
36
+
37
+ &nbsp;
38
+
39
+ ## 新闻
40
+ <br>
41
+
42
+ * 2025.6.26 我们在Hugging Face开源了 **Hunyuan-A13B-Instruct**,**Hunyuan-A13B-Pretrain**, **Hunyuan-A13B-Instruct-FP8**, **Hunyuan-A13B-Instruct-GPTQ-Int4**。并发布了技术报告和训练推理操作手册,详细介绍了模型能力和训练与推理的操作。
43
+
44
+ ## 模型结构
45
+
46
+ Hunyuan-A13B采用了细粒度混合专家(Fine-grained Mixture of Experts,Fine-grained MoE)架构,包含800亿参数和130亿激活参数,累计训练了超过 20T tokens。该模型支持 256K 的上下文长度,以下为模型结构细节:
47
+ * 总参数: 80B
48
+ * 激活参数: 13B
49
+ * 层数: 32
50
+ * Attention Heads: 32
51
+ * 共享专家数: 1
52
+ * 非共享专家数: 64
53
+ * 路由策略: Top-8
54
+ * 激活函数: SwiGLU
55
+ * 隐层维度: 4096
56
+ * 专家隐层维度: 3072
57
+
58
+ ## Benchmark评估榜单
59
+
60
+ **Hunyuan-A13B-Pretrain** 在 12/14 个任务上超越了Hunyuan上一代52B激活参数的MoE模型Hunyuan-Large,证实了它在预训练任务上出色的能力。与业界更大参数量的Dense和MoE模型相比, Hunyuan-A13B在多个代码和数学任务上都取得了最高分数。在MMLU, MMLU-PRO等诸多众聚合任务上, Hunyuan-A13B达到了与Qwen3-A22B模型同等的水平,表现出优秀的综合能力。
61
+
62
+ | Model | Hunyuan-Large | Qwen2.5-72B | Qwen3-A22B | Hunyuan-A13B |
63
+ |------------------|---------------|--------------|-------------|---------------|
64
+ | MMLU | 88.40 | 86.10 | 87.81 | 88.17 |
65
+ | MMLU-Pro | 60.20 | 58.10 | 68.18 | 67.23 |
66
+ | MMLU-Redux | 87.47 | 83.90 | 87.40 | 87.67 |
67
+ | BBH | 86.30 | 85.80 | 88.87 | 87.56 |
68
+ | SuperGPQA | 38.90 | 36.20 | 44.06 | 41.32 |
69
+ | EvalPlus | 75.69 | 65.93 | 77.60 | 78.64 |
70
+ | MultiPL-E | 59.13 | 60.50 | 65.94 | 69.33 |
71
+ | MBPP | 72.60 | 76.00 | 81.40 | 83.86 |
72
+ | CRUX-I | 57.00 | 57.63 | - | 70.13 |
73
+ | CRUX-O | 60.63 | 66.20 | 79.00 | 77.00 |
74
+ | MATH | 69.80 | 62.12 | 71.84 | 72.35 |
75
+ | CMATH | 91.30 | 84.80 | - | 91.17 |
76
+ | GSM8k | 92.80 | 91.50 | 94.39 | 91.83 |
77
+ | GPQA | 25.18 | 45.90 | 47.47 | 49.12 |
78
+
79
+ **Hunyuan-A13B-Instruct** 在多项基准测试中取得了极具有竞争力的表现,尤其是在数学、科学、agent等领域。我们与一些强力模型进行了对比,结果如下所示。
80
+
81
+ | Topic | Bench | OpenAI-o1-1217 | DeepSeek R1 | Qwen3-A22B | Hunyuan-A13B-Instruct |
82
+ |:-------------------:|:-----------------------------:|:-------------:|:------------:|:-----------:|:---------------------:|
83
+ | **Mathematics** | AIME 2024<br>AIME 2025<br>MATH | 74.3<br>79.2<br>96.4 | 79.8<br>70<br>94.9 | 85.7<br>81.5<br>94.0 | 87.3<br>76.8<br>94.3 |
84
+ | **Science** | GPQA-Diamond<br>OlympiadBench | 78<br>83.1 | 71.5<br>82.4 | 71.1<br>85.7 | 71.2<br>82.7 |
85
+ | **Coding** | Livecodebench<br>Fullstackbench<br>ArtifactsBench | 63.9<br>64.6<br>38.6 | 65.9<br>71.6<br>44.6 | 70.7<br>65.6<br>44.6 | 63.9<br>67.8<br>43 |
86
+ | **Reasoning** | BBH<br>DROP<br>ZebraLogic | 80.4<br>90.2<br>81 | 83.7<br>92.2<br>78.7 | 88.9<br>90.3<br>80.3 | 89.1<br>91.1<br>84.7 |
87
+ | **Instruction<br>Following** | IF-Eval<br>SysBench | 91.8<br>82.5 | 88.3<br>77.7 | 83.4<br>74.2 | 84.7<br>76.1 |
88
+ | **Text<br>Creation**| LengthCtrl<br>InsCtrl | 60.1<br>74.8 | 55.9<br>69 | 53.3<br>73.7 | 55.4<br>71.9 |
89
+ | **NLU** | ComplexNLU<br>Word-Task | 64.7<br>67.1 | 64.5<br>76.3 | 59.8<br>56.4 | 61.2<br>62.9 |
90
+ | **Agent** | BDCL v3<br> τ-Bench<br>ComplexFuncBench<br> $C^3$-Bench | 67.8<br>60.4<br>47.6<br>58.8 | 56.9<br>43.8<br>41.1<br>55.3 | 70.8<br>44.6<br>40.6<br>51.7 | 78.3<br>54.7<br>61.2<br>63.5 |
91
+
92
+
93
+ ## 推理和部署
94
+
95
+ HunyuanLLM可以采用vLLM,sglang或TensorRT-LLM部署。为了简化部署过程HunyuanLLM提供了预构建docker镜像。
96
+
97
+
98
+ ## 使用TensorRT-LLM推理
99
+
100
+ ### BF16部署
101
+
102
+ #### Step1:执行推理
103
+
104
+ #### 方式1:命令行推理
105
+
106
+ 下面我们展示一个代码片段,采用`TensorRT-LLM`快速请求chat model:
107
+ 修改 examples/pytorch/quickstart_advanced.py 中如下代码:
108
+
109
+
110
+ ```python
111
+ from tensorrt_llm import SamplingParams
112
+ from tensorrt_llm._torch import LLM
113
+ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
114
+ from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
115
+ MTPDecodingConfig)
116
+
117
+ prompt = "Write a short summary of the benefits of regular exercise"
118
+
119
+ def main():
120
+ args = parse_arguments()
121
+
122
+ llm, sampling_params = setup_llm(args)
123
+ new_prompts = []
124
+ if args.apply_chat_template:
125
+ messages = [{"role": "user", "content": f"{prompt}"}]
126
+ new_prompts.append(llm.tokenizer.apply_chat_template(
127
+ messages, tokenize=False, add_generation_prompt=True)
128
+ )
129
+
130
+ outputs = llm.generate(new_prompts, sampling_params)
131
+
132
+ for i, output in enumerate(outputs):
133
+ prompt = output.prompt
134
+ generated_text = output.outputs[0].text
135
+ print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
136
+ ```
137
+
138
+ 运行方式:
139
+
140
+ ```shell
141
+ python3 quickstart_advanced.py --model_dir "HunyuanLLM模型路径" --tp_size 4 --apply_chat_template
142
+ ```
143
+
144
+ #### 方式2:服务化推理
145
+
146
+ 下面我们展示使用`TensorRT-LLM`服务化的方式部署模型和请求。
147
+
148
+ ```shell
149
+ model_path="HunyuanLLM模型路径"
150
+ trtllm-serve <model_path> [--backend pytorch --tp_size <tp> --ep_size <ep> --host <host> --port <port>]
151
+ ```
152
+
153
+ 服务启动成功后, 运行请求脚本:
154
+ ```python
155
+ ### OpenAI Chat Client
156
+
157
+ from openai import OpenAI
158
+
159
+ client = OpenAI(
160
+ base_url="http://localhost:8000/v1",
161
+ api_key="tensorrt_llm",
162
+ )
163
+
164
+ response = client.chat.completions.create(
165
+ model="default",
166
+ messages=[{
167
+ "role": "user",
168
+ "content": "Write a short summary of the benefits of regular exercise"
169
+ }],
170
+ max_tokens=4096,
171
+ )
172
+ print(response)
173
+ ```
174
+
175
+ #### FP8/Int4量化模型部署:
176
+ 目前 TensorRT-LLM 的 fp8 和 int4 量化模型正在支持中,敬请期待。
177
+
178
+
179
+ ## 使用vLLM推理
180
+ ### Docker:
181
+
182
+ 为了简化部署过程,HunyuanLLM提供了预构建docker镜像:
183
+
184
+ [hunyuaninfer/hunyuan-large:hunyuan-moe-A13B-vllm](https://hub.docker.com/r/hunyuaninfer/hunyuan-large/tags) 。您只需要下载模型文件并用下面代码启动docker即可开始推理模型。
185
+ ```shell
186
+ # 拉取
187
+ docker pull hunyuaninfer/hunyuan-large:hunyuan-moe-A13B-vllm
188
+ # 起镜像
189
+ docker run --name hunyuanLLM_infer -itd --privileged --user root --net=host --ipc=host --gpus=8 hunyuaninfer/hunyuan-large:hunyuan-moe-A13B-vllm
190
+ ```
191
+
192
+ 注: Docker容器权限管理。以上代码采用特权模式(--privileged)启动Docker容器会赋予容器较高的权限,增加数据泄露和集群安全风险。建议在非必要情况下避免使用特权模式,以降低安全威胁。对于必须使用特权模式的场景,应进行严格的安全评估,并实施相应的安全监控、加固措施。
193
+
194
+
195
+ ### BF16部署
196
+
197
+ BF16可以在2张显存超过80G的GPU卡上部署,如果长文推荐TP4。按如下步骤执行:
198
+
199
+ 运行命令前请先设置如下环境变量:
200
+
201
+ ```shell
202
+ export MODEL_PATH=PATH_TO_MODEL
203
+ ```
204
+
205
+ #### Step1:执行推理
206
+
207
+ #### 方式1:命令行推理
208
+
209
+ 下面我们展示一个代码片段,采用`vLLM`快速请求chat model:
210
+
211
+ 注: vLLM组件远程代码执行防护。下列代码中vLLM组件的trust-remote-code配置项若被启用,将允许加载并执行来自远程模型仓库的代码,这可能导致恶意代码的执行。除非业务需求明确要求,否则建议该配置项处于禁用状态,以降低潜在的安全威胁。
212
+
213
+
214
+ ```python
215
+ import os
216
+ from typing import List, Optional
217
+ from vllm import LLM, SamplingParams
218
+ from vllm.inputs import PromptType
219
+ from transformers import AutoTokenizer
220
+
221
+ model_path=os.environ.get('MODEL_PATH')
222
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
223
+
224
+ llm = LLM(model=model_path,
225
+ tokenizer=model_path,
226
+ trust_remote_code=True,
227
+ dtype='bfloat16',
228
+ tensor_parallel_size=4,
229
+ gpu_memory_utilization=0.9)
230
+
231
+ sampling_params = SamplingParams(
232
+ temperature=0.7, top_p=0.8, max_tokens=4096, top_k=20, repetition_penalty=1.05)
233
+
234
+ messages = [
235
+ {
236
+ "role": "system",
237
+ "content": "You are a helpful assistant.",
238
+ },
239
+ {"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
240
+ ]
241
+
242
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
243
+
244
+ dummy_inputs: List[PromptType] = [{
245
+ "prompt_token_ids": batch
246
+ } for batch in tokenized_chat.numpy().tolist()]
247
+
248
+ outputs = llm.generate(dummy_inputs, sampling_params)
249
+
250
+ # Print the outputs.
251
+ for output in outputs:
252
+ prompt = output.prompt
253
+ generated_text = output.outputs[0].text
254
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
255
+ ```
256
+
257
+ #### 方式2:服务化推理
258
+
259
+ 下面我们展示使用`vLLM`服务化的方式部署模型并请求
260
+
261
+ 在主节点上运行:
262
+
263
+ ```shell
264
+ export VLLM_HOST_IP=${LOCAL_IP}
265
+ ```
266
+ 接着我们启动服务,运行 :
267
+ ```shell
268
+ cd inference
269
+ sh run_server.sh
270
+ ```
271
+
272
+ 运行`run_server.sh`成功后, 运行请求脚本:
273
+ ```shell
274
+ sh openapi.sh
275
+ ```
276
+
277
+ 注意修改`openapi.sh`中的`${LOCAL_IP}`和`${MODEL_PATH}`为服务对应值。
278
+
279
+
280
+ ### 量化模型部署:
281
+
282
+ 本部分介绍采用vLLM部署量化后模型的流程。
283
+
284
+ 镜像:部署镜像同BF16。
285
+
286
+
287
+ #### Int8量化模型部署:
288
+ 部署Int8-weight-only版本HunYuan-A13B模型只需设置`run_server_int8.sh`中的环境变量:
289
+ ```SHELL
290
+ export MODEL_PATH=PATH_TO_BF16_MODEL
291
+ ```
292
+
293
+ 接着我们启动Int8服务。运行:
294
+ ```shell
295
+ sh run_server_int8.sh
296
+ ```
297
+
298
+ 运行`run_server_int8.sh`成功后, 运行请求脚本:
299
+ ```shell
300
+ sh openapi.sh
301
+ ```
302
+
303
+ #### Int4量化模型部署:
304
+ 部署Int4-weight-only版本HunYuan-A13B模型只需设置`run_server_int4.sh`中的环境变量,采用GPTQ方式:
305
+ ```SHELL
306
+ export MODEL_PATH=PATH_TO_INT4_MODEL
307
+ ```
308
+
309
+ 接着我们启动Int4服务。运行:
310
+ ```shell
311
+ sh run_server_int4.sh
312
+ ```
313
+
314
+ 运行`run_server_int4.sh`成功后, 运行请求脚本:
315
+ ```shell
316
+ sh openapi.sh
317
+ ```
318
+
319
+ #### FP8量化模型部署:
320
+ 部署W8A8C8版本HunYuan-A13B模型只需设置`run_server_int8.sh`中的环境变量:
321
+ ```shell
322
+ export MODEL_PATH=PATH_TO_FP8_MODEL
323
+ ```
324
+
325
+ 接着我们启动FP8服务。运行:
326
+ ```shell
327
+ sh run_server_fp8.sh
328
+ ```
329
+
330
+ 运行`run_server_fp8.sh`成功后, 运行请求脚本:
331
+ ```shell
332
+ sh openapi.sh
333
+ ```
334
+
335
+ ### 性能评估:
336
+
337
+ 本部分介绍采用vLLM部署各个模型(原始模型和量化模型)的效率测试结果,包括不同Batchsize下的推理速度(tokens/s), 测试环境(腾讯云,H80(96G)GPU x 卡数):
338
+
339
+ 测试命令:
340
+ ```python
341
+ python3 benchmark_throughput.py --backend vllm \
342
+ --input-len 2048 \
343
+ --output-len 14336 \
344
+ --model $MODEL_PATH \
345
+ --tensor-parallel-size $TP \
346
+ --use-v2-block-manager \
347
+ --async-engine \
348
+ --trust-remote-code \
349
+ --num_prompts $BATCH_SIZE \
350
+ --max-num-seqs $BATCH_SIZE
351
+ ```
352
+
353
+ | 推理框架 | 模型 | 部署卡数 | input_length | batch=1 | batch=16 | batch=32 |
354
+ |------|-----------------------------|-----------|-------------------------|---------------------|----------------------|----------------------|
355
+ | vLLM | Hunyuan-A13B-Instruct | 8 | 2048 | 190.84 | 1246.54 | 1981.99 |
356
+ | vLLM | Hunyuan-A13B-Instruct | 4 | 2048 | 158.90 | 779.10 | 1301.75 |
357
+ | vLLM | Hunyuan-A13B-Instruct | 2 | 2048 | 111.72 | 327.31 | 346.54 |
358
+ | vLLM | Hunyuan-A13B-Instruct(int8 weight only) | 2 | 2048 | 109.10 | 444.17 | 721.93 |
359
+ | vLLM | Hunyuan-A13B-Instruct(W8A8C8-FP8) | 2 | 2048 | 91.83 | 372.01 | 617.70 |
360
+ | vLLM | Hunyuan-A13B-Instruct(W8A8C8-FP8) | 1 | 2048 | 60.07 | 148.80 | 160.41 |
361
+
362
+
363
+ ## 使用sglang推理
364
+
365
+ ### BF16部署
366
+
367
+ #### Step1:执行推理
368
+
369
+ #### 方式1:命令行推理
370
+
371
+ 下面我们展示一个代码片段,采用`sglang`快速请求chat model:
372
+
373
+
374
+ ```python
375
+ import sglang as sgl
376
+ from transformers import AutoTokenizer
377
+
378
+ model_path=os.environ.get('MODEL_PATH')
379
+
380
+
381
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
382
+
383
+ messages = [
384
+ {
385
+ "role": "system",
386
+ "content": "You are a helpful assistant.",
387
+ },
388
+ {"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
389
+ ]
390
+ prompts = []
391
+ prompts.append(tokenizer.apply_chat_template(
392
+ messages,
393
+ tokenize=False,
394
+ add_generation_prompt=True
395
+ ))
396
+ print(prompts)
397
+
398
+ llm = sgl.Engine(
399
+ model_path=model_path,
400
+ tp_size=4,
401
+ trust_remote_code=True,
402
+ mem_fraction_static=0.7,
403
+ )
404
+
405
+ sampling_params = {"temperature": 0.7, "top_p": 0.8, "top_k": 20, "max_new_tokens": 4096}
406
+ outputs = llm.generate(prompts, sampling_params)
407
+ for prompt, output in zip(prompts, outputs):
408
+ print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
409
+ ```
410
+
411
+ #### 方式2:服务化推理
412
+
413
+ 下面我们展示使用`sglang`服务化的方式部署模型和请求。
414
+
415
+ ```shell
416
+ model_path="HunyuanLLM模型路径"
417
+ python3 -u -m sglang.launch_server \
418
+ --model-path $model_path \
419
+ --tp 4 \
420
+ --trust-remote-code \
421
+ ```
422
+
423
+ 服务启动成功后, 运行请求脚本:
424
+ ```python
425
+ import openai
426
+ client = openai.Client(
427
+ base_url="http://localhost:30000/v1", api_key="EMPTY")
428
+
429
+ response = client.chat.completions.create(
430
+ model="default",
431
+ messages= [
432
+ {"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
433
+ ],
434
+ temperature=0.7,
435
+ max_tokens=4096,
436
+ extra_body={"top_p": 0.8, "top_k": 20}
437
+ )
438
+ print(response)
439
+ ```
440
+
441
+ #### FP8/Int4量化模型部署:
442
+ 目前 sglang 的 fp8 和 int4 量化模型正在支持中,敬请期待。
443
+
444
+ ## 交互式Demo Web
445
+ hunyuan-A13B 现已开放网页demo。访问 https://hunyuan.tencent.com/?model=hunyuan-a13b 即可简单体验我们的模型。
446
+
447
+ <br>
448
+
449
+ ## 引用
450
+ 如果你觉得我们的工作对你有帮助,欢迎引用我们的<a href="report/Hunyuan_A13B_Technical_Report.pdf">技术报告</a>!
451
+
452
+ <br>
453
+
454
+
455
+ ## 联系我们
456
+ 如果你想给我们的研发和产品团队留言,欢迎联系我们腾讯混元LLM团队。你可以通过邮件([email protected])联系我们。
config.json ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_classification_head": false,
3
+ "anyres_pooling_size": 2,
4
+ "anyres_vit_max_image_size": null,
5
+ "anyres_vit_two_views": false,
6
+ "architectures": [
7
+ "HunYuanMoEV1ForCausalLM"
8
+ ],
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.1,
11
+ "attention_head_dim": 128,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_hunyuan.HunYuanConfig",
14
+ "AutoModel": "hunyuan.HunYuanModel",
15
+ "AutoModelForCausalLM": "hunyuan.HunYuanMoEV1ForCausalLM"
16
+ },
17
+ "bos_token_id": 1,
18
+ "cla_share_factor": 2,
19
+ "class_num": 0,
20
+ "dense_list": [
21
+ 4096,
22
+ 0
23
+ ],
24
+ "eod_token_id": 127967,
25
+ "eos_token_id": 127960,
26
+ "group_limited_greedy": false,
27
+ "hidden_act": "silu",
28
+ "hidden_size": 4096,
29
+ "im_end_id": 6,
30
+ "im_newline_id": 12,
31
+ "im_start_id": 5,
32
+ "image_token_id": 9,
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 3072,
35
+ "kv_lora_rank": null,
36
+ "mask_init_id": 13,
37
+ "max_position_embeddings": 32768,
38
+ "mlp_bias": false,
39
+ "model_type": "hunyuan",
40
+ "moe_drop_tokens": false,
41
+ "moe_intermediate_size": [
42
+ 3072,
43
+ 3072,
44
+ 3072,
45
+ 3072,
46
+ 3072,
47
+ 3072,
48
+ 3072,
49
+ 3072,
50
+ 3072,
51
+ 3072,
52
+ 3072,
53
+ 3072,
54
+ 3072,
55
+ 3072,
56
+ 3072,
57
+ 3072,
58
+ 3072,
59
+ 3072,
60
+ 3072,
61
+ 3072,
62
+ 3072,
63
+ 3072,
64
+ 3072,
65
+ 3072,
66
+ 3072,
67
+ 3072,
68
+ 3072,
69
+ 3072,
70
+ 3072,
71
+ 3072,
72
+ 3072,
73
+ 3072
74
+ ],
75
+ "moe_layer_num_skipped": 0,
76
+ "moe_random_routing_dropped_token": false,
77
+ "moe_topk": [
78
+ 8,
79
+ 8,
80
+ 8,
81
+ 8,
82
+ 8,
83
+ 8,
84
+ 8,
85
+ 8,
86
+ 8,
87
+ 8,
88
+ 8,
89
+ 8,
90
+ 8,
91
+ 8,
92
+ 8,
93
+ 8,
94
+ 8,
95
+ 8,
96
+ 8,
97
+ 8,
98
+ 8,
99
+ 8,
100
+ 8,
101
+ 8,
102
+ 8,
103
+ 8,
104
+ 8,
105
+ 8,
106
+ 8,
107
+ 8,
108
+ 8,
109
+ 8
110
+ ],
111
+ "n_group": null,
112
+ "norm_topk_prob": true,
113
+ "norm_type": "rms",
114
+ "num_attention_heads": 32,
115
+ "num_experts": 64,
116
+ "num_hidden_layers": 32,
117
+ "num_key_value_heads": 8,
118
+ "num_media_embeds": 257,
119
+ "num_shared_expert": [
120
+ 1,
121
+ 1,
122
+ 1,
123
+ 1,
124
+ 1,
125
+ 1,
126
+ 1,
127
+ 1,
128
+ 1,
129
+ 1,
130
+ 1,
131
+ 1,
132
+ 1,
133
+ 1,
134
+ 1,
135
+ 1,
136
+ 1,
137
+ 1,
138
+ 1,
139
+ 1,
140
+ 1,
141
+ 1,
142
+ 1,
143
+ 1,
144
+ 1,
145
+ 1,
146
+ 1,
147
+ 1,
148
+ 1,
149
+ 1,
150
+ 1,
151
+ 1
152
+ ],
153
+ "org_vocab_size": 128167,
154
+ "pad_id": 127961,
155
+ "pad_token_id": 127961,
156
+ "pool_type": "last",
157
+ "position_embedding_xdrope": false,
158
+ "pretraining_tp": 1,
159
+ "q_lora_rank": null,
160
+ "qk_nope_head_dim": null,
161
+ "qk_rope_head_dim": null,
162
+ "rms_norm_eps": 1e-05,
163
+ "rope_scaling": {
164
+ "alpha": 1000.0,
165
+ "beta_fast": 32,
166
+ "beta_slow": 1,
167
+ "factor": 1.0,
168
+ "mscale": 1.0,
169
+ "mscale_all_dim": 1.0,
170
+ "type": "dynamic"
171
+ },
172
+ "rope_theta": 10000.0,
173
+ "routed_scaling_factor": 1.0,
174
+ "sep_token_id": 127962,
175
+ "skip_cls_token": false,
176
+ "text_end_id": 8,
177
+ "text_start_id": 7,
178
+ "tie_word_embeddings": true,
179
+ "topk_group": null,
180
+ "torch_dtype": "bfloat16",
181
+ "transformers_version": "4.41.2",
182
+ "use_cache": true,
183
+ "use_cla": false,
184
+ "use_mixed_mlp_moe": true,
185
+ "use_mla": false,
186
+ "use_qk_norm": true,
187
+ "use_rotary_pos_emb": true,
188
+ "v_head_dim": null,
189
+ "video_end_id": 11,
190
+ "video_start_id": 10,
191
+ "vit_add_patchemb_bias": false,
192
+ "vit_input_resolution": 224,
193
+ "vit_mapping_type": "resampler",
194
+ "vit_norm_type": "fused",
195
+ "vit_patch": 1,
196
+ "vit_path": null,
197
+ "vit_remove_prenorm": false,
198
+ "vit_token": 64,
199
+ "vit_type": null,
200
+ "vit_used_rms_norm": false,
201
+ "vocab_size": 128167,
202
+ "xdrope_section": null
203
+ }
configuration_hunyuan.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
3
+ """ HunYuan model configuration"""
4
+ from torch import nn
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+ from typing import List, Union, Optional
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class HunYuanConfig(PretrainedConfig):
14
+ r"""
15
+ This is the configuration class to store the configuration of a [`HunYuanModel`]. It is used to instantiate an
16
+ HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
17
+ with the defaults will yield a similar configuration to that of the HunYuan-7B.
18
+
19
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
20
+ documentation from [`PretrainedConfig`] for more information.
21
+
22
+
23
+ Args:
24
+ vocab_size (`int`, *optional*, defaults to 32000):
25
+ Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
26
+ `inputs_ids` passed when calling [`HunYuanModel`]
27
+ hidden_size (`int`, *optional*, defaults to 4096):
28
+ Dimension of the hidden representations.
29
+ intermediate_size (`int`, *optional*, defaults to 11008):
30
+ Dimension of the MLP representations or shared MLP representations.
31
+ moe_intermediate_size (`int` or `List`, *optional*, defaults to 11008):
32
+ Dimension of the MLP representations in MoE. Use a list if you want a different size per layer.
33
+ num_hidden_layers (`int`, *optional*, defaults to 32):
34
+ Number of hidden layers in the Transformer decoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer decoder.
37
+ num_key_value_heads (`int`, *optional*):
38
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
39
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
40
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
41
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
42
+ by meanpooling all the original heads within that group. For more details checkout [this
43
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
44
+ `num_attention_heads`.
45
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
46
+ The non-linear activation function (function or string) in the decoder.
47
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
48
+ The maximum sequence length that this model might ever be used with.
49
+ initializer_range (`float`, *optional*, defaults to 0.02):
50
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
52
+ The epsilon used by the rms normalization layers.
53
+ use_cache (`bool`, *optional*, defaults to `True`):
54
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
55
+ relevant if `config.is_decoder=True`.
56
+ pad_token_id (`int`, *optional*):
57
+ Padding token id.
58
+ bos_token_id (`int`, *optional*, defaults to 1):
59
+ Beginning of stream token id.
60
+ eos_token_id (`int`, *optional*, defaults to 2):
61
+ End of stream token id.
62
+ pretraining_tp (`int`, *optional*, defaults to 1):
63
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
64
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
65
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
66
+ issue](https://github.com/pytorch/pytorch/issues/76232).
67
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
68
+ Whether to tie weight embeddings
69
+ rope_theta (`float`, *optional*, defaults to 10000.0):
70
+ The base period of the RoPE embeddings.
71
+ rope_scaling (`Dict`, *optional*):
72
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
73
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
74
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
75
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
76
+ these scaling strategies behave:
77
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
78
+ experimental feature, subject to breaking API changes in future versions.
79
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
80
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
81
+ attention_dropout (`float`, *optional*, defaults to 0.0):
82
+ The dropout ratio for the attention probabilities.
83
+ use_qk_norm (`bool`, *optional*, defaults to `False`):
84
+ Whether query and key in attention use norm
85
+ use_cla (`bool`, *optional*, defaults to `False`):
86
+ Whether to use CLA in attention
87
+ cla_share_factor (`int`, *optional*, defaults to 1):
88
+ The share factor of CLA
89
+ num_experts (`int` or `List`, *optional*, defaults to 1):
90
+ The number of experts for moe. If it is a list, it will be used as the number of experts for each layer.
91
+ num_shared_expert (`int` or `List`, *optional*, defaults to 1):
92
+ The number of shared experts for moe. If it is a list, it will be used as the number of shared experts for each layer.
93
+ moe_topk (`int` or `List`, *optional*, defaults to 1):
94
+ The topk value for moe. If it is a list, it will be used as the topk value for each layer.
95
+ capacity_factor (Not used) (`float` or `List`, *optional*, defaults to 1.0):
96
+ The capacity factor for moe. If it is a list, it will be used as the capacity factor for each layer.
97
+ moe_layer_num_skipped (`int`, *optional*, defaults to 0):
98
+ First moe_layer_num_skipped layers do not use MoE.
99
+ """
100
+
101
+ model_type = "hunyuan"
102
+ keys_to_ignore_at_inference = ["past_key_values"]
103
+
104
+ def __init__(
105
+ self,
106
+ vocab_size=290943,
107
+ org_vocab_size=290943,
108
+ hidden_size=4096,
109
+ intermediate_size: int=11008,
110
+ moe_intermediate_size: Union[int, List]=None,
111
+ num_hidden_layers=32,
112
+ num_attention_heads=32,
113
+ num_key_value_heads=None,
114
+ attention_head_dim=None,
115
+ hidden_act="silu",
116
+ max_position_embeddings=2048,
117
+ initializer_range=0.02,
118
+ rms_norm_eps=1e-5,
119
+ use_cache=True,
120
+ pad_token_id=0,
121
+ bos_token_id=1,
122
+ eos_token_id=2,
123
+ eod_token_id=3,
124
+ sep_token_id=4,
125
+ im_start_id=5,
126
+ im_end_id=6,
127
+ text_start_id=7,
128
+ text_end_id=8,
129
+ image_token_id=9,
130
+ video_start_id=10,
131
+ video_end_id=11,
132
+ im_newline_id=12,
133
+ mask_init_id=13,
134
+ pretraining_tp=1,
135
+ tie_word_embeddings=False,
136
+ rope_theta=10000.0,
137
+ rope_scaling=None,
138
+ attention_bias=False,
139
+ mlp_bias=False,
140
+ attention_dropout=0.0,
141
+ use_qk_norm=False,
142
+ use_rotary_pos_emb=True,
143
+ use_cla=False,
144
+ cla_share_factor=1,
145
+ norm_type="hf_rms",
146
+ num_experts: Union[int, List]=1,
147
+ use_mixed_mlp_moe=False,
148
+ num_shared_expert: Union[int, List]=1,
149
+ moe_topk: Union[int, List]=1,
150
+ # capacity_factor: Union[int, List]=1.0,
151
+ moe_drop_tokens=False,
152
+ moe_random_routing_dropped_token=False,
153
+ use_mla=False,
154
+ kv_lora_rank=512,
155
+ q_lora_rank=1536,
156
+ qk_rope_head_dim=64,
157
+ v_head_dim=128,
158
+ qk_nope_head_dim=128,
159
+ moe_layer_num_skipped=0,
160
+ norm_topk_prob=True,
161
+ routed_scaling_factor=1.0,
162
+ group_limited_greedy=False,
163
+ n_group=None,
164
+ topk_group=None,
165
+ vit_path=None,
166
+ num_media_embeds=257,
167
+ vit_type="AnyResVit",
168
+ vit_input_resolution=224,
169
+ vit_token=64,
170
+ vit_patch=1,
171
+ vit_mapping_type="simple_conv_mlp",
172
+ vit_norm_type="fused",
173
+ vit_used_rms_norm=True,
174
+ vit_remove_prenorm=True,
175
+ vit_add_patchemb_bias=True,
176
+ anyres_vit_max_image_size=2048,
177
+ anyres_pooling_size=2,
178
+ anyres_vit_two_views=False,
179
+ skip_cls_token=False,
180
+ position_embedding_xdrope=False,
181
+ xdrope_section=None,
182
+ add_classification_head=False,
183
+ class_num=0,
184
+ pool_type="last",
185
+ pad_id=-1,
186
+ **kwargs,
187
+ ):
188
+ self.vocab_size = vocab_size
189
+ self.org_vocab_size = org_vocab_size
190
+ self.max_position_embeddings = max_position_embeddings
191
+ self.hidden_size = hidden_size
192
+ self.intermediate_size = intermediate_size
193
+ self.moe_intermediate_size = moe_intermediate_size
194
+ self.num_hidden_layers = num_hidden_layers
195
+ self.num_attention_heads = num_attention_heads
196
+ self.num_experts = num_experts
197
+ self.use_mixed_mlp_moe = use_mixed_mlp_moe
198
+ self.num_shared_expert = num_shared_expert
199
+ self.moe_topk = moe_topk
200
+ # self.capacity_factor = capacity_factor
201
+ self.moe_drop_tokens = moe_drop_tokens
202
+ self.moe_random_routing_dropped_token = moe_random_routing_dropped_token
203
+
204
+ if attention_head_dim is not None:
205
+ self.attention_head_dim = attention_head_dim
206
+ else:
207
+ self.attention_head_dim = self.hidden_size // num_attention_heads
208
+
209
+ # for backward compatibility
210
+ if num_key_value_heads is None:
211
+ num_key_value_heads = num_attention_heads
212
+
213
+ self.num_key_value_heads = num_key_value_heads
214
+ self.hidden_act = hidden_act
215
+ self.initializer_range = initializer_range
216
+ self.rms_norm_eps = rms_norm_eps
217
+ self.pretraining_tp = pretraining_tp
218
+ self.use_cache = use_cache
219
+ self.rope_theta = rope_theta
220
+ self.rope_scaling = rope_scaling
221
+ # self._rope_scaling_validation() # TODO: Need validation?
222
+ self.attention_bias = attention_bias
223
+ self.mlp_bias = mlp_bias
224
+ self.attention_dropout = attention_dropout
225
+ self.use_qk_norm = use_qk_norm
226
+ self.use_rotary_pos_emb = use_rotary_pos_emb
227
+ self.use_cla = use_cla
228
+ self.cla_share_factor = cla_share_factor
229
+ self.norm_type = norm_type
230
+ # MLA args
231
+ self.use_mla = use_mla
232
+ self.kv_lora_rank = kv_lora_rank
233
+ self.q_lora_rank = q_lora_rank
234
+ self.qk_rope_head_dim = qk_rope_head_dim
235
+ self.qk_nope_head_dim = qk_nope_head_dim
236
+ self.v_head_dim = v_head_dim
237
+
238
+ # DeepSeek related args
239
+ self.moe_layer_num_skipped = moe_layer_num_skipped
240
+ self.norm_topk_prob = norm_topk_prob
241
+ self.routed_scaling_factor = routed_scaling_factor
242
+ self.group_limited_greedy = group_limited_greedy
243
+ self.n_group = n_group
244
+ self.topk_group = topk_group
245
+ self.add_classification_head = add_classification_head
246
+ self.class_num = class_num
247
+ self.pool_type = pool_type
248
+ self.pad_id = pad_id
249
+
250
+ if self.class_num is not None:
251
+ self.dense_list = [self.hidden_size, self.class_num]
252
+
253
+ # Vit args
254
+ self.vit_path = vit_path
255
+ self.num_media_embeds = num_media_embeds
256
+ self.vit_type = vit_type
257
+ self.vit_input_resolution = vit_input_resolution
258
+ self.vit_token = vit_token
259
+ self.vit_patch = vit_patch
260
+ self.vit_mapping_type = vit_mapping_type
261
+ self.vit_norm_type = vit_norm_type
262
+ self.vit_used_rms_norm = vit_used_rms_norm
263
+ self.vit_remove_prenorm = vit_remove_prenorm
264
+ self.vit_add_patchemb_bias = vit_add_patchemb_bias
265
+ self.anyres_vit_max_image_size = anyres_vit_max_image_size
266
+ self.anyres_pooling_size = anyres_pooling_size
267
+ self.anyres_vit_two_views = anyres_vit_two_views
268
+ self.skip_cls_token = skip_cls_token
269
+ self.position_embedding_xdrope = position_embedding_xdrope
270
+ self.xdrope_section = xdrope_section
271
+
272
+ # token id
273
+ self.eod_token_id = eod_token_id
274
+ self.im_start_id = im_start_id
275
+ self.im_end_id = im_end_id
276
+ self.text_start_id = text_start_id
277
+ self.text_end_id = text_end_id
278
+ self.image_token_id = image_token_id
279
+ self.video_start_id = video_start_id
280
+ self.video_end_id = video_end_id
281
+ self.im_newline_id = im_newline_id
282
+ self.mask_init_id = mask_init_id
283
+
284
+ super().__init__(
285
+ pad_token_id=pad_token_id,
286
+ bos_token_id=bos_token_id,
287
+ eos_token_id=eos_token_id,
288
+ sep_token_id=sep_token_id,
289
+ tie_word_embeddings=tie_word_embeddings,
290
+ **kwargs,
291
+ )
292
+
293
+ def _rope_scaling_validation(self):
294
+ """
295
+ Validate the `rope_scaling` configuration.
296
+ """
297
+ if self.rope_scaling is None:
298
+ return
299
+
300
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
301
+ raise ValueError(
302
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor` or `type` and `alpha`, "
303
+ f"got {self.rope_scaling}"
304
+ )
305
+ rope_scaling_type = self.rope_scaling.get("type", None)
306
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
307
+ rope_scaling_alpha = self.rope_scaling.get("alpha", None)
308
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
309
+ raise ValueError(
310
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
311
+ )
312
+ if rope_scaling_factor is None and rope_scaling_alpha is None:
313
+ raise ValueError("`rope_scaling`'s factor or alpha field must be have one, got both of none")
314
+ if rope_scaling_factor is not None:
315
+ if not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
316
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1.0, got {rope_scaling_factor}")
317
+ if rope_scaling_alpha is not None:
318
+ if not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0:
319
+ raise ValueError(f"`rope_scaling`'s alpha field must be a float > 1.0, got {rope_scaling_alpha}")
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token_id": [127960, 127967],
3
+ "pad_token_id": 127961,
4
+ "do_sample": true,
5
+ "top_k": 20,
6
+ "top_p": 0.8,
7
+ "repetition_penalty": 1.05,
8
+ "temperature": 0.7,
9
+ "transformers_version": "4.31.0"
10
+ }
hunyuan.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
3
+ #
4
+ """ PyTorch HunYuan model."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from torch import nn
15
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16
+
17
+ from transformers.activations import ACT2FN
18
+ from transformers.cache_utils import Cache, DynamicCache
19
+ from transformers.modeling_attn_mask_utils import (
20
+ AttentionMaskConverter,
21
+ _prepare_4d_attention_mask,
22
+ _prepare_4d_causal_attention_mask,
23
+ _prepare_4d_causal_attention_mask_for_sdpa,
24
+ )
25
+ from transformers.modeling_outputs import (
26
+ BaseModelOutputWithPast,
27
+ CausalLMOutputWithPast,
28
+ SequenceClassifierOutputWithPast
29
+ )
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
32
+ from transformers.utils import (
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_flash_attn_2_available,
36
+ is_flash_attn_greater_or_equal_2_10,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers.utils.import_utils import is_torch_fx_available
41
+ from transformers.generation.utils import GenerateOutput
42
+ from .configuration_hunyuan import HunYuanConfig
43
+ from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
44
+ from .vit_model import NaVitForward, VitForward, Vit
45
+
46
+
47
+ if is_flash_attn_2_available():
48
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
49
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
50
+
51
+
52
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
53
+ # It means that the function will not be traced through and simply appear as a node in the graph.
54
+ if is_torch_fx_available():
55
+ if not is_torch_greater_or_equal_than_1_13:
56
+ import torch.fx
57
+
58
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
59
+
60
+
61
+
62
+ _CONFIG_FOR_DOC = "HunYuanConfig"
63
+
64
+
65
+ HUNYUAN_START_DOCSTRING = r"""
66
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
67
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
68
+ etc.)
69
+
70
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
71
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
72
+ and behavior.
73
+
74
+ Parameters:
75
+ config ([`HunYuanConfig`]):
76
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
77
+ load the weights associated with the model, only the configuration. Check out the
78
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
79
+ """
80
+
81
+
82
+ @add_start_docstrings(
83
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
84
+ HUNYUAN_START_DOCSTRING,
85
+ )
86
+ class HunYuanPreTrainedModel(PreTrainedModel):
87
+ config_class = HunYuanConfig
88
+ base_model_prefix = "model"
89
+ supports_gradient_checkpointing = True
90
+ _no_split_modules = ["HunYuanDecoderLayer"]
91
+ _skip_keys_device_placement = "past_key_values"
92
+ _supports_flash_attn_2 = True
93
+ _supports_sdpa = True
94
+ _supports_cache_class = True
95
+
96
+ def _init_weights(self, module):
97
+ std = self.config.initializer_range
98
+ if isinstance(module, nn.Linear):
99
+ module.weight.data.normal_(mean=0.0, std=std)
100
+ if module.bias is not None:
101
+ module.bias.data.zero_()
102
+ elif isinstance(module, nn.Embedding):
103
+ module.weight.data.normal_(mean=0.0, std=std)
104
+ if module.padding_idx is not None:
105
+ module.weight.data[module.padding_idx].zero_()
106
+
107
+
108
+ HUNYUAN_INPUTS_DOCSTRING = r"""
109
+ Args:
110
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
111
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
112
+ it.
113
+
114
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
115
+ [`PreTrainedTokenizer.__call__`] for details.
116
+
117
+ [What are input IDs?](../glossary#input-ids)
118
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
119
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
120
+
121
+ - 1 for tokens that are **not masked**,
122
+ - 0 for tokens that are **masked**.
123
+
124
+ [What are attention masks?](../glossary#attention-mask)
125
+
126
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
127
+ [`PreTrainedTokenizer.__call__`] for details.
128
+
129
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
130
+ `past_key_values`).
131
+
132
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
133
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
134
+ information on the default strategy.
135
+
136
+ - 1 indicates the head is **not masked**,
137
+ - 0 indicates the head is **masked**.
138
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
139
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
140
+ config.n_positions - 1]`.
141
+
142
+ [What are position IDs?](../glossary#position-ids)
143
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
144
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
145
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
146
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
147
+
148
+ Two formats are allowed:
149
+ - a [`~cache_utils.Cache`] instance;
150
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
151
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
152
+ cache format.
153
+
154
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
155
+ legacy cache format will be returned.
156
+
157
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
158
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
159
+ of shape `(batch_size, sequence_length)`.
160
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
161
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
162
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
163
+ model's internal embedding lookup matrix.
164
+ use_cache (`bool`, *optional*):
165
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
166
+ `past_key_values`).
167
+ output_attentions (`bool`, *optional*):
168
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
169
+ tensors for more detail.
170
+ output_hidden_states (`bool`, *optional*):
171
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
172
+ more detail.
173
+ return_dict (`bool`, *optional*):
174
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
175
+ """
176
+
177
+
178
+ @add_start_docstrings(
179
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
180
+ HUNYUAN_START_DOCSTRING,
181
+ )
182
+ class HunYuanModel(HunYuanPreTrainedModel):
183
+ """
184
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
185
+
186
+ Args:
187
+ config: HunYuanConfig
188
+ """
189
+
190
+ def __init__(self, config: HunYuanConfig):
191
+ super().__init__(config)
192
+ self.padding_idx = config.pad_token_id
193
+ self.vocab_size = config.vocab_size
194
+ self.add_classification_head = config.add_classification_head
195
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
196
+ self.layers = nn.ModuleList(
197
+ [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
198
+ )
199
+ self._use_sdpa = config._attn_implementation == "sdpa"
200
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
201
+ if not config.add_classification_head:
202
+ self.norm = HunYuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
203
+
204
+ self.cla = config.use_cla
205
+ self.cla_share_factor = config.cla_share_factor
206
+
207
+ self.gradient_checkpointing = False
208
+ # Initialize weights and apply final processing
209
+ self.post_init()
210
+
211
+ def get_input_embeddings(self):
212
+ return self.embed_tokens
213
+
214
+ def set_input_embeddings(self, value):
215
+ self.embed_tokens = value
216
+
217
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
218
+ def forward(
219
+ self,
220
+ input_ids: torch.LongTensor = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ position_ids: Optional[torch.LongTensor] = None,
223
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
224
+ inputs_embeds: Optional[torch.FloatTensor] = None,
225
+ use_cache: Optional[bool] = None,
226
+ output_attentions: Optional[bool] = None,
227
+ output_hidden_states: Optional[bool] = None,
228
+ return_dict: Optional[bool] = None,
229
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = (
232
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
233
+ )
234
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
235
+
236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
237
+
238
+ # retrieve input_ids and inputs_embeds
239
+ # if input_ids is not None and inputs_embeds is not None:
240
+ # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
241
+ if input_ids is not None:
242
+ batch_size, seq_length = input_ids.shape[:2]
243
+ elif inputs_embeds is not None:
244
+ batch_size, seq_length = inputs_embeds.shape[:2]
245
+ else:
246
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
247
+
248
+ if self.gradient_checkpointing and self.training:
249
+ if use_cache:
250
+ logger.warning_once(
251
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
252
+ )
253
+ use_cache = False
254
+
255
+ past_key_values_length = 0
256
+ if use_cache:
257
+ use_legacy_cache = not isinstance(past_key_values, Cache)
258
+ if use_legacy_cache:
259
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
260
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
261
+
262
+ if position_ids is None:
263
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
264
+ position_ids = torch.arange(
265
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
266
+ )
267
+ position_ids = position_ids.unsqueeze(0)
268
+
269
+ if inputs_embeds is None:
270
+ inputs_embeds = self.embed_tokens(input_ids)
271
+
272
+ # Fix lora with gradient checkpointing training
273
+ if self.training and inputs_embeds.is_leaf:
274
+ inputs_embeds.requires_grad = True
275
+
276
+ if self._use_flash_attention_2:
277
+ # 2d mask is passed through the layers
278
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
279
+ elif self._use_sdpa and not output_attentions:
280
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
281
+ # the manual implementation that requires a 4D causal mask in all cases.
282
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
283
+ attention_mask,
284
+ (batch_size, seq_length),
285
+ inputs_embeds,
286
+ past_key_values_length,
287
+ )
288
+ else:
289
+ # 4d mask is passed through the layers
290
+ attention_mask = _prepare_4d_causal_attention_mask(
291
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
292
+ )
293
+
294
+ # embed positions
295
+ hidden_states = inputs_embeds
296
+
297
+ # decoder layers
298
+ all_hidden_states = () if output_hidden_states else None
299
+ all_self_attns = () if output_attentions else None
300
+ next_decoder_cache = None
301
+
302
+ prev_kv_states = None
303
+ for layer_idx, decoder_layer in enumerate(self.layers):
304
+ if output_hidden_states:
305
+ all_hidden_states += (hidden_states,)
306
+
307
+ if self.gradient_checkpointing and self.training:
308
+ layer_outputs = self._gradient_checkpointing_func(
309
+ decoder_layer.__call__,
310
+ hidden_states,
311
+ attention_mask,
312
+ position_ids,
313
+ past_key_values,
314
+ output_attentions,
315
+ use_cache,
316
+ prev_kv_states,
317
+ )
318
+ else:
319
+ layer_outputs = decoder_layer(
320
+ hidden_states,
321
+ attention_mask=attention_mask,
322
+ position_ids=position_ids,
323
+ past_key_value=past_key_values,
324
+ output_attentions=output_attentions,
325
+ use_cache=use_cache,
326
+ kv_states=prev_kv_states
327
+ )
328
+
329
+ hidden_states = layer_outputs[0]
330
+
331
+ if use_cache:
332
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
333
+
334
+ if output_attentions:
335
+ all_self_attns += (layer_outputs[1],)
336
+
337
+ kv_states = layer_outputs[-1]
338
+
339
+ if self.cla and layer_idx % self.cla_share_factor == 0:
340
+ prev_kv_states = kv_states
341
+ if not self.add_classification_head:
342
+ hidden_states = self.norm(hidden_states)
343
+
344
+ # add hidden states from the last decoder layer
345
+ if output_hidden_states:
346
+ all_hidden_states += (hidden_states,)
347
+
348
+ next_cache = None
349
+ if use_cache:
350
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
351
+ if not return_dict:
352
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
353
+ return BaseModelOutputWithPast(
354
+ last_hidden_state=hidden_states,
355
+ past_key_values=next_cache,
356
+ hidden_states=all_hidden_states,
357
+ attentions=all_self_attns,
358
+ )
359
+
360
+
361
+ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
362
+ _tied_weights_keys = ["lm_head.weight"]
363
+
364
+ def __init__(self, config: HunYuanConfig):
365
+ super().__init__(config)
366
+ if config.vit_path is not None:
367
+ if "-tp" in config.vit_type:
368
+ config.vit_type = config.vit_type.replace("-tp", "")
369
+ self.vit_type = config.vit_type
370
+ if self.vit_type not in ['NaVit', 'EvaVit']:
371
+ if config.vit_mapping_type == 'mlp':
372
+ self.vit_linear_encoder = torch.nn.Linear(config.hidden_size, config.hidden_size)
373
+ self.vit = Vit(config)
374
+ else:
375
+ self.vit = None
376
+ self.config = config
377
+ self.model = HunYuanModel(config)
378
+ self.add_classification_head = config.add_classification_head
379
+ self.pad_id = config.pad_id
380
+ self.vocab_size = config.vocab_size
381
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
382
+ if config.add_classification_head:
383
+ self.pool_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
384
+ self.pool_head2 = nn.Linear(config.hidden_size, config.class_num, bias=False)
385
+ # Initialize weights and apply final processing
386
+ self.post_init()
387
+
388
+ def get_input_embeddings(self):
389
+ return self.model.embed_tokens
390
+
391
+ def set_input_embeddings(self, value):
392
+ self.model.embed_tokens = value
393
+
394
+ def get_output_embeddings(self):
395
+ return self.lm_head
396
+
397
+ def set_output_embeddings(self, new_embeddings):
398
+ self.lm_head = new_embeddings
399
+
400
+ def set_decoder(self, decoder):
401
+ self.model = decoder
402
+
403
+ def get_decoder(self):
404
+ return self.model
405
+
406
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
407
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
408
+ def forward(
409
+ self,
410
+ input_ids: torch.LongTensor = None,
411
+ attention_mask: Optional[torch.Tensor] = None,
412
+ position_ids: Optional[torch.LongTensor] = None,
413
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
414
+ inputs_embeds: Optional[torch.FloatTensor] = None,
415
+ labels: Optional[torch.LongTensor] = None,
416
+ use_cache: Optional[bool] = None,
417
+ output_attentions: Optional[bool] = None,
418
+ output_hidden_states: Optional[bool] = None,
419
+ return_dict: Optional[bool] = None,
420
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
421
+ r"""
422
+ Args:
423
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
424
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
425
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
426
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
427
+
428
+ Returns:
429
+
430
+ Example:
431
+
432
+ ```python
433
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
434
+
435
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
436
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
437
+
438
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
439
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
440
+
441
+ >>> # Generate
442
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
443
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
444
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
445
+ ```"""
446
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
447
+ output_hidden_states = (
448
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
449
+ )
450
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
451
+
452
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
453
+ outputs = self.model(
454
+ input_ids=input_ids,
455
+ attention_mask=attention_mask,
456
+ position_ids=position_ids,
457
+ past_key_values=past_key_values,
458
+ inputs_embeds=inputs_embeds,
459
+ use_cache=use_cache,
460
+ output_attentions=output_attentions,
461
+ output_hidden_states=output_hidden_states,
462
+ return_dict=return_dict,
463
+ )
464
+
465
+ hidden_states = outputs[0]
466
+
467
+ if not self.add_classification_head:
468
+ if self.config.pretraining_tp > 1:
469
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
470
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
471
+ logits = torch.cat(logits, dim=-1)
472
+ else:
473
+ logits = self.lm_head(hidden_states)
474
+ logits = logits.float()
475
+ else:
476
+ logits = hidden_states
477
+ logits = logits.float()
478
+ pooled_output = self.pool_head(logits)
479
+ pooled_output = torch.tanh(pooled_output)
480
+ pooled_output = self.pool_head2(pooled_output).contiguous() # bs * class_num
481
+ if len(pooled_output.shape) < 2:
482
+ raise ValueError("pooled_output does not have enough dimensions for transpose")
483
+
484
+ if self.config.pool_type == "mean":
485
+ reward = pooled_output.mean(dim=1).squeeze(-1)
486
+ elif self.config.pool_type == "last":
487
+ # bs * hidden_size
488
+ seq_length = (input_ids != self.pad_id).long().sum(dim=1) - 1
489
+ batch_size = input_ids.size(0)
490
+ reward = pooled_output[torch.arange(batch_size, device=pooled_output.device), seq_length].squeeze(-1)
491
+ else:
492
+ reward = pooled_output[:, 0].squeeze(-1)
493
+
494
+ loss = None
495
+ if labels is not None:
496
+ # Shift so that tokens < n predict n
497
+ shift_logits = logits[..., :-1, :].contiguous()
498
+ shift_labels = labels[..., 1:].contiguous()
499
+ # Flatten the tokens
500
+ loss_fct = CrossEntropyLoss()
501
+ shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
502
+ shift_labels = shift_labels.reshape(-1)
503
+ # Enable model parallelism
504
+ shift_labels = shift_labels.to(shift_logits.device)
505
+ loss = loss_fct(shift_logits, shift_labels)
506
+
507
+ if not return_dict:
508
+ output = (logits,) + outputs[1:]
509
+ return (loss,) + output if loss is not None else output
510
+
511
+ output = CausalLMOutputWithPast(
512
+ loss=loss,
513
+ logits=logits,
514
+ past_key_values=outputs.past_key_values,
515
+ hidden_states=outputs.hidden_states,
516
+ attentions=outputs.attentions,
517
+ )
518
+ if self.add_classification_head:
519
+ output['reward'] = reward
520
+
521
+ return output
522
+
523
+ def prepare_inputs_for_generation(
524
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
525
+ ):
526
+ if past_key_values is not None:
527
+ if isinstance(past_key_values, Cache):
528
+ cache_length = past_key_values.get_seq_length()
529
+ past_length = past_key_values.seen_tokens
530
+ max_cache_length = past_key_values.get_max_cache_shape()
531
+ else:
532
+ cache_length = past_length = past_key_values[0][0].shape[2]
533
+ max_cache_length = None
534
+
535
+ # Keep only the unprocessed tokens:
536
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
537
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
538
+ # input)
539
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
540
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
541
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
542
+ # input_ids based on the past_length.
543
+ elif past_length < input_ids.shape[1]:
544
+ input_ids = input_ids[:, past_length:]
545
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
546
+
547
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
548
+ if (
549
+ max_cache_length is not None
550
+ and attention_mask is not None
551
+ and cache_length + input_ids.shape[1] > max_cache_length
552
+ ):
553
+ attention_mask = attention_mask[:, -max_cache_length:]
554
+
555
+ position_ids = kwargs.get("position_ids", None)
556
+ if attention_mask is not None and position_ids is None:
557
+ # create position_ids on the fly for batch generation
558
+ position_ids = attention_mask.long().cumsum(-1) - 1
559
+ position_ids.masked_fill_(attention_mask == 0, 1)
560
+ if past_key_values:
561
+ position_ids = position_ids[:, -input_ids.shape[1]:]
562
+
563
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
564
+ if inputs_embeds is not None and past_key_values is None:
565
+ model_inputs = {"inputs_embeds": inputs_embeds}
566
+ else:
567
+ model_inputs = {"input_ids": input_ids}
568
+
569
+ model_inputs.update(
570
+ {
571
+ "position_ids": position_ids,
572
+ "past_key_values": past_key_values,
573
+ "use_cache": kwargs.get("use_cache"),
574
+ "attention_mask": attention_mask,
575
+ }
576
+ )
577
+ return model_inputs
578
+
579
+ @staticmethod
580
+ def _reorder_cache(past_key_values, beam_idx):
581
+ reordered_past = ()
582
+ for layer_past in past_key_values:
583
+ reordered_past += (
584
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
585
+ )
586
+ return reordered_past
587
+
588
+
589
+ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
590
+ _tied_weights_keys = ["lm_head.weight"]
591
+
592
+ def __init__(self, config: HunYuanConfig):
593
+ super().__init__(config)
594
+
595
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
596
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
597
+ def forward(
598
+ self,
599
+ input_ids: torch.LongTensor = None,
600
+ attention_mask: Optional[torch.Tensor] = None,
601
+ position_ids: Optional[torch.LongTensor] = None,
602
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
603
+ inputs_embeds: Optional[torch.FloatTensor] = None,
604
+ labels: Optional[torch.LongTensor] = None,
605
+ imgs: Optional[List[torch.FloatTensor]] = None,
606
+ imgs_pos: Optional[List[int]] = None,
607
+ use_cache: Optional[bool] = None,
608
+ output_attentions: Optional[bool] = None,
609
+ output_hidden_states: Optional[bool] = None,
610
+ return_dict: Optional[bool] = None,
611
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
612
+ r"""
613
+ Args:
614
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
615
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
616
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
617
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
618
+
619
+ Returns:
620
+
621
+ Example:
622
+
623
+ ```python
624
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
625
+
626
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
627
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
628
+
629
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
630
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
631
+
632
+ >>> # Generate
633
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
634
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
635
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
636
+ ```"""
637
+ mask_init_id = self.config.mask_init_id
638
+ pad_id = self.config.pad_token_id
639
+ eod_id = self.config.eod_token_id
640
+ image_token_id = self.config.image_token_id
641
+ im_start_id = self.config.im_start_id
642
+ im_end_id = self.config.im_end_id
643
+ video_start_id = self.config.video_start_id
644
+ video_end_id = self.config.video_end_id
645
+
646
+ if self.vit is not None and imgs is not None:
647
+ encoder_input = self.model.embed_tokens(input_ids)
648
+ if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
649
+ inputs_embeds, input_ids = NaVitForward(input_ids, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
650
+ im_start_id, im_end_id, image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
651
+ else:
652
+ inputs_embeds, input_ids = VitForward(input_ids, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
653
+ self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
654
+
655
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
656
+ output_hidden_states = (
657
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
658
+ )
659
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
661
+
662
+ outputs = self.model(
663
+ input_ids=input_ids,
664
+ attention_mask=attention_mask,
665
+ position_ids=position_ids,
666
+ past_key_values=past_key_values,
667
+ inputs_embeds=inputs_embeds,
668
+ use_cache=use_cache,
669
+ output_attentions=output_attentions,
670
+ output_hidden_states=output_hidden_states,
671
+ return_dict=return_dict,
672
+ )
673
+
674
+ hidden_states = outputs[0]
675
+ if self.config.pretraining_tp > 1:
676
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
677
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
678
+ logits = torch.cat(logits, dim=-1)
679
+ else:
680
+ logits = self.lm_head(hidden_states)
681
+ logits = logits.float()
682
+
683
+ loss = None
684
+ if labels is not None:
685
+ labels = labels.to(logits.device)
686
+ # Shift so that tokens < n predict n
687
+ shift_logits = logits
688
+ shift_labels = labels
689
+ # Flatten the tokens
690
+ loss_fct = CrossEntropyLoss()
691
+ shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
692
+ shift_labels = shift_labels.reshape(-1)
693
+ shift_tokens = input_ids.reshape(-1)
694
+ # compute loss
695
+ mask = (shift_labels < mask_init_id) & (shift_labels != pad_id) & (shift_labels != image_token_id) & (shift_labels != im_start_id) \
696
+ & (shift_labels != im_end_id) & (shift_labels != video_start_id) & (shift_labels != video_end_id) & (shift_tokens != pad_id) & (shift_tokens != eod_id)
697
+ shift_logits = shift_logits[mask, :]
698
+ shift_labels = shift_labels[mask]
699
+ loss = loss_fct(shift_logits, shift_labels)
700
+
701
+ if not return_dict:
702
+ output = (logits,) + outputs[1:]
703
+ return (loss,) + output if loss is not None else output
704
+
705
+ return CausalLMOutputWithPast(
706
+ loss=loss,
707
+ logits=logits,
708
+ past_key_values=outputs.past_key_values,
709
+ hidden_states=outputs.hidden_states,
710
+ attentions=outputs.attentions,
711
+ )
712
+
713
+ def prepare_inputs_for_generation(
714
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
715
+ ):
716
+ imgs = kwargs.pop("imgs", None)
717
+ imgs_pos = kwargs.pop("imgs_pos", None)
718
+ inputs = super().prepare_inputs_for_generation(
719
+ input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs
720
+ )
721
+
722
+ if imgs is not None:
723
+ inputs['imgs'] = imgs
724
+ if imgs_pos is not None:
725
+ inputs['imgs_pos'] = imgs_pos
726
+ return inputs
727
+
728
+ @torch.no_grad()
729
+ def generate(
730
+ self,
731
+ inputs: Optional[torch.Tensor] = None,
732
+ attention_mask: Optional[torch.Tensor] = None,
733
+ position_ids: Optional[torch.LongTensor] = None,
734
+ imgs: Optional[List[torch.FloatTensor]] = None,
735
+ imgs_pos: Optional[List[int]] = None,
736
+ **kwargs,
737
+ ) -> Union[GenerateOutput, torch.LongTensor]:
738
+ if "inputs_embeds" in kwargs:
739
+ raise NotImplementedError("`inputs_embeds` is not supported")
740
+
741
+ if self.vit is not None:
742
+ encoder_input = self.model.embed_tokens(inputs)
743
+ if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
744
+ inputs_embeds, input_ids = NaVitForward(inputs, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
745
+ self.config.im_start_id, self.config.im_end_id, self.config.image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
746
+ else:
747
+ inputs_embeds, input_ids = VitForward(inputs, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
748
+ self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
749
+
750
+ return super().generate(
751
+ inputs=input_ids,
752
+ position_ids=position_ids,
753
+ attention_mask=attention_mask,
754
+ inputs_embeds=inputs_embeds,
755
+ eos_token_id=self.config.eod_token_id,
756
+ **kwargs
757
+ )
758
+
759
+
760
+ @add_start_docstrings(
761
+ """
762
+ The HunYuan Model transformer with a sequence classification head on top (linear layer).
763
+
764
+ [`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models
765
+ (e.g. GPT-2) do.
766
+
767
+ Since it does classification on the last token, it requires to know the position of the last token. If a
768
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
769
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
770
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
771
+ each row of the batch).
772
+ """,
773
+ HUNYUAN_START_DOCSTRING,
774
+ )
775
+ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
776
+ def __init__(self, config):
777
+ super().__init__(config)
778
+ self.num_labels = config.num_labels
779
+ self.model = HunYuanModel(config)
780
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
781
+
782
+ # Initialize weights and apply final processing
783
+ self.post_init()
784
+
785
+ def get_input_embeddings(self):
786
+ return self.model.embed_tokens
787
+
788
+ def set_input_embeddings(self, value):
789
+ self.model.embed_tokens = value
790
+
791
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
792
+ def forward(
793
+ self,
794
+ input_ids: torch.LongTensor = None,
795
+ attention_mask: Optional[torch.Tensor] = None,
796
+ position_ids: Optional[torch.LongTensor] = None,
797
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
798
+ inputs_embeds: Optional[torch.FloatTensor] = None,
799
+ labels: Optional[torch.LongTensor] = None,
800
+ use_cache: Optional[bool] = None,
801
+ output_attentions: Optional[bool] = None,
802
+ output_hidden_states: Optional[bool] = None,
803
+ return_dict: Optional[bool] = None,
804
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
805
+ r"""
806
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
807
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
808
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
809
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
810
+ """
811
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
812
+
813
+ transformer_outputs = self.model(
814
+ input_ids,
815
+ attention_mask=attention_mask,
816
+ position_ids=position_ids,
817
+ past_key_values=past_key_values,
818
+ inputs_embeds=inputs_embeds,
819
+ use_cache=use_cache,
820
+ output_attentions=output_attentions,
821
+ output_hidden_states=output_hidden_states,
822
+ return_dict=return_dict,
823
+ )
824
+ hidden_states = transformer_outputs[0]
825
+ logits = self.score(hidden_states)
826
+
827
+ if input_ids is not None:
828
+ batch_size = input_ids.shape[0]
829
+ else:
830
+ batch_size = inputs_embeds.shape[0]
831
+
832
+ if self.config.pad_token_id is None and batch_size != 1:
833
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
834
+ if self.config.pad_token_id is None:
835
+ sequence_lengths = -1
836
+ else:
837
+ if input_ids is not None:
838
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
839
+ logits.device
840
+ )
841
+ else:
842
+ sequence_lengths = -1
843
+
844
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
845
+
846
+ loss = None
847
+ if labels is not None:
848
+ labels = labels.to(logits.device)
849
+ if self.config.problem_type is None:
850
+ if self.num_labels == 1:
851
+ self.config.problem_type = "regression"
852
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
853
+ self.config.problem_type = "single_label_classification"
854
+ else:
855
+ self.config.problem_type = "multi_label_classification"
856
+
857
+ if self.config.problem_type == "regression":
858
+ loss_fct = MSELoss()
859
+ if self.num_labels == 1:
860
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
861
+ else:
862
+ loss = loss_fct(pooled_logits, labels)
863
+ elif self.config.problem_type == "single_label_classification":
864
+ loss_fct = CrossEntropyLoss()
865
+ loss = loss_fct(pooled_logits.reshape(-1, self.num_labels), labels.reshape(-1))
866
+ elif self.config.problem_type == "multi_label_classification":
867
+ loss_fct = BCEWithLogitsLoss()
868
+ loss = loss_fct(pooled_logits, labels)
869
+ if not return_dict:
870
+ output = (pooled_logits,) + transformer_outputs[1:]
871
+ return ((loss,) + output) if loss is not None else output
872
+
873
+ return SequenceClassifierOutputWithPast(
874
+ loss=loss,
875
+ logits=pooled_logits,
876
+ past_key_values=transformer_outputs.past_key_values,
877
+ hidden_states=transformer_outputs.hidden_states,
878
+ attentions=transformer_outputs.attentions,
879
+ )
hy.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94e14569aa3cb80bc04f48d943f40cc11d6734680315189e9502f9b9c9bee038
3
+ size 40963281008
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8f4d6ac61f5d2cffcfdc54aa22a249bf003b3a72514f08befddf8cc516a47bc
3
+ size 39938503616
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e72f67b554a3105b4bfea7a3670c31d746aafa0c6bbba0da81ca29b2fc26354e
3
+ size 39938503688
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:228fa272d42ef720223fc1596bc31b35e46f394857f04d5182598eb715699aa2
3
+ size 39963677928
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenization_hy.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ import unicodedata
5
+ from typing import Collection, Dict, List, Set, Tuple, Union
6
+
7
+ import tiktoken
8
+ from transformers import PreTrainedTokenizer, AddedToken
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ VOCAB_FILES_NAMES = {"vocab_file": "hy.tiktoken"}
14
+
15
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
16
+ # PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
17
+ ENDOFTEXT = "<|endoftext|>"
18
+ STARTOFTEXT = "<|startoftext|>"
19
+ BOSTOKEN = "<|bos|>"
20
+ EOSTOKEN = "<|eos|>"
21
+ PADTOKEN = "<|pad|>"
22
+
23
+ # as the default behavior is changed to allow special tokens in
24
+ # regular texts, the surface forms of special tokens need to be
25
+ # as different as possible to minimize the impact
26
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
27
+ # changed to use actual index to avoid misconfiguration with vocabulary expansion
28
+
29
+
30
+ SPECIAL_START_ID = 127957
31
+
32
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
33
+ # with open(tiktoken_bpe_file, "rb", encoding="utf-8") as f:
34
+ # contents = f.read()
35
+ dic = {}
36
+ rank = 0
37
+ for line in open(tiktoken_bpe_file, "rb"):
38
+ if line:
39
+ token, _ = line.split()
40
+ if base64.b64decode(token) in dic:
41
+ continue
42
+ dic[base64.b64decode(token)] = int(rank)
43
+ rank += 1
44
+ global SPECIAL_START_ID
45
+ SPECIAL_START_ID=rank
46
+ return dic
47
+
48
+ # NOTE: Please use the code line to check `SPECIAL_START_ID` right, this will affect the SPECIAL_START_ID
49
+ # _load_tiktoken_bpe('/apdcephfs/share_1502809/shaneshu/tokenizer_exp/other_tokenizer_vocab/hy/' + VOCAB_FILES_NAMES['vocab_file'])
50
+ # print(SPECIAL_START_ID)
51
+
52
+ SPECIAL_TOKENS = tuple(
53
+ enumerate(
54
+ (
55
+ (
56
+ ENDOFTEXT,
57
+ STARTOFTEXT,
58
+ BOSTOKEN,
59
+ EOSTOKEN,
60
+ PADTOKEN,
61
+ )
62
+ + EXTRAS
63
+ ),
64
+ start=SPECIAL_START_ID,
65
+ )
66
+ )
67
+ # NOTE: Unused Token ID starts from 127962
68
+ SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
69
+
70
+ class HYTokenizer(PreTrainedTokenizer):
71
+ """hunyuan tokenizer."""
72
+
73
+ vocab_files_names = VOCAB_FILES_NAMES
74
+
75
+ def __init__(
76
+ self,
77
+ vocab_file,
78
+ errors="replace",
79
+ extra_vocab_file=None,
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+
84
+ # how to handle errors in decoding UTF-8 byte sequences
85
+ # use ignore if you are in streaming inference
86
+ self.errors = errors
87
+
88
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
89
+ self.special_tokens = {
90
+ token: index
91
+ for index, token in SPECIAL_TOKENS
92
+ }
93
+
94
+ # try load extra vocab from file
95
+ if extra_vocab_file is not None:
96
+ used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
97
+ extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
98
+ for token, index in extra_mergeable_ranks.items():
99
+ if token in self.mergeable_ranks:
100
+ logger.info(f"extra token {token} exists, skipping")
101
+ continue
102
+ if index in used_ids:
103
+ logger.info(f'the index {index} for extra token {token} exists, skipping')
104
+ continue
105
+ self.mergeable_ranks[token] = index
106
+ # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
107
+
108
+ enc = tiktoken.Encoding(
109
+ "HunYuan",
110
+ pat_str=PAT_STR,
111
+ mergeable_ranks=self.mergeable_ranks,
112
+ special_tokens=self.special_tokens,
113
+ )
114
+ assert (
115
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
116
+ ), f"{len(self.mergeable_ranks)} + {len(self.special_tokens)} != {enc.n_vocab} in encoding"
117
+
118
+ self.decoder = {
119
+ v: k for k, v in self.mergeable_ranks.items()
120
+ } # type: dict[int, bytes|str]
121
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
122
+
123
+ self.tokenizer = enc # type: tiktoken.Encoding
124
+
125
+ self.eod_id = self.tokenizer.eot_token
126
+ self.bod_id = self.special_tokens[STARTOFTEXT]
127
+ self.bos_id = self.special_tokens[BOSTOKEN]
128
+ self.eos_id = self.special_tokens[EOSTOKEN]
129
+ self.pad_id = self.special_tokens[PADTOKEN]
130
+
131
+ def __getstate__(self):
132
+ # for pickle lovers
133
+ state = self.__dict__.copy()
134
+ del state["tokenizer"]
135
+ return state
136
+
137
+ def __setstate__(self, state):
138
+ # tokenizer is not python native; don't pass it; rebuild it
139
+ self.__dict__.update(state)
140
+ enc = tiktoken.Encoding(
141
+ "HunYuan",
142
+ pat_str=PAT_STR,
143
+ mergeable_ranks=self.mergeable_ranks,
144
+ special_tokens=self.special_tokens,
145
+ )
146
+ self.tokenizer = enc
147
+
148
+ def __len__(self) -> int:
149
+ return self.tokenizer.n_vocab
150
+
151
+ def get_vocab(self) -> Dict[bytes, int]:
152
+ return self.mergeable_ranks
153
+
154
+ def convert_tokens_to_ids(
155
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
156
+ ) -> List[int]:
157
+ ids = []
158
+ if isinstance(tokens, (str, bytes)):
159
+ if tokens in self.special_tokens:
160
+ return self.special_tokens[tokens]
161
+ else:
162
+ return self.mergeable_ranks.get(tokens)
163
+ for token in tokens:
164
+ if token in self.special_tokens:
165
+ ids.append(self.special_tokens[token])
166
+ else:
167
+ ids.append(self.mergeable_ranks.get(token))
168
+ return ids
169
+
170
+ def _add_tokens(
171
+ self,
172
+ new_tokens: Union[List[str], List[AddedToken]],
173
+ special_tokens: bool = False,
174
+ ) -> int:
175
+ if not special_tokens and new_tokens:
176
+ raise ValueError("Adding regular tokens is not supported")
177
+ for token in new_tokens:
178
+ surface_form = token.content if isinstance(token, AddedToken) else token
179
+ if surface_form not in SPECIAL_TOKENS_SET:
180
+ raise ValueError("Adding unknown special tokens is not supported")
181
+ return 0
182
+
183
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
184
+ """
185
+ Save only the vocabulary of the tokenizer (vocabulary).
186
+ Returns:
187
+ `Tuple(str)`: Paths to the files saved.
188
+ """
189
+ file_path = os.path.join(save_directory, "hunyuan.tiktoken")
190
+ with open(file_path, "w", encoding="utf-8") as w:
191
+ for k, v in self.mergeable_ranks.items():
192
+ line = base64.b64encode(k).decode("utf-8") + " " + str(v) + "\n"
193
+ w.write(line)
194
+ return (file_path,)
195
+
196
+ def tokenize(
197
+ self,
198
+ text: str,
199
+ allowed_special: Union[Set, str] = "all",
200
+ disallowed_special: Union[Collection, str] = (),
201
+ **kwargs,
202
+ ) -> List[Union[bytes, str]]:
203
+ """
204
+ Converts a string in a sequence of tokens.
205
+ Args:
206
+ text (`str`):
207
+ The sequence to be encoded.
208
+ allowed_special (`Literal["all"]` or `set`):
209
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
210
+ Default to "all".
211
+ disallowed_special (`Literal["all"]` or `Collection`):
212
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
213
+ Default to an empty tuple.
214
+ kwargs (additional keyword arguments, *optional*):
215
+ Will be passed to the underlying model specific encode method.
216
+ Returns:
217
+ `List[bytes|str]`: The list of tokens.
218
+ """
219
+ tokens = []
220
+ text = unicodedata.normalize("NFC", text)
221
+
222
+ # this implementation takes a detour: text -> token id -> token surface forms
223
+ for t in self.tokenizer.encode(
224
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
225
+ ):
226
+ tokens.append(self.decoder[t])
227
+ return tokens
228
+
229
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
230
+ """
231
+ Converts a sequence of tokens in a single string.
232
+ """
233
+ text = ""
234
+ temp = b""
235
+ for t in tokens:
236
+ if isinstance(t, str):
237
+ if temp:
238
+ text += temp.decode("utf-8", errors=self.errors)
239
+ temp = b""
240
+ text += t
241
+ elif isinstance(t, bytes):
242
+ temp += t
243
+ else:
244
+ raise TypeError("token should only be of type types or str")
245
+ if temp:
246
+ text += temp.decode("utf-8", errors=self.errors)
247
+ return text
248
+
249
+ @property
250
+ def vocab_size(self):
251
+ return self.tokenizer.n_vocab
252
+
253
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
254
+ """Converts an id to a token, special tokens included"""
255
+ if index in self.decoder:
256
+ return self.decoder[index]
257
+ raise ValueError("unknown ids")
258
+
259
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
260
+ """Converts a token to an id using the vocab, special tokens included"""
261
+ if token in self.special_tokens:
262
+ return self.special_tokens[token]
263
+ if token in self.mergeable_ranks:
264
+ return self.mergeable_ranks[token]
265
+ raise ValueError("unknown token")
266
+
267
+ def _tokenize(self, text: str, **kwargs):
268
+ """
269
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
270
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
271
+ Do NOT take care of added tokens.
272
+ """
273
+ raise NotImplementedError
274
+
275
+ def _decode(
276
+ self,
277
+ token_ids: Union[int, List[int]],
278
+ skip_special_tokens: bool = False,
279
+ errors: str = None,
280
+ **kwargs,
281
+ ) -> str:
282
+ if isinstance(token_ids, int):
283
+ token_ids = [token_ids]
284
+ if skip_special_tokens:
285
+ token_ids = [i for i in token_ids if i < self.eod_id]
286
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
287
+
288
+ # tests
289
+ if __name__ == "__main__":
290
+ tokenizer = HYTokenizer.from_pretrained('./hy')
291
+ text = '你好,世界'
292
+ tokens = tokenizer.tokenize(text)
293
+ print(tokens)
294
+ ids = tokenizer.convert_tokens_to_ids(tokens)
295
+ print(ids)
296
+ text2 = tokenizer.convert_tokens_to_string(tokens)
297
+ print(text2)
298
+ ids2 = tokenizer.convert_tokens_to_ids(tokens)
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPT2LMHeadModel"
4
+ ],
5
+ "model_max_length": 1048576,
6
+ "tokenizer_class": "HYTokenizer",
7
+ "auto_map": {
8
+ "AutoTokenizer": [
9
+ "tokenization_hy.HYTokenizer",
10
+ null
11
+ ]
12
+ },
13
+ "eos_token": "<|eos|>",
14
+ "model_type": "gpt2",
15
+ "additional_special_tokens": ["<|startoftext|>", "<|extra_0|>", "<|extra_4|>", "<|extra_5|>", "<|eos|>"],
16
+ "pad_token": "<|pad|>",
17
+ "chat_template": "{% set loop_messages = messages %}\n{% if tools %}\n {% set weekday_map = {'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三', 'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'} %}\n {% set weekday_cn = weekday_map[strftime_now('%A')] %}\n {% set datetime_str = strftime_now('%Y-%m-%d %H:%M:%S') %}\n {% set datetime_str = datetime_str + ' ' + weekday_cn %}\n {% for message in loop_messages %}\n {% if 'content' in message %}\n {% set content = message['content'] %}\n {% else %}\n {% set content = '' %}\n {% endif %}\n {% if loop.index0 == 0 %}\n {% set content_tmp = '你是一位函数组合专家。你会得到一个问题和一组可能的函数。根据问题,你需要进行一个或多个函数/工具调用以实现目的。\n如果没有一个函数可以使用,请直接使用自然语言回复用户,以助手:开头。\n如果给定的问题缺少函数所需的参数,请使用自然语言进行提问,向用户询问必要信息,以助手:开头。\n如果调用结果已经足够回答用户问题,请对历史结果进行总结,使用自然语言回复用户,以助手:开头。\n你应该只在工具调用部分返回函数调用。如果你决定调用任何函数,你必须将其格式化为<tool_calls>[{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},...]</tool_calls>。你不应该在回复中包含任何其他文本。以下是你可以调用的函数列表,格式为JSON。\n' %}\n {% set content_tmp = content_tmp + '\n' + tools | tojson + '\n' %}\n {% if message['role'] == 'system' %}\n {% set content_tmp = content_tmp + '\n额外要求:\n' + content + '\n\n如果你决定返回函数调用,请将其格式化为<tool_calls>[{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},...]</tool_calls>,不得包含其他文本。如果额外要求里有格式要求,请忽略,以此处为准。\n否则,请参考开头说的三种情况,以助手:开头进行回复。\n\n如果额外要求里有时间信息,就以额外要求里的时间为准,否则,参考当前时间:' + datetime_str %}\n {% set content = '<|startoftext|>' + content_tmp + '<|extra_4|>' %}\n {% elif message['role'] == 'user' %}\n {% set content_tmp = content_tmp + '\n如果你决定返回函数调用,请将其格式化为<tool_calls>[{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},...]</tool_calls>,不得包含其他文本。\n否则,请参考开头说的三种情况,以助手:开头进行回复。\n\n当前时间:' + datetime_str %}\n {% set content_tmp = '<|startoftext|>' + content_tmp + '<|extra_4|>'%}\n {% set content = content_tmp + '用户:' + content + '<|extra_0|>' %}\n {% endif %}\n {% else %}\n {% if message['role'] == 'user' %}\n {% set content = '用户:' + content + '<|extra_0|>' %}\n {% elif message['role'] == 'assistant' %}\n {% if 'tool_calls' in message %}\n {% set tool_calls = message['tool_calls'] %}\n {% set ns = namespace(tool_calls=\"[\") %}\n {% for tool_call in tool_calls %}\n {% set function = tool_call['function'] %}\n {% set name = function['name'] %}\n {% set ns.tool_calls = ns.tool_calls + '{\"name\": \"' + name + '\", '%}\n {% set arguments = function['arguments'] %}\n {% if arguments is not string %}\n {% set arguments = arguments | tojson %}\n {% endif %}\n {% set ns.tool_calls = ns.tool_calls + '\"arguments\": ' + arguments + '}' %}\n {% if not loop.last %}\n {% set ns.tool_calls = ns.tool_calls + ', '%}\n {% endif %}\n {% endfor %}\n {% set ns.tool_calls = ns.tool_calls + ']' %}\n {% set content = content + '<tool_calls>' + ns.tool_calls + '</tool_calls>' %}\n {% else %}\n {% set content = '助手:' + content %}\n {% endif %}\n {% set content = content + '<|eos|>' %}\n {% elif message['role'] == 'tool' %}\n {% if content is not string %}\n {set content = content | tojson }\n {% endif %}\n {% set content = '<tool_response>' + content + '</tool_response>' %}\n {% set content = content + '<|extra_0|>' %}\n {% endif %}\n {% endif %}\n {{- content -}}\n {% endfor %}\n{% else %}\n {% set context = {'has_head': true} %}\n {% for message in loop_messages %}\n {% if 'content' in message %}\n {% set content = message['content'] %}\n {% else %}\n {% set content = '' %}\n {% endif %}\n {% if loop.index0 == 0 %}\n {% if content == '' %}\n {% set _ = context.update({'has_head': false}) %}\n {% elif message['role'] == 'system' %}\n {% set content = '<|startoftext|>' + content + '<|extra_4|>' %}\n {% endif %}\n {% endif %}\n {% if message['role'] == 'user' %}\n {% if loop.index0 == 1 and not context.has_head %}\n {% set content = '<|startoftext|>' + content %}\n {% endif %}\n {% if loop.index0 == 1 and context.has_head %}\n {% set content = content + '<|extra_0|>' %}\n {% else %}\n {% set content = '<|startoftext|>' + content + '<|extra_0|>' %}\n {% endif %}\n {% elif message['role'] == 'assistant' %}\n {% set content = content + '<|eos|>' %}\n {% elif message['role'] == 'tool' %}\n {% set content = content + '<|extra_0|>' %}\n {% endif %}\n {{- content -}}\n {% endfor %}\n{% endif %}\n{%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n' }}\n{%- endif %}"
18
+ }
vit_model.py ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import types
3
+ import math
4
+ import torch
5
+ from torch import Tensor, nn
6
+ import torch.nn.functional as F
7
+ from typing import List, Tuple, Optional, Union
8
+ from contextlib import contextmanager
9
+ from transformers.modeling_attn_mask_utils import (
10
+ _prepare_4d_causal_attention_mask_for_sdpa,
11
+ _prepare_4d_causal_attention_mask_for_sdpa,
12
+ _prepare_4d_causal_attention_mask,
13
+ )
14
+ from transformers.models.clip.configuration_clip import CLIPVisionConfig
15
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
16
+ from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
17
+ from .configuration_hunyuan import HunYuanConfig
18
+
19
+
20
+ def NaVitForward(input_ids, encoder_input, vit, image_tensors, images_pos, vit_input_resolution, im_start_id, im_end_id, image_token_id, anyres_vit_two_views, dtype):
21
+ # input_ids: (B, L)
22
+ # encoder_input: (L, B, E)
23
+ # image_tensors [[Tensor],...,[Tensor]]
24
+ # image_pos [[Tensor],...,[Tensor]]
25
+ # tokenizer = get_tokenizer()
26
+ b = len(input_ids)
27
+ img_embs = None
28
+ all_nums = sum([len(tensors) for tensors in image_tensors]) if image_tensors else 0
29
+ if all_nums != 0:
30
+ img_embs, img_batch_pos = vit(image_tensors)
31
+ else:
32
+ # when no input image, initialize a fake tensor
33
+ pad_nums = 1
34
+ image_tensors = [[torch.rand(3, vit_input_resolution, vit_input_resolution, dtype=dtype, device=torch.cuda.current_device()) for _ in range(pad_nums)]]
35
+ img_embs, img_batch_pos = vit(image_tensors)
36
+
37
+ encoder_input = encoder_input.clone()
38
+ if all_nums > 0:
39
+ assert len(images_pos) == len(img_batch_pos), \
40
+ (len(images_pos), len(img_batch_pos))
41
+ start_token_id = im_start_id
42
+ end_token_id = im_end_id
43
+ placeholder_id = image_token_id
44
+ for idx in range(len(images_pos)):
45
+ assert len(images_pos[idx]) == len(img_batch_pos[idx]), \
46
+ (len(images_pos[idx]), len(img_batch_pos[idx]))
47
+ for p_img_pos_in_batch, p_batch_img_pos in zip(img_batch_pos[idx], images_pos[idx]):
48
+ # the positions to be filled [s_start, s_end)
49
+ s_idx, s_start, s_end = p_img_pos_in_batch
50
+ current_embs = img_embs[s_idx, s_start:s_end]
51
+ im_s, im_e = p_batch_img_pos
52
+ assert len(current_embs) == im_e - im_s, \
53
+ (img_embs.shape, (s_start, s_end, s_idx), current_embs.shape, (im_s, im_e, idx))
54
+ if not anyres_vit_two_views:
55
+ assert input_ids[idx, im_s - 1] == start_token_id, \
56
+ input_ids[idx, im_s - 1]
57
+ assert input_ids[idx, im_e] == end_token_id, \
58
+ input_ids[idx, im_e]
59
+ assert (input_ids[idx, im_s:im_e] == placeholder_id).all(), \
60
+ f'The tokens to be filled are not the placeholder_id {placeholder_id}: {(input_ids[idx, im_s:im_e] == placeholder_id).sum()} vs {im_e - im_s}'
61
+ encoder_input[idx, im_s:im_e] = current_embs
62
+ else:
63
+ # when no input image, to mask vit value
64
+ vit_mask = torch.zeros([1, img_embs.shape[0]], device=torch.cuda.current_device())
65
+ current_embs = img_embs[0, :]
66
+ encoder_input[0, 1:img_embs.shape[0] + 1] = encoder_input[0, 1:img_embs.shape[0] + 1] * (1 - vit_mask) + current_embs * vit_mask
67
+ return encoder_input, input_ids
68
+
69
+
70
+ def VitForward(input_ids, encoder_input, vit, vit_linear_encoder, image_tensors, images_pos, vit_input_resolution, vit_mapping_type, vit_patch, vit_token):
71
+ vit_patch_mlp = (vit_patch > 1 and vit_mapping_type == 'mlp') or vit_patch == 0
72
+
73
+ b = len(input_ids)
74
+ if images_pos is None:
75
+ images_pos = torch.ones([len(input_ids), 1, 3])
76
+ images_pos[:, :, 1] = images_pos[:, :, 1]*(vit_token + 1)
77
+ images_pos = images_pos.long()
78
+
79
+ real_image_nums = []
80
+ image_tensors = image_tensors.view(b, -1, 3, vit_input_resolution, vit_input_resolution)
81
+ real_images = []
82
+
83
+ all_nums = 0
84
+ img_index = []
85
+ for s in range(len(images_pos)):
86
+ real_image_num = 0
87
+ for (im_s, im_e,index) in images_pos[s]:
88
+ if im_s == -1:
89
+ break
90
+ real_image_num += 1
91
+ all_nums += 1
92
+ img_index.append(index)
93
+
94
+ real_image_nums.append(real_image_num)
95
+ real_images.append(image_tensors[s][:real_image_num])
96
+
97
+ if vit_patch == 1:
98
+ img_index = None
99
+
100
+ if all_nums == 0:
101
+ # when no input image, initialize a fake tensor
102
+ img_input = torch.rand(b, 3, vit_input_resolution, vit_input_resolution).cuda().type(image_tensors.dtype)
103
+ img_embs = vit(img_input)
104
+ img_embs = vit_linear_encoder(img_embs)
105
+ else:
106
+ img_input = torch.cat(real_images)
107
+ img_embs = vit(img_input, img_index = img_index)
108
+ img_embs = vit_linear_encoder(img_embs)
109
+
110
+ encoder_input = encoder_input.clone()
111
+ start = 0
112
+ if all_nums > 0:
113
+ for s, real_image_len in enumerate(real_image_nums):
114
+ current_embs = img_embs[start:start + real_image_len, :] #[30, 256, 4096]
115
+ for ss in range(current_embs.shape[0]):
116
+ im_s, im_e, index = images_pos[s, ss]
117
+ # 子图特征更少
118
+ if index > 0 and vit_patch_mlp:
119
+ encoder_input[s, im_s:im_e,] = current_embs[ss, :(im_e-im_s)]
120
+ else:
121
+ encoder_input[s, im_s:im_e] = current_embs[ss, :]
122
+ start = start + real_image_len
123
+ else:
124
+ # when no input image, to mask vit value
125
+ for s in range(b):
126
+ vit_mask = torch.zeros([vit_token, 1]).cuda()
127
+ current_embs = img_embs[:, start:start + 1]
128
+ encoder_input[1:vit_token + 1, s] = encoder_input[1:vit_token + 1, s] * (1 - vit_mask) + current_embs[:, 0, :] * vit_mask
129
+ start = start + 1
130
+ return encoder_input, input_ids
131
+
132
+
133
+ def group_images_by_max_seq_len(
134
+ images: List[List[Tensor]], patch_size: int,
135
+ max_seq_len: int, adaptor_patch_size: int,
136
+ add_cls_token: bool = False) -> List[List[Tensor]]:
137
+
138
+ groups = []
139
+ group = []
140
+ pos_groups = []
141
+ seq_len = 0
142
+ num_images = 0
143
+ for image_list in images:
144
+ pos_group = []
145
+ for image in image_list:
146
+ num_images += 1
147
+ assert isinstance(image, Tensor)
148
+
149
+ image_dims = image.shape[-2:]
150
+ ph, pw = map(lambda t: t // patch_size, image_dims)
151
+
152
+ image_seq_len = (ph * pw)
153
+ new_image_seq_len = image_seq_len
154
+ grouped_len = seq_len + image_seq_len
155
+ if add_cls_token:
156
+ new_image_seq_len += 1
157
+ grouped_len += num_images
158
+
159
+ assert new_image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
160
+
161
+ if grouped_len > max_seq_len:
162
+ groups.append(group)
163
+ group = []
164
+ seq_len = 0
165
+ num_images = 1
166
+
167
+ group.append(image)
168
+ start = seq_len // (adaptor_patch_size * adaptor_patch_size)
169
+ end = start + image_seq_len//(adaptor_patch_size * adaptor_patch_size)
170
+ batch_idx = len(groups)
171
+ pos_group.append([batch_idx, start, end])
172
+ seq_len += image_seq_len
173
+ pos_groups.append(pos_group)
174
+
175
+ if len(group) > 0:
176
+ groups.append(group)
177
+
178
+ return groups, pos_groups
179
+
180
+
181
+ class AnyResCLIPVisionEmbeddings(nn.Module):
182
+ def __init__(self, config: CLIPVisionConfig):
183
+ super().__init__()
184
+
185
+ self.config = config
186
+ # self.sparse_attn_mask = args.sparse_attn_mask
187
+ # self.use_flash_attn = args.use_flash_attn
188
+ self.embed_dim = config.hidden_size
189
+ self.image_size = config.max_image_size
190
+ self.patch_size = config.patch_size
191
+ self.max_seq_len = config.max_vit_seq_len
192
+ self.adaptor_patch_size = config.adaptor_patch_size
193
+ self.anyres_vit_two_views = config.anyres_vit_two_views
194
+ self.vit_add_patchemb_bias = config.vit_add_patchemb_bias
195
+ self.vit_remove_prenorm = config.vit_remove_prenorm
196
+
197
+ self.patch_embedding = nn.Conv2d(
198
+ in_channels=config.num_channels,
199
+ out_channels=self.embed_dim,
200
+ kernel_size=self.patch_size,
201
+ stride=self.patch_size,
202
+ bias=self.vit_add_patchemb_bias,
203
+ )
204
+
205
+ self.num_patches = (self.image_size // self.patch_size) ** 2
206
+ self.skip_cls_token = True
207
+
208
+ # add interpolate_pos_encoding
209
+ if self.anyres_vit_two_views:
210
+ self.num_positions = self.num_patches
211
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim) * 0.02)
212
+ else:
213
+ self.num_positions = self.num_patches + 1
214
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
215
+ # self.position_ids = torch.arange(self.num_positions).expand((1, -1))
216
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
217
+
218
+ if not self.vit_remove_prenorm:
219
+ self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
220
+
221
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
222
+ """
223
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
224
+ resolution images.
225
+
226
+ Source:
227
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
228
+ """
229
+ num_patches = embeddings.shape[1]
230
+ position_embeddings = self.position_embedding(self.position_ids)
231
+ patch_pos_embed = position_embeddings[:, 1:]
232
+ num_positions = position_embeddings.shape[1] - 1
233
+ if num_patches == num_positions and height == width:
234
+ return patch_pos_embed
235
+ # class_pos_embed = position_embeddings[:, 0]
236
+ dim = embeddings.shape[-1]
237
+ h0 = height // self.patch_size
238
+ w0 = width // self.patch_size
239
+ # we add a small number to avoid floating point error in the interpolation
240
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
241
+ h0, w0 = h0 + 0.1, w0 + 0.1
242
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
243
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
244
+ raw_type = patch_pos_embed.dtype
245
+ patch_pos_embed = nn.functional.interpolate(
246
+ patch_pos_embed.to(torch.float32, non_blocking=True),
247
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
248
+ mode="bilinear",
249
+ align_corners=False,
250
+ )
251
+ patch_pos_embed = patch_pos_embed.to(raw_type, non_blocking=True)
252
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
253
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
254
+ return patch_pos_embed
255
+
256
+ def rescale_positional_embedding(self, out_size):
257
+ h, w = out_size
258
+ pos_embed_shape = int((self.position_embedding.shape[1]) ** 0.5)
259
+ if (h, w) == (pos_embed_shape, pos_embed_shape):
260
+ return self.position_embedding
261
+ rescaled_positional_embedding = \
262
+ self.position_embedding.new_zeros(1, h*w, self.position_embedding.shape[2])
263
+ pe_2d = self.position_embedding[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
264
+ pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
265
+ rescaled_positional_embedding[0] = pe_2d.T.contiguous()
266
+ return rescaled_positional_embedding
267
+
268
+ def forward_single(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
269
+ if pixel_values.ndim == 3:
270
+ pixel_values = pixel_values[None]
271
+ batch_size, num_channels, height, width = pixel_values.shape
272
+
273
+ if self.anyres_vit_two_views:
274
+ # padding
275
+ pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
276
+ pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
277
+ pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
278
+
279
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
280
+ b, c, h, w = patch_embeds.shape
281
+
282
+ # (b, hw, c)
283
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
284
+ if self.anyres_vit_two_views:
285
+ embeddings = patch_embeds + self.rescale_positional_embedding(out_size=(h, w))
286
+ else:
287
+ embeddings = patch_embeds + self.interpolate_pos_encoding(patch_embeds, height, width)
288
+ if not self.vit_remove_prenorm:
289
+ embeddings = self.pre_layernorm(embeddings)
290
+ return embeddings, (h, w)
291
+
292
+ def forward(self, images: List[List[Tensor]]):
293
+ '''
294
+ Input:
295
+ images: List[List[Tensor]]
296
+
297
+ Return:
298
+ embeddings: Tensor (B, L, E)
299
+ attn_mask: Tensor (B, L, 2)
300
+ pos_groups: List[List[(batch_idx, start, end)]]
301
+ '''
302
+ batched_images, pos_groups = group_images_by_max_seq_len(
303
+ images, self.patch_size, self.max_seq_len, self.adaptor_patch_size, add_cls_token=not self.skip_cls_token)
304
+ max_seq_len = self.max_seq_len
305
+
306
+ # batched_images is a list of a list
307
+ B = len(batched_images)
308
+ L = max_seq_len
309
+ E = self.embed_dim
310
+
311
+ embeddings = torch.zeros(B, L, E, dtype=self.config.torch_dtype, requires_grad=True).cuda(non_blocking=True)
312
+ attn_mask = embeddings.new_full((B, 1, L, L), False, dtype=torch.bool) # True presents compute
313
+ assert len(images) == len(pos_groups), (len(images), len(pos_groups))
314
+
315
+ batch_images = []
316
+ batch_pos = []
317
+ for images_i, pos_group in zip(images, pos_groups):
318
+ assert len(images_i) == len(pos_group), (len(images_i), len(pos_group))
319
+ for image, pos in zip(images_i, pos_group):
320
+ batch_idx, start, end = pos
321
+ a2 = self.adaptor_patch_size ** 2
322
+ # recover the real number of the input image tokens
323
+ start *= a2
324
+ end *= a2
325
+ emb, _ = self.forward_single(image)
326
+ assert emb.ndim == 3, '(B, L, E)'
327
+ embeddings[batch_idx, start:end] = emb
328
+ attn_mask[batch_idx, :, start:end, start:end] = True
329
+ return embeddings, attn_mask, pos_groups
330
+
331
+
332
+ class CLIPVisionEmbeddings(nn.Module):
333
+ def __init__(self, config: CLIPVisionConfig, add_pre_layernorm=False, skip_cls_token=True, vit_patch=1):
334
+ super().__init__()
335
+ self.config = config
336
+ self.embed_dim = config.hidden_size
337
+ self.image_size = config.image_size
338
+ self.image_size = config.vit_input_resolution
339
+ self.patch_size = config.patch_size
340
+
341
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
342
+
343
+ self.patch_embedding = nn.Conv2d(
344
+ in_channels=config.num_channels,
345
+ out_channels=self.embed_dim,
346
+ kernel_size=self.patch_size,
347
+ stride=self.patch_size,
348
+ bias=False,
349
+ )
350
+
351
+ self.num_patches = (self.image_size // self.patch_size) ** 2
352
+
353
+ self.skip_cls_token = skip_cls_token
354
+
355
+ self.num_positions = self.num_patches + 1
356
+
357
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
358
+ if vit_patch > 1:
359
+ self.position_embedding = nn.Embedding(self.num_patches * (vit_patch ** 2 + 1) + 1, self.embed_dim)
360
+ # 0 支持最大16张图,目前写死了,如需其他的需要额外定义参数
361
+ elif vit_patch == 0:
362
+ self.position_embedding = nn.Embedding(self.num_patches * (16 ** 2 + 1) + 1, self.embed_dim)
363
+ else:
364
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
365
+
366
+ if add_pre_layernorm:
367
+ self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
368
+ else:
369
+ self.pre_layernorm = None
370
+
371
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
372
+ """
373
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
374
+ resolution images.
375
+
376
+ Source:
377
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
378
+ """
379
+ num_patches = embeddings.shape[1] - 1
380
+ position_embeddings = self.position_embedding(self.position_ids)
381
+ num_positions = position_embeddings.shape[1] - 1
382
+ if num_patches == num_positions and height == width:
383
+ return position_embeddings
384
+ class_pos_embed = position_embeddings[:, 0]
385
+ patch_pos_embed = position_embeddings[:, 1:]
386
+ dim = embeddings.shape[-1]
387
+ h0 = height // self.config.patch_size
388
+ w0 = width // self.config.patch_size
389
+ # we add a small number to avoid floating point error in the interpolation
390
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
391
+ h0, w0 = h0 + 0.1, w0 + 0.1
392
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
393
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
394
+ raw_type = patch_pos_embed.dtype
395
+ patch_pos_embed = nn.functional.interpolate(
396
+ patch_pos_embed.float(),
397
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
398
+ mode="bicubic",
399
+ align_corners=False,
400
+ )
401
+ # print(patch_pos_embed.shape)
402
+ patch_pos_embed = patch_pos_embed.to(raw_type)
403
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
404
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
405
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
406
+
407
+
408
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, img_index=None) -> torch.Tensor:
409
+ batch_size, num_channels, height, width = pixel_values.shape
410
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
411
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
412
+ if self.skip_cls_token:
413
+ embeddings = patch_embeds
414
+ if img_index is None:
415
+ position_ids = self.position_ids[:,1:]
416
+ embeddings = embeddings + self.position_embedding(position_ids)
417
+ else:
418
+ position_ids = (torch.tensor(img_index).cuda() * (self.num_positions - 1)).unsqueeze(1).repeat(1, self.num_positions - 1) \
419
+ + self.position_ids.expand(batch_size, -1)[:, 1:]
420
+ embeddings = embeddings + self.position_embedding(position_ids)
421
+ else:
422
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
423
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
424
+ if interpolate_pos_encoding:
425
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
426
+ else:
427
+ if img_index is None:
428
+ embeddings = embeddings + self.position_embedding(self.position_ids)
429
+ else:
430
+ position_ids = self.position_ids.expand(batch_size,-1)[:,0].unsqueeze(1)
431
+ new_position = (torch.tensor(img_index).cuda() * (self.num_positions -1)).unsqueeze(1).repeat(1,self.num_positions-1) + self.position_ids.expand(batch_size,-1)[:,1:]
432
+ position_ids = torch.cat([position_ids,new_position],dim=1)
433
+ embeddings = embeddings + self.position_embedding(position_ids)
434
+ if self.pre_layernorm is not None:
435
+ embeddings = self.pre_layernorm(embeddings)
436
+ return embeddings
437
+
438
+
439
+ class NaVitTransformer(nn.Module):
440
+ def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
441
+ super().__init__()
442
+ self.config = config
443
+ self.vit_config = vit_config
444
+ with self.prepare_args(config, vit_config):
445
+ self._use_sdpa = config._attn_implementation == "sdpa"
446
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
447
+ self.layers = nn.ModuleList(
448
+ [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
449
+ )
450
+
451
+ @contextmanager
452
+ def prepare_args(self, config, vit_config):
453
+ hidden_act = config.hidden_act
454
+ hidden_size = config.hidden_size
455
+ ffn_hidden_size = config.intermediate_size
456
+ num_attention_heads = config.num_attention_heads
457
+ num_key_value_heads = config.num_key_value_heads
458
+ attention_head_dim = config.attention_head_dim
459
+ use_qk_norm = config.use_qk_norm
460
+ use_rotary_pos_emb = config.use_rotary_pos_emb
461
+ num_hidden_layers = config.num_hidden_layers
462
+ rms_norm_eps = config.rms_norm_eps
463
+ attention_dropout = config.attention_dropout
464
+ # hidden_dropout = config.hidden_dropout
465
+ norm_type = config.norm_type
466
+ attention_bias = config.attention_bias
467
+ mlp_bias = config.mlp_bias
468
+ use_mla = config.use_mla
469
+ num_experts = config.num_experts
470
+ _attn_implementation = config._attn_implementation
471
+
472
+ config.hidden_act = vit_config.hidden_act
473
+ config.hidden_size = vit_config.hidden_size
474
+ config.intermediate_size = vit_config.intermediate_size
475
+ config.num_attention_heads = vit_config.num_attention_heads
476
+ config.num_key_value_heads = None
477
+ config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
478
+ config.use_qk_norm = False
479
+ config.use_rotary_pos_emb = False
480
+ config.num_hidden_layers = vit_config.num_hidden_layers
481
+ config.rms_norm_eps = vit_config.layer_norm_eps
482
+ config.attention_dropout = vit_config.attention_dropout
483
+ # config.hidden_dropout = vit_config.hidden_dropout
484
+ config.norm_type = config.vit_norm_type
485
+ config.attention_bias = True
486
+ config.mlp_bias = True
487
+ config.use_mla = False
488
+ config.num_experts = 1
489
+ config._attn_implementation = "eager"
490
+
491
+ yield
492
+ config.hidden_act = hidden_act
493
+ config.hidden_size = hidden_size
494
+ config.intermediate_size = ffn_hidden_size
495
+ config.num_attention_heads = num_attention_heads
496
+ config.num_key_value_heads = num_key_value_heads
497
+ config.attention_head_dim = attention_head_dim
498
+ config.use_qk_norm = use_qk_norm
499
+ config.use_rotary_pos_emb = use_rotary_pos_emb
500
+ config.num_hidden_layers = num_hidden_layers
501
+ config.rms_norm_eps = rms_norm_eps
502
+ config.attention_dropout = attention_dropout
503
+ # config.hidden_dropout = hidden_dropout
504
+ config.attention_bias = attention_bias
505
+ config.mlp_bias = mlp_bias
506
+ config.norm_type = norm_type
507
+ config.use_mla = use_mla
508
+ config.num_experts = num_experts
509
+ config._attn_implementation = _attn_implementation
510
+
511
+ def forward(
512
+ self,
513
+ pixel_values: Optional[torch.FloatTensor] = None,
514
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
515
+
516
+ hidden_states, attention_mask, img_pos = self.embeddings(pixel_values)
517
+ attention_mask = attention_mask.int()
518
+ batch_size, seq_length, _ = hidden_states.shape
519
+ past_key_values_length = 0
520
+
521
+ if self._use_flash_attention_2:
522
+ # 2d mask is passed through the layers
523
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
524
+ elif self._use_sdpa:
525
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
526
+ # the manual implementation that requires a 4D causal mask in all cases.
527
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
528
+ attention_mask,
529
+ (batch_size, seq_length),
530
+ hidden_states,
531
+ past_key_values_length,
532
+ )
533
+ else:
534
+ attention_mask = _prepare_4d_causal_attention_mask(
535
+ attention_mask,
536
+ (batch_size, seq_length),
537
+ hidden_states,
538
+ past_key_values_length,
539
+ )
540
+
541
+ for layer_idx, decoder_layer in enumerate(self.layers):
542
+ layer_outputs = decoder_layer(
543
+ hidden_states,
544
+ attention_mask=attention_mask
545
+ )
546
+ hidden_states = layer_outputs[0]
547
+
548
+ return hidden_states, img_pos
549
+
550
+
551
+ class AnyResVitTransformer(NaVitTransformer):
552
+ def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig, anyres_vit_max_image_size):
553
+ super().__init__(config, vit_config)
554
+ old_anyres_vit_max_image_size = vit_config.max_image_size
555
+ anyres_vit_max_image_size = anyres_vit_max_image_size or old_anyres_vit_max_image_size
556
+ vit_config.max_image_size = anyres_vit_max_image_size
557
+ vit_config.torch_dtype = config.torch_dtype
558
+ vit_config.anyres_vit_two_views = config.anyres_vit_two_views
559
+ vit_config.vit_remove_prenorm = config.vit_remove_prenorm
560
+ vit_config.vit_add_patchemb_bias = config.vit_add_patchemb_bias
561
+ self.embeddings = AnyResCLIPVisionEmbeddings(vit_config)
562
+ vit_config.max_image_size = old_anyres_vit_max_image_size
563
+
564
+ def fix_embeddings_fn(self, pixel_values):
565
+ # (B, L, E)
566
+ embeddings, hw = self.embeddings.forward_single(pixel_values)
567
+ embeddings = self.embeddings.pre_layernorm(embeddings)
568
+ return embeddings
569
+
570
+
571
+ class CLIPVisionTransformer(nn.Module):
572
+ def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
573
+ super().__init__()
574
+ embed_dim = vit_config.hidden_size
575
+
576
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=vit_config.layer_norm_eps)
577
+ self.embeddings = CLIPVisionEmbeddings(vit_config, skip_cls_token=config.skip_cls_token, vit_patch=config.vit_patch)
578
+
579
+ with self.prepare_args(config, vit_config):
580
+ self.layers = nn.ModuleList(
581
+ [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
582
+ )
583
+
584
+ @contextmanager
585
+ def prepare_args(self, config, vit_config):
586
+ hidden_act = config.hidden_act
587
+ hidden_size = config.hidden_size
588
+ ffn_hidden_size = config.intermediate_size
589
+ num_attention_heads = config.num_attention_heads
590
+ num_key_value_heads = config.num_key_value_heads
591
+ attention_head_dim = config.attention_head_dim
592
+ use_qk_norm = config.use_qk_norm
593
+ use_rotary_pos_emb = config.use_rotary_pos_emb
594
+ num_hidden_layers = config.num_hidden_layers
595
+ rms_norm_eps = config.rms_norm_eps
596
+ attention_dropout = config.attention_dropout
597
+ # hidden_dropout = config.hidden_dropout
598
+ norm_type = config.norm_type
599
+ attention_bias = config.attention_bias
600
+ mlp_bias = config.mlp_bias
601
+ use_mla = config.use_mla
602
+ num_experts = config.num_experts
603
+ _attn_implementation = config._attn_implementation
604
+
605
+ config.hidden_act = vit_config.hidden_act
606
+ config.hidden_size = vit_config.hidden_size
607
+ config.intermediate_size = vit_config.intermediate_size
608
+ config.num_attention_heads = vit_config.num_attention_heads
609
+ config.num_key_value_heads = None
610
+ config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
611
+ config.use_qk_norm = False
612
+ config.use_rotary_pos_emb = False
613
+ config.num_hidden_layers = vit_config.num_hidden_layers
614
+ config.rms_norm_eps = vit_config.layer_norm_eps
615
+ config.attention_dropout = vit_config.attention_dropout
616
+ # config.hidden_dropout = 0.0
617
+ config.norm_type = "fused"
618
+ config.attention_bias = True
619
+ config.mlp_bias = True
620
+ config.use_mla = False
621
+ config.num_experts = 1
622
+ config._attn_implementation = "eager"
623
+
624
+ yield
625
+
626
+ config.hidden_act = hidden_act
627
+ config.hidden_size = hidden_size
628
+ config.intermediate_size = ffn_hidden_size
629
+ config.num_attention_heads = num_attention_heads
630
+ config.num_key_value_heads = num_key_value_heads
631
+ config.attention_head_dim = attention_head_dim
632
+ config.use_qk_norm = use_qk_norm
633
+ config.use_rotary_pos_emb = use_rotary_pos_emb
634
+ config.num_hidden_layers = num_hidden_layers
635
+ config.rms_norm_eps = rms_norm_eps
636
+ config.attention_dropout = attention_dropout
637
+ # config.hidden_dropout = hidden_dropout
638
+ config.norm_type = norm_type
639
+ config.attention_bias = attention_bias
640
+ config.mlp_bias = mlp_bias
641
+ config.use_mla = use_mla
642
+ config.num_experts = num_experts
643
+ config._attn_implementation = _attn_implementation
644
+
645
+ def forward(
646
+ self,
647
+ pixel_values: Optional[torch.FloatTensor] = None,
648
+ interpolate_pos_encoding: Optional[bool] = None,
649
+ img_index=None
650
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
651
+ r"""
652
+ Returns:
653
+
654
+ """
655
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, img_index=img_index)
656
+ hidden_states = self.pre_layrnorm(hidden_states)
657
+ batch = hidden_states.shape[0]
658
+ seq_len = hidden_states.shape[1]
659
+ device = hidden_states.device
660
+ attention_mask = torch.ones(batch, 1, seq_len, seq_len, dtype=torch.float32, device=device)
661
+
662
+ for layer_idx, decoder_layer in enumerate(self.layers):
663
+ layer_outputs = decoder_layer(
664
+ hidden_states,
665
+ attention_mask=attention_mask
666
+ )
667
+ hidden_states = layer_outputs[0]
668
+
669
+ return hidden_states
670
+
671
+
672
+ class Vit(torch.nn.Module):
673
+ def __init__(self, config, resampler_token=64, pool_rate=2):
674
+ super().__init__()
675
+ self.config = config
676
+ self.vit_mapping_type = config.vit_mapping_type
677
+ self.anyres_vit_max_image_size = config.anyres_vit_max_image_size
678
+ self.skip_cls_token = config.skip_cls_token
679
+ self.pool_rate = pool_rate
680
+ self.vit_type = self.config.vit_type
681
+ self.anyres_vit_two_views = self.config.anyres_vit_two_views
682
+ if self.vit_type in ['Vit-g', 'Vit-bigG', 'NaVit', 'EvaVit', 'AnyResVit']:
683
+ self.img_init(resampler_token, config.vit_input_resolution, config.vit_mapping_type, pool_rate)
684
+ else:
685
+ raise NotImplementedError(f"unsupported vit type: {self.vit_type}")
686
+
687
+ def img_init(self, resampler_token=64, vit_input_resolution=224, vit_mapping_type='resampler', pool_rate=2):
688
+ if self.vit_type == 'AnyResVit':
689
+ vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
690
+ self.vit_config = types.SimpleNamespace(**vit_config["vision_config"])
691
+ self.vit_config.image_size = vit_input_resolution
692
+ self.vit = AnyResVitTransformer(self.config, self.vit_config, self.anyres_vit_max_image_size)
693
+ elif self.vit_type == 'Vit-g':
694
+ vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
695
+ self.vit_config = types.SimpleNamespace(**{**vit_config["vision_config_dict"],**vit_config["vision_config"]})
696
+ self.vit_config.vit_input_resolution = vit_input_resolution
697
+ self.vit = CLIPVisionTransformer(self.config, self.vit_config)
698
+ else:
699
+ assert False, "other vit_types are not supported"
700
+
701
+ if self.vit_mapping_type == 'simple_conv_mlp':
702
+ self.perceive = SimpleConvMlp(self.vit_config.hidden_size, self.config.hidden_size, self.config.anyres_pooling_size, \
703
+ self.config.vit_used_rms_norm, self.config.rms_norm_eps, poolmlp=False, twoview=True)
704
+ elif self.vit_mapping_type == 'oryx_mlp':
705
+ self.perceive = OryxMLPv2(self.vit_config.hidden_size, self.config.hidden_size, twoview=True, use_pe=False)
706
+ elif self.vit_mapping_type == 'mlp':
707
+ self.mlp_depth = 2
708
+ # one mlp layer already in gpt_model.py
709
+ mlp_hidden_size = self.vit_config.hidden_size
710
+ if self.vit_type in ['NaVit', 'EvaVit']:
711
+ mlp_hidden_size *= self.vit_config.adaptor_patch_size **2
712
+ if self.mlp_depth > 1:
713
+ mlp_modules = [torch.nn.Linear(mlp_hidden_size, self.config.hidden_size), torch.nn.GELU()]
714
+ if self.vit_type in ['NaVit', 'EvaVit']:
715
+ for _ in range(1, self.mlp_depth):
716
+ mlp_modules.append(torch.nn.Linear(self.config.hidden_size, self.config.hidden_size))
717
+ mlp_modules.append(torch.nn.GELU())
718
+ self.perceive = torch.nn.Sequential(*mlp_modules)
719
+ else:
720
+ assert False, "other vit_mapping_types are not supported"
721
+
722
+ self.vit_patch_mlp = (self.config.vit_patch > 1 and self.vit_mapping_type == 'mlp') or self.config.vit_patch == 0
723
+ for name, param in self.named_parameters():
724
+ setattr(param, "is_vit_param", True)
725
+
726
+ def forward(self, images, img_index=None):
727
+ if self.vit_type in ['AnyResVit']:
728
+ dtype = self.config.torch_dtype
729
+ device = torch.cuda.current_device()
730
+
731
+ images_size = []
732
+ for i in range(len(images)):
733
+ images_size.append([])
734
+ for j in range(len(images[i])):
735
+ images_size[i].append((images[i][j].size()[1] // self.vit_config.patch_size, images[i][j].size()[2] // self.vit_config.patch_size))
736
+
737
+ images_feats, img_batch_pos = self.vit(pixel_values=images)
738
+ a2 = self.vit_config.adaptor_patch_size ** 2
739
+
740
+ if self.anyres_vit_two_views:
741
+ step = 2
742
+ else:
743
+ step = 1
744
+ perceive_fn = lambda x, img_size, is_video: self.perceive(x, img_size, is_video=is_video)
745
+ images_list = []
746
+ images_fix_i = 0
747
+ num_img_batch_pos = len(img_batch_pos)
748
+ for i in range(num_img_batch_pos): # batch_id
749
+ for j in range(0, len(img_batch_pos[i]), step):
750
+ if self.anyres_vit_two_views:
751
+ lower_idx, lower_begin, lower_end = img_batch_pos[i][j]
752
+ lower_begin = lower_begin * a2
753
+ lower_end = lower_end * a2
754
+ higher_idx, higher_begin, higher_end = img_batch_pos[i][j + 1]
755
+ higher_begin = higher_begin * a2
756
+ higher_end = higher_end * a2
757
+ lower_res_feat = images_feats[lower_idx, lower_begin:lower_end].unsqueeze(0)
758
+ higher_res_feat = images_feats[higher_idx, higher_begin:higher_end].unsqueeze(0)
759
+ lower_images_size = images_size[i][j]
760
+ higher_images_size = images_size[i][j + 1]
761
+ images_list.append(self.perceive(lower_res_feat, lower_images_size, higher_res_feat, higher_images_size))
762
+ else:
763
+ idx, begin, end = img_batch_pos[i][j]
764
+ begin = begin * a2
765
+ end = end * a2
766
+ is_video = hasattr(images[i][j],'_is_video') and images[i][j]._is_video
767
+ images_list.append(perceive_fn(images_feats[idx, begin:end].unsqueeze(0), images_size[i][j], is_video=is_video))
768
+
769
+ images = torch.cat(images_list, dim=1)
770
+
771
+ new_batch_pos = []
772
+ k = 0; cur_len = 0
773
+ for i in range(len(images_size)):
774
+ new_batch_pos.append([])
775
+ for j in range(0, len(images_size[i]), step):
776
+ new_pos = [0, cur_len, cur_len + images_list[k].size(1)]
777
+ cur_len += images_list[k].size(1)
778
+ k += 1
779
+ new_batch_pos[i].append(new_pos)
780
+ return images, new_batch_pos
781
+ elif self.vit_type == 'Vit-g':
782
+ images = self.vit(pixel_values=images, interpolate_pos_encoding=False, img_index=img_index)
783
+ else:
784
+ assert False, "other vit_types are not supported"
785
+
786
+ if self.vit_mapping_type == 'mlp':
787
+ if self.vit_type in ['Vit-g'] and not self.skip_cls_token:
788
+ images = images[:,1:,:]
789
+ b, v, d = images.shape
790
+ s = int(math.sqrt(v))
791
+ images = images.reshape(b, s, s, d)
792
+
793
+
794
+ if self.vit_patch_mlp and img_index is not None:
795
+ L_tensor = torch.tensor(img_index)
796
+ device = images.device
797
+ # 获取子图位置
798
+ nonzero_indices = torch.nonzero(L_tensor).squeeze().to(device)
799
+ # 获取主图位置
800
+ zero_indices = torch.nonzero(L_tensor == 0).squeeze().to(device)
801
+
802
+
803
+ images_nonzero = torch.index_select(images,0, nonzero_indices).to(device)
804
+ images_zero = torch.index_select(images, 0, zero_indices).to(device)
805
+
806
+ # 子图额外多pool一次
807
+ pool_rate = self.pool_rate * 2
808
+ images_nonzero = images_nonzero.reshape(-1, s // pool_rate, pool_rate, s // pool_rate, pool_rate, d)
809
+ images_nonzero = images_nonzero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // pool_rate) * (s // pool_rate), d,
810
+ pool_rate*pool_rate).mean(-1)
811
+
812
+ # 为了组batch折衷方案
813
+ images_nonzero = F.pad(images_nonzero, (0, 0, 0, (s // self.pool_rate) * (s // self.pool_rate)- (s // pool_rate) * (s // pool_rate)))
814
+ images_zero = images_zero.reshape(-1, s // self.pool_rate, self.pool_rate, s // self.pool_rate, self.pool_rate, d)
815
+ images_zero = images_zero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // self.pool_rate) * (s // self.pool_rate), d,
816
+ self.pool_rate*self.pool_rate).mean(-1)
817
+ # 组batch
818
+ images = torch.zeros(b, (s // self.pool_rate) * (s // self.pool_rate), d).to(device).to(images.dtype)
819
+ images.index_copy_(0, nonzero_indices, images_nonzero)
820
+ images.index_copy_(0, zero_indices, images_zero)
821
+
822
+ if self.mlp_depth >= 2:
823
+ images = self.perceive(images)
824
+ else:
825
+ if s % self.pool_rate == 0:
826
+ images = images.reshape(b, s//self.pool_rate, self.pool_rate, s//self.pool_rate, self.pool_rate, d)
827
+ images = images.permute(0, 1, 3, 5, 2, 4).reshape(b, (s//self.pool_rate) * (s//self.pool_rate), d, -1).mean(-1)
828
+ if self.mlp_depth >= 2:
829
+ images = self.perceive(images)
830
+ else:
831
+ raise ValueError
832
+ return images
833
+
834
+
835
+ class SimpleConvMlp(nn.Module):
836
+ def __init__(self, in_channels, out_channels, anyres_pooling_size, vit_used_rms_norm, rms_norm_eps, twoview=False, poolmlp=True, cat_extra_token=True):
837
+ super().__init__()
838
+
839
+ embed_std = 1 / math.sqrt(out_channels)
840
+ if poolmlp:
841
+ # if args.learnable_mlp_pooling_size is not None:
842
+ # in_channels *= args.learnable_mlp_pooling_size ** 2
843
+ self.proj = nn.Sequential(
844
+ nn.Linear(in_channels, out_channels),
845
+ nn.GELU()
846
+ )
847
+ self.vit_linear_encoder = nn.Linear(out_channels, out_channels)
848
+ self.image_newline = nn.Parameter(
849
+ torch.randn(out_channels) * embed_std
850
+ )
851
+ else:
852
+ self.proj = nn.Sequential(
853
+ nn.Conv2d(in_channels, in_channels * 2, kernel_size=anyres_pooling_size, stride=anyres_pooling_size),
854
+ nn.GELU(),
855
+ nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
856
+ )
857
+ self.mlp = nn.Linear(in_channels * 4, out_channels)
858
+ self.image_newline = nn.Parameter(
859
+ torch.randn(in_channels * 4) * embed_std
860
+ )
861
+ self.poolmlp = poolmlp
862
+
863
+ self.image_begin = nn.Parameter(
864
+ torch.randn(out_channels) * embed_std
865
+ )
866
+ self.image_end = nn.Parameter(
867
+ torch.randn(out_channels) * embed_std
868
+ )
869
+
870
+ if twoview:
871
+ self.image_sep = nn.Parameter(
872
+ torch.randn(out_channels) * embed_std
873
+ )
874
+
875
+ self.cat_extra_token = cat_extra_token
876
+ self.use_rms_norm = vit_used_rms_norm
877
+ if self.use_rms_norm:
878
+ self.before_rms = HunYuanRMSNorm(in_channels, eps=rms_norm_eps)
879
+ self.after_rms = HunYuanRMSNorm(out_channels, eps=rms_norm_eps)
880
+
881
+ def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
882
+ return self.single_forward(x=x, size=size, x2=x2, size2=size2, is_video=is_video)
883
+
884
+ def single_forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
885
+ remove_vit_special_tokens = False
886
+ learnable_mlp_pooling_size = None
887
+ if self.use_rms_norm:
888
+ x = self.before_rms(x)
889
+ h, w = size
890
+ dtype = x.dtype
891
+ x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
892
+ if self.poolmlp:
893
+ if learnable_mlp_pooling_size is None:
894
+ x = F.avg_pool2d(x, anyres_pooling_size)
895
+ x = self.proj(x.permute(0, 2, 3, 1)) # b, h, w, c
896
+ else:
897
+ x = x.permute(0, 2, 3, 1) # b, h, w, c
898
+ x = x.reshape(x.shape[0], h // learnable_mlp_pooling_size, learnable_mlp_pooling_size,
899
+ w // learnable_mlp_pooling_size, learnable_mlp_pooling_size, -1)
900
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(x.shape[0], h // learnable_mlp_pooling_size, w // learnable_mlp_pooling_size, -1)
901
+ x = self.proj(x)
902
+ x = self.vit_linear_encoder(x)
903
+ b, h, w, c = x.shape
904
+ if not remove_vit_special_tokens:
905
+ x = torch.cat([
906
+ x,
907
+ self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype, non_blocking=True)
908
+ ], dim=2)
909
+ x = x.reshape(b, -1, c)
910
+ else:
911
+ x = self.proj(x) #b,c,h,w
912
+ if is_video:
913
+ video_avgpool_size = 2
914
+ stride = 2
915
+ x = F.avg_pool2d(x, kernel_size = video_avgpool_size, stride = stride)
916
+ b, c, h, w = x.shape
917
+ if not remove_vit_special_tokens:
918
+ x = torch.cat([
919
+ x,
920
+ self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype, non_blocking=True)
921
+ ], dim=-1)
922
+ x = x.reshape(b, c, -1).permute(0, 2, 1)
923
+ x = self.mlp(x)
924
+
925
+
926
+ if x2 is not None:
927
+ h2, w2 = size2
928
+ x2 = x2.permute(0, 2, 1).reshape(x2.shape[0], -1, h2, w2)
929
+ if self.poolmlp:
930
+ x2 = F.avg_pool2d(x2, 2)
931
+ x2 = self.proj(x2.permute(0, 2, 3, 1)) # b, h, w, c
932
+ x2 = self.vit_linear_encoder(x2)
933
+ b2, h2, w2, c2 = x2.shape
934
+ if not remove_vit_special_tokens:
935
+ x2 = torch.cat([
936
+ x2,
937
+ self.image_newline.reshape(1, 1, 1, c2).expand(b2, h2, 1, c2).to(dtype, non_blocking=True)
938
+ ], dim=2)
939
+ x2 = x2.reshape(b2, -1, c2)
940
+ else:
941
+ x2 = self.proj(x2)
942
+ b2, c2, h2, w2 = x2.shape
943
+ if not remove_vit_special_tokens:
944
+ x2 = torch.cat([
945
+ x2,
946
+ self.image_newline.reshape(1, c2, 1, 1).expand(b2, c2, h2, 1).to(dtype, non_blocking=True)
947
+ ], dim=-1)
948
+ x2 = x2.reshape(b2, c2, -1).permute(0, 2, 1) #b,n,c
949
+ x2 = self.mlp(x2)
950
+
951
+ sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, x2.shape[-1]).to(dtype, non_blocking=True)
952
+
953
+ x = torch.cat([x, sep, x2], dim=1)
954
+
955
+ if self.cat_extra_token:
956
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
957
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
958
+ x = torch.cat([begin, x, end], dim=1)
959
+
960
+ if self.use_rms_norm:
961
+ return self.after_rms(x)
962
+ else:
963
+ return x
964
+
965
+
966
+ class NormalizedDwPooler(nn.Module):
967
+ def __init__(self, dim):
968
+ super().__init__()
969
+ self.dim = dim
970
+ self.predictor = nn.Sequential(
971
+ nn.Linear(dim*2, dim),
972
+ nn.GELU(),
973
+ nn.Linear(dim, dim),
974
+ )
975
+
976
+ def forward(self, x, forward_type='2x'):
977
+ B, H, W, C = x.shape
978
+
979
+ if forward_type == '2x':
980
+ new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
981
+ pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
982
+ fused_x = torch.cat([new_x, pooled_x], dim=-1)
983
+ elif forward_type == '1x':
984
+ new_x = x.reshape(B, H, W, 1, C)
985
+ fused_x = torch.cat([new_x, new_x], dim=-1)
986
+ elif forward_type == '4x':
987
+ new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
988
+ pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
989
+ fused_x = torch.cat([new_x, pooled_x], dim=-1)
990
+
991
+ score = self.predictor(fused_x)
992
+ normalized_score = F.softmax(score, dim=-2)
993
+ new_x = (new_x * normalized_score).sum(dim=-2)
994
+ return new_x
995
+
996
+
997
+ class OryxMLPv2(nn.Module):
998
+ def __init__(self, in_channels, out_channels, twoview=False, use_pe=False):
999
+ super().__init__()
1000
+
1001
+ self.proj1 = nn.Linear(in_channels, out_channels)
1002
+ self.proj2 = nn.Linear(out_channels, out_channels)
1003
+ self.act = nn.GELU()
1004
+ self.pooler = NormalizedDwPooler(out_channels)
1005
+ embed_std = 1 / math.sqrt(out_channels)
1006
+
1007
+ self.use_pe = use_pe
1008
+ if not use_pe:
1009
+ self.image_newline = nn.Parameter(
1010
+ torch.randn(out_channels) * embed_std
1011
+ )
1012
+ self.image_begin = nn.Parameter(
1013
+ torch.randn(out_channels) * embed_std
1014
+ )
1015
+ self.image_end = nn.Parameter(
1016
+ torch.randn(out_channels) * embed_std
1017
+ )
1018
+
1019
+ if twoview:
1020
+ self.image_sep = nn.Parameter(
1021
+ torch.randn(out_channels) * embed_std
1022
+ )
1023
+
1024
+ def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
1025
+ h, w = size
1026
+ dtype = x.dtype
1027
+ x = x.reshape(x.shape[0], h, w, -1)
1028
+ # x = self.pooler(x, forward_type=REGIONAL_POOL)
1029
+ # x = self.proj(x) #b,h,w, c
1030
+ x = self.proj1(x)
1031
+ x = self.pooler(x, forward_type='2x')
1032
+ x = self.act(x)
1033
+ x = self.proj2(x)
1034
+
1035
+
1036
+ b, h, w, c = x.shape
1037
+ if not self.use_pe:
1038
+ x = torch.cat([
1039
+ x,
1040
+ self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
1041
+ ], dim=2)
1042
+ else:
1043
+ pe_h = torch.arange(h, dtype=torch.long, device=x.device).reshape(1, h, 1, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
1044
+ pe_w = torch.arange(w, dtype=torch.long, device=x.device).reshape(1, 1, w, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
1045
+ pe = torch.cat([pe_h, pe_w], dim=-1)
1046
+
1047
+ x = x.reshape(b, -1, c)
1048
+
1049
+ if x2 is not None:
1050
+ h2, w2 = size2
1051
+ x2 = x2.reshape(x2.shape[0], h2, w2, -1)
1052
+ # x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
1053
+ ## x2 = self.proj(x2) #b,h,w, c
1054
+ x2 = self.proj1(x2)
1055
+ x2 = self.pooler(x2, forward_type='2x')
1056
+ x2 = self.act(x2)
1057
+ x2 = self.proj2(x2)
1058
+
1059
+ b2, h2, w2, c2 = x2.shape
1060
+ if not self.use_pe:
1061
+ x2 = torch.cat([
1062
+ x2,
1063
+ self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
1064
+ ], dim=2)
1065
+ x2 = x2.reshape(b, -1, c)
1066
+ sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
1067
+ x = torch.cat([x, sep, x2], dim=1)
1068
+
1069
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
1070
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
1071
+ x = torch.cat([begin, x, end], dim=1)
1072
+ # print(x.shape, x2.shape, h, w, h2, w2)
1073
+ # print("vit rank = " + str(torch.distributed.get_rank()) +" x = " + str(x))
1074
+ if self.use_pe:
1075
+ zero_pad = torch.zeros(b, 1, 2, device=x.device, dtype=torch.long)
1076
+ pe = torch.cat([zero_pad, pe, zero_pad], dim=1)
1077
+ assert pe.shape[1] == x.shape[1]
1078
+ return x, pe
1079
+ else:
1080
+ nseq = x.shape[1]
1081
+ fake_pe = torch.zeros(b, nseq, 2, device=x.device, dtype=torch.long)
1082
+ return x #, fake_pe
1083
+