Add files using upload-large-folder tool
Browse files- LICENSE +77 -0
- Notice.txt +160 -0
- README.md +270 -0
- README_CN.md +456 -0
- config.json +203 -0
- configuration_hunyuan.py +319 -0
- generation_config.json +10 -0
- hunyuan.py +879 -0
- hy.tiktoken +0 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +0 -0
- tokenization_hy.py +298 -0
- tokenizer_config.json +18 -0
- vit_model.py +1083 -0
LICENSE
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
2 |
+
Tencent Hunyuan A13B Release Date: June 27, 2025
|
3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
5 |
+
1. DEFINITIONS.
|
6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
15 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan A13B released at [https://github.com/Tencent-Hunyuan/Hunyuan-A13B].
|
16 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
17 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
18 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
19 |
+
n. “including” shall mean including but not limited to.
|
20 |
+
2. GRANT OF RIGHTS.
|
21 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
22 |
+
3. DISTRIBUTION.
|
23 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
24 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
25 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
26 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
27 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
28 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
29 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
30 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
31 |
+
5. RULES OF USE.
|
32 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
33 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
34 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
35 |
+
6. INTELLECTUAL PROPERTY.
|
36 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
37 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
38 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
39 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
40 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
41 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
42 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
43 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
44 |
+
8. SURVIVAL AND TERMINATION.
|
45 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
46 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
47 |
+
9. GOVERNING LAW AND JURISDICTION.
|
48 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
49 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
50 |
+
|
51 |
+
EXHIBIT A
|
52 |
+
ACCEPTABLE USE POLICY
|
53 |
+
|
54 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
55 |
+
Last modified: November 5, 2024
|
56 |
+
|
57 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
58 |
+
1. Outside the Territory;
|
59 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
60 |
+
3. To harm Yourself or others;
|
61 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
62 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
63 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
64 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
65 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
66 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
67 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
68 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
69 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
70 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
71 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
72 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
73 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
74 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
75 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
76 |
+
19. For military purposes;
|
77 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
Notice.txt
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Usage and Legal Notices:
|
2 |
+
|
3 |
+
Tencent is pleased to support the open source community by making Tencent Hunyuan A13B available.
|
4 |
+
|
5 |
+
Copyright (C) Tencent. All rights reserved. The below software and/or models in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
|
6 |
+
|
7 |
+
Tencent Hunyuan A13B is licensed under TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent Hunyuan A13B does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
8 |
+
|
9 |
+
For avoidance of doubts, Tencent Hunyuan A13B refers to the inference code, training code, parameters and the weights of Tencent Hunyuan A13B only, which are made publicly available by Tencent in accordance with the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
10 |
+
|
11 |
+
|
12 |
+
Other dependencies and licenses:
|
13 |
+
|
14 |
+
|
15 |
+
Open Source Software Licensed under the Apache License Version 2.0:
|
16 |
+
The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
17 |
+
--------------------------------------------------------------------
|
18 |
+
1. pytorch
|
19 |
+
Copyright 2016-2017 TorchAPI
|
20 |
+
Copyright 2016-2017 Contributors
|
21 |
+
|
22 |
+
2. VLLM
|
23 |
+
Copyright (c) vllm original author and authors
|
24 |
+
Please note this software has been modified by Tencent in this distribution.
|
25 |
+
|
26 |
+
3. transformers
|
27 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
28 |
+
|
29 |
+
4. accelerate
|
30 |
+
Copyright (c) accelerate original author and authors
|
31 |
+
|
32 |
+
|
33 |
+
Terms of the Apache License Version 2.0:
|
34 |
+
--------------------------------------------------------------------
|
35 |
+
Apache License
|
36 |
+
|
37 |
+
Version 2.0, January 2004
|
38 |
+
|
39 |
+
http://www.apache.org/licenses/
|
40 |
+
|
41 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
42 |
+
1. Definitions.
|
43 |
+
|
44 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
45 |
+
|
46 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
47 |
+
|
48 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
49 |
+
|
50 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
51 |
+
|
52 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
53 |
+
|
54 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
55 |
+
|
56 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
57 |
+
|
58 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
59 |
+
|
60 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
63 |
+
|
64 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
65 |
+
|
66 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
67 |
+
|
68 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
69 |
+
|
70 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
71 |
+
|
72 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
73 |
+
|
74 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
75 |
+
|
76 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
77 |
+
|
78 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
79 |
+
|
80 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
81 |
+
|
82 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
83 |
+
|
84 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
85 |
+
|
86 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
87 |
+
|
88 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
89 |
+
|
90 |
+
END OF TERMS AND CONDITIONS
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
95 |
+
--------------------------------------------------------------------
|
96 |
+
1. pytorch
|
97 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
98 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
99 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
100 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
101 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
102 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
103 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
104 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
105 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
106 |
+
|
107 |
+
|
108 |
+
Terms of the BSD 3-Clause:
|
109 |
+
--------------------------------------------------------------------
|
110 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
111 |
+
|
112 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
113 |
+
|
114 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
115 |
+
|
116 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
117 |
+
|
118 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
119 |
+
|
120 |
+
For the license of other third party components, please refer to the following URL:
|
121 |
+
https://github.com/pytorch/pytorch/blob/v2.1.1/NOTICE
|
122 |
+
https://github.com/pytorch/pytorch/tree/v2.1.1/third_party
|
123 |
+
|
124 |
+
|
125 |
+
Open Source Software Licensed under the BSD 3-Clause License:
|
126 |
+
--------------------------------------------------------------------
|
127 |
+
1. flash_attn
|
128 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
129 |
+
All rights reserved.
|
130 |
+
|
131 |
+
|
132 |
+
A copy of the BSD 3-Clause is included in this file.
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
Open Source Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
137 |
+
The below software in this distribution is modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
138 |
+
--------------------------------------------------------------------
|
139 |
+
1. sglang
|
140 |
+
Copyright 2023-2024 SGLang Team
|
141 |
+
|
142 |
+
|
143 |
+
A copy of the Apache 2.0 is included in this file.
|
144 |
+
|
145 |
+
For the license of other third party components, please refer to the following URL:
|
146 |
+
https://github.com/sgl-project/sglang/tree/v0.4.7/3rdparty/amd
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
Open Source Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
151 |
+
The below software in this distribution is modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
152 |
+
--------------------------------------------------------------------
|
153 |
+
1. TensorRT-LLM
|
154 |
+
Copyright (c) TensorRT-LLM original author and authors
|
155 |
+
|
156 |
+
|
157 |
+
A copy of the Apache 2.0 is included in this file.
|
158 |
+
|
159 |
+
For the license of other third party components, please refer to the following URL:
|
160 |
+
https://github.com/NVIDIA/TensorRT-LLM/tree/v0.20.0/3rdparty
|
README.md
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: other
|
3 |
+
license_name: tencent-hunyuan-a13b
|
4 |
+
license_link: https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/LICENSE
|
5 |
+
---
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<img src="https://dscache.tencent-cloud.cn/upload/uploader/hunyuan-64b418fd052c033b228e04bc77bbc4b54fd7f5bc.png" width="400"/> <br>
|
9 |
+
</p><p></p>
|
10 |
+
|
11 |
+
|
12 |
+
<p align="center">
|
13 |
+
🫣 <a href="https://huggingface.co/tencent/Hunyuan-A13B-Instruct"><b>Hugging Face</b></a> |
|
14 |
+
🖥️ <a href="https://llm.hunyuan.tencent.com/" style="color: red;"><b>Official Website</b></a> |
|
15 |
+
🕖 <a href="https://cloud.tencent.com/product/hunyuan"><b>HunyuanAPI</b></a> |
|
16 |
+
🕹️ <a href="https://hunyuan.tencent.com/?model=hunyuan-a13b"><b>Demo</b></a> |
|
17 |
+
</p>
|
18 |
+
|
19 |
+
|
20 |
+
<p align="center">
|
21 |
+
<a href="https://github.com/Tencent-Hunyuan/Hunyuan-A13B"><b>GITHUB</b></a> |
|
22 |
+
<a href="https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/LICENSE"><b>LICENSE</b></a>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
Welcome to the official repository of **Hunyuan-A13B**, an innovative and open-source large language model (LLM) built on a fine-grained Mixture-of-Experts (MoE) architecture. Designed for efficiency and scalability, Hunyuan-A13B delivers cutting-edge performance with minimal computational overhead, making it an ideal choice for advanced reasoning and general-purpose applications, especially in resource-constrained environments.
|
28 |
+
|
29 |
+
## Model Introduction
|
30 |
+
|
31 |
+
With the rapid advancement of artificial intelligence technology, large language models (LLMs) have achieved remarkable progress in natural language processing, computer vision, and scientific tasks. However, as model scales continue to expand, optimizing resource consumption while maintaining high performance has become a critical challenge. To address this, we have explored Mixture of Experts (MoE) architectures. The newly introduced Hunyuan-A13B model features a total of 80 billion parameters with 13 billion active parameters. It not only delivers high-performance results but also achieves optimal resource efficiency, successfully balancing computational power and resource utilization.
|
32 |
+
|
33 |
+
### Key Features and Advantages
|
34 |
+
|
35 |
+
- **Compact yet Powerful**: With only 13 billion active parameters (out of a total of 80 billion), the model delivers competitive performance on a wide range of benchmark tasks, rivaling much larger models.
|
36 |
+
- **Hybrid Inference Support**: Supports both fast and slow thinking modes, allowing users to flexibly choose according to their needs.
|
37 |
+
- **Ultra-Long Context Understanding**: Natively supports a 256K context window, maintaining stable performance on long-text tasks.
|
38 |
+
- **Enhanced Agent Capabilities**: Optimized for agent tasks, achieving leading results on benchmarks such as BFCL-v3 and τ-Bench.
|
39 |
+
- **Efficient Inference**: Utilizes Grouped Query Attention (GQA) and supports multiple quantization formats, enabling highly efficient inference.
|
40 |
+
|
41 |
+
### Why Choose Hunyuan-A13B?
|
42 |
+
|
43 |
+
As a powerful yet computationally efficient large model, Hunyuan-A13B is an ideal choice for researchers and developers seeking high performance under resource constraints. Whether for academic research, cost-effective AI solution development, or innovative application exploration, this model provides a robust foundation for advancement.
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
## Related News
|
48 |
+
* 2025.6.27 We have open-sourced **Hunyuan-A13B-Pretrain** , **Hunyuan-A13B-Instruct** , **Hunyuan-A13B-Instruct-FP8** , **Hunyuan-A13B-Instruct-GPTQ-Int4** on Hugging Face.
|
49 |
+
<br>
|
50 |
+
|
51 |
+
|
52 |
+
## Benchmark
|
53 |
+
|
54 |
+
Note: The following benchmarks are evaluated by TRT-LLM-backend
|
55 |
+
|
56 |
+
| Model | Hunyuan-Large | Qwen2.5-72B | Qwen3-A22B | Hunyuan-A13B |
|
57 |
+
|------------------|---------------|--------------|-------------|---------------|
|
58 |
+
| MMLU | 88.40 | 86.10 | 87.81 | 88.17 |
|
59 |
+
| MMLU-Pro | 60.20 | 58.10 | 68.18 | 67.23 |
|
60 |
+
| MMLU-Redux | 87.47 | 83.90 | 87.40 | 87.67 |
|
61 |
+
| BBH | 86.30 | 85.80 | 88.87 | 87.56 |
|
62 |
+
| SuperGPQA | 38.90 | 36.20 | 44.06 | 41.32 |
|
63 |
+
| EvalPlus | 75.69 | 65.93 | 77.60 | 78.64 |
|
64 |
+
| MultiPL-E | 59.13 | 60.50 | 65.94 | 69.33 |
|
65 |
+
| MBPP | 72.60 | 76.00 | 81.40 | 83.86 |
|
66 |
+
| CRUX-I | 57.00 | 57.63 | - | 70.13 |
|
67 |
+
| CRUX-O | 60.63 | 66.20 | 79.00 | 77.00 |
|
68 |
+
| MATH | 69.80 | 62.12 | 71.84 | 72.35 |
|
69 |
+
| CMATH | 91.30 | 84.80 | - | 91.17 |
|
70 |
+
| GSM8k | 92.80 | 91.50 | 94.39 | 91.83 |
|
71 |
+
| GPQA | 25.18 | 45.90 | 47.47 | 49.12 |
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
Hunyuan-A13B-Instruct has achieved highly competitive performance across multiple benchmarks, particularly in mathematics, science, agent domains, and more. We compared it with several powerful models, and the results are shown below.
|
77 |
+
|
78 |
+
| Topic | Bench | OpenAI-o1-1217 | DeepSeek R1 | Qwen3-A22B | Hunyuan-A13B-Instruct |
|
79 |
+
|:-------------------:|:-----------------------------:|:-------------:|:------------:|:-----------:|:---------------------:|
|
80 |
+
| **Mathematics** | AIME 2024<br>AIME 2025<br>MATH | 74.3<br>79.2<br>96.4 | 79.8<br>70<br>94.9 | 85.7<br>81.5<br>94.0 | 87.3<br>76.8<br>94.3 |
|
81 |
+
| **Science** | GPQA-Diamond<br>OlympiadBench | 78<br>83.1 | 71.5<br>82.4 | 71.1<br>85.7 | 71.2<br>82.7 |
|
82 |
+
| **Coding** | Livecodebench<br>Fullstackbench<br>ArtifactsBench | 63.9<br>64.6<br>38.6 | 65.9<br>71.6<br>44.6 | 70.7<br>65.6<br>44.6 | 63.9<br>67.8<br>43 |
|
83 |
+
| **Reasoning** | BBH<br>DROP<br>ZebraLogic | 80.4<br>90.2<br>81 | 83.7<br>92.2<br>78.7 | 88.9<br>90.3<br>80.3 | 89.1<br>91.1<br>84.7 |
|
84 |
+
| **Instruction<br>Following** | IF-Eval<br>SysBench | 91.8<br>82.5 | 88.3<br>77.7 | 83.4<br>74.2 | 84.7<br>76.1 |
|
85 |
+
| **Text<br>Creation**| LengthCtrl<br>InsCtrl | 60.1<br>74.8 | 55.9<br>69 | 53.3<br>73.7 | 55.4<br>71.9 |
|
86 |
+
| **NLU** | ComplexNLU<br>Word-Task | 64.7<br>67.1 | 64.5<br>76.3 | 59.8<br>56.4 | 61.2<br>62.9 |
|
87 |
+
| **Agent** | BDCL v3<br> τ-Bench<br>ComplexFuncBench<br> C3-Bench | 67.8<br>60.4<br>47.6<br>58.8 | 56.9<br>43.8<br>41.1<br>55.3 | 70.8<br>44.6<br>40.6<br>51.7 | 78.3<br>54.7<br>61.2<br>63.5 |
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
## Use with transformers
|
93 |
+
The following code snippet shows how to use the transformers library to load and apply the model. It also demonstrates how to enable and disable the reasoning mode , and how to parse the reasoning process along with the final output.
|
94 |
+
|
95 |
+
```python
|
96 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
97 |
+
import os
|
98 |
+
import re
|
99 |
+
|
100 |
+
model_name_or_path = os.environ['MODEL_PATH']
|
101 |
+
# model_name_or_path = "tencent/Hunyuan-A13B-Instruct"
|
102 |
+
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
104 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto",trust_remote_code=True) # You may want to use bfloat16 and/or move to GPU here
|
105 |
+
messages = [
|
106 |
+
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
|
107 |
+
]
|
108 |
+
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True,return_tensors="pt",
|
109 |
+
enable_thinking=True # Toggle thinking mode (default: True)
|
110 |
+
)
|
111 |
+
|
112 |
+
outputs = model.generate(tokenized_chat.to(model.device), max_new_tokens=4096)
|
113 |
+
|
114 |
+
output_text = tokenizer.decode(outputs[0])
|
115 |
+
|
116 |
+
think_pattern = r'<think>(.*?)</think>'
|
117 |
+
think_matches = re.findall(think_pattern, output_text, re.DOTALL)
|
118 |
+
|
119 |
+
answer_pattern = r'<answer>(.*?)</answer>'
|
120 |
+
answer_matches = re.findall(answer_pattern, output_text, re.DOTALL)
|
121 |
+
|
122 |
+
think_content = [match.strip() for match in think_matches][0]
|
123 |
+
answer_content = [match.strip() for match in answer_matches][0]
|
124 |
+
print(f"thinking_content:{think_content}\n\n")
|
125 |
+
print(f"answer_content:{answer_content}\n\n")
|
126 |
+
```
|
127 |
+
|
128 |
+
## Quantitative Compression
|
129 |
+
We used our own `AngleSlim` compression tool to produce FP8 and INT4 quantization models. `AngleSlim` compression tool is expected to be open source in early July, which will support one-click quantization and compression of large models, please look forward to it, and you can download our quantization models directly for deployment testing now.
|
130 |
+
|
131 |
+
### FP8 Quantization
|
132 |
+
We use FP8-static quantization, FP8 quantization adopts 8-bit floating point format, through a small amount of calibration data (without training) to pre-determine the quantization scale, the model weights and activation values will be converted to FP8 format, to improve the inference efficiency and reduce the deployment threshold. We you can use AngleSlim quantization, you can also directly download our quantization completed open source model to use [Hunyuan-A13B-Instruct-FP8](https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8).
|
133 |
+
|
134 |
+
#### FP8 Benchmark
|
135 |
+
This subsection describes the Benchmark metrics for the Hunyuan-80B-A13B-Instruct-FP8 quantitative model.
|
136 |
+
|
137 |
+
| Bench | Hunyuan-A13B-Instruct | Hunyuan-A13B-Instruct-FP8 |
|
138 |
+
|:---------:|:---------------------:|:-------------------------:|
|
139 |
+
| AIME 2024 | 87.3 | 86.7 |
|
140 |
+
| Gsm8k | 94.39 | 94.01 |
|
141 |
+
| BBH | 89.1 | 88.34 |
|
142 |
+
| DROP | 91.1 | 91.1 |
|
143 |
+
|
144 |
+
### Int4 Quantization
|
145 |
+
We use the GPTQ algorithm to achieve W4A16 quantization, which processes the model weights layer by layer, uses a small amount of calibration data to minimize the reconfiguration error of the quantized weights, and adjusts the weights layer by layer by the optimization process of approximating the Hessian inverse matrix. The process eliminates the need to retrain the model and requires only a small amount of calibration data to quantize the weights, improving inference efficiency and lowering the deployment threshold. You can use `AngleSlim` quantization, you can also directly download our quantization completed open source model to use [Hunyuan-A13B-Instruct-Int4](https://huggingface.co/tencent/Hunyuan-A13B-Instruct-GPTQ-Int4).
|
146 |
+
|
147 |
+
#### Int4 Benchmark
|
148 |
+
This subsection describes the Benchmark metrics for the Hunyuan-80B-A13B-Instruct-GPTQ-Int4 quantitative model.
|
149 |
+
|
150 |
+
| Bench | Hunyuan-A13B-Instruct | Hunyuan-A13B-Instruct-GPTQ-Int4 |
|
151 |
+
|:--------------:|:---------------------:|:-------------------------------:|
|
152 |
+
| OlympiadBench | 82.7 | 84.0 |
|
153 |
+
| AIME 2024 | 87.3 | 86.7 |
|
154 |
+
| Gsm8k | 94.39 | 94.24 |
|
155 |
+
| BBH | 88.34 | 87.91 |
|
156 |
+
| DROP | 91.12 | 91.05 |
|
157 |
+
|
158 |
+
|
159 |
+
## Deployment
|
160 |
+
|
161 |
+
For deployment, you can use frameworks such as **TensorRT-LLM**, **vLLM**, or **SGLang** to serve the model and create an OpenAI-compatible API endpoint.
|
162 |
+
|
163 |
+
image: https://hub.docker.com/r/hunyuaninfer/hunyuan-a13b/tags
|
164 |
+
|
165 |
+
|
166 |
+
### TensorRT-LLM
|
167 |
+
|
168 |
+
#### Docker Image
|
169 |
+
|
170 |
+
We provide a pre-built Docker image based on the latest version of TensorRT-LLM.
|
171 |
+
|
172 |
+
- To get started:
|
173 |
+
|
174 |
+
https://hub.docker.com/r/hunyuaninfer/hunyuan-large/tags
|
175 |
+
|
176 |
+
```
|
177 |
+
docker pull hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-trtllm
|
178 |
+
```
|
179 |
+
|
180 |
+
- Start the API server:
|
181 |
+
|
182 |
+
```
|
183 |
+
docker run --name hunyuanLLM_infer --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-trtllm
|
184 |
+
```
|
185 |
+
```
|
186 |
+
trtllm-serve \
|
187 |
+
/path/to/HunYuan-moe-A13B \
|
188 |
+
--host localhost \
|
189 |
+
--port 8000 \
|
190 |
+
--backend pytorch \
|
191 |
+
--max_batch_size 128 \
|
192 |
+
--max_num_tokens 16384 \
|
193 |
+
--tp_size 2 \
|
194 |
+
--kv_cache_free_gpu_memory_fraction 0.95 \
|
195 |
+
--extra_llm_api_options /path/to/extra-llm-api-config.yml
|
196 |
+
```
|
197 |
+
|
198 |
+
|
199 |
+
### vllm
|
200 |
+
|
201 |
+
#### Docker Image
|
202 |
+
We provide a pre-built Docker image containing vLLM 0.8.5 with full support for this model. The official vllm release is currently under development, **note: cuda 12.8 is require for this docker**.
|
203 |
+
|
204 |
+
- To get started:
|
205 |
+
|
206 |
+
```
|
207 |
+
docker pull docker.cnb.cool/tencent/hunyuan/hunyuan-a13b:hunyuan-moe-A13B-vllm
|
208 |
+
or
|
209 |
+
docker pull hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-vllm
|
210 |
+
```
|
211 |
+
|
212 |
+
- Download Model file:
|
213 |
+
- Huggingface: will download automicly by vllm.
|
214 |
+
- ModelScope: `modelscope download --model Tencent-Hunyuan/Hunyuan-A13B-Instruct`
|
215 |
+
|
216 |
+
|
217 |
+
- Start the API server:
|
218 |
+
|
219 |
+
model download by huggingface:
|
220 |
+
```
|
221 |
+
docker run --privileged --user root --net=host --ipc=host \
|
222 |
+
-v ~/.cache:/root/.cache/ \
|
223 |
+
--gpus=all -it --entrypoint python hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-vllm
|
224 |
+
\
|
225 |
+
-m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 \
|
226 |
+
--tensor-parallel-size 4 --model tencent/Hunyuan-A13B-Instruct --trust-remote-code
|
227 |
+
|
228 |
+
```
|
229 |
+
|
230 |
+
model downloaded by modelscope:
|
231 |
+
```
|
232 |
+
docker run --privileged --user root --net=host --ipc=host \
|
233 |
+
-v ~/.cache/modelscope:/root/.cache/modelscope \
|
234 |
+
--gpus=all -it --entrypoint python hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-vllm \
|
235 |
+
-m vllm.entrypoints.openai.api_server --host 0.0.0.0 --tensor-parallel-size 4 --port 8000 \
|
236 |
+
--model /root/.cache/modelscope/hub/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct/ --trust_remote_code
|
237 |
+
```
|
238 |
+
|
239 |
+
|
240 |
+
### SGLang
|
241 |
+
|
242 |
+
#### Docker Image
|
243 |
+
|
244 |
+
We also provide a pre-built Docker image based on the latest version of SGLang.
|
245 |
+
|
246 |
+
To get started:
|
247 |
+
|
248 |
+
- Pull the Docker image
|
249 |
+
|
250 |
+
```
|
251 |
+
docker pull docker.cnb.cool/tencent/hunyuan/hunyuan-a13b:hunyuan-moe-A13B-sglang
|
252 |
+
or
|
253 |
+
docker pull hunyuaninfer/hunyuan-a13b:hunyuan-moe-A13B-sglang
|
254 |
+
```
|
255 |
+
|
256 |
+
- Start the API server:
|
257 |
+
|
258 |
+
```
|
259 |
+
docker run --gpus all \
|
260 |
+
--shm-size 32g \
|
261 |
+
-p 30000:30000 \
|
262 |
+
--ipc=host \
|
263 |
+
docker.cnb.cool/tencent/hunyuan/hunyuan-a13b:hunyuan-moe-A13B-sglang \
|
264 |
+
-m sglang.launch_server --model-path hunyuan/huanyuan_A13B --tp 4 --trust-remote-code --host 0.0.0.0 --port 30000
|
265 |
+
```
|
266 |
+
|
267 |
+
|
268 |
+
## Contact Us
|
269 |
+
|
270 |
+
If you would like to leave a message for our R&D and product teams, Welcome to contact our open-source team . You can also contact us via email ([email protected]).
|
README_CN.md
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="https://dscache.tencent-cloud.cn/upload/uploader/hunyuan-64b418fd052c033b228e04bc77bbc4b54fd7f5bc.png" width="400"/> <br>
|
3 |
+
</p><p></p>
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
🫣 <a href="https://huggingface.co/tencent/Hunyuan-A13B-Instruct"><b>Hugging Face</b></a> |
|
7 |
+
🖥️ <a href="https://llm.hunyuan.tencent.com/" style="color: red;"><b>Official Website</b></a> |
|
8 |
+
🕖 <a href="https://cloud.tencent.com/product/hunyuan"><b>HunyuanAPI</b></a> |
|
9 |
+
🕹️ <a href="https://hunyuan.tencent.com/?model=hunyuan-a13b"><b>Demo</b></a> |
|
10 |
+
<img src="https://avatars.githubusercontent.com/u/109945100?s=200&v=4" width="16"/> <a href="https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct"><b>ModelScope</b></a>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
<p align="center">
|
14 |
+
<a href="https://github.com/Tencent/Hunyuan-A13B"><b>GITHUB</b></a>
|
15 |
+
</p>
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
## 模型介绍
|
21 |
+
|
22 |
+
随着人工智能技术的快速发展,大型语言模型(LLMs)在自然语言处理、计算机视觉和科学任务等领域取得了显著进展。然而,随着模型规模的扩大,如何在保持高性能的同时优化资源消耗成为一个关键挑战。为了应对这一挑战,我们研究了混合专家(MoE)模型,当前亮相的 Hunyuan-A13B 模型,拥有800亿总参数和130亿激活参数。不仅在效果上达到了高标准,而且在尺寸上也做到了极致的优化,成功平衡了模型性能与资源占用。
|
23 |
+
|
24 |
+
|
25 |
+
### 核心特性与优势
|
26 |
+
- **小参数量,高性能**:仅激活130亿参数(总参数量800亿),即可在多样化基准任务中媲美更大规模模型的竞争力表现
|
27 |
+
- **混合推理支持**:同时支持快思考和慢思考两种模式,支持用户灵活选择
|
28 |
+
- **超长上下文理解**:原生支持256K上下文窗口,在长文本任务中保持稳定性能
|
29 |
+
- **增强Agent能力**:优化Agent能力,在BFCL-v3、τ-Bench等智能体基准测试中领先
|
30 |
+
- **高效推理**:采用分组查询注意力(GQA)策略,支持多量化格式,实现高效推理
|
31 |
+
|
32 |
+
|
33 |
+
### 为何选择Hunyuan-A13B?
|
34 |
+
作为兼具强大性能与计算效率的大模型,Hunyuan-A13B是研究者与开发者在资源受限条件下追求高性能的理想选择。无论学术研究、高性价比AI解决方案开发,还是创新应用探索,本模型都能提供强大的基础支持。
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
## 新闻
|
40 |
+
<br>
|
41 |
+
|
42 |
+
* 2025.6.26 我们在Hugging Face开源了 **Hunyuan-A13B-Instruct**,**Hunyuan-A13B-Pretrain**, **Hunyuan-A13B-Instruct-FP8**, **Hunyuan-A13B-Instruct-GPTQ-Int4**。并发布了技术报告和训练推理操作手册,详细介绍了模型能力和训练与推理的操作。
|
43 |
+
|
44 |
+
## 模型结构
|
45 |
+
|
46 |
+
Hunyuan-A13B采用了细粒度混合专家(Fine-grained Mixture of Experts,Fine-grained MoE)架构,包含800亿参数和130亿激活参数,累计训练了超过 20T tokens。该模型支持 256K 的上下文长度,以下为模型结构细节:
|
47 |
+
* 总参数: 80B
|
48 |
+
* 激活参数: 13B
|
49 |
+
* 层数: 32
|
50 |
+
* Attention Heads: 32
|
51 |
+
* 共享专家数: 1
|
52 |
+
* 非共享专家数: 64
|
53 |
+
* 路由策略: Top-8
|
54 |
+
* 激活函数: SwiGLU
|
55 |
+
* 隐层维度: 4096
|
56 |
+
* 专家隐层维度: 3072
|
57 |
+
|
58 |
+
## Benchmark评估榜单
|
59 |
+
|
60 |
+
**Hunyuan-A13B-Pretrain** 在 12/14 个任务上超越了Hunyuan上一代52B激活参数的MoE模型Hunyuan-Large,证实了它在预训练任务上出色的能力。与业界更大参数量的Dense和MoE模型相比, Hunyuan-A13B在多个代码和数学任务上都取得了最高分数。在MMLU, MMLU-PRO等诸多众聚合任务上, Hunyuan-A13B达到了与Qwen3-A22B模型同等的水平,表现出优秀的综合能力。
|
61 |
+
|
62 |
+
| Model | Hunyuan-Large | Qwen2.5-72B | Qwen3-A22B | Hunyuan-A13B |
|
63 |
+
|------------------|---------------|--------------|-------------|---------------|
|
64 |
+
| MMLU | 88.40 | 86.10 | 87.81 | 88.17 |
|
65 |
+
| MMLU-Pro | 60.20 | 58.10 | 68.18 | 67.23 |
|
66 |
+
| MMLU-Redux | 87.47 | 83.90 | 87.40 | 87.67 |
|
67 |
+
| BBH | 86.30 | 85.80 | 88.87 | 87.56 |
|
68 |
+
| SuperGPQA | 38.90 | 36.20 | 44.06 | 41.32 |
|
69 |
+
| EvalPlus | 75.69 | 65.93 | 77.60 | 78.64 |
|
70 |
+
| MultiPL-E | 59.13 | 60.50 | 65.94 | 69.33 |
|
71 |
+
| MBPP | 72.60 | 76.00 | 81.40 | 83.86 |
|
72 |
+
| CRUX-I | 57.00 | 57.63 | - | 70.13 |
|
73 |
+
| CRUX-O | 60.63 | 66.20 | 79.00 | 77.00 |
|
74 |
+
| MATH | 69.80 | 62.12 | 71.84 | 72.35 |
|
75 |
+
| CMATH | 91.30 | 84.80 | - | 91.17 |
|
76 |
+
| GSM8k | 92.80 | 91.50 | 94.39 | 91.83 |
|
77 |
+
| GPQA | 25.18 | 45.90 | 47.47 | 49.12 |
|
78 |
+
|
79 |
+
**Hunyuan-A13B-Instruct** 在多项基准测试中取得了极具有竞争力的表现,尤其是在数学、科学、agent等领域。我们与一些强力模型进行了对比,结果如下所示。
|
80 |
+
|
81 |
+
| Topic | Bench | OpenAI-o1-1217 | DeepSeek R1 | Qwen3-A22B | Hunyuan-A13B-Instruct |
|
82 |
+
|:-------------------:|:-----------------------------:|:-------------:|:------------:|:-----------:|:---------------------:|
|
83 |
+
| **Mathematics** | AIME 2024<br>AIME 2025<br>MATH | 74.3<br>79.2<br>96.4 | 79.8<br>70<br>94.9 | 85.7<br>81.5<br>94.0 | 87.3<br>76.8<br>94.3 |
|
84 |
+
| **Science** | GPQA-Diamond<br>OlympiadBench | 78<br>83.1 | 71.5<br>82.4 | 71.1<br>85.7 | 71.2<br>82.7 |
|
85 |
+
| **Coding** | Livecodebench<br>Fullstackbench<br>ArtifactsBench | 63.9<br>64.6<br>38.6 | 65.9<br>71.6<br>44.6 | 70.7<br>65.6<br>44.6 | 63.9<br>67.8<br>43 |
|
86 |
+
| **Reasoning** | BBH<br>DROP<br>ZebraLogic | 80.4<br>90.2<br>81 | 83.7<br>92.2<br>78.7 | 88.9<br>90.3<br>80.3 | 89.1<br>91.1<br>84.7 |
|
87 |
+
| **Instruction<br>Following** | IF-Eval<br>SysBench | 91.8<br>82.5 | 88.3<br>77.7 | 83.4<br>74.2 | 84.7<br>76.1 |
|
88 |
+
| **Text<br>Creation**| LengthCtrl<br>InsCtrl | 60.1<br>74.8 | 55.9<br>69 | 53.3<br>73.7 | 55.4<br>71.9 |
|
89 |
+
| **NLU** | ComplexNLU<br>Word-Task | 64.7<br>67.1 | 64.5<br>76.3 | 59.8<br>56.4 | 61.2<br>62.9 |
|
90 |
+
| **Agent** | BDCL v3<br> τ-Bench<br>ComplexFuncBench<br> $C^3$-Bench | 67.8<br>60.4<br>47.6<br>58.8 | 56.9<br>43.8<br>41.1<br>55.3 | 70.8<br>44.6<br>40.6<br>51.7 | 78.3<br>54.7<br>61.2<br>63.5 |
|
91 |
+
|
92 |
+
|
93 |
+
## 推理和部署
|
94 |
+
|
95 |
+
HunyuanLLM可以采用vLLM,sglang或TensorRT-LLM部署。为了简化部署过程HunyuanLLM提供了预构建docker镜像。
|
96 |
+
|
97 |
+
|
98 |
+
## 使用TensorRT-LLM推理
|
99 |
+
|
100 |
+
### BF16部署
|
101 |
+
|
102 |
+
#### Step1:执行推理
|
103 |
+
|
104 |
+
#### 方式1:命令行推理
|
105 |
+
|
106 |
+
下面我们展示一个代码片段,采用`TensorRT-LLM`快速请求chat model:
|
107 |
+
修改 examples/pytorch/quickstart_advanced.py 中如下代码:
|
108 |
+
|
109 |
+
|
110 |
+
```python
|
111 |
+
from tensorrt_llm import SamplingParams
|
112 |
+
from tensorrt_llm._torch import LLM
|
113 |
+
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
114 |
+
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
|
115 |
+
MTPDecodingConfig)
|
116 |
+
|
117 |
+
prompt = "Write a short summary of the benefits of regular exercise"
|
118 |
+
|
119 |
+
def main():
|
120 |
+
args = parse_arguments()
|
121 |
+
|
122 |
+
llm, sampling_params = setup_llm(args)
|
123 |
+
new_prompts = []
|
124 |
+
if args.apply_chat_template:
|
125 |
+
messages = [{"role": "user", "content": f"{prompt}"}]
|
126 |
+
new_prompts.append(llm.tokenizer.apply_chat_template(
|
127 |
+
messages, tokenize=False, add_generation_prompt=True)
|
128 |
+
)
|
129 |
+
|
130 |
+
outputs = llm.generate(new_prompts, sampling_params)
|
131 |
+
|
132 |
+
for i, output in enumerate(outputs):
|
133 |
+
prompt = output.prompt
|
134 |
+
generated_text = output.outputs[0].text
|
135 |
+
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
136 |
+
```
|
137 |
+
|
138 |
+
运行方式:
|
139 |
+
|
140 |
+
```shell
|
141 |
+
python3 quickstart_advanced.py --model_dir "HunyuanLLM模型路径" --tp_size 4 --apply_chat_template
|
142 |
+
```
|
143 |
+
|
144 |
+
#### 方式2:服务化推理
|
145 |
+
|
146 |
+
下面我们展示使用`TensorRT-LLM`服务化的方式部署模型和请求。
|
147 |
+
|
148 |
+
```shell
|
149 |
+
model_path="HunyuanLLM模型路径"
|
150 |
+
trtllm-serve <model_path> [--backend pytorch --tp_size <tp> --ep_size <ep> --host <host> --port <port>]
|
151 |
+
```
|
152 |
+
|
153 |
+
服务启动成功后, 运行请求脚本:
|
154 |
+
```python
|
155 |
+
### OpenAI Chat Client
|
156 |
+
|
157 |
+
from openai import OpenAI
|
158 |
+
|
159 |
+
client = OpenAI(
|
160 |
+
base_url="http://localhost:8000/v1",
|
161 |
+
api_key="tensorrt_llm",
|
162 |
+
)
|
163 |
+
|
164 |
+
response = client.chat.completions.create(
|
165 |
+
model="default",
|
166 |
+
messages=[{
|
167 |
+
"role": "user",
|
168 |
+
"content": "Write a short summary of the benefits of regular exercise"
|
169 |
+
}],
|
170 |
+
max_tokens=4096,
|
171 |
+
)
|
172 |
+
print(response)
|
173 |
+
```
|
174 |
+
|
175 |
+
#### FP8/Int4量化模型部署:
|
176 |
+
目前 TensorRT-LLM 的 fp8 和 int4 量化模型正在支持中,敬请期待。
|
177 |
+
|
178 |
+
|
179 |
+
## 使用vLLM推理
|
180 |
+
### Docker:
|
181 |
+
|
182 |
+
为了简化部署过程,HunyuanLLM提供了预构建docker镜像:
|
183 |
+
|
184 |
+
[hunyuaninfer/hunyuan-large:hunyuan-moe-A13B-vllm](https://hub.docker.com/r/hunyuaninfer/hunyuan-large/tags) 。您只需要下载模型文件并用下面代码启动docker即可开始推理模型。
|
185 |
+
```shell
|
186 |
+
# 拉取
|
187 |
+
docker pull hunyuaninfer/hunyuan-large:hunyuan-moe-A13B-vllm
|
188 |
+
# 起镜像
|
189 |
+
docker run --name hunyuanLLM_infer -itd --privileged --user root --net=host --ipc=host --gpus=8 hunyuaninfer/hunyuan-large:hunyuan-moe-A13B-vllm
|
190 |
+
```
|
191 |
+
|
192 |
+
注: Docker容器权限管理。以上代码采用特权模式(--privileged)启动Docker容器会赋予容器较高的权限,增加数据泄露和集群安全风险。建议在非必要情况下避免使用特权模式,以降低安全威胁。对于必须使用特权模式的场景,应进行严格的安全评估,并实施相应的安全监控、加固措施。
|
193 |
+
|
194 |
+
|
195 |
+
### BF16部署
|
196 |
+
|
197 |
+
BF16可以在2张显存超过80G的GPU卡上部署,如果长文推荐TP4。按如下步骤执行:
|
198 |
+
|
199 |
+
运行命令前请先设置如下环境变量:
|
200 |
+
|
201 |
+
```shell
|
202 |
+
export MODEL_PATH=PATH_TO_MODEL
|
203 |
+
```
|
204 |
+
|
205 |
+
#### Step1:执行推理
|
206 |
+
|
207 |
+
#### 方式1:命令行推理
|
208 |
+
|
209 |
+
下面我们展示一个代码片段,采用`vLLM`快速请求chat model:
|
210 |
+
|
211 |
+
注: vLLM组件远程代码执行防护。下列代码中vLLM组件的trust-remote-code配置项若被启用,将允许加载并执行来自远程模型仓库的代码,这可能导致恶意代码的执行。除非业务需求明确要求,否则建议该配置项处于禁用状态,以降低潜在的安全威胁。
|
212 |
+
|
213 |
+
|
214 |
+
```python
|
215 |
+
import os
|
216 |
+
from typing import List, Optional
|
217 |
+
from vllm import LLM, SamplingParams
|
218 |
+
from vllm.inputs import PromptType
|
219 |
+
from transformers import AutoTokenizer
|
220 |
+
|
221 |
+
model_path=os.environ.get('MODEL_PATH')
|
222 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
223 |
+
|
224 |
+
llm = LLM(model=model_path,
|
225 |
+
tokenizer=model_path,
|
226 |
+
trust_remote_code=True,
|
227 |
+
dtype='bfloat16',
|
228 |
+
tensor_parallel_size=4,
|
229 |
+
gpu_memory_utilization=0.9)
|
230 |
+
|
231 |
+
sampling_params = SamplingParams(
|
232 |
+
temperature=0.7, top_p=0.8, max_tokens=4096, top_k=20, repetition_penalty=1.05)
|
233 |
+
|
234 |
+
messages = [
|
235 |
+
{
|
236 |
+
"role": "system",
|
237 |
+
"content": "You are a helpful assistant.",
|
238 |
+
},
|
239 |
+
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
|
240 |
+
]
|
241 |
+
|
242 |
+
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
243 |
+
|
244 |
+
dummy_inputs: List[PromptType] = [{
|
245 |
+
"prompt_token_ids": batch
|
246 |
+
} for batch in tokenized_chat.numpy().tolist()]
|
247 |
+
|
248 |
+
outputs = llm.generate(dummy_inputs, sampling_params)
|
249 |
+
|
250 |
+
# Print the outputs.
|
251 |
+
for output in outputs:
|
252 |
+
prompt = output.prompt
|
253 |
+
generated_text = output.outputs[0].text
|
254 |
+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
255 |
+
```
|
256 |
+
|
257 |
+
#### 方式2:服务化推理
|
258 |
+
|
259 |
+
下面我们展示使用`vLLM`服务化的方式部署模型并请求
|
260 |
+
|
261 |
+
在主节点上运行:
|
262 |
+
|
263 |
+
```shell
|
264 |
+
export VLLM_HOST_IP=${LOCAL_IP}
|
265 |
+
```
|
266 |
+
接着我们启动服务,运行 :
|
267 |
+
```shell
|
268 |
+
cd inference
|
269 |
+
sh run_server.sh
|
270 |
+
```
|
271 |
+
|
272 |
+
运行`run_server.sh`成功后, 运行请求脚本:
|
273 |
+
```shell
|
274 |
+
sh openapi.sh
|
275 |
+
```
|
276 |
+
|
277 |
+
注意修改`openapi.sh`中的`${LOCAL_IP}`和`${MODEL_PATH}`为服务对应值。
|
278 |
+
|
279 |
+
|
280 |
+
### 量化模型部署:
|
281 |
+
|
282 |
+
本部分介绍采用vLLM部署量化后模型的流程。
|
283 |
+
|
284 |
+
镜像:部署镜像同BF16。
|
285 |
+
|
286 |
+
|
287 |
+
#### Int8量化模型部署:
|
288 |
+
部署Int8-weight-only版本HunYuan-A13B模型只需设置`run_server_int8.sh`中的环境变量:
|
289 |
+
```SHELL
|
290 |
+
export MODEL_PATH=PATH_TO_BF16_MODEL
|
291 |
+
```
|
292 |
+
|
293 |
+
接着我们启动Int8服务。运行:
|
294 |
+
```shell
|
295 |
+
sh run_server_int8.sh
|
296 |
+
```
|
297 |
+
|
298 |
+
运行`run_server_int8.sh`成功后, 运行请求脚本:
|
299 |
+
```shell
|
300 |
+
sh openapi.sh
|
301 |
+
```
|
302 |
+
|
303 |
+
#### Int4量化模型部署:
|
304 |
+
部署Int4-weight-only版本HunYuan-A13B模型只需设置`run_server_int4.sh`中的环境变量,采用GPTQ方式:
|
305 |
+
```SHELL
|
306 |
+
export MODEL_PATH=PATH_TO_INT4_MODEL
|
307 |
+
```
|
308 |
+
|
309 |
+
接着我们启动Int4服务。运行:
|
310 |
+
```shell
|
311 |
+
sh run_server_int4.sh
|
312 |
+
```
|
313 |
+
|
314 |
+
运行`run_server_int4.sh`成功后, 运行请求脚本:
|
315 |
+
```shell
|
316 |
+
sh openapi.sh
|
317 |
+
```
|
318 |
+
|
319 |
+
#### FP8量化模型部署:
|
320 |
+
部署W8A8C8版本HunYuan-A13B模型只需设置`run_server_int8.sh`中的环境变量:
|
321 |
+
```shell
|
322 |
+
export MODEL_PATH=PATH_TO_FP8_MODEL
|
323 |
+
```
|
324 |
+
|
325 |
+
接着我们启动FP8服务。运行:
|
326 |
+
```shell
|
327 |
+
sh run_server_fp8.sh
|
328 |
+
```
|
329 |
+
|
330 |
+
运行`run_server_fp8.sh`成功后, 运行请求脚本:
|
331 |
+
```shell
|
332 |
+
sh openapi.sh
|
333 |
+
```
|
334 |
+
|
335 |
+
### 性能评估:
|
336 |
+
|
337 |
+
本部分介绍采用vLLM部署各个模型(原始模型和量化模型)的效率测试结果,包括不同Batchsize下的推理速度(tokens/s), 测试环境(腾讯云,H80(96G)GPU x 卡数):
|
338 |
+
|
339 |
+
测试命令:
|
340 |
+
```python
|
341 |
+
python3 benchmark_throughput.py --backend vllm \
|
342 |
+
--input-len 2048 \
|
343 |
+
--output-len 14336 \
|
344 |
+
--model $MODEL_PATH \
|
345 |
+
--tensor-parallel-size $TP \
|
346 |
+
--use-v2-block-manager \
|
347 |
+
--async-engine \
|
348 |
+
--trust-remote-code \
|
349 |
+
--num_prompts $BATCH_SIZE \
|
350 |
+
--max-num-seqs $BATCH_SIZE
|
351 |
+
```
|
352 |
+
|
353 |
+
| 推理框架 | 模型 | 部署卡数 | input_length | batch=1 | batch=16 | batch=32 |
|
354 |
+
|------|-----------------------------|-----------|-------------------------|---------------------|----------------------|----------------------|
|
355 |
+
| vLLM | Hunyuan-A13B-Instruct | 8 | 2048 | 190.84 | 1246.54 | 1981.99 |
|
356 |
+
| vLLM | Hunyuan-A13B-Instruct | 4 | 2048 | 158.90 | 779.10 | 1301.75 |
|
357 |
+
| vLLM | Hunyuan-A13B-Instruct | 2 | 2048 | 111.72 | 327.31 | 346.54 |
|
358 |
+
| vLLM | Hunyuan-A13B-Instruct(int8 weight only) | 2 | 2048 | 109.10 | 444.17 | 721.93 |
|
359 |
+
| vLLM | Hunyuan-A13B-Instruct(W8A8C8-FP8) | 2 | 2048 | 91.83 | 372.01 | 617.70 |
|
360 |
+
| vLLM | Hunyuan-A13B-Instruct(W8A8C8-FP8) | 1 | 2048 | 60.07 | 148.80 | 160.41 |
|
361 |
+
|
362 |
+
|
363 |
+
## 使用sglang推理
|
364 |
+
|
365 |
+
### BF16部署
|
366 |
+
|
367 |
+
#### Step1:执行推理
|
368 |
+
|
369 |
+
#### 方式1:命令行推理
|
370 |
+
|
371 |
+
下面我们展示一个代码片段,采用`sglang`快速请求chat model:
|
372 |
+
|
373 |
+
|
374 |
+
```python
|
375 |
+
import sglang as sgl
|
376 |
+
from transformers import AutoTokenizer
|
377 |
+
|
378 |
+
model_path=os.environ.get('MODEL_PATH')
|
379 |
+
|
380 |
+
|
381 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
382 |
+
|
383 |
+
messages = [
|
384 |
+
{
|
385 |
+
"role": "system",
|
386 |
+
"content": "You are a helpful assistant.",
|
387 |
+
},
|
388 |
+
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
|
389 |
+
]
|
390 |
+
prompts = []
|
391 |
+
prompts.append(tokenizer.apply_chat_template(
|
392 |
+
messages,
|
393 |
+
tokenize=False,
|
394 |
+
add_generation_prompt=True
|
395 |
+
))
|
396 |
+
print(prompts)
|
397 |
+
|
398 |
+
llm = sgl.Engine(
|
399 |
+
model_path=model_path,
|
400 |
+
tp_size=4,
|
401 |
+
trust_remote_code=True,
|
402 |
+
mem_fraction_static=0.7,
|
403 |
+
)
|
404 |
+
|
405 |
+
sampling_params = {"temperature": 0.7, "top_p": 0.8, "top_k": 20, "max_new_tokens": 4096}
|
406 |
+
outputs = llm.generate(prompts, sampling_params)
|
407 |
+
for prompt, output in zip(prompts, outputs):
|
408 |
+
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
409 |
+
```
|
410 |
+
|
411 |
+
#### 方式2:服务化推理
|
412 |
+
|
413 |
+
下面我们展示使用`sglang`服务化的方式部署模型和请求。
|
414 |
+
|
415 |
+
```shell
|
416 |
+
model_path="HunyuanLLM模型路径"
|
417 |
+
python3 -u -m sglang.launch_server \
|
418 |
+
--model-path $model_path \
|
419 |
+
--tp 4 \
|
420 |
+
--trust-remote-code \
|
421 |
+
```
|
422 |
+
|
423 |
+
服务启动成功后, 运行请求脚本:
|
424 |
+
```python
|
425 |
+
import openai
|
426 |
+
client = openai.Client(
|
427 |
+
base_url="http://localhost:30000/v1", api_key="EMPTY")
|
428 |
+
|
429 |
+
response = client.chat.completions.create(
|
430 |
+
model="default",
|
431 |
+
messages= [
|
432 |
+
{"role": "user", "content": "Write a short summary of the benefits of regular exercise"},
|
433 |
+
],
|
434 |
+
temperature=0.7,
|
435 |
+
max_tokens=4096,
|
436 |
+
extra_body={"top_p": 0.8, "top_k": 20}
|
437 |
+
)
|
438 |
+
print(response)
|
439 |
+
```
|
440 |
+
|
441 |
+
#### FP8/Int4量化模型部署:
|
442 |
+
目前 sglang 的 fp8 和 int4 量化模型正在支持中,敬请期待。
|
443 |
+
|
444 |
+
## 交互式Demo Web
|
445 |
+
hunyuan-A13B 现已开放网页demo。访问 https://hunyuan.tencent.com/?model=hunyuan-a13b 即可简单体验我们的模型。
|
446 |
+
|
447 |
+
<br>
|
448 |
+
|
449 |
+
## 引用
|
450 |
+
如果你觉得我们的工作对你有帮助,欢迎引用我们的<a href="report/Hunyuan_A13B_Technical_Report.pdf">技术报告</a>!
|
451 |
+
|
452 |
+
<br>
|
453 |
+
|
454 |
+
|
455 |
+
## 联系我们
|
456 |
+
如果你想给我们的研发和产品团队留言,欢迎联系我们腾讯混元LLM团队。你可以通过邮件([email protected])联系我们。
|
config.json
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_classification_head": false,
|
3 |
+
"anyres_pooling_size": 2,
|
4 |
+
"anyres_vit_max_image_size": null,
|
5 |
+
"anyres_vit_two_views": false,
|
6 |
+
"architectures": [
|
7 |
+
"HunYuanMoEV1ForCausalLM"
|
8 |
+
],
|
9 |
+
"attention_bias": false,
|
10 |
+
"attention_dropout": 0.1,
|
11 |
+
"attention_head_dim": 128,
|
12 |
+
"auto_map": {
|
13 |
+
"AutoConfig": "configuration_hunyuan.HunYuanConfig",
|
14 |
+
"AutoModel": "hunyuan.HunYuanModel",
|
15 |
+
"AutoModelForCausalLM": "hunyuan.HunYuanMoEV1ForCausalLM"
|
16 |
+
},
|
17 |
+
"bos_token_id": 1,
|
18 |
+
"cla_share_factor": 2,
|
19 |
+
"class_num": 0,
|
20 |
+
"dense_list": [
|
21 |
+
4096,
|
22 |
+
0
|
23 |
+
],
|
24 |
+
"eod_token_id": 127967,
|
25 |
+
"eos_token_id": 127960,
|
26 |
+
"group_limited_greedy": false,
|
27 |
+
"hidden_act": "silu",
|
28 |
+
"hidden_size": 4096,
|
29 |
+
"im_end_id": 6,
|
30 |
+
"im_newline_id": 12,
|
31 |
+
"im_start_id": 5,
|
32 |
+
"image_token_id": 9,
|
33 |
+
"initializer_range": 0.02,
|
34 |
+
"intermediate_size": 3072,
|
35 |
+
"kv_lora_rank": null,
|
36 |
+
"mask_init_id": 13,
|
37 |
+
"max_position_embeddings": 32768,
|
38 |
+
"mlp_bias": false,
|
39 |
+
"model_type": "hunyuan",
|
40 |
+
"moe_drop_tokens": false,
|
41 |
+
"moe_intermediate_size": [
|
42 |
+
3072,
|
43 |
+
3072,
|
44 |
+
3072,
|
45 |
+
3072,
|
46 |
+
3072,
|
47 |
+
3072,
|
48 |
+
3072,
|
49 |
+
3072,
|
50 |
+
3072,
|
51 |
+
3072,
|
52 |
+
3072,
|
53 |
+
3072,
|
54 |
+
3072,
|
55 |
+
3072,
|
56 |
+
3072,
|
57 |
+
3072,
|
58 |
+
3072,
|
59 |
+
3072,
|
60 |
+
3072,
|
61 |
+
3072,
|
62 |
+
3072,
|
63 |
+
3072,
|
64 |
+
3072,
|
65 |
+
3072,
|
66 |
+
3072,
|
67 |
+
3072,
|
68 |
+
3072,
|
69 |
+
3072,
|
70 |
+
3072,
|
71 |
+
3072,
|
72 |
+
3072,
|
73 |
+
3072
|
74 |
+
],
|
75 |
+
"moe_layer_num_skipped": 0,
|
76 |
+
"moe_random_routing_dropped_token": false,
|
77 |
+
"moe_topk": [
|
78 |
+
8,
|
79 |
+
8,
|
80 |
+
8,
|
81 |
+
8,
|
82 |
+
8,
|
83 |
+
8,
|
84 |
+
8,
|
85 |
+
8,
|
86 |
+
8,
|
87 |
+
8,
|
88 |
+
8,
|
89 |
+
8,
|
90 |
+
8,
|
91 |
+
8,
|
92 |
+
8,
|
93 |
+
8,
|
94 |
+
8,
|
95 |
+
8,
|
96 |
+
8,
|
97 |
+
8,
|
98 |
+
8,
|
99 |
+
8,
|
100 |
+
8,
|
101 |
+
8,
|
102 |
+
8,
|
103 |
+
8,
|
104 |
+
8,
|
105 |
+
8,
|
106 |
+
8,
|
107 |
+
8,
|
108 |
+
8,
|
109 |
+
8
|
110 |
+
],
|
111 |
+
"n_group": null,
|
112 |
+
"norm_topk_prob": true,
|
113 |
+
"norm_type": "rms",
|
114 |
+
"num_attention_heads": 32,
|
115 |
+
"num_experts": 64,
|
116 |
+
"num_hidden_layers": 32,
|
117 |
+
"num_key_value_heads": 8,
|
118 |
+
"num_media_embeds": 257,
|
119 |
+
"num_shared_expert": [
|
120 |
+
1,
|
121 |
+
1,
|
122 |
+
1,
|
123 |
+
1,
|
124 |
+
1,
|
125 |
+
1,
|
126 |
+
1,
|
127 |
+
1,
|
128 |
+
1,
|
129 |
+
1,
|
130 |
+
1,
|
131 |
+
1,
|
132 |
+
1,
|
133 |
+
1,
|
134 |
+
1,
|
135 |
+
1,
|
136 |
+
1,
|
137 |
+
1,
|
138 |
+
1,
|
139 |
+
1,
|
140 |
+
1,
|
141 |
+
1,
|
142 |
+
1,
|
143 |
+
1,
|
144 |
+
1,
|
145 |
+
1,
|
146 |
+
1,
|
147 |
+
1,
|
148 |
+
1,
|
149 |
+
1,
|
150 |
+
1,
|
151 |
+
1
|
152 |
+
],
|
153 |
+
"org_vocab_size": 128167,
|
154 |
+
"pad_id": 127961,
|
155 |
+
"pad_token_id": 127961,
|
156 |
+
"pool_type": "last",
|
157 |
+
"position_embedding_xdrope": false,
|
158 |
+
"pretraining_tp": 1,
|
159 |
+
"q_lora_rank": null,
|
160 |
+
"qk_nope_head_dim": null,
|
161 |
+
"qk_rope_head_dim": null,
|
162 |
+
"rms_norm_eps": 1e-05,
|
163 |
+
"rope_scaling": {
|
164 |
+
"alpha": 1000.0,
|
165 |
+
"beta_fast": 32,
|
166 |
+
"beta_slow": 1,
|
167 |
+
"factor": 1.0,
|
168 |
+
"mscale": 1.0,
|
169 |
+
"mscale_all_dim": 1.0,
|
170 |
+
"type": "dynamic"
|
171 |
+
},
|
172 |
+
"rope_theta": 10000.0,
|
173 |
+
"routed_scaling_factor": 1.0,
|
174 |
+
"sep_token_id": 127962,
|
175 |
+
"skip_cls_token": false,
|
176 |
+
"text_end_id": 8,
|
177 |
+
"text_start_id": 7,
|
178 |
+
"tie_word_embeddings": true,
|
179 |
+
"topk_group": null,
|
180 |
+
"torch_dtype": "bfloat16",
|
181 |
+
"transformers_version": "4.41.2",
|
182 |
+
"use_cache": true,
|
183 |
+
"use_cla": false,
|
184 |
+
"use_mixed_mlp_moe": true,
|
185 |
+
"use_mla": false,
|
186 |
+
"use_qk_norm": true,
|
187 |
+
"use_rotary_pos_emb": true,
|
188 |
+
"v_head_dim": null,
|
189 |
+
"video_end_id": 11,
|
190 |
+
"video_start_id": 10,
|
191 |
+
"vit_add_patchemb_bias": false,
|
192 |
+
"vit_input_resolution": 224,
|
193 |
+
"vit_mapping_type": "resampler",
|
194 |
+
"vit_norm_type": "fused",
|
195 |
+
"vit_patch": 1,
|
196 |
+
"vit_path": null,
|
197 |
+
"vit_remove_prenorm": false,
|
198 |
+
"vit_token": 64,
|
199 |
+
"vit_type": null,
|
200 |
+
"vit_used_rms_norm": false,
|
201 |
+
"vocab_size": 128167,
|
202 |
+
"xdrope_section": null
|
203 |
+
}
|
configuration_hunyuan.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
3 |
+
""" HunYuan model configuration"""
|
4 |
+
from torch import nn
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
from transformers.utils import logging
|
7 |
+
from typing import List, Union, Optional
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class HunYuanConfig(PretrainedConfig):
|
14 |
+
r"""
|
15 |
+
This is the configuration class to store the configuration of a [`HunYuanModel`]. It is used to instantiate an
|
16 |
+
HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
17 |
+
with the defaults will yield a similar configuration to that of the HunYuan-7B.
|
18 |
+
|
19 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
20 |
+
documentation from [`PretrainedConfig`] for more information.
|
21 |
+
|
22 |
+
|
23 |
+
Args:
|
24 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
25 |
+
Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
|
26 |
+
`inputs_ids` passed when calling [`HunYuanModel`]
|
27 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
28 |
+
Dimension of the hidden representations.
|
29 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
30 |
+
Dimension of the MLP representations or shared MLP representations.
|
31 |
+
moe_intermediate_size (`int` or `List`, *optional*, defaults to 11008):
|
32 |
+
Dimension of the MLP representations in MoE. Use a list if you want a different size per layer.
|
33 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
34 |
+
Number of hidden layers in the Transformer decoder.
|
35 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
36 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
37 |
+
num_key_value_heads (`int`, *optional*):
|
38 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
39 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
40 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
41 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
42 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
43 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
44 |
+
`num_attention_heads`.
|
45 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
46 |
+
The non-linear activation function (function or string) in the decoder.
|
47 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
48 |
+
The maximum sequence length that this model might ever be used with.
|
49 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
50 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
51 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
52 |
+
The epsilon used by the rms normalization layers.
|
53 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
54 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
55 |
+
relevant if `config.is_decoder=True`.
|
56 |
+
pad_token_id (`int`, *optional*):
|
57 |
+
Padding token id.
|
58 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
59 |
+
Beginning of stream token id.
|
60 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
61 |
+
End of stream token id.
|
62 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
63 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
64 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
65 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
66 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
67 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
68 |
+
Whether to tie weight embeddings
|
69 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
70 |
+
The base period of the RoPE embeddings.
|
71 |
+
rope_scaling (`Dict`, *optional*):
|
72 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
73 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
74 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
75 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
76 |
+
these scaling strategies behave:
|
77 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
78 |
+
experimental feature, subject to breaking API changes in future versions.
|
79 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
80 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
81 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
82 |
+
The dropout ratio for the attention probabilities.
|
83 |
+
use_qk_norm (`bool`, *optional*, defaults to `False`):
|
84 |
+
Whether query and key in attention use norm
|
85 |
+
use_cla (`bool`, *optional*, defaults to `False`):
|
86 |
+
Whether to use CLA in attention
|
87 |
+
cla_share_factor (`int`, *optional*, defaults to 1):
|
88 |
+
The share factor of CLA
|
89 |
+
num_experts (`int` or `List`, *optional*, defaults to 1):
|
90 |
+
The number of experts for moe. If it is a list, it will be used as the number of experts for each layer.
|
91 |
+
num_shared_expert (`int` or `List`, *optional*, defaults to 1):
|
92 |
+
The number of shared experts for moe. If it is a list, it will be used as the number of shared experts for each layer.
|
93 |
+
moe_topk (`int` or `List`, *optional*, defaults to 1):
|
94 |
+
The topk value for moe. If it is a list, it will be used as the topk value for each layer.
|
95 |
+
capacity_factor (Not used) (`float` or `List`, *optional*, defaults to 1.0):
|
96 |
+
The capacity factor for moe. If it is a list, it will be used as the capacity factor for each layer.
|
97 |
+
moe_layer_num_skipped (`int`, *optional*, defaults to 0):
|
98 |
+
First moe_layer_num_skipped layers do not use MoE.
|
99 |
+
"""
|
100 |
+
|
101 |
+
model_type = "hunyuan"
|
102 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
vocab_size=290943,
|
107 |
+
org_vocab_size=290943,
|
108 |
+
hidden_size=4096,
|
109 |
+
intermediate_size: int=11008,
|
110 |
+
moe_intermediate_size: Union[int, List]=None,
|
111 |
+
num_hidden_layers=32,
|
112 |
+
num_attention_heads=32,
|
113 |
+
num_key_value_heads=None,
|
114 |
+
attention_head_dim=None,
|
115 |
+
hidden_act="silu",
|
116 |
+
max_position_embeddings=2048,
|
117 |
+
initializer_range=0.02,
|
118 |
+
rms_norm_eps=1e-5,
|
119 |
+
use_cache=True,
|
120 |
+
pad_token_id=0,
|
121 |
+
bos_token_id=1,
|
122 |
+
eos_token_id=2,
|
123 |
+
eod_token_id=3,
|
124 |
+
sep_token_id=4,
|
125 |
+
im_start_id=5,
|
126 |
+
im_end_id=6,
|
127 |
+
text_start_id=7,
|
128 |
+
text_end_id=8,
|
129 |
+
image_token_id=9,
|
130 |
+
video_start_id=10,
|
131 |
+
video_end_id=11,
|
132 |
+
im_newline_id=12,
|
133 |
+
mask_init_id=13,
|
134 |
+
pretraining_tp=1,
|
135 |
+
tie_word_embeddings=False,
|
136 |
+
rope_theta=10000.0,
|
137 |
+
rope_scaling=None,
|
138 |
+
attention_bias=False,
|
139 |
+
mlp_bias=False,
|
140 |
+
attention_dropout=0.0,
|
141 |
+
use_qk_norm=False,
|
142 |
+
use_rotary_pos_emb=True,
|
143 |
+
use_cla=False,
|
144 |
+
cla_share_factor=1,
|
145 |
+
norm_type="hf_rms",
|
146 |
+
num_experts: Union[int, List]=1,
|
147 |
+
use_mixed_mlp_moe=False,
|
148 |
+
num_shared_expert: Union[int, List]=1,
|
149 |
+
moe_topk: Union[int, List]=1,
|
150 |
+
# capacity_factor: Union[int, List]=1.0,
|
151 |
+
moe_drop_tokens=False,
|
152 |
+
moe_random_routing_dropped_token=False,
|
153 |
+
use_mla=False,
|
154 |
+
kv_lora_rank=512,
|
155 |
+
q_lora_rank=1536,
|
156 |
+
qk_rope_head_dim=64,
|
157 |
+
v_head_dim=128,
|
158 |
+
qk_nope_head_dim=128,
|
159 |
+
moe_layer_num_skipped=0,
|
160 |
+
norm_topk_prob=True,
|
161 |
+
routed_scaling_factor=1.0,
|
162 |
+
group_limited_greedy=False,
|
163 |
+
n_group=None,
|
164 |
+
topk_group=None,
|
165 |
+
vit_path=None,
|
166 |
+
num_media_embeds=257,
|
167 |
+
vit_type="AnyResVit",
|
168 |
+
vit_input_resolution=224,
|
169 |
+
vit_token=64,
|
170 |
+
vit_patch=1,
|
171 |
+
vit_mapping_type="simple_conv_mlp",
|
172 |
+
vit_norm_type="fused",
|
173 |
+
vit_used_rms_norm=True,
|
174 |
+
vit_remove_prenorm=True,
|
175 |
+
vit_add_patchemb_bias=True,
|
176 |
+
anyres_vit_max_image_size=2048,
|
177 |
+
anyres_pooling_size=2,
|
178 |
+
anyres_vit_two_views=False,
|
179 |
+
skip_cls_token=False,
|
180 |
+
position_embedding_xdrope=False,
|
181 |
+
xdrope_section=None,
|
182 |
+
add_classification_head=False,
|
183 |
+
class_num=0,
|
184 |
+
pool_type="last",
|
185 |
+
pad_id=-1,
|
186 |
+
**kwargs,
|
187 |
+
):
|
188 |
+
self.vocab_size = vocab_size
|
189 |
+
self.org_vocab_size = org_vocab_size
|
190 |
+
self.max_position_embeddings = max_position_embeddings
|
191 |
+
self.hidden_size = hidden_size
|
192 |
+
self.intermediate_size = intermediate_size
|
193 |
+
self.moe_intermediate_size = moe_intermediate_size
|
194 |
+
self.num_hidden_layers = num_hidden_layers
|
195 |
+
self.num_attention_heads = num_attention_heads
|
196 |
+
self.num_experts = num_experts
|
197 |
+
self.use_mixed_mlp_moe = use_mixed_mlp_moe
|
198 |
+
self.num_shared_expert = num_shared_expert
|
199 |
+
self.moe_topk = moe_topk
|
200 |
+
# self.capacity_factor = capacity_factor
|
201 |
+
self.moe_drop_tokens = moe_drop_tokens
|
202 |
+
self.moe_random_routing_dropped_token = moe_random_routing_dropped_token
|
203 |
+
|
204 |
+
if attention_head_dim is not None:
|
205 |
+
self.attention_head_dim = attention_head_dim
|
206 |
+
else:
|
207 |
+
self.attention_head_dim = self.hidden_size // num_attention_heads
|
208 |
+
|
209 |
+
# for backward compatibility
|
210 |
+
if num_key_value_heads is None:
|
211 |
+
num_key_value_heads = num_attention_heads
|
212 |
+
|
213 |
+
self.num_key_value_heads = num_key_value_heads
|
214 |
+
self.hidden_act = hidden_act
|
215 |
+
self.initializer_range = initializer_range
|
216 |
+
self.rms_norm_eps = rms_norm_eps
|
217 |
+
self.pretraining_tp = pretraining_tp
|
218 |
+
self.use_cache = use_cache
|
219 |
+
self.rope_theta = rope_theta
|
220 |
+
self.rope_scaling = rope_scaling
|
221 |
+
# self._rope_scaling_validation() # TODO: Need validation?
|
222 |
+
self.attention_bias = attention_bias
|
223 |
+
self.mlp_bias = mlp_bias
|
224 |
+
self.attention_dropout = attention_dropout
|
225 |
+
self.use_qk_norm = use_qk_norm
|
226 |
+
self.use_rotary_pos_emb = use_rotary_pos_emb
|
227 |
+
self.use_cla = use_cla
|
228 |
+
self.cla_share_factor = cla_share_factor
|
229 |
+
self.norm_type = norm_type
|
230 |
+
# MLA args
|
231 |
+
self.use_mla = use_mla
|
232 |
+
self.kv_lora_rank = kv_lora_rank
|
233 |
+
self.q_lora_rank = q_lora_rank
|
234 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
235 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
236 |
+
self.v_head_dim = v_head_dim
|
237 |
+
|
238 |
+
# DeepSeek related args
|
239 |
+
self.moe_layer_num_skipped = moe_layer_num_skipped
|
240 |
+
self.norm_topk_prob = norm_topk_prob
|
241 |
+
self.routed_scaling_factor = routed_scaling_factor
|
242 |
+
self.group_limited_greedy = group_limited_greedy
|
243 |
+
self.n_group = n_group
|
244 |
+
self.topk_group = topk_group
|
245 |
+
self.add_classification_head = add_classification_head
|
246 |
+
self.class_num = class_num
|
247 |
+
self.pool_type = pool_type
|
248 |
+
self.pad_id = pad_id
|
249 |
+
|
250 |
+
if self.class_num is not None:
|
251 |
+
self.dense_list = [self.hidden_size, self.class_num]
|
252 |
+
|
253 |
+
# Vit args
|
254 |
+
self.vit_path = vit_path
|
255 |
+
self.num_media_embeds = num_media_embeds
|
256 |
+
self.vit_type = vit_type
|
257 |
+
self.vit_input_resolution = vit_input_resolution
|
258 |
+
self.vit_token = vit_token
|
259 |
+
self.vit_patch = vit_patch
|
260 |
+
self.vit_mapping_type = vit_mapping_type
|
261 |
+
self.vit_norm_type = vit_norm_type
|
262 |
+
self.vit_used_rms_norm = vit_used_rms_norm
|
263 |
+
self.vit_remove_prenorm = vit_remove_prenorm
|
264 |
+
self.vit_add_patchemb_bias = vit_add_patchemb_bias
|
265 |
+
self.anyres_vit_max_image_size = anyres_vit_max_image_size
|
266 |
+
self.anyres_pooling_size = anyres_pooling_size
|
267 |
+
self.anyres_vit_two_views = anyres_vit_two_views
|
268 |
+
self.skip_cls_token = skip_cls_token
|
269 |
+
self.position_embedding_xdrope = position_embedding_xdrope
|
270 |
+
self.xdrope_section = xdrope_section
|
271 |
+
|
272 |
+
# token id
|
273 |
+
self.eod_token_id = eod_token_id
|
274 |
+
self.im_start_id = im_start_id
|
275 |
+
self.im_end_id = im_end_id
|
276 |
+
self.text_start_id = text_start_id
|
277 |
+
self.text_end_id = text_end_id
|
278 |
+
self.image_token_id = image_token_id
|
279 |
+
self.video_start_id = video_start_id
|
280 |
+
self.video_end_id = video_end_id
|
281 |
+
self.im_newline_id = im_newline_id
|
282 |
+
self.mask_init_id = mask_init_id
|
283 |
+
|
284 |
+
super().__init__(
|
285 |
+
pad_token_id=pad_token_id,
|
286 |
+
bos_token_id=bos_token_id,
|
287 |
+
eos_token_id=eos_token_id,
|
288 |
+
sep_token_id=sep_token_id,
|
289 |
+
tie_word_embeddings=tie_word_embeddings,
|
290 |
+
**kwargs,
|
291 |
+
)
|
292 |
+
|
293 |
+
def _rope_scaling_validation(self):
|
294 |
+
"""
|
295 |
+
Validate the `rope_scaling` configuration.
|
296 |
+
"""
|
297 |
+
if self.rope_scaling is None:
|
298 |
+
return
|
299 |
+
|
300 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
301 |
+
raise ValueError(
|
302 |
+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor` or `type` and `alpha`, "
|
303 |
+
f"got {self.rope_scaling}"
|
304 |
+
)
|
305 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
306 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
307 |
+
rope_scaling_alpha = self.rope_scaling.get("alpha", None)
|
308 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
309 |
+
raise ValueError(
|
310 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
311 |
+
)
|
312 |
+
if rope_scaling_factor is None and rope_scaling_alpha is None:
|
313 |
+
raise ValueError("`rope_scaling`'s factor or alpha field must be have one, got both of none")
|
314 |
+
if rope_scaling_factor is not None:
|
315 |
+
if not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
316 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1.0, got {rope_scaling_factor}")
|
317 |
+
if rope_scaling_alpha is not None:
|
318 |
+
if not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0:
|
319 |
+
raise ValueError(f"`rope_scaling`'s alpha field must be a float > 1.0, got {rope_scaling_alpha}")
|
generation_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"eos_token_id": [127960, 127967],
|
3 |
+
"pad_token_id": 127961,
|
4 |
+
"do_sample": true,
|
5 |
+
"top_k": 20,
|
6 |
+
"top_p": 0.8,
|
7 |
+
"repetition_penalty": 1.05,
|
8 |
+
"temperature": 0.7,
|
9 |
+
"transformers_version": "4.31.0"
|
10 |
+
}
|
hunyuan.py
ADDED
@@ -0,0 +1,879 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
3 |
+
#
|
4 |
+
""" PyTorch HunYuan model."""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import warnings
|
8 |
+
from typing import List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.utils.checkpoint
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
16 |
+
|
17 |
+
from transformers.activations import ACT2FN
|
18 |
+
from transformers.cache_utils import Cache, DynamicCache
|
19 |
+
from transformers.modeling_attn_mask_utils import (
|
20 |
+
AttentionMaskConverter,
|
21 |
+
_prepare_4d_attention_mask,
|
22 |
+
_prepare_4d_causal_attention_mask,
|
23 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
24 |
+
)
|
25 |
+
from transformers.modeling_outputs import (
|
26 |
+
BaseModelOutputWithPast,
|
27 |
+
CausalLMOutputWithPast,
|
28 |
+
SequenceClassifierOutputWithPast
|
29 |
+
)
|
30 |
+
from transformers.modeling_utils import PreTrainedModel
|
31 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
32 |
+
from transformers.utils import (
|
33 |
+
add_start_docstrings,
|
34 |
+
add_start_docstrings_to_model_forward,
|
35 |
+
is_flash_attn_2_available,
|
36 |
+
is_flash_attn_greater_or_equal_2_10,
|
37 |
+
logging,
|
38 |
+
replace_return_docstrings,
|
39 |
+
)
|
40 |
+
from transformers.utils.import_utils import is_torch_fx_available
|
41 |
+
from transformers.generation.utils import GenerateOutput
|
42 |
+
from .configuration_hunyuan import HunYuanConfig
|
43 |
+
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
|
44 |
+
from .vit_model import NaVitForward, VitForward, Vit
|
45 |
+
|
46 |
+
|
47 |
+
if is_flash_attn_2_available():
|
48 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
49 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
50 |
+
|
51 |
+
|
52 |
+
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
53 |
+
# It means that the function will not be traced through and simply appear as a node in the graph.
|
54 |
+
if is_torch_fx_available():
|
55 |
+
if not is_torch_greater_or_equal_than_1_13:
|
56 |
+
import torch.fx
|
57 |
+
|
58 |
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
_CONFIG_FOR_DOC = "HunYuanConfig"
|
63 |
+
|
64 |
+
|
65 |
+
HUNYUAN_START_DOCSTRING = r"""
|
66 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
67 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
68 |
+
etc.)
|
69 |
+
|
70 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
71 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
72 |
+
and behavior.
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
config ([`HunYuanConfig`]):
|
76 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
77 |
+
load the weights associated with the model, only the configuration. Check out the
|
78 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
79 |
+
"""
|
80 |
+
|
81 |
+
|
82 |
+
@add_start_docstrings(
|
83 |
+
"The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
|
84 |
+
HUNYUAN_START_DOCSTRING,
|
85 |
+
)
|
86 |
+
class HunYuanPreTrainedModel(PreTrainedModel):
|
87 |
+
config_class = HunYuanConfig
|
88 |
+
base_model_prefix = "model"
|
89 |
+
supports_gradient_checkpointing = True
|
90 |
+
_no_split_modules = ["HunYuanDecoderLayer"]
|
91 |
+
_skip_keys_device_placement = "past_key_values"
|
92 |
+
_supports_flash_attn_2 = True
|
93 |
+
_supports_sdpa = True
|
94 |
+
_supports_cache_class = True
|
95 |
+
|
96 |
+
def _init_weights(self, module):
|
97 |
+
std = self.config.initializer_range
|
98 |
+
if isinstance(module, nn.Linear):
|
99 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
100 |
+
if module.bias is not None:
|
101 |
+
module.bias.data.zero_()
|
102 |
+
elif isinstance(module, nn.Embedding):
|
103 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
104 |
+
if module.padding_idx is not None:
|
105 |
+
module.weight.data[module.padding_idx].zero_()
|
106 |
+
|
107 |
+
|
108 |
+
HUNYUAN_INPUTS_DOCSTRING = r"""
|
109 |
+
Args:
|
110 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
111 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
112 |
+
it.
|
113 |
+
|
114 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
115 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
116 |
+
|
117 |
+
[What are input IDs?](../glossary#input-ids)
|
118 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
119 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
120 |
+
|
121 |
+
- 1 for tokens that are **not masked**,
|
122 |
+
- 0 for tokens that are **masked**.
|
123 |
+
|
124 |
+
[What are attention masks?](../glossary#attention-mask)
|
125 |
+
|
126 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
127 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
128 |
+
|
129 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
130 |
+
`past_key_values`).
|
131 |
+
|
132 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
133 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
134 |
+
information on the default strategy.
|
135 |
+
|
136 |
+
- 1 indicates the head is **not masked**,
|
137 |
+
- 0 indicates the head is **masked**.
|
138 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
139 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
140 |
+
config.n_positions - 1]`.
|
141 |
+
|
142 |
+
[What are position IDs?](../glossary#position-ids)
|
143 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
144 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
145 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
146 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
147 |
+
|
148 |
+
Two formats are allowed:
|
149 |
+
- a [`~cache_utils.Cache`] instance;
|
150 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
151 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
152 |
+
cache format.
|
153 |
+
|
154 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
155 |
+
legacy cache format will be returned.
|
156 |
+
|
157 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
158 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
159 |
+
of shape `(batch_size, sequence_length)`.
|
160 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
161 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
162 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
163 |
+
model's internal embedding lookup matrix.
|
164 |
+
use_cache (`bool`, *optional*):
|
165 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
166 |
+
`past_key_values`).
|
167 |
+
output_attentions (`bool`, *optional*):
|
168 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
169 |
+
tensors for more detail.
|
170 |
+
output_hidden_states (`bool`, *optional*):
|
171 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
172 |
+
more detail.
|
173 |
+
return_dict (`bool`, *optional*):
|
174 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
175 |
+
"""
|
176 |
+
|
177 |
+
|
178 |
+
@add_start_docstrings(
|
179 |
+
"The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
|
180 |
+
HUNYUAN_START_DOCSTRING,
|
181 |
+
)
|
182 |
+
class HunYuanModel(HunYuanPreTrainedModel):
|
183 |
+
"""
|
184 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
|
185 |
+
|
186 |
+
Args:
|
187 |
+
config: HunYuanConfig
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(self, config: HunYuanConfig):
|
191 |
+
super().__init__(config)
|
192 |
+
self.padding_idx = config.pad_token_id
|
193 |
+
self.vocab_size = config.vocab_size
|
194 |
+
self.add_classification_head = config.add_classification_head
|
195 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
196 |
+
self.layers = nn.ModuleList(
|
197 |
+
[HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
198 |
+
)
|
199 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
200 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
201 |
+
if not config.add_classification_head:
|
202 |
+
self.norm = HunYuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
203 |
+
|
204 |
+
self.cla = config.use_cla
|
205 |
+
self.cla_share_factor = config.cla_share_factor
|
206 |
+
|
207 |
+
self.gradient_checkpointing = False
|
208 |
+
# Initialize weights and apply final processing
|
209 |
+
self.post_init()
|
210 |
+
|
211 |
+
def get_input_embeddings(self):
|
212 |
+
return self.embed_tokens
|
213 |
+
|
214 |
+
def set_input_embeddings(self, value):
|
215 |
+
self.embed_tokens = value
|
216 |
+
|
217 |
+
@add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
|
218 |
+
def forward(
|
219 |
+
self,
|
220 |
+
input_ids: torch.LongTensor = None,
|
221 |
+
attention_mask: Optional[torch.Tensor] = None,
|
222 |
+
position_ids: Optional[torch.LongTensor] = None,
|
223 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
224 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
225 |
+
use_cache: Optional[bool] = None,
|
226 |
+
output_attentions: Optional[bool] = None,
|
227 |
+
output_hidden_states: Optional[bool] = None,
|
228 |
+
return_dict: Optional[bool] = None,
|
229 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
230 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
231 |
+
output_hidden_states = (
|
232 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
233 |
+
)
|
234 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
235 |
+
|
236 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
237 |
+
|
238 |
+
# retrieve input_ids and inputs_embeds
|
239 |
+
# if input_ids is not None and inputs_embeds is not None:
|
240 |
+
# raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
241 |
+
if input_ids is not None:
|
242 |
+
batch_size, seq_length = input_ids.shape[:2]
|
243 |
+
elif inputs_embeds is not None:
|
244 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
245 |
+
else:
|
246 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
247 |
+
|
248 |
+
if self.gradient_checkpointing and self.training:
|
249 |
+
if use_cache:
|
250 |
+
logger.warning_once(
|
251 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
252 |
+
)
|
253 |
+
use_cache = False
|
254 |
+
|
255 |
+
past_key_values_length = 0
|
256 |
+
if use_cache:
|
257 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
258 |
+
if use_legacy_cache:
|
259 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
260 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
261 |
+
|
262 |
+
if position_ids is None:
|
263 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
264 |
+
position_ids = torch.arange(
|
265 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
266 |
+
)
|
267 |
+
position_ids = position_ids.unsqueeze(0)
|
268 |
+
|
269 |
+
if inputs_embeds is None:
|
270 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
271 |
+
|
272 |
+
# Fix lora with gradient checkpointing training
|
273 |
+
if self.training and inputs_embeds.is_leaf:
|
274 |
+
inputs_embeds.requires_grad = True
|
275 |
+
|
276 |
+
if self._use_flash_attention_2:
|
277 |
+
# 2d mask is passed through the layers
|
278 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
279 |
+
elif self._use_sdpa and not output_attentions:
|
280 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
281 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
282 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
283 |
+
attention_mask,
|
284 |
+
(batch_size, seq_length),
|
285 |
+
inputs_embeds,
|
286 |
+
past_key_values_length,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
# 4d mask is passed through the layers
|
290 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
291 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
292 |
+
)
|
293 |
+
|
294 |
+
# embed positions
|
295 |
+
hidden_states = inputs_embeds
|
296 |
+
|
297 |
+
# decoder layers
|
298 |
+
all_hidden_states = () if output_hidden_states else None
|
299 |
+
all_self_attns = () if output_attentions else None
|
300 |
+
next_decoder_cache = None
|
301 |
+
|
302 |
+
prev_kv_states = None
|
303 |
+
for layer_idx, decoder_layer in enumerate(self.layers):
|
304 |
+
if output_hidden_states:
|
305 |
+
all_hidden_states += (hidden_states,)
|
306 |
+
|
307 |
+
if self.gradient_checkpointing and self.training:
|
308 |
+
layer_outputs = self._gradient_checkpointing_func(
|
309 |
+
decoder_layer.__call__,
|
310 |
+
hidden_states,
|
311 |
+
attention_mask,
|
312 |
+
position_ids,
|
313 |
+
past_key_values,
|
314 |
+
output_attentions,
|
315 |
+
use_cache,
|
316 |
+
prev_kv_states,
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
layer_outputs = decoder_layer(
|
320 |
+
hidden_states,
|
321 |
+
attention_mask=attention_mask,
|
322 |
+
position_ids=position_ids,
|
323 |
+
past_key_value=past_key_values,
|
324 |
+
output_attentions=output_attentions,
|
325 |
+
use_cache=use_cache,
|
326 |
+
kv_states=prev_kv_states
|
327 |
+
)
|
328 |
+
|
329 |
+
hidden_states = layer_outputs[0]
|
330 |
+
|
331 |
+
if use_cache:
|
332 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
333 |
+
|
334 |
+
if output_attentions:
|
335 |
+
all_self_attns += (layer_outputs[1],)
|
336 |
+
|
337 |
+
kv_states = layer_outputs[-1]
|
338 |
+
|
339 |
+
if self.cla and layer_idx % self.cla_share_factor == 0:
|
340 |
+
prev_kv_states = kv_states
|
341 |
+
if not self.add_classification_head:
|
342 |
+
hidden_states = self.norm(hidden_states)
|
343 |
+
|
344 |
+
# add hidden states from the last decoder layer
|
345 |
+
if output_hidden_states:
|
346 |
+
all_hidden_states += (hidden_states,)
|
347 |
+
|
348 |
+
next_cache = None
|
349 |
+
if use_cache:
|
350 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
351 |
+
if not return_dict:
|
352 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
353 |
+
return BaseModelOutputWithPast(
|
354 |
+
last_hidden_state=hidden_states,
|
355 |
+
past_key_values=next_cache,
|
356 |
+
hidden_states=all_hidden_states,
|
357 |
+
attentions=all_self_attns,
|
358 |
+
)
|
359 |
+
|
360 |
+
|
361 |
+
class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
362 |
+
_tied_weights_keys = ["lm_head.weight"]
|
363 |
+
|
364 |
+
def __init__(self, config: HunYuanConfig):
|
365 |
+
super().__init__(config)
|
366 |
+
if config.vit_path is not None:
|
367 |
+
if "-tp" in config.vit_type:
|
368 |
+
config.vit_type = config.vit_type.replace("-tp", "")
|
369 |
+
self.vit_type = config.vit_type
|
370 |
+
if self.vit_type not in ['NaVit', 'EvaVit']:
|
371 |
+
if config.vit_mapping_type == 'mlp':
|
372 |
+
self.vit_linear_encoder = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
373 |
+
self.vit = Vit(config)
|
374 |
+
else:
|
375 |
+
self.vit = None
|
376 |
+
self.config = config
|
377 |
+
self.model = HunYuanModel(config)
|
378 |
+
self.add_classification_head = config.add_classification_head
|
379 |
+
self.pad_id = config.pad_id
|
380 |
+
self.vocab_size = config.vocab_size
|
381 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
382 |
+
if config.add_classification_head:
|
383 |
+
self.pool_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
384 |
+
self.pool_head2 = nn.Linear(config.hidden_size, config.class_num, bias=False)
|
385 |
+
# Initialize weights and apply final processing
|
386 |
+
self.post_init()
|
387 |
+
|
388 |
+
def get_input_embeddings(self):
|
389 |
+
return self.model.embed_tokens
|
390 |
+
|
391 |
+
def set_input_embeddings(self, value):
|
392 |
+
self.model.embed_tokens = value
|
393 |
+
|
394 |
+
def get_output_embeddings(self):
|
395 |
+
return self.lm_head
|
396 |
+
|
397 |
+
def set_output_embeddings(self, new_embeddings):
|
398 |
+
self.lm_head = new_embeddings
|
399 |
+
|
400 |
+
def set_decoder(self, decoder):
|
401 |
+
self.model = decoder
|
402 |
+
|
403 |
+
def get_decoder(self):
|
404 |
+
return self.model
|
405 |
+
|
406 |
+
@add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
|
407 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
408 |
+
def forward(
|
409 |
+
self,
|
410 |
+
input_ids: torch.LongTensor = None,
|
411 |
+
attention_mask: Optional[torch.Tensor] = None,
|
412 |
+
position_ids: Optional[torch.LongTensor] = None,
|
413 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
414 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
415 |
+
labels: Optional[torch.LongTensor] = None,
|
416 |
+
use_cache: Optional[bool] = None,
|
417 |
+
output_attentions: Optional[bool] = None,
|
418 |
+
output_hidden_states: Optional[bool] = None,
|
419 |
+
return_dict: Optional[bool] = None,
|
420 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
421 |
+
r"""
|
422 |
+
Args:
|
423 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
424 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
425 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
426 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
|
430 |
+
Example:
|
431 |
+
|
432 |
+
```python
|
433 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
434 |
+
|
435 |
+
>>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
436 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
437 |
+
|
438 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
439 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
440 |
+
|
441 |
+
>>> # Generate
|
442 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
443 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
444 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
445 |
+
```"""
|
446 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
447 |
+
output_hidden_states = (
|
448 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
449 |
+
)
|
450 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
451 |
+
|
452 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
453 |
+
outputs = self.model(
|
454 |
+
input_ids=input_ids,
|
455 |
+
attention_mask=attention_mask,
|
456 |
+
position_ids=position_ids,
|
457 |
+
past_key_values=past_key_values,
|
458 |
+
inputs_embeds=inputs_embeds,
|
459 |
+
use_cache=use_cache,
|
460 |
+
output_attentions=output_attentions,
|
461 |
+
output_hidden_states=output_hidden_states,
|
462 |
+
return_dict=return_dict,
|
463 |
+
)
|
464 |
+
|
465 |
+
hidden_states = outputs[0]
|
466 |
+
|
467 |
+
if not self.add_classification_head:
|
468 |
+
if self.config.pretraining_tp > 1:
|
469 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
470 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
471 |
+
logits = torch.cat(logits, dim=-1)
|
472 |
+
else:
|
473 |
+
logits = self.lm_head(hidden_states)
|
474 |
+
logits = logits.float()
|
475 |
+
else:
|
476 |
+
logits = hidden_states
|
477 |
+
logits = logits.float()
|
478 |
+
pooled_output = self.pool_head(logits)
|
479 |
+
pooled_output = torch.tanh(pooled_output)
|
480 |
+
pooled_output = self.pool_head2(pooled_output).contiguous() # bs * class_num
|
481 |
+
if len(pooled_output.shape) < 2:
|
482 |
+
raise ValueError("pooled_output does not have enough dimensions for transpose")
|
483 |
+
|
484 |
+
if self.config.pool_type == "mean":
|
485 |
+
reward = pooled_output.mean(dim=1).squeeze(-1)
|
486 |
+
elif self.config.pool_type == "last":
|
487 |
+
# bs * hidden_size
|
488 |
+
seq_length = (input_ids != self.pad_id).long().sum(dim=1) - 1
|
489 |
+
batch_size = input_ids.size(0)
|
490 |
+
reward = pooled_output[torch.arange(batch_size, device=pooled_output.device), seq_length].squeeze(-1)
|
491 |
+
else:
|
492 |
+
reward = pooled_output[:, 0].squeeze(-1)
|
493 |
+
|
494 |
+
loss = None
|
495 |
+
if labels is not None:
|
496 |
+
# Shift so that tokens < n predict n
|
497 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
498 |
+
shift_labels = labels[..., 1:].contiguous()
|
499 |
+
# Flatten the tokens
|
500 |
+
loss_fct = CrossEntropyLoss()
|
501 |
+
shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
|
502 |
+
shift_labels = shift_labels.reshape(-1)
|
503 |
+
# Enable model parallelism
|
504 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
505 |
+
loss = loss_fct(shift_logits, shift_labels)
|
506 |
+
|
507 |
+
if not return_dict:
|
508 |
+
output = (logits,) + outputs[1:]
|
509 |
+
return (loss,) + output if loss is not None else output
|
510 |
+
|
511 |
+
output = CausalLMOutputWithPast(
|
512 |
+
loss=loss,
|
513 |
+
logits=logits,
|
514 |
+
past_key_values=outputs.past_key_values,
|
515 |
+
hidden_states=outputs.hidden_states,
|
516 |
+
attentions=outputs.attentions,
|
517 |
+
)
|
518 |
+
if self.add_classification_head:
|
519 |
+
output['reward'] = reward
|
520 |
+
|
521 |
+
return output
|
522 |
+
|
523 |
+
def prepare_inputs_for_generation(
|
524 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
525 |
+
):
|
526 |
+
if past_key_values is not None:
|
527 |
+
if isinstance(past_key_values, Cache):
|
528 |
+
cache_length = past_key_values.get_seq_length()
|
529 |
+
past_length = past_key_values.seen_tokens
|
530 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
531 |
+
else:
|
532 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
533 |
+
max_cache_length = None
|
534 |
+
|
535 |
+
# Keep only the unprocessed tokens:
|
536 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
537 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
538 |
+
# input)
|
539 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
540 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
|
541 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
542 |
+
# input_ids based on the past_length.
|
543 |
+
elif past_length < input_ids.shape[1]:
|
544 |
+
input_ids = input_ids[:, past_length:]
|
545 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
546 |
+
|
547 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
548 |
+
if (
|
549 |
+
max_cache_length is not None
|
550 |
+
and attention_mask is not None
|
551 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
552 |
+
):
|
553 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
554 |
+
|
555 |
+
position_ids = kwargs.get("position_ids", None)
|
556 |
+
if attention_mask is not None and position_ids is None:
|
557 |
+
# create position_ids on the fly for batch generation
|
558 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
559 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
560 |
+
if past_key_values:
|
561 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
562 |
+
|
563 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
564 |
+
if inputs_embeds is not None and past_key_values is None:
|
565 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
566 |
+
else:
|
567 |
+
model_inputs = {"input_ids": input_ids}
|
568 |
+
|
569 |
+
model_inputs.update(
|
570 |
+
{
|
571 |
+
"position_ids": position_ids,
|
572 |
+
"past_key_values": past_key_values,
|
573 |
+
"use_cache": kwargs.get("use_cache"),
|
574 |
+
"attention_mask": attention_mask,
|
575 |
+
}
|
576 |
+
)
|
577 |
+
return model_inputs
|
578 |
+
|
579 |
+
@staticmethod
|
580 |
+
def _reorder_cache(past_key_values, beam_idx):
|
581 |
+
reordered_past = ()
|
582 |
+
for layer_past in past_key_values:
|
583 |
+
reordered_past += (
|
584 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
585 |
+
)
|
586 |
+
return reordered_past
|
587 |
+
|
588 |
+
|
589 |
+
class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
|
590 |
+
_tied_weights_keys = ["lm_head.weight"]
|
591 |
+
|
592 |
+
def __init__(self, config: HunYuanConfig):
|
593 |
+
super().__init__(config)
|
594 |
+
|
595 |
+
@add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
|
596 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
597 |
+
def forward(
|
598 |
+
self,
|
599 |
+
input_ids: torch.LongTensor = None,
|
600 |
+
attention_mask: Optional[torch.Tensor] = None,
|
601 |
+
position_ids: Optional[torch.LongTensor] = None,
|
602 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
603 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
604 |
+
labels: Optional[torch.LongTensor] = None,
|
605 |
+
imgs: Optional[List[torch.FloatTensor]] = None,
|
606 |
+
imgs_pos: Optional[List[int]] = None,
|
607 |
+
use_cache: Optional[bool] = None,
|
608 |
+
output_attentions: Optional[bool] = None,
|
609 |
+
output_hidden_states: Optional[bool] = None,
|
610 |
+
return_dict: Optional[bool] = None,
|
611 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
612 |
+
r"""
|
613 |
+
Args:
|
614 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
615 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
616 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
617 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
618 |
+
|
619 |
+
Returns:
|
620 |
+
|
621 |
+
Example:
|
622 |
+
|
623 |
+
```python
|
624 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
625 |
+
|
626 |
+
>>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
627 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
628 |
+
|
629 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
630 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
631 |
+
|
632 |
+
>>> # Generate
|
633 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
634 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
635 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
636 |
+
```"""
|
637 |
+
mask_init_id = self.config.mask_init_id
|
638 |
+
pad_id = self.config.pad_token_id
|
639 |
+
eod_id = self.config.eod_token_id
|
640 |
+
image_token_id = self.config.image_token_id
|
641 |
+
im_start_id = self.config.im_start_id
|
642 |
+
im_end_id = self.config.im_end_id
|
643 |
+
video_start_id = self.config.video_start_id
|
644 |
+
video_end_id = self.config.video_end_id
|
645 |
+
|
646 |
+
if self.vit is not None and imgs is not None:
|
647 |
+
encoder_input = self.model.embed_tokens(input_ids)
|
648 |
+
if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
|
649 |
+
inputs_embeds, input_ids = NaVitForward(input_ids, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
|
650 |
+
im_start_id, im_end_id, image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
|
651 |
+
else:
|
652 |
+
inputs_embeds, input_ids = VitForward(input_ids, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
|
653 |
+
self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
|
654 |
+
|
655 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
656 |
+
output_hidden_states = (
|
657 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
658 |
+
)
|
659 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
660 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
661 |
+
|
662 |
+
outputs = self.model(
|
663 |
+
input_ids=input_ids,
|
664 |
+
attention_mask=attention_mask,
|
665 |
+
position_ids=position_ids,
|
666 |
+
past_key_values=past_key_values,
|
667 |
+
inputs_embeds=inputs_embeds,
|
668 |
+
use_cache=use_cache,
|
669 |
+
output_attentions=output_attentions,
|
670 |
+
output_hidden_states=output_hidden_states,
|
671 |
+
return_dict=return_dict,
|
672 |
+
)
|
673 |
+
|
674 |
+
hidden_states = outputs[0]
|
675 |
+
if self.config.pretraining_tp > 1:
|
676 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
677 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
678 |
+
logits = torch.cat(logits, dim=-1)
|
679 |
+
else:
|
680 |
+
logits = self.lm_head(hidden_states)
|
681 |
+
logits = logits.float()
|
682 |
+
|
683 |
+
loss = None
|
684 |
+
if labels is not None:
|
685 |
+
labels = labels.to(logits.device)
|
686 |
+
# Shift so that tokens < n predict n
|
687 |
+
shift_logits = logits
|
688 |
+
shift_labels = labels
|
689 |
+
# Flatten the tokens
|
690 |
+
loss_fct = CrossEntropyLoss()
|
691 |
+
shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
|
692 |
+
shift_labels = shift_labels.reshape(-1)
|
693 |
+
shift_tokens = input_ids.reshape(-1)
|
694 |
+
# compute loss
|
695 |
+
mask = (shift_labels < mask_init_id) & (shift_labels != pad_id) & (shift_labels != image_token_id) & (shift_labels != im_start_id) \
|
696 |
+
& (shift_labels != im_end_id) & (shift_labels != video_start_id) & (shift_labels != video_end_id) & (shift_tokens != pad_id) & (shift_tokens != eod_id)
|
697 |
+
shift_logits = shift_logits[mask, :]
|
698 |
+
shift_labels = shift_labels[mask]
|
699 |
+
loss = loss_fct(shift_logits, shift_labels)
|
700 |
+
|
701 |
+
if not return_dict:
|
702 |
+
output = (logits,) + outputs[1:]
|
703 |
+
return (loss,) + output if loss is not None else output
|
704 |
+
|
705 |
+
return CausalLMOutputWithPast(
|
706 |
+
loss=loss,
|
707 |
+
logits=logits,
|
708 |
+
past_key_values=outputs.past_key_values,
|
709 |
+
hidden_states=outputs.hidden_states,
|
710 |
+
attentions=outputs.attentions,
|
711 |
+
)
|
712 |
+
|
713 |
+
def prepare_inputs_for_generation(
|
714 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
715 |
+
):
|
716 |
+
imgs = kwargs.pop("imgs", None)
|
717 |
+
imgs_pos = kwargs.pop("imgs_pos", None)
|
718 |
+
inputs = super().prepare_inputs_for_generation(
|
719 |
+
input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs
|
720 |
+
)
|
721 |
+
|
722 |
+
if imgs is not None:
|
723 |
+
inputs['imgs'] = imgs
|
724 |
+
if imgs_pos is not None:
|
725 |
+
inputs['imgs_pos'] = imgs_pos
|
726 |
+
return inputs
|
727 |
+
|
728 |
+
@torch.no_grad()
|
729 |
+
def generate(
|
730 |
+
self,
|
731 |
+
inputs: Optional[torch.Tensor] = None,
|
732 |
+
attention_mask: Optional[torch.Tensor] = None,
|
733 |
+
position_ids: Optional[torch.LongTensor] = None,
|
734 |
+
imgs: Optional[List[torch.FloatTensor]] = None,
|
735 |
+
imgs_pos: Optional[List[int]] = None,
|
736 |
+
**kwargs,
|
737 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
738 |
+
if "inputs_embeds" in kwargs:
|
739 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
740 |
+
|
741 |
+
if self.vit is not None:
|
742 |
+
encoder_input = self.model.embed_tokens(inputs)
|
743 |
+
if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
|
744 |
+
inputs_embeds, input_ids = NaVitForward(inputs, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
|
745 |
+
self.config.im_start_id, self.config.im_end_id, self.config.image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
|
746 |
+
else:
|
747 |
+
inputs_embeds, input_ids = VitForward(inputs, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
|
748 |
+
self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
|
749 |
+
|
750 |
+
return super().generate(
|
751 |
+
inputs=input_ids,
|
752 |
+
position_ids=position_ids,
|
753 |
+
attention_mask=attention_mask,
|
754 |
+
inputs_embeds=inputs_embeds,
|
755 |
+
eos_token_id=self.config.eod_token_id,
|
756 |
+
**kwargs
|
757 |
+
)
|
758 |
+
|
759 |
+
|
760 |
+
@add_start_docstrings(
|
761 |
+
"""
|
762 |
+
The HunYuan Model transformer with a sequence classification head on top (linear layer).
|
763 |
+
|
764 |
+
[`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
765 |
+
(e.g. GPT-2) do.
|
766 |
+
|
767 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
768 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
769 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
770 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
771 |
+
each row of the batch).
|
772 |
+
""",
|
773 |
+
HUNYUAN_START_DOCSTRING,
|
774 |
+
)
|
775 |
+
class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
776 |
+
def __init__(self, config):
|
777 |
+
super().__init__(config)
|
778 |
+
self.num_labels = config.num_labels
|
779 |
+
self.model = HunYuanModel(config)
|
780 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
781 |
+
|
782 |
+
# Initialize weights and apply final processing
|
783 |
+
self.post_init()
|
784 |
+
|
785 |
+
def get_input_embeddings(self):
|
786 |
+
return self.model.embed_tokens
|
787 |
+
|
788 |
+
def set_input_embeddings(self, value):
|
789 |
+
self.model.embed_tokens = value
|
790 |
+
|
791 |
+
@add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
|
792 |
+
def forward(
|
793 |
+
self,
|
794 |
+
input_ids: torch.LongTensor = None,
|
795 |
+
attention_mask: Optional[torch.Tensor] = None,
|
796 |
+
position_ids: Optional[torch.LongTensor] = None,
|
797 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
798 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
799 |
+
labels: Optional[torch.LongTensor] = None,
|
800 |
+
use_cache: Optional[bool] = None,
|
801 |
+
output_attentions: Optional[bool] = None,
|
802 |
+
output_hidden_states: Optional[bool] = None,
|
803 |
+
return_dict: Optional[bool] = None,
|
804 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
805 |
+
r"""
|
806 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
807 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
808 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
809 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
810 |
+
"""
|
811 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
812 |
+
|
813 |
+
transformer_outputs = self.model(
|
814 |
+
input_ids,
|
815 |
+
attention_mask=attention_mask,
|
816 |
+
position_ids=position_ids,
|
817 |
+
past_key_values=past_key_values,
|
818 |
+
inputs_embeds=inputs_embeds,
|
819 |
+
use_cache=use_cache,
|
820 |
+
output_attentions=output_attentions,
|
821 |
+
output_hidden_states=output_hidden_states,
|
822 |
+
return_dict=return_dict,
|
823 |
+
)
|
824 |
+
hidden_states = transformer_outputs[0]
|
825 |
+
logits = self.score(hidden_states)
|
826 |
+
|
827 |
+
if input_ids is not None:
|
828 |
+
batch_size = input_ids.shape[0]
|
829 |
+
else:
|
830 |
+
batch_size = inputs_embeds.shape[0]
|
831 |
+
|
832 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
833 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
834 |
+
if self.config.pad_token_id is None:
|
835 |
+
sequence_lengths = -1
|
836 |
+
else:
|
837 |
+
if input_ids is not None:
|
838 |
+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
|
839 |
+
logits.device
|
840 |
+
)
|
841 |
+
else:
|
842 |
+
sequence_lengths = -1
|
843 |
+
|
844 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
845 |
+
|
846 |
+
loss = None
|
847 |
+
if labels is not None:
|
848 |
+
labels = labels.to(logits.device)
|
849 |
+
if self.config.problem_type is None:
|
850 |
+
if self.num_labels == 1:
|
851 |
+
self.config.problem_type = "regression"
|
852 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
853 |
+
self.config.problem_type = "single_label_classification"
|
854 |
+
else:
|
855 |
+
self.config.problem_type = "multi_label_classification"
|
856 |
+
|
857 |
+
if self.config.problem_type == "regression":
|
858 |
+
loss_fct = MSELoss()
|
859 |
+
if self.num_labels == 1:
|
860 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
861 |
+
else:
|
862 |
+
loss = loss_fct(pooled_logits, labels)
|
863 |
+
elif self.config.problem_type == "single_label_classification":
|
864 |
+
loss_fct = CrossEntropyLoss()
|
865 |
+
loss = loss_fct(pooled_logits.reshape(-1, self.num_labels), labels.reshape(-1))
|
866 |
+
elif self.config.problem_type == "multi_label_classification":
|
867 |
+
loss_fct = BCEWithLogitsLoss()
|
868 |
+
loss = loss_fct(pooled_logits, labels)
|
869 |
+
if not return_dict:
|
870 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
871 |
+
return ((loss,) + output) if loss is not None else output
|
872 |
+
|
873 |
+
return SequenceClassifierOutputWithPast(
|
874 |
+
loss=loss,
|
875 |
+
logits=pooled_logits,
|
876 |
+
past_key_values=transformer_outputs.past_key_values,
|
877 |
+
hidden_states=transformer_outputs.hidden_states,
|
878 |
+
attentions=transformer_outputs.attentions,
|
879 |
+
)
|
hy.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94e14569aa3cb80bc04f48d943f40cc11d6734680315189e9502f9b9c9bee038
|
3 |
+
size 40963281008
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8f4d6ac61f5d2cffcfdc54aa22a249bf003b3a72514f08befddf8cc516a47bc
|
3 |
+
size 39938503616
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e72f67b554a3105b4bfea7a3670c31d746aafa0c6bbba0da81ca29b2fc26354e
|
3 |
+
size 39938503688
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:228fa272d42ef720223fc1596bc31b35e46f394857f04d5182598eb715699aa2
|
3 |
+
size 39963677928
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenization_hy.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import unicodedata
|
5 |
+
from typing import Collection, Dict, List, Set, Tuple, Union
|
6 |
+
|
7 |
+
import tiktoken
|
8 |
+
from transformers import PreTrainedTokenizer, AddedToken
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
VOCAB_FILES_NAMES = {"vocab_file": "hy.tiktoken"}
|
14 |
+
|
15 |
+
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
16 |
+
# PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
17 |
+
ENDOFTEXT = "<|endoftext|>"
|
18 |
+
STARTOFTEXT = "<|startoftext|>"
|
19 |
+
BOSTOKEN = "<|bos|>"
|
20 |
+
EOSTOKEN = "<|eos|>"
|
21 |
+
PADTOKEN = "<|pad|>"
|
22 |
+
|
23 |
+
# as the default behavior is changed to allow special tokens in
|
24 |
+
# regular texts, the surface forms of special tokens need to be
|
25 |
+
# as different as possible to minimize the impact
|
26 |
+
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
27 |
+
# changed to use actual index to avoid misconfiguration with vocabulary expansion
|
28 |
+
|
29 |
+
|
30 |
+
SPECIAL_START_ID = 127957
|
31 |
+
|
32 |
+
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
33 |
+
# with open(tiktoken_bpe_file, "rb", encoding="utf-8") as f:
|
34 |
+
# contents = f.read()
|
35 |
+
dic = {}
|
36 |
+
rank = 0
|
37 |
+
for line in open(tiktoken_bpe_file, "rb"):
|
38 |
+
if line:
|
39 |
+
token, _ = line.split()
|
40 |
+
if base64.b64decode(token) in dic:
|
41 |
+
continue
|
42 |
+
dic[base64.b64decode(token)] = int(rank)
|
43 |
+
rank += 1
|
44 |
+
global SPECIAL_START_ID
|
45 |
+
SPECIAL_START_ID=rank
|
46 |
+
return dic
|
47 |
+
|
48 |
+
# NOTE: Please use the code line to check `SPECIAL_START_ID` right, this will affect the SPECIAL_START_ID
|
49 |
+
# _load_tiktoken_bpe('/apdcephfs/share_1502809/shaneshu/tokenizer_exp/other_tokenizer_vocab/hy/' + VOCAB_FILES_NAMES['vocab_file'])
|
50 |
+
# print(SPECIAL_START_ID)
|
51 |
+
|
52 |
+
SPECIAL_TOKENS = tuple(
|
53 |
+
enumerate(
|
54 |
+
(
|
55 |
+
(
|
56 |
+
ENDOFTEXT,
|
57 |
+
STARTOFTEXT,
|
58 |
+
BOSTOKEN,
|
59 |
+
EOSTOKEN,
|
60 |
+
PADTOKEN,
|
61 |
+
)
|
62 |
+
+ EXTRAS
|
63 |
+
),
|
64 |
+
start=SPECIAL_START_ID,
|
65 |
+
)
|
66 |
+
)
|
67 |
+
# NOTE: Unused Token ID starts from 127962
|
68 |
+
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
|
69 |
+
|
70 |
+
class HYTokenizer(PreTrainedTokenizer):
|
71 |
+
"""hunyuan tokenizer."""
|
72 |
+
|
73 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
vocab_file,
|
78 |
+
errors="replace",
|
79 |
+
extra_vocab_file=None,
|
80 |
+
**kwargs,
|
81 |
+
):
|
82 |
+
super().__init__(**kwargs)
|
83 |
+
|
84 |
+
# how to handle errors in decoding UTF-8 byte sequences
|
85 |
+
# use ignore if you are in streaming inference
|
86 |
+
self.errors = errors
|
87 |
+
|
88 |
+
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
|
89 |
+
self.special_tokens = {
|
90 |
+
token: index
|
91 |
+
for index, token in SPECIAL_TOKENS
|
92 |
+
}
|
93 |
+
|
94 |
+
# try load extra vocab from file
|
95 |
+
if extra_vocab_file is not None:
|
96 |
+
used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
|
97 |
+
extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
|
98 |
+
for token, index in extra_mergeable_ranks.items():
|
99 |
+
if token in self.mergeable_ranks:
|
100 |
+
logger.info(f"extra token {token} exists, skipping")
|
101 |
+
continue
|
102 |
+
if index in used_ids:
|
103 |
+
logger.info(f'the index {index} for extra token {token} exists, skipping')
|
104 |
+
continue
|
105 |
+
self.mergeable_ranks[token] = index
|
106 |
+
# the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
|
107 |
+
|
108 |
+
enc = tiktoken.Encoding(
|
109 |
+
"HunYuan",
|
110 |
+
pat_str=PAT_STR,
|
111 |
+
mergeable_ranks=self.mergeable_ranks,
|
112 |
+
special_tokens=self.special_tokens,
|
113 |
+
)
|
114 |
+
assert (
|
115 |
+
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
|
116 |
+
), f"{len(self.mergeable_ranks)} + {len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
117 |
+
|
118 |
+
self.decoder = {
|
119 |
+
v: k for k, v in self.mergeable_ranks.items()
|
120 |
+
} # type: dict[int, bytes|str]
|
121 |
+
self.decoder.update({v: k for k, v in self.special_tokens.items()})
|
122 |
+
|
123 |
+
self.tokenizer = enc # type: tiktoken.Encoding
|
124 |
+
|
125 |
+
self.eod_id = self.tokenizer.eot_token
|
126 |
+
self.bod_id = self.special_tokens[STARTOFTEXT]
|
127 |
+
self.bos_id = self.special_tokens[BOSTOKEN]
|
128 |
+
self.eos_id = self.special_tokens[EOSTOKEN]
|
129 |
+
self.pad_id = self.special_tokens[PADTOKEN]
|
130 |
+
|
131 |
+
def __getstate__(self):
|
132 |
+
# for pickle lovers
|
133 |
+
state = self.__dict__.copy()
|
134 |
+
del state["tokenizer"]
|
135 |
+
return state
|
136 |
+
|
137 |
+
def __setstate__(self, state):
|
138 |
+
# tokenizer is not python native; don't pass it; rebuild it
|
139 |
+
self.__dict__.update(state)
|
140 |
+
enc = tiktoken.Encoding(
|
141 |
+
"HunYuan",
|
142 |
+
pat_str=PAT_STR,
|
143 |
+
mergeable_ranks=self.mergeable_ranks,
|
144 |
+
special_tokens=self.special_tokens,
|
145 |
+
)
|
146 |
+
self.tokenizer = enc
|
147 |
+
|
148 |
+
def __len__(self) -> int:
|
149 |
+
return self.tokenizer.n_vocab
|
150 |
+
|
151 |
+
def get_vocab(self) -> Dict[bytes, int]:
|
152 |
+
return self.mergeable_ranks
|
153 |
+
|
154 |
+
def convert_tokens_to_ids(
|
155 |
+
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
|
156 |
+
) -> List[int]:
|
157 |
+
ids = []
|
158 |
+
if isinstance(tokens, (str, bytes)):
|
159 |
+
if tokens in self.special_tokens:
|
160 |
+
return self.special_tokens[tokens]
|
161 |
+
else:
|
162 |
+
return self.mergeable_ranks.get(tokens)
|
163 |
+
for token in tokens:
|
164 |
+
if token in self.special_tokens:
|
165 |
+
ids.append(self.special_tokens[token])
|
166 |
+
else:
|
167 |
+
ids.append(self.mergeable_ranks.get(token))
|
168 |
+
return ids
|
169 |
+
|
170 |
+
def _add_tokens(
|
171 |
+
self,
|
172 |
+
new_tokens: Union[List[str], List[AddedToken]],
|
173 |
+
special_tokens: bool = False,
|
174 |
+
) -> int:
|
175 |
+
if not special_tokens and new_tokens:
|
176 |
+
raise ValueError("Adding regular tokens is not supported")
|
177 |
+
for token in new_tokens:
|
178 |
+
surface_form = token.content if isinstance(token, AddedToken) else token
|
179 |
+
if surface_form not in SPECIAL_TOKENS_SET:
|
180 |
+
raise ValueError("Adding unknown special tokens is not supported")
|
181 |
+
return 0
|
182 |
+
|
183 |
+
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
184 |
+
"""
|
185 |
+
Save only the vocabulary of the tokenizer (vocabulary).
|
186 |
+
Returns:
|
187 |
+
`Tuple(str)`: Paths to the files saved.
|
188 |
+
"""
|
189 |
+
file_path = os.path.join(save_directory, "hunyuan.tiktoken")
|
190 |
+
with open(file_path, "w", encoding="utf-8") as w:
|
191 |
+
for k, v in self.mergeable_ranks.items():
|
192 |
+
line = base64.b64encode(k).decode("utf-8") + " " + str(v) + "\n"
|
193 |
+
w.write(line)
|
194 |
+
return (file_path,)
|
195 |
+
|
196 |
+
def tokenize(
|
197 |
+
self,
|
198 |
+
text: str,
|
199 |
+
allowed_special: Union[Set, str] = "all",
|
200 |
+
disallowed_special: Union[Collection, str] = (),
|
201 |
+
**kwargs,
|
202 |
+
) -> List[Union[bytes, str]]:
|
203 |
+
"""
|
204 |
+
Converts a string in a sequence of tokens.
|
205 |
+
Args:
|
206 |
+
text (`str`):
|
207 |
+
The sequence to be encoded.
|
208 |
+
allowed_special (`Literal["all"]` or `set`):
|
209 |
+
The surface forms of the tokens to be encoded as special tokens in regular texts.
|
210 |
+
Default to "all".
|
211 |
+
disallowed_special (`Literal["all"]` or `Collection`):
|
212 |
+
The surface forms of the tokens that should not be in regular texts and trigger errors.
|
213 |
+
Default to an empty tuple.
|
214 |
+
kwargs (additional keyword arguments, *optional*):
|
215 |
+
Will be passed to the underlying model specific encode method.
|
216 |
+
Returns:
|
217 |
+
`List[bytes|str]`: The list of tokens.
|
218 |
+
"""
|
219 |
+
tokens = []
|
220 |
+
text = unicodedata.normalize("NFC", text)
|
221 |
+
|
222 |
+
# this implementation takes a detour: text -> token id -> token surface forms
|
223 |
+
for t in self.tokenizer.encode(
|
224 |
+
text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
225 |
+
):
|
226 |
+
tokens.append(self.decoder[t])
|
227 |
+
return tokens
|
228 |
+
|
229 |
+
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
|
230 |
+
"""
|
231 |
+
Converts a sequence of tokens in a single string.
|
232 |
+
"""
|
233 |
+
text = ""
|
234 |
+
temp = b""
|
235 |
+
for t in tokens:
|
236 |
+
if isinstance(t, str):
|
237 |
+
if temp:
|
238 |
+
text += temp.decode("utf-8", errors=self.errors)
|
239 |
+
temp = b""
|
240 |
+
text += t
|
241 |
+
elif isinstance(t, bytes):
|
242 |
+
temp += t
|
243 |
+
else:
|
244 |
+
raise TypeError("token should only be of type types or str")
|
245 |
+
if temp:
|
246 |
+
text += temp.decode("utf-8", errors=self.errors)
|
247 |
+
return text
|
248 |
+
|
249 |
+
@property
|
250 |
+
def vocab_size(self):
|
251 |
+
return self.tokenizer.n_vocab
|
252 |
+
|
253 |
+
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
|
254 |
+
"""Converts an id to a token, special tokens included"""
|
255 |
+
if index in self.decoder:
|
256 |
+
return self.decoder[index]
|
257 |
+
raise ValueError("unknown ids")
|
258 |
+
|
259 |
+
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
|
260 |
+
"""Converts a token to an id using the vocab, special tokens included"""
|
261 |
+
if token in self.special_tokens:
|
262 |
+
return self.special_tokens[token]
|
263 |
+
if token in self.mergeable_ranks:
|
264 |
+
return self.mergeable_ranks[token]
|
265 |
+
raise ValueError("unknown token")
|
266 |
+
|
267 |
+
def _tokenize(self, text: str, **kwargs):
|
268 |
+
"""
|
269 |
+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
270 |
+
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
|
271 |
+
Do NOT take care of added tokens.
|
272 |
+
"""
|
273 |
+
raise NotImplementedError
|
274 |
+
|
275 |
+
def _decode(
|
276 |
+
self,
|
277 |
+
token_ids: Union[int, List[int]],
|
278 |
+
skip_special_tokens: bool = False,
|
279 |
+
errors: str = None,
|
280 |
+
**kwargs,
|
281 |
+
) -> str:
|
282 |
+
if isinstance(token_ids, int):
|
283 |
+
token_ids = [token_ids]
|
284 |
+
if skip_special_tokens:
|
285 |
+
token_ids = [i for i in token_ids if i < self.eod_id]
|
286 |
+
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
|
287 |
+
|
288 |
+
# tests
|
289 |
+
if __name__ == "__main__":
|
290 |
+
tokenizer = HYTokenizer.from_pretrained('./hy')
|
291 |
+
text = '你好,世界'
|
292 |
+
tokens = tokenizer.tokenize(text)
|
293 |
+
print(tokens)
|
294 |
+
ids = tokenizer.convert_tokens_to_ids(tokens)
|
295 |
+
print(ids)
|
296 |
+
text2 = tokenizer.convert_tokens_to_string(tokens)
|
297 |
+
print(text2)
|
298 |
+
ids2 = tokenizer.convert_tokens_to_ids(tokens)
|
tokenizer_config.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"GPT2LMHeadModel"
|
4 |
+
],
|
5 |
+
"model_max_length": 1048576,
|
6 |
+
"tokenizer_class": "HYTokenizer",
|
7 |
+
"auto_map": {
|
8 |
+
"AutoTokenizer": [
|
9 |
+
"tokenization_hy.HYTokenizer",
|
10 |
+
null
|
11 |
+
]
|
12 |
+
},
|
13 |
+
"eos_token": "<|eos|>",
|
14 |
+
"model_type": "gpt2",
|
15 |
+
"additional_special_tokens": ["<|startoftext|>", "<|extra_0|>", "<|extra_4|>", "<|extra_5|>", "<|eos|>"],
|
16 |
+
"pad_token": "<|pad|>",
|
17 |
+
"chat_template": "{% set loop_messages = messages %}\n{% if tools %}\n {% set weekday_map = {'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三', 'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'} %}\n {% set weekday_cn = weekday_map[strftime_now('%A')] %}\n {% set datetime_str = strftime_now('%Y-%m-%d %H:%M:%S') %}\n {% set datetime_str = datetime_str + ' ' + weekday_cn %}\n {% for message in loop_messages %}\n {% if 'content' in message %}\n {% set content = message['content'] %}\n {% else %}\n {% set content = '' %}\n {% endif %}\n {% if loop.index0 == 0 %}\n {% set content_tmp = '你是一位函数组合专家。你会得到一个问题和一组可能的函数。根据问题,你需要进行一个或多个函数/工具调用以实现目的。\n如果没有一个函数可以使用,请直接使用自然语言回复用户,以助手:开头。\n如果给定的问题缺少函数所需的参数,请使用自然语言进行提问,向用户询问必要信息,以助手:开头。\n如果调用结果已经足够回答用户问题,请对历史结果进行总结,使用自然语言回复用户,以助手:开头。\n你应该只在工具调用部分返回函数调用。如果你决定调用任何函数,你必须将其格式化为<tool_calls>[{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},...]</tool_calls>。你不应该在回复中包含任何其他文本。以下是你可以调用的函数列表,格式为JSON。\n' %}\n {% set content_tmp = content_tmp + '\n' + tools | tojson + '\n' %}\n {% if message['role'] == 'system' %}\n {% set content_tmp = content_tmp + '\n额外要求:\n' + content + '\n\n如果你决定返回函数调用,请将其格式化为<tool_calls>[{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},...]</tool_calls>,不得包含其他文本。如果额外要求里有格式要求,请忽略,以此处为准。\n否则,请参考开头说的三种情况,以助手:开头进行回复。\n\n如果额外要求里有时间信息,就以额外要求里的时间为准,否则,参考当前时间:' + datetime_str %}\n {% set content = '<|startoftext|>' + content_tmp + '<|extra_4|>' %}\n {% elif message['role'] == 'user' %}\n {% set content_tmp = content_tmp + '\n如果你决定返回函数调用,请将其格式化为<tool_calls>[{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},...]</tool_calls>,不得包含其他文本。\n否则,请参考开头说的三种情况,以助手:开头进行回复。\n\n当前时间:' + datetime_str %}\n {% set content_tmp = '<|startoftext|>' + content_tmp + '<|extra_4|>'%}\n {% set content = content_tmp + '用户:' + content + '<|extra_0|>' %}\n {% endif %}\n {% else %}\n {% if message['role'] == 'user' %}\n {% set content = '用户:' + content + '<|extra_0|>' %}\n {% elif message['role'] == 'assistant' %}\n {% if 'tool_calls' in message %}\n {% set tool_calls = message['tool_calls'] %}\n {% set ns = namespace(tool_calls=\"[\") %}\n {% for tool_call in tool_calls %}\n {% set function = tool_call['function'] %}\n {% set name = function['name'] %}\n {% set ns.tool_calls = ns.tool_calls + '{\"name\": \"' + name + '\", '%}\n {% set arguments = function['arguments'] %}\n {% if arguments is not string %}\n {% set arguments = arguments | tojson %}\n {% endif %}\n {% set ns.tool_calls = ns.tool_calls + '\"arguments\": ' + arguments + '}' %}\n {% if not loop.last %}\n {% set ns.tool_calls = ns.tool_calls + ', '%}\n {% endif %}\n {% endfor %}\n {% set ns.tool_calls = ns.tool_calls + ']' %}\n {% set content = content + '<tool_calls>' + ns.tool_calls + '</tool_calls>' %}\n {% else %}\n {% set content = '助手:' + content %}\n {% endif %}\n {% set content = content + '<|eos|>' %}\n {% elif message['role'] == 'tool' %}\n {% if content is not string %}\n {set content = content | tojson }\n {% endif %}\n {% set content = '<tool_response>' + content + '</tool_response>' %}\n {% set content = content + '<|extra_0|>' %}\n {% endif %}\n {% endif %}\n {{- content -}}\n {% endfor %}\n{% else %}\n {% set context = {'has_head': true} %}\n {% for message in loop_messages %}\n {% if 'content' in message %}\n {% set content = message['content'] %}\n {% else %}\n {% set content = '' %}\n {% endif %}\n {% if loop.index0 == 0 %}\n {% if content == '' %}\n {% set _ = context.update({'has_head': false}) %}\n {% elif message['role'] == 'system' %}\n {% set content = '<|startoftext|>' + content + '<|extra_4|>' %}\n {% endif %}\n {% endif %}\n {% if message['role'] == 'user' %}\n {% if loop.index0 == 1 and not context.has_head %}\n {% set content = '<|startoftext|>' + content %}\n {% endif %}\n {% if loop.index0 == 1 and context.has_head %}\n {% set content = content + '<|extra_0|>' %}\n {% else %}\n {% set content = '<|startoftext|>' + content + '<|extra_0|>' %}\n {% endif %}\n {% elif message['role'] == 'assistant' %}\n {% set content = content + '<|eos|>' %}\n {% elif message['role'] == 'tool' %}\n {% set content = content + '<|extra_0|>' %}\n {% endif %}\n {{- content -}}\n {% endfor %}\n{% endif %}\n{%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n' }}\n{%- endif %}"
|
18 |
+
}
|
vit_model.py
ADDED
@@ -0,0 +1,1083 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import types
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch import Tensor, nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from typing import List, Tuple, Optional, Union
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from transformers.modeling_attn_mask_utils import (
|
10 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
11 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
12 |
+
_prepare_4d_causal_attention_mask,
|
13 |
+
)
|
14 |
+
from transformers.models.clip.configuration_clip import CLIPVisionConfig
|
15 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
16 |
+
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
|
17 |
+
from .configuration_hunyuan import HunYuanConfig
|
18 |
+
|
19 |
+
|
20 |
+
def NaVitForward(input_ids, encoder_input, vit, image_tensors, images_pos, vit_input_resolution, im_start_id, im_end_id, image_token_id, anyres_vit_two_views, dtype):
|
21 |
+
# input_ids: (B, L)
|
22 |
+
# encoder_input: (L, B, E)
|
23 |
+
# image_tensors [[Tensor],...,[Tensor]]
|
24 |
+
# image_pos [[Tensor],...,[Tensor]]
|
25 |
+
# tokenizer = get_tokenizer()
|
26 |
+
b = len(input_ids)
|
27 |
+
img_embs = None
|
28 |
+
all_nums = sum([len(tensors) for tensors in image_tensors]) if image_tensors else 0
|
29 |
+
if all_nums != 0:
|
30 |
+
img_embs, img_batch_pos = vit(image_tensors)
|
31 |
+
else:
|
32 |
+
# when no input image, initialize a fake tensor
|
33 |
+
pad_nums = 1
|
34 |
+
image_tensors = [[torch.rand(3, vit_input_resolution, vit_input_resolution, dtype=dtype, device=torch.cuda.current_device()) for _ in range(pad_nums)]]
|
35 |
+
img_embs, img_batch_pos = vit(image_tensors)
|
36 |
+
|
37 |
+
encoder_input = encoder_input.clone()
|
38 |
+
if all_nums > 0:
|
39 |
+
assert len(images_pos) == len(img_batch_pos), \
|
40 |
+
(len(images_pos), len(img_batch_pos))
|
41 |
+
start_token_id = im_start_id
|
42 |
+
end_token_id = im_end_id
|
43 |
+
placeholder_id = image_token_id
|
44 |
+
for idx in range(len(images_pos)):
|
45 |
+
assert len(images_pos[idx]) == len(img_batch_pos[idx]), \
|
46 |
+
(len(images_pos[idx]), len(img_batch_pos[idx]))
|
47 |
+
for p_img_pos_in_batch, p_batch_img_pos in zip(img_batch_pos[idx], images_pos[idx]):
|
48 |
+
# the positions to be filled [s_start, s_end)
|
49 |
+
s_idx, s_start, s_end = p_img_pos_in_batch
|
50 |
+
current_embs = img_embs[s_idx, s_start:s_end]
|
51 |
+
im_s, im_e = p_batch_img_pos
|
52 |
+
assert len(current_embs) == im_e - im_s, \
|
53 |
+
(img_embs.shape, (s_start, s_end, s_idx), current_embs.shape, (im_s, im_e, idx))
|
54 |
+
if not anyres_vit_two_views:
|
55 |
+
assert input_ids[idx, im_s - 1] == start_token_id, \
|
56 |
+
input_ids[idx, im_s - 1]
|
57 |
+
assert input_ids[idx, im_e] == end_token_id, \
|
58 |
+
input_ids[idx, im_e]
|
59 |
+
assert (input_ids[idx, im_s:im_e] == placeholder_id).all(), \
|
60 |
+
f'The tokens to be filled are not the placeholder_id {placeholder_id}: {(input_ids[idx, im_s:im_e] == placeholder_id).sum()} vs {im_e - im_s}'
|
61 |
+
encoder_input[idx, im_s:im_e] = current_embs
|
62 |
+
else:
|
63 |
+
# when no input image, to mask vit value
|
64 |
+
vit_mask = torch.zeros([1, img_embs.shape[0]], device=torch.cuda.current_device())
|
65 |
+
current_embs = img_embs[0, :]
|
66 |
+
encoder_input[0, 1:img_embs.shape[0] + 1] = encoder_input[0, 1:img_embs.shape[0] + 1] * (1 - vit_mask) + current_embs * vit_mask
|
67 |
+
return encoder_input, input_ids
|
68 |
+
|
69 |
+
|
70 |
+
def VitForward(input_ids, encoder_input, vit, vit_linear_encoder, image_tensors, images_pos, vit_input_resolution, vit_mapping_type, vit_patch, vit_token):
|
71 |
+
vit_patch_mlp = (vit_patch > 1 and vit_mapping_type == 'mlp') or vit_patch == 0
|
72 |
+
|
73 |
+
b = len(input_ids)
|
74 |
+
if images_pos is None:
|
75 |
+
images_pos = torch.ones([len(input_ids), 1, 3])
|
76 |
+
images_pos[:, :, 1] = images_pos[:, :, 1]*(vit_token + 1)
|
77 |
+
images_pos = images_pos.long()
|
78 |
+
|
79 |
+
real_image_nums = []
|
80 |
+
image_tensors = image_tensors.view(b, -1, 3, vit_input_resolution, vit_input_resolution)
|
81 |
+
real_images = []
|
82 |
+
|
83 |
+
all_nums = 0
|
84 |
+
img_index = []
|
85 |
+
for s in range(len(images_pos)):
|
86 |
+
real_image_num = 0
|
87 |
+
for (im_s, im_e,index) in images_pos[s]:
|
88 |
+
if im_s == -1:
|
89 |
+
break
|
90 |
+
real_image_num += 1
|
91 |
+
all_nums += 1
|
92 |
+
img_index.append(index)
|
93 |
+
|
94 |
+
real_image_nums.append(real_image_num)
|
95 |
+
real_images.append(image_tensors[s][:real_image_num])
|
96 |
+
|
97 |
+
if vit_patch == 1:
|
98 |
+
img_index = None
|
99 |
+
|
100 |
+
if all_nums == 0:
|
101 |
+
# when no input image, initialize a fake tensor
|
102 |
+
img_input = torch.rand(b, 3, vit_input_resolution, vit_input_resolution).cuda().type(image_tensors.dtype)
|
103 |
+
img_embs = vit(img_input)
|
104 |
+
img_embs = vit_linear_encoder(img_embs)
|
105 |
+
else:
|
106 |
+
img_input = torch.cat(real_images)
|
107 |
+
img_embs = vit(img_input, img_index = img_index)
|
108 |
+
img_embs = vit_linear_encoder(img_embs)
|
109 |
+
|
110 |
+
encoder_input = encoder_input.clone()
|
111 |
+
start = 0
|
112 |
+
if all_nums > 0:
|
113 |
+
for s, real_image_len in enumerate(real_image_nums):
|
114 |
+
current_embs = img_embs[start:start + real_image_len, :] #[30, 256, 4096]
|
115 |
+
for ss in range(current_embs.shape[0]):
|
116 |
+
im_s, im_e, index = images_pos[s, ss]
|
117 |
+
# 子图特征更少
|
118 |
+
if index > 0 and vit_patch_mlp:
|
119 |
+
encoder_input[s, im_s:im_e,] = current_embs[ss, :(im_e-im_s)]
|
120 |
+
else:
|
121 |
+
encoder_input[s, im_s:im_e] = current_embs[ss, :]
|
122 |
+
start = start + real_image_len
|
123 |
+
else:
|
124 |
+
# when no input image, to mask vit value
|
125 |
+
for s in range(b):
|
126 |
+
vit_mask = torch.zeros([vit_token, 1]).cuda()
|
127 |
+
current_embs = img_embs[:, start:start + 1]
|
128 |
+
encoder_input[1:vit_token + 1, s] = encoder_input[1:vit_token + 1, s] * (1 - vit_mask) + current_embs[:, 0, :] * vit_mask
|
129 |
+
start = start + 1
|
130 |
+
return encoder_input, input_ids
|
131 |
+
|
132 |
+
|
133 |
+
def group_images_by_max_seq_len(
|
134 |
+
images: List[List[Tensor]], patch_size: int,
|
135 |
+
max_seq_len: int, adaptor_patch_size: int,
|
136 |
+
add_cls_token: bool = False) -> List[List[Tensor]]:
|
137 |
+
|
138 |
+
groups = []
|
139 |
+
group = []
|
140 |
+
pos_groups = []
|
141 |
+
seq_len = 0
|
142 |
+
num_images = 0
|
143 |
+
for image_list in images:
|
144 |
+
pos_group = []
|
145 |
+
for image in image_list:
|
146 |
+
num_images += 1
|
147 |
+
assert isinstance(image, Tensor)
|
148 |
+
|
149 |
+
image_dims = image.shape[-2:]
|
150 |
+
ph, pw = map(lambda t: t // patch_size, image_dims)
|
151 |
+
|
152 |
+
image_seq_len = (ph * pw)
|
153 |
+
new_image_seq_len = image_seq_len
|
154 |
+
grouped_len = seq_len + image_seq_len
|
155 |
+
if add_cls_token:
|
156 |
+
new_image_seq_len += 1
|
157 |
+
grouped_len += num_images
|
158 |
+
|
159 |
+
assert new_image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
|
160 |
+
|
161 |
+
if grouped_len > max_seq_len:
|
162 |
+
groups.append(group)
|
163 |
+
group = []
|
164 |
+
seq_len = 0
|
165 |
+
num_images = 1
|
166 |
+
|
167 |
+
group.append(image)
|
168 |
+
start = seq_len // (adaptor_patch_size * adaptor_patch_size)
|
169 |
+
end = start + image_seq_len//(adaptor_patch_size * adaptor_patch_size)
|
170 |
+
batch_idx = len(groups)
|
171 |
+
pos_group.append([batch_idx, start, end])
|
172 |
+
seq_len += image_seq_len
|
173 |
+
pos_groups.append(pos_group)
|
174 |
+
|
175 |
+
if len(group) > 0:
|
176 |
+
groups.append(group)
|
177 |
+
|
178 |
+
return groups, pos_groups
|
179 |
+
|
180 |
+
|
181 |
+
class AnyResCLIPVisionEmbeddings(nn.Module):
|
182 |
+
def __init__(self, config: CLIPVisionConfig):
|
183 |
+
super().__init__()
|
184 |
+
|
185 |
+
self.config = config
|
186 |
+
# self.sparse_attn_mask = args.sparse_attn_mask
|
187 |
+
# self.use_flash_attn = args.use_flash_attn
|
188 |
+
self.embed_dim = config.hidden_size
|
189 |
+
self.image_size = config.max_image_size
|
190 |
+
self.patch_size = config.patch_size
|
191 |
+
self.max_seq_len = config.max_vit_seq_len
|
192 |
+
self.adaptor_patch_size = config.adaptor_patch_size
|
193 |
+
self.anyres_vit_two_views = config.anyres_vit_two_views
|
194 |
+
self.vit_add_patchemb_bias = config.vit_add_patchemb_bias
|
195 |
+
self.vit_remove_prenorm = config.vit_remove_prenorm
|
196 |
+
|
197 |
+
self.patch_embedding = nn.Conv2d(
|
198 |
+
in_channels=config.num_channels,
|
199 |
+
out_channels=self.embed_dim,
|
200 |
+
kernel_size=self.patch_size,
|
201 |
+
stride=self.patch_size,
|
202 |
+
bias=self.vit_add_patchemb_bias,
|
203 |
+
)
|
204 |
+
|
205 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
206 |
+
self.skip_cls_token = True
|
207 |
+
|
208 |
+
# add interpolate_pos_encoding
|
209 |
+
if self.anyres_vit_two_views:
|
210 |
+
self.num_positions = self.num_patches
|
211 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim) * 0.02)
|
212 |
+
else:
|
213 |
+
self.num_positions = self.num_patches + 1
|
214 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
215 |
+
# self.position_ids = torch.arange(self.num_positions).expand((1, -1))
|
216 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
217 |
+
|
218 |
+
if not self.vit_remove_prenorm:
|
219 |
+
self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
220 |
+
|
221 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
222 |
+
"""
|
223 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
224 |
+
resolution images.
|
225 |
+
|
226 |
+
Source:
|
227 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
228 |
+
"""
|
229 |
+
num_patches = embeddings.shape[1]
|
230 |
+
position_embeddings = self.position_embedding(self.position_ids)
|
231 |
+
patch_pos_embed = position_embeddings[:, 1:]
|
232 |
+
num_positions = position_embeddings.shape[1] - 1
|
233 |
+
if num_patches == num_positions and height == width:
|
234 |
+
return patch_pos_embed
|
235 |
+
# class_pos_embed = position_embeddings[:, 0]
|
236 |
+
dim = embeddings.shape[-1]
|
237 |
+
h0 = height // self.patch_size
|
238 |
+
w0 = width // self.patch_size
|
239 |
+
# we add a small number to avoid floating point error in the interpolation
|
240 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
241 |
+
h0, w0 = h0 + 0.1, w0 + 0.1
|
242 |
+
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
243 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
244 |
+
raw_type = patch_pos_embed.dtype
|
245 |
+
patch_pos_embed = nn.functional.interpolate(
|
246 |
+
patch_pos_embed.to(torch.float32, non_blocking=True),
|
247 |
+
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
248 |
+
mode="bilinear",
|
249 |
+
align_corners=False,
|
250 |
+
)
|
251 |
+
patch_pos_embed = patch_pos_embed.to(raw_type, non_blocking=True)
|
252 |
+
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
253 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
254 |
+
return patch_pos_embed
|
255 |
+
|
256 |
+
def rescale_positional_embedding(self, out_size):
|
257 |
+
h, w = out_size
|
258 |
+
pos_embed_shape = int((self.position_embedding.shape[1]) ** 0.5)
|
259 |
+
if (h, w) == (pos_embed_shape, pos_embed_shape):
|
260 |
+
return self.position_embedding
|
261 |
+
rescaled_positional_embedding = \
|
262 |
+
self.position_embedding.new_zeros(1, h*w, self.position_embedding.shape[2])
|
263 |
+
pe_2d = self.position_embedding[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
|
264 |
+
pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
|
265 |
+
rescaled_positional_embedding[0] = pe_2d.T.contiguous()
|
266 |
+
return rescaled_positional_embedding
|
267 |
+
|
268 |
+
def forward_single(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
269 |
+
if pixel_values.ndim == 3:
|
270 |
+
pixel_values = pixel_values[None]
|
271 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
272 |
+
|
273 |
+
if self.anyres_vit_two_views:
|
274 |
+
# padding
|
275 |
+
pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
|
276 |
+
pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
|
277 |
+
pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
|
278 |
+
|
279 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
280 |
+
b, c, h, w = patch_embeds.shape
|
281 |
+
|
282 |
+
# (b, hw, c)
|
283 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
284 |
+
if self.anyres_vit_two_views:
|
285 |
+
embeddings = patch_embeds + self.rescale_positional_embedding(out_size=(h, w))
|
286 |
+
else:
|
287 |
+
embeddings = patch_embeds + self.interpolate_pos_encoding(patch_embeds, height, width)
|
288 |
+
if not self.vit_remove_prenorm:
|
289 |
+
embeddings = self.pre_layernorm(embeddings)
|
290 |
+
return embeddings, (h, w)
|
291 |
+
|
292 |
+
def forward(self, images: List[List[Tensor]]):
|
293 |
+
'''
|
294 |
+
Input:
|
295 |
+
images: List[List[Tensor]]
|
296 |
+
|
297 |
+
Return:
|
298 |
+
embeddings: Tensor (B, L, E)
|
299 |
+
attn_mask: Tensor (B, L, 2)
|
300 |
+
pos_groups: List[List[(batch_idx, start, end)]]
|
301 |
+
'''
|
302 |
+
batched_images, pos_groups = group_images_by_max_seq_len(
|
303 |
+
images, self.patch_size, self.max_seq_len, self.adaptor_patch_size, add_cls_token=not self.skip_cls_token)
|
304 |
+
max_seq_len = self.max_seq_len
|
305 |
+
|
306 |
+
# batched_images is a list of a list
|
307 |
+
B = len(batched_images)
|
308 |
+
L = max_seq_len
|
309 |
+
E = self.embed_dim
|
310 |
+
|
311 |
+
embeddings = torch.zeros(B, L, E, dtype=self.config.torch_dtype, requires_grad=True).cuda(non_blocking=True)
|
312 |
+
attn_mask = embeddings.new_full((B, 1, L, L), False, dtype=torch.bool) # True presents compute
|
313 |
+
assert len(images) == len(pos_groups), (len(images), len(pos_groups))
|
314 |
+
|
315 |
+
batch_images = []
|
316 |
+
batch_pos = []
|
317 |
+
for images_i, pos_group in zip(images, pos_groups):
|
318 |
+
assert len(images_i) == len(pos_group), (len(images_i), len(pos_group))
|
319 |
+
for image, pos in zip(images_i, pos_group):
|
320 |
+
batch_idx, start, end = pos
|
321 |
+
a2 = self.adaptor_patch_size ** 2
|
322 |
+
# recover the real number of the input image tokens
|
323 |
+
start *= a2
|
324 |
+
end *= a2
|
325 |
+
emb, _ = self.forward_single(image)
|
326 |
+
assert emb.ndim == 3, '(B, L, E)'
|
327 |
+
embeddings[batch_idx, start:end] = emb
|
328 |
+
attn_mask[batch_idx, :, start:end, start:end] = True
|
329 |
+
return embeddings, attn_mask, pos_groups
|
330 |
+
|
331 |
+
|
332 |
+
class CLIPVisionEmbeddings(nn.Module):
|
333 |
+
def __init__(self, config: CLIPVisionConfig, add_pre_layernorm=False, skip_cls_token=True, vit_patch=1):
|
334 |
+
super().__init__()
|
335 |
+
self.config = config
|
336 |
+
self.embed_dim = config.hidden_size
|
337 |
+
self.image_size = config.image_size
|
338 |
+
self.image_size = config.vit_input_resolution
|
339 |
+
self.patch_size = config.patch_size
|
340 |
+
|
341 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
342 |
+
|
343 |
+
self.patch_embedding = nn.Conv2d(
|
344 |
+
in_channels=config.num_channels,
|
345 |
+
out_channels=self.embed_dim,
|
346 |
+
kernel_size=self.patch_size,
|
347 |
+
stride=self.patch_size,
|
348 |
+
bias=False,
|
349 |
+
)
|
350 |
+
|
351 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
352 |
+
|
353 |
+
self.skip_cls_token = skip_cls_token
|
354 |
+
|
355 |
+
self.num_positions = self.num_patches + 1
|
356 |
+
|
357 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
358 |
+
if vit_patch > 1:
|
359 |
+
self.position_embedding = nn.Embedding(self.num_patches * (vit_patch ** 2 + 1) + 1, self.embed_dim)
|
360 |
+
# 0 支持最大16张图,目前写死了,如需其他的需要额外定义参数
|
361 |
+
elif vit_patch == 0:
|
362 |
+
self.position_embedding = nn.Embedding(self.num_patches * (16 ** 2 + 1) + 1, self.embed_dim)
|
363 |
+
else:
|
364 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
365 |
+
|
366 |
+
if add_pre_layernorm:
|
367 |
+
self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
368 |
+
else:
|
369 |
+
self.pre_layernorm = None
|
370 |
+
|
371 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
372 |
+
"""
|
373 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
374 |
+
resolution images.
|
375 |
+
|
376 |
+
Source:
|
377 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
378 |
+
"""
|
379 |
+
num_patches = embeddings.shape[1] - 1
|
380 |
+
position_embeddings = self.position_embedding(self.position_ids)
|
381 |
+
num_positions = position_embeddings.shape[1] - 1
|
382 |
+
if num_patches == num_positions and height == width:
|
383 |
+
return position_embeddings
|
384 |
+
class_pos_embed = position_embeddings[:, 0]
|
385 |
+
patch_pos_embed = position_embeddings[:, 1:]
|
386 |
+
dim = embeddings.shape[-1]
|
387 |
+
h0 = height // self.config.patch_size
|
388 |
+
w0 = width // self.config.patch_size
|
389 |
+
# we add a small number to avoid floating point error in the interpolation
|
390 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
391 |
+
h0, w0 = h0 + 0.1, w0 + 0.1
|
392 |
+
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
393 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
394 |
+
raw_type = patch_pos_embed.dtype
|
395 |
+
patch_pos_embed = nn.functional.interpolate(
|
396 |
+
patch_pos_embed.float(),
|
397 |
+
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
398 |
+
mode="bicubic",
|
399 |
+
align_corners=False,
|
400 |
+
)
|
401 |
+
# print(patch_pos_embed.shape)
|
402 |
+
patch_pos_embed = patch_pos_embed.to(raw_type)
|
403 |
+
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
404 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
405 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
406 |
+
|
407 |
+
|
408 |
+
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, img_index=None) -> torch.Tensor:
|
409 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
410 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
411 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
412 |
+
if self.skip_cls_token:
|
413 |
+
embeddings = patch_embeds
|
414 |
+
if img_index is None:
|
415 |
+
position_ids = self.position_ids[:,1:]
|
416 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
417 |
+
else:
|
418 |
+
position_ids = (torch.tensor(img_index).cuda() * (self.num_positions - 1)).unsqueeze(1).repeat(1, self.num_positions - 1) \
|
419 |
+
+ self.position_ids.expand(batch_size, -1)[:, 1:]
|
420 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
421 |
+
else:
|
422 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
423 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
424 |
+
if interpolate_pos_encoding:
|
425 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
426 |
+
else:
|
427 |
+
if img_index is None:
|
428 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
429 |
+
else:
|
430 |
+
position_ids = self.position_ids.expand(batch_size,-1)[:,0].unsqueeze(1)
|
431 |
+
new_position = (torch.tensor(img_index).cuda() * (self.num_positions -1)).unsqueeze(1).repeat(1,self.num_positions-1) + self.position_ids.expand(batch_size,-1)[:,1:]
|
432 |
+
position_ids = torch.cat([position_ids,new_position],dim=1)
|
433 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
434 |
+
if self.pre_layernorm is not None:
|
435 |
+
embeddings = self.pre_layernorm(embeddings)
|
436 |
+
return embeddings
|
437 |
+
|
438 |
+
|
439 |
+
class NaVitTransformer(nn.Module):
|
440 |
+
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
|
441 |
+
super().__init__()
|
442 |
+
self.config = config
|
443 |
+
self.vit_config = vit_config
|
444 |
+
with self.prepare_args(config, vit_config):
|
445 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
446 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
447 |
+
self.layers = nn.ModuleList(
|
448 |
+
[HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
449 |
+
)
|
450 |
+
|
451 |
+
@contextmanager
|
452 |
+
def prepare_args(self, config, vit_config):
|
453 |
+
hidden_act = config.hidden_act
|
454 |
+
hidden_size = config.hidden_size
|
455 |
+
ffn_hidden_size = config.intermediate_size
|
456 |
+
num_attention_heads = config.num_attention_heads
|
457 |
+
num_key_value_heads = config.num_key_value_heads
|
458 |
+
attention_head_dim = config.attention_head_dim
|
459 |
+
use_qk_norm = config.use_qk_norm
|
460 |
+
use_rotary_pos_emb = config.use_rotary_pos_emb
|
461 |
+
num_hidden_layers = config.num_hidden_layers
|
462 |
+
rms_norm_eps = config.rms_norm_eps
|
463 |
+
attention_dropout = config.attention_dropout
|
464 |
+
# hidden_dropout = config.hidden_dropout
|
465 |
+
norm_type = config.norm_type
|
466 |
+
attention_bias = config.attention_bias
|
467 |
+
mlp_bias = config.mlp_bias
|
468 |
+
use_mla = config.use_mla
|
469 |
+
num_experts = config.num_experts
|
470 |
+
_attn_implementation = config._attn_implementation
|
471 |
+
|
472 |
+
config.hidden_act = vit_config.hidden_act
|
473 |
+
config.hidden_size = vit_config.hidden_size
|
474 |
+
config.intermediate_size = vit_config.intermediate_size
|
475 |
+
config.num_attention_heads = vit_config.num_attention_heads
|
476 |
+
config.num_key_value_heads = None
|
477 |
+
config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
|
478 |
+
config.use_qk_norm = False
|
479 |
+
config.use_rotary_pos_emb = False
|
480 |
+
config.num_hidden_layers = vit_config.num_hidden_layers
|
481 |
+
config.rms_norm_eps = vit_config.layer_norm_eps
|
482 |
+
config.attention_dropout = vit_config.attention_dropout
|
483 |
+
# config.hidden_dropout = vit_config.hidden_dropout
|
484 |
+
config.norm_type = config.vit_norm_type
|
485 |
+
config.attention_bias = True
|
486 |
+
config.mlp_bias = True
|
487 |
+
config.use_mla = False
|
488 |
+
config.num_experts = 1
|
489 |
+
config._attn_implementation = "eager"
|
490 |
+
|
491 |
+
yield
|
492 |
+
config.hidden_act = hidden_act
|
493 |
+
config.hidden_size = hidden_size
|
494 |
+
config.intermediate_size = ffn_hidden_size
|
495 |
+
config.num_attention_heads = num_attention_heads
|
496 |
+
config.num_key_value_heads = num_key_value_heads
|
497 |
+
config.attention_head_dim = attention_head_dim
|
498 |
+
config.use_qk_norm = use_qk_norm
|
499 |
+
config.use_rotary_pos_emb = use_rotary_pos_emb
|
500 |
+
config.num_hidden_layers = num_hidden_layers
|
501 |
+
config.rms_norm_eps = rms_norm_eps
|
502 |
+
config.attention_dropout = attention_dropout
|
503 |
+
# config.hidden_dropout = hidden_dropout
|
504 |
+
config.attention_bias = attention_bias
|
505 |
+
config.mlp_bias = mlp_bias
|
506 |
+
config.norm_type = norm_type
|
507 |
+
config.use_mla = use_mla
|
508 |
+
config.num_experts = num_experts
|
509 |
+
config._attn_implementation = _attn_implementation
|
510 |
+
|
511 |
+
def forward(
|
512 |
+
self,
|
513 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
514 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
515 |
+
|
516 |
+
hidden_states, attention_mask, img_pos = self.embeddings(pixel_values)
|
517 |
+
attention_mask = attention_mask.int()
|
518 |
+
batch_size, seq_length, _ = hidden_states.shape
|
519 |
+
past_key_values_length = 0
|
520 |
+
|
521 |
+
if self._use_flash_attention_2:
|
522 |
+
# 2d mask is passed through the layers
|
523 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
524 |
+
elif self._use_sdpa:
|
525 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
526 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
527 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
528 |
+
attention_mask,
|
529 |
+
(batch_size, seq_length),
|
530 |
+
hidden_states,
|
531 |
+
past_key_values_length,
|
532 |
+
)
|
533 |
+
else:
|
534 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
535 |
+
attention_mask,
|
536 |
+
(batch_size, seq_length),
|
537 |
+
hidden_states,
|
538 |
+
past_key_values_length,
|
539 |
+
)
|
540 |
+
|
541 |
+
for layer_idx, decoder_layer in enumerate(self.layers):
|
542 |
+
layer_outputs = decoder_layer(
|
543 |
+
hidden_states,
|
544 |
+
attention_mask=attention_mask
|
545 |
+
)
|
546 |
+
hidden_states = layer_outputs[0]
|
547 |
+
|
548 |
+
return hidden_states, img_pos
|
549 |
+
|
550 |
+
|
551 |
+
class AnyResVitTransformer(NaVitTransformer):
|
552 |
+
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig, anyres_vit_max_image_size):
|
553 |
+
super().__init__(config, vit_config)
|
554 |
+
old_anyres_vit_max_image_size = vit_config.max_image_size
|
555 |
+
anyres_vit_max_image_size = anyres_vit_max_image_size or old_anyres_vit_max_image_size
|
556 |
+
vit_config.max_image_size = anyres_vit_max_image_size
|
557 |
+
vit_config.torch_dtype = config.torch_dtype
|
558 |
+
vit_config.anyres_vit_two_views = config.anyres_vit_two_views
|
559 |
+
vit_config.vit_remove_prenorm = config.vit_remove_prenorm
|
560 |
+
vit_config.vit_add_patchemb_bias = config.vit_add_patchemb_bias
|
561 |
+
self.embeddings = AnyResCLIPVisionEmbeddings(vit_config)
|
562 |
+
vit_config.max_image_size = old_anyres_vit_max_image_size
|
563 |
+
|
564 |
+
def fix_embeddings_fn(self, pixel_values):
|
565 |
+
# (B, L, E)
|
566 |
+
embeddings, hw = self.embeddings.forward_single(pixel_values)
|
567 |
+
embeddings = self.embeddings.pre_layernorm(embeddings)
|
568 |
+
return embeddings
|
569 |
+
|
570 |
+
|
571 |
+
class CLIPVisionTransformer(nn.Module):
|
572 |
+
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
|
573 |
+
super().__init__()
|
574 |
+
embed_dim = vit_config.hidden_size
|
575 |
+
|
576 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=vit_config.layer_norm_eps)
|
577 |
+
self.embeddings = CLIPVisionEmbeddings(vit_config, skip_cls_token=config.skip_cls_token, vit_patch=config.vit_patch)
|
578 |
+
|
579 |
+
with self.prepare_args(config, vit_config):
|
580 |
+
self.layers = nn.ModuleList(
|
581 |
+
[HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
582 |
+
)
|
583 |
+
|
584 |
+
@contextmanager
|
585 |
+
def prepare_args(self, config, vit_config):
|
586 |
+
hidden_act = config.hidden_act
|
587 |
+
hidden_size = config.hidden_size
|
588 |
+
ffn_hidden_size = config.intermediate_size
|
589 |
+
num_attention_heads = config.num_attention_heads
|
590 |
+
num_key_value_heads = config.num_key_value_heads
|
591 |
+
attention_head_dim = config.attention_head_dim
|
592 |
+
use_qk_norm = config.use_qk_norm
|
593 |
+
use_rotary_pos_emb = config.use_rotary_pos_emb
|
594 |
+
num_hidden_layers = config.num_hidden_layers
|
595 |
+
rms_norm_eps = config.rms_norm_eps
|
596 |
+
attention_dropout = config.attention_dropout
|
597 |
+
# hidden_dropout = config.hidden_dropout
|
598 |
+
norm_type = config.norm_type
|
599 |
+
attention_bias = config.attention_bias
|
600 |
+
mlp_bias = config.mlp_bias
|
601 |
+
use_mla = config.use_mla
|
602 |
+
num_experts = config.num_experts
|
603 |
+
_attn_implementation = config._attn_implementation
|
604 |
+
|
605 |
+
config.hidden_act = vit_config.hidden_act
|
606 |
+
config.hidden_size = vit_config.hidden_size
|
607 |
+
config.intermediate_size = vit_config.intermediate_size
|
608 |
+
config.num_attention_heads = vit_config.num_attention_heads
|
609 |
+
config.num_key_value_heads = None
|
610 |
+
config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
|
611 |
+
config.use_qk_norm = False
|
612 |
+
config.use_rotary_pos_emb = False
|
613 |
+
config.num_hidden_layers = vit_config.num_hidden_layers
|
614 |
+
config.rms_norm_eps = vit_config.layer_norm_eps
|
615 |
+
config.attention_dropout = vit_config.attention_dropout
|
616 |
+
# config.hidden_dropout = 0.0
|
617 |
+
config.norm_type = "fused"
|
618 |
+
config.attention_bias = True
|
619 |
+
config.mlp_bias = True
|
620 |
+
config.use_mla = False
|
621 |
+
config.num_experts = 1
|
622 |
+
config._attn_implementation = "eager"
|
623 |
+
|
624 |
+
yield
|
625 |
+
|
626 |
+
config.hidden_act = hidden_act
|
627 |
+
config.hidden_size = hidden_size
|
628 |
+
config.intermediate_size = ffn_hidden_size
|
629 |
+
config.num_attention_heads = num_attention_heads
|
630 |
+
config.num_key_value_heads = num_key_value_heads
|
631 |
+
config.attention_head_dim = attention_head_dim
|
632 |
+
config.use_qk_norm = use_qk_norm
|
633 |
+
config.use_rotary_pos_emb = use_rotary_pos_emb
|
634 |
+
config.num_hidden_layers = num_hidden_layers
|
635 |
+
config.rms_norm_eps = rms_norm_eps
|
636 |
+
config.attention_dropout = attention_dropout
|
637 |
+
# config.hidden_dropout = hidden_dropout
|
638 |
+
config.norm_type = norm_type
|
639 |
+
config.attention_bias = attention_bias
|
640 |
+
config.mlp_bias = mlp_bias
|
641 |
+
config.use_mla = use_mla
|
642 |
+
config.num_experts = num_experts
|
643 |
+
config._attn_implementation = _attn_implementation
|
644 |
+
|
645 |
+
def forward(
|
646 |
+
self,
|
647 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
648 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
649 |
+
img_index=None
|
650 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
651 |
+
r"""
|
652 |
+
Returns:
|
653 |
+
|
654 |
+
"""
|
655 |
+
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, img_index=img_index)
|
656 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
657 |
+
batch = hidden_states.shape[0]
|
658 |
+
seq_len = hidden_states.shape[1]
|
659 |
+
device = hidden_states.device
|
660 |
+
attention_mask = torch.ones(batch, 1, seq_len, seq_len, dtype=torch.float32, device=device)
|
661 |
+
|
662 |
+
for layer_idx, decoder_layer in enumerate(self.layers):
|
663 |
+
layer_outputs = decoder_layer(
|
664 |
+
hidden_states,
|
665 |
+
attention_mask=attention_mask
|
666 |
+
)
|
667 |
+
hidden_states = layer_outputs[0]
|
668 |
+
|
669 |
+
return hidden_states
|
670 |
+
|
671 |
+
|
672 |
+
class Vit(torch.nn.Module):
|
673 |
+
def __init__(self, config, resampler_token=64, pool_rate=2):
|
674 |
+
super().__init__()
|
675 |
+
self.config = config
|
676 |
+
self.vit_mapping_type = config.vit_mapping_type
|
677 |
+
self.anyres_vit_max_image_size = config.anyres_vit_max_image_size
|
678 |
+
self.skip_cls_token = config.skip_cls_token
|
679 |
+
self.pool_rate = pool_rate
|
680 |
+
self.vit_type = self.config.vit_type
|
681 |
+
self.anyres_vit_two_views = self.config.anyres_vit_two_views
|
682 |
+
if self.vit_type in ['Vit-g', 'Vit-bigG', 'NaVit', 'EvaVit', 'AnyResVit']:
|
683 |
+
self.img_init(resampler_token, config.vit_input_resolution, config.vit_mapping_type, pool_rate)
|
684 |
+
else:
|
685 |
+
raise NotImplementedError(f"unsupported vit type: {self.vit_type}")
|
686 |
+
|
687 |
+
def img_init(self, resampler_token=64, vit_input_resolution=224, vit_mapping_type='resampler', pool_rate=2):
|
688 |
+
if self.vit_type == 'AnyResVit':
|
689 |
+
vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
|
690 |
+
self.vit_config = types.SimpleNamespace(**vit_config["vision_config"])
|
691 |
+
self.vit_config.image_size = vit_input_resolution
|
692 |
+
self.vit = AnyResVitTransformer(self.config, self.vit_config, self.anyres_vit_max_image_size)
|
693 |
+
elif self.vit_type == 'Vit-g':
|
694 |
+
vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
|
695 |
+
self.vit_config = types.SimpleNamespace(**{**vit_config["vision_config_dict"],**vit_config["vision_config"]})
|
696 |
+
self.vit_config.vit_input_resolution = vit_input_resolution
|
697 |
+
self.vit = CLIPVisionTransformer(self.config, self.vit_config)
|
698 |
+
else:
|
699 |
+
assert False, "other vit_types are not supported"
|
700 |
+
|
701 |
+
if self.vit_mapping_type == 'simple_conv_mlp':
|
702 |
+
self.perceive = SimpleConvMlp(self.vit_config.hidden_size, self.config.hidden_size, self.config.anyres_pooling_size, \
|
703 |
+
self.config.vit_used_rms_norm, self.config.rms_norm_eps, poolmlp=False, twoview=True)
|
704 |
+
elif self.vit_mapping_type == 'oryx_mlp':
|
705 |
+
self.perceive = OryxMLPv2(self.vit_config.hidden_size, self.config.hidden_size, twoview=True, use_pe=False)
|
706 |
+
elif self.vit_mapping_type == 'mlp':
|
707 |
+
self.mlp_depth = 2
|
708 |
+
# one mlp layer already in gpt_model.py
|
709 |
+
mlp_hidden_size = self.vit_config.hidden_size
|
710 |
+
if self.vit_type in ['NaVit', 'EvaVit']:
|
711 |
+
mlp_hidden_size *= self.vit_config.adaptor_patch_size **2
|
712 |
+
if self.mlp_depth > 1:
|
713 |
+
mlp_modules = [torch.nn.Linear(mlp_hidden_size, self.config.hidden_size), torch.nn.GELU()]
|
714 |
+
if self.vit_type in ['NaVit', 'EvaVit']:
|
715 |
+
for _ in range(1, self.mlp_depth):
|
716 |
+
mlp_modules.append(torch.nn.Linear(self.config.hidden_size, self.config.hidden_size))
|
717 |
+
mlp_modules.append(torch.nn.GELU())
|
718 |
+
self.perceive = torch.nn.Sequential(*mlp_modules)
|
719 |
+
else:
|
720 |
+
assert False, "other vit_mapping_types are not supported"
|
721 |
+
|
722 |
+
self.vit_patch_mlp = (self.config.vit_patch > 1 and self.vit_mapping_type == 'mlp') or self.config.vit_patch == 0
|
723 |
+
for name, param in self.named_parameters():
|
724 |
+
setattr(param, "is_vit_param", True)
|
725 |
+
|
726 |
+
def forward(self, images, img_index=None):
|
727 |
+
if self.vit_type in ['AnyResVit']:
|
728 |
+
dtype = self.config.torch_dtype
|
729 |
+
device = torch.cuda.current_device()
|
730 |
+
|
731 |
+
images_size = []
|
732 |
+
for i in range(len(images)):
|
733 |
+
images_size.append([])
|
734 |
+
for j in range(len(images[i])):
|
735 |
+
images_size[i].append((images[i][j].size()[1] // self.vit_config.patch_size, images[i][j].size()[2] // self.vit_config.patch_size))
|
736 |
+
|
737 |
+
images_feats, img_batch_pos = self.vit(pixel_values=images)
|
738 |
+
a2 = self.vit_config.adaptor_patch_size ** 2
|
739 |
+
|
740 |
+
if self.anyres_vit_two_views:
|
741 |
+
step = 2
|
742 |
+
else:
|
743 |
+
step = 1
|
744 |
+
perceive_fn = lambda x, img_size, is_video: self.perceive(x, img_size, is_video=is_video)
|
745 |
+
images_list = []
|
746 |
+
images_fix_i = 0
|
747 |
+
num_img_batch_pos = len(img_batch_pos)
|
748 |
+
for i in range(num_img_batch_pos): # batch_id
|
749 |
+
for j in range(0, len(img_batch_pos[i]), step):
|
750 |
+
if self.anyres_vit_two_views:
|
751 |
+
lower_idx, lower_begin, lower_end = img_batch_pos[i][j]
|
752 |
+
lower_begin = lower_begin * a2
|
753 |
+
lower_end = lower_end * a2
|
754 |
+
higher_idx, higher_begin, higher_end = img_batch_pos[i][j + 1]
|
755 |
+
higher_begin = higher_begin * a2
|
756 |
+
higher_end = higher_end * a2
|
757 |
+
lower_res_feat = images_feats[lower_idx, lower_begin:lower_end].unsqueeze(0)
|
758 |
+
higher_res_feat = images_feats[higher_idx, higher_begin:higher_end].unsqueeze(0)
|
759 |
+
lower_images_size = images_size[i][j]
|
760 |
+
higher_images_size = images_size[i][j + 1]
|
761 |
+
images_list.append(self.perceive(lower_res_feat, lower_images_size, higher_res_feat, higher_images_size))
|
762 |
+
else:
|
763 |
+
idx, begin, end = img_batch_pos[i][j]
|
764 |
+
begin = begin * a2
|
765 |
+
end = end * a2
|
766 |
+
is_video = hasattr(images[i][j],'_is_video') and images[i][j]._is_video
|
767 |
+
images_list.append(perceive_fn(images_feats[idx, begin:end].unsqueeze(0), images_size[i][j], is_video=is_video))
|
768 |
+
|
769 |
+
images = torch.cat(images_list, dim=1)
|
770 |
+
|
771 |
+
new_batch_pos = []
|
772 |
+
k = 0; cur_len = 0
|
773 |
+
for i in range(len(images_size)):
|
774 |
+
new_batch_pos.append([])
|
775 |
+
for j in range(0, len(images_size[i]), step):
|
776 |
+
new_pos = [0, cur_len, cur_len + images_list[k].size(1)]
|
777 |
+
cur_len += images_list[k].size(1)
|
778 |
+
k += 1
|
779 |
+
new_batch_pos[i].append(new_pos)
|
780 |
+
return images, new_batch_pos
|
781 |
+
elif self.vit_type == 'Vit-g':
|
782 |
+
images = self.vit(pixel_values=images, interpolate_pos_encoding=False, img_index=img_index)
|
783 |
+
else:
|
784 |
+
assert False, "other vit_types are not supported"
|
785 |
+
|
786 |
+
if self.vit_mapping_type == 'mlp':
|
787 |
+
if self.vit_type in ['Vit-g'] and not self.skip_cls_token:
|
788 |
+
images = images[:,1:,:]
|
789 |
+
b, v, d = images.shape
|
790 |
+
s = int(math.sqrt(v))
|
791 |
+
images = images.reshape(b, s, s, d)
|
792 |
+
|
793 |
+
|
794 |
+
if self.vit_patch_mlp and img_index is not None:
|
795 |
+
L_tensor = torch.tensor(img_index)
|
796 |
+
device = images.device
|
797 |
+
# 获取子图位置
|
798 |
+
nonzero_indices = torch.nonzero(L_tensor).squeeze().to(device)
|
799 |
+
# 获取主图位置
|
800 |
+
zero_indices = torch.nonzero(L_tensor == 0).squeeze().to(device)
|
801 |
+
|
802 |
+
|
803 |
+
images_nonzero = torch.index_select(images,0, nonzero_indices).to(device)
|
804 |
+
images_zero = torch.index_select(images, 0, zero_indices).to(device)
|
805 |
+
|
806 |
+
# 子图额外多pool一次
|
807 |
+
pool_rate = self.pool_rate * 2
|
808 |
+
images_nonzero = images_nonzero.reshape(-1, s // pool_rate, pool_rate, s // pool_rate, pool_rate, d)
|
809 |
+
images_nonzero = images_nonzero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // pool_rate) * (s // pool_rate), d,
|
810 |
+
pool_rate*pool_rate).mean(-1)
|
811 |
+
|
812 |
+
# 为了组batch折衷方案
|
813 |
+
images_nonzero = F.pad(images_nonzero, (0, 0, 0, (s // self.pool_rate) * (s // self.pool_rate)- (s // pool_rate) * (s // pool_rate)))
|
814 |
+
images_zero = images_zero.reshape(-1, s // self.pool_rate, self.pool_rate, s // self.pool_rate, self.pool_rate, d)
|
815 |
+
images_zero = images_zero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // self.pool_rate) * (s // self.pool_rate), d,
|
816 |
+
self.pool_rate*self.pool_rate).mean(-1)
|
817 |
+
# 组batch
|
818 |
+
images = torch.zeros(b, (s // self.pool_rate) * (s // self.pool_rate), d).to(device).to(images.dtype)
|
819 |
+
images.index_copy_(0, nonzero_indices, images_nonzero)
|
820 |
+
images.index_copy_(0, zero_indices, images_zero)
|
821 |
+
|
822 |
+
if self.mlp_depth >= 2:
|
823 |
+
images = self.perceive(images)
|
824 |
+
else:
|
825 |
+
if s % self.pool_rate == 0:
|
826 |
+
images = images.reshape(b, s//self.pool_rate, self.pool_rate, s//self.pool_rate, self.pool_rate, d)
|
827 |
+
images = images.permute(0, 1, 3, 5, 2, 4).reshape(b, (s//self.pool_rate) * (s//self.pool_rate), d, -1).mean(-1)
|
828 |
+
if self.mlp_depth >= 2:
|
829 |
+
images = self.perceive(images)
|
830 |
+
else:
|
831 |
+
raise ValueError
|
832 |
+
return images
|
833 |
+
|
834 |
+
|
835 |
+
class SimpleConvMlp(nn.Module):
|
836 |
+
def __init__(self, in_channels, out_channels, anyres_pooling_size, vit_used_rms_norm, rms_norm_eps, twoview=False, poolmlp=True, cat_extra_token=True):
|
837 |
+
super().__init__()
|
838 |
+
|
839 |
+
embed_std = 1 / math.sqrt(out_channels)
|
840 |
+
if poolmlp:
|
841 |
+
# if args.learnable_mlp_pooling_size is not None:
|
842 |
+
# in_channels *= args.learnable_mlp_pooling_size ** 2
|
843 |
+
self.proj = nn.Sequential(
|
844 |
+
nn.Linear(in_channels, out_channels),
|
845 |
+
nn.GELU()
|
846 |
+
)
|
847 |
+
self.vit_linear_encoder = nn.Linear(out_channels, out_channels)
|
848 |
+
self.image_newline = nn.Parameter(
|
849 |
+
torch.randn(out_channels) * embed_std
|
850 |
+
)
|
851 |
+
else:
|
852 |
+
self.proj = nn.Sequential(
|
853 |
+
nn.Conv2d(in_channels, in_channels * 2, kernel_size=anyres_pooling_size, stride=anyres_pooling_size),
|
854 |
+
nn.GELU(),
|
855 |
+
nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
|
856 |
+
)
|
857 |
+
self.mlp = nn.Linear(in_channels * 4, out_channels)
|
858 |
+
self.image_newline = nn.Parameter(
|
859 |
+
torch.randn(in_channels * 4) * embed_std
|
860 |
+
)
|
861 |
+
self.poolmlp = poolmlp
|
862 |
+
|
863 |
+
self.image_begin = nn.Parameter(
|
864 |
+
torch.randn(out_channels) * embed_std
|
865 |
+
)
|
866 |
+
self.image_end = nn.Parameter(
|
867 |
+
torch.randn(out_channels) * embed_std
|
868 |
+
)
|
869 |
+
|
870 |
+
if twoview:
|
871 |
+
self.image_sep = nn.Parameter(
|
872 |
+
torch.randn(out_channels) * embed_std
|
873 |
+
)
|
874 |
+
|
875 |
+
self.cat_extra_token = cat_extra_token
|
876 |
+
self.use_rms_norm = vit_used_rms_norm
|
877 |
+
if self.use_rms_norm:
|
878 |
+
self.before_rms = HunYuanRMSNorm(in_channels, eps=rms_norm_eps)
|
879 |
+
self.after_rms = HunYuanRMSNorm(out_channels, eps=rms_norm_eps)
|
880 |
+
|
881 |
+
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
|
882 |
+
return self.single_forward(x=x, size=size, x2=x2, size2=size2, is_video=is_video)
|
883 |
+
|
884 |
+
def single_forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
|
885 |
+
remove_vit_special_tokens = False
|
886 |
+
learnable_mlp_pooling_size = None
|
887 |
+
if self.use_rms_norm:
|
888 |
+
x = self.before_rms(x)
|
889 |
+
h, w = size
|
890 |
+
dtype = x.dtype
|
891 |
+
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
|
892 |
+
if self.poolmlp:
|
893 |
+
if learnable_mlp_pooling_size is None:
|
894 |
+
x = F.avg_pool2d(x, anyres_pooling_size)
|
895 |
+
x = self.proj(x.permute(0, 2, 3, 1)) # b, h, w, c
|
896 |
+
else:
|
897 |
+
x = x.permute(0, 2, 3, 1) # b, h, w, c
|
898 |
+
x = x.reshape(x.shape[0], h // learnable_mlp_pooling_size, learnable_mlp_pooling_size,
|
899 |
+
w // learnable_mlp_pooling_size, learnable_mlp_pooling_size, -1)
|
900 |
+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(x.shape[0], h // learnable_mlp_pooling_size, w // learnable_mlp_pooling_size, -1)
|
901 |
+
x = self.proj(x)
|
902 |
+
x = self.vit_linear_encoder(x)
|
903 |
+
b, h, w, c = x.shape
|
904 |
+
if not remove_vit_special_tokens:
|
905 |
+
x = torch.cat([
|
906 |
+
x,
|
907 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype, non_blocking=True)
|
908 |
+
], dim=2)
|
909 |
+
x = x.reshape(b, -1, c)
|
910 |
+
else:
|
911 |
+
x = self.proj(x) #b,c,h,w
|
912 |
+
if is_video:
|
913 |
+
video_avgpool_size = 2
|
914 |
+
stride = 2
|
915 |
+
x = F.avg_pool2d(x, kernel_size = video_avgpool_size, stride = stride)
|
916 |
+
b, c, h, w = x.shape
|
917 |
+
if not remove_vit_special_tokens:
|
918 |
+
x = torch.cat([
|
919 |
+
x,
|
920 |
+
self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype, non_blocking=True)
|
921 |
+
], dim=-1)
|
922 |
+
x = x.reshape(b, c, -1).permute(0, 2, 1)
|
923 |
+
x = self.mlp(x)
|
924 |
+
|
925 |
+
|
926 |
+
if x2 is not None:
|
927 |
+
h2, w2 = size2
|
928 |
+
x2 = x2.permute(0, 2, 1).reshape(x2.shape[0], -1, h2, w2)
|
929 |
+
if self.poolmlp:
|
930 |
+
x2 = F.avg_pool2d(x2, 2)
|
931 |
+
x2 = self.proj(x2.permute(0, 2, 3, 1)) # b, h, w, c
|
932 |
+
x2 = self.vit_linear_encoder(x2)
|
933 |
+
b2, h2, w2, c2 = x2.shape
|
934 |
+
if not remove_vit_special_tokens:
|
935 |
+
x2 = torch.cat([
|
936 |
+
x2,
|
937 |
+
self.image_newline.reshape(1, 1, 1, c2).expand(b2, h2, 1, c2).to(dtype, non_blocking=True)
|
938 |
+
], dim=2)
|
939 |
+
x2 = x2.reshape(b2, -1, c2)
|
940 |
+
else:
|
941 |
+
x2 = self.proj(x2)
|
942 |
+
b2, c2, h2, w2 = x2.shape
|
943 |
+
if not remove_vit_special_tokens:
|
944 |
+
x2 = torch.cat([
|
945 |
+
x2,
|
946 |
+
self.image_newline.reshape(1, c2, 1, 1).expand(b2, c2, h2, 1).to(dtype, non_blocking=True)
|
947 |
+
], dim=-1)
|
948 |
+
x2 = x2.reshape(b2, c2, -1).permute(0, 2, 1) #b,n,c
|
949 |
+
x2 = self.mlp(x2)
|
950 |
+
|
951 |
+
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, x2.shape[-1]).to(dtype, non_blocking=True)
|
952 |
+
|
953 |
+
x = torch.cat([x, sep, x2], dim=1)
|
954 |
+
|
955 |
+
if self.cat_extra_token:
|
956 |
+
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
|
957 |
+
end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
|
958 |
+
x = torch.cat([begin, x, end], dim=1)
|
959 |
+
|
960 |
+
if self.use_rms_norm:
|
961 |
+
return self.after_rms(x)
|
962 |
+
else:
|
963 |
+
return x
|
964 |
+
|
965 |
+
|
966 |
+
class NormalizedDwPooler(nn.Module):
|
967 |
+
def __init__(self, dim):
|
968 |
+
super().__init__()
|
969 |
+
self.dim = dim
|
970 |
+
self.predictor = nn.Sequential(
|
971 |
+
nn.Linear(dim*2, dim),
|
972 |
+
nn.GELU(),
|
973 |
+
nn.Linear(dim, dim),
|
974 |
+
)
|
975 |
+
|
976 |
+
def forward(self, x, forward_type='2x'):
|
977 |
+
B, H, W, C = x.shape
|
978 |
+
|
979 |
+
if forward_type == '2x':
|
980 |
+
new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
|
981 |
+
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
|
982 |
+
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
983 |
+
elif forward_type == '1x':
|
984 |
+
new_x = x.reshape(B, H, W, 1, C)
|
985 |
+
fused_x = torch.cat([new_x, new_x], dim=-1)
|
986 |
+
elif forward_type == '4x':
|
987 |
+
new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
|
988 |
+
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
|
989 |
+
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
990 |
+
|
991 |
+
score = self.predictor(fused_x)
|
992 |
+
normalized_score = F.softmax(score, dim=-2)
|
993 |
+
new_x = (new_x * normalized_score).sum(dim=-2)
|
994 |
+
return new_x
|
995 |
+
|
996 |
+
|
997 |
+
class OryxMLPv2(nn.Module):
|
998 |
+
def __init__(self, in_channels, out_channels, twoview=False, use_pe=False):
|
999 |
+
super().__init__()
|
1000 |
+
|
1001 |
+
self.proj1 = nn.Linear(in_channels, out_channels)
|
1002 |
+
self.proj2 = nn.Linear(out_channels, out_channels)
|
1003 |
+
self.act = nn.GELU()
|
1004 |
+
self.pooler = NormalizedDwPooler(out_channels)
|
1005 |
+
embed_std = 1 / math.sqrt(out_channels)
|
1006 |
+
|
1007 |
+
self.use_pe = use_pe
|
1008 |
+
if not use_pe:
|
1009 |
+
self.image_newline = nn.Parameter(
|
1010 |
+
torch.randn(out_channels) * embed_std
|
1011 |
+
)
|
1012 |
+
self.image_begin = nn.Parameter(
|
1013 |
+
torch.randn(out_channels) * embed_std
|
1014 |
+
)
|
1015 |
+
self.image_end = nn.Parameter(
|
1016 |
+
torch.randn(out_channels) * embed_std
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
if twoview:
|
1020 |
+
self.image_sep = nn.Parameter(
|
1021 |
+
torch.randn(out_channels) * embed_std
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
|
1025 |
+
h, w = size
|
1026 |
+
dtype = x.dtype
|
1027 |
+
x = x.reshape(x.shape[0], h, w, -1)
|
1028 |
+
# x = self.pooler(x, forward_type=REGIONAL_POOL)
|
1029 |
+
# x = self.proj(x) #b,h,w, c
|
1030 |
+
x = self.proj1(x)
|
1031 |
+
x = self.pooler(x, forward_type='2x')
|
1032 |
+
x = self.act(x)
|
1033 |
+
x = self.proj2(x)
|
1034 |
+
|
1035 |
+
|
1036 |
+
b, h, w, c = x.shape
|
1037 |
+
if not self.use_pe:
|
1038 |
+
x = torch.cat([
|
1039 |
+
x,
|
1040 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
|
1041 |
+
], dim=2)
|
1042 |
+
else:
|
1043 |
+
pe_h = torch.arange(h, dtype=torch.long, device=x.device).reshape(1, h, 1, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
|
1044 |
+
pe_w = torch.arange(w, dtype=torch.long, device=x.device).reshape(1, 1, w, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
|
1045 |
+
pe = torch.cat([pe_h, pe_w], dim=-1)
|
1046 |
+
|
1047 |
+
x = x.reshape(b, -1, c)
|
1048 |
+
|
1049 |
+
if x2 is not None:
|
1050 |
+
h2, w2 = size2
|
1051 |
+
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
1052 |
+
# x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
|
1053 |
+
## x2 = self.proj(x2) #b,h,w, c
|
1054 |
+
x2 = self.proj1(x2)
|
1055 |
+
x2 = self.pooler(x2, forward_type='2x')
|
1056 |
+
x2 = self.act(x2)
|
1057 |
+
x2 = self.proj2(x2)
|
1058 |
+
|
1059 |
+
b2, h2, w2, c2 = x2.shape
|
1060 |
+
if not self.use_pe:
|
1061 |
+
x2 = torch.cat([
|
1062 |
+
x2,
|
1063 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
|
1064 |
+
], dim=2)
|
1065 |
+
x2 = x2.reshape(b, -1, c)
|
1066 |
+
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
|
1067 |
+
x = torch.cat([x, sep, x2], dim=1)
|
1068 |
+
|
1069 |
+
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
1070 |
+
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
1071 |
+
x = torch.cat([begin, x, end], dim=1)
|
1072 |
+
# print(x.shape, x2.shape, h, w, h2, w2)
|
1073 |
+
# print("vit rank = " + str(torch.distributed.get_rank()) +" x = " + str(x))
|
1074 |
+
if self.use_pe:
|
1075 |
+
zero_pad = torch.zeros(b, 1, 2, device=x.device, dtype=torch.long)
|
1076 |
+
pe = torch.cat([zero_pad, pe, zero_pad], dim=1)
|
1077 |
+
assert pe.shape[1] == x.shape[1]
|
1078 |
+
return x, pe
|
1079 |
+
else:
|
1080 |
+
nseq = x.shape[1]
|
1081 |
+
fake_pe = torch.zeros(b, nseq, 2, device=x.device, dtype=torch.long)
|
1082 |
+
return x #, fake_pe
|
1083 |
+
|