Light-Dav commited on
Commit
a5ccb22
·
verified ·
1 Parent(s): ce27c7c

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. venv/Lib/site-packages/accelerate-1.6.0.dist-info/INSTALLER +1 -0
  2. venv/Lib/site-packages/accelerate-1.6.0.dist-info/LICENSE +201 -0
  3. venv/Lib/site-packages/accelerate-1.6.0.dist-info/METADATA +380 -0
  4. venv/Lib/site-packages/accelerate-1.6.0.dist-info/RECORD +177 -0
  5. venv/Lib/site-packages/accelerate-1.6.0.dist-info/REQUESTED +0 -0
  6. venv/Lib/site-packages/accelerate-1.6.0.dist-info/WHEEL +5 -0
  7. venv/Lib/site-packages/accelerate-1.6.0.dist-info/entry_points.txt +6 -0
  8. venv/Lib/site-packages/accelerate-1.6.0.dist-info/top_level.txt +1 -0
  9. venv/Lib/site-packages/accelerate/__init__.py +50 -0
  10. venv/Lib/site-packages/accelerate/accelerator.py +0 -0
  11. venv/Lib/site-packages/accelerate/big_modeling.py +637 -0
  12. venv/Lib/site-packages/accelerate/checkpointing.py +319 -0
  13. venv/Lib/site-packages/accelerate/data_loader.py +1429 -0
  14. venv/Lib/site-packages/accelerate/hooks.py +739 -0
  15. venv/Lib/site-packages/accelerate/inference.py +184 -0
  16. venv/Lib/site-packages/accelerate/launchers.py +301 -0
  17. venv/Lib/site-packages/accelerate/local_sgd.py +106 -0
  18. venv/Lib/site-packages/accelerate/logging.py +125 -0
  19. venv/Lib/site-packages/accelerate/memory_utils.py +22 -0
  20. venv/Lib/site-packages/accelerate/optimizer.py +212 -0
  21. venv/Lib/site-packages/accelerate/scheduler.py +98 -0
  22. venv/Lib/site-packages/accelerate/state.py +1330 -0
  23. venv/Lib/site-packages/accelerate/tracking.py +1089 -0
  24. venv/Lib/site-packages/adodbapi/__init__.py +82 -0
  25. venv/Lib/site-packages/adodbapi/ado_consts.py +283 -0
  26. venv/Lib/site-packages/adodbapi/adodbapi.py +1153 -0
  27. venv/Lib/site-packages/adodbapi/apibase.py +723 -0
  28. venv/Lib/site-packages/adodbapi/is64bit.py +34 -0
  29. venv/Lib/site-packages/adodbapi/license.txt +505 -0
  30. venv/Lib/site-packages/adodbapi/process_connect_string.py +137 -0
  31. venv/Lib/site-packages/adodbapi/readme.txt +88 -0
  32. venv/Lib/site-packages/adodbapi/schema_table.py +16 -0
  33. venv/Lib/site-packages/adodbapi/setup.py +68 -0
  34. venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/INSTALLER +1 -0
  35. venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/LICENSE +279 -0
  36. venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/METADATA +123 -0
  37. venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/RECORD +16 -0
  38. venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/WHEEL +4 -0
  39. venv/Lib/site-packages/aiohappyeyeballs/__init__.py +14 -0
  40. venv/Lib/site-packages/aiohappyeyeballs/_staggered.py +207 -0
  41. venv/Lib/site-packages/aiohappyeyeballs/impl.py +259 -0
  42. venv/Lib/site-packages/aiohappyeyeballs/py.typed +0 -0
  43. venv/Lib/site-packages/aiohappyeyeballs/types.py +17 -0
  44. venv/Lib/site-packages/aiohappyeyeballs/utils.py +97 -0
  45. venv/Lib/site-packages/aiohttp/abc.py +253 -0
  46. venv/Lib/site-packages/aiohttp/base_protocol.py +100 -0
  47. venv/Lib/site-packages/scipy-1.15.3-cp312-cp312-win_amd64.whl +0 -0
  48. venv/Lib/site-packages/six.py +1003 -0
  49. venv/Lib/site-packages/threadpoolctl.py +1292 -0
  50. venv/Lib/site-packages/typing_extensions.py +0 -0
venv/Lib/site-packages/accelerate-1.6.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
venv/Lib/site-packages/accelerate-1.6.0.dist-info/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
venv/Lib/site-packages/accelerate-1.6.0.dist-info/METADATA ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: accelerate
3
+ Version: 1.6.0
4
+ Summary: Accelerate
5
+ Home-page: https://github.com/huggingface/accelerate
6
+ Author: The HuggingFace team
7
+ Author-email: [email protected]
8
+ License: Apache
9
+ Keywords: deep learning
10
+ Classifier: Development Status :: 5 - Production/Stable
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: Apache Software License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.9.0
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: numpy<3.0.0,>=1.17
23
+ Requires-Dist: packaging>=20.0
24
+ Requires-Dist: psutil
25
+ Requires-Dist: pyyaml
26
+ Requires-Dist: torch>=2.0.0
27
+ Requires-Dist: huggingface-hub>=0.21.0
28
+ Requires-Dist: safetensors>=0.4.3
29
+ Provides-Extra: deepspeed
30
+ Requires-Dist: deepspeed; extra == "deepspeed"
31
+ Provides-Extra: dev
32
+ Requires-Dist: black~=23.1; extra == "dev"
33
+ Requires-Dist: hf-doc-builder>=0.3.0; extra == "dev"
34
+ Requires-Dist: ruff~=0.11.2; extra == "dev"
35
+ Requires-Dist: pytest<=8.0.0,>=7.2.0; extra == "dev"
36
+ Requires-Dist: pytest-xdist; extra == "dev"
37
+ Requires-Dist: pytest-subtests; extra == "dev"
38
+ Requires-Dist: parameterized; extra == "dev"
39
+ Requires-Dist: pytest-order; extra == "dev"
40
+ Requires-Dist: datasets; extra == "dev"
41
+ Requires-Dist: diffusers; extra == "dev"
42
+ Requires-Dist: evaluate; extra == "dev"
43
+ Requires-Dist: torchdata>=0.8.0; extra == "dev"
44
+ Requires-Dist: torchpippy>=0.2.0; extra == "dev"
45
+ Requires-Dist: transformers; extra == "dev"
46
+ Requires-Dist: scipy; extra == "dev"
47
+ Requires-Dist: scikit-learn; extra == "dev"
48
+ Requires-Dist: tqdm; extra == "dev"
49
+ Requires-Dist: bitsandbytes; extra == "dev"
50
+ Requires-Dist: timm; extra == "dev"
51
+ Requires-Dist: rich; extra == "dev"
52
+ Provides-Extra: docs
53
+ Provides-Extra: quality
54
+ Requires-Dist: black~=23.1; extra == "quality"
55
+ Requires-Dist: hf-doc-builder>=0.3.0; extra == "quality"
56
+ Requires-Dist: ruff~=0.11.2; extra == "quality"
57
+ Provides-Extra: rich
58
+ Requires-Dist: rich; extra == "rich"
59
+ Provides-Extra: sagemaker
60
+ Requires-Dist: sagemaker; extra == "sagemaker"
61
+ Provides-Extra: test_dev
62
+ Requires-Dist: datasets; extra == "test-dev"
63
+ Requires-Dist: diffusers; extra == "test-dev"
64
+ Requires-Dist: evaluate; extra == "test-dev"
65
+ Requires-Dist: torchdata>=0.8.0; extra == "test-dev"
66
+ Requires-Dist: torchpippy>=0.2.0; extra == "test-dev"
67
+ Requires-Dist: transformers; extra == "test-dev"
68
+ Requires-Dist: scipy; extra == "test-dev"
69
+ Requires-Dist: scikit-learn; extra == "test-dev"
70
+ Requires-Dist: tqdm; extra == "test-dev"
71
+ Requires-Dist: bitsandbytes; extra == "test-dev"
72
+ Requires-Dist: timm; extra == "test-dev"
73
+ Provides-Extra: test_prod
74
+ Requires-Dist: pytest<=8.0.0,>=7.2.0; extra == "test-prod"
75
+ Requires-Dist: pytest-xdist; extra == "test-prod"
76
+ Requires-Dist: pytest-subtests; extra == "test-prod"
77
+ Requires-Dist: parameterized; extra == "test-prod"
78
+ Requires-Dist: pytest-order; extra == "test-prod"
79
+ Provides-Extra: test_trackers
80
+ Requires-Dist: wandb; extra == "test-trackers"
81
+ Requires-Dist: comet-ml; extra == "test-trackers"
82
+ Requires-Dist: tensorboard; extra == "test-trackers"
83
+ Requires-Dist: dvclive; extra == "test-trackers"
84
+ Requires-Dist: mlflow; extra == "test-trackers"
85
+ Requires-Dist: matplotlib; extra == "test-trackers"
86
+ Provides-Extra: testing
87
+ Requires-Dist: pytest<=8.0.0,>=7.2.0; extra == "testing"
88
+ Requires-Dist: pytest-xdist; extra == "testing"
89
+ Requires-Dist: pytest-subtests; extra == "testing"
90
+ Requires-Dist: parameterized; extra == "testing"
91
+ Requires-Dist: pytest-order; extra == "testing"
92
+ Requires-Dist: datasets; extra == "testing"
93
+ Requires-Dist: diffusers; extra == "testing"
94
+ Requires-Dist: evaluate; extra == "testing"
95
+ Requires-Dist: torchdata>=0.8.0; extra == "testing"
96
+ Requires-Dist: torchpippy>=0.2.0; extra == "testing"
97
+ Requires-Dist: transformers; extra == "testing"
98
+ Requires-Dist: scipy; extra == "testing"
99
+ Requires-Dist: scikit-learn; extra == "testing"
100
+ Requires-Dist: tqdm; extra == "testing"
101
+ Requires-Dist: bitsandbytes; extra == "testing"
102
+ Requires-Dist: timm; extra == "testing"
103
+
104
+ <!---
105
+ Copyright 2021 The HuggingFace Team. All rights reserved.
106
+
107
+ Licensed under the Apache License, Version 2.0 (the "License");
108
+ you may not use this file except in compliance with the License.
109
+ You may obtain a copy of the License at
110
+
111
+ http://www.apache.org/licenses/LICENSE-2.0
112
+
113
+ Unless required by applicable law or agreed to in writing, software
114
+ distributed under the License is distributed on an "AS IS" BASIS,
115
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116
+ See the License for the specific language governing permissions and
117
+ limitations under the License.
118
+ -->
119
+
120
+ <p align="center">
121
+ <br>
122
+ <img src="https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/accelerate_logo.png" width="400"/>
123
+ <br>
124
+ <p>
125
+
126
+ <p align="center">
127
+ <!-- Uncomment when CircleCI is set up
128
+ <a href="https://circleci.com/gh/huggingface/accelerate"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/master"></a>
129
+ -->
130
+ <a href="https://github.com/huggingface/accelerate/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/accelerate.svg?color=blue"></a>
131
+ <a href="https://huggingface.co/docs/accelerate/index.html"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/accelerate/index.html.svg?down_color=red&down_message=offline&up_message=online"></a>
132
+ <a href="https://github.com/huggingface/accelerate/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/accelerate.svg"></a>
133
+ <a href="https://github.com/huggingface/accelerate/blob/main/CODE_OF_CONDUCT.md"><img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-v2.0%20adopted-ff69b4.svg"></a>
134
+ </p>
135
+
136
+ <h3 align="center">
137
+ <p>Run your *raw* PyTorch training script on any kind of device
138
+ </h3>
139
+
140
+ <h3 align="center">
141
+ <a href="https://hf.co/course"><img src="https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/course_banner.png"></a>
142
+ </h3>
143
+
144
+ ## Easy to integrate
145
+
146
+ 🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.
147
+
148
+ 🤗 Accelerate abstracts exactly and only the boilerplate code related to multi-GPUs/TPU/fp16 and leaves the rest of your code unchanged.
149
+
150
+ Here is an example:
151
+
152
+ ```diff
153
+ import torch
154
+ import torch.nn.functional as F
155
+ from datasets import load_dataset
156
+ + from accelerate import Accelerator
157
+
158
+ + accelerator = Accelerator()
159
+ - device = 'cpu'
160
+ + device = accelerator.device
161
+
162
+ model = torch.nn.Transformer().to(device)
163
+ optimizer = torch.optim.Adam(model.parameters())
164
+
165
+ dataset = load_dataset('my_dataset')
166
+ data = torch.utils.data.DataLoader(dataset, shuffle=True)
167
+
168
+ + model, optimizer, data = accelerator.prepare(model, optimizer, data)
169
+
170
+ model.train()
171
+ for epoch in range(10):
172
+ for source, targets in data:
173
+ source = source.to(device)
174
+ targets = targets.to(device)
175
+
176
+ optimizer.zero_grad()
177
+
178
+ output = model(source)
179
+ loss = F.cross_entropy(output, targets)
180
+
181
+ - loss.backward()
182
+ + accelerator.backward(loss)
183
+
184
+ optimizer.step()
185
+ ```
186
+
187
+ As you can see in this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp8, fp16, bf16).
188
+
189
+ In particular, the same code can then be run without modification on your local machine for debugging or your training environment.
190
+
191
+ 🤗 Accelerate even handles the device placement for you (which requires a few more changes to your code, but is safer in general), so you can even simplify your training loop further:
192
+
193
+ ```diff
194
+ import torch
195
+ import torch.nn.functional as F
196
+ from datasets import load_dataset
197
+ + from accelerate import Accelerator
198
+
199
+ - device = 'cpu'
200
+ + accelerator = Accelerator()
201
+
202
+ - model = torch.nn.Transformer().to(device)
203
+ + model = torch.nn.Transformer()
204
+ optimizer = torch.optim.Adam(model.parameters())
205
+
206
+ dataset = load_dataset('my_dataset')
207
+ data = torch.utils.data.DataLoader(dataset, shuffle=True)
208
+
209
+ + model, optimizer, data = accelerator.prepare(model, optimizer, data)
210
+
211
+ model.train()
212
+ for epoch in range(10):
213
+ for source, targets in data:
214
+ - source = source.to(device)
215
+ - targets = targets.to(device)
216
+
217
+ optimizer.zero_grad()
218
+
219
+ output = model(source)
220
+ loss = F.cross_entropy(output, targets)
221
+
222
+ - loss.backward()
223
+ + accelerator.backward(loss)
224
+
225
+ optimizer.step()
226
+ ```
227
+
228
+ Want to learn more? Check out the [documentation](https://huggingface.co/docs/accelerate) or have a look at our [examples](https://github.com/huggingface/accelerate/tree/main/examples).
229
+
230
+ ## Launching script
231
+
232
+ 🤗 Accelerate also provides an optional CLI tool that allows you to quickly configure and test your training environment before launching the scripts. No need to remember how to use `torch.distributed.run` or to write a specific launcher for TPU training!
233
+ On your machine(s) just run:
234
+
235
+ ```bash
236
+ accelerate config
237
+ ```
238
+
239
+ and answer the questions asked. This will generate a config file that will be used automatically to properly set the default options when doing
240
+
241
+ ```bash
242
+ accelerate launch my_script.py --args_to_my_script
243
+ ```
244
+
245
+ For instance, here is how you would run the GLUE example on the MRPC task (from the root of the repo):
246
+
247
+ ```bash
248
+ accelerate launch examples/nlp_example.py
249
+ ```
250
+
251
+ This CLI tool is **optional**, and you can still use `python my_script.py` or `python -m torchrun my_script.py` at your convenience.
252
+
253
+ You can also directly pass in the arguments you would to `torchrun` as arguments to `accelerate launch` if you wish to not run` accelerate config`.
254
+
255
+ For example, here is how to launch on two GPUs:
256
+
257
+ ```bash
258
+ accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py
259
+ ```
260
+
261
+ To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).
262
+
263
+ Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)
264
+
265
+ ## Launching multi-CPU run using MPI
266
+
267
+ 🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
268
+ Once you have MPI setup on your cluster, just run:
269
+ ```bash
270
+ accelerate config
271
+ ```
272
+ Answer the questions that are asked, selecting to run using multi-CPU, and answer "yes" when asked if you want accelerate to launch mpirun.
273
+ Then, use `accelerate launch` with your script like:
274
+ ```bash
275
+ accelerate launch examples/nlp_example.py
276
+ ```
277
+ Alternatively, you can use mpirun directly, without using the CLI like:
278
+ ```bash
279
+ mpirun -np 2 python examples/nlp_example.py
280
+ ```
281
+
282
+ ## Launching training using DeepSpeed
283
+
284
+ 🤗 Accelerate supports training on single/multiple GPUs using DeepSpeed. To use it, you don't need to change anything in your training code; you can set everything using just `accelerate config`. However, if you desire to tweak your DeepSpeed related args from your Python script, we provide you the `DeepSpeedPlugin`.
285
+
286
+ ```python
287
+ from accelerate import Accelerator, DeepSpeedPlugin
288
+
289
+ # deepspeed needs to know your gradient accumulation steps beforehand, so don't forget to pass it
290
+ # Remember you still need to do gradient accumulation by yourself, just like you would have done without deepspeed
291
+ deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
292
+ accelerator = Accelerator(mixed_precision='fp16', deepspeed_plugin=deepspeed_plugin)
293
+
294
+ # How to save your 🤗 Transformer?
295
+ accelerator.wait_for_everyone()
296
+ unwrapped_model = accelerator.unwrap_model(model)
297
+ unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
298
+ ```
299
+
300
+ Note: DeepSpeed support is experimental for now. In case you get into some problem, please open an issue.
301
+
302
+ ## Launching your training from a notebook
303
+
304
+ 🤗 Accelerate also provides a `notebook_launcher` function you can use in a notebook to launch a distributed training. This is especially useful for Colab or Kaggle notebooks with a TPU backend. Just define your training loop in a `training_function` then in your last cell, add:
305
+
306
+ ```python
307
+ from accelerate import notebook_launcher
308
+
309
+ notebook_launcher(training_function)
310
+ ```
311
+
312
+ An example can be found in [this notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb). [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb)
313
+
314
+ ## Why should I use 🤗 Accelerate?
315
+
316
+ You should use 🤗 Accelerate when you want to easily run your training scripts in a distributed environment without having to renounce full control over your training loop. This is not a high-level framework above PyTorch, just a thin wrapper so you don't have to learn a new library. In fact, the whole API of 🤗 Accelerate is in one class, the `Accelerator` object.
317
+
318
+ ## Why shouldn't I use 🤗 Accelerate?
319
+
320
+ You shouldn't use 🤗 Accelerate if you don't want to write a training loop yourself. There are plenty of high-level libraries above PyTorch that will offer you that, 🤗 Accelerate is not one of them.
321
+
322
+ ## Frameworks using 🤗 Accelerate
323
+
324
+ If you like the simplicity of 🤗 Accelerate but would prefer a higher-level abstraction around its capabilities, some frameworks and libraries that are built on top of 🤗 Accelerate are listed below:
325
+
326
+ * [Amphion](https://github.com/open-mmlab/Amphion) is a toolkit for Audio, Music, and Speech Generation. Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development.
327
+ * [Animus](https://github.com/Scitator/animus) is a minimalistic framework to run machine learning experiments. Animus highlights common "breakpoints" in ML experiments and provides a unified interface for them within [IExperiment](https://github.com/Scitator/animus/blob/main/animus/core.py#L76).
328
+ * [Catalyst](https://github.com/catalyst-team/catalyst#getting-started) is a PyTorch framework for Deep Learning Research and Development. It focuses on reproducibility, rapid experimentation, and codebase reuse so you can create something new rather than write yet another train loop. Catalyst provides a [Runner](https://catalyst-team.github.io/catalyst/api/core.html#runner) to connect all parts of the experiment: hardware backend, data transformations, model training, and inference logic.
329
+ * [fastai](https://github.com/fastai/fastai#installing) is a PyTorch framework for Deep Learning that simplifies training fast and accurate neural nets using modern best practices. fastai provides a [Learner](https://docs.fast.ai/learner.html#Learner) to handle the training, fine-tuning, and inference of deep learning algorithms.
330
+ * [Finetuner](https://github.com/jina-ai/finetuner) is a service that enables models to create higher-quality embeddings for semantic search, visual similarity search, cross-modal text<->image search, recommendation systems, clustering, duplication detection, anomaly detection, or other uses.
331
+ * [InvokeAI](https://github.com/invoke-ai/InvokeAI) is a creative engine for Stable Diffusion models, offering industry-leading WebUI, terminal usage support, and serves as the foundation for many commercial products.
332
+ * [Kornia](https://kornia.readthedocs.io/en/latest/get-started/introduction.html) is a differentiable library that allows classical computer vision to be integrated into deep learning models. Kornia provides a [Trainer](https://kornia.readthedocs.io/en/latest/x.html#kornia.x.Trainer) with the specific purpose to train and fine-tune the supported deep learning algorithms within the library.
333
+ * [Open Assistant](https://projects.laion.ai/Open-Assistant/) is a chat-based assistant that understands tasks, can interact with their party systems, and retrieve information dynamically to do so.
334
+ * [pytorch-accelerated](https://github.com/Chris-hughes10/pytorch-accelerated) is a lightweight training library, with a streamlined feature set centered around a general-purpose [Trainer](https://pytorch-accelerated.readthedocs.io/en/latest/trainer.html), that places a huge emphasis on simplicity and transparency; enabling users to understand exactly what is going on under the hood, but without having to write and maintain the boilerplate themselves!
335
+ * [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is an open-source browser-based easy-to-use interface based on the Gradio library for Stable Diffusion.
336
+ * [torchkeras](https://github.com/lyhue1991/torchkeras) is a simple tool for training pytorch model just in a keras style, a dynamic and beautiful plot is provided in notebook to monitor your loss or metric.
337
+ * [transformers](https://github.com/huggingface/transformers) as a tool for helping train state-of-the-art machine learning models in PyTorch, Tensorflow, and JAX. (Accelerate is the backend for the PyTorch side).
338
+
339
+
340
+ ## Installation
341
+
342
+ This repository is tested on Python 3.8+ and PyTorch 1.10.0+
343
+
344
+ You should install 🤗 Accelerate in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
345
+
346
+ First, create a virtual environment with the version of Python you're going to use and activate it.
347
+
348
+ Then, you will need to install PyTorch: refer to the [official installation page](https://pytorch.org/get-started/locally/#start-locally) regarding the specific install command for your platform. Then 🤗 Accelerate can be installed using pip as follows:
349
+
350
+ ```bash
351
+ pip install accelerate
352
+ ```
353
+
354
+ ## Supported integrations
355
+
356
+ - CPU only
357
+ - multi-CPU on one node (machine)
358
+ - multi-CPU on several nodes (machines)
359
+ - single GPU
360
+ - multi-GPU on one node (machine)
361
+ - multi-GPU on several nodes (machines)
362
+ - TPU
363
+ - FP16/BFloat16 mixed precision
364
+ - FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)
365
+ - DeepSpeed support (Experimental)
366
+ - PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
367
+ - Megatron-LM support (Experimental)
368
+
369
+ ## Citing 🤗 Accelerate
370
+
371
+ If you use 🤗 Accelerate in your publication, please cite it by using the following BibTeX entry.
372
+
373
+ ```bibtex
374
+ @Misc{accelerate,
375
+ title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.},
376
+ author = {Sylvain Gugger and Lysandre Debut and Thomas Wolf and Philipp Schmid and Zachary Mueller and Sourab Mangrulkar and Marc Sun and Benjamin Bossan},
377
+ howpublished = {\url{https://github.com/huggingface/accelerate}},
378
+ year = {2022}
379
+ }
380
+ ```
venv/Lib/site-packages/accelerate-1.6.0.dist-info/RECORD ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ../../Scripts/accelerate-config.exe,sha256=oMHvUIO20oc9e7mTWqdxwnwE2vt6jKFxV5U0295vjGQ,108433
2
+ ../../Scripts/accelerate-estimate-memory.exe,sha256=ptuggVnh4A7ZaLZFPsmF8CMf3_QPEi7MFOshFkZtg_E,108435
3
+ ../../Scripts/accelerate-launch.exe,sha256=n9Bd7LTGgPp2bVXE9A73FDCFI_brpmNFws404EoTC_g,108433
4
+ ../../Scripts/accelerate-merge-weights.exe,sha256=PSkH501EplMRrCYdZqTr8qK-VUHOeIpuD9UzcDuJ6oQ,108432
5
+ ../../Scripts/accelerate.exe,sha256=InYSaN6P9U5H6tX0Xjkm5K1FvyZwbIyQmXJY2WlrW6s,108441
6
+ accelerate-1.6.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
7
+ accelerate-1.6.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ accelerate-1.6.0.dist-info/METADATA,sha256=zT5ADQHZZeLT4qEiGMNSG4cT7hCnQplwyshDyeDyZNo,19421
9
+ accelerate-1.6.0.dist-info/RECORD,,
10
+ accelerate-1.6.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ accelerate-1.6.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
12
+ accelerate-1.6.0.dist-info/entry_points.txt,sha256=Vpy8gUGfZ-1VnM2229fb8CpJNLBdMH_wtJ9PQ7b_2tQ,296
13
+ accelerate-1.6.0.dist-info/top_level.txt,sha256=esVfdxTidsjQ90zsN_rPpjLFJ4ijRlx4mnLrG09hlt4,11
14
+ accelerate/__init__.py,sha256=r3I-pArsQK9ZrH3XgnjeCoXo4l-DEFOWQhjj3BguTZc,1504
15
+ accelerate/__pycache__/__init__.cpython-312.pyc,,
16
+ accelerate/__pycache__/accelerator.cpython-312.pyc,,
17
+ accelerate/__pycache__/big_modeling.cpython-312.pyc,,
18
+ accelerate/__pycache__/checkpointing.cpython-312.pyc,,
19
+ accelerate/__pycache__/data_loader.cpython-312.pyc,,
20
+ accelerate/__pycache__/hooks.cpython-312.pyc,,
21
+ accelerate/__pycache__/inference.cpython-312.pyc,,
22
+ accelerate/__pycache__/launchers.cpython-312.pyc,,
23
+ accelerate/__pycache__/local_sgd.cpython-312.pyc,,
24
+ accelerate/__pycache__/logging.cpython-312.pyc,,
25
+ accelerate/__pycache__/memory_utils.cpython-312.pyc,,
26
+ accelerate/__pycache__/optimizer.cpython-312.pyc,,
27
+ accelerate/__pycache__/scheduler.cpython-312.pyc,,
28
+ accelerate/__pycache__/state.cpython-312.pyc,,
29
+ accelerate/__pycache__/tracking.cpython-312.pyc,,
30
+ accelerate/accelerator.py,sha256=G952noNHGPrl-poK6qAj1OY32kGjmN5S13v8zy7H63E,173175
31
+ accelerate/big_modeling.py,sha256=IMiAtiuZQpwSyk2jQsoYC2uWzfRUSpCg7FiThSvjfKw,29702
32
+ accelerate/checkpointing.py,sha256=BaDOrpQzRI2U1BvN2vK4lepTRNNqbxGd4QPa1zOShoc,13612
33
+ accelerate/commands/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606
34
+ accelerate/commands/__pycache__/__init__.cpython-312.pyc,,
35
+ accelerate/commands/__pycache__/accelerate_cli.cpython-312.pyc,,
36
+ accelerate/commands/__pycache__/env.cpython-312.pyc,,
37
+ accelerate/commands/__pycache__/estimate.cpython-312.pyc,,
38
+ accelerate/commands/__pycache__/launch.cpython-312.pyc,,
39
+ accelerate/commands/__pycache__/merge.cpython-312.pyc,,
40
+ accelerate/commands/__pycache__/test.cpython-312.pyc,,
41
+ accelerate/commands/__pycache__/to_fsdp2.cpython-312.pyc,,
42
+ accelerate/commands/__pycache__/tpu.cpython-312.pyc,,
43
+ accelerate/commands/__pycache__/utils.cpython-312.pyc,,
44
+ accelerate/commands/accelerate_cli.py,sha256=SkwFad6Z1ZsGjtm7TiXFq8je-akshp_0WxX_6rGSBw8,1972
45
+ accelerate/commands/config/__init__.py,sha256=iJK8dgj3pc5Vdr1E7UuGoFu-BlybyXLxYDoTg9gXngE,1645
46
+ accelerate/commands/config/__pycache__/__init__.cpython-312.pyc,,
47
+ accelerate/commands/config/__pycache__/cluster.cpython-312.pyc,,
48
+ accelerate/commands/config/__pycache__/config.cpython-312.pyc,,
49
+ accelerate/commands/config/__pycache__/config_args.cpython-312.pyc,,
50
+ accelerate/commands/config/__pycache__/config_utils.cpython-312.pyc,,
51
+ accelerate/commands/config/__pycache__/default.cpython-312.pyc,,
52
+ accelerate/commands/config/__pycache__/sagemaker.cpython-312.pyc,,
53
+ accelerate/commands/config/__pycache__/update.cpython-312.pyc,,
54
+ accelerate/commands/config/cluster.py,sha256=w0L3zTyZp4sjDpCrM3NxOjxZ0kyPJZmzi06pFZmbM2c,37472
55
+ accelerate/commands/config/config.py,sha256=FuRlQvOjgATEtyqOSsGD-KEtOCvACOHjs2C-krrtldk,3035
56
+ accelerate/commands/config/config_args.py,sha256=xn6M8iJnlFycosDlbM0BE86r9RxfdDwHtIlk-UUq7UM,10082
57
+ accelerate/commands/config/config_utils.py,sha256=mdvZE9fpllfD8S4Blhqk3nLqQ5m14WJ0jQ1yh768H10,3177
58
+ accelerate/commands/config/default.py,sha256=sPgQVt_0zk68KlupQFqt8B6JUoPMFPxXmXr7xFM-EN8,6212
59
+ accelerate/commands/config/sagemaker.py,sha256=GjHE2-h4tRr1P_PFtMF3miiAtJlzkbHbMb6kFXqn8eo,10341
60
+ accelerate/commands/config/update.py,sha256=NXW1J7GkUHpg71QlIXsmMB_0z8S8IZo2FWax5POwrhc,2395
61
+ accelerate/commands/env.py,sha256=-B3FPX4S705A-P_tyLKm_JzGpz-TeKqFNPdNWDAdGIM,4156
62
+ accelerate/commands/estimate.py,sha256=Qduq4xudVyIede37BMEe1rNhXf-rfW-MHV2KtwxdfEA,12585
63
+ accelerate/commands/launch.py,sha256=7DI42Uw4kf_peOpY5TUA1V2yz7cuSO3cYnLgiI5G1Vs,47496
64
+ accelerate/commands/menu/__init__.py,sha256=uqSlBM0TFHBwzdv3p3SXfpAk1lZFp4h1a7mbBdscPHs,645
65
+ accelerate/commands/menu/__pycache__/__init__.cpython-312.pyc,,
66
+ accelerate/commands/menu/__pycache__/cursor.cpython-312.pyc,,
67
+ accelerate/commands/menu/__pycache__/helpers.cpython-312.pyc,,
68
+ accelerate/commands/menu/__pycache__/input.cpython-312.pyc,,
69
+ accelerate/commands/menu/__pycache__/keymap.cpython-312.pyc,,
70
+ accelerate/commands/menu/__pycache__/selection_menu.cpython-312.pyc,,
71
+ accelerate/commands/menu/cursor.py,sha256=-lmpJVAzvNc0c3EOtSuLoKB59zqylVCbYyWLPnrOmvQ,2028
72
+ accelerate/commands/menu/helpers.py,sha256=KrSB5fJjH4MUEUAQJ6bYaN16AYcnl9UalDrPD3DYeeg,1483
73
+ accelerate/commands/menu/input.py,sha256=T8Mdd-Y_OURgqfDV9qZh4Wf6hmT22AneNtJzj4JA1Rk,2512
74
+ accelerate/commands/menu/keymap.py,sha256=eXj-suyYs1m5dEHoUKN4mKAMLc8DWHnwhP6G6JSU0jQ,4086
75
+ accelerate/commands/menu/selection_menu.py,sha256=bxy-DHaKKC6SCToOlMBv5_z0MdUzylEg6Sio9OuV3GM,4921
76
+ accelerate/commands/merge.py,sha256=quDKckN3vKn9nsGjdwfoojnfTMFdKRRUkY1DYuuNNmc,2388
77
+ accelerate/commands/test.py,sha256=YrPYEaAACOGZ6btn2MV6NbMSEdBUcMWADLbQWaZSHtk,2149
78
+ accelerate/commands/to_fsdp2.py,sha256=gfbhoUT4qFB3LVDMNmckElgLG0yWm8aj_aofszeiJmM,5991
79
+ accelerate/commands/tpu.py,sha256=KyxDP7IuveidZrbW4rx2s8Ku3o_ptI6tzwr_R7ck0os,5548
80
+ accelerate/commands/utils.py,sha256=aT8xUCe2pCkFII7yZxcfaohEjgBAzMUM7WiD4UuWSOY,4150
81
+ accelerate/data_loader.py,sha256=yArisKhfuIJzDD7vuOgZAqEJNUC8tgl2L8ay92rgtfY,64551
82
+ accelerate/hooks.py,sha256=lYtYSIqEQnZOImgj2UMTngQPkcQDEHS2klwak1oHD6w,32248
83
+ accelerate/inference.py,sha256=NLANdzXm5PwmDWbPYkFmoRoQSLLvuhfvIG33xfpapT0,7668
84
+ accelerate/launchers.py,sha256=QIqUVkDc-oTmWf00L8kas7u2RBEwOYoRi8M2Our0DAs,13721
85
+ accelerate/local_sgd.py,sha256=aCj_yqXK_FhhZRWEpzXIkgXBERH6fC3HyrC3nsOj1uA,4160
86
+ accelerate/logging.py,sha256=4XcgY_BV7Qn_enh2tZ-8fNtuaE_3n-LsYJbgwhRx_PI,5042
87
+ accelerate/memory_utils.py,sha256=3R5LoeHl6GgTZ-IMPrDZMdaEehWarGdPqODushb-6pg,862
88
+ accelerate/optimizer.py,sha256=QfgCkQ5dA-fLSi_Z7CBPRCObXA1rL9zxHg4tyKCEg2A,8113
89
+ accelerate/scheduler.py,sha256=des_4M_Tt1W8gCYZZbLla0GHBEgJY3Wx2EGBQPTzeiY,4238
90
+ accelerate/state.py,sha256=YYpuPqXeNjz5_Y71h0zmCu13cBuDmQ8lw6fAmoSWUFk,55457
91
+ accelerate/test_utils/__init__.py,sha256=8xikmLMAM6_6CwVF6tsdsv4XzgWkHAk2tZBdV9DxIH8,1749
92
+ accelerate/test_utils/__pycache__/__init__.cpython-312.pyc,,
93
+ accelerate/test_utils/__pycache__/examples.cpython-312.pyc,,
94
+ accelerate/test_utils/__pycache__/testing.cpython-312.pyc,,
95
+ accelerate/test_utils/__pycache__/training.cpython-312.pyc,,
96
+ accelerate/test_utils/examples.py,sha256=IN4n2lxA95hexE2rojsyyjhpXLbXnbmjTzd8UTws5_4,7257
97
+ accelerate/test_utils/scripts/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606
98
+ accelerate/test_utils/scripts/__pycache__/__init__.cpython-312.pyc,,
99
+ accelerate/test_utils/scripts/__pycache__/test_cli.cpython-312.pyc,,
100
+ accelerate/test_utils/scripts/__pycache__/test_ddp_comm_hook.cpython-312.pyc,,
101
+ accelerate/test_utils/scripts/__pycache__/test_distributed_data_loop.cpython-312.pyc,,
102
+ accelerate/test_utils/scripts/__pycache__/test_merge_weights.cpython-312.pyc,,
103
+ accelerate/test_utils/scripts/__pycache__/test_notebook.cpython-312.pyc,,
104
+ accelerate/test_utils/scripts/__pycache__/test_ops.cpython-312.pyc,,
105
+ accelerate/test_utils/scripts/__pycache__/test_script.cpython-312.pyc,,
106
+ accelerate/test_utils/scripts/__pycache__/test_sync.cpython-312.pyc,,
107
+ accelerate/test_utils/scripts/external_deps/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606
108
+ accelerate/test_utils/scripts/external_deps/__pycache__/__init__.cpython-312.pyc,,
109
+ accelerate/test_utils/scripts/external_deps/__pycache__/test_checkpointing.cpython-312.pyc,,
110
+ accelerate/test_utils/scripts/external_deps/__pycache__/test_ds_multiple_model.cpython-312.pyc,,
111
+ accelerate/test_utils/scripts/external_deps/__pycache__/test_metrics.cpython-312.pyc,,
112
+ accelerate/test_utils/scripts/external_deps/__pycache__/test_peak_memory_usage.cpython-312.pyc,,
113
+ accelerate/test_utils/scripts/external_deps/__pycache__/test_performance.cpython-312.pyc,,
114
+ accelerate/test_utils/scripts/external_deps/__pycache__/test_pippy.cpython-312.pyc,,
115
+ accelerate/test_utils/scripts/external_deps/__pycache__/test_zero3_integration.cpython-312.pyc,,
116
+ accelerate/test_utils/scripts/external_deps/test_checkpointing.py,sha256=XHaNRmnrARd1izXFjWGi5UjYGas-4vqayW51jAHBPCA,10699
117
+ accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py,sha256=Cg4-h0B4UcOQ5CxXjIdrsPVR5fFsWCv24DqZGjXEwW8,13790
118
+ accelerate/test_utils/scripts/external_deps/test_metrics.py,sha256=Ev2XKaiwmznoxKujskAAuISGChW646MOiyf0CXEPb9Y,12168
119
+ accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py,sha256=9Yn9Rc7d-yWr1fU0RagASPG5l8vrKeHVYbuYABbA-fU,12498
120
+ accelerate/test_utils/scripts/external_deps/test_performance.py,sha256=Di6LT19bCBLlWmCBSu_jjdqR2EqngXpvUOGDBx8GfZE,10432
121
+ accelerate/test_utils/scripts/external_deps/test_pippy.py,sha256=ocZntbmAduln2ma4LeEA9o-S8hla3YXCJ_A8hEcWHgs,4762
122
+ accelerate/test_utils/scripts/external_deps/test_zero3_integration.py,sha256=P9alBOHZ9Lfqs5LoRP7bCbXl-tnsNrBkvJZGseibBeA,1665
123
+ accelerate/test_utils/scripts/test_cli.py,sha256=qfk1aYFtdvYFCYPkl05602SNGvk08QTv0xZVVcFVtzM,833
124
+ accelerate/test_utils/scripts/test_ddp_comm_hook.py,sha256=k_-2MBjLKNdMGIcneTbuGd84K05Wp1GEQX6DUVF9UBw,3566
125
+ accelerate/test_utils/scripts/test_distributed_data_loop.py,sha256=RUWTwd7DIpr2fl7JtKOsvTjMiJioTxO8FdSr2Lw_5uI,15137
126
+ accelerate/test_utils/scripts/test_merge_weights.py,sha256=dssMnAoZt291vNLbPhPOTQUooh0leg_0erQh0uZH6aU,6125
127
+ accelerate/test_utils/scripts/test_notebook.py,sha256=qfIy3IvH74-kGn8nadBn_k7qrviqvsxy5ijsnUhuY6o,3894
128
+ accelerate/test_utils/scripts/test_ops.py,sha256=Bcs-h8EMJwULTfbizlFN5qkv3JraWEpoSZWMn-HswiI,6265
129
+ accelerate/test_utils/scripts/test_script.py,sha256=8-53hIVQXD28HQT4h2Ijy6yGCHfTWDAf1-HOi4UtDng,34219
130
+ accelerate/test_utils/scripts/test_sync.py,sha256=PDe8sYZLCL2LKjj_L9b-Bh2BjAjeii9EZ8sZNfuYx5s,18817
131
+ accelerate/test_utils/testing.py,sha256=x9RK70VgAMyHlo5xut7P85j-9kdAnlfQe_4jwSPpMv4,27807
132
+ accelerate/test_utils/training.py,sha256=jO5YEIr34jAcnJ_9WNp_x3zuHzSam_I6IgMvmcGm7yI,6456
133
+ accelerate/tracking.py,sha256=ucpsoYAT3pVXgOfwDdXf6qTugY2-tk-EINvZtfmRitM,42756
134
+ accelerate/utils/__init__.py,sha256=wjpXyvFxS-ed3Stwm_IHIlBmsmP7KRyAljQ_Qss-OWw,7802
135
+ accelerate/utils/__pycache__/__init__.cpython-312.pyc,,
136
+ accelerate/utils/__pycache__/ao.cpython-312.pyc,,
137
+ accelerate/utils/__pycache__/bnb.cpython-312.pyc,,
138
+ accelerate/utils/__pycache__/constants.cpython-312.pyc,,
139
+ accelerate/utils/__pycache__/dataclasses.cpython-312.pyc,,
140
+ accelerate/utils/__pycache__/deepspeed.cpython-312.pyc,,
141
+ accelerate/utils/__pycache__/environment.cpython-312.pyc,,
142
+ accelerate/utils/__pycache__/fsdp_utils.cpython-312.pyc,,
143
+ accelerate/utils/__pycache__/imports.cpython-312.pyc,,
144
+ accelerate/utils/__pycache__/launch.cpython-312.pyc,,
145
+ accelerate/utils/__pycache__/megatron_lm.cpython-312.pyc,,
146
+ accelerate/utils/__pycache__/memory.cpython-312.pyc,,
147
+ accelerate/utils/__pycache__/modeling.cpython-312.pyc,,
148
+ accelerate/utils/__pycache__/offload.cpython-312.pyc,,
149
+ accelerate/utils/__pycache__/operations.cpython-312.pyc,,
150
+ accelerate/utils/__pycache__/other.cpython-312.pyc,,
151
+ accelerate/utils/__pycache__/random.cpython-312.pyc,,
152
+ accelerate/utils/__pycache__/rich.cpython-312.pyc,,
153
+ accelerate/utils/__pycache__/torch_xla.cpython-312.pyc,,
154
+ accelerate/utils/__pycache__/tqdm.cpython-312.pyc,,
155
+ accelerate/utils/__pycache__/transformer_engine.cpython-312.pyc,,
156
+ accelerate/utils/__pycache__/versions.cpython-312.pyc,,
157
+ accelerate/utils/ao.py,sha256=koMiji7AG1kJMRMkJnwSnpuycfx4lPY3CNnpNx2ZqzM,4736
158
+ accelerate/utils/bnb.py,sha256=KCbg6LUt4eXvPHVnKh7rSVcPwDnzxY_Ii7yYmK5bNGw,20737
159
+ accelerate/utils/constants.py,sha256=hc24V0pgxWdBQwS6SXxDKwuIni2pCnzdfvMOX1XI9Os,3264
160
+ accelerate/utils/dataclasses.py,sha256=E7CnCbfskpzxzSorst95Via_XE39t0NP_UGYgJUris0,131486
161
+ accelerate/utils/deepspeed.py,sha256=QYIXv5LwHXw7wBFFo-7a0t86MbwNAfieJkkBaLGA6wI,14064
162
+ accelerate/utils/environment.py,sha256=h0zacbBkAp9szltTf5-aTr5NcbVsQp7wl6DFWp8XNuI,15257
163
+ accelerate/utils/fsdp_utils.py,sha256=Q2tc9EakwBjuYlyXvQrBLV97r6cdReRft6KeS1P_Vb4,28938
164
+ accelerate/utils/imports.py,sha256=YI1ebPJAuxarclENTfzvDPPGf6jeEKnVQ42taFPuqh0,16759
165
+ accelerate/utils/launch.py,sha256=nN4ykAtnEL3oITLTejABltdpS3OivcE2COmX-BnWuY4,31195
166
+ accelerate/utils/megatron_lm.py,sha256=FnIF-niZjvdMk9ymafZWEPjDho_Q_P98C69qc9g5r_E,58059
167
+ accelerate/utils/memory.py,sha256=lDHqW7Ue_CPmw_DWgNxX_B3HY71_srAFdgR10XiVRSM,6960
168
+ accelerate/utils/modeling.py,sha256=_xSTiH7zSsffZULSTJuzcDK6IaWImEMOcbq1xqeI7GY,92319
169
+ accelerate/utils/offload.py,sha256=VFaL8oSJzqZ_47VuUQ69xZi9bF2heRSFoOSnnOxbGXc,7825
170
+ accelerate/utils/operations.py,sha256=VWPYvtrO4UGX5JmisanXzLLUbhAeL8kQk0yYc66bQ2M,31055
171
+ accelerate/utils/other.py,sha256=iiLZcKEAlK2Sj_wt03gAEGKrk7_NZFwbmy9cgEppRPw,13231
172
+ accelerate/utils/random.py,sha256=Xv_ZJm9eaC2Q7rgZy9OpOunKuTingMiDQCH00qhNVxE,6220
173
+ accelerate/utils/rich.py,sha256=8JZX_uGMQX-BufdXxJpdne7BWd1KyLHSgbiGxrDMYr8,847
174
+ accelerate/utils/torch_xla.py,sha256=Pq1tuqN0X_pWDVza6YgjfO45uoJdoRVRForLeLQzFus,1908
175
+ accelerate/utils/tqdm.py,sha256=k8e9JnieTEQHCCNBaiBys7hPxWlEbyRASdIma-qy_X8,1657
176
+ accelerate/utils/transformer_engine.py,sha256=498Y3z2BkbybYLtBiuF_TJgt8Iii943s4wgRAV8FDC4,6372
177
+ accelerate/utils/versions.py,sha256=UgmcbjBm--6CIx1ZamSAMjAK_B_2l48LbeaNygqej8M,2149
venv/Lib/site-packages/accelerate-1.6.0.dist-info/REQUESTED ADDED
File without changes
venv/Lib/site-packages/accelerate-1.6.0.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.1.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
venv/Lib/site-packages/accelerate-1.6.0.dist-info/entry_points.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [console_scripts]
2
+ accelerate = accelerate.commands.accelerate_cli:main
3
+ accelerate-config = accelerate.commands.config:main
4
+ accelerate-estimate-memory = accelerate.commands.estimate:main
5
+ accelerate-launch = accelerate.commands.launch:main
6
+ accelerate-merge-weights = accelerate.commands.merge:main
venv/Lib/site-packages/accelerate-1.6.0.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ accelerate
venv/Lib/site-packages/accelerate/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ __version__ = "1.6.0"
15
+
16
+ from .accelerator import Accelerator
17
+ from .big_modeling import (
18
+ cpu_offload,
19
+ cpu_offload_with_hook,
20
+ disk_offload,
21
+ dispatch_model,
22
+ init_empty_weights,
23
+ init_on_device,
24
+ load_checkpoint_and_dispatch,
25
+ )
26
+ from .data_loader import skip_first_batches
27
+ from .inference import prepare_pippy
28
+ from .launchers import debug_launcher, notebook_launcher
29
+ from .state import PartialState
30
+ from .utils import (
31
+ AutocastKwargs,
32
+ DataLoaderConfiguration,
33
+ DDPCommunicationHookType,
34
+ DeepSpeedPlugin,
35
+ DistributedDataParallelKwargs,
36
+ DistributedType,
37
+ FullyShardedDataParallelPlugin,
38
+ GradScalerKwargs,
39
+ InitProcessGroupKwargs,
40
+ ProfileKwargs,
41
+ find_executable_batch_size,
42
+ infer_auto_device_map,
43
+ is_rich_available,
44
+ load_checkpoint_in_model,
45
+ synchronize_rng_states,
46
+ )
47
+
48
+
49
+ if is_rich_available():
50
+ from .utils import rich
venv/Lib/site-packages/accelerate/accelerator.py ADDED
The diff for this file is too large to render. See raw diff
 
venv/Lib/site-packages/accelerate/big_modeling.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ from contextlib import contextmanager
18
+ from functools import wraps
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from .hooks import (
25
+ AlignDevicesHook,
26
+ CpuOffload,
27
+ UserCpuOffloadHook,
28
+ add_hook_to_module,
29
+ attach_align_device_hook,
30
+ attach_align_device_hook_on_blocks,
31
+ )
32
+ from .utils import (
33
+ OffloadedWeightsLoader,
34
+ check_cuda_p2p_ib_support,
35
+ check_device_map,
36
+ extract_submodules_state_dict,
37
+ find_tied_parameters,
38
+ get_balanced_memory,
39
+ infer_auto_device_map,
40
+ is_bnb_available,
41
+ is_mlu_available,
42
+ is_musa_available,
43
+ is_npu_available,
44
+ is_sdaa_available,
45
+ is_xpu_available,
46
+ load_checkpoint_in_model,
47
+ offload_state_dict,
48
+ parse_flag_from_env,
49
+ retie_parameters,
50
+ )
51
+ from .utils.other import recursive_getattr
52
+
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ @contextmanager
58
+ def init_empty_weights(include_buffers: bool = None):
59
+ """
60
+ A context manager under which models are initialized with all parameters on the meta device, therefore creating an
61
+ empty model. Useful when just initializing the model would blow the available RAM.
62
+
63
+ Args:
64
+ include_buffers (`bool`, *optional*):
65
+ Whether or not to also put all buffers on the meta device while initializing.
66
+
67
+ Example:
68
+
69
+ ```python
70
+ import torch.nn as nn
71
+ from accelerate import init_empty_weights
72
+
73
+ # Initialize a model with 100 billions parameters in no time and without using any RAM.
74
+ with init_empty_weights():
75
+ tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
76
+ ```
77
+
78
+ <Tip warning={true}>
79
+
80
+ Any model created under this context manager has no weights. As such you can't do something like
81
+ `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
82
+ Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
83
+ called.
84
+
85
+ </Tip>
86
+ """
87
+ if include_buffers is None:
88
+ include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
89
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
90
+ yield f
91
+
92
+
93
+ @contextmanager
94
+ def init_on_device(device: torch.device, include_buffers: bool = None):
95
+ """
96
+ A context manager under which models are initialized with all parameters on the specified device.
97
+
98
+ Args:
99
+ device (`torch.device`):
100
+ Device to initialize all parameters on.
101
+ include_buffers (`bool`, *optional*):
102
+ Whether or not to also put all buffers on the meta device while initializing.
103
+
104
+ Example:
105
+
106
+ ```python
107
+ import torch.nn as nn
108
+ from accelerate import init_on_device
109
+
110
+ with init_on_device(device=torch.device("cuda")):
111
+ tst = nn.Linear(100, 100) # on `cuda` device
112
+ ```
113
+ """
114
+ if include_buffers is None:
115
+ include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
116
+
117
+ if include_buffers:
118
+ with device:
119
+ yield
120
+ return
121
+
122
+ old_register_parameter = nn.Module.register_parameter
123
+ if include_buffers:
124
+ old_register_buffer = nn.Module.register_buffer
125
+
126
+ def register_empty_parameter(module, name, param):
127
+ old_register_parameter(module, name, param)
128
+ if param is not None:
129
+ param_cls = type(module._parameters[name])
130
+ kwargs = module._parameters[name].__dict__
131
+ kwargs["requires_grad"] = param.requires_grad
132
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
133
+
134
+ def register_empty_buffer(module, name, buffer, persistent=True):
135
+ old_register_buffer(module, name, buffer, persistent=persistent)
136
+ if buffer is not None:
137
+ module._buffers[name] = module._buffers[name].to(device)
138
+
139
+ # Patch tensor creation
140
+ if include_buffers:
141
+ tensor_constructors_to_patch = {
142
+ torch_function_name: getattr(torch, torch_function_name)
143
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
144
+ }
145
+ else:
146
+ tensor_constructors_to_patch = {}
147
+
148
+ def patch_tensor_constructor(fn):
149
+ def wrapper(*args, **kwargs):
150
+ kwargs["device"] = device
151
+ return fn(*args, **kwargs)
152
+
153
+ return wrapper
154
+
155
+ try:
156
+ nn.Module.register_parameter = register_empty_parameter
157
+ if include_buffers:
158
+ nn.Module.register_buffer = register_empty_buffer
159
+ for torch_function_name in tensor_constructors_to_patch.keys():
160
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
161
+ yield
162
+ finally:
163
+ nn.Module.register_parameter = old_register_parameter
164
+ if include_buffers:
165
+ nn.Module.register_buffer = old_register_buffer
166
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
167
+ setattr(torch, torch_function_name, old_torch_function)
168
+
169
+
170
+ def cpu_offload(
171
+ model: nn.Module,
172
+ execution_device: Optional[torch.device] = None,
173
+ offload_buffers: bool = False,
174
+ state_dict: Optional[dict[str, torch.Tensor]] = None,
175
+ preload_module_classes: Optional[list[str]] = None,
176
+ ):
177
+ """
178
+ Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
179
+ copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
180
+ state dict and put on the execution device passed as they are needed, then offloaded again.
181
+
182
+ Args:
183
+ model (`torch.nn.Module`):
184
+ The model to offload.
185
+ execution_device (`torch.device`, *optional*):
186
+ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
187
+ model first parameter device.
188
+ offload_buffers (`bool`, *optional*, defaults to `False`):
189
+ Whether or not to offload the buffers with the model parameters.
190
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
191
+ The state dict of the model that will be kept on CPU.
192
+ preload_module_classes (`List[str]`, *optional*):
193
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
194
+ of the forward. This should only be used for classes that have submodules which are registered but not
195
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
196
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
197
+ """
198
+ if execution_device is None:
199
+ execution_device = next(iter(model.parameters())).device
200
+ if state_dict is None:
201
+ state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
202
+
203
+ add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
204
+ attach_align_device_hook(
205
+ model,
206
+ execution_device=execution_device,
207
+ offload=True,
208
+ offload_buffers=offload_buffers,
209
+ weights_map=state_dict,
210
+ preload_module_classes=preload_module_classes,
211
+ )
212
+
213
+ return model
214
+
215
+
216
+ def cpu_offload_with_hook(
217
+ model: torch.nn.Module,
218
+ execution_device: Optional[Union[int, str, torch.device]] = None,
219
+ prev_module_hook: Optional[UserCpuOffloadHook] = None,
220
+ ):
221
+ """
222
+ Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
223
+ [`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
224
+ the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
225
+
226
+ Args:
227
+ model (`torch.nn.Module`):
228
+ The model to offload.
229
+ execution_device(`str`, `int` or `torch.device`, *optional*):
230
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
231
+ GPU 0 if there is a GPU, and finally to the CPU.
232
+ prev_module_hook (`UserCpuOffloadHook`, *optional*):
233
+ The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
234
+ offload method will be called just before the forward of the model to which this hook is attached.
235
+
236
+ Example:
237
+
238
+ ```py
239
+ model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
240
+ model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
241
+ model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
242
+
243
+ hid_1 = model_1(input)
244
+ for i in range(50):
245
+ # model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
246
+ hid_2 = model_2(hid_1)
247
+ # model2 is offloaded to the CPU just before this forward.
248
+ hid_3 = model_3(hid_3)
249
+
250
+ # For model3, you need to manually call the hook offload method.
251
+ hook_3.offload()
252
+ ```
253
+ """
254
+ hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
255
+ add_hook_to_module(model, hook, append=True)
256
+ user_hook = UserCpuOffloadHook(model, hook)
257
+ return model, user_hook
258
+
259
+
260
+ def disk_offload(
261
+ model: nn.Module,
262
+ offload_dir: Union[str, os.PathLike],
263
+ execution_device: Optional[torch.device] = None,
264
+ offload_buffers: bool = False,
265
+ preload_module_classes: Optional[list[str]] = None,
266
+ ):
267
+ """
268
+ Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
269
+ memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
270
+ put on the execution device passed as they are needed, then offloaded again.
271
+
272
+ Args:
273
+ model (`torch.nn.Module`): The model to offload.
274
+ offload_dir (`str` or `os.PathLike`):
275
+ The folder in which to offload the model weights (or where the model weights are already offloaded).
276
+ execution_device (`torch.device`, *optional*):
277
+ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
278
+ model's first parameter device.
279
+ offload_buffers (`bool`, *optional*, defaults to `False`):
280
+ Whether or not to offload the buffers with the model parameters.
281
+ preload_module_classes (`List[str]`, *optional*):
282
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
283
+ of the forward. This should only be used for classes that have submodules which are registered but not
284
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
285
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
286
+ """
287
+ if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
288
+ offload_state_dict(offload_dir, model.state_dict())
289
+ if execution_device is None:
290
+ execution_device = next(iter(model.parameters())).device
291
+ weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
292
+
293
+ add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
294
+ attach_align_device_hook(
295
+ model,
296
+ execution_device=execution_device,
297
+ offload=True,
298
+ offload_buffers=offload_buffers,
299
+ weights_map=weights_map,
300
+ preload_module_classes=preload_module_classes,
301
+ )
302
+
303
+ return model
304
+
305
+
306
+ def dispatch_model(
307
+ model: nn.Module,
308
+ device_map: dict[str, Union[str, int, torch.device]],
309
+ main_device: Optional[torch.device] = None,
310
+ state_dict: Optional[dict[str, torch.Tensor]] = None,
311
+ offload_dir: Optional[Union[str, os.PathLike]] = None,
312
+ offload_index: Optional[dict[str, str]] = None,
313
+ offload_buffers: bool = False,
314
+ skip_keys: Optional[Union[str, list[str]]] = None,
315
+ preload_module_classes: Optional[list[str]] = None,
316
+ force_hooks: bool = False,
317
+ ):
318
+ """
319
+ Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
320
+ the CPU or even the disk.
321
+
322
+ Args:
323
+ model (`torch.nn.Module`):
324
+ The model to dispatch.
325
+ device_map (`Dict[str, Union[str, int, torch.device]]`):
326
+ A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
327
+ `"disk"` is accepted even if it's not a proper value for `torch.device`.
328
+ main_device (`str`, `int` or `torch.device`, *optional*):
329
+ The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
330
+ `"disk"`.
331
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
332
+ The state dict of the part of the model that will be kept on CPU.
333
+ offload_dir (`str` or `os.PathLike`):
334
+ The folder in which to offload the model weights (or where the model weights are already offloaded).
335
+ offload_index (`Dict`, *optional*):
336
+ A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
337
+ to the index saved in `save_folder`.
338
+ offload_buffers (`bool`, *optional*, defaults to `False`):
339
+ Whether or not to offload the buffers with the model parameters.
340
+ skip_keys (`str` or `List[str]`, *optional*):
341
+ A list of keys to ignore when moving inputs or outputs between devices.
342
+ preload_module_classes (`List[str]`, *optional*):
343
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
344
+ of the forward. This should only be used for classes that have submodules which are registered but not
345
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
346
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
347
+ force_hooks (`bool`, *optional*, defaults to `False`):
348
+ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
349
+ single device.
350
+ """
351
+ # Error early if the device map is incomplete.
352
+ check_device_map(model, device_map)
353
+
354
+ # We need to force hook for quantized model that can't be moved with to()
355
+ if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes":
356
+ # since bnb 0.43.2, we can move 4-bit model
357
+ if getattr(model, "is_loaded_in_8bit", False) or (
358
+ getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2")
359
+ ):
360
+ force_hooks = True
361
+
362
+ # We attach hooks if the device_map has at least 2 different devices or if
363
+ # force_hooks is set to `True`. Otherwise, the model in already loaded
364
+ # in the unique device and the user can decide where to dispatch the model.
365
+ # If the model is quantized, we always force-dispatch the model
366
+ if (len(set(device_map.values())) > 1) or force_hooks:
367
+ if main_device is None:
368
+ if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
369
+ main_device = "cpu"
370
+ else:
371
+ main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
372
+
373
+ if main_device != "cpu":
374
+ cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
375
+ if state_dict is None and len(cpu_modules) > 0:
376
+ state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
377
+
378
+ disk_modules = [name for name, device in device_map.items() if device == "disk"]
379
+ if offload_dir is None and offload_index is None and len(disk_modules) > 0:
380
+ raise ValueError(
381
+ "We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
382
+ f"need to be offloaded: {', '.join(disk_modules)}."
383
+ )
384
+ if (
385
+ len(disk_modules) > 0
386
+ and offload_index is None
387
+ and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")))
388
+ ):
389
+ disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
390
+ offload_state_dict(offload_dir, disk_state_dict)
391
+
392
+ execution_device = {
393
+ name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
394
+ }
395
+ execution_device[""] = main_device
396
+ offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
397
+ offload = {name: device in offloaded_devices for name, device in device_map.items()}
398
+ save_folder = offload_dir if len(disk_modules) > 0 else None
399
+ if state_dict is not None or save_folder is not None or offload_index is not None:
400
+ device = main_device if offload_index is not None else None
401
+ weights_map = OffloadedWeightsLoader(
402
+ state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
403
+ )
404
+ else:
405
+ weights_map = None
406
+
407
+ # When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
408
+ # tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
409
+ # original pointer) on each devices.
410
+ tied_params = find_tied_parameters(model)
411
+
412
+ tied_params_map = {}
413
+ for group in tied_params:
414
+ for param_name in group:
415
+ # data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
416
+ # to care about views of tensors through storage_offset.
417
+ data_ptr = recursive_getattr(model, param_name).data_ptr()
418
+ tied_params_map[data_ptr] = {}
419
+
420
+ # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
421
+ # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
422
+
423
+ attach_align_device_hook_on_blocks(
424
+ model,
425
+ execution_device=execution_device,
426
+ offload=offload,
427
+ offload_buffers=offload_buffers,
428
+ weights_map=weights_map,
429
+ skip_keys=skip_keys,
430
+ preload_module_classes=preload_module_classes,
431
+ tied_params_map=tied_params_map,
432
+ )
433
+
434
+ # warn if there is any params on the meta device
435
+ offloaded_devices_str = " and ".join(
436
+ [device for device in set(device_map.values()) if device in ("cpu", "disk")]
437
+ )
438
+ if len(offloaded_devices_str) > 0:
439
+ logger.warning(
440
+ f"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}."
441
+ )
442
+
443
+ # Attaching the hook may break tied weights, so we retie them
444
+ retie_parameters(model, tied_params)
445
+
446
+ # add warning to cuda and to method
447
+ def add_warning(fn, model):
448
+ @wraps(fn)
449
+ def wrapper(*args, **kwargs):
450
+ warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks."
451
+ if str(fn.__name__) == "to":
452
+ to_device = torch._C._nn._parse_to(*args, **kwargs)[0]
453
+ if to_device is not None:
454
+ logger.warning(warning_msg)
455
+ else:
456
+ logger.warning(warning_msg)
457
+ for param in model.parameters():
458
+ if param.device == torch.device("meta"):
459
+ raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.")
460
+ return fn(*args, **kwargs)
461
+
462
+ return wrapper
463
+
464
+ # Make sure to update _accelerate_added_attributes in hooks.py if you add any hook
465
+ model.to = add_warning(model.to, model)
466
+ if is_npu_available():
467
+ model.npu = add_warning(model.npu, model)
468
+ elif is_mlu_available():
469
+ model.mlu = add_warning(model.mlu, model)
470
+ elif is_sdaa_available():
471
+ model.sdaa = add_warning(model.sdaa, model)
472
+ elif is_musa_available():
473
+ model.musa = add_warning(model.musa, model)
474
+ elif is_xpu_available():
475
+ model.xpu = add_warning(model.xpu, model)
476
+ else:
477
+ model.cuda = add_warning(model.cuda, model)
478
+
479
+ # Check if we are using multi-gpus with RTX 4000 series
480
+ use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1
481
+ if use_multi_gpu and not check_cuda_p2p_ib_support():
482
+ logger.warning(
483
+ "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. "
484
+ "This can affect the multi-gpu inference when using accelerate device_map."
485
+ "Please make sure to update your driver to the latest version which resolves this."
486
+ )
487
+ else:
488
+ device = list(device_map.values())[0]
489
+ # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
490
+ if is_npu_available() and isinstance(device, int):
491
+ device = f"npu:{device}"
492
+ elif is_mlu_available() and isinstance(device, int):
493
+ device = f"mlu:{device}"
494
+ elif is_sdaa_available() and isinstance(device, int):
495
+ device = f"sdaa:{device}"
496
+ elif is_musa_available() and isinstance(device, int):
497
+ device = f"musa:{device}"
498
+ if device != "disk":
499
+ model.to(device)
500
+ else:
501
+ raise ValueError(
502
+ "You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead."
503
+ )
504
+ # Convert OrderedDict back to dict for easier usage
505
+ model.hf_device_map = dict(device_map)
506
+ return model
507
+
508
+
509
+ def load_checkpoint_and_dispatch(
510
+ model: nn.Module,
511
+ checkpoint: Union[str, os.PathLike],
512
+ device_map: Optional[Union[str, dict[str, Union[int, str, torch.device]]]] = None,
513
+ max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
514
+ no_split_module_classes: Optional[list[str]] = None,
515
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
516
+ offload_buffers: bool = False,
517
+ dtype: Optional[Union[str, torch.dtype]] = None,
518
+ offload_state_dict: Optional[bool] = None,
519
+ skip_keys: Optional[Union[str, list[str]]] = None,
520
+ preload_module_classes: Optional[list[str]] = None,
521
+ force_hooks: bool = False,
522
+ strict: bool = False,
523
+ ):
524
+ """
525
+ Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
526
+ loaded and adds the various hooks that will make this model run properly (even if split across devices).
527
+
528
+ Args:
529
+ model (`torch.nn.Module`): The model in which we want to load a checkpoint.
530
+ checkpoint (`str` or `os.PathLike`):
531
+ The folder checkpoint to load. It can be:
532
+ - a path to a file containing a whole model state dict
533
+ - a path to a `.json` file containing the index to a sharded checkpoint
534
+ - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
535
+ device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
536
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
537
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
538
+
539
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more
540
+ information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).
541
+ Defaults to None, which means [`dispatch_model`] will not be called.
542
+ max_memory (`Dict`, *optional*):
543
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
544
+ and the available CPU RAM if unset.
545
+ no_split_module_classes (`List[str]`, *optional*):
546
+ A list of layer class names that should never be split across device (for instance any layer that has a
547
+ residual connection).
548
+ offload_folder (`str` or `os.PathLike`, *optional*):
549
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
550
+ offload_buffers (`bool`, *optional*, defaults to `False`):
551
+ In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
552
+ well as the parameters.
553
+ dtype (`str` or `torch.dtype`, *optional*):
554
+ If provided, the weights will be converted to that type when loaded.
555
+ offload_state_dict (`bool`, *optional*):
556
+ If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
557
+ the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
558
+ picked contains `"disk"` values.
559
+ skip_keys (`str` or `List[str]`, *optional*):
560
+ A list of keys to ignore when moving inputs or outputs between devices.
561
+ preload_module_classes (`List[str]`, *optional*):
562
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
563
+ of the forward. This should only be used for classes that have submodules which are registered but not
564
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
565
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
566
+ force_hooks (`bool`, *optional*, defaults to `False`):
567
+ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
568
+ single device.
569
+ strict (`bool`, *optional*, defaults to `False`):
570
+ Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
571
+ state_dict.
572
+
573
+ Example:
574
+
575
+ ```python
576
+ >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch
577
+ >>> from huggingface_hub import hf_hub_download
578
+ >>> from transformers import AutoConfig, AutoModelForCausalLM
579
+
580
+ >>> # Download the Weights
581
+ >>> checkpoint = "EleutherAI/gpt-j-6B"
582
+ >>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
583
+
584
+ >>> # Create a model and initialize it with empty weights
585
+ >>> config = AutoConfig.from_pretrained(checkpoint)
586
+ >>> with init_empty_weights():
587
+ ... model = AutoModelForCausalLM.from_config(config)
588
+
589
+ >>> # Load the checkpoint and dispatch it to the right devices
590
+ >>> model = load_checkpoint_and_dispatch(
591
+ ... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
592
+ ... )
593
+ ```
594
+ """
595
+ if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
596
+ raise ValueError(
597
+ "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or 'sequential'."
598
+ )
599
+ if isinstance(device_map, str):
600
+ if device_map != "sequential":
601
+ max_memory = get_balanced_memory(
602
+ model,
603
+ max_memory=max_memory,
604
+ no_split_module_classes=no_split_module_classes,
605
+ dtype=dtype,
606
+ low_zero=(device_map == "balanced_low_0"),
607
+ )
608
+ device_map = infer_auto_device_map(
609
+ model,
610
+ max_memory=max_memory,
611
+ no_split_module_classes=no_split_module_classes,
612
+ dtype=dtype,
613
+ offload_buffers=offload_buffers,
614
+ )
615
+ if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
616
+ offload_state_dict = True
617
+ load_checkpoint_in_model(
618
+ model,
619
+ checkpoint,
620
+ device_map=device_map,
621
+ offload_folder=offload_folder,
622
+ dtype=dtype,
623
+ offload_state_dict=offload_state_dict,
624
+ offload_buffers=offload_buffers,
625
+ strict=strict,
626
+ )
627
+ if device_map is None:
628
+ return model
629
+ return dispatch_model(
630
+ model,
631
+ device_map=device_map,
632
+ offload_dir=offload_folder,
633
+ offload_buffers=offload_buffers,
634
+ skip_keys=skip_keys,
635
+ preload_module_classes=preload_module_classes,
636
+ force_hooks=force_hooks,
637
+ )
venv/Lib/site-packages/accelerate/checkpointing.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import torch
20
+ from safetensors.torch import load_model
21
+ from torch.cuda.amp import GradScaler
22
+
23
+ from .utils import (
24
+ MODEL_NAME,
25
+ OPTIMIZER_NAME,
26
+ RNG_STATE_NAME,
27
+ SAFE_MODEL_NAME,
28
+ SAFE_WEIGHTS_NAME,
29
+ SAMPLER_NAME,
30
+ SCALER_NAME,
31
+ SCHEDULER_NAME,
32
+ WEIGHTS_NAME,
33
+ get_pretty_name,
34
+ is_cuda_available,
35
+ is_hpu_available,
36
+ is_mlu_available,
37
+ is_musa_available,
38
+ is_sdaa_available,
39
+ is_torch_xla_available,
40
+ is_xpu_available,
41
+ load,
42
+ save,
43
+ )
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ from .logging import get_logger
50
+ from .state import PartialState
51
+
52
+
53
+ logger = get_logger(__name__)
54
+
55
+
56
+ def save_accelerator_state(
57
+ output_dir: str,
58
+ model_states: list[dict],
59
+ optimizers: list,
60
+ schedulers: list,
61
+ dataloaders: list,
62
+ process_index: int,
63
+ step: int,
64
+ scaler: GradScaler = None,
65
+ save_on_each_node: bool = False,
66
+ safe_serialization: bool = True,
67
+ ):
68
+ """
69
+ Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
70
+
71
+ <Tip>
72
+
73
+ If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
74
+ `pickle`.
75
+
76
+ </Tip>
77
+
78
+ Args:
79
+ output_dir (`str` or `os.PathLike`):
80
+ The name of the folder to save all relevant weights and states.
81
+ model_states (`List[torch.nn.Module]`):
82
+ A list of model states
83
+ optimizers (`List[torch.optim.Optimizer]`):
84
+ A list of optimizer instances
85
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
86
+ A list of learning rate schedulers
87
+ dataloaders (`List[torch.utils.data.DataLoader]`):
88
+ A list of dataloader instances to save their sampler states
89
+ process_index (`int`):
90
+ The current process index in the Accelerator state
91
+ step (`int`):
92
+ The current step in the internal step tracker
93
+ scaler (`torch.amp.GradScaler`, *optional*):
94
+ An optional gradient scaler instance to save;
95
+ save_on_each_node (`bool`, *optional*):
96
+ Whether to save on every node, or only the main node.
97
+ safe_serialization (`bool`, *optional*, defaults to `True`):
98
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
99
+ """
100
+ output_dir = Path(output_dir)
101
+ # Model states
102
+ for i, state in enumerate(model_states):
103
+ weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
104
+ if i > 0:
105
+ weights_name = weights_name.replace(".", f"_{i}.")
106
+ output_model_file = output_dir.joinpath(weights_name)
107
+ save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
108
+ logger.info(f"Model weights saved in {output_model_file}")
109
+ # Optimizer states
110
+ for i, opt in enumerate(optimizers):
111
+ state = opt.state_dict()
112
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
113
+ output_optimizer_file = output_dir.joinpath(optimizer_name)
114
+ save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
115
+ logger.info(f"Optimizer state saved in {output_optimizer_file}")
116
+ # Scheduler states
117
+ for i, scheduler in enumerate(schedulers):
118
+ state = scheduler.state_dict()
119
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
120
+ output_scheduler_file = output_dir.joinpath(scheduler_name)
121
+ save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
122
+ logger.info(f"Scheduler state saved in {output_scheduler_file}")
123
+ # DataLoader states
124
+ for i, dataloader in enumerate(dataloaders):
125
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
126
+ output_sampler_file = output_dir.joinpath(sampler_name)
127
+ # Only save if we have our custom sampler
128
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
129
+
130
+ if isinstance(dataloader.dataset, IterableDatasetShard):
131
+ sampler = dataloader.get_sampler()
132
+ if isinstance(sampler, SeedableRandomSampler):
133
+ save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
134
+ if getattr(dataloader, "use_stateful_dataloader", False):
135
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
136
+ output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
137
+ state_dict = dataloader.state_dict()
138
+ torch.save(state_dict, output_dataloader_state_dict_file)
139
+ logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
140
+
141
+ # GradScaler state
142
+ if scaler is not None:
143
+ state = scaler.state_dict()
144
+ output_scaler_file = output_dir.joinpath(SCALER_NAME)
145
+ torch.save(state, output_scaler_file)
146
+ logger.info(f"Gradient scaler state saved in {output_scaler_file}")
147
+ # Random number generator states
148
+ states = {}
149
+ states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
150
+ states["step"] = step
151
+ states["random_state"] = random.getstate()
152
+ states["numpy_random_seed"] = np.random.get_state()
153
+ states["torch_manual_seed"] = torch.get_rng_state()
154
+ if is_xpu_available():
155
+ states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
156
+ if is_mlu_available():
157
+ states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
158
+ elif is_sdaa_available():
159
+ states["torch_sdaa_manual_seed"] = torch.sdaa.get_rng_state_all()
160
+ elif is_musa_available():
161
+ states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
162
+ if is_hpu_available():
163
+ states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
164
+ if is_cuda_available():
165
+ states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
166
+ if is_torch_xla_available():
167
+ states["xm_seed"] = xm.get_rng_state()
168
+ output_states_file = output_dir.joinpath(states_name)
169
+ torch.save(states, output_states_file)
170
+ logger.info(f"Random states saved in {output_states_file}")
171
+ return output_dir
172
+
173
+
174
+ def load_accelerator_state(
175
+ input_dir,
176
+ models,
177
+ optimizers,
178
+ schedulers,
179
+ dataloaders,
180
+ process_index,
181
+ scaler=None,
182
+ map_location=None,
183
+ **load_model_func_kwargs,
184
+ ):
185
+ """
186
+ Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
187
+
188
+ Args:
189
+ input_dir (`str` or `os.PathLike`):
190
+ The name of the folder to load all relevant weights and states.
191
+ models (`List[torch.nn.Module]`):
192
+ A list of model instances
193
+ optimizers (`List[torch.optim.Optimizer]`):
194
+ A list of optimizer instances
195
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
196
+ A list of learning rate schedulers
197
+ process_index (`int`):
198
+ The current process index in the Accelerator state
199
+ scaler (`torch.amp.GradScaler`, *optional*):
200
+ An optional *GradScaler* instance to load
201
+ map_location (`str`, *optional*):
202
+ What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
203
+ load_model_func_kwargs (`dict`, *optional*):
204
+ Additional arguments that can be passed to the model's `load_state_dict` method.
205
+
206
+ Returns:
207
+ `dict`: Contains the `Accelerator` attributes to override while loading the state.
208
+ """
209
+ # stores the `Accelerator` attributes to override
210
+ override_attributes = dict()
211
+ if map_location not in [None, "cpu", "on_device"]:
212
+ raise TypeError(
213
+ "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
214
+ )
215
+ if map_location is None:
216
+ map_location = "cpu"
217
+ elif map_location == "on_device":
218
+ map_location = PartialState().device
219
+
220
+ input_dir = Path(input_dir)
221
+ # Model states
222
+ for i, model in enumerate(models):
223
+ ending = f"_{i}" if i > 0 else ""
224
+ input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
225
+ if input_model_file.exists():
226
+ load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
227
+ else:
228
+ # Load with torch
229
+ input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
230
+ state_dict = load(input_model_file, map_location=map_location)
231
+ model.load_state_dict(state_dict, **load_model_func_kwargs)
232
+ logger.info("All model weights loaded successfully")
233
+
234
+ # Optimizer states
235
+ for i, opt in enumerate(optimizers):
236
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
237
+ input_optimizer_file = input_dir.joinpath(optimizer_name)
238
+ optimizer_state = load(input_optimizer_file, map_location=map_location)
239
+ optimizers[i].load_state_dict(optimizer_state)
240
+ logger.info("All optimizer states loaded successfully")
241
+
242
+ # Scheduler states
243
+ for i, scheduler in enumerate(schedulers):
244
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
245
+ input_scheduler_file = input_dir.joinpath(scheduler_name)
246
+ scheduler_state = load(input_scheduler_file)
247
+ scheduler.load_state_dict(scheduler_state)
248
+ logger.info("All scheduler states loaded successfully")
249
+
250
+ for i, dataloader in enumerate(dataloaders):
251
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
252
+ input_sampler_file = input_dir.joinpath(sampler_name)
253
+ # Only load if we have our custom sampler
254
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
255
+
256
+ if isinstance(dataloader.dataset, IterableDatasetShard):
257
+ sampler = dataloader.get_sampler()
258
+ if isinstance(sampler, SeedableRandomSampler):
259
+ sampler = dataloader.set_sampler(load(input_sampler_file))
260
+ if getattr(dataloader, "use_stateful_dataloader", False):
261
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
262
+ input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
263
+ if input_dataloader_state_dict_file.exists():
264
+ state_dict = load(input_dataloader_state_dict_file)
265
+ dataloader.load_state_dict(state_dict)
266
+ logger.info("All dataloader sampler states loaded successfully")
267
+
268
+ # GradScaler state
269
+ if scaler is not None:
270
+ input_scaler_file = input_dir.joinpath(SCALER_NAME)
271
+ scaler_state = load(input_scaler_file)
272
+ scaler.load_state_dict(scaler_state)
273
+ logger.info("GradScaler state loaded successfully")
274
+
275
+ # Random states
276
+ try:
277
+ states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
278
+ if "step" in states:
279
+ override_attributes["step"] = states["step"]
280
+ random.setstate(states["random_state"])
281
+ np.random.set_state(states["numpy_random_seed"])
282
+ torch.set_rng_state(states["torch_manual_seed"])
283
+ if is_xpu_available():
284
+ torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
285
+ if is_mlu_available():
286
+ torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
287
+ elif is_sdaa_available():
288
+ torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"])
289
+ elif is_musa_available():
290
+ torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
291
+ else:
292
+ torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
293
+ if is_torch_xla_available():
294
+ xm.set_rng_state(states["xm_seed"])
295
+ logger.info("All random states loaded successfully")
296
+ except Exception:
297
+ logger.info("Could not load random states")
298
+
299
+ return override_attributes
300
+
301
+
302
+ def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
303
+ """
304
+ Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
305
+ """
306
+ # Should this be the right way to get a qual_name type value from `obj`?
307
+ save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
308
+ logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
309
+ save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
310
+
311
+
312
+ def load_custom_state(obj, path, index: int = 0):
313
+ """
314
+ Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
315
+ loading the state.
316
+ """
317
+ load_location = f"{path}/custom_checkpoint_{index}.pkl"
318
+ logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
319
+ obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))
venv/Lib/site-packages/accelerate/data_loader.py ADDED
@@ -0,0 +1,1429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import math
17
+ from contextlib import suppress
18
+ from typing import Callable, Optional, Union
19
+
20
+ import torch
21
+ from packaging import version
22
+ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
23
+
24
+ from .logging import get_logger
25
+ from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
26
+ from .utils import (
27
+ RNGType,
28
+ broadcast,
29
+ broadcast_object_list,
30
+ compare_versions,
31
+ concatenate,
32
+ find_batch_size,
33
+ get_data_structure,
34
+ initialize_tensors,
35
+ is_torch_version,
36
+ is_torchdata_stateful_dataloader_available,
37
+ send_to_device,
38
+ slice_tensors,
39
+ synchronize_rng_states,
40
+ )
41
+
42
+
43
+ logger = get_logger(__name__)
44
+
45
+ # kwargs of the DataLoader in min version 2.0
46
+ _PYTORCH_DATALOADER_KWARGS = {
47
+ "batch_size": 1,
48
+ "shuffle": False,
49
+ "sampler": None,
50
+ "batch_sampler": None,
51
+ "num_workers": 0,
52
+ "collate_fn": None,
53
+ "pin_memory": False,
54
+ "drop_last": False,
55
+ "timeout": 0,
56
+ "worker_init_fn": None,
57
+ "multiprocessing_context": None,
58
+ "generator": None,
59
+ "prefetch_factor": 2,
60
+ "persistent_workers": False,
61
+ "pin_memory_device": "",
62
+ }
63
+
64
+ # kwargs added after by version
65
+ _PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}}
66
+
67
+ for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
68
+ if is_torch_version(">=", v):
69
+ _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
70
+
71
+
72
+ class SeedableRandomSampler(RandomSampler):
73
+ """
74
+ Same as a random sampler, except that in `__iter__` a seed can be used.
75
+
76
+ Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
77
+ and be fully reproducable on multiple iterations.
78
+
79
+ If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
80
+ (stored in `self.epoch`).
81
+ """
82
+
83
+ def __init__(self, *args, **kwargs):
84
+ data_seed = kwargs.pop("data_seed", None)
85
+ super().__init__(*args, **kwargs)
86
+
87
+ self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
88
+ self.epoch = 0
89
+
90
+ def __iter__(self):
91
+ if self.generator is None:
92
+ self.generator = torch.Generator(
93
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
94
+ )
95
+ self.generator.manual_seed(self.initial_seed)
96
+
97
+ # Allow `self.epoch` to modify the seed of the generator
98
+ seed = self.epoch + self.initial_seed
99
+ # print("Setting seed at epoch", self.epoch, seed)
100
+ self.generator.manual_seed(seed)
101
+ yield from super().__iter__()
102
+ self.set_epoch(self.epoch + 1)
103
+
104
+ def set_epoch(self, epoch: int):
105
+ "Sets the current iteration of the sampler."
106
+ self.epoch = epoch
107
+
108
+
109
+ class BatchSamplerShard(BatchSampler):
110
+ """
111
+ Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
112
+ always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
113
+ Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
114
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
115
+
116
+ Args:
117
+ batch_sampler (`torch.utils.data.sampler.BatchSampler`):
118
+ The batch sampler to split in several shards.
119
+ num_processes (`int`, *optional*, defaults to 1):
120
+ The number of processes running concurrently.
121
+ process_index (`int`, *optional*, defaults to 0):
122
+ The index of the current process.
123
+ split_batches (`bool`, *optional*, defaults to `False`):
124
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
125
+ yielding different full batches on each process.
126
+
127
+ On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
128
+
129
+ - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
130
+ this argument is set to `False`.
131
+ - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
132
+ then `[6, 7]` if this argument is set to `True`.
133
+ even_batches (`bool`, *optional*, defaults to `True`):
134
+ Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
135
+ multiple of (original batch size / number of processes).
136
+
137
+ <Tip warning={true}>
138
+
139
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
140
+ equal to `False`
141
+
142
+ </Tip>"""
143
+
144
+ def __init__(
145
+ self,
146
+ batch_sampler: BatchSampler,
147
+ num_processes: int = 1,
148
+ process_index: int = 0,
149
+ split_batches: bool = False,
150
+ even_batches: bool = True,
151
+ ):
152
+ if split_batches and batch_sampler.batch_size % num_processes != 0:
153
+ raise ValueError(
154
+ f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
155
+ f"needs to be a round multiple of the number of processes ({num_processes})."
156
+ )
157
+ self.batch_sampler = batch_sampler
158
+ self.num_processes = num_processes
159
+ self.process_index = process_index
160
+ self.split_batches = split_batches
161
+ self.even_batches = even_batches
162
+ self.batch_size = getattr(batch_sampler, "batch_size", None)
163
+ self.drop_last = getattr(batch_sampler, "drop_last", False)
164
+ if self.batch_size is None and self.even_batches:
165
+ raise ValueError(
166
+ "You need to use `even_batches=False` when the batch sampler has no batch size. If you "
167
+ "are not calling this method directly, set `accelerator.even_batches=False` instead."
168
+ )
169
+
170
+ @property
171
+ def total_length(self):
172
+ return len(self.batch_sampler)
173
+
174
+ def __len__(self):
175
+ if self.split_batches:
176
+ # Split batches does not change the length of the batch sampler
177
+ return len(self.batch_sampler)
178
+ if len(self.batch_sampler) % self.num_processes == 0:
179
+ # If the length is a round multiple of the number of processes, it's easy.
180
+ return len(self.batch_sampler) // self.num_processes
181
+ length = len(self.batch_sampler) // self.num_processes
182
+ if self.drop_last:
183
+ # Same if we drop the remainder.
184
+ return length
185
+ elif self.even_batches:
186
+ # When we even batches we always get +1
187
+ return length + 1
188
+ else:
189
+ # Otherwise it depends on the process index.
190
+ return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
191
+
192
+ def __iter__(self):
193
+ return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
194
+
195
+ def _iter_with_split(self):
196
+ initial_data = []
197
+ batch_length = self.batch_sampler.batch_size // self.num_processes
198
+ for idx, batch in enumerate(self.batch_sampler):
199
+ if idx == 0:
200
+ initial_data = batch
201
+ if len(batch) == self.batch_size:
202
+ # If the batch is full, we yield the part of it this process is responsible of.
203
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
204
+
205
+ # If drop_last is True of the last batch was full, iteration is over, otherwise...
206
+ if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
207
+ if not self.even_batches:
208
+ if len(batch) > batch_length * self.process_index:
209
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
210
+ else:
211
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
212
+ while len(initial_data) < self.batch_size:
213
+ initial_data += initial_data
214
+ batch = batch + initial_data
215
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
216
+
217
+ def _iter_with_no_split(self):
218
+ initial_data = []
219
+ batch_to_yield = []
220
+ for idx, batch in enumerate(self.batch_sampler):
221
+ # We gather the initial indices in case we need to circle back at the end.
222
+ if not self.drop_last and idx < self.num_processes:
223
+ initial_data += batch
224
+ # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
225
+ # yielding it.
226
+ if idx % self.num_processes == self.process_index:
227
+ batch_to_yield = batch
228
+ if idx % self.num_processes == self.num_processes - 1 and (
229
+ self.batch_size is None or len(batch) == self.batch_size
230
+ ):
231
+ yield batch_to_yield
232
+ batch_to_yield = []
233
+
234
+ # If drop_last is True, iteration is over, otherwise...
235
+ if not self.drop_last and len(initial_data) > 0:
236
+ if not self.even_batches:
237
+ if len(batch_to_yield) > 0:
238
+ yield batch_to_yield
239
+ else:
240
+ # ... we yield the complete batch we had saved before if it has the proper length
241
+ if len(batch_to_yield) == self.batch_size:
242
+ yield batch_to_yield
243
+
244
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
245
+ while len(initial_data) < self.num_processes * self.batch_size:
246
+ initial_data += initial_data
247
+
248
+ # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
249
+ if len(batch) == self.batch_size:
250
+ batch = []
251
+ idx += 1
252
+
253
+ # Make sure we yield a multiple of self.num_processes batches
254
+ cycle_index = 0
255
+ while idx % self.num_processes != 0 or len(batch) > 0:
256
+ end_index = cycle_index + self.batch_size - len(batch)
257
+ batch += initial_data[cycle_index:end_index]
258
+ if idx % self.num_processes == self.process_index:
259
+ yield batch
260
+ cycle_index = end_index
261
+ batch = []
262
+ idx += 1
263
+
264
+
265
+ class IterableDatasetShard(IterableDataset):
266
+ """
267
+ Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
268
+ always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
269
+ `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
270
+ `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
271
+ be too small or loop with indices from the beginning.
272
+
273
+ Args:
274
+ dataset (`torch.utils.data.dataset.IterableDataset`):
275
+ The batch sampler to split in several shards.
276
+ batch_size (`int`, *optional*, defaults to 1):
277
+ The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
278
+ `split_batches=True`).
279
+ drop_last (`bool`, *optional*, defaults to `False`):
280
+ Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
281
+ beginning.
282
+ num_processes (`int`, *optional*, defaults to 1):
283
+ The number of processes running concurrently.
284
+ process_index (`int`, *optional*, defaults to 0):
285
+ The index of the current process.
286
+ split_batches (`bool`, *optional*, defaults to `False`):
287
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
288
+ yielding different full batches on each process.
289
+
290
+ On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
291
+
292
+ - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
293
+ argument is set to `False`.
294
+ - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
295
+ this argument is set to `True`.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ dataset: IterableDataset,
301
+ batch_size: int = 1,
302
+ drop_last: bool = False,
303
+ num_processes: int = 1,
304
+ process_index: int = 0,
305
+ split_batches: bool = False,
306
+ ):
307
+ if split_batches and batch_size > 1 and batch_size % num_processes != 0:
308
+ raise ValueError(
309
+ f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
310
+ f"needs to be a round multiple of the number of processes ({num_processes})."
311
+ )
312
+ self.dataset = dataset
313
+ self.batch_size = batch_size
314
+ self.drop_last = drop_last
315
+ self.num_processes = num_processes
316
+ self.process_index = process_index
317
+ self.split_batches = split_batches
318
+
319
+ def set_epoch(self, epoch):
320
+ self.epoch = epoch
321
+ if hasattr(self.dataset, "set_epoch"):
322
+ self.dataset.set_epoch(epoch)
323
+
324
+ def __len__(self):
325
+ # We will just raise the downstream error if the underlying dataset is not sized
326
+ if self.drop_last:
327
+ return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
328
+ else:
329
+ return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
330
+
331
+ def __iter__(self):
332
+ if (
333
+ not hasattr(self.dataset, "set_epoch")
334
+ and hasattr(self.dataset, "generator")
335
+ and isinstance(self.dataset.generator, torch.Generator)
336
+ ):
337
+ self.dataset.generator.manual_seed(self.epoch)
338
+ real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
339
+ process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
340
+ process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
341
+
342
+ first_batch = None
343
+ current_batch = []
344
+ for element in self.dataset:
345
+ current_batch.append(element)
346
+ # Wait to have a full batch before yielding elements.
347
+ if len(current_batch) == real_batch_size:
348
+ for i in process_slice:
349
+ yield current_batch[i]
350
+ if first_batch is None:
351
+ first_batch = current_batch.copy()
352
+ current_batch = []
353
+
354
+ # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
355
+ if not self.drop_last and len(current_batch) > 0:
356
+ if first_batch is None:
357
+ first_batch = current_batch.copy()
358
+ while len(current_batch) < real_batch_size:
359
+ current_batch += first_batch
360
+ for i in process_slice:
361
+ yield current_batch[i]
362
+
363
+
364
+ class DataLoaderStateMixin:
365
+ """
366
+ Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
367
+ end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
368
+ useful information that might be needed.
369
+
370
+ **Available attributes:**
371
+
372
+ - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
373
+ - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
374
+ batch size
375
+
376
+ <Tip warning={true}>
377
+
378
+ Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
379
+ `self.gradient_state`.
380
+
381
+ </Tip>
382
+
383
+ """
384
+
385
+ def __init_subclass__(cls, **kwargs):
386
+ cls.end_of_dataloader = False
387
+ cls.remainder = -1
388
+
389
+ def reset(self):
390
+ self.end_of_dataloader = False
391
+ self.remainder = -1
392
+
393
+ def begin(self):
394
+ "Prepares the gradient state for the current dataloader"
395
+ self.reset()
396
+ with suppress(Exception):
397
+ if not self._drop_last:
398
+ length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
399
+ self.remainder = length % self.total_batch_size
400
+ self.gradient_state._add_dataloader(self)
401
+
402
+ def end(self):
403
+ "Cleans up the gradient state after exiting the dataloader"
404
+ self.gradient_state._remove_dataloader(self)
405
+
406
+
407
+ class DataLoaderAdapter:
408
+ """
409
+ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
410
+ compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
411
+ """
412
+
413
+ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
414
+ self.use_stateful_dataloader = use_stateful_dataloader
415
+ if is_torchdata_stateful_dataloader_available():
416
+ from torchdata.stateful_dataloader import StatefulDataLoader
417
+
418
+ if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
419
+ raise ImportError(
420
+ "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
421
+ )
422
+ if use_stateful_dataloader:
423
+ torchdata_version = version.parse(importlib.metadata.version("torchdata"))
424
+ if (
425
+ "in_order" in kwargs
426
+ and compare_versions(torchdata_version, "<", "0.11")
427
+ and is_torch_version(">=", "2.6.0")
428
+ ):
429
+ kwargs.pop("in_order")
430
+ self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
431
+ else:
432
+ self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
433
+
434
+ if hasattr(self.base_dataloader, "state_dict"):
435
+ self.dl_state_dict = self.base_dataloader.state_dict()
436
+
437
+ def __getattr__(self, name):
438
+ # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
439
+ if name == "base_dataloader":
440
+ raise AttributeError()
441
+ # Delegate attribute access to the internal dataloader
442
+ return getattr(self.base_dataloader, name)
443
+
444
+ def state_dict(self):
445
+ return self.dl_state_dict
446
+
447
+ def load_state_dict(self, state_dict):
448
+ self.base_dataloader.load_state_dict(state_dict)
449
+
450
+ @property
451
+ def __class__(self):
452
+ """
453
+ In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
454
+ returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
455
+ object.
456
+ """
457
+ return self.base_dataloader.__class__
458
+
459
+ def __len__(self):
460
+ return len(self.base_dataloader)
461
+
462
+ def adjust_state_dict_for_prefetch(self):
463
+ """
464
+ Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
465
+ `self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
466
+ overridden.
467
+
468
+ This should modify `self.dl_state_dict` directly
469
+ """
470
+ # The state dict will be off by a factor of `n-1` batch too many during DDP,
471
+ # so we need to adjust it here
472
+ if PartialState().distributed_type != DistributedType.NO:
473
+ factor = PartialState().num_processes - 1
474
+ if self.dl_state_dict["_sampler_iter_yielded"] > 0:
475
+ self.dl_state_dict["_sampler_iter_yielded"] -= factor
476
+ if self.dl_state_dict["_num_yielded"] > 0:
477
+ self.dl_state_dict["_num_yielded"] -= factor
478
+ if self.dl_state_dict["_index_sampler_state"] is not None:
479
+ if (
480
+ "samples_yielded" in self.dl_state_dict["_index_sampler_state"]
481
+ and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
482
+ ):
483
+ self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
484
+
485
+ def _update_state_dict(self):
486
+ # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
487
+ # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
488
+ # what it wants to yield.
489
+ #
490
+ # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
491
+ if hasattr(self.base_dataloader, "state_dict"):
492
+ self.dl_state_dict = self.base_dataloader.state_dict()
493
+ # Potentially modify the state_dict to adjust for prefetching
494
+ self.adjust_state_dict_for_prefetch()
495
+ # Then tag if we are at the end of the dataloader
496
+ self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
497
+
498
+
499
+ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
500
+ """
501
+ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
502
+
503
+ Args:
504
+ dataset (`torch.utils.data.dataset.Dataset`):
505
+ The dataset to use to build this dataloader.
506
+ device (`torch.device`, *optional*):
507
+ If passed, the device to put all batches on.
508
+ rng_types (list of `str` or [`~utils.RNGType`]):
509
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
510
+ several of:
511
+
512
+ - `"torch"`: the base torch random number generator
513
+ - `"cuda"`: the CUDA random number generator (GPU only)
514
+ - `"xla"`: the XLA random number generator (TPU only)
515
+ - `"generator"`: an optional `torch.Generator`
516
+ synchronized_generator (`torch.Generator`, *optional*):
517
+ A random number generator to keep synchronized across processes.
518
+ skip_batches (`int`, *optional*, defaults to 0):
519
+ The number of batches to skip at the beginning.
520
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
521
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
522
+ **kwargs (additional keyword arguments, *optional*):
523
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
524
+
525
+ **Available attributes:**
526
+
527
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
528
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
529
+ number of processes
530
+
531
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
532
+ """
533
+
534
+ def __init__(
535
+ self,
536
+ dataset,
537
+ device=None,
538
+ rng_types=None,
539
+ synchronized_generator=None,
540
+ skip_batches=0,
541
+ use_stateful_dataloader=False,
542
+ _drop_last: bool = False,
543
+ _non_blocking: bool = False,
544
+ torch_device_mesh=None,
545
+ **kwargs,
546
+ ):
547
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
548
+ self.device = device
549
+ self.rng_types = rng_types
550
+ self.synchronized_generator = synchronized_generator
551
+ self.skip_batches = skip_batches
552
+ self.gradient_state = GradientState()
553
+ self._drop_last = _drop_last
554
+ self._non_blocking = _non_blocking
555
+ self.iteration = 0
556
+
557
+ def __iter__(self):
558
+ if self.rng_types is not None:
559
+ synchronize_rng_states(self.rng_types, self.synchronized_generator)
560
+ self.begin()
561
+
562
+ self.set_epoch(self.iteration)
563
+ dataloader_iter = self.base_dataloader.__iter__()
564
+ # We iterate one batch ahead to check when we are at the end
565
+ try:
566
+ current_batch = next(dataloader_iter)
567
+ except StopIteration:
568
+ yield
569
+
570
+ batch_index = 0
571
+ while True:
572
+ try:
573
+ # But we still move it to the device so it is done before `StopIteration` is reached
574
+ if self.device is not None:
575
+ current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
576
+ self._update_state_dict()
577
+ next_batch = next(dataloader_iter)
578
+ if batch_index >= self.skip_batches:
579
+ yield current_batch
580
+ batch_index += 1
581
+ current_batch = next_batch
582
+ except StopIteration:
583
+ self.end_of_dataloader = True
584
+ self._update_state_dict()
585
+ if batch_index >= self.skip_batches:
586
+ yield current_batch
587
+ break
588
+
589
+ self.iteration += 1
590
+ self.end()
591
+
592
+ def __reduce__(self):
593
+ """
594
+ Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
595
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
596
+ `__class__` member.
597
+ """
598
+ args = super().__reduce__()
599
+ return (DataLoaderShard, *args[1:])
600
+
601
+ def set_epoch(self, epoch: int):
602
+ # In case it is manually passed in, the user can set it to what they like
603
+ if self.iteration != epoch:
604
+ self.iteration = epoch
605
+ if hasattr(self.batch_sampler, "set_epoch"):
606
+ self.batch_sampler.set_epoch(epoch)
607
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
608
+ self.batch_sampler.sampler.set_epoch(epoch)
609
+ # We support if a custom `Dataset` implementation has `set_epoch`
610
+ # or in general HF datasets `Datasets`
611
+ elif hasattr(self.dataset, "set_epoch"):
612
+ self.dataset.set_epoch(epoch)
613
+
614
+ @property
615
+ def total_batch_size(self):
616
+ batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
617
+ return (
618
+ batch_sampler.batch_size
619
+ if getattr(batch_sampler, "split_batches", False)
620
+ else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
621
+ )
622
+
623
+ @property
624
+ def total_dataset_length(self):
625
+ if hasattr(self.dataset, "total_length"):
626
+ return self.dataset.total_length
627
+ else:
628
+ return len(self.dataset)
629
+
630
+ def get_sampler(self):
631
+ return get_sampler(self)
632
+
633
+ def set_sampler(self, sampler):
634
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
635
+ if sampler_is_batch_sampler:
636
+ self.sampler.sampler = sampler
637
+ else:
638
+ self.batch_sampler.sampler = sampler
639
+ if hasattr(self.batch_sampler, "batch_sampler"):
640
+ self.batch_sampler.batch_sampler.sampler = sampler
641
+
642
+
643
+ if is_torch_xla_available():
644
+ import torch_xla.distributed.parallel_loader as xpl
645
+
646
+ class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
647
+ """
648
+ Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
649
+
650
+ XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
651
+ prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
652
+ thread only.
653
+
654
+ **Available attributes:**
655
+
656
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
657
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
658
+ number of processes
659
+
660
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
661
+ """
662
+
663
+ def __init__(self, dataloader: DataLoaderShard, device: torch.device):
664
+ super().__init__(dataloader, device)
665
+ self._rng_types = self._loader.rng_types
666
+ self._loader.rng_types = None
667
+ self.device = device
668
+
669
+ def __iter__(self):
670
+ if self._rng_types is not None:
671
+ synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
672
+
673
+ return super().__iter__()
674
+
675
+ def set_epoch(self, epoch: int):
676
+ if hasattr(self.dataloader, "set_epoch"):
677
+ self.dataloader.set_epoch(epoch)
678
+
679
+ @property
680
+ def total_batch_size(self):
681
+ return self._loader.total_batch_size
682
+
683
+ @property
684
+ def total_dataset_length(self):
685
+ return self._loader.total_dataset_length
686
+
687
+ @property
688
+ def batch_sampler(self):
689
+ return self._loader.batch_sampler
690
+
691
+ @property
692
+ def dataloader(self):
693
+ return self._loader
694
+
695
+
696
+ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
697
+ """
698
+ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
699
+ their part of the batch.
700
+
701
+ Args:
702
+ split_batches (`bool`, *optional*, defaults to `False`):
703
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
704
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
705
+ `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
706
+ the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
707
+ `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
708
+ size of the `dataloader` is a round multiple of `batch_size`.
709
+ skip_batches (`int`, *optional*, defaults to 0):
710
+ The number of batches to skip at the beginning of an iteration.
711
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
712
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
713
+
714
+ **Available attributes:**
715
+
716
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
717
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
718
+ number of processes
719
+
720
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
721
+ """
722
+
723
+ def __init__(
724
+ self,
725
+ dataset,
726
+ split_batches: bool = False,
727
+ skip_batches=0,
728
+ use_stateful_dataloader=False,
729
+ _drop_last: bool = False,
730
+ _non_blocking: bool = False,
731
+ slice_fn=None,
732
+ torch_device_mesh=None,
733
+ **kwargs,
734
+ ):
735
+ shuffle = False
736
+ from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
737
+
738
+ # We need to save the shuffling state of the DataPipe
739
+ if isinstance(dataset, ShufflerIterDataPipe):
740
+ shuffle = dataset._shuffle_enabled
741
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
742
+ self.split_batches = split_batches
743
+ if shuffle:
744
+ torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
745
+
746
+ self.gradient_state = GradientState()
747
+ self.state = PartialState()
748
+ self._drop_last = _drop_last
749
+ self._non_blocking = _non_blocking
750
+ self.skip_batches = skip_batches
751
+ self.torch_device_mesh = torch_device_mesh
752
+
753
+ self.slice_fn = slice_tensors if slice_fn is None else slice_fn
754
+ self.iteration = 0
755
+
756
+ # if a device mesh is provided extract each dimension (dp, fsdp, tp)
757
+ # device mesh may hold any number of dimensions, however,
758
+ # below code is for targetted support for dp, fsdp and tp
759
+
760
+ # device mesh will be used only if there is tp involved
761
+ # or any multi-dimensional parallelism involving tp
762
+ # (dp, tp) (fsdp, tp) (dp, fsdp, tp)
763
+ # otherwise the default behavour not using device mesh should be sufficient
764
+ # since multi dimensional parallelism devoid of tp would anyway need
765
+ # different batches for each process irrespective of dp or fsdp
766
+ self.submesh_tp = None
767
+ self.submesh_dp = None
768
+ self.submesh_fsdp = None
769
+ if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
770
+ self.submesh_tp = self.torch_device_mesh["tp"]
771
+ if "dp" in self.torch_device_mesh.mesh_dim_names:
772
+ self.submesh_dp = self.torch_device_mesh["dp"]
773
+ if "fsdp" in self.torch_device_mesh.mesh_dim_names:
774
+ self.submesh_fsdp = self.torch_device_mesh["fsdp"]
775
+ if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
776
+ raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")
777
+
778
+ def _fetch_batches(self, iterator):
779
+ batches, batch = None, None
780
+ # On process 0, we gather the batch to dispatch.
781
+ if self.state.process_index == 0:
782
+ # Procedure to support TP only is simpler
783
+ # since we want to dispatch the same batch of samples across all ranks
784
+ # this removes complexity of handling multiple tp rank groups when TP + DP
785
+ # combination is involved.
786
+
787
+ try:
788
+ # for TP case avoid using split_batches
789
+ # since it would mean that the dataloader should be spilling out
790
+ # duplicates of batches.
791
+ if self.split_batches:
792
+ # One batch of the main iterator is dispatched and split.
793
+ if self.submesh_tp:
794
+ logger.warning(
795
+ "Use of split_batches for TP would need the dataloader to produce duplicate batches,"
796
+ "otherwise, use dispatch_batches=True instead."
797
+ )
798
+ self._update_state_dict()
799
+ batch = next(iterator)
800
+ else:
801
+ # num_processes batches of the main iterator are concatenated then dispatched and split.
802
+ # We add the batches one by one so we have the remainder available when drop_last=False.
803
+ batches = []
804
+ if self.submesh_tp:
805
+ # when tp, extract single batch and then replicate
806
+ self._update_state_dict()
807
+ batch = next(iterator)
808
+ batches = [batch] * self.state.num_processes
809
+ else:
810
+ for _ in range(self.state.num_processes):
811
+ self._update_state_dict()
812
+ batches.append(next(iterator))
813
+ try:
814
+ batch = concatenate(batches, dim=0)
815
+ except RuntimeError as e:
816
+ raise RuntimeError(
817
+ "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
818
+ "either pass `dispatch_batches=False` and have each process fetch its own batch "
819
+ " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
820
+ "slice it into `num_processes` batches for each process."
821
+ ) from e
822
+ # In both cases, we need to get the structure of the batch that we will broadcast on other
823
+ # processes to initialize the tensors with the right shape.
824
+ # data_structure, stop_iteration
825
+ batch_info = [get_data_structure(batch), False]
826
+ except StopIteration:
827
+ batch_info = [None, True]
828
+ else:
829
+ batch_info = [None, self._stop_iteration]
830
+ # This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
831
+ broadcast_object_list(batch_info)
832
+ self._stop_iteration = batch_info[1]
833
+ if self._stop_iteration:
834
+ # If drop_last is False and split_batches is False, we may have a remainder to take care of.
835
+ if not self.split_batches and not self._drop_last:
836
+ if self.state.process_index == 0 and len(batches) > 0:
837
+ batch = concatenate(batches, dim=0)
838
+ batch_info = [get_data_structure(batch), False]
839
+ else:
840
+ batch_info = [None, True]
841
+ broadcast_object_list(batch_info)
842
+ return batch, batch_info
843
+
844
+ def __iter__(self):
845
+ self.begin()
846
+ self.set_epoch(self.iteration)
847
+ main_iterator = None
848
+ if is_torch_version(">=", "2.0.1"):
849
+ # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
850
+ # shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
851
+ # But, we only iterate through the DataLoader on process 0.
852
+ main_iterator = self.base_dataloader.__iter__()
853
+ elif self.state.process_index == 0:
854
+ main_iterator = self.base_dataloader.__iter__()
855
+ stop_iteration = False
856
+ self._stop_iteration = False
857
+ first_batch = None
858
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
859
+ batch_index = 0
860
+ while not stop_iteration:
861
+ batch, batch_info = next_batch, next_batch_info
862
+
863
+ if self.state.process_index != 0:
864
+ # Initialize tensors on other processes than process 0.
865
+ batch = initialize_tensors(batch_info[0])
866
+ batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
867
+ # Broadcast the batch before splitting it.
868
+ batch = broadcast(batch, from_process=0)
869
+
870
+ if not self._drop_last and first_batch is None:
871
+ # We keep at least num processes elements of the first batch to be able to complete the last batch
872
+ first_batch = self.slice_fn(
873
+ batch,
874
+ slice(0, self.state.num_processes),
875
+ process_index=self.state.process_index,
876
+ num_processes=self.state.num_processes,
877
+ )
878
+
879
+ if batch is None:
880
+ raise ValueError(
881
+ f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
882
+ )
883
+
884
+ observed_batch_size = find_batch_size(batch)
885
+ batch_size = observed_batch_size // self.state.num_processes
886
+
887
+ stop_iteration = self._stop_iteration
888
+ if not stop_iteration:
889
+ # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
890
+ # the dataloader since the number of batches is a round multiple of the number of processes.
891
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
892
+ # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
893
+ if self._stop_iteration and next_batch_info[0] is None:
894
+ stop_iteration = True
895
+
896
+ if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
897
+ # If the last batch is not complete, let's add the first batch to it.
898
+ batch = concatenate([batch, first_batch], dim=0)
899
+ # Batch size computation above is wrong, it's off by 1 so we fix it.
900
+ batch_size += 1
901
+
902
+ data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
903
+ batch = self.slice_fn(
904
+ batch,
905
+ data_slice,
906
+ process_index=self.state.process_index,
907
+ num_processes=self.state.num_processes,
908
+ )
909
+
910
+ if stop_iteration:
911
+ self.end_of_dataloader = True
912
+ self._update_state_dict()
913
+ self.remainder = observed_batch_size
914
+ if batch_index >= self.skip_batches:
915
+ yield batch
916
+ batch_index += 1
917
+ self.iteration += 1
918
+ self.end()
919
+
920
+ def set_epoch(self, epoch: int):
921
+ # In case it is manually passed in, the user can set it to what they like
922
+ if self.iteration != epoch:
923
+ self.iteration = epoch
924
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
925
+ self.batch_sampler.sampler.set_epoch(epoch)
926
+ elif hasattr(self.dataset, "set_epoch"):
927
+ self.dataset.set_epoch(epoch)
928
+
929
+ def __len__(self):
930
+ whole_length = len(self.base_dataloader)
931
+ if self.split_batches:
932
+ return whole_length
933
+ elif self._drop_last:
934
+ return whole_length // self.state.num_processes
935
+ else:
936
+ return math.ceil(whole_length / self.state.num_processes)
937
+
938
+ def __reduce__(self):
939
+ """
940
+ Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
941
+ be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
942
+ `__class__` member.
943
+ """
944
+ args = super().__reduce__()
945
+ return (DataLoaderDispatcher, *args[1:])
946
+
947
+ @property
948
+ def total_batch_size(self):
949
+ return (
950
+ self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
951
+ )
952
+
953
+ @property
954
+ def total_dataset_length(self):
955
+ return len(self.dataset)
956
+
957
+ def get_sampler(self):
958
+ return get_sampler(self)
959
+
960
+ def set_sampler(self, sampler):
961
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
962
+ if sampler_is_batch_sampler:
963
+ self.sampler.sampler = sampler
964
+ else:
965
+ self.batch_sampler.sampler = sampler
966
+ if hasattr(self.batch_sampler, "batch_sampler"):
967
+ self.batch_sampler.batch_sampler.sampler = sampler
968
+
969
+
970
+ def get_sampler(dataloader):
971
+ """
972
+ Get the sampler associated to the dataloader
973
+
974
+ Args:
975
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
976
+ The data loader to split across several devices.
977
+ Returns:
978
+ `torch.utils.data.Sampler`: The sampler associated to the dataloader
979
+ """
980
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
981
+ if sampler_is_batch_sampler:
982
+ sampler = getattr(dataloader.sampler, "sampler", None)
983
+ else:
984
+ sampler = getattr(dataloader.batch_sampler, "sampler", None)
985
+ return sampler
986
+
987
+
988
+ def prepare_data_loader(
989
+ dataloader: DataLoader,
990
+ device: Optional[torch.device] = None,
991
+ num_processes: Optional[int] = None,
992
+ process_index: Optional[int] = None,
993
+ split_batches: bool = False,
994
+ put_on_device: bool = False,
995
+ rng_types: Optional[list[Union[str, RNGType]]] = None,
996
+ dispatch_batches: Optional[bool] = None,
997
+ even_batches: bool = True,
998
+ slice_fn_for_dispatch: Optional[Callable] = None,
999
+ use_seedable_sampler: bool = False,
1000
+ data_seed: Optional[int] = None,
1001
+ non_blocking: bool = False,
1002
+ use_stateful_dataloader: bool = False,
1003
+ torch_device_mesh=None,
1004
+ ) -> DataLoader:
1005
+ """
1006
+ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
1007
+
1008
+ Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
1009
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
1010
+
1011
+ Args:
1012
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
1013
+ The data loader to split across several devices.
1014
+ device (`torch.device`):
1015
+ The target device for the returned `DataLoader`.
1016
+ num_processes (`int`, *optional*):
1017
+ The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
1018
+ process_index (`int`, *optional*):
1019
+ The index of the current process. Will default to the value given by [`~state.PartialState`].
1020
+ split_batches (`bool`, *optional*, defaults to `False`):
1021
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
1022
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
1023
+ `num_processes` batches at each iteration).
1024
+
1025
+ Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
1026
+ this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
1027
+ otherwise.
1028
+
1029
+ Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
1030
+ `batch_size`.
1031
+ put_on_device (`bool`, *optional*, defaults to `False`):
1032
+ Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
1033
+ dictionaries of tensors).
1034
+ rng_types (list of `str` or [`~utils.RNGType`]):
1035
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
1036
+ several of:
1037
+
1038
+ - `"torch"`: the base torch random number generator
1039
+ - `"cuda"`: the CUDA random number generator (GPU only)
1040
+ - `"xla"`: the XLA random number generator (TPU only)
1041
+ - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
1042
+ dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
1043
+
1044
+ dispatch_batches (`bool`, *optional*):
1045
+ If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
1046
+ are split and broadcast to each process. Will default to `True` when the underlying dataset is an
1047
+ `IterableDataset`, `False` otherwise.
1048
+ even_batches (`bool`, *optional*, defaults to `True`):
1049
+ If set to `True`, in cases where the total batch size across all processes does not exactly divide the
1050
+ dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
1051
+ all workers.
1052
+ slice_fn_for_dispatch (`Callable`, *optional*`):
1053
+ If passed, this function will be used to slice tensors across `num_processes`. Will default to
1054
+ [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
1055
+ ignored otherwise.
1056
+ use_seedable_sampler (`bool`, *optional*, defaults to `False`):
1057
+ Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
1058
+ reproducability. Comes at a cost of potentially different performances due to different shuffling
1059
+ algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
1060
+ `self.set_epoch`
1061
+ data_seed (`int`, *optional*, defaults to `None`):
1062
+ The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
1063
+ will use the current default seed from torch.
1064
+ non_blocking (`bool`, *optional*, defaults to `False`):
1065
+ If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
1066
+ `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
1067
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
1068
+ "If set to true, the dataloader prepared by the Accelerator will be backed by "
1069
+ "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
1070
+ This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
1071
+ torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
1072
+ PyTorch device mesh.
1073
+
1074
+
1075
+ Returns:
1076
+ `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
1077
+
1078
+ <Tip warning={true}>
1079
+
1080
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
1081
+ equal to `False`
1082
+
1083
+ </Tip>
1084
+ """
1085
+ if dispatch_batches is None:
1086
+ if not put_on_device:
1087
+ dispatch_batches = False
1088
+ else:
1089
+ dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
1090
+
1091
+ if dispatch_batches and not put_on_device:
1092
+ raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
1093
+ # Grab defaults from PartialState
1094
+ state = PartialState()
1095
+ if num_processes is None:
1096
+ num_processes = state.num_processes
1097
+
1098
+ if process_index is None:
1099
+ process_index = state.process_index
1100
+
1101
+ if torch_device_mesh:
1102
+ if state.distributed_type == DistributedType.DEEPSPEED:
1103
+ # In DeepSpeed, the optimizer sharing level in DP is determined by the config file.
1104
+ # Only considers "dp" and "tp".
1105
+ # Given a device mesh (dp, tp) = (2, 3):
1106
+ # - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
1107
+ # - Processes with the same DP rank will receive the same batch.
1108
+ if "tp" in torch_device_mesh.mesh_dim_names:
1109
+ submesh_tp_size = torch_device_mesh["tp"].size()
1110
+ process_index = process_index // submesh_tp_size
1111
+ num_processes = num_processes // submesh_tp_size
1112
+ else:
1113
+ # when device mesh is used, specifically with TP
1114
+ # then there is need to update process_index and num_processes
1115
+ # to bring in the effect of generating same batch across TP ranks
1116
+ # and different batch across FSDP and DP ranks.
1117
+ # Example:
1118
+ # if device mesh is (dp,fsdp,tp) = (2, 2, 3)
1119
+ # ranks would range from 0...11
1120
+ # from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
1121
+ # processes with same ranks/ids would receive the same batch
1122
+ submesh_fsdp_size = 1
1123
+ submesh_dp_size = 1
1124
+ submesh_tp_size = 1
1125
+ if "tp" in torch_device_mesh.mesh_dim_names:
1126
+ submesh_tp_size = torch_device_mesh["tp"].size()
1127
+ if "dp" in torch_device_mesh.mesh_dim_names:
1128
+ submesh_dp_size = torch_device_mesh["dp"].size()
1129
+ if "fsdp" in torch_device_mesh.mesh_dim_names:
1130
+ submesh_fsdp_size = torch_device_mesh["fsdp"].size()
1131
+ process_index = process_index // submesh_tp_size
1132
+ num_processes = submesh_fsdp_size * submesh_dp_size
1133
+
1134
+ # Sanity check
1135
+ if split_batches:
1136
+ if dataloader.batch_size is not None:
1137
+ batch_size_for_check = dataloader.batch_size
1138
+ else:
1139
+ # For custom batch_sampler
1140
+ if hasattr(dataloader.batch_sampler, "batch_size"):
1141
+ batch_size_for_check = dataloader.batch_sampler.batch_size
1142
+ else:
1143
+ raise ValueError(
1144
+ "In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
1145
+ "`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
1146
+ "Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
1147
+ f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
1148
+ )
1149
+
1150
+ if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
1151
+ raise ValueError(
1152
+ f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
1153
+ f"needs to be a round multiple of the number of processes ({num_processes})."
1154
+ )
1155
+
1156
+ new_dataset = dataloader.dataset
1157
+ # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
1158
+ new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
1159
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1160
+ synchronized_generator = None
1161
+
1162
+ sampler = get_sampler(dataloader)
1163
+ if isinstance(sampler, RandomSampler) and use_seedable_sampler:
1164
+ # When iterating through the dataloader during distributed processes
1165
+ # we want to ensure that on each process we are iterating through the same
1166
+ # samples in the same order if a seed is set. This requires a tweak
1167
+ # to the `torch.utils.data.RandomSampler` class (if used).
1168
+ sampler = SeedableRandomSampler(
1169
+ data_source=sampler.data_source,
1170
+ replacement=sampler.replacement,
1171
+ num_samples=sampler._num_samples,
1172
+ generator=getattr(
1173
+ sampler,
1174
+ "generator",
1175
+ torch.Generator(device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"),
1176
+ ),
1177
+ data_seed=data_seed,
1178
+ )
1179
+
1180
+ if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
1181
+ # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
1182
+ generator = torch.Generator(
1183
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1184
+ )
1185
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
1186
+ generator.manual_seed(seed)
1187
+ dataloader.generator = generator
1188
+ dataloader.sampler.generator = generator
1189
+ # No change if no multiprocess
1190
+ if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
1191
+ if isinstance(new_dataset, IterableDataset):
1192
+ if getattr(dataloader.dataset, "generator", None) is not None:
1193
+ synchronized_generator = dataloader.dataset.generator
1194
+ new_dataset = IterableDatasetShard(
1195
+ new_dataset,
1196
+ batch_size=dataloader.batch_size,
1197
+ drop_last=dataloader.drop_last,
1198
+ num_processes=num_processes,
1199
+ process_index=process_index,
1200
+ split_batches=split_batches,
1201
+ )
1202
+ else:
1203
+ if not use_seedable_sampler and hasattr(sampler, "generator"):
1204
+ if sampler.generator is None:
1205
+ sampler.generator = torch.Generator(
1206
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1207
+ )
1208
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
1209
+ sampler.generator.manual_seed(seed)
1210
+ synchronized_generator = sampler.generator
1211
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1212
+ new_batch_sampler = BatchSamplerShard(
1213
+ batch_sampler,
1214
+ num_processes=num_processes,
1215
+ process_index=process_index,
1216
+ split_batches=split_batches,
1217
+ even_batches=even_batches,
1218
+ )
1219
+
1220
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1221
+ ignore_kwargs = [
1222
+ "batch_size",
1223
+ "shuffle",
1224
+ "sampler",
1225
+ "batch_sampler",
1226
+ "drop_last",
1227
+ ]
1228
+
1229
+ if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
1230
+ rng_types.remove("generator")
1231
+
1232
+ kwargs = {
1233
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1234
+ for k in _PYTORCH_DATALOADER_KWARGS
1235
+ if k not in ignore_kwargs
1236
+ }
1237
+
1238
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1239
+ if new_batch_sampler is None:
1240
+ kwargs["drop_last"] = dataloader.drop_last
1241
+ kwargs["batch_size"] = (
1242
+ dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
1243
+ )
1244
+ if dispatch_batches:
1245
+ kwargs.pop("generator")
1246
+ dataloader = DataLoaderDispatcher(
1247
+ new_dataset,
1248
+ split_batches=split_batches,
1249
+ batch_sampler=new_batch_sampler,
1250
+ _drop_last=dataloader.drop_last,
1251
+ _non_blocking=non_blocking,
1252
+ slice_fn=slice_fn_for_dispatch,
1253
+ use_stateful_dataloader=use_stateful_dataloader,
1254
+ torch_device_mesh=torch_device_mesh,
1255
+ **kwargs,
1256
+ )
1257
+ elif sampler_is_batch_sampler:
1258
+ dataloader = DataLoaderShard(
1259
+ new_dataset,
1260
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1261
+ sampler=new_batch_sampler,
1262
+ batch_size=dataloader.batch_size,
1263
+ rng_types=rng_types,
1264
+ _drop_last=dataloader.drop_last,
1265
+ _non_blocking=non_blocking,
1266
+ synchronized_generator=synchronized_generator,
1267
+ use_stateful_dataloader=use_stateful_dataloader,
1268
+ **kwargs,
1269
+ )
1270
+ else:
1271
+ dataloader = DataLoaderShard(
1272
+ new_dataset,
1273
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1274
+ batch_sampler=new_batch_sampler,
1275
+ rng_types=rng_types,
1276
+ synchronized_generator=synchronized_generator,
1277
+ _drop_last=dataloader.drop_last,
1278
+ _non_blocking=non_blocking,
1279
+ use_stateful_dataloader=use_stateful_dataloader,
1280
+ **kwargs,
1281
+ )
1282
+
1283
+ if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
1284
+ dataloader.set_sampler(sampler)
1285
+ if state.distributed_type == DistributedType.XLA:
1286
+ return MpDeviceLoaderWrapper(dataloader, device)
1287
+ return dataloader
1288
+
1289
+
1290
+ class SkipBatchSampler(BatchSampler):
1291
+ """
1292
+ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
1293
+ Should not be used if the original dataloader is a `StatefulDataLoader`.
1294
+ """
1295
+
1296
+ def __init__(self, batch_sampler, skip_batches=0):
1297
+ self.batch_sampler = batch_sampler
1298
+ self.skip_batches = skip_batches
1299
+
1300
+ def __iter__(self):
1301
+ for index, samples in enumerate(self.batch_sampler):
1302
+ if index >= self.skip_batches:
1303
+ yield samples
1304
+
1305
+ @property
1306
+ def total_length(self):
1307
+ return len(self.batch_sampler)
1308
+
1309
+ def __len__(self):
1310
+ return len(self.batch_sampler) - self.skip_batches
1311
+
1312
+
1313
+ class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
1314
+ """
1315
+ Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
1316
+ `skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
1317
+
1318
+ Args:
1319
+ dataset (`torch.utils.data.dataset.Dataset`):
1320
+ The dataset to use to build this dataloader.
1321
+ skip_batches (`int`, *optional*, defaults to 0):
1322
+ The number of batches to skip at the beginning.
1323
+ kwargs:
1324
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
1325
+ """
1326
+
1327
+ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
1328
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
1329
+ self.skip_batches = skip_batches
1330
+ self.gradient_state = GradientState()
1331
+
1332
+ def __iter__(self):
1333
+ self.begin()
1334
+ for index, batch in enumerate(self.base_dataloader.__iter__()):
1335
+ if index >= self.skip_batches:
1336
+ self._update_state_dict()
1337
+ yield batch
1338
+ self.end()
1339
+
1340
+ def __len__(self):
1341
+ return len(self.base_dataloader) - self.skip_batches
1342
+
1343
+ def __reduce__(self):
1344
+ """
1345
+ Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
1346
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
1347
+ `__class__` member.
1348
+ """
1349
+ args = super().__reduce__()
1350
+ return (SkipDataLoader, *args[1:])
1351
+
1352
+
1353
+ def skip_first_batches(dataloader, num_batches=0):
1354
+ """
1355
+ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
1356
+ the original dataloader is a `StatefulDataLoader`.
1357
+ """
1358
+ state = PartialState()
1359
+ if state.distributed_type == DistributedType.XLA:
1360
+ device = dataloader.device
1361
+ dataloader = dataloader.dataloader
1362
+
1363
+ dataset = dataloader.dataset
1364
+ sampler_is_batch_sampler = False
1365
+ if isinstance(dataset, IterableDataset):
1366
+ new_batch_sampler = None
1367
+ else:
1368
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1369
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1370
+ new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
1371
+
1372
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1373
+ ignore_kwargs = [
1374
+ "batch_size",
1375
+ "shuffle",
1376
+ "sampler",
1377
+ "batch_sampler",
1378
+ "drop_last",
1379
+ ]
1380
+
1381
+ kwargs = {
1382
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1383
+ for k in _PYTORCH_DATALOADER_KWARGS
1384
+ if k not in ignore_kwargs
1385
+ }
1386
+
1387
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1388
+ if new_batch_sampler is None:
1389
+ kwargs["drop_last"] = dataloader.drop_last
1390
+ kwargs["batch_size"] = dataloader.batch_size
1391
+
1392
+ if isinstance(dataloader, DataLoaderDispatcher):
1393
+ if new_batch_sampler is None:
1394
+ # Need to manually skip batches in the dataloader
1395
+ kwargs["skip_batches"] = num_batches
1396
+ dataloader = DataLoaderDispatcher(
1397
+ dataset,
1398
+ split_batches=dataloader.split_batches,
1399
+ batch_sampler=new_batch_sampler,
1400
+ _drop_last=dataloader._drop_last,
1401
+ **kwargs,
1402
+ )
1403
+ elif isinstance(dataloader, DataLoaderShard):
1404
+ if new_batch_sampler is None:
1405
+ # Need to manually skip batches in the dataloader
1406
+ kwargs["skip_batches"] = num_batches
1407
+ elif sampler_is_batch_sampler:
1408
+ kwargs["sampler"] = new_batch_sampler
1409
+ kwargs["batch_size"] = dataloader.batch_size
1410
+ else:
1411
+ kwargs["batch_sampler"] = new_batch_sampler
1412
+ dataloader = DataLoaderShard(
1413
+ dataset,
1414
+ device=dataloader.device,
1415
+ rng_types=dataloader.rng_types,
1416
+ synchronized_generator=dataloader.synchronized_generator,
1417
+ **kwargs,
1418
+ )
1419
+ else:
1420
+ if new_batch_sampler is None:
1421
+ # Need to manually skip batches in the dataloader
1422
+ dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
1423
+ else:
1424
+ dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
1425
+
1426
+ if state.distributed_type == DistributedType.XLA:
1427
+ dataloader = MpDeviceLoaderWrapper(dataloader, device)
1428
+
1429
+ return dataloader
venv/Lib/site-packages/accelerate/hooks.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from collections.abc import Mapping
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from .state import PartialState
23
+ from .utils import (
24
+ PrefixedDataset,
25
+ find_device,
26
+ named_module_tensors,
27
+ send_to_device,
28
+ set_module_tensor_to_device,
29
+ )
30
+ from .utils.imports import (
31
+ is_mlu_available,
32
+ is_musa_available,
33
+ is_npu_available,
34
+ )
35
+ from .utils.memory import clear_device_cache
36
+ from .utils.modeling import get_non_persistent_buffers
37
+ from .utils.other import recursive_getattr
38
+
39
+
40
+ _accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "sdaa", "musa"]
41
+
42
+
43
+ class ModelHook:
44
+ """
45
+ A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
46
+ with PyTorch existing hooks is that they get passed along the kwargs.
47
+
48
+ Class attribute:
49
+ - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
50
+ the `torch.no_grad()` context manager.
51
+ """
52
+
53
+ no_grad = False
54
+
55
+ def init_hook(self, module):
56
+ """
57
+ To be executed when the hook is attached to the module.
58
+
59
+ Args:
60
+ module (`torch.nn.Module`): The module attached to this hook.
61
+ """
62
+ return module
63
+
64
+ def pre_forward(self, module, *args, **kwargs):
65
+ """
66
+ To be executed just before the forward method of the model.
67
+
68
+ Args:
69
+ module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
70
+ args (`Tuple[Any]`): The positional arguments passed to the module.
71
+ kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
72
+
73
+ Returns:
74
+ `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
75
+ """
76
+ return args, kwargs
77
+
78
+ def post_forward(self, module, output):
79
+ """
80
+ To be executed just after the forward method of the model.
81
+
82
+ Args:
83
+ module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
84
+ output (`Any`): The output of the module.
85
+
86
+ Returns:
87
+ `Any`: The processed `output`.
88
+ """
89
+ return output
90
+
91
+ def detach_hook(self, module):
92
+ """
93
+ To be executed when the hook is detached from a module.
94
+
95
+ Args:
96
+ module (`torch.nn.Module`): The module detached from this hook.
97
+ """
98
+ return module
99
+
100
+
101
+ class SequentialHook(ModelHook):
102
+ """
103
+ A hook that can contain several hooks and iterates through them at each event.
104
+ """
105
+
106
+ def __init__(self, *hooks):
107
+ self.hooks = hooks
108
+
109
+ def init_hook(self, module):
110
+ for hook in self.hooks:
111
+ module = hook.init_hook(module)
112
+ return module
113
+
114
+ def pre_forward(self, module, *args, **kwargs):
115
+ for hook in self.hooks:
116
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
117
+ return args, kwargs
118
+
119
+ def post_forward(self, module, output):
120
+ for hook in self.hooks:
121
+ output = hook.post_forward(module, output)
122
+ return output
123
+
124
+ def detach_hook(self, module):
125
+ for hook in self.hooks:
126
+ module = hook.detach_hook(module)
127
+ return module
128
+
129
+
130
+ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
131
+ """
132
+ Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
133
+ this behavior and restore the original `forward` method, use `remove_hook_from_module`.
134
+
135
+ <Tip warning={true}>
136
+
137
+ If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
138
+ together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
139
+
140
+ </Tip>
141
+
142
+ Args:
143
+ module (`torch.nn.Module`):
144
+ The module to attach a hook to.
145
+ hook (`ModelHook`):
146
+ The hook to attach.
147
+ append (`bool`, *optional*, defaults to `False`):
148
+ Whether the hook should be chained with an existing one (if module already contains a hook) or not.
149
+
150
+ Returns:
151
+ `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
152
+ be discarded).
153
+ """
154
+
155
+ if append and (getattr(module, "_hf_hook", None) is not None):
156
+ old_hook = module._hf_hook
157
+ remove_hook_from_module(module)
158
+ hook = SequentialHook(old_hook, hook)
159
+
160
+ if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
161
+ # If we already put some hook on this module, we replace it with the new one.
162
+ old_forward = module._old_forward
163
+ else:
164
+ old_forward = module.forward
165
+ module._old_forward = old_forward
166
+
167
+ module = hook.init_hook(module)
168
+ module._hf_hook = hook
169
+
170
+ def new_forward(module, *args, **kwargs):
171
+ args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
172
+ if module._hf_hook.no_grad:
173
+ with torch.no_grad():
174
+ output = module._old_forward(*args, **kwargs)
175
+ else:
176
+ output = module._old_forward(*args, **kwargs)
177
+ return module._hf_hook.post_forward(module, output)
178
+
179
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
180
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
181
+ if "GraphModuleImpl" in str(type(module)):
182
+ module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
183
+ else:
184
+ module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
185
+
186
+ return module
187
+
188
+
189
+ def remove_hook_from_module(module: nn.Module, recurse=False):
190
+ """
191
+ Removes any hook attached to a module via `add_hook_to_module`.
192
+
193
+ Args:
194
+ module (`torch.nn.Module`): The module to attach a hook to.
195
+ recurse (`bool`, **optional**): Whether to remove the hooks recursively
196
+
197
+ Returns:
198
+ `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
199
+ be discarded).
200
+ """
201
+
202
+ if hasattr(module, "_hf_hook"):
203
+ module._hf_hook.detach_hook(module)
204
+ delattr(module, "_hf_hook")
205
+
206
+ if hasattr(module, "_old_forward"):
207
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
208
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
209
+ if "GraphModuleImpl" in str(type(module)):
210
+ module.__class__.forward = module._old_forward
211
+ else:
212
+ module.forward = module._old_forward
213
+ delattr(module, "_old_forward")
214
+
215
+ # Remove accelerate added warning hooks from dispatch_model
216
+ for attr in _accelerate_added_attributes:
217
+ module.__dict__.pop(attr, None)
218
+
219
+ if recurse:
220
+ for child in module.children():
221
+ remove_hook_from_module(child, recurse)
222
+
223
+ return module
224
+
225
+
226
+ class AlignDevicesHook(ModelHook):
227
+ """
228
+ A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
229
+ associated module, potentially offloading the weights after the forward pass.
230
+
231
+ Args:
232
+ execution_device (`torch.device`, *optional*):
233
+ The device on which inputs and model weights should be placed before the forward pass.
234
+ offload (`bool`, *optional*, defaults to `False`):
235
+ Whether or not the weights should be offloaded after the forward pass.
236
+ io_same_device (`bool`, *optional*, defaults to `False`):
237
+ Whether or not the output should be placed on the same device as the input was.
238
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
239
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
240
+ offload_buffers (`bool`, *optional*, defaults to `False`):
241
+ Whether or not to include the associated module's buffers when offloading.
242
+ place_submodules (`bool`, *optional*, defaults to `False`):
243
+ Whether to place the submodules on `execution_device` during the `init_hook` event.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ execution_device: Optional[Union[int, str, torch.device]] = None,
249
+ offload: bool = False,
250
+ io_same_device: bool = False,
251
+ weights_map: Optional[Mapping] = None,
252
+ offload_buffers: bool = False,
253
+ place_submodules: bool = False,
254
+ skip_keys: Optional[Union[str, list[str]]] = None,
255
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
256
+ ):
257
+ self.execution_device = execution_device
258
+ self.offload = offload
259
+ self.io_same_device = io_same_device
260
+ self.weights_map = weights_map
261
+ self.offload_buffers = offload_buffers
262
+ self.place_submodules = place_submodules
263
+ self.skip_keys = skip_keys
264
+
265
+ # Will contain the input device when `io_same_device=True`.
266
+ self.input_device = None
267
+ self.param_original_devices = {}
268
+ self.buffer_original_devices = {}
269
+ self.tied_params_names = set()
270
+
271
+ # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
272
+ # for tied weights already loaded on the target execution device.
273
+ self.tied_params_map = tied_params_map
274
+
275
+ def __repr__(self):
276
+ return (
277
+ f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
278
+ f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
279
+ f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
280
+ )
281
+
282
+ def init_hook(self, module):
283
+ # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
284
+ if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
285
+ self.tied_params_map = None
286
+
287
+ if not self.offload and self.execution_device is not None:
288
+ for name, _ in named_module_tensors(module, recurse=self.place_submodules):
289
+ set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
290
+ elif self.offload:
291
+ self.original_devices = {
292
+ name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
293
+ }
294
+ if self.weights_map is None:
295
+ self.weights_map = {
296
+ name: param.to("cpu")
297
+ for name, param in named_module_tensors(
298
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules
299
+ )
300
+ }
301
+ for name, _ in named_module_tensors(
302
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
303
+ ):
304
+ # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
305
+ # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
306
+ # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
307
+ # to add on the fly pointers to `tied_params_map` in the pre_forward call.
308
+ if (
309
+ self.tied_params_map is not None
310
+ and recursive_getattr(module, name).data_ptr() in self.tied_params_map
311
+ ):
312
+ self.tied_params_names.add(name)
313
+
314
+ set_module_tensor_to_device(module, name, "meta")
315
+
316
+ if not self.offload_buffers and self.execution_device is not None:
317
+ for name, _ in module.named_buffers(recurse=self.place_submodules):
318
+ set_module_tensor_to_device(
319
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
320
+ )
321
+ elif self.offload_buffers and self.execution_device is not None:
322
+ for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
323
+ set_module_tensor_to_device(
324
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
325
+ )
326
+
327
+ return module
328
+
329
+ def pre_forward(self, module, *args, **kwargs):
330
+ if self.io_same_device:
331
+ self.input_device = find_device([args, kwargs])
332
+ if self.offload:
333
+ self.tied_pointers_to_remove = set()
334
+
335
+ for name, _ in named_module_tensors(
336
+ module,
337
+ include_buffers=self.offload_buffers,
338
+ recurse=self.place_submodules,
339
+ remove_non_persistent=True,
340
+ ):
341
+ fp16_statistics = None
342
+ value = self.weights_map[name]
343
+ if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
344
+ if value.dtype == torch.int8:
345
+ fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
346
+
347
+ # In case we are using offloading with tied weights, we need to keep track of the offloaded weights
348
+ # that are loaded on device at this point, as we will need to remove them as well from the dictionary
349
+ # self.tied_params_map in order to allow to free memory.
350
+ if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
351
+ self.tied_params_map[value.data_ptr()] = {}
352
+
353
+ if (
354
+ value is not None
355
+ and self.tied_params_map is not None
356
+ and value.data_ptr() in self.tied_params_map
357
+ and self.execution_device not in self.tied_params_map[value.data_ptr()]
358
+ ):
359
+ self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
360
+
361
+ set_module_tensor_to_device(
362
+ module,
363
+ name,
364
+ self.execution_device,
365
+ value=value,
366
+ fp16_statistics=fp16_statistics,
367
+ tied_params_map=self.tied_params_map,
368
+ )
369
+
370
+ return send_to_device(args, self.execution_device), send_to_device(
371
+ kwargs, self.execution_device, skip_keys=self.skip_keys
372
+ )
373
+
374
+ def post_forward(self, module, output):
375
+ if self.offload:
376
+ for name, _ in named_module_tensors(
377
+ module,
378
+ include_buffers=self.offload_buffers,
379
+ recurse=self.place_submodules,
380
+ remove_non_persistent=True,
381
+ ):
382
+ set_module_tensor_to_device(module, name, "meta")
383
+ if type(module).__name__ == "Linear8bitLt":
384
+ module.state.SCB = None
385
+ module.state.CxB = None
386
+
387
+ # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
388
+ # this dictionary to allow the garbage collector to do its job.
389
+ for value_pointer, device in self.tied_pointers_to_remove:
390
+ if isinstance(device, int):
391
+ if is_npu_available():
392
+ device = f"npu:{device}"
393
+ elif is_mlu_available():
394
+ device = f"mlu:{device}"
395
+ elif is_musa_available():
396
+ device = f"musa:{device}"
397
+ if device in self.tied_params_map[value_pointer]:
398
+ del self.tied_params_map[value_pointer][device]
399
+ self.tied_pointers_to_remove = set()
400
+ if self.io_same_device and self.input_device is not None:
401
+ output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
402
+
403
+ return output
404
+
405
+ def detach_hook(self, module):
406
+ if self.offload:
407
+ for name, device in self.original_devices.items():
408
+ if device != torch.device("meta"):
409
+ set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
410
+ return module
411
+
412
+
413
+ def attach_execution_device_hook(
414
+ module: torch.nn.Module,
415
+ execution_device: Union[int, str, torch.device],
416
+ skip_keys: Optional[Union[str, list[str]]] = None,
417
+ preload_module_classes: Optional[list[str]] = None,
418
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
419
+ ):
420
+ """
421
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
422
+ execution device
423
+
424
+ Args:
425
+ module (`torch.nn.Module`):
426
+ The module where we want to attach the hooks.
427
+ execution_device (`int`, `str` or `torch.device`):
428
+ The device on which inputs and model weights should be placed before the forward pass.
429
+ skip_keys (`str` or `List[str]`, *optional*):
430
+ A list of keys to ignore when moving inputs or outputs between devices.
431
+ preload_module_classes (`List[str]`, *optional*):
432
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
433
+ of the forward. This should only be used for classes that have submodules which are registered but not
434
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
435
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
436
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
437
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
438
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
439
+ instead of duplicating memory.
440
+ """
441
+ if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
442
+ add_hook_to_module(
443
+ module,
444
+ AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
445
+ )
446
+
447
+ # Break the recursion if we get to a preload module.
448
+ if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
449
+ return
450
+
451
+ for child in module.children():
452
+ attach_execution_device_hook(
453
+ child,
454
+ execution_device,
455
+ skip_keys=skip_keys,
456
+ preload_module_classes=preload_module_classes,
457
+ tied_params_map=tied_params_map,
458
+ )
459
+
460
+
461
+ def attach_align_device_hook(
462
+ module: torch.nn.Module,
463
+ execution_device: Optional[torch.device] = None,
464
+ offload: bool = False,
465
+ weights_map: Optional[Mapping] = None,
466
+ offload_buffers: bool = False,
467
+ module_name: str = "",
468
+ skip_keys: Optional[Union[str, list[str]]] = None,
469
+ preload_module_classes: Optional[list[str]] = None,
470
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
471
+ ):
472
+ """
473
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
474
+ buffers.
475
+
476
+ Args:
477
+ module (`torch.nn.Module`):
478
+ The module where we want to attach the hooks.
479
+ execution_device (`torch.device`, *optional*):
480
+ The device on which inputs and model weights should be placed before the forward pass.
481
+ offload (`bool`, *optional*, defaults to `False`):
482
+ Whether or not the weights should be offloaded after the forward pass.
483
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
484
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
485
+ offload_buffers (`bool`, *optional*, defaults to `False`):
486
+ Whether or not to include the associated module's buffers when offloading.
487
+ module_name (`str`, *optional*, defaults to `""`):
488
+ The name of the module.
489
+ skip_keys (`str` or `List[str]`, *optional*):
490
+ A list of keys to ignore when moving inputs or outputs between devices.
491
+ preload_module_classes (`List[str]`, *optional*):
492
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
493
+ of the forward. This should only be used for classes that have submodules which are registered but not
494
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
495
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
496
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
497
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
498
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
499
+ instead of duplicating memory.
500
+ """
501
+ # Attach the hook on this module if it has any direct tensor.
502
+ directs = named_module_tensors(module)
503
+ full_offload = (
504
+ offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
505
+ )
506
+
507
+ if len(list(directs)) > 0 or full_offload:
508
+ if weights_map is not None:
509
+ prefix = f"{module_name}." if len(module_name) > 0 else ""
510
+ prefixed_weights_map = PrefixedDataset(weights_map, prefix)
511
+ else:
512
+ prefixed_weights_map = None
513
+ hook = AlignDevicesHook(
514
+ execution_device=execution_device,
515
+ offload=offload,
516
+ weights_map=prefixed_weights_map,
517
+ offload_buffers=offload_buffers,
518
+ place_submodules=full_offload,
519
+ skip_keys=skip_keys,
520
+ tied_params_map=tied_params_map,
521
+ )
522
+ add_hook_to_module(module, hook, append=True)
523
+
524
+ # We stop the recursion in case we hit the full offload.
525
+ if full_offload:
526
+ return
527
+
528
+ # Recurse on all children of the module.
529
+ for child_name, child in module.named_children():
530
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
531
+ attach_align_device_hook(
532
+ child,
533
+ execution_device=execution_device,
534
+ offload=offload,
535
+ weights_map=weights_map,
536
+ offload_buffers=offload_buffers,
537
+ module_name=child_name,
538
+ preload_module_classes=preload_module_classes,
539
+ skip_keys=skip_keys,
540
+ tied_params_map=tied_params_map,
541
+ )
542
+
543
+
544
+ def remove_hook_from_submodules(module: nn.Module):
545
+ """
546
+ Recursively removes all hooks attached on the submodules of a given model.
547
+
548
+ Args:
549
+ module (`torch.nn.Module`): The module on which to remove all hooks.
550
+ """
551
+ remove_hook_from_module(module)
552
+ for child in module.children():
553
+ remove_hook_from_submodules(child)
554
+
555
+
556
+ def attach_align_device_hook_on_blocks(
557
+ module: nn.Module,
558
+ execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,
559
+ offload: Union[bool, dict[str, bool]] = False,
560
+ weights_map: Mapping = None,
561
+ offload_buffers: bool = False,
562
+ module_name: str = "",
563
+ skip_keys: Optional[Union[str, list[str]]] = None,
564
+ preload_module_classes: Optional[list[str]] = None,
565
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
566
+ ):
567
+ """
568
+ Attaches `AlignDevicesHook` to all blocks of a given model as needed.
569
+
570
+ Args:
571
+ module (`torch.nn.Module`):
572
+ The module where we want to attach the hooks.
573
+ execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
574
+ The device on which inputs and model weights should be placed before the forward pass. It can be one device
575
+ for the whole module, or a dictionary mapping module name to device.
576
+ offload (`bool`, *optional*, defaults to `False`):
577
+ Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
578
+ module, or a dictionary mapping module name to boolean.
579
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
580
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
581
+ offload_buffers (`bool`, *optional*, defaults to `False`):
582
+ Whether or not to include the associated module's buffers when offloading.
583
+ module_name (`str`, *optional*, defaults to `""`):
584
+ The name of the module.
585
+ skip_keys (`str` or `List[str]`, *optional*):
586
+ A list of keys to ignore when moving inputs or outputs between devices.
587
+ preload_module_classes (`List[str]`, *optional*):
588
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
589
+ of the forward. This should only be used for classes that have submodules which are registered but not
590
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
591
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
592
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
593
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
594
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
595
+ instead of duplicating memory.
596
+ """
597
+ # If one device and one offload, we've got one hook.
598
+ if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
599
+ if not offload:
600
+ hook = AlignDevicesHook(
601
+ execution_device=execution_device,
602
+ io_same_device=True,
603
+ skip_keys=skip_keys,
604
+ place_submodules=True,
605
+ tied_params_map=tied_params_map,
606
+ )
607
+ add_hook_to_module(module, hook)
608
+ else:
609
+ attach_align_device_hook(
610
+ module,
611
+ execution_device=execution_device,
612
+ offload=True,
613
+ weights_map=weights_map,
614
+ offload_buffers=offload_buffers,
615
+ module_name=module_name,
616
+ skip_keys=skip_keys,
617
+ tied_params_map=tied_params_map,
618
+ )
619
+ return
620
+
621
+ if not isinstance(execution_device, Mapping):
622
+ execution_device = {key: execution_device for key in offload.keys()}
623
+ if not isinstance(offload, Mapping):
624
+ offload = {key: offload for key in execution_device.keys()}
625
+
626
+ if module_name in execution_device and module_name in offload and not offload[module_name]:
627
+ hook = AlignDevicesHook(
628
+ execution_device=execution_device[module_name],
629
+ offload_buffers=offload_buffers,
630
+ io_same_device=(module_name == ""),
631
+ place_submodules=True,
632
+ skip_keys=skip_keys,
633
+ tied_params_map=tied_params_map,
634
+ )
635
+ add_hook_to_module(module, hook)
636
+ attach_execution_device_hook(
637
+ module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
638
+ )
639
+ elif module_name in execution_device and module_name in offload:
640
+ attach_align_device_hook(
641
+ module,
642
+ execution_device=execution_device[module_name],
643
+ offload=True,
644
+ weights_map=weights_map,
645
+ offload_buffers=offload_buffers,
646
+ module_name=module_name,
647
+ skip_keys=skip_keys,
648
+ preload_module_classes=preload_module_classes,
649
+ tied_params_map=tied_params_map,
650
+ )
651
+ if not hasattr(module, "_hf_hook"):
652
+ hook = AlignDevicesHook(
653
+ execution_device=execution_device[module_name],
654
+ io_same_device=(module_name == ""),
655
+ skip_keys=skip_keys,
656
+ tied_params_map=tied_params_map,
657
+ )
658
+ add_hook_to_module(module, hook)
659
+ attach_execution_device_hook(
660
+ module,
661
+ execution_device[module_name],
662
+ preload_module_classes=preload_module_classes,
663
+ skip_keys=skip_keys,
664
+ tied_params_map=tied_params_map,
665
+ )
666
+ elif module_name == "":
667
+ hook = AlignDevicesHook(
668
+ execution_device=execution_device.get(""),
669
+ io_same_device=True,
670
+ skip_keys=skip_keys,
671
+ tied_params_map=tied_params_map,
672
+ )
673
+ add_hook_to_module(module, hook)
674
+
675
+ for child_name, child in module.named_children():
676
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
677
+ attach_align_device_hook_on_blocks(
678
+ child,
679
+ execution_device=execution_device,
680
+ offload=offload,
681
+ weights_map=weights_map,
682
+ offload_buffers=offload_buffers,
683
+ module_name=child_name,
684
+ preload_module_classes=preload_module_classes,
685
+ skip_keys=skip_keys,
686
+ tied_params_map=tied_params_map,
687
+ )
688
+
689
+
690
+ class CpuOffload(ModelHook):
691
+ """
692
+ Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
693
+ the forward, the user needs to call the `init_hook` method again for this.
694
+
695
+ Args:
696
+ execution_device(`str`, `int` or `torch.device`, *optional*):
697
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
698
+ GPU 0 if there is a GPU, and finally to the CPU.
699
+ prev_module_hook (`UserCpuOffloadHook`, *optional*):
700
+ The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
701
+ passed, its offload method will be called just before the forward of the model to which this hook is
702
+ attached.
703
+ """
704
+
705
+ def __init__(
706
+ self,
707
+ execution_device: Optional[Union[str, int, torch.device]] = None,
708
+ prev_module_hook: Optional["UserCpuOffloadHook"] = None,
709
+ ):
710
+ self.prev_module_hook = prev_module_hook
711
+
712
+ self.execution_device = execution_device if execution_device is not None else PartialState().default_device
713
+
714
+ def init_hook(self, module):
715
+ return module.to("cpu")
716
+
717
+ def pre_forward(self, module, *args, **kwargs):
718
+ if self.prev_module_hook is not None:
719
+ self.prev_module_hook.offload()
720
+ clear_device_cache()
721
+ module.to(self.execution_device)
722
+ return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
723
+
724
+
725
+ class UserCpuOffloadHook:
726
+ """
727
+ A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
728
+ or remove it entirely.
729
+ """
730
+
731
+ def __init__(self, model, hook):
732
+ self.model = model
733
+ self.hook = hook
734
+
735
+ def offload(self):
736
+ self.hook.init_hook(self.model)
737
+
738
+ def remove(self):
739
+ remove_hook_from_module(self.model)
venv/Lib/site-packages/accelerate/inference.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from types import MethodType
16
+ from typing import Any, Optional, Union
17
+
18
+ from .state import PartialState
19
+ from .utils import (
20
+ calculate_maximum_sizes,
21
+ convert_bytes,
22
+ copy_tensor_to_devices,
23
+ ignorant_find_batch_size,
24
+ infer_auto_device_map,
25
+ is_pippy_available,
26
+ pad_input_tensors,
27
+ send_to_device,
28
+ )
29
+
30
+
31
+ def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
32
+ """
33
+ Calculates the device map for `model` with an offset for PiPPy
34
+ """
35
+ if num_processes == 1:
36
+ return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
37
+ if max_memory is None:
38
+ model_size, shared = calculate_maximum_sizes(model)
39
+
40
+ # Split into `n` chunks for each GPU
41
+ memory = (model_size + shared[0]) / num_processes
42
+ memory = convert_bytes(memory)
43
+ value, ending = memory.split(" ")
44
+
45
+ # Add a chunk to deal with potential extra shared memory instances
46
+ memory = math.ceil(float(value)) * 1.1
47
+ memory = f"{memory} {ending}"
48
+ max_memory = {i: memory for i in range(num_processes)}
49
+ device_map = infer_auto_device_map(
50
+ model,
51
+ max_memory=max_memory,
52
+ no_split_module_classes=no_split_module_classes,
53
+ clean_result=False,
54
+ )
55
+ return device_map
56
+
57
+
58
+ def find_pippy_batch_size(args, kwargs):
59
+ found_batch_size = None
60
+ if args is not None:
61
+ for arg in args:
62
+ found_batch_size = ignorant_find_batch_size(arg)
63
+ if found_batch_size is not None:
64
+ break
65
+ if kwargs is not None and found_batch_size is None:
66
+ for kwarg in kwargs.values():
67
+ found_batch_size = ignorant_find_batch_size(kwarg)
68
+ if found_batch_size is not None:
69
+ break
70
+ return found_batch_size
71
+
72
+
73
+ def build_pipeline(model, split_points, args, kwargs, num_chunks):
74
+ """
75
+ Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
76
+ in needed `args` and `kwargs` as the model needs on the CPU.
77
+
78
+ Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
79
+ `AcceleratorState.num_processes`
80
+ """
81
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
82
+ from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
83
+
84
+ # We need to annotate the split points in the model for PiPPy
85
+ state = PartialState()
86
+ split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
87
+ pipe = pipeline(
88
+ model,
89
+ mb_args=args,
90
+ mb_kwargs=kwargs,
91
+ split_spec=split_spec,
92
+ )
93
+ stage = pipe.build_stage(state.local_process_index, device=state.device)
94
+ schedule = ScheduleGPipe(stage, num_chunks)
95
+
96
+ return schedule
97
+
98
+
99
+ def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
100
+ state = PartialState()
101
+ output = None
102
+
103
+ if state.num_processes == 1:
104
+ output = forward(*args, **kwargs)
105
+ elif state.is_local_main_process:
106
+ found_batch_size = find_pippy_batch_size(args, kwargs)
107
+ if found_batch_size is None:
108
+ raise ValueError("Could not find batch size from args or kwargs")
109
+ else:
110
+ if found_batch_size != num_chunks:
111
+ args = pad_input_tensors(args, found_batch_size, num_chunks)
112
+ kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
113
+ forward(*args, **kwargs)
114
+ elif state.is_last_process:
115
+ output = forward()
116
+ else:
117
+ forward()
118
+ if gather_output:
119
+ # Each node will get a copy of the full output which is only on the last GPU
120
+ output = copy_tensor_to_devices(output)
121
+ return output
122
+
123
+
124
+ def prepare_pippy(
125
+ model,
126
+ split_points: Optional[Union[str, list[str]]] = "auto",
127
+ no_split_module_classes: Optional[list[str]] = None,
128
+ example_args: Optional[tuple[Any]] = (),
129
+ example_kwargs: Optional[dict[str, Any]] = None,
130
+ num_chunks: Optional[int] = None,
131
+ gather_output: Optional[bool] = False,
132
+ ):
133
+ """
134
+ Wraps `model` for pipeline parallel inference.
135
+
136
+ Args:
137
+ model (`torch.nn.Module`):
138
+ A model we want to split for pipeline-parallel inference
139
+ split_points (`str` or `List[str]`, defaults to 'auto'):
140
+ How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
141
+ split given any model. Should be a list of layer names in the model to split by otherwise.
142
+ no_split_module_classes (`List[str]`):
143
+ A list of class names for layers we don't want to be split.
144
+ example_args (tuple of model inputs):
145
+ The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
146
+ this method if possible.
147
+ example_kwargs (dict of model inputs)
148
+ The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
149
+ *highly* limiting structure that requires the same keys be present at *all* inference calls. Not
150
+ recommended unless the prior condition is true for all cases.
151
+ num_chunks (`int`, defaults to the number of available GPUs):
152
+ The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
153
+ this can be tuned and played with. In general one should have num_chunks >= num_gpus.
154
+ gather_output (`bool`, defaults to `False`):
155
+ If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
156
+ """
157
+ if not is_pippy_available():
158
+ raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
159
+ state = PartialState()
160
+ example_args = send_to_device(example_args, "cpu")
161
+ example_kwargs = send_to_device(example_kwargs, "cpu")
162
+ if num_chunks is None:
163
+ num_chunks = state.num_processes
164
+ if split_points == "auto":
165
+ device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
166
+ split_points = []
167
+ for i in range(1, num_chunks):
168
+ split_points.append(next(k for k, v in device_map.items() if v == i))
169
+ model.hf_split_points = split_points
170
+ stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
171
+ model._original_forward = model.forward
172
+ model._original_call = model.__call__
173
+ model.pippy_stage = stage
174
+ model.hf_split_points = split_points
175
+
176
+ def forward(*args, **kwargs):
177
+ return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
178
+
179
+ # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
180
+ # Note: creates an infinite recursion loop with `generate`
181
+ model_forward = MethodType(forward, model)
182
+ forward.__wrapped__ = model_forward
183
+ model.forward = forward
184
+ return model
venv/Lib/site-packages/accelerate/launchers.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ import tempfile
18
+
19
+ import torch
20
+
21
+ from .state import AcceleratorState, PartialState
22
+ from .utils import (
23
+ PrecisionType,
24
+ PrepareForLaunch,
25
+ are_libraries_initialized,
26
+ check_cuda_p2p_ib_support,
27
+ get_gpu_info,
28
+ is_mps_available,
29
+ is_torch_version,
30
+ patch_environment,
31
+ )
32
+ from .utils.constants import ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION
33
+
34
+
35
+ def test_launch():
36
+ "Verify a `PartialState` can be initialized."
37
+ _ = PartialState()
38
+
39
+
40
+ def notebook_launcher(
41
+ function,
42
+ args=(),
43
+ num_processes=None,
44
+ mixed_precision="no",
45
+ use_port="29500",
46
+ master_addr="127.0.0.1",
47
+ node_rank=0,
48
+ num_nodes=1,
49
+ rdzv_backend="static",
50
+ rdzv_endpoint="",
51
+ rdzv_conf=None,
52
+ rdzv_id="none",
53
+ max_restarts=0,
54
+ monitor_interval=0.1,
55
+ log_line_prefix_template=None,
56
+ ):
57
+ """
58
+ Launches a training function, using several processes or multiple nodes if it's possible in the current environment
59
+ (TPU with multiple cores for instance).
60
+
61
+ <Tip warning={true}>
62
+
63
+ To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If
64
+ any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability.
65
+
66
+ Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
67
+ of those calls have been made.
68
+
69
+ </Tip>
70
+
71
+ Args:
72
+ function (`Callable`):
73
+ The training function to execute. If it accepts arguments, the first argument should be the index of the
74
+ process run.
75
+ args (`Tuple`):
76
+ Tuple of arguments to pass to the function (it will receive `*args`).
77
+ num_processes (`int`, *optional*):
78
+ The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
79
+ the number of GPUs available otherwise.
80
+ mixed_precision (`str`, *optional*, defaults to `"no"`):
81
+ If `fp16` or `bf16`, will use mixed precision training on multi-GPU.
82
+ use_port (`str`, *optional*, defaults to `"29500"`):
83
+ The port to use to communicate between processes when launching a multi-GPU training.
84
+ master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
85
+ The address to use for communication between processes.
86
+ node_rank (`int`, *optional*, defaults to 0):
87
+ The rank of the current node.
88
+ num_nodes (`int`, *optional*, defaults to 1):
89
+ The number of nodes to use for training.
90
+ rdzv_backend (`str`, *optional*, defaults to `"static"`):
91
+ The rendezvous method to use, such as 'static' (the default) or 'c10d'
92
+ rdzv_endpoint (`str`, *optional*, defaults to `""`):
93
+ The endpoint of the rdzv sync. storage.
94
+ rdzv_conf (`Dict`, *optional*, defaults to `None`):
95
+ Additional rendezvous configuration.
96
+ rdzv_id (`str`, *optional*, defaults to `"none"`):
97
+ The unique run id of the job.
98
+ max_restarts (`int`, *optional*, defaults to 0):
99
+ The maximum amount of restarts that elastic agent will conduct on workers before failure.
100
+ monitor_interval (`float`, *optional*, defaults to 0.1):
101
+ The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
102
+ log_line_prefix_template (`str`, *optional*, defaults to `None`):
103
+ The prefix template for elastic launch logging. Available from PyTorch 2.2.0.
104
+
105
+ Example:
106
+
107
+ ```python
108
+ # Assume this is defined in a Jupyter Notebook on an instance with two GPUs
109
+ from accelerate import notebook_launcher
110
+
111
+
112
+ def train(*args):
113
+ # Your training function here
114
+ ...
115
+
116
+
117
+ notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16")
118
+ ```
119
+ """
120
+ # Are we in a google colab or a Kaggle Kernel?
121
+ in_colab = False
122
+ in_kaggle = False
123
+ if any(key.startswith("KAGGLE") for key in os.environ.keys()):
124
+ in_kaggle = True
125
+ elif "IPython" in sys.modules:
126
+ in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
127
+
128
+ try:
129
+ mixed_precision = PrecisionType(mixed_precision.lower())
130
+ except ValueError:
131
+ raise ValueError(
132
+ f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
133
+ )
134
+
135
+ if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
136
+ # TPU launch
137
+ import torch_xla.distributed.xla_multiprocessing as xmp
138
+ from torch_xla import device_count
139
+
140
+ if len(AcceleratorState._shared_state) > 0:
141
+ raise ValueError(
142
+ "To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
143
+ "your training function. Restart your notebook and make sure no cells initializes an "
144
+ "`Accelerator`."
145
+ )
146
+
147
+ launcher = PrepareForLaunch(function, distributed_type="XLA")
148
+ print(f"Launching a training on {device_count()} TPU cores.")
149
+ xmp.spawn(launcher, args=args, start_method="fork")
150
+ elif in_colab and get_gpu_info()[1] < 2:
151
+ # No need for a distributed launch otherwise as it's either CPU or one GPU.
152
+ if torch.cuda.is_available():
153
+ print("Launching training on one GPU.")
154
+ else:
155
+ print("Launching training on one CPU.")
156
+ function(*args)
157
+ else:
158
+ if num_processes is None:
159
+ raise ValueError(
160
+ "You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call."
161
+ )
162
+ if node_rank >= num_nodes:
163
+ raise ValueError("The node_rank must be less than the number of nodes.")
164
+ if num_processes > 1:
165
+ # Multi-GPU launch
166
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
167
+ from torch.multiprocessing import start_processes
168
+ from torch.multiprocessing.spawn import ProcessRaisedException
169
+
170
+ if len(AcceleratorState._shared_state) > 0:
171
+ raise ValueError(
172
+ "To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized "
173
+ "inside your training function. Restart your notebook and make sure no cells initializes an "
174
+ "`Accelerator`."
175
+ )
176
+ # Check for specific libraries known to initialize CUDA that users constantly use
177
+ problematic_imports = are_libraries_initialized("bitsandbytes")
178
+ if len(problematic_imports) > 0:
179
+ err = (
180
+ "Could not start distributed process. Libraries known to initialize CUDA upon import have been "
181
+ "imported already. Please keep these imports inside your training function to try and help with this:"
182
+ )
183
+ for lib_name in problematic_imports:
184
+ err += f"\n\t* `{lib_name}`"
185
+ raise RuntimeError(err)
186
+
187
+ patched_env = dict(
188
+ nproc=num_processes,
189
+ node_rank=node_rank,
190
+ world_size=num_nodes * num_processes,
191
+ master_addr=master_addr,
192
+ master_port=use_port,
193
+ mixed_precision=mixed_precision,
194
+ )
195
+
196
+ # Check for CUDA P2P and IB issues
197
+ if not check_cuda_p2p_ib_support():
198
+ patched_env["nccl_p2p_disable"] = "1"
199
+ patched_env["nccl_ib_disable"] = "1"
200
+
201
+ # torch.distributed will expect a few environment variable to be here. We set the ones common to each
202
+ # process here (the other ones will be set be the launcher).
203
+ with patch_environment(**patched_env):
204
+ # First dummy launch
205
+ if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
206
+ launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
207
+ try:
208
+ start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
209
+ except ProcessRaisedException as e:
210
+ err = "An issue was found when verifying a stable environment for the notebook launcher."
211
+ if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
212
+ raise RuntimeError(
213
+ f"{err}"
214
+ "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
215
+ "Please review your imports and test them when running the `notebook_launcher()` to identify "
216
+ "which one is problematic and causing CUDA to be initialized."
217
+ ) from e
218
+ else:
219
+ raise RuntimeError(f"{err} The following error was raised: {e}") from e
220
+ # Now the actual launch
221
+ launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
222
+ print(f"Launching training on {num_processes} GPUs.")
223
+ try:
224
+ if rdzv_conf is None:
225
+ rdzv_conf = {}
226
+ if rdzv_backend == "static":
227
+ rdzv_conf["rank"] = node_rank
228
+ if not rdzv_endpoint:
229
+ rdzv_endpoint = f"{master_addr}:{use_port}"
230
+ launch_config_kwargs = dict(
231
+ min_nodes=num_nodes,
232
+ max_nodes=num_nodes,
233
+ nproc_per_node=num_processes,
234
+ run_id=rdzv_id,
235
+ rdzv_endpoint=rdzv_endpoint,
236
+ rdzv_backend=rdzv_backend,
237
+ rdzv_configs=rdzv_conf,
238
+ max_restarts=max_restarts,
239
+ monitor_interval=monitor_interval,
240
+ start_method="fork",
241
+ )
242
+ if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
243
+ launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
244
+ elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
245
+ except ProcessRaisedException as e:
246
+ if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
247
+ raise RuntimeError(
248
+ "CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. "
249
+ "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
250
+ "Please review your imports and test them when running the `notebook_launcher()` to identify "
251
+ "which one is problematic and causing CUDA to be initialized."
252
+ ) from e
253
+ else:
254
+ raise RuntimeError(f"An issue was found when launching the training: {e}") from e
255
+
256
+ else:
257
+ # No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
258
+ if is_mps_available():
259
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
260
+ print("Launching training on MPS.")
261
+ elif torch.cuda.is_available():
262
+ print("Launching training on one GPU.")
263
+ else:
264
+ print("Launching training on CPU.")
265
+ function(*args)
266
+
267
+
268
+ def debug_launcher(function, args=(), num_processes=2):
269
+ """
270
+ Launches a training function using several processes on CPU for debugging purposes.
271
+
272
+ <Tip warning={true}>
273
+
274
+ This function is provided for internal testing and debugging, but it's not intended for real trainings. It will
275
+ only use the CPU.
276
+
277
+ </Tip>
278
+
279
+ Args:
280
+ function (`Callable`):
281
+ The training function to execute.
282
+ args (`Tuple`):
283
+ Tuple of arguments to pass to the function (it will receive `*args`).
284
+ num_processes (`int`, *optional*, defaults to 2):
285
+ The number of processes to use for training.
286
+ """
287
+ from torch.multiprocessing import start_processes
288
+
289
+ with tempfile.NamedTemporaryFile() as tmp_file:
290
+ # torch.distributed will expect a few environment variable to be here. We set the ones common to each
291
+ # process here (the other ones will be set be the launcher).
292
+ with patch_environment(
293
+ world_size=num_processes,
294
+ master_addr="127.0.0.1",
295
+ master_port="29500",
296
+ accelerate_mixed_precision="no",
297
+ accelerate_debug_rdv_file=tmp_file.name,
298
+ accelerate_use_cpu="yes",
299
+ ):
300
+ launcher = PrepareForLaunch(function, debug=True)
301
+ start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
venv/Lib/site-packages/accelerate/local_sgd.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ from accelerate import Accelerator, DistributedType
17
+
18
+
19
+ class LocalSGD:
20
+ """
21
+ A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
22
+ on each device, and averages model weights every K synchronization step.
23
+
24
+ It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
25
+ this is a simple implementation that cannot support scenarios such as model parallelism.
26
+
27
+
28
+ Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
29
+ back to at least:
30
+
31
+ Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
32
+ arXiv:1606.07365.](https://arxiv.org/abs/1606.07365)
33
+
34
+ We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).
35
+
36
+ Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
37
+ Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767)
38
+
39
+ """
40
+
41
+ def __enter__(self):
42
+ if self.enabled:
43
+ self.model_sync_obj = self.model.no_sync()
44
+ self.model_sync_obj.__enter__()
45
+
46
+ return self
47
+
48
+ def __exit__(self, type, value, tb):
49
+ if self.enabled:
50
+ # Average all models on exit
51
+ self._sync_and_avg_model_params()
52
+ self.model_sync_obj.__exit__(type, value, tb)
53
+
54
+ def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
55
+ """
56
+ Constructor.
57
+
58
+ Args:
59
+ model (`torch.nn.Module):
60
+ The model whose parameters we need to average.
61
+ accelerator (`Accelerator`):
62
+ Accelerator object.
63
+ local_sgd_steps (`int`):
64
+ A number of local SGD steps (before model parameters are synchronized).
65
+ enabled (`bool):
66
+ Local SGD is disabled if this parameter set to `False`.
67
+ """
68
+ if accelerator.distributed_type not in [
69
+ DistributedType.NO,
70
+ DistributedType.MULTI_CPU,
71
+ DistributedType.MULTI_GPU,
72
+ DistributedType.MULTI_XPU,
73
+ DistributedType.MULTI_MLU,
74
+ DistributedType.MULTI_HPU,
75
+ DistributedType.MULTI_SDAA,
76
+ DistributedType.MULTI_MUSA,
77
+ DistributedType.MULTI_NPU,
78
+ ]:
79
+ raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
80
+ self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
81
+ self.num_steps = 0
82
+ if self.enabled:
83
+ self.accelerator = accelerator
84
+ self.model = model
85
+ self.local_sgd_steps = local_sgd_steps
86
+
87
+ def step(self):
88
+ """
89
+ This function makes a "step" and synchronizes model parameters if necessary.
90
+ """
91
+ self.num_steps += 1
92
+ if not self.enabled:
93
+ return
94
+
95
+ if self.num_steps % self.local_sgd_steps == 0:
96
+ self._sync_and_avg_model_params()
97
+
98
+ def _sync_and_avg_model_params(self):
99
+ """
100
+ Synchronize + Average model parameters across all GPUs
101
+ """
102
+
103
+ self.accelerator.wait_for_everyone()
104
+ with self.accelerator.autocast():
105
+ for param in self.model.parameters():
106
+ param.data = self.accelerator.reduce(param.data, reduction="mean")
venv/Lib/site-packages/accelerate/logging.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import logging
17
+ import os
18
+
19
+ from .state import PartialState
20
+
21
+
22
+ class MultiProcessAdapter(logging.LoggerAdapter):
23
+ """
24
+ An adapter to assist with logging in multiprocess.
25
+
26
+ `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
27
+ or only the main executed one. Default is `main_process_only=True`.
28
+
29
+ Does not require an `Accelerator` object to be created first.
30
+ """
31
+
32
+ @staticmethod
33
+ def _should_log(main_process_only):
34
+ "Check if log should be performed"
35
+ state = PartialState()
36
+ return not main_process_only or (main_process_only and state.is_main_process)
37
+
38
+ def log(self, level, msg, *args, **kwargs):
39
+ """
40
+ Delegates logger call after checking if we should log.
41
+
42
+ Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
43
+ or only the main executed one. Default is `True` if not passed
44
+
45
+ Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
46
+ read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
47
+ break with the previous behavior.
48
+
49
+ `in_order` is ignored if `main_process_only` is passed.
50
+ """
51
+ if PartialState._shared_state == {}:
52
+ raise RuntimeError(
53
+ "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
54
+ )
55
+ main_process_only = kwargs.pop("main_process_only", True)
56
+ in_order = kwargs.pop("in_order", False)
57
+ # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
58
+ kwargs.setdefault("stacklevel", 2)
59
+
60
+ if self.isEnabledFor(level):
61
+ if self._should_log(main_process_only):
62
+ msg, kwargs = self.process(msg, kwargs)
63
+ self.logger.log(level, msg, *args, **kwargs)
64
+
65
+ elif in_order:
66
+ state = PartialState()
67
+ for i in range(state.num_processes):
68
+ if i == state.process_index:
69
+ msg, kwargs = self.process(msg, kwargs)
70
+ self.logger.log(level, msg, *args, **kwargs)
71
+ state.wait_for_everyone()
72
+
73
+ @functools.lru_cache(None)
74
+ def warning_once(self, *args, **kwargs):
75
+ """
76
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
77
+
78
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
79
+ cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
80
+ switch to another type of cache that includes the caller frame information in the hashing function.
81
+ """
82
+ self.warning(*args, **kwargs)
83
+
84
+
85
+ def get_logger(name: str, log_level: str = None):
86
+ """
87
+ Returns a `logging.Logger` for `name` that can handle multiprocessing.
88
+
89
+ If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
90
+ processes and in order, also pass `in_order=True`
91
+
92
+ Args:
93
+ name (`str`):
94
+ The name for the logger, such as `__file__`
95
+ log_level (`str`, *optional*):
96
+ The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
97
+
98
+ Example:
99
+
100
+ ```python
101
+ >>> from accelerate.logging import get_logger
102
+ >>> from accelerate import Accelerator
103
+
104
+ >>> logger = get_logger(__name__)
105
+
106
+ >>> accelerator = Accelerator()
107
+ >>> logger.info("My log", main_process_only=False)
108
+ >>> logger.debug("My log", main_process_only=True)
109
+
110
+ >>> logger = get_logger(__name__, log_level="DEBUG")
111
+ >>> logger.info("My log")
112
+ >>> logger.debug("My second log")
113
+
114
+ >>> array = ["a", "b", "c", "d"]
115
+ >>> letter_at_rank = array[accelerator.process_index]
116
+ >>> logger.info(letter_at_rank, in_order=True)
117
+ ```
118
+ """
119
+ if log_level is None:
120
+ log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
121
+ logger = logging.getLogger(name)
122
+ if log_level is not None:
123
+ logger.setLevel(log_level.upper())
124
+ logger.root.setLevel(log_level.upper())
125
+ return MultiProcessAdapter(logger, {})
venv/Lib/site-packages/accelerate/memory_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+
17
+
18
+ warnings.warn(
19
+ "memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: "
20
+ "`from accelerate import find_executable_batch_size` to avoid this warning.",
21
+ FutureWarning,
22
+ )
venv/Lib/site-packages/accelerate/optimizer.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+
17
+ import torch
18
+
19
+ from .state import AcceleratorState, GradientState
20
+ from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
21
+
22
+
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+
26
+
27
+ def move_to_device(state, device):
28
+ if isinstance(state, (list, tuple)):
29
+ return honor_type(state, (move_to_device(t, device) for t in state))
30
+ elif isinstance(state, dict):
31
+ return type(state)({k: move_to_device(v, device) for k, v in state.items()})
32
+ elif isinstance(state, torch.Tensor):
33
+ return state.to(device)
34
+ return state
35
+
36
+
37
+ class AcceleratedOptimizer(torch.optim.Optimizer):
38
+ """
39
+ Internal wrapper around a torch optimizer.
40
+
41
+ Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
42
+ accumulation.
43
+
44
+ Args:
45
+ optimizer (`torch.optim.optimizer.Optimizer`):
46
+ The optimizer to wrap.
47
+ device_placement (`bool`, *optional*, defaults to `True`):
48
+ Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
49
+ `optimizer` on the right device.
50
+ scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*):
51
+ The scaler to use in the step function if training with mixed precision.
52
+ """
53
+
54
+ def __init__(self, optimizer, device_placement=True, scaler=None):
55
+ self.optimizer = optimizer
56
+ self.scaler = scaler
57
+ self.accelerator_state = AcceleratorState()
58
+ self.gradient_state = GradientState()
59
+ self.device_placement = device_placement
60
+ self._is_overflow = False
61
+
62
+ if self.scaler is not None:
63
+ self._accelerate_step_called = False
64
+ self._optimizer_original_step_method = self.optimizer.step
65
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
66
+
67
+ # Handle device placement
68
+ if device_placement:
69
+ state_dict = self.optimizer.state_dict()
70
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
71
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
72
+ else:
73
+ state_dict = move_to_device(state_dict, self.accelerator_state.device)
74
+ self.optimizer.load_state_dict(state_dict)
75
+
76
+ @property
77
+ def state(self):
78
+ return self.optimizer.state
79
+
80
+ @state.setter
81
+ def state(self, state):
82
+ self.optimizer.state = state
83
+
84
+ @property
85
+ def param_groups(self):
86
+ return self.optimizer.param_groups
87
+
88
+ @param_groups.setter
89
+ def param_groups(self, param_groups):
90
+ self.optimizer.param_groups = param_groups
91
+
92
+ @property
93
+ def defaults(self):
94
+ return self.optimizer.defaults
95
+
96
+ @defaults.setter
97
+ def defaults(self, defaults):
98
+ self.optimizer.defaults = defaults
99
+
100
+ def add_param_group(self, param_group):
101
+ self.optimizer.add_param_group(param_group)
102
+
103
+ def load_state_dict(self, state_dict):
104
+ if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
105
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
106
+ self.optimizer.load_state_dict(state_dict)
107
+
108
+ def state_dict(self):
109
+ return self.optimizer.state_dict()
110
+
111
+ def zero_grad(self, set_to_none=None):
112
+ if self.gradient_state.sync_gradients:
113
+ accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
114
+ if accept_arg:
115
+ if set_to_none is None:
116
+ set_to_none = True
117
+ self.optimizer.zero_grad(set_to_none=set_to_none)
118
+ else:
119
+ if set_to_none is not None:
120
+ raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
121
+ self.optimizer.zero_grad()
122
+
123
+ def train(self):
124
+ """
125
+ Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
126
+ """
127
+ if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
128
+ self.optimizer.train()
129
+ elif (
130
+ hasattr(self.optimizer, "optimizer")
131
+ and hasattr(self.optimizer.optimizer, "train")
132
+ and callable(self.optimizer.optimizer.train)
133
+ ):
134
+ # the deepspeed optimizer further wraps the optimizer
135
+ self.optimizer.optimizer.train()
136
+
137
+ def eval(self):
138
+ """
139
+ Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
140
+ """
141
+ if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
142
+ self.optimizer.eval()
143
+
144
+ def step(self, closure=None):
145
+ if is_lomo_available():
146
+ from lomo_optim import AdaLomo, Lomo
147
+
148
+ if (
149
+ not self.gradient_state.is_xla_gradients_synced
150
+ and self.accelerator_state.distributed_type == DistributedType.XLA
151
+ ):
152
+ gradients = xm._fetch_gradients(self.optimizer)
153
+ xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
154
+ self.gradient_state.is_xla_gradients_synced = True
155
+
156
+ if is_lomo_available():
157
+ # `step` should be a no-op for LOMO optimizers.
158
+ if isinstance(self.optimizer, (Lomo, AdaLomo)):
159
+ return
160
+
161
+ if self.gradient_state.sync_gradients:
162
+ if self.scaler is not None:
163
+ self.optimizer.step = self._optimizer_patched_step_method
164
+
165
+ self.scaler.step(self.optimizer, closure)
166
+ self.scaler.update()
167
+
168
+ if not self._accelerate_step_called:
169
+ # If the optimizer step was skipped, gradient overflow was detected.
170
+ self._is_overflow = True
171
+ else:
172
+ self._is_overflow = False
173
+ # Reset the step method to the original one
174
+ self.optimizer.step = self._optimizer_original_step_method
175
+ # Reset the indicator
176
+ self._accelerate_step_called = False
177
+ else:
178
+ self.optimizer.step(closure)
179
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
180
+ self.gradient_state.is_xla_gradients_synced = False
181
+
182
+ def _switch_parameters(self, parameters_map):
183
+ for param_group in self.optimizer.param_groups:
184
+ param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
185
+
186
+ @property
187
+ def step_was_skipped(self):
188
+ """Whether or not the optimizer step was skipped."""
189
+ return self._is_overflow
190
+
191
+ def __getstate__(self):
192
+ _ignored_keys = [
193
+ "_accelerate_step_called",
194
+ "_optimizer_original_step_method",
195
+ "_optimizer_patched_step_method",
196
+ ]
197
+ return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
198
+
199
+ def __setstate__(self, state):
200
+ self.__dict__.update(state)
201
+ if self.scaler is not None:
202
+ self._accelerate_step_called = False
203
+ self._optimizer_original_step_method = self.optimizer.step
204
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
205
+
206
+
207
+ def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
208
+ def patched_step(*args, **kwargs):
209
+ accelerated_optimizer._accelerate_step_called = True
210
+ return method(*args, **kwargs)
211
+
212
+ return patched_step
venv/Lib/site-packages/accelerate/scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
16
+
17
+ import warnings
18
+
19
+ from .state import AcceleratorState, GradientState
20
+
21
+
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
23
+
24
+
25
+ class AcceleratedScheduler:
26
+ """
27
+ A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
28
+ to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
29
+ precision training)
30
+
31
+ When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
32
+ step the scheduler to account for it.
33
+
34
+ Args:
35
+ scheduler (`torch.optim.lr_scheduler._LRScheduler`):
36
+ The scheduler to wrap.
37
+ optimizers (one or a list of `torch.optim.Optimizer`):
38
+ The optimizers used.
39
+ step_with_optimizer (`bool`, *optional*, defaults to `True`):
40
+ Whether or not the scheduler should be stepped at each optimizer step.
41
+ split_batches (`bool`, *optional*, defaults to `False`):
42
+ Whether or not the dataloaders split one batch across the different processes (so batch size is the same
43
+ regardless of the number of processes) or create batches on each process (so batch size is the original
44
+ batch size multiplied by the number of processes).
45
+ """
46
+
47
+ def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
48
+ self.scheduler = scheduler
49
+ self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
50
+ self.split_batches = split_batches
51
+ self.step_with_optimizer = step_with_optimizer
52
+ self.gradient_state = GradientState()
53
+
54
+ def step(self, *args, **kwargs):
55
+ if not self.step_with_optimizer:
56
+ # No link between scheduler and optimizer -> just step
57
+ self.scheduler.step(*args, **kwargs)
58
+ return
59
+
60
+ # Otherwise, first make sure the optimizer was stepped.
61
+ if not self.gradient_state.sync_gradients:
62
+ if self.gradient_state.adjust_scheduler:
63
+ self.scheduler._step_count += 1
64
+ return
65
+
66
+ for opt in self.optimizers:
67
+ if opt.step_was_skipped:
68
+ return
69
+ if self.split_batches:
70
+ # Split batches -> the training dataloader batch size is not changed so one step per training step
71
+ self.scheduler.step(*args, **kwargs)
72
+ else:
73
+ # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
74
+ # num_processes steps per training step
75
+ num_processes = AcceleratorState().num_processes
76
+ for _ in range(num_processes):
77
+ # Special case when using OneCycle and `drop_last` was not used
78
+ if hasattr(self.scheduler, "total_steps"):
79
+ if self.scheduler._step_count <= self.scheduler.total_steps:
80
+ self.scheduler.step(*args, **kwargs)
81
+ else:
82
+ self.scheduler.step(*args, **kwargs)
83
+
84
+ # Passthroughs
85
+ def get_last_lr(self):
86
+ return self.scheduler.get_last_lr()
87
+
88
+ def state_dict(self):
89
+ return self.scheduler.state_dict()
90
+
91
+ def load_state_dict(self, state_dict):
92
+ self.scheduler.load_state_dict(state_dict)
93
+
94
+ def get_lr(self):
95
+ return self.scheduler.get_lr()
96
+
97
+ def print_lr(self, *args, **kwargs):
98
+ return self.scheduler.print_lr(*args, **kwargs)
venv/Lib/site-packages/accelerate/state.py ADDED
@@ -0,0 +1,1330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ import os
19
+ import threading
20
+ import warnings
21
+ import weakref
22
+ from contextlib import contextmanager
23
+ from functools import partial
24
+ from typing import Any, Callable
25
+
26
+ import torch
27
+
28
+ from .utils import (
29
+ DistributedType,
30
+ DynamoBackend,
31
+ GradientAccumulationPlugin,
32
+ check_cuda_fp8_capability,
33
+ check_cuda_p2p_ib_support,
34
+ deepspeed_required,
35
+ get_ccl_version,
36
+ get_cpu_distributed_information,
37
+ get_int_from_env,
38
+ is_ccl_available,
39
+ is_datasets_available,
40
+ is_deepspeed_available,
41
+ is_fp8_available,
42
+ is_habana_gaudi1,
43
+ is_hpu_available,
44
+ is_ipex_available,
45
+ is_mlu_available,
46
+ is_mps_available,
47
+ is_musa_available,
48
+ is_npu_available,
49
+ is_sdaa_available,
50
+ is_torch_xla_available,
51
+ is_xccl_available,
52
+ is_xpu_available,
53
+ parse_choice_from_env,
54
+ parse_flag_from_env,
55
+ set_numa_affinity,
56
+ )
57
+ from .utils.dataclasses import SageMakerDistributedType
58
+
59
+
60
+ if is_torch_xla_available():
61
+ import torch_xla.core.xla_model as xm
62
+
63
+ if is_mlu_available(check_device=False):
64
+ import torch_mlu # noqa: F401
65
+
66
+ if is_sdaa_available(check_device=False):
67
+ import torch_sdaa # noqa: F401
68
+
69
+ if is_musa_available(check_device=False):
70
+ import torch_musa # noqa: F401
71
+
72
+ if is_npu_available(check_device=False):
73
+ import torch_npu # noqa: F401
74
+
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+
79
+ def is_initialized() -> bool:
80
+ """
81
+ Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
82
+ but works as a module method.
83
+ """
84
+ return AcceleratorState._shared_state != {}
85
+
86
+
87
+ # Lambda function that does nothing
88
+ def do_nothing(*args, **kwargs):
89
+ return None
90
+
91
+
92
+ class ThreadLocalSharedDict(threading.local):
93
+ """
94
+ Descriptor that holds a dict shared between instances of a class in the same thread.
95
+
96
+ Note: Descriptors have slightly different semantics than just a dict field on its own.
97
+ `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
98
+ underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
99
+ the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
100
+ object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
101
+
102
+ See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
103
+
104
+ This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
105
+
106
+ See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
107
+ """
108
+
109
+ def __init__(self, thread_local: bool = False):
110
+ self._storage = {}
111
+
112
+ def __get__(self, obj, objtype=None):
113
+ return self._storage
114
+
115
+ def __set__(self, obj, value):
116
+ self._storage = value
117
+
118
+
119
+ # Prefer global shared dictionary, except when using TPU.
120
+ SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
121
+
122
+
123
+ # Inspired by Alex Martelli's 'Borg'.
124
+ class PartialState:
125
+ """
126
+ Singleton class that has information about the current training environment and functions to help with process
127
+ control. Designed to be used when only process control and device execution states are needed. Does *not* need to
128
+ be initialized from `Accelerator`.
129
+
130
+ Args:
131
+ cpu (`bool`, *optional*):
132
+ Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
133
+ `True` and force the execution on the CPU.
134
+ kwargs (additional keyword arguments, *optional*):
135
+ Additional keyword arguments to pass to the relevent `init_process_group` function. Valid `kwargs` can be
136
+ found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
137
+
138
+ **Available attributes:**
139
+
140
+ - **device** (`torch.device`) -- The device to use.
141
+ - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
142
+ in use.
143
+ - **local_process_index** (`int`) -- The index of the current process on the current server.
144
+ - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
145
+ of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
146
+ - **num_processes** (`int`) -- The number of processes currently launched in parallel.
147
+ - **process_index** (`int`) -- The index of the current process.
148
+ - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
149
+ - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
150
+ - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
151
+ - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
152
+
153
+ Example:
154
+ ```python
155
+ from accelerate.utils import InitProcessGroupKwargs
156
+
157
+ # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`
158
+ kwargs = InitProcessGroupKwargs(...).to_kwargs()
159
+ state = PartialState(**kwargs)
160
+ ```
161
+ """
162
+
163
+ _shared_state = SharedDict()
164
+ _known_attrs = [
165
+ "_cpu",
166
+ "_mixed_precision",
167
+ "_shared_state",
168
+ "backend",
169
+ "debug",
170
+ "device",
171
+ "distributed_type",
172
+ "fork_launched",
173
+ "local_process_index",
174
+ "num_processes",
175
+ "process_index",
176
+ ]
177
+
178
+ def __init__(self, cpu: bool = False, **kwargs):
179
+ self.__dict__ = self._shared_state
180
+ if not self.initialized:
181
+ self._cpu = cpu
182
+ self.backend = None
183
+ env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
184
+ self.device = torch.device(env_device) if env_device is not None else None
185
+ self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
186
+ use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
187
+ dist_information = None
188
+ if use_sagemaker_dp is None:
189
+ use_sagemaker_dp = (
190
+ os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true"
191
+ and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
192
+ )
193
+
194
+ # Sets up self.backend + imports
195
+ original_backend = kwargs.pop("backend", None)
196
+ backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
197
+ if original_backend is not None and backend != original_backend:
198
+ raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
199
+ self.backend = backend
200
+ self.distributed_type = distributed_type
201
+ use_deepspeed = False
202
+ if not cpu and self.backend != "xla":
203
+ if int(os.environ.get("LOCAL_RANK", -1)) != -1:
204
+ # Deal with spawning deepspeed
205
+ if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
206
+ if not is_deepspeed_available():
207
+ raise ImportError(
208
+ "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
209
+ )
210
+ from deepspeed import comm as dist
211
+
212
+ if not dist.is_initialized():
213
+ if self.backend == "tccl":
214
+ local_rank = os.environ.get("LOCAL_RANK", -1)
215
+ torch.sdaa.set_device(f"sdaa:{local_rank}")
216
+ dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
217
+ # We need to flag to `use_deepspeed` to be True to override `distributed_type` later
218
+ use_deepspeed = True
219
+ # Deal with all other backends but XPU and CPU, that gets handled special later
220
+ elif (
221
+ self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
222
+ and not torch.distributed.is_initialized()
223
+ ):
224
+ if self.backend == "tccl":
225
+ local_rank = os.environ.get("LOCAL_RANK", -1)
226
+ torch.sdaa.set_device(f"sdaa:{local_rank}")
227
+ torch.distributed.init_process_group(backend=self.backend, **kwargs)
228
+
229
+ # XPU and CPU require special env configs to be set
230
+ if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
231
+ dist_information = get_cpu_distributed_information()
232
+ os.environ["RANK"] = str(dist_information.rank)
233
+ os.environ["WORLD_SIZE"] = str(dist_information.world_size)
234
+ os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
235
+ os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
236
+ if not os.environ.get("MASTER_PORT", None):
237
+ os.environ["MASTER_PORT"] = "29500"
238
+ if (
239
+ not os.environ.get("MASTER_ADDR", None)
240
+ and dist_information.local_world_size != dist_information.world_size
241
+ and self.backend != "mpi"
242
+ ):
243
+ raise ValueError(
244
+ "Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, "
245
+ "please try exporting rank 0's hostname as `MASTER_ADDR`"
246
+ )
247
+ kwargs["rank"] = dist_information.rank
248
+ kwargs["world_size"] = dist_information.world_size
249
+
250
+ if (
251
+ self.distributed_type == DistributedType.MULTI_CPU
252
+ and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0
253
+ ):
254
+ import psutil
255
+
256
+ num_cpu_threads_per_process = int(
257
+ psutil.cpu_count(logical=False) / dist_information.local_world_size
258
+ )
259
+ if num_cpu_threads_per_process == 0:
260
+ num_cpu_threads_per_process = 1
261
+ torch.set_num_threads(num_cpu_threads_per_process)
262
+ warnings.warn(
263
+ f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob"
264
+ " performance."
265
+ )
266
+
267
+ if not torch.distributed.is_initialized():
268
+ torch.distributed.init_process_group(backend=self.backend, **kwargs)
269
+
270
+ # No backend == no distributed training
271
+ if self.backend is None:
272
+ self.distributed_type = DistributedType.NO
273
+ self.num_processes = 1
274
+ self.process_index = 0
275
+ self.local_process_index = 0
276
+ elif self.backend == "xla":
277
+ # XLA needs device setting first for `set_replication`
278
+ self.set_device()
279
+ xm.set_replication(self.device, xm.get_xla_supported_devices())
280
+ self.num_processes = xm.xrt_world_size()
281
+ self.process_index = xm.get_ordinal()
282
+ if is_torch_xla_available(check_is_tpu=True):
283
+ self.local_process_index = xm.get_local_ordinal()
284
+ else:
285
+ self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
286
+ else:
287
+ self.num_processes = torch.distributed.get_world_size()
288
+ self.process_index = torch.distributed.get_rank()
289
+ self.local_process_index = (
290
+ int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
291
+ )
292
+ self.set_device()
293
+ # Now we can change to deepseed
294
+ if use_deepspeed:
295
+ self.distributed_type = DistributedType.DEEPSPEED
296
+
297
+ # Set CPU affinity if enabled
298
+ if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
299
+ set_numa_affinity(self.local_process_index)
300
+
301
+ # Check for old RTX 4000's that can't use P2P or IB and are on old drivers
302
+ if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
303
+ if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
304
+ raise NotImplementedError(
305
+ "Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
306
+ 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
307
+ "will do this automatically."
308
+ )
309
+
310
+ # Important: This should be the *only* code outside of `self.initialized!`
311
+ self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
312
+
313
+ def __repr__(self) -> str:
314
+ return (
315
+ f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
316
+ f"Num processes: {self.num_processes}\n"
317
+ f"Process index: {self.process_index}\n"
318
+ f"Local process index: {self.local_process_index}\n"
319
+ f"Device: {self.device}\n"
320
+ )
321
+
322
+ @staticmethod
323
+ def _reset_state():
324
+ "Resets `_shared_state`, is used internally and should not be called"
325
+ PartialState._shared_state.clear()
326
+
327
+ @property
328
+ def initialized(self) -> bool:
329
+ "Returns whether the `PartialState` has been initialized"
330
+ return self._shared_state != {}
331
+
332
+ @property
333
+ def use_distributed(self):
334
+ """
335
+ Whether the Accelerator is configured for distributed training
336
+ """
337
+ return self.distributed_type != DistributedType.NO and self.num_processes > 1
338
+
339
+ @property
340
+ def is_last_process(self) -> bool:
341
+ "Returns whether the current process is the last one"
342
+ return self.process_index == self.num_processes - 1
343
+
344
+ @property
345
+ def is_main_process(self) -> bool:
346
+ "Returns whether the current process is the main process"
347
+ return (
348
+ self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process
349
+ )
350
+
351
+ @property
352
+ def is_local_main_process(self) -> bool:
353
+ "Returns whether the current process is the main process on the local node"
354
+ return (
355
+ self.local_process_index == 0
356
+ if self.distributed_type != DistributedType.MEGATRON_LM
357
+ else self.is_last_process
358
+ )
359
+
360
+ def wait_for_everyone(self):
361
+ """
362
+ Will stop the execution of the current process until every other process has reached that point (so this does
363
+ nothing when the script is only run in one process). Useful to do before saving a model.
364
+
365
+ Example:
366
+
367
+ ```python
368
+ >>> # Assuming two GPU processes
369
+ >>> import time
370
+ >>> from accelerate.state import PartialState
371
+
372
+ >>> state = PartialState()
373
+ >>> if state.is_main_process:
374
+ ... time.sleep(2)
375
+ >>> else:
376
+ ... print("I'm waiting for the main process to finish its sleep...")
377
+ >>> state.wait_for_everyone()
378
+ >>> # Should print on every process at the same time
379
+ >>> print("Everyone is here")
380
+ ```
381
+ """
382
+ if self.distributed_type in (
383
+ DistributedType.MULTI_GPU,
384
+ DistributedType.MULTI_MLU,
385
+ DistributedType.MULTI_SDAA,
386
+ DistributedType.MULTI_MUSA,
387
+ DistributedType.MULTI_NPU,
388
+ DistributedType.MULTI_XPU,
389
+ DistributedType.MULTI_CPU,
390
+ DistributedType.MULTI_HPU,
391
+ DistributedType.DEEPSPEED,
392
+ DistributedType.FSDP,
393
+ ):
394
+ torch.distributed.barrier()
395
+ elif self.distributed_type == DistributedType.XLA:
396
+ xm.rendezvous("accelerate.utils.wait_for_everyone")
397
+
398
+ def _goes_first(self, is_main: bool):
399
+ if not is_main:
400
+ self.wait_for_everyone()
401
+
402
+ yield
403
+
404
+ if is_main:
405
+ self.wait_for_everyone()
406
+
407
+ @contextmanager
408
+ def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
409
+ """
410
+ Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
411
+ distributed inference, such as with different prompts.
412
+
413
+ Note that when using a `dict`, all keys need to have the same number of elements.
414
+
415
+ Args:
416
+ inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
417
+ The input to split between processes.
418
+ apply_padding (`bool`, `optional`, defaults to `False`):
419
+ Whether to apply padding by repeating the last element of the input so that all processes have the same
420
+ number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
421
+ in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
422
+
423
+
424
+ Example:
425
+
426
+ ```python
427
+ # Assume there are two processes
428
+ from accelerate import PartialState
429
+
430
+ state = PartialState()
431
+ with state.split_between_processes(["A", "B", "C"]) as inputs:
432
+ print(inputs)
433
+ # Process 0
434
+ ["A", "B"]
435
+ # Process 1
436
+ ["C"]
437
+
438
+ with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
439
+ print(inputs)
440
+ # Process 0
441
+ ["A", "B"]
442
+ # Process 1
443
+ ["C", "C"]
444
+ ```
445
+ """
446
+ if self.num_processes == 1:
447
+ yield inputs
448
+ return
449
+ length = len(inputs)
450
+ # Nested dictionary of any types
451
+ if isinstance(inputs, dict):
452
+ length = len(inputs[list(inputs.keys())[0]])
453
+ if not all(len(v) == length for v in inputs.values()):
454
+ raise ValueError("All values in the dictionary must have the same length")
455
+ num_samples_per_process, num_extras = divmod(length, self.num_processes)
456
+ start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
457
+ end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
458
+
459
+ def _split_values(inputs, start_index, end_index):
460
+ if isinstance(inputs, (list, tuple, torch.Tensor)):
461
+ if start_index >= len(inputs):
462
+ result = inputs[-1:]
463
+ else:
464
+ result = inputs[start_index:end_index]
465
+ if apply_padding:
466
+ if isinstance(result, torch.Tensor):
467
+ from accelerate.utils import pad_across_processes, send_to_device
468
+
469
+ # The tensor needs to be on the device before we can pad it
470
+ tensorized_result = send_to_device(result, self.device)
471
+ result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
472
+ else:
473
+ result += [result[-1]] * (num_samples_per_process + 1 - len(result))
474
+ return result
475
+ elif isinstance(inputs, dict):
476
+ for key in inputs.keys():
477
+ inputs[key] = _split_values(inputs[key], start_index, end_index)
478
+ return inputs
479
+ else:
480
+ if is_datasets_available():
481
+ from datasets import Dataset
482
+
483
+ if isinstance(inputs, Dataset):
484
+ if start_index >= len(inputs):
485
+ start_index = len(inputs) - 1
486
+ if end_index > len(inputs):
487
+ end_index = len(inputs)
488
+ result_idcs = list(range(start_index, end_index))
489
+ if apply_padding:
490
+ result_idcs += [end_index - 1] * (num_samples_per_process + 1 - len(result_idcs))
491
+ return inputs.select(result_idcs)
492
+ return inputs
493
+
494
+ yield _split_values(inputs, start_index, end_index)
495
+
496
+ @contextmanager
497
+ def main_process_first(self):
498
+ """
499
+ Lets the main process go first inside a with block.
500
+
501
+ The other processes will enter the with block after the main process exits.
502
+
503
+ Example:
504
+
505
+ ```python
506
+ >>> from accelerate import Accelerator
507
+
508
+ >>> accelerator = Accelerator()
509
+ >>> with accelerator.main_process_first():
510
+ ... # This will be printed first by process 0 then in a seemingly
511
+ ... # random order by the other processes.
512
+ ... print(f"This will be printed by process {accelerator.process_index}")
513
+ ```
514
+ """
515
+ yield from self._goes_first(self.is_main_process)
516
+
517
+ @contextmanager
518
+ def local_main_process_first(self):
519
+ """
520
+ Lets the local main process go inside a with block.
521
+
522
+ The other processes will enter the with block after the main process exits.
523
+
524
+ Example:
525
+
526
+ ```python
527
+ >>> from accelerate.state import PartialState
528
+
529
+ >>> state = PartialState()
530
+ >>> with state.local_main_process_first():
531
+ ... # This will be printed first by local process 0 then in a seemingly
532
+ ... # random order by the other processes.
533
+ ... print(f"This will be printed by process {state.local_process_index}")
534
+ ```
535
+ """
536
+ yield from self._goes_first(self.is_local_main_process)
537
+
538
+ def on_main_process(self, function: Callable[..., Any] = None):
539
+ """
540
+ Decorator that only runs the decorated function on the main process.
541
+
542
+ Args:
543
+ function (`Callable`): The function to decorate.
544
+
545
+ Example:
546
+
547
+ ```python
548
+ >>> from accelerate.state import PartialState
549
+
550
+ >>> state = PartialState()
551
+
552
+
553
+ >>> @state.on_main_process
554
+ ... def print_something():
555
+ ... print("This will be printed by process 0 only.")
556
+
557
+
558
+ >>> print_something()
559
+ "This will be printed by process 0 only"
560
+ ```
561
+ """
562
+ if not self.initialized:
563
+ raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.")
564
+ if self.is_main_process or not self.use_distributed:
565
+ return function
566
+ return do_nothing
567
+
568
+ def on_local_main_process(self, function: Callable[..., Any] = None):
569
+ """
570
+ Decorator that only runs the decorated function on the local main process.
571
+
572
+ Args:
573
+ function (`Callable`): The function to decorate.
574
+
575
+ Example:
576
+ ```python
577
+ # Assume we have 2 servers with 4 processes each.
578
+ from accelerate.state import PartialState
579
+
580
+ state = PartialState()
581
+
582
+
583
+ @state.on_local_main_process
584
+ def print_something():
585
+ print("This will be printed by process 0 only on each server.")
586
+
587
+
588
+ print_something()
589
+ # On server 1:
590
+ "This will be printed by process 0 only"
591
+ # On server 2:
592
+ "This will be printed by process 0 only"
593
+ ```
594
+ """
595
+ if self.is_local_main_process or not self.use_distributed:
596
+ return function
597
+ return do_nothing
598
+
599
+ def on_last_process(self, function: Callable[..., Any]):
600
+ """
601
+ Decorator that only runs the decorated function on the last process.
602
+
603
+ Args:
604
+ function (`Callable`): The function to decorate.
605
+
606
+ Example:
607
+ ```python
608
+ # Assume we have 4 processes.
609
+ from accelerate.state import PartialState
610
+
611
+ state = PartialState()
612
+
613
+
614
+ @state.on_last_process
615
+ def print_something():
616
+ print(f"Printed on process {state.process_index}")
617
+
618
+
619
+ print_something()
620
+ "Printed on process 3"
621
+ ```
622
+ """
623
+ if self.is_last_process or not self.use_distributed:
624
+ return function
625
+ return do_nothing
626
+
627
+ def on_process(self, function: Callable[..., Any] = None, process_index: int = None):
628
+ """
629
+ Decorator that only runs the decorated function on the process with the given index.
630
+
631
+ Args:
632
+ function (`Callable`, `optional`):
633
+ The function to decorate.
634
+ process_index (`int`, `optional`):
635
+ The index of the process on which to run the function.
636
+
637
+ Example:
638
+ ```python
639
+ # Assume we have 4 processes.
640
+ from accelerate.state import PartialState
641
+
642
+ state = PartialState()
643
+
644
+
645
+ @state.on_process(process_index=2)
646
+ def print_something():
647
+ print(f"Printed on process {state.process_index}")
648
+
649
+
650
+ print_something()
651
+ "Printed on process 2"
652
+ ```
653
+ """
654
+ if function is None:
655
+ return partial(self.on_process, process_index=process_index)
656
+ if (self.process_index == process_index) or (not self.use_distributed):
657
+ return function
658
+ return do_nothing
659
+
660
+ def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None):
661
+ """
662
+ Decorator that only runs the decorated function on the process with the given index on the current node.
663
+
664
+ Args:
665
+ function (`Callable`, *optional*):
666
+ The function to decorate.
667
+ local_process_index (`int`, *optional*):
668
+ The index of the local process on which to run the function.
669
+
670
+ Example:
671
+ ```python
672
+ # Assume we have 2 servers with 4 processes each.
673
+ from accelerate import Accelerator
674
+
675
+ accelerator = Accelerator()
676
+
677
+
678
+ @accelerator.on_local_process(local_process_index=2)
679
+ def print_something():
680
+ print(f"Printed on process {accelerator.local_process_index}")
681
+
682
+
683
+ print_something()
684
+ # On server 1:
685
+ "Printed on process 2"
686
+ # On server 2:
687
+ "Printed on process 2"
688
+ ```
689
+ """
690
+ if function is None:
691
+ return partial(self.on_local_process, local_process_index=local_process_index)
692
+ if (self.local_process_index == local_process_index) or (not self.use_distributed):
693
+ return function
694
+ return do_nothing
695
+
696
+ def print(self, *args, **kwargs):
697
+ if self.is_local_main_process:
698
+ print(*args, **kwargs)
699
+
700
+ @property
701
+ def default_device(self) -> torch.device:
702
+ """
703
+ Returns the default device which is:
704
+ - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
705
+ - CUDA if `torch.cuda.is_available()`
706
+ - MLU if `is_mlu_available()`
707
+ - SDAA if `is_sdaa_available()`
708
+ - MUSA if `is_musa_available()`
709
+ - NPU if `is_npu_available()`
710
+ - HPU if `is_hpu_available()`
711
+ - CPU otherwise
712
+ """
713
+ if is_mps_available():
714
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
715
+ return torch.device("mps")
716
+ elif is_mlu_available():
717
+ return torch.device("mlu")
718
+ elif is_sdaa_available():
719
+ return torch.device("sdaa")
720
+ elif is_musa_available():
721
+ return torch.device("musa")
722
+ # NPU should be checked before CUDA when using `transfer_to_npu`
723
+ # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
724
+ elif is_npu_available():
725
+ return torch.device("npu")
726
+ elif is_hpu_available():
727
+ return torch.device("hpu")
728
+ elif torch.cuda.is_available():
729
+ return torch.device("cuda")
730
+ elif is_xpu_available():
731
+ return torch.device("xpu")
732
+ else:
733
+ return torch.device("cpu")
734
+
735
+ def _prepare_backend(
736
+ self, cpu: bool = False, sagemaker_dp=False, backend: str = None
737
+ ) -> tuple[str, DistributedType]:
738
+ "Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
739
+ distributed_type = None
740
+ if sagemaker_dp:
741
+ import smdistributed.dataparallel.torch.torch_smddp # noqa
742
+
743
+ backend = "smddp"
744
+ distributed_type = DistributedType.MULTI_GPU
745
+ elif is_torch_xla_available():
746
+ backend = "xla"
747
+ distributed_type = DistributedType.XLA
748
+
749
+ elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
750
+ if is_mlu_available():
751
+ backend = "cncl"
752
+ distributed_type = DistributedType.MULTI_MLU
753
+ if is_sdaa_available():
754
+ backend = "tccl"
755
+ distributed_type = DistributedType.MULTI_SDAA
756
+ elif is_musa_available():
757
+ backend = "mccl"
758
+ distributed_type = DistributedType.MULTI_MUSA
759
+ # NPU should be checked before CUDA when using `transfer_to_npu`
760
+ # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
761
+ elif is_npu_available():
762
+ backend = "hccl"
763
+ distributed_type = DistributedType.MULTI_NPU
764
+ elif is_hpu_available(init_hccl=True):
765
+ if backend is None:
766
+ backend = "hccl"
767
+ distributed_type = DistributedType.MULTI_HPU
768
+ elif torch.cuda.is_available():
769
+ if backend is None:
770
+ backend = "nccl"
771
+ distributed_type = DistributedType.MULTI_GPU
772
+ elif is_xpu_available() and is_xccl_available():
773
+ if backend is None:
774
+ backend = "xccl"
775
+ distributed_type = DistributedType.MULTI_XPU
776
+
777
+ if distributed_type is None and (
778
+ int(os.environ.get("LOCAL_RANK", -1)) != -1
779
+ or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
780
+ ):
781
+ if not cpu and is_xpu_available():
782
+ distributed_type = DistributedType.MULTI_XPU
783
+ else:
784
+ distributed_type = DistributedType.MULTI_CPU
785
+
786
+ if (
787
+ backend in (None, "ccl")
788
+ and is_ccl_available()
789
+ and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
790
+ ):
791
+ if get_ccl_version() >= "1.12":
792
+ import oneccl_bindings_for_pytorch # noqa: F401
793
+ else:
794
+ import torch_ccl # noqa: F401
795
+
796
+ backend = "ccl"
797
+ elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
798
+ backend = "mpi"
799
+ else:
800
+ backend = "gloo"
801
+ if distributed_type is None:
802
+ distributed_type = DistributedType.NO
803
+
804
+ return backend, distributed_type
805
+
806
+ def set_device(self):
807
+ """
808
+ Sets the device in `self.device` to the current distributed environment.
809
+ """
810
+ if self.device is not None:
811
+ return
812
+ if self.distributed_type == DistributedType.NO:
813
+ self.device = torch.device("cpu") if self._cpu else self.default_device
814
+ return
815
+ device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
816
+ if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa"):
817
+ raise ValueError(
818
+ f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
819
+ )
820
+ if device == "xla":
821
+ self.device = xm.xla_device()
822
+ elif device == "hpu":
823
+ self.device = torch.device("hpu", torch.hpu.current_device())
824
+ else:
825
+ if device == "gpu":
826
+ device = "cuda"
827
+ device_module = getattr(torch, device)
828
+ device_index = self.local_process_index % device_module.device_count()
829
+ self.device = torch.device(device, device_index)
830
+ device_module.set_device(self.device)
831
+
832
+ def destroy_process_group(self, group=None):
833
+ """
834
+ Destroys the process group. If one is not specified, the default process group is destroyed.
835
+ """
836
+ if self.fork_launched and group is None:
837
+ return
838
+ # needed when using torch.distributed.init_process_group
839
+ if torch.distributed.is_initialized():
840
+ torch.distributed.destroy_process_group(group)
841
+
842
+ def __getattr__(self, name: str):
843
+ # By this point we know that no attributes of `self` contain `name`,
844
+ # so we just modify the error message
845
+ if name in self._known_attrs:
846
+ raise AttributeError(
847
+ f"`PartialState` object has no attribute `{name}`. "
848
+ "This happens if `PartialState._reset_state()` was called and "
849
+ "an `Accelerator` or `PartialState` was not reinitialized."
850
+ )
851
+ # Raise a typical AttributeError
852
+ raise AttributeError(f"'PartialState' object has no attribute '{name}'")
853
+
854
+
855
+ class AcceleratorState:
856
+ """
857
+ Singleton class that has information about the current training environment.
858
+
859
+ **Available attributes:**
860
+
861
+ - **device** (`torch.device`) -- The device to use.
862
+ - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
863
+ in use.
864
+ - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
865
+ - **local_process_index** (`int`) -- The index of the current process on the current server.
866
+ - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
867
+ of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
868
+ - **num_processes** (`int`) -- The number of processes currently launched in parallel.
869
+ - **process_index** (`int`) -- The index of the current process.
870
+ - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
871
+ - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
872
+ - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
873
+ - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
874
+ """
875
+
876
+ _shared_state = SharedDict()
877
+ _known_attrs = PartialState._known_attrs + [
878
+ "deepspeed_plugin",
879
+ "use_ipex",
880
+ "fsdp_plugin",
881
+ "megatron_lm_plugin",
882
+ "dynamo_plugin",
883
+ ]
884
+
885
+ def __init__(
886
+ self,
887
+ mixed_precision: str = None,
888
+ cpu: bool = False,
889
+ dynamo_plugin=None,
890
+ deepspeed_plugin=None,
891
+ fsdp_plugin=None,
892
+ torch_tp_plugin=None,
893
+ megatron_lm_plugin=None,
894
+ _from_accelerator: bool = False,
895
+ **kwargs,
896
+ ):
897
+ self.__dict__ = self._shared_state
898
+ if parse_flag_from_env("ACCELERATE_USE_CPU"):
899
+ cpu = True
900
+ if PartialState._shared_state == {}:
901
+ PartialState(cpu, **kwargs)
902
+ self.__dict__.update(PartialState._shared_state)
903
+ self._check_initialized(mixed_precision, cpu)
904
+ if not self.initialized:
905
+ self.deepspeed_plugins = None
906
+ self.use_ipex = None
907
+ self.torch_tp_plugin = torch_tp_plugin
908
+ mixed_precision = (
909
+ parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
910
+ if mixed_precision is None
911
+ else mixed_precision.lower()
912
+ )
913
+ if mixed_precision == "fp8":
914
+ # this is confusing, why is is_fp8_available only checks for library availability ?
915
+ if not is_fp8_available():
916
+ raise ValueError(
917
+ "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
918
+ )
919
+ elif torch.cuda.is_available() and not check_cuda_fp8_capability():
920
+ logger.warning(
921
+ f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
922
+ "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
923
+ "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
924
+ )
925
+ mixed_precision = "fp16"
926
+ elif is_habana_gaudi1():
927
+ logger.warning(
928
+ "The current HPU device is Gaudi1 which does not support FP8 mixed precision training (requires "
929
+ "Gaudi2 or higher). Will use BF16 instead."
930
+ )
931
+ mixed_precision = "bf16"
932
+
933
+ self.dynamo_plugin = dynamo_plugin
934
+ if not _from_accelerator:
935
+ raise ValueError(
936
+ "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
937
+ "before using any functionality from the `accelerate` library."
938
+ )
939
+ # deepspeed handles mixed_precision using deepspeed_config
940
+ self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
941
+ if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
942
+ if mixed_precision == "bf16":
943
+ if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
944
+ os.environ["XLA_USE_BF16"] = str(0)
945
+ os.environ["XLA_DOWNCAST_BF16"] = str(1)
946
+ self.downcast_bfloat = True
947
+ else:
948
+ os.environ["XLA_USE_BF16"] = str(1)
949
+ os.environ["XLA_DOWNCAST_BF16"] = str(0)
950
+ self.downcast_bfloat = False
951
+ elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
952
+ self.deepspeed_plugins = deepspeed_plugin
953
+ self.distributed_type = DistributedType.DEEPSPEED
954
+ elif self.distributed_type in [
955
+ DistributedType.MULTI_GPU,
956
+ DistributedType.MULTI_MLU,
957
+ DistributedType.MULTI_SDAA,
958
+ DistributedType.MULTI_MUSA,
959
+ DistributedType.MULTI_NPU,
960
+ DistributedType.MULTI_XPU,
961
+ DistributedType.MULTI_HPU,
962
+ ]:
963
+ if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None:
964
+ self.distributed_type = DistributedType.FSDP
965
+ if self._mixed_precision != "no":
966
+ fsdp_plugin.set_mixed_precision(self._mixed_precision)
967
+ self.fsdp_plugin = fsdp_plugin
968
+ if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" and self.distributed_type not in [
969
+ DistributedType.MULTI_XPU,
970
+ ]:
971
+ self.distributed_type = DistributedType.MEGATRON_LM
972
+ megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
973
+ self.megatron_lm_plugin = megatron_lm_plugin
974
+ if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
975
+ self.distributed_type = DistributedType.TP
976
+ elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
977
+ if is_ipex_available():
978
+ # check if user disables it explicitly
979
+ self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True)
980
+ else:
981
+ self.use_ipex = False
982
+ if (
983
+ self.dynamo_plugin.backend != DynamoBackend.NO
984
+ and self._mixed_precision == "no"
985
+ and self.device.type == "cuda"
986
+ ):
987
+ torch.backends.cuda.matmul.allow_tf32 = True
988
+ if (
989
+ self.dynamo_plugin.backend != DynamoBackend.NO
990
+ and self._mixed_precision == "no"
991
+ and self.device.type == "musa"
992
+ ):
993
+ torch.backends.musa.matmul.allow_tf32 = True
994
+ PartialState._shared_state["distributed_type"] = self.distributed_type
995
+
996
+ @property
997
+ def initialized(self) -> bool:
998
+ return self._shared_state != PartialState._shared_state
999
+
1000
+ def __repr__(self):
1001
+ repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
1002
+ if self.distributed_type == DistributedType.DEEPSPEED:
1003
+ repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
1004
+ return repr
1005
+
1006
+ def _check_initialized(self, mixed_precision=None, cpu=None):
1007
+ "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
1008
+ if self.initialized:
1009
+ err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
1010
+ if cpu and self.device.type != "cpu":
1011
+ raise ValueError(err.format(flag="cpu=True"))
1012
+ if (
1013
+ mixed_precision is not None
1014
+ and mixed_precision != self._mixed_precision
1015
+ and self.distributed_type != DistributedType.DEEPSPEED
1016
+ ):
1017
+ raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'"))
1018
+
1019
+ @property
1020
+ def mixed_precision(self):
1021
+ if self.distributed_type == DistributedType.DEEPSPEED:
1022
+ config = self.deepspeed_plugin.deepspeed_config
1023
+ if config.get("fp16", {}).get("enabled", False):
1024
+ mixed_precision = "fp16"
1025
+ elif config.get("bf16", {}).get("enabled", False):
1026
+ mixed_precision = "bf16"
1027
+ else:
1028
+ mixed_precision = "no"
1029
+ else:
1030
+ mixed_precision = self._mixed_precision
1031
+ return mixed_precision
1032
+
1033
+ @staticmethod
1034
+ def _reset_state(reset_partial_state: bool = False):
1035
+ "Resets `_shared_state`, is used internally and should not be called"
1036
+ AcceleratorState._shared_state.clear()
1037
+ if reset_partial_state:
1038
+ PartialState._reset_state()
1039
+
1040
+ def destroy_process_group(self, group=None):
1041
+ """
1042
+ Destroys the process group. If one is not specified, the default process group is destroyed.
1043
+
1044
+ If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
1045
+ """
1046
+ PartialState().destroy_process_group(group)
1047
+
1048
+ @property
1049
+ def fork_launched(self):
1050
+ return PartialState().fork_launched
1051
+
1052
+ @property
1053
+ def use_distributed(self):
1054
+ """
1055
+ Whether the Accelerator is configured for distributed training
1056
+ """
1057
+ return PartialState().use_distributed
1058
+
1059
+ @property
1060
+ def is_fsdp2(self) -> bool:
1061
+ return self.distributed_type == DistributedType.FSDP and self.fsdp_plugin.fsdp_version == 2
1062
+
1063
+ @property
1064
+ def is_last_process(self) -> bool:
1065
+ "Returns whether the current process is the last one"
1066
+ return PartialState().is_last_process
1067
+
1068
+ @property
1069
+ def is_main_process(self) -> bool:
1070
+ "Returns whether the current process is the main process"
1071
+ return PartialState().is_main_process
1072
+
1073
+ @property
1074
+ def is_local_main_process(self) -> bool:
1075
+ "Returns whether the current process is the main process on the local node"
1076
+ return PartialState().is_local_main_process
1077
+
1078
+ def wait_for_everyone(self):
1079
+ PartialState().wait_for_everyone()
1080
+
1081
+ @contextmanager
1082
+ def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
1083
+ """
1084
+ Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
1085
+ distributed inference, such as with different prompts.
1086
+
1087
+ Note that when using a `dict`, all keys need to have the same number of elements.
1088
+
1089
+ Args:
1090
+ inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
1091
+ The input to split between processes.
1092
+ apply_padding (`bool`, `optional`, defaults to `False`):
1093
+ Whether to apply padding by repeating the last element of the input so that all processes have the same
1094
+ number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
1095
+ in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
1096
+
1097
+
1098
+ Example:
1099
+
1100
+ ```python
1101
+ # Assume there are two processes
1102
+ from accelerate.state import AcceleratorState
1103
+
1104
+ state = AcceleratorState()
1105
+ with state.split_between_processes(["A", "B", "C"]) as inputs:
1106
+ print(inputs)
1107
+ # Process 0
1108
+ ["A", "B"]
1109
+ # Process 1
1110
+ ["C"]
1111
+
1112
+ with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
1113
+ print(inputs)
1114
+ # Process 0
1115
+ ["A", "B"]
1116
+ # Process 1
1117
+ ["C", "C"]
1118
+ ```
1119
+ """
1120
+ with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
1121
+ yield inputs
1122
+
1123
+ @contextmanager
1124
+ def main_process_first(self):
1125
+ """
1126
+ Lets the main process go first inside a with block.
1127
+
1128
+ The other processes will enter the with block after the main process exits.
1129
+ """
1130
+ with PartialState().main_process_first():
1131
+ yield
1132
+
1133
+ @contextmanager
1134
+ def local_main_process_first(self):
1135
+ """
1136
+ Lets the local main process go inside a with block.
1137
+
1138
+ The other processes will enter the with block after the main process exits.
1139
+ """
1140
+ with PartialState().local_main_process_first():
1141
+ yield
1142
+
1143
+ @property
1144
+ def deepspeed_plugin(self):
1145
+ """
1146
+ Returns the currently active DeepSpeedPlugin.
1147
+
1148
+ If not using deepspeed, returns `None`.
1149
+ """
1150
+ # To maintain original behavior, return None if not using deepspeed.
1151
+ if self.distributed_type != DistributedType.DEEPSPEED:
1152
+ return None
1153
+ from accelerate.utils.deepspeed import get_active_deepspeed_plugin
1154
+
1155
+ return get_active_deepspeed_plugin(self)
1156
+
1157
+ @deepspeed_required
1158
+ def get_deepspeed_plugin(self, name: str):
1159
+ """
1160
+ Returns the DeepSpeedPlugin with the given plugin_key.
1161
+ """
1162
+ return self.deepspeed_plugins[name]
1163
+
1164
+ @deepspeed_required
1165
+ def select_deepspeed_plugin(self, name: str = None):
1166
+ """
1167
+ Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.
1168
+ """
1169
+ for key, plugin in self.deepspeed_plugins.items():
1170
+ if key != name:
1171
+ plugin._unselect()
1172
+ self.deepspeed_plugins[name].select(_from_accelerator_state=True)
1173
+
1174
+ def print(self, *args, **kwargs):
1175
+ PartialState().print(*args, **kwargs)
1176
+
1177
+ def __getattr__(self, name: str):
1178
+ # By this point we know that no attributes of `self` contain `name`,
1179
+ # so we just modify the error message
1180
+ if name in self._known_attrs:
1181
+ raise AttributeError(
1182
+ f"`AcceleratorState` object has no attribute `{name}`. "
1183
+ "This happens if `AcceleratorState._reset_state()` was called and "
1184
+ "an `Accelerator` or `PartialState` was not reinitialized."
1185
+ )
1186
+ # Raise a typical AttributeError
1187
+ raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
1188
+
1189
+
1190
+ class GradientState:
1191
+ """
1192
+ Singleton class that has information related to gradient synchronization for gradient accumulation
1193
+
1194
+ **Available attributes:**
1195
+
1196
+ - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader
1197
+ - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader
1198
+ - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices
1199
+ - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over
1200
+ - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are
1201
+ being iterated over
1202
+ - **num_steps** (`int`) -- The number of steps to accumulate over
1203
+ - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient
1204
+ accumulation
1205
+ - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
1206
+ iteration and the number of total steps reset
1207
+ - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
1208
+ as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
1209
+ after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
1210
+ is_xla_gradients_synced is always true.
1211
+ """
1212
+
1213
+ _shared_state = SharedDict()
1214
+
1215
+ def __init__(self, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None):
1216
+ self.__dict__ = self._shared_state
1217
+ if not self.initialized:
1218
+ self.sync_gradients = True
1219
+ self._dataloader_references_ref = [None]
1220
+ self.plugin_kwargs = (
1221
+ gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
1222
+ )
1223
+ self._is_xla_gradients_synced = False
1224
+
1225
+ # Plugin args are different and can be updated
1226
+ if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():
1227
+ self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()
1228
+
1229
+ @property
1230
+ def num_steps(self) -> int:
1231
+ "Returns the number of steps to accumulate over"
1232
+ return self.plugin_kwargs.get("num_steps", 1)
1233
+
1234
+ @property
1235
+ def adjust_scheduler(self) -> bool:
1236
+ "Returns whether the scheduler should be adjusted"
1237
+ return self.plugin_kwargs.get("adjust_scheduler", False)
1238
+
1239
+ @property
1240
+ def sync_with_dataloader(self) -> bool:
1241
+ "Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset"
1242
+ return self.plugin_kwargs.get("sync_with_dataloader", True)
1243
+
1244
+ @property
1245
+ def initialized(self) -> bool:
1246
+ "Returns whether the `GradientState` has been initialized"
1247
+ return GradientState._shared_state != {}
1248
+
1249
+ @property
1250
+ def end_of_dataloader(self) -> bool:
1251
+ "Returns whether we have reached the end of the current dataloader"
1252
+ if not self.in_dataloader:
1253
+ return False
1254
+ return self.active_dataloader.end_of_dataloader
1255
+
1256
+ @property
1257
+ def remainder(self) -> int:
1258
+ "Returns the number of extra samples that were added from padding the dataloader"
1259
+ if not self.in_dataloader:
1260
+ return -1
1261
+ return self.active_dataloader.remainder
1262
+
1263
+ def __repr__(self):
1264
+ return (
1265
+ f"Sync Gradients: {self.sync_gradients}\n"
1266
+ f"At end of current dataloader: {self.end_of_dataloader}\n"
1267
+ f"Extra samples added: {self.remainder}\n"
1268
+ f"Gradient accumulation plugin: {self.plugin_kwargs}\n"
1269
+ )
1270
+
1271
+ @property
1272
+ def is_xla_gradients_synced(self):
1273
+ "Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true."
1274
+ if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False):
1275
+ return True
1276
+ return self._is_xla_gradients_synced
1277
+
1278
+ @is_xla_gradients_synced.setter
1279
+ def is_xla_gradients_synced(self, is_synced):
1280
+ "Set the _is_xla_gradients_synced attribute."
1281
+ self._is_xla_gradients_synced = is_synced
1282
+
1283
+ def _set_sync_gradients(self, sync_gradients):
1284
+ "Private function that sets whether gradients should be synchronized. Users should not have to call this."
1285
+ self.sync_gradients = sync_gradients
1286
+ # Allow grad-sync to automatically work on TPUs
1287
+ if (
1288
+ self.sync_gradients
1289
+ and is_torch_xla_available(check_is_tpu=True)
1290
+ and PartialState().distributed_type == DistributedType.XLA
1291
+ ):
1292
+ xm.mark_step()
1293
+
1294
+ def _add_dataloader(self, dataloader):
1295
+ "Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
1296
+ # We explicitly use assignment to ensure that the property setter is triggered, which is required for garbage collection.
1297
+ # Avoid using self.dataloader_references.append as it will not trigger the setter.
1298
+ self.dataloader_references += [dataloader]
1299
+
1300
+ def _remove_dataloader(self, dataloader):
1301
+ "Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
1302
+ # We explicitly use assignment to ensure that the property setter is triggered.
1303
+ self.dataloader_references = [
1304
+ dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader
1305
+ ]
1306
+
1307
+ @property
1308
+ def active_dataloader(self):
1309
+ return self.dataloader_references[-1]
1310
+
1311
+ @property
1312
+ def dataloader_references(self):
1313
+ # We use a property getter and setter with weakrefs to avoid circular references that prevent garbage collection
1314
+ return [reference() if reference is not None else reference for reference in self._dataloader_references_ref]
1315
+
1316
+ @dataloader_references.setter
1317
+ def dataloader_references(self, references):
1318
+ self._dataloader_references_ref = [
1319
+ weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references
1320
+ ]
1321
+
1322
+ @property
1323
+ def in_dataloader(self) -> bool:
1324
+ "Returns whether the current process is in a dataloader"
1325
+ return self.active_dataloader is not None
1326
+
1327
+ @staticmethod
1328
+ def _reset_state():
1329
+ "Resets `_shared_state`, is used internally and should not be called"
1330
+ GradientState._shared_state.clear()
venv/Lib/site-packages/accelerate/tracking.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Expectation:
16
+ # Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
17
+
18
+ import json
19
+ import os
20
+ import time
21
+ from functools import wraps
22
+ from typing import Any, Optional, Union
23
+
24
+ import yaml
25
+ from packaging import version
26
+
27
+ from .logging import get_logger
28
+ from .state import PartialState
29
+ from .utils import (
30
+ LoggerType,
31
+ compare_versions,
32
+ is_aim_available,
33
+ is_clearml_available,
34
+ is_comet_ml_available,
35
+ is_dvclive_available,
36
+ is_mlflow_available,
37
+ is_tensorboard_available,
38
+ is_wandb_available,
39
+ listify,
40
+ )
41
+
42
+
43
+ _available_trackers = []
44
+
45
+ if is_tensorboard_available():
46
+ _available_trackers.append(LoggerType.TENSORBOARD)
47
+
48
+ if is_wandb_available():
49
+ _available_trackers.append(LoggerType.WANDB)
50
+
51
+ if is_comet_ml_available():
52
+ _available_trackers.append(LoggerType.COMETML)
53
+
54
+ if is_aim_available():
55
+ _available_trackers.append(LoggerType.AIM)
56
+
57
+ if is_mlflow_available():
58
+ _available_trackers.append(LoggerType.MLFLOW)
59
+
60
+ if is_clearml_available():
61
+ _available_trackers.append(LoggerType.CLEARML)
62
+
63
+ if is_dvclive_available():
64
+ _available_trackers.append(LoggerType.DVCLIVE)
65
+
66
+ logger = get_logger(__name__)
67
+
68
+
69
+ def on_main_process(function):
70
+ """
71
+ Decorator to selectively run the decorated function on the main process only based on the `main_process_only`
72
+ attribute in a class.
73
+
74
+ Checks at function execution rather than initialization time, not triggering the initialization of the
75
+ `PartialState`.
76
+ """
77
+
78
+ @wraps(function)
79
+ def execute_on_main_process(self, *args, **kwargs):
80
+ if getattr(self, "main_process_only", False):
81
+ return PartialState().on_main_process(function)(self, *args, **kwargs)
82
+ else:
83
+ return function(self, *args, **kwargs)
84
+
85
+ return execute_on_main_process
86
+
87
+
88
+ def get_available_trackers():
89
+ "Returns a list of all supported available trackers in the system"
90
+ return _available_trackers
91
+
92
+
93
+ class GeneralTracker:
94
+ """
95
+ A base Tracker class to be used for all logging integration implementations.
96
+
97
+ Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to
98
+ [`Accelerator`].
99
+
100
+ Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:
101
+
102
+ `name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory`
103
+ (`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
104
+ tracking mechanism used by a tracker class (such as the `run` for wandb)
105
+
106
+ Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
107
+ other functions should occur on the main process or across all processes (by default will use `True`)
108
+ """
109
+
110
+ main_process_only = True
111
+
112
+ def __init__(self, _blank=False):
113
+ if not _blank:
114
+ err = ""
115
+ if not hasattr(self, "name"):
116
+ err += "`name`"
117
+ if not hasattr(self, "requires_logging_directory"):
118
+ if len(err) > 0:
119
+ err += ", "
120
+ err += "`requires_logging_directory`"
121
+
122
+ # as tracker is a @property that relies on post-init
123
+ if "tracker" not in dir(self):
124
+ if len(err) > 0:
125
+ err += ", "
126
+ err += "`tracker`"
127
+ if len(err) > 0:
128
+ raise NotImplementedError(
129
+ f"The implementation for this tracker class is missing the following "
130
+ f"required attributes. Please define them in the class definition: "
131
+ f"{err}"
132
+ )
133
+
134
+ def store_init_configuration(self, values: dict):
135
+ """
136
+ Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration
137
+ functionality of a tracking API.
138
+
139
+ Args:
140
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
141
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
142
+ `str`, `float`, `int`, or `None`.
143
+ """
144
+ pass
145
+
146
+ def log(self, values: dict, step: Optional[int], **kwargs):
147
+ """
148
+ Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
149
+ special behavior for the `step parameter.
150
+
151
+ Args:
152
+ values (Dictionary `str` to `str`, `float`, or `int`):
153
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
154
+ step (`int`, *optional*):
155
+ The run step. If included, the log will be affiliated with this step.
156
+ """
157
+ pass
158
+
159
+ def finish(self):
160
+ """
161
+ Should run any finalizing functions within the tracking API. If the API should not have one, just don't
162
+ overwrite that method.
163
+ """
164
+ pass
165
+
166
+
167
+ class TensorBoardTracker(GeneralTracker):
168
+ """
169
+ A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.
170
+
171
+ Args:
172
+ run_name (`str`):
173
+ The name of the experiment run
174
+ logging_dir (`str`, `os.PathLike`):
175
+ Location for TensorBoard logs to be stored.
176
+ **kwargs (additional keyword arguments, *optional*):
177
+ Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
178
+ """
179
+
180
+ name = "tensorboard"
181
+ requires_logging_directory = True
182
+
183
+ @on_main_process
184
+ def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
185
+ try:
186
+ from torch.utils import tensorboard
187
+ except ModuleNotFoundError:
188
+ import tensorboardX as tensorboard
189
+ super().__init__()
190
+ self.run_name = run_name
191
+ self.logging_dir = os.path.join(logging_dir, run_name)
192
+ self.writer = tensorboard.SummaryWriter(self.logging_dir, **kwargs)
193
+ logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}")
194
+ logger.debug(
195
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
196
+ )
197
+
198
+ @property
199
+ def tracker(self):
200
+ return self.writer
201
+
202
+ @on_main_process
203
+ def store_init_configuration(self, values: dict):
204
+ """
205
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
206
+ hyperparameters in a yaml file for future use.
207
+
208
+ Args:
209
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
210
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
211
+ `str`, `float`, `int`, or `None`.
212
+ """
213
+ self.writer.add_hparams(values, metric_dict={})
214
+ self.writer.flush()
215
+ project_run_name = time.time()
216
+ dir_name = os.path.join(self.logging_dir, str(project_run_name))
217
+ os.makedirs(dir_name, exist_ok=True)
218
+ with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile:
219
+ try:
220
+ yaml.dump(values, outfile)
221
+ except yaml.representer.RepresenterError:
222
+ logger.error("Serialization to store hyperparameters failed")
223
+ raise
224
+ logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file")
225
+
226
+ @on_main_process
227
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
228
+ """
229
+ Logs `values` to the current run.
230
+
231
+ Args:
232
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
233
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
234
+ `str` to `float`/`int`.
235
+ step (`int`, *optional*):
236
+ The run step. If included, the log will be affiliated with this step.
237
+ kwargs:
238
+ Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
239
+ `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
240
+ """
241
+ values = listify(values)
242
+ for k, v in values.items():
243
+ if isinstance(v, (int, float)):
244
+ self.writer.add_scalar(k, v, global_step=step, **kwargs)
245
+ elif isinstance(v, str):
246
+ self.writer.add_text(k, v, global_step=step, **kwargs)
247
+ elif isinstance(v, dict):
248
+ self.writer.add_scalars(k, v, global_step=step, **kwargs)
249
+ self.writer.flush()
250
+ logger.debug("Successfully logged to TensorBoard")
251
+
252
+ @on_main_process
253
+ def log_images(self, values: dict, step: Optional[int], **kwargs):
254
+ """
255
+ Logs `images` to the current run.
256
+
257
+ Args:
258
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
259
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
260
+ step (`int`, *optional*):
261
+ The run step. If included, the log will be affiliated with this step.
262
+ kwargs:
263
+ Additional key word arguments passed along to the `SummaryWriter.add_image` method.
264
+ """
265
+ for k, v in values.items():
266
+ self.writer.add_images(k, v, global_step=step, **kwargs)
267
+ logger.debug("Successfully logged images to TensorBoard")
268
+
269
+ @on_main_process
270
+ def finish(self):
271
+ """
272
+ Closes `TensorBoard` writer
273
+ """
274
+ self.writer.close()
275
+ logger.debug("TensorBoard writer closed")
276
+
277
+
278
+ class WandBTracker(GeneralTracker):
279
+ """
280
+ A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.
281
+
282
+ Args:
283
+ run_name (`str`):
284
+ The name of the experiment run.
285
+ **kwargs (additional keyword arguments, *optional*):
286
+ Additional key word arguments passed along to the `wandb.init` method.
287
+ """
288
+
289
+ name = "wandb"
290
+ requires_logging_directory = False
291
+ main_process_only = False
292
+
293
+ @on_main_process
294
+ def __init__(self, run_name: str, **kwargs):
295
+ super().__init__()
296
+ self.run_name = run_name
297
+
298
+ import wandb
299
+
300
+ self.run = wandb.init(project=self.run_name, **kwargs)
301
+ logger.debug(f"Initialized WandB project {self.run_name}")
302
+ logger.debug(
303
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
304
+ )
305
+
306
+ @property
307
+ def tracker(self):
308
+ return self.run
309
+
310
+ @on_main_process
311
+ def store_init_configuration(self, values: dict):
312
+ """
313
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
314
+
315
+ Args:
316
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
317
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
318
+ `str`, `float`, `int`, or `None`.
319
+ """
320
+ import wandb
321
+
322
+ wandb.config.update(values, allow_val_change=True)
323
+ logger.debug("Stored initial configuration hyperparameters to WandB")
324
+
325
+ @on_main_process
326
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
327
+ """
328
+ Logs `values` to the current run.
329
+
330
+ Args:
331
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
332
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
333
+ `str` to `float`/`int`.
334
+ step (`int`, *optional*):
335
+ The run step. If included, the log will be affiliated with this step.
336
+ kwargs:
337
+ Additional key word arguments passed along to the `wandb.log` method.
338
+ """
339
+ self.run.log(values, step=step, **kwargs)
340
+ logger.debug("Successfully logged to WandB")
341
+
342
+ @on_main_process
343
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
344
+ """
345
+ Logs `images` to the current run.
346
+
347
+ Args:
348
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
349
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
350
+ step (`int`, *optional*):
351
+ The run step. If included, the log will be affiliated with this step.
352
+ kwargs:
353
+ Additional key word arguments passed along to the `wandb.log` method.
354
+ """
355
+ import wandb
356
+
357
+ for k, v in values.items():
358
+ self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs)
359
+ logger.debug("Successfully logged images to WandB")
360
+
361
+ @on_main_process
362
+ def log_table(
363
+ self,
364
+ table_name: str,
365
+ columns: list[str] = None,
366
+ data: list[list[Any]] = None,
367
+ dataframe: Any = None,
368
+ step: Optional[int] = None,
369
+ **kwargs,
370
+ ):
371
+ """
372
+ Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
373
+ with `columns` and `data` or with `dataframe`.
374
+
375
+ Args:
376
+ table_name (`str`):
377
+ The name to give to the logged table on the wandb workspace
378
+ columns (list of `str`, *optional*):
379
+ The name of the columns on the table
380
+ data (List of List of Any data type, *optional*):
381
+ The data to be logged in the table
382
+ dataframe (Any data type, *optional*):
383
+ The data to be logged in the table
384
+ step (`int`, *optional*):
385
+ The run step. If included, the log will be affiliated with this step.
386
+ """
387
+ import wandb
388
+
389
+ values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
390
+ self.log(values, step=step, **kwargs)
391
+
392
+ @on_main_process
393
+ def finish(self):
394
+ """
395
+ Closes `wandb` writer
396
+ """
397
+ self.run.finish()
398
+ logger.debug("WandB run closed")
399
+
400
+
401
+ class CometMLTracker(GeneralTracker):
402
+ """
403
+ A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
404
+
405
+ API keys must be stored in a Comet config file.
406
+
407
+ Note:
408
+ For `comet_ml` versions < 3.41.0, additional keyword arguments are passed to `comet_ml.Experiment` instead:
409
+ https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/#comet_ml.Experiment.__init__
410
+
411
+ Args:
412
+ run_name (`str`):
413
+ The name of the experiment run.
414
+ **kwargs (additional keyword arguments, *optional*):
415
+ Additional key word arguments passed along to the `comet_ml.start` method:
416
+ https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/start/
417
+ """
418
+
419
+ name = "comet_ml"
420
+ requires_logging_directory = False
421
+
422
+ @on_main_process
423
+ def __init__(self, run_name: str, **kwargs):
424
+ super().__init__()
425
+ self.run_name = run_name
426
+
427
+ import comet_ml
428
+
429
+ comet_version = version.parse(comet_ml.__version__)
430
+ if compare_versions(comet_version, ">=", "3.41.0"):
431
+ self.writer = comet_ml.start(project_name=run_name, **kwargs)
432
+ else:
433
+ logger.info("Update `comet_ml` (>=3.41.0) for experiment reuse and offline support.")
434
+ self.writer = comet_ml.Experiment(project_name=run_name, **kwargs)
435
+
436
+ logger.debug(f"Initialized CometML project {self.run_name}")
437
+ logger.debug(
438
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
439
+ )
440
+
441
+ @property
442
+ def tracker(self):
443
+ return self.writer
444
+
445
+ @on_main_process
446
+ def store_init_configuration(self, values: dict):
447
+ """
448
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
449
+
450
+ Args:
451
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
452
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
453
+ `str`, `float`, `int`, or `None`.
454
+ """
455
+ self.writer.log_parameters(values)
456
+ logger.debug("Stored initial configuration hyperparameters to Comet")
457
+
458
+ @on_main_process
459
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
460
+ """
461
+ Logs `values` to the current run.
462
+
463
+ Args:
464
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
465
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
466
+ `str` to `float`/`int`.
467
+ step (`int`, *optional*):
468
+ The run step. If included, the log will be affiliated with this step.
469
+ kwargs:
470
+ Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,
471
+ or `Experiment.log_metrics` method based on the contents of `values`.
472
+ """
473
+ if step is not None:
474
+ self.writer.set_step(step)
475
+ for k, v in values.items():
476
+ if isinstance(v, (int, float)):
477
+ self.writer.log_metric(k, v, step=step, **kwargs)
478
+ elif isinstance(v, str):
479
+ self.writer.log_other(k, v, **kwargs)
480
+ elif isinstance(v, dict):
481
+ self.writer.log_metrics(v, step=step, **kwargs)
482
+ logger.debug("Successfully logged to Comet")
483
+
484
+ @on_main_process
485
+ def finish(self):
486
+ """
487
+ Flush `comet-ml` writer
488
+ """
489
+ self.writer.end()
490
+ logger.debug("Comet run flushed")
491
+
492
+
493
+ class AimTracker(GeneralTracker):
494
+ """
495
+ A `Tracker` class that supports `aim`. Should be initialized at the start of your script.
496
+
497
+ Args:
498
+ run_name (`str`):
499
+ The name of the experiment run.
500
+ **kwargs (additional keyword arguments, *optional*):
501
+ Additional key word arguments passed along to the `Run.__init__` method.
502
+ """
503
+
504
+ name = "aim"
505
+ requires_logging_directory = True
506
+
507
+ @on_main_process
508
+ def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs):
509
+ self.run_name = run_name
510
+
511
+ from aim import Run
512
+
513
+ self.writer = Run(repo=logging_dir, **kwargs)
514
+ self.writer.name = self.run_name
515
+ logger.debug(f"Initialized Aim project {self.run_name}")
516
+ logger.debug(
517
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
518
+ )
519
+
520
+ @property
521
+ def tracker(self):
522
+ return self.writer
523
+
524
+ @on_main_process
525
+ def store_init_configuration(self, values: dict):
526
+ """
527
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
528
+
529
+ Args:
530
+ values (`dict`):
531
+ Values to be stored as initial hyperparameters as key-value pairs.
532
+ """
533
+ self.writer["hparams"] = values
534
+
535
+ @on_main_process
536
+ def log(self, values: dict, step: Optional[int], **kwargs):
537
+ """
538
+ Logs `values` to the current run.
539
+
540
+ Args:
541
+ values (`dict`):
542
+ Values to be logged as key-value pairs.
543
+ step (`int`, *optional*):
544
+ The run step. If included, the log will be affiliated with this step.
545
+ kwargs:
546
+ Additional key word arguments passed along to the `Run.track` method.
547
+ """
548
+ # Note: replace this with the dictionary support when merged
549
+ for key, value in values.items():
550
+ self.writer.track(value, name=key, step=step, **kwargs)
551
+
552
+ @on_main_process
553
+ def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[dict[str, dict]] = None):
554
+ """
555
+ Logs `images` to the current run.
556
+
557
+ Args:
558
+ values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`):
559
+ Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a
560
+ tuple is provided, the first element should be the image and the second element should be the caption.
561
+ step (`int`, *optional*):
562
+ The run step. If included, the log will be affiliated with this step.
563
+ kwargs (`Dict[str, dict]`):
564
+ Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the
565
+ keys `aim_image` and `track`, respectively.
566
+ """
567
+ import aim
568
+
569
+ aim_image_kw = {}
570
+ track_kw = {}
571
+
572
+ if kwargs is not None:
573
+ aim_image_kw = kwargs.get("aim_image", {})
574
+ track_kw = kwargs.get("track", {})
575
+
576
+ for key, value in values.items():
577
+ if isinstance(value, tuple):
578
+ img, caption = value
579
+ else:
580
+ img, caption = value, ""
581
+ aim_image = aim.Image(img, caption=caption, **aim_image_kw)
582
+ self.writer.track(aim_image, name=key, step=step, **track_kw)
583
+
584
+ @on_main_process
585
+ def finish(self):
586
+ """
587
+ Closes `aim` writer
588
+ """
589
+ self.writer.close()
590
+
591
+
592
+ class MLflowTracker(GeneralTracker):
593
+ """
594
+ A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.
595
+
596
+ Args:
597
+ experiment_name (`str`, *optional*):
598
+ Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.
599
+ logging_dir (`str` or `os.PathLike`, defaults to `"."`):
600
+ Location for mlflow logs to be stored.
601
+ run_id (`str`, *optional*):
602
+ If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s
603
+ end time is unset and its status is set to running, but the run’s other attributes (source_version,
604
+ source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.
605
+ tags (`Dict[str, str]`, *optional*):
606
+ An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a
607
+ run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are
608
+ set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.
609
+ nested_run (`bool`, *optional*, defaults to `False`):
610
+ Controls whether run is nested in parent run. True creates a nested run. Environment variable
611
+ MLFLOW_NESTED_RUN has priority over this argument.
612
+ run_name (`str`, *optional*):
613
+ Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.
614
+ description (`str`, *optional*):
615
+ An optional string that populates the description box of the run. If a run is being resumed, the
616
+ description is set on the resumed run. If a new run is being created, the description is set on the new
617
+ run.
618
+ """
619
+
620
+ name = "mlflow"
621
+ requires_logging_directory = False
622
+
623
+ @on_main_process
624
+ def __init__(
625
+ self,
626
+ experiment_name: str = None,
627
+ logging_dir: Optional[Union[str, os.PathLike]] = None,
628
+ run_id: Optional[str] = None,
629
+ tags: Optional[Union[dict[str, Any], str]] = None,
630
+ nested_run: Optional[bool] = False,
631
+ run_name: Optional[str] = None,
632
+ description: Optional[str] = None,
633
+ ):
634
+ experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", experiment_name)
635
+ run_id = os.environ.get("MLFLOW_RUN_ID", run_id)
636
+ tags = os.environ.get("MLFLOW_TAGS", tags)
637
+ if isinstance(tags, str):
638
+ tags = json.loads(tags)
639
+
640
+ nested_run = os.environ.get("MLFLOW_NESTED_RUN", nested_run)
641
+
642
+ import mlflow
643
+
644
+ exps = mlflow.search_experiments(filter_string=f"name = '{experiment_name}'")
645
+ if len(exps) > 0:
646
+ if len(exps) > 1:
647
+ logger.warning("Multiple experiments with the same name found. Using first one.")
648
+ experiment_id = exps[0].experiment_id
649
+ else:
650
+ experiment_id = mlflow.create_experiment(
651
+ name=experiment_name,
652
+ artifact_location=logging_dir,
653
+ tags=tags,
654
+ )
655
+
656
+ self.active_run = mlflow.start_run(
657
+ run_id=run_id,
658
+ experiment_id=experiment_id,
659
+ run_name=run_name,
660
+ nested=nested_run,
661
+ tags=tags,
662
+ description=description,
663
+ )
664
+
665
+ logger.debug(f"Initialized mlflow experiment {experiment_name}")
666
+ logger.debug(
667
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
668
+ )
669
+
670
+ @property
671
+ def tracker(self):
672
+ return self.active_run
673
+
674
+ @on_main_process
675
+ def store_init_configuration(self, values: dict):
676
+ """
677
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
678
+
679
+ Args:
680
+ values (`dict`):
681
+ Values to be stored as initial hyperparameters as key-value pairs.
682
+ """
683
+ import mlflow
684
+
685
+ for name, value in list(values.items()):
686
+ # internally, all values are converted to str in MLflow
687
+ if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
688
+ logger.warning_once(
689
+ f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
690
+ f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute."
691
+ )
692
+ del values[name]
693
+
694
+ values_list = list(values.items())
695
+
696
+ # MLflow cannot log more than 100 values in one go, so we have to split it
697
+ for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH):
698
+ mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH]))
699
+
700
+ logger.debug("Stored initial configuration hyperparameters to MLflow")
701
+
702
+ @on_main_process
703
+ def log(self, values: dict, step: Optional[int]):
704
+ """
705
+ Logs `values` to the current run.
706
+
707
+ Args:
708
+ values (`dict`):
709
+ Values to be logged as key-value pairs.
710
+ step (`int`, *optional*):
711
+ The run step. If included, the log will be affiliated with this step.
712
+ """
713
+ metrics = {}
714
+ for k, v in values.items():
715
+ if isinstance(v, (int, float)):
716
+ metrics[k] = v
717
+ else:
718
+ logger.warning_once(
719
+ f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
720
+ "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
721
+ )
722
+ import mlflow
723
+
724
+ mlflow.log_metrics(metrics, step=step)
725
+ logger.debug("Successfully logged to mlflow")
726
+
727
+ @on_main_process
728
+ def log_figure(self, figure: Any, artifact_file: str, **save_kwargs):
729
+ """
730
+ Logs an figure to the current run.
731
+
732
+ Args:
733
+ figure (Any):
734
+ The figure to be logged.
735
+ artifact_file (`str`, *optional*):
736
+ The run-relative artifact file path in posixpath format to which the image is saved.
737
+ If not provided, the image is saved to a default location.
738
+ **kwargs:
739
+ Additional keyword arguments passed to the underlying mlflow.log_image function.
740
+ """
741
+ import mlflow
742
+
743
+ mlflow.log_figure(figure=figure, artifact_file=artifact_file, **save_kwargs)
744
+ logger.debug("Successfully logged image to mlflow")
745
+
746
+ @on_main_process
747
+ def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None):
748
+ """
749
+ Logs an artifacts (all content of a dir) to the current run.
750
+
751
+ local_dir (`str`):
752
+ Path to the directory to be logged as an artifact.
753
+ artifact_path (`str`, *optional*):
754
+ Directory within the run's artifact directory where the artifact will be logged. If omitted, the
755
+ artifact will be logged to the root of the run's artifact directory. The run step. If included, the
756
+ artifact will be affiliated with this step.
757
+ """
758
+ import mlflow
759
+
760
+ mlflow.log_artifacts(local_dir=local_dir, artifact_path=artifact_path)
761
+ logger.debug("Successfully logged artofact to mlflow")
762
+
763
+ @on_main_process
764
+ def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
765
+ """
766
+ Logs an artifact (file) to the current run.
767
+
768
+ local_path (`str`):
769
+ Path to the file to be logged as an artifact.
770
+ artifact_path (`str`, *optional*):
771
+ Directory within the run's artifact directory where the artifact will be logged. If omitted, the
772
+ artifact will be logged to the root of the run's artifact directory. The run step. If included, the
773
+ artifact will be affiliated with this step.
774
+ """
775
+ import mlflow
776
+
777
+ mlflow.log_artifact(local_path=local_path, artifact_path=artifact_path)
778
+ logger.debug("Successfully logged artofact to mlflow")
779
+
780
+ @on_main_process
781
+ def finish(self):
782
+ """
783
+ End the active MLflow run.
784
+ """
785
+ import mlflow
786
+
787
+ mlflow.end_run()
788
+
789
+
790
+ class ClearMLTracker(GeneralTracker):
791
+ """
792
+ A `Tracker` class that supports `clearml`. Should be initialized at the start of your script.
793
+
794
+ Args:
795
+ run_name (`str`, *optional*):
796
+ Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this
797
+ argument.
798
+ **kwargs (additional keyword arguments, *optional*):
799
+ Kwargs passed along to the `Task.__init__` method.
800
+ """
801
+
802
+ name = "clearml"
803
+ requires_logging_directory = False
804
+
805
+ @on_main_process
806
+ def __init__(self, run_name: str = None, **kwargs):
807
+ from clearml import Task
808
+
809
+ current_task = Task.current_task()
810
+ self._initialized_externally = False
811
+ if current_task:
812
+ self._initialized_externally = True
813
+ self.task = current_task
814
+ return
815
+
816
+ kwargs.setdefault("project_name", os.environ.get("CLEARML_PROJECT", run_name))
817
+ kwargs.setdefault("task_name", os.environ.get("CLEARML_TASK", run_name))
818
+ self.task = Task.init(**kwargs)
819
+
820
+ @property
821
+ def tracker(self):
822
+ return self.task
823
+
824
+ @on_main_process
825
+ def store_init_configuration(self, values: dict):
826
+ """
827
+ Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment.
828
+
829
+ Args:
830
+ values (`dict`):
831
+ Values to be stored as initial hyperparameters as key-value pairs.
832
+ """
833
+ return self.task.connect_configuration(values)
834
+
835
+ @on_main_process
836
+ def log(self, values: dict[str, Union[int, float]], step: Optional[int] = None, **kwargs):
837
+ """
838
+ Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be
839
+ ints or floats
840
+
841
+ Args:
842
+ values (`Dict[str, Union[int, float]]`):
843
+ Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will
844
+ be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed.
845
+ Otherwise, the value will be reported under the 'train' series, and no prefix will be removed.
846
+ step (`int`, *optional*):
847
+ If specified, the values will be reported as scalars, with the iteration number equal to `step`.
848
+ Otherwise they will be reported as single values.
849
+ kwargs:
850
+ Additional key word arguments passed along to the `clearml.Logger.report_single_value` or
851
+ `clearml.Logger.report_scalar` methods.
852
+ """
853
+ clearml_logger = self.task.get_logger()
854
+ for k, v in values.items():
855
+ if not isinstance(v, (int, float)):
856
+ logger.warning_once(
857
+ "Accelerator is attempting to log a value of "
858
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
859
+ "This invocation of ClearML logger's report_scalar() "
860
+ "is incorrect so we dropped this attribute."
861
+ )
862
+ continue
863
+ if step is None:
864
+ clearml_logger.report_single_value(name=k, value=v, **kwargs)
865
+ continue
866
+ title, series = ClearMLTracker._get_title_series(k)
867
+ clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs)
868
+
869
+ @on_main_process
870
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
871
+ """
872
+ Logs `images` to the current run.
873
+
874
+ Args:
875
+ values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`):
876
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
877
+ step (`int`, *optional*):
878
+ The run step. If included, the log will be affiliated with this step.
879
+ kwargs:
880
+ Additional key word arguments passed along to the `clearml.Logger.report_image` method.
881
+ """
882
+ clearml_logger = self.task.get_logger()
883
+ for k, v in values.items():
884
+ title, series = ClearMLTracker._get_title_series(k)
885
+ clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs)
886
+
887
+ @on_main_process
888
+ def log_table(
889
+ self,
890
+ table_name: str,
891
+ columns: list[str] = None,
892
+ data: list[list[Any]] = None,
893
+ dataframe: Any = None,
894
+ step: Optional[int] = None,
895
+ **kwargs,
896
+ ):
897
+ """
898
+ Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`.
899
+
900
+ Args:
901
+ table_name (`str`):
902
+ The name of the table
903
+ columns (list of `str`, *optional*):
904
+ The name of the columns on the table
905
+ data (List of List of Any data type, *optional*):
906
+ The data to be logged in the table. If `columns` is not specified, then the first entry in data will be
907
+ the name of the columns of the table
908
+ dataframe (Any data type, *optional*):
909
+ The data to be logged in the table
910
+ step (`int`, *optional*):
911
+ The run step. If included, the log will be affiliated with this step.
912
+ kwargs:
913
+ Additional key word arguments passed along to the `clearml.Logger.report_table` method.
914
+ """
915
+ to_report = dataframe
916
+ if dataframe is None:
917
+ if data is None:
918
+ raise ValueError(
919
+ "`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`"
920
+ )
921
+ to_report = [columns] + data if columns else data
922
+ title, series = ClearMLTracker._get_title_series(table_name)
923
+ self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs)
924
+
925
+ @on_main_process
926
+ def finish(self):
927
+ """
928
+ Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this
929
+ function is a noop
930
+ """
931
+ if self.task and not self._initialized_externally:
932
+ self.task.close()
933
+
934
+ @staticmethod
935
+ def _get_title_series(name):
936
+ for prefix in ["eval", "test", "train"]:
937
+ if name.startswith(prefix + "_"):
938
+ return name[len(prefix) + 1 :], prefix
939
+ return name, "train"
940
+
941
+
942
+ class DVCLiveTracker(GeneralTracker):
943
+ """
944
+ A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.
945
+
946
+ Args:
947
+ run_name (`str`, *optional*):
948
+ Ignored for dvclive. See `kwargs` instead.
949
+ kwargs:
950
+ Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live).
951
+
952
+ Example:
953
+
954
+ ```py
955
+ from accelerate import Accelerator
956
+
957
+ accelerator = Accelerator(log_with="dvclive")
958
+ accelerator.init_trackers(project_name="my_project", init_kwargs={"dvclive": {"dir": "my_directory"}})
959
+ ```
960
+ """
961
+
962
+ name = "dvclive"
963
+ requires_logging_directory = False
964
+
965
+ @on_main_process
966
+ def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):
967
+ from dvclive import Live
968
+
969
+ super().__init__()
970
+ self.live = live if live is not None else Live(**kwargs)
971
+
972
+ @property
973
+ def tracker(self):
974
+ return self.live
975
+
976
+ @on_main_process
977
+ def store_init_configuration(self, values: dict):
978
+ """
979
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
980
+ hyperparameters in a yaml file for future use.
981
+
982
+ Args:
983
+ values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):
984
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
985
+ `str`, `float`, or `int`.
986
+ """
987
+ self.live.log_params(values)
988
+
989
+ @on_main_process
990
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
991
+ """
992
+ Logs `values` to the current run.
993
+
994
+ Args:
995
+ values (Dictionary `str` to `str`, `float`, or `int`):
996
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
997
+ step (`int`, *optional*):
998
+ The run step. If included, the log will be affiliated with this step.
999
+ kwargs:
1000
+ Additional key word arguments passed along to `dvclive.Live.log_metric()`.
1001
+ """
1002
+ from dvclive.plots import Metric
1003
+
1004
+ if step is not None:
1005
+ self.live.step = step
1006
+ for k, v in values.items():
1007
+ if Metric.could_log(v):
1008
+ self.live.log_metric(k, v, **kwargs)
1009
+ else:
1010
+ logger.warning_once(
1011
+ "Accelerator attempted to log a value of "
1012
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
1013
+ "This invocation of DVCLive's Live.log_metric() "
1014
+ "is incorrect so we dropped this attribute."
1015
+ )
1016
+ self.live.next_step()
1017
+
1018
+ @on_main_process
1019
+ def finish(self):
1020
+ """
1021
+ Closes `dvclive.Live()`.
1022
+ """
1023
+ self.live.end()
1024
+
1025
+
1026
+ LOGGER_TYPE_TO_CLASS = {
1027
+ "aim": AimTracker,
1028
+ "comet_ml": CometMLTracker,
1029
+ "mlflow": MLflowTracker,
1030
+ "tensorboard": TensorBoardTracker,
1031
+ "wandb": WandBTracker,
1032
+ "clearml": ClearMLTracker,
1033
+ "dvclive": DVCLiveTracker,
1034
+ }
1035
+
1036
+
1037
+ def filter_trackers(
1038
+ log_with: list[Union[str, LoggerType, GeneralTracker]],
1039
+ logging_dir: Union[str, os.PathLike] = None,
1040
+ ):
1041
+ """
1042
+ Takes in a list of potential tracker types and checks that:
1043
+ - The tracker wanted is available in that environment
1044
+ - Filters out repeats of tracker types
1045
+ - If `all` is in `log_with`, will return all trackers in the environment
1046
+ - If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`
1047
+
1048
+ Args:
1049
+ log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
1050
+ A list of loggers to be setup for experiment tracking. Should be one or several of:
1051
+
1052
+ - `"all"`
1053
+ - `"tensorboard"`
1054
+ - `"wandb"`
1055
+ - `"comet_ml"`
1056
+ - `"mlflow"`
1057
+ - `"dvclive"`
1058
+ If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
1059
+ also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
1060
+ logging_dir (`str`, `os.PathLike`, *optional*):
1061
+ A path to a directory for storing logs of locally-compatible loggers.
1062
+ """
1063
+ loggers = []
1064
+ if log_with is not None:
1065
+ if not isinstance(log_with, (list, tuple)):
1066
+ log_with = [log_with]
1067
+ if "all" in log_with or LoggerType.ALL in log_with:
1068
+ loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers()
1069
+ else:
1070
+ for log_type in log_with:
1071
+ if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker):
1072
+ raise ValueError(f"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}")
1073
+ if issubclass(type(log_type), GeneralTracker):
1074
+ loggers.append(log_type)
1075
+ else:
1076
+ log_type = LoggerType(log_type)
1077
+ if log_type not in loggers:
1078
+ if log_type in get_available_trackers():
1079
+ tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)]
1080
+ if tracker_init.requires_logging_directory:
1081
+ if logging_dir is None:
1082
+ raise ValueError(
1083
+ f"Logging with `{log_type}` requires a `logging_dir` to be passed in."
1084
+ )
1085
+ loggers.append(log_type)
1086
+ else:
1087
+ logger.debug(f"Tried adding logger {log_type}, but package is unavailable in the system.")
1088
+
1089
+ return loggers
venv/Lib/site-packages/adodbapi/__init__.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nopycln: file # undecidable cases due to explicit re-exports https://github.com/hadialqattan/pycln/issues/205
2
+ """adodbapi - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
3
+
4
+ Copyright (C) 2002 Henrik Ekelund, version 2.1 by Vernon Cole
5
+ * https://sourceforge.net/projects/adodbapi
6
+ """
7
+
8
+ import time
9
+
10
+ # Re-exports to keep backward compatibility with existing code
11
+ from .adodbapi import (
12
+ Connection as Connection,
13
+ Cursor as Cursor,
14
+ __version__,
15
+ connect as connect,
16
+ dateconverter,
17
+ )
18
+ from .apibase import (
19
+ BINARY as BINARY,
20
+ DATETIME as DATETIME,
21
+ NUMBER as NUMBER,
22
+ ROWID as ROWID,
23
+ STRING as STRING,
24
+ DatabaseError as DatabaseError,
25
+ DataError as DataError,
26
+ Error as Error,
27
+ FetchFailedError as FetchFailedError,
28
+ IntegrityError as IntegrityError,
29
+ InterfaceError as InterfaceError,
30
+ InternalError as InternalError,
31
+ NotSupportedError as NotSupportedError,
32
+ OperationalError as OperationalError,
33
+ ProgrammingError as ProgrammingError,
34
+ Warning as Warning,
35
+ apilevel as apilevel,
36
+ paramstyle as paramstyle,
37
+ threadsafety as threadsafety,
38
+ )
39
+
40
+
41
+ def Binary(aString):
42
+ """This function constructs an object capable of holding a binary (long) string value."""
43
+ return bytes(aString)
44
+
45
+
46
+ def Date(year, month, day):
47
+ "This function constructs an object holding a date value."
48
+ return dateconverter.Date(year, month, day)
49
+
50
+
51
+ def Time(hour, minute, second):
52
+ "This function constructs an object holding a time value."
53
+ return dateconverter.Time(hour, minute, second)
54
+
55
+
56
+ def Timestamp(year, month, day, hour, minute, second):
57
+ "This function constructs an object holding a time stamp value."
58
+ return dateconverter.Timestamp(year, month, day, hour, minute, second)
59
+
60
+
61
+ def DateFromTicks(ticks):
62
+ """This function constructs an object holding a date value from the given ticks value
63
+ (number of seconds since the epoch; see the documentation of the standard Python time module for details).
64
+ """
65
+ return Date(*time.gmtime(ticks)[:3])
66
+
67
+
68
+ def TimeFromTicks(ticks):
69
+ """This function constructs an object holding a time value from the given ticks value
70
+ (number of seconds since the epoch; see the documentation of the standard Python time module for details).
71
+ """
72
+ return Time(*time.gmtime(ticks)[3:6])
73
+
74
+
75
+ def TimestampFromTicks(ticks):
76
+ """This function constructs an object holding a time stamp value from the given
77
+ ticks value (number of seconds since the epoch;
78
+ see the documentation of the standard Python time module for details)."""
79
+ return Timestamp(*time.gmtime(ticks)[:6])
80
+
81
+
82
+ version = "adodbapi v" + __version__
venv/Lib/site-packages/adodbapi/ado_consts.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADO enumerated constants documented on MSDN:
2
+ # https://learn.microsoft.com/en-us/sql/ado/reference/ado-api/ado-enumerated-constants
3
+ # TODO: Update to https://learn.microsoft.com/en-us/sql/ado/reference/ado-api/ado-enumerated-constants
4
+
5
+ # IsolationLevelEnum
6
+ adXactUnspecified = -1
7
+ adXactBrowse = 0x100
8
+ adXactChaos = 0x10
9
+ adXactCursorStability = 0x1000
10
+ adXactIsolated = 0x100000
11
+ adXactReadCommitted = 0x1000
12
+ adXactReadUncommitted = 0x100
13
+ adXactRepeatableRead = 0x10000
14
+ adXactSerializable = 0x100000
15
+
16
+ # CursorLocationEnum
17
+ adUseClient = 3
18
+ adUseServer = 2
19
+
20
+ # CursorTypeEnum
21
+ adOpenDynamic = 2
22
+ adOpenForwardOnly = 0
23
+ adOpenKeyset = 1
24
+ adOpenStatic = 3
25
+ adOpenUnspecified = -1
26
+
27
+ # CommandTypeEnum
28
+ adCmdText = 1
29
+ adCmdStoredProc = 4
30
+ adSchemaTables = 20
31
+
32
+ # ParameterDirectionEnum
33
+ adParamInput = 1
34
+ adParamInputOutput = 3
35
+ adParamOutput = 2
36
+ adParamReturnValue = 4
37
+ adParamUnknown = 0
38
+ directions = {
39
+ 0: "Unknown",
40
+ 1: "Input",
41
+ 2: "Output",
42
+ 3: "InputOutput",
43
+ 4: "Return",
44
+ }
45
+
46
+
47
+ def ado_direction_name(ado_dir):
48
+ try:
49
+ return "adParam" + directions[ado_dir]
50
+ except:
51
+ return f"unknown direction ({ado_dir})"
52
+
53
+
54
+ # ObjectStateEnum
55
+ adStateClosed = 0
56
+ adStateOpen = 1
57
+ adStateConnecting = 2
58
+ adStateExecuting = 4
59
+ adStateFetching = 8
60
+
61
+ # FieldAttributeEnum
62
+ adFldMayBeNull = 0x40
63
+
64
+ # ConnectModeEnum
65
+ adModeUnknown = 0
66
+ adModeRead = 1
67
+ adModeWrite = 2
68
+ adModeReadWrite = 3
69
+ adModeShareDenyRead = 4
70
+ adModeShareDenyWrite = 8
71
+ adModeShareExclusive = 12
72
+ adModeShareDenyNone = 16
73
+ adModeRecursive = 0x400000
74
+
75
+ # XactAttributeEnum
76
+ adXactCommitRetaining = 131072
77
+ adXactAbortRetaining = 262144
78
+
79
+ ado_error_TIMEOUT = -2147217871
80
+
81
+ # DataTypeEnum - ADO Data types documented at:
82
+ # http://msdn2.microsoft.com/en-us/library/ms675318.aspx
83
+ # TODO: Update to https://learn.microsoft.com/en-us/sql/ado/reference/ado-api/datatypeenum
84
+ adArray = 0x2000
85
+ adEmpty = 0x0
86
+ adBSTR = 0x8
87
+ adBigInt = 0x14
88
+ adBinary = 0x80
89
+ adBoolean = 0xB
90
+ adChapter = 0x88
91
+ adChar = 0x81
92
+ adCurrency = 0x6
93
+ adDBDate = 0x85
94
+ adDBTime = 0x86
95
+ adDBTimeStamp = 0x87
96
+ adDate = 0x7
97
+ adDecimal = 0xE
98
+ adDouble = 0x5
99
+ adError = 0xA
100
+ adFileTime = 0x40
101
+ adGUID = 0x48
102
+ adIDispatch = 0x9
103
+ adIUnknown = 0xD
104
+ adInteger = 0x3
105
+ adLongVarBinary = 0xCD
106
+ adLongVarChar = 0xC9
107
+ adLongVarWChar = 0xCB
108
+ adNumeric = 0x83
109
+ adPropVariant = 0x8A
110
+ adSingle = 0x4
111
+ adSmallInt = 0x2
112
+ adTinyInt = 0x10
113
+ adUnsignedBigInt = 0x15
114
+ adUnsignedInt = 0x13
115
+ adUnsignedSmallInt = 0x12
116
+ adUnsignedTinyInt = 0x11
117
+ adUserDefined = 0x84
118
+ adVarBinary = 0xCC
119
+ adVarChar = 0xC8
120
+ adVarNumeric = 0x8B
121
+ adVarWChar = 0xCA
122
+ adVariant = 0xC
123
+ adWChar = 0x82
124
+ # Additional constants used by introspection but not ADO itself
125
+ AUTO_FIELD_MARKER = -1000
126
+
127
+ adTypeNames = {
128
+ adBSTR: "adBSTR",
129
+ adBigInt: "adBigInt",
130
+ adBinary: "adBinary",
131
+ adBoolean: "adBoolean",
132
+ adChapter: "adChapter",
133
+ adChar: "adChar",
134
+ adCurrency: "adCurrency",
135
+ adDBDate: "adDBDate",
136
+ adDBTime: "adDBTime",
137
+ adDBTimeStamp: "adDBTimeStamp",
138
+ adDate: "adDate",
139
+ adDecimal: "adDecimal",
140
+ adDouble: "adDouble",
141
+ adEmpty: "adEmpty",
142
+ adError: "adError",
143
+ adFileTime: "adFileTime",
144
+ adGUID: "adGUID",
145
+ adIDispatch: "adIDispatch",
146
+ adIUnknown: "adIUnknown",
147
+ adInteger: "adInteger",
148
+ adLongVarBinary: "adLongVarBinary",
149
+ adLongVarChar: "adLongVarChar",
150
+ adLongVarWChar: "adLongVarWChar",
151
+ adNumeric: "adNumeric",
152
+ adPropVariant: "adPropVariant",
153
+ adSingle: "adSingle",
154
+ adSmallInt: "adSmallInt",
155
+ adTinyInt: "adTinyInt",
156
+ adUnsignedBigInt: "adUnsignedBigInt",
157
+ adUnsignedInt: "adUnsignedInt",
158
+ adUnsignedSmallInt: "adUnsignedSmallInt",
159
+ adUnsignedTinyInt: "adUnsignedTinyInt",
160
+ adUserDefined: "adUserDefined",
161
+ adVarBinary: "adVarBinary",
162
+ adVarChar: "adVarChar",
163
+ adVarNumeric: "adVarNumeric",
164
+ adVarWChar: "adVarWChar",
165
+ adVariant: "adVariant",
166
+ adWChar: "adWChar",
167
+ }
168
+
169
+
170
+ def ado_type_name(ado_type):
171
+ return adTypeNames.get(ado_type, f"unknown type ({ado_type})")
172
+
173
+
174
+ # here in decimal, sorted by value
175
+ # adEmpty 0 Specifies no value (DBTYPE_EMPTY).
176
+ # adSmallInt 2 Indicates a two-byte signed integer (DBTYPE_I2).
177
+ # adInteger 3 Indicates a four-byte signed integer (DBTYPE_I4).
178
+ # adSingle 4 Indicates a single-precision floating-point value (DBTYPE_R4).
179
+ # adDouble 5 Indicates a double-precision floating-point value (DBTYPE_R8).
180
+ # adCurrency 6 Indicates a currency value (DBTYPE_CY). Currency is a fixed-point number
181
+ # with four digits to the right of the decimal point. It is stored in an eight-byte signed integer scaled by 10,000.
182
+ # adDate 7 Indicates a date value (DBTYPE_DATE). A date is stored as a double, the whole part of which is
183
+ # the number of days since December 30, 1899, and the fractional part of which is the fraction of a day.
184
+ # adBSTR 8 Indicates a null-terminated character string (Unicode) (DBTYPE_BSTR).
185
+ # adIDispatch 9 Indicates a pointer to an IDispatch interface on a COM object (DBTYPE_IDISPATCH).
186
+ # adError 10 Indicates a 32-bit error code (DBTYPE_ERROR).
187
+ # adBoolean 11 Indicates a boolean value (DBTYPE_BOOL).
188
+ # adVariant 12 Indicates an Automation Variant (DBTYPE_VARIANT).
189
+ # adIUnknown 13 Indicates a pointer to an IUnknown interface on a COM object (DBTYPE_IUNKNOWN).
190
+ # adDecimal 14 Indicates an exact numeric value with a fixed precision and scale (DBTYPE_DECIMAL).
191
+ # adTinyInt 16 Indicates a one-byte signed integer (DBTYPE_I1).
192
+ # adUnsignedTinyInt 17 Indicates a one-byte unsigned integer (DBTYPE_UI1).
193
+ # adUnsignedSmallInt 18 Indicates a two-byte unsigned integer (DBTYPE_UI2).
194
+ # adUnsignedInt 19 Indicates a four-byte unsigned integer (DBTYPE_UI4).
195
+ # adBigInt 20 Indicates an eight-byte signed integer (DBTYPE_I8).
196
+ # adUnsignedBigInt 21 Indicates an eight-byte unsigned integer (DBTYPE_UI8).
197
+ # adFileTime 64 Indicates a 64-bit value representing the number of 100-nanosecond intervals since
198
+ # January 1, 1601 (DBTYPE_FILETIME).
199
+ # adGUID 72 Indicates a globally unique identifier (GUID) (DBTYPE_GUID).
200
+ # adBinary 128 Indicates a binary value (DBTYPE_BYTES).
201
+ # adChar 129 Indicates a string value (DBTYPE_STR).
202
+ # adWChar 130 Indicates a null-terminated Unicode character string (DBTYPE_WSTR).
203
+ # adNumeric 131 Indicates an exact numeric value with a fixed precision and scale (DBTYPE_NUMERIC).
204
+ # adUserDefined 132 Indicates a user-defined variable (DBTYPE_UDT).
205
+ # adUserDefined 132 Indicates a user-defined variable (DBTYPE_UDT).
206
+ # adDBDate 133 Indicates a date value (yyyymmdd) (DBTYPE_DBDATE).
207
+ # adDBTime 134 Indicates a time value (hhmmss) (DBTYPE_DBTIME).
208
+ # adDBTimeStamp 135 Indicates a date/time stamp (yyyymmddhhmmss plus a fraction in billionths) (DBTYPE_DBTIMESTAMP).
209
+ # adChapter 136 Indicates a four-byte chapter value that identifies rows in a child rowset (DBTYPE_HCHAPTER).
210
+ # adPropVariant 138 Indicates an Automation PROPVARIANT (DBTYPE_PROP_VARIANT).
211
+ # adVarNumeric 139 Indicates a numeric value (Parameter object only).
212
+ # adVarChar 200 Indicates a string value (Parameter object only).
213
+ # adLongVarChar 201 Indicates a long string value (Parameter object only).
214
+ # adVarWChar 202 Indicates a null-terminated Unicode character string (Parameter object only).
215
+ # adLongVarWChar 203 Indicates a long null-terminated Unicode string value (Parameter object only).
216
+ # adVarBinary 204 Indicates a binary value (Parameter object only).
217
+ # adLongVarBinary 205 Indicates a long binary value (Parameter object only).
218
+ # adArray (Does not apply to ADOX.) 0x2000 A flag value, always combined with another data type constant,
219
+ # that indicates an array of that other data type.
220
+
221
+ # Error codes to names
222
+ adoErrors = {
223
+ 0xE7B: "adErrBoundToCommand",
224
+ 0xE94: "adErrCannotComplete",
225
+ 0xEA4: "adErrCantChangeConnection",
226
+ 0xC94: "adErrCantChangeProvider",
227
+ 0xE8C: "adErrCantConvertvalue",
228
+ 0xE8D: "adErrCantCreate",
229
+ 0xEA3: "adErrCatalogNotSet",
230
+ 0xE8E: "adErrColumnNotOnThisRow",
231
+ 0xD5D: "adErrDataConversion",
232
+ 0xE89: "adErrDataOverflow",
233
+ 0xE9A: "adErrDelResOutOfScope",
234
+ 0xEA6: "adErrDenyNotSupported",
235
+ 0xEA7: "adErrDenyTypeNotSupported",
236
+ 0xCB3: "adErrFeatureNotAvailable",
237
+ 0xEA5: "adErrFieldsUpdateFailed",
238
+ 0xC93: "adErrIllegalOperation",
239
+ 0xCAE: "adErrInTransaction",
240
+ 0xE87: "adErrIntegrityViolation",
241
+ 0xBB9: "adErrInvalidArgument",
242
+ 0xE7D: "adErrInvalidConnection",
243
+ 0xE7C: "adErrInvalidParamInfo",
244
+ 0xE82: "adErrInvalidTransaction",
245
+ 0xE91: "adErrInvalidURL",
246
+ 0xCC1: "adErrItemNotFound",
247
+ 0xBCD: "adErrNoCurrentRecord",
248
+ 0xE83: "adErrNotExecuting",
249
+ 0xE7E: "adErrNotReentrant",
250
+ 0xE78: "adErrObjectClosed",
251
+ 0xD27: "adErrObjectInCollection",
252
+ 0xD5C: "adErrObjectNotSet",
253
+ 0xE79: "adErrObjectOpen",
254
+ 0xBBA: "adErrOpeningFile",
255
+ 0xE80: "adErrOperationCancelled",
256
+ 0xE96: "adErrOutOfSpace",
257
+ 0xE88: "adErrPermissionDenied",
258
+ 0xE9E: "adErrPropConflicting",
259
+ 0xE9B: "adErrPropInvalidColumn",
260
+ 0xE9C: "adErrPropInvalidOption",
261
+ 0xE9D: "adErrPropInvalidValue",
262
+ 0xE9F: "adErrPropNotAllSettable",
263
+ 0xEA0: "adErrPropNotSet",
264
+ 0xEA1: "adErrPropNotSettable",
265
+ 0xEA2: "adErrPropNotSupported",
266
+ 0xBB8: "adErrProviderFailed",
267
+ 0xE7A: "adErrProviderNotFound",
268
+ 0xBBB: "adErrReadFile",
269
+ 0xE93: "adErrResourceExists",
270
+ 0xE92: "adErrResourceLocked",
271
+ 0xE97: "adErrResourceOutOfScope",
272
+ 0xE8A: "adErrSchemaViolation",
273
+ 0xE8B: "adErrSignMismatch",
274
+ 0xE81: "adErrStillConnecting",
275
+ 0xE7F: "adErrStillExecuting",
276
+ 0xE90: "adErrTreePermissionDenied",
277
+ 0xE8F: "adErrURLDoesNotExist",
278
+ 0xE99: "adErrURLNamedRowDoesNotExist",
279
+ 0xE98: "adErrUnavailable",
280
+ 0xE84: "adErrUnsafeOperation",
281
+ 0xE95: "adErrVolumeNotFound",
282
+ 0xBBC: "adErrWriteFile",
283
+ }
venv/Lib/site-packages/adodbapi/adodbapi.py ADDED
@@ -0,0 +1,1153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adodbapi - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
2
+
3
+ Copyright (C) 2002 Henrik Ekelund, versions 2.1 and later by Vernon Cole
4
+ * https://sourceforge.net/projects/pywin32
5
+ * https://github.com/mhammond/pywin32
6
+ * https://sourceforge.net/projects/adodbapi
7
+
8
+ This library is free software; you can redistribute it and/or
9
+ modify it under the terms of the GNU Lesser General Public
10
+ License as published by the Free Software Foundation; either
11
+ version 2.1 of the License, or (at your option) any later version.
12
+
13
+ This library is distributed in the hope that it will be useful,
14
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
15
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16
+ Lesser General Public License for more details.
17
+
18
+ You should have received a copy of the GNU Lesser General Public
19
+ License along with this library; if not, write to the Free Software
20
+ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
21
+
22
+ django adaptations and refactoring by Adam Vandenberg
23
+
24
+ DB-API 2.0 specification: https://peps.python.org/pep-0249/
25
+
26
+ This module source should run correctly in CPython versions 2.7 and later,
27
+ or CPython 3.4 or later.
28
+ """
29
+
30
+ __version__ = "2.6.2.0"
31
+ version = "adodbapi v" + __version__
32
+
33
+ import copy
34
+ import decimal
35
+ import os
36
+ import sys
37
+ import weakref
38
+
39
+ from . import ado_consts as adc, apibase as api, process_connect_string
40
+
41
+ try:
42
+ verbose = int(os.environ["ADODBAPI_VERBOSE"])
43
+ except:
44
+ verbose = False
45
+ if verbose:
46
+ print(version)
47
+
48
+ try:
49
+ import pythoncom
50
+ import pywintypes
51
+ from win32com.client import Dispatch
52
+ except ImportError:
53
+ import warnings
54
+
55
+ warnings.warn("pywin32 package required for adodbapi.", ImportWarning)
56
+
57
+
58
+ def getIndexedValue(obj, index):
59
+ return obj(index)
60
+
61
+
62
+ from collections.abc import Mapping
63
+
64
+
65
+ # ----------------- The .connect method -----------------
66
+ def make_COM_connecter():
67
+ try:
68
+ pythoncom.CoInitialize() # v2.1 Paj
69
+ c = Dispatch("ADODB.Connection") # connect _after_ CoInitialize v2.1.1 adamvan
70
+ except:
71
+ raise api.InterfaceError(
72
+ "Windows COM Error: Dispatch('ADODB.Connection') failed."
73
+ )
74
+ return c
75
+
76
+
77
+ def connect(*args, **kwargs): # --> a db-api connection object
78
+ """Connect to a database.
79
+
80
+ call using:
81
+ :connection_string -- An ADODB formatted connection string, see:
82
+ * https://www.connectionstrings.com
83
+ * https://www.codeguru.com/dotnet/whats-in-an-ado-connection-string/
84
+ * https://learn.microsoft.com/en-us/dotnet/framework/data/adonet/connection-strings
85
+ :timeout -- A command timeout value, in seconds (default 30 seconds)
86
+ """
87
+ co = Connection() # make an empty connection object
88
+
89
+ kwargs = process_connect_string.process(args, kwargs, True)
90
+
91
+ try: # connect to the database, using the connection information in kwargs
92
+ co.connect(kwargs)
93
+ return co
94
+ except Exception as e:
95
+ message = 'Error opening connection to "%s"' % co.connection_string
96
+ raise api.OperationalError(e, message)
97
+
98
+
99
+ # so you could use something like:
100
+ # myConnection.paramstyle = 'named'
101
+ # The programmer may also change the default.
102
+ # For example, if I were using django, I would say:
103
+ # import adodbapi as Database
104
+ # Database.adodbapi.paramstyle = 'format'
105
+
106
+ # ------- other module level defaults --------
107
+ defaultIsolationLevel = adc.adXactReadCommitted
108
+ # Set defaultIsolationLevel on module level before creating the connection.
109
+ # For example:
110
+ # import adodbapi, ado_consts
111
+ # adodbapi.adodbapi.defaultIsolationLevel=ado_consts.adXactBrowse"
112
+ #
113
+ # Set defaultCursorLocation on module level before creating the connection.
114
+ # It may be one of the "adUse..." consts.
115
+ defaultCursorLocation = adc.adUseClient # changed from adUseServer as of v 2.3.0
116
+
117
+ dateconverter = api.pythonDateTimeConverter() # default
118
+
119
+
120
+ def format_parameters(ADOparameters, show_value=False):
121
+ """Format a collection of ADO Command Parameters.
122
+
123
+ Used by error reporting in _execute_command.
124
+ """
125
+ try:
126
+ if show_value:
127
+ desc = [
128
+ 'Name: %s, Dir.: %s, Type: %s, Size: %s, Value: "%s", Precision: %s, NumericScale: %s'
129
+ % (
130
+ p.Name,
131
+ adc.directions[p.Direction],
132
+ adc.adTypeNames.get(p.Type, str(p.Type) + " (unknown type)"),
133
+ p.Size,
134
+ p.Value,
135
+ p.Precision,
136
+ p.NumericScale,
137
+ )
138
+ for p in ADOparameters
139
+ ]
140
+ else:
141
+ desc = [
142
+ "Name: %s, Dir.: %s, Type: %s, Size: %s, Precision: %s, NumericScale: %s"
143
+ % (
144
+ p.Name,
145
+ adc.directions[p.Direction],
146
+ adc.adTypeNames.get(p.Type, str(p.Type) + " (unknown type)"),
147
+ p.Size,
148
+ p.Precision,
149
+ p.NumericScale,
150
+ )
151
+ for p in ADOparameters
152
+ ]
153
+ return "[" + "\n".join(desc) + "]"
154
+ except:
155
+ return "[]"
156
+
157
+
158
+ def _configure_parameter(p, value, adotype, settings_known):
159
+ """Configure the given ADO Parameter 'p' with the Python 'value'."""
160
+
161
+ if adotype in api.adoBinaryTypes:
162
+ p.Size = len(value)
163
+ p.AppendChunk(value)
164
+
165
+ elif isinstance(value, str): # v2.1 Jevon
166
+ length = len(value)
167
+ if adotype in api.adoStringTypes: # v2.2.1 Cole
168
+ if settings_known:
169
+ length = min(length, p.Size) # v2.1 Cole limit data to defined size
170
+ p.Value = value[:length] # v2.1 Jevon & v2.1 Cole
171
+ else:
172
+ p.Value = value # don't limit if db column is numeric
173
+ if length > 0: # v2.1 Cole something does not like p.Size as Zero
174
+ p.Size = length # v2.1 Jevon
175
+
176
+ elif isinstance(value, decimal.Decimal):
177
+ p.Value = value
178
+ exponent = value.as_tuple()[2]
179
+ digit_count = len(value.as_tuple()[1])
180
+ p.Precision = digit_count
181
+ if exponent == 0:
182
+ p.NumericScale = 0
183
+ elif exponent < 0:
184
+ p.NumericScale = -exponent
185
+ if p.Precision < p.NumericScale:
186
+ p.Precision = p.NumericScale
187
+ else: # exponent > 0:
188
+ p.NumericScale = 0
189
+ p.Precision = digit_count + exponent
190
+
191
+ elif type(value) in dateconverter.types:
192
+ if settings_known and adotype in api.adoDateTimeTypes:
193
+ p.Value = dateconverter.COMDate(value)
194
+ else: # probably a string
195
+ # provide the date as a string in the format 'YYYY-MM-dd'
196
+ s = dateconverter.DateObjectToIsoFormatString(value)
197
+ p.Value = s
198
+ p.Size = len(s)
199
+
200
+ elif adotype == adc.adEmpty: # ADO will not let you specify a null column
201
+ p.Type = (
202
+ adc.adInteger
203
+ ) # so we will fake it to be an integer (just to have something)
204
+ p.Value = None # and pass in a Null *value*
205
+
206
+ # For any other type, set the value and let pythoncom do the right thing.
207
+ else:
208
+ p.Value = value
209
+
210
+
211
+ # # # # # ----- the Class that defines a connection ----- # # # # #
212
+ class Connection:
213
+ # include connection attributes as class attributes required by api definition.
214
+ Warning = api.Warning
215
+ Error = api.Error
216
+ InterfaceError = api.InterfaceError
217
+ DataError = api.DataError
218
+ DatabaseError = api.DatabaseError
219
+ OperationalError = api.OperationalError
220
+ IntegrityError = api.IntegrityError
221
+ InternalError = api.InternalError
222
+ NotSupportedError = api.NotSupportedError
223
+ ProgrammingError = api.ProgrammingError
224
+ FetchFailedError = api.FetchFailedError # (special for django)
225
+ # ...class attributes... (can be overridden by instance attributes)
226
+ verbose = api.verbose
227
+
228
+ @property
229
+ def dbapi(self): # a proposed db-api version 3 extension.
230
+ "Return a reference to the DBAPI module for this Connection."
231
+ return api
232
+
233
+ def __init__(self): # now define the instance attributes
234
+ self.connector = None
235
+ self.paramstyle = api.paramstyle
236
+ self.supportsTransactions = False
237
+ self.connection_string = ""
238
+ self.cursors = weakref.WeakValueDictionary[int, Cursor]()
239
+ self.dbms_name = ""
240
+ self.dbms_version = ""
241
+ self.errorhandler = None # use the standard error handler for this instance
242
+ self.transaction_level = 0 # 0 == Not in a transaction, at the top level
243
+ self._autocommit = False
244
+
245
+ def connect(self, kwargs, connection_maker=make_COM_connecter):
246
+ if verbose > 9:
247
+ print(f"kwargs={kwargs!r}")
248
+ try:
249
+ self.connection_string = (
250
+ kwargs["connection_string"] % kwargs
251
+ ) # insert keyword arguments
252
+ except Exception as e:
253
+ self._raiseConnectionError(
254
+ KeyError, "Python string format error in connection string->"
255
+ )
256
+ self.timeout = kwargs.get("timeout", 30)
257
+ self.mode = kwargs.get("mode", adc.adModeUnknown)
258
+ self.kwargs = kwargs
259
+ if verbose:
260
+ print('%s attempting: "%s"' % (version, self.connection_string))
261
+ self.connector = connection_maker()
262
+ self.connector.ConnectionTimeout = self.timeout
263
+ self.connector.ConnectionString = self.connection_string
264
+ self.connector.Mode = self.mode
265
+
266
+ try:
267
+ self.connector.Open() # Open the ADO connection
268
+ except api.Error:
269
+ self._raiseConnectionError(
270
+ api.DatabaseError,
271
+ "ADO error trying to Open=%s" % self.connection_string,
272
+ )
273
+
274
+ try: # Stefan Fuchs; support WINCCOLEDBProvider
275
+ if getIndexedValue(self.connector.Properties, "Transaction DDL").Value != 0:
276
+ self.supportsTransactions = True
277
+ except pywintypes.com_error:
278
+ pass # Stefan Fuchs
279
+ self.dbms_name = getIndexedValue(self.connector.Properties, "DBMS Name").Value
280
+ try: # Stefan Fuchs
281
+ self.dbms_version = getIndexedValue(
282
+ self.connector.Properties, "DBMS Version"
283
+ ).Value
284
+ except pywintypes.com_error:
285
+ pass # Stefan Fuchs
286
+ self.connector.CursorLocation = defaultCursorLocation # v2.1 Rose
287
+ if self.supportsTransactions:
288
+ self.connector.IsolationLevel = defaultIsolationLevel
289
+ self._autocommit = bool(kwargs.get("autocommit", False))
290
+ if not self._autocommit:
291
+ self.transaction_level = (
292
+ self.connector.BeginTrans()
293
+ ) # Disables autocommit & inits transaction_level
294
+ else:
295
+ self._autocommit = True
296
+ if "paramstyle" in kwargs:
297
+ self.paramstyle = kwargs["paramstyle"] # let setattr do the error checking
298
+ self.messages = []
299
+ if verbose:
300
+ print("adodbapi New connection at %X" % id(self))
301
+
302
+ def _raiseConnectionError(self, errorclass, errorvalue):
303
+ eh = self.errorhandler
304
+ if eh is None:
305
+ eh = api.standardErrorHandler
306
+ eh(self, None, errorclass, errorvalue)
307
+
308
+ def _closeAdoConnection(self): # all v2.1 Rose
309
+ """close the underlying ADO Connection object,
310
+ rolling it back first if it supports transactions."""
311
+ if self.connector is None:
312
+ return
313
+ if not self._autocommit:
314
+ if self.transaction_level:
315
+ try:
316
+ self.connector.RollbackTrans()
317
+ except:
318
+ pass
319
+ self.connector.Close()
320
+ if verbose:
321
+ print("adodbapi Closed connection at %X" % id(self))
322
+
323
+ def close(self):
324
+ """Close the connection now (rather than whenever __del__ is called).
325
+
326
+ The connection will be unusable from this point forward;
327
+ an Error (or subclass) exception will be raised if any operation is attempted with the connection.
328
+ The same applies to all cursor objects trying to use the connection.
329
+ """
330
+ for crsr in list(self.cursors.values())[
331
+ :
332
+ ]: # copy the list, then close each one
333
+ crsr.close(dont_tell_me=True) # close without back-link clearing
334
+ self.messages = []
335
+ try:
336
+ self._closeAdoConnection() # v2.1 Rose
337
+ except Exception as e:
338
+ self._raiseConnectionError(sys.exc_info()[0], sys.exc_info()[1])
339
+
340
+ self.connector = None # v2.4.2.2 fix subtle timeout bug
341
+ # per M.Hammond: "I expect the benefits of uninitializing are probably fairly small,
342
+ # so never uninitializing will probably not cause any problems."
343
+
344
+ def commit(self):
345
+ """Commit any pending transaction to the database.
346
+
347
+ Note that if the database supports an auto-commit feature,
348
+ this must be initially off. An interface method may be provided to turn it back on.
349
+ Database modules that do not support transactions should implement this method with void functionality.
350
+ """
351
+ self.messages = []
352
+ if not self.supportsTransactions:
353
+ return
354
+
355
+ try:
356
+ self.transaction_level = self.connector.CommitTrans()
357
+ if verbose > 1:
358
+ print("commit done on connection at %X" % id(self))
359
+ if not (
360
+ self._autocommit
361
+ or (self.connector.Attributes & adc.adXactAbortRetaining)
362
+ ):
363
+ # If attributes has adXactCommitRetaining it performs retaining commits that is,
364
+ # calling CommitTrans automatically starts a new transaction. Not all providers support this.
365
+ # If not, we will have to start a new transaction by this command:
366
+ self.transaction_level = self.connector.BeginTrans()
367
+ except Exception as e:
368
+ self._raiseConnectionError(api.ProgrammingError, e)
369
+
370
+ def _rollback(self):
371
+ """In case a database does provide transactions this method causes the the database to roll back to
372
+ the start of any pending transaction. Closing a connection without committing the changes first will
373
+ cause an implicit rollback to be performed.
374
+
375
+ If the database does not support the functionality required by the method, the interface should
376
+ throw an exception in case the method is used.
377
+ The preferred approach is to not implement the method and thus have Python generate
378
+ an AttributeError in case the method is requested. This allows the programmer to check for database
379
+ capabilities using the standard hasattr() function.
380
+
381
+ For some dynamically configured interfaces it may not be appropriate to require dynamically making
382
+ the method available. These interfaces should then raise a NotSupportedError to indicate the
383
+ non-ability to perform the roll back when the method is invoked.
384
+ """
385
+ self.messages = []
386
+ if (
387
+ self.transaction_level
388
+ ): # trying to roll back with no open transaction causes an error
389
+ try:
390
+ self.transaction_level = self.connector.RollbackTrans()
391
+ if verbose > 1:
392
+ print("rollback done on connection at %X" % id(self))
393
+ if not self._autocommit and not (
394
+ self.connector.Attributes & adc.adXactAbortRetaining
395
+ ):
396
+ # If attributes has adXactAbortRetaining it performs retaining aborts that is,
397
+ # calling RollbackTrans automatically starts a new transaction. Not all providers support this.
398
+ # If not, we will have to start a new transaction by this command:
399
+ if not self.transaction_level:
400
+ self.transaction_level = self.connector.BeginTrans()
401
+ except Exception as e:
402
+ self._raiseConnectionError(api.ProgrammingError, e)
403
+
404
+ def __setattr__(self, name, value):
405
+ if name == "autocommit": # extension: allow user to turn autocommit on or off
406
+ if self.supportsTransactions:
407
+ object.__setattr__(self, "_autocommit", bool(value))
408
+ try:
409
+ self._rollback() # must clear any outstanding transactions
410
+ except:
411
+ pass
412
+ return
413
+ elif name == "paramstyle":
414
+ if value not in api.accepted_paramstyles:
415
+ self._raiseConnectionError(
416
+ api.NotSupportedError,
417
+ f"paramstyle={value!r} not in:{api.accepted_paramstyles!r}",
418
+ )
419
+ elif name == "variantConversions":
420
+ # make a new copy -- no changes in the default, please
421
+ value = copy.copy(value)
422
+ object.__setattr__(self, name, value)
423
+
424
+ def __getattr__(self, item):
425
+ if (
426
+ item == "rollback"
427
+ ): # the rollback method only appears if the database supports transactions
428
+ if self.supportsTransactions:
429
+ return (
430
+ self._rollback
431
+ ) # return the rollback method so the caller can execute it.
432
+ else:
433
+ raise AttributeError("this data provider does not support Rollback")
434
+ elif item == "autocommit":
435
+ return self._autocommit
436
+ else:
437
+ raise AttributeError(
438
+ 'no such attribute in ADO connection object as="%s"' % item
439
+ )
440
+
441
+ def cursor(self):
442
+ "Return a new Cursor Object using the connection."
443
+ self.messages = []
444
+ c = Cursor(self)
445
+ return c
446
+
447
+ def _i_am_here(self, crsr):
448
+ "message from a new cursor proclaiming its existence"
449
+ oid = id(crsr)
450
+ self.cursors[oid] = crsr
451
+
452
+ def _i_am_closing(self, crsr):
453
+ "message from a cursor giving connection a chance to clean up"
454
+ try:
455
+ del self.cursors[id(crsr)]
456
+ except:
457
+ pass
458
+
459
+ def printADOerrors(self):
460
+ j = self.connector.Errors.Count
461
+ if j:
462
+ print("ADO Errors:(%i)" % j)
463
+ for e in self.connector.Errors:
464
+ print("Description: %s" % e.Description)
465
+ print("Error: %s %s " % (e.Number, adc.adoErrors.get(e.Number, "unknown")))
466
+ if e.Number == adc.ado_error_TIMEOUT:
467
+ print(
468
+ "Timeout Error: Try using adodbpi.connect(constr,timeout=Nseconds)"
469
+ )
470
+ print("Source: %s" % e.Source)
471
+ print("NativeError: %s" % e.NativeError)
472
+ print("SQL State: %s" % e.SQLState)
473
+
474
+ def _suggest_error_class(self):
475
+ """Introspect the current ADO Errors and determine an appropriate error class.
476
+
477
+ Error.SQLState is a SQL-defined error condition, per the SQL specification:
478
+ https://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt
479
+
480
+ The 23000 class of errors are integrity errors.
481
+ Error 40002 is a transactional integrity error.
482
+ """
483
+ if self.connector is not None:
484
+ for e in self.connector.Errors:
485
+ state = str(e.SQLState)
486
+ if state.startswith("23") or state == "40002":
487
+ return api.IntegrityError
488
+ return api.DatabaseError
489
+
490
+ def __del__(self):
491
+ try:
492
+ self._closeAdoConnection() # v2.1 Rose
493
+ except:
494
+ pass
495
+ self.connector = None
496
+
497
+ def __enter__(self): # Connections are context managers
498
+ return self
499
+
500
+ def __exit__(self, exc_type, exc_val, exc_tb):
501
+ if exc_type:
502
+ self._rollback() # automatic rollback on errors
503
+ else:
504
+ self.commit()
505
+
506
+ def get_table_names(self):
507
+ schema = self.connector.OpenSchema(20) # constant = adSchemaTables
508
+
509
+ tables = []
510
+ while not schema.EOF:
511
+ name = getIndexedValue(schema.Fields, "TABLE_NAME").Value
512
+ tables.append(name)
513
+ schema.MoveNext()
514
+ del schema
515
+ return tables
516
+
517
+
518
+ # # # # # ----- the Class that defines a cursor ----- # # # # #
519
+ class Cursor:
520
+ ## ** api required attributes:
521
+ ## description...
522
+ ## This read-only attribute is a sequence of 7-item sequences.
523
+ ## Each of these sequences contains information describing one result column:
524
+ ## (name, type_code, display_size, internal_size, precision, scale, null_ok).
525
+ ## This attribute will be None for operations that do not return rows or if the
526
+ ## cursor has not had an operation invoked via the executeXXX() method yet.
527
+ ## The type_code can be interpreted by comparing it to the Type Objects specified in the section below.
528
+ ## rowcount...
529
+ ## This read-only attribute specifies the number of rows that the last executeXXX() produced
530
+ ## (for DQL statements like select) or affected (for DML statements like update or insert).
531
+ ## The attribute is -1 in case no executeXXX() has been performed on the cursor or
532
+ ## the rowcount of the last operation is not determinable by the interface.[7]
533
+ ## arraysize...
534
+ ## This read/write attribute specifies the number of rows to fetch at a time with fetchmany().
535
+ ## It defaults to 1 meaning to fetch a single row at a time.
536
+ ## Implementations must observe this value with respect to the fetchmany() method,
537
+ ## but are free to interact with the database a single row at a time.
538
+ ## It may also be used in the implementation of executemany().
539
+ ## ** extension attributes:
540
+ ## paramstyle...
541
+ ## allows the programmer to override the connection's default paramstyle
542
+ ## errorhandler...
543
+ ## allows the programmer to override the connection's default error handler
544
+
545
+ def __init__(self, connection):
546
+ self.command = None
547
+ self._ado_prepared = False
548
+ self.messages = []
549
+ self.connection = connection
550
+ self.paramstyle = connection.paramstyle # used for overriding the paramstyle
551
+ self._parameter_names = []
552
+ self.recordset_is_remote = False
553
+ self.rs = None # the ADO recordset for this cursor
554
+ self.converters = [] # conversion function for each column
555
+ self.columnNames = {} # names of columns {lowercase name : number,...}
556
+ self.numberOfColumns = 0
557
+ self._description = None
558
+ self.rowcount = -1
559
+ self.errorhandler = connection.errorhandler
560
+ self.arraysize = 1
561
+ connection._i_am_here(self)
562
+ if verbose:
563
+ print(
564
+ "%s New cursor at %X on conn %X"
565
+ % (version, id(self), id(self.connection))
566
+ )
567
+
568
+ def __iter__(self): # [2.1 Zamarev]
569
+ return iter(self.fetchone, None) # [2.1 Zamarev]
570
+
571
+ def prepare(self, operation):
572
+ self.command = operation
573
+ self._description = None
574
+ self._ado_prepared = "setup"
575
+
576
+ def __next__(self):
577
+ r = self.fetchone()
578
+ if r:
579
+ return r
580
+ raise StopIteration
581
+
582
+ def __enter__(self):
583
+ "Allow database cursors to be used with context managers."
584
+ return self
585
+
586
+ def __exit__(self, exc_type, exc_val, exc_tb):
587
+ "Allow database cursors to be used with context managers."
588
+ self.close()
589
+
590
+ def _raiseCursorError(self, errorclass, errorvalue):
591
+ eh = self.errorhandler
592
+ if eh is None:
593
+ eh = api.standardErrorHandler
594
+ eh(self.connection, self, errorclass, errorvalue)
595
+
596
+ def build_column_info(self, recordset):
597
+ self.converters = [] # conversion function for each column
598
+ self.columnNames = {} # names of columns {lowercase name : number,...}
599
+ self._description = None
600
+
601
+ # if EOF and BOF are true at the same time, there are no records in the recordset
602
+ if (recordset is None) or (recordset.State == adc.adStateClosed):
603
+ self.rs = None
604
+ self.numberOfColumns = 0
605
+ return
606
+ self.rs = recordset # v2.1.1 bkline
607
+ self.recordset_format = api.RS_WIN_32
608
+ self.numberOfColumns = recordset.Fields.Count
609
+ try:
610
+ varCon = self.connection.variantConversions
611
+ except AttributeError:
612
+ varCon = api.variantConversions
613
+ for i in range(self.numberOfColumns):
614
+ f = getIndexedValue(self.rs.Fields, i)
615
+ try:
616
+ self.converters.append(
617
+ varCon[f.Type]
618
+ ) # conversion function for this column
619
+ except KeyError:
620
+ self._raiseCursorError(
621
+ api.InternalError, "Data column of Unknown ADO type=%s" % f.Type
622
+ )
623
+ self.columnNames[f.Name.lower()] = i # columnNames lookup
624
+
625
+ def _makeDescriptionFromRS(self):
626
+ # Abort if closed or no recordset.
627
+ if self.rs is None:
628
+ self._description = None
629
+ return
630
+ desc = []
631
+ for i in range(self.numberOfColumns):
632
+ f = getIndexedValue(self.rs.Fields, i)
633
+ if self.rs.EOF or self.rs.BOF:
634
+ display_size = None
635
+ else:
636
+ # TODO: Is this the correct defintion according to the DB API 2 Spec ?
637
+ display_size = f.ActualSize
638
+ null_ok = bool(f.Attributes & adc.adFldMayBeNull) # v2.1 Cole
639
+ desc.append(
640
+ (
641
+ f.Name,
642
+ f.Type,
643
+ display_size,
644
+ f.DefinedSize,
645
+ f.Precision,
646
+ f.NumericScale,
647
+ null_ok,
648
+ )
649
+ )
650
+ self._description = desc
651
+
652
+ def get_description(self):
653
+ if not self._description:
654
+ self._makeDescriptionFromRS()
655
+ return self._description
656
+
657
+ def __getattr__(self, item):
658
+ if item == "description":
659
+ return self.get_description()
660
+ object.__getattribute__(
661
+ self, item
662
+ ) # may get here on Remote attribute calls for existing attributes
663
+
664
+ def format_description(self, d):
665
+ """Format db_api description tuple for printing."""
666
+ if self.description is None:
667
+ self._makeDescriptionFromRS()
668
+ if isinstance(d, int):
669
+ d = self.description[d]
670
+ desc = (
671
+ "Name= %s, Type= %s, DispSize= %s, IntSize= %s, Precision= %s, Scale= %s NullOK=%s"
672
+ % (
673
+ d[0],
674
+ adc.adTypeNames.get(d[1], str(d[1]) + " (unknown type)"),
675
+ d[2],
676
+ d[3],
677
+ d[4],
678
+ d[5],
679
+ d[6],
680
+ )
681
+ )
682
+ return desc
683
+
684
+ def close(self, dont_tell_me=False):
685
+ """Close the cursor now (rather than whenever __del__ is called).
686
+ The cursor will be unusable from this point forward; an Error (or subclass)
687
+ exception will be raised if any operation is attempted with the cursor.
688
+ """
689
+ if self.connection is None:
690
+ return
691
+ self.messages = []
692
+ if (
693
+ self.rs and self.rs.State != adc.adStateClosed
694
+ ): # rs exists and is open #v2.1 Rose
695
+ self.rs.Close() # v2.1 Rose
696
+ self.rs = None # let go of the recordset so ADO will let it be disposed #v2.1 Rose
697
+ if not dont_tell_me:
698
+ self.connection._i_am_closing(
699
+ self
700
+ ) # take me off the connection's cursors list
701
+ self.connection = (
702
+ None # this will make all future method calls on me throw an exception
703
+ )
704
+ if verbose:
705
+ print("adodbapi Closed cursor at %X" % id(self))
706
+
707
+ def __del__(self):
708
+ try:
709
+ self.close()
710
+ except:
711
+ pass
712
+
713
+ def _new_command(self, command_type=adc.adCmdText):
714
+ self.cmd = None
715
+ self.messages = []
716
+
717
+ if self.connection is None:
718
+ self._raiseCursorError(api.InterfaceError, None)
719
+ return
720
+ try:
721
+ self.cmd = Dispatch("ADODB.Command")
722
+ self.cmd.ActiveConnection = self.connection.connector
723
+ self.cmd.CommandTimeout = self.connection.timeout
724
+ self.cmd.CommandType = command_type
725
+ self.cmd.CommandText = self.commandText
726
+ self.cmd.Prepared = bool(self._ado_prepared)
727
+ except:
728
+ self._raiseCursorError(
729
+ api.DatabaseError,
730
+ f"Error creating new ADODB.Command object for {self.commandText!r}",
731
+ )
732
+
733
+ def _execute_command(self):
734
+ # Stored procedures may have an integer return value
735
+ self.return_value = None
736
+ recordset = None
737
+ count = -1 # default value
738
+ if verbose:
739
+ print('Executing command="%s"' % self.commandText)
740
+ try:
741
+ # ----- the actual SQL is executed here ---
742
+ recordset, count = self.cmd.Execute()
743
+ # ----- ------------------------------- ---
744
+ except Exception as e:
745
+ _message = ""
746
+ if hasattr(e, "args"):
747
+ _message += str(e.args) + "\n"
748
+ _message += "Command:\n%s\nParameters:\n%s" % (
749
+ self.commandText,
750
+ format_parameters(self.cmd.Parameters, True),
751
+ )
752
+ klass = self.connection._suggest_error_class()
753
+ self._raiseCursorError(klass, _message)
754
+ try:
755
+ self.rowcount = recordset.RecordCount
756
+ except:
757
+ self.rowcount = count
758
+ self.build_column_info(recordset)
759
+
760
+ # The ADO documentation hints that obtaining the recordcount may be timeconsuming
761
+ # "If the Recordset object does not support approximate positioning, this property
762
+ # may be a significant drain on resources # [ekelund]
763
+ # Therefore, COM will not return rowcount for server-side cursors. [Cole]
764
+ # Client-side cursors (the default since v2.8) will force a static
765
+ # cursor, and rowcount will then be set accurately [Cole]
766
+
767
+ def get_rowcount(self):
768
+ return self.rowcount
769
+
770
+ def get_returned_parameters(self):
771
+ """with some providers, returned parameters and the .return_value are not available until
772
+ after the last recordset has been read. In that case, you must coll nextset() until it
773
+ returns None, then call this method to get your returned information."""
774
+
775
+ # store procedures may return altered parameters, including an added "return value" item
776
+ retLst = []
777
+ for p in tuple(self.cmd.Parameters):
778
+ if verbose > 2:
779
+ print(
780
+ 'Returned=Name: %s, Dir.: %s, Type: %s, Size: %s, Value: "%s",'
781
+ " Precision: %s, NumericScale: %s"
782
+ % (
783
+ p.Name,
784
+ adc.directions[p.Direction],
785
+ adc.adTypeNames.get(p.Type, str(p.Type) + " (unknown type)"),
786
+ p.Size,
787
+ p.Value,
788
+ p.Precision,
789
+ p.NumericScale,
790
+ )
791
+ )
792
+ pyObject = api.convert_to_python(p.Value, api.variantConversions[p.Type])
793
+ if p.Direction == adc.adParamReturnValue:
794
+ self.returnValue = (
795
+ pyObject # also load the undocumented attribute (Vernon's Error!)
796
+ )
797
+ self.return_value = pyObject
798
+ else:
799
+ retLst.append(pyObject)
800
+ return retLst # return the parameter list to the caller
801
+
802
+ def callproc(self, procname, parameters=None):
803
+ """Call a stored database procedure with the given name.
804
+ The sequence of parameters must contain one entry for each
805
+ argument that the sproc expects. The result of the
806
+ call is returned as modified copy of the input
807
+ sequence. Input parameters are left untouched, output and
808
+ input/output parameters replaced with possibly new values.
809
+
810
+ The sproc may also provide a result set as output,
811
+ which is available through the standard .fetch*() methods.
812
+ Extension: A "return_value" property may be set on the
813
+ cursor if the sproc defines an integer return value.
814
+ """
815
+ self._parameter_names = []
816
+ self.commandText = procname
817
+ self._new_command(command_type=adc.adCmdStoredProc)
818
+ self._buildADOparameterList(parameters, sproc=True)
819
+ if verbose > 2:
820
+ print(
821
+ "Calling Stored Proc with Params=",
822
+ format_parameters(self.cmd.Parameters, True),
823
+ )
824
+ self._execute_command()
825
+ return self.get_returned_parameters()
826
+
827
+ def _reformat_operation(self, operation, parameters):
828
+ if self.paramstyle in ("format", "pyformat"): # convert %s to ?
829
+ operation, self._parameter_names = api.changeFormatToQmark(operation)
830
+ elif self.paramstyle == "named" or (
831
+ self.paramstyle == "dynamic" and isinstance(parameters, Mapping)
832
+ ):
833
+ operation, self._parameter_names = api.changeNamedToQmark(
834
+ operation
835
+ ) # convert :name to ?
836
+ return operation
837
+
838
+ def _buildADOparameterList(self, parameters, sproc=False):
839
+ self.parameters = parameters
840
+ if parameters is None:
841
+ parameters = []
842
+
843
+ # Note: ADO does not preserve the parameter list, even if "Prepared" is True, so we must build every time.
844
+ parameters_known = False
845
+ if sproc: # needed only if we are calling a stored procedure
846
+ try: # attempt to use ADO's parameter list
847
+ self.cmd.Parameters.Refresh()
848
+ if verbose > 2:
849
+ print(
850
+ "ADO detected Params=",
851
+ format_parameters(self.cmd.Parameters, True),
852
+ )
853
+ print(f"Program Parameters={parameters!r}")
854
+ parameters_known = True
855
+ except api.Error:
856
+ if verbose:
857
+ print("ADO Parameter Refresh failed")
858
+ pass
859
+ else:
860
+ if len(parameters) != self.cmd.Parameters.Count - 1:
861
+ raise api.ProgrammingError(
862
+ "You must supply %d parameters for this stored procedure"
863
+ % (self.cmd.Parameters.Count - 1)
864
+ )
865
+ if sproc or parameters != []:
866
+ i = 0
867
+ if parameters_known: # use ado parameter list
868
+ if self._parameter_names: # named parameters
869
+ for i, pm_name in enumerate(self._parameter_names):
870
+ p = getIndexedValue(self.cmd.Parameters, i)
871
+ try:
872
+ _configure_parameter(
873
+ p, parameters[pm_name], p.Type, parameters_known
874
+ )
875
+ except Exception as e:
876
+ _message = "Error Converting Parameter {}: {}, {} <- {!r}\n".format(
877
+ p.Name,
878
+ adc.ado_type_name(p.Type),
879
+ p.Value,
880
+ parameters[pm_name],
881
+ )
882
+ self._raiseCursorError(
883
+ api.DataError, f"{_message}->{e.args!r}"
884
+ )
885
+ else: # regular sequence of parameters
886
+ for value in parameters:
887
+ p = getIndexedValue(self.cmd.Parameters, i)
888
+ if (
889
+ p.Direction == adc.adParamReturnValue
890
+ ): # this is an extra parameter added by ADO
891
+ i += 1 # skip the extra
892
+ p = getIndexedValue(self.cmd.Parameters, i)
893
+ try:
894
+ _configure_parameter(p, value, p.Type, parameters_known)
895
+ except Exception as e:
896
+ _message = "Error Converting Parameter {}: {}, {} <- {!r}\n".format(
897
+ p.Name,
898
+ adc.ado_type_name(p.Type),
899
+ p.Value,
900
+ value,
901
+ )
902
+ self._raiseCursorError(
903
+ api.DataError, f"{_message}->{e.args!r}"
904
+ )
905
+ i += 1
906
+ else: # -- build own parameter list
907
+ # we expect a dictionary of parameters, this is the list of expected names
908
+ if self._parameter_names:
909
+ for parm_name in self._parameter_names:
910
+ elem = parameters[parm_name]
911
+ adotype = api.pyTypeToADOType(elem)
912
+ p = self.cmd.CreateParameter(
913
+ parm_name, adotype, adc.adParamInput
914
+ )
915
+ _configure_parameter(p, elem, adotype, parameters_known)
916
+ try:
917
+ self.cmd.Parameters.Append(p)
918
+ except Exception as e:
919
+ _message = (
920
+ "Error Building Parameter {}: {}, {} <- {!r}\n".format(
921
+ p.Name,
922
+ adc.ado_type_name(p.Type),
923
+ p.Value,
924
+ elem,
925
+ )
926
+ )
927
+ self._raiseCursorError(
928
+ api.DataError, f"{_message}->{e.args!r}"
929
+ )
930
+ else: # expecting the usual sequence of parameters
931
+ if sproc:
932
+ p = self.cmd.CreateParameter(
933
+ "@RETURN_VALUE", adc.adInteger, adc.adParamReturnValue
934
+ )
935
+ self.cmd.Parameters.Append(p)
936
+
937
+ for elem in parameters:
938
+ name = "p%i" % i
939
+ adotype = api.pyTypeToADOType(elem)
940
+ p = self.cmd.CreateParameter(
941
+ name, adotype, adc.adParamInput
942
+ ) # Name, Type, Direction, Size, Value
943
+ _configure_parameter(p, elem, adotype, parameters_known)
944
+ try:
945
+ self.cmd.Parameters.Append(p)
946
+ except Exception as e:
947
+ _message = (
948
+ "Error Building Parameter {}: {}, {} <- {!r}\n".format(
949
+ p.Name,
950
+ adc.ado_type_name(p.Type),
951
+ p.Value,
952
+ elem,
953
+ )
954
+ )
955
+ self._raiseCursorError(
956
+ api.DataError, f"{_message}->{e.args!r}"
957
+ )
958
+ i += 1
959
+ if self._ado_prepared == "setup":
960
+ self._ado_prepared = (
961
+ True # parameters will be "known" by ADO next loop
962
+ )
963
+
964
+ def execute(self, operation, parameters=None):
965
+ """Prepare and execute a database operation (query or command).
966
+
967
+ Parameters may be provided as sequence or mapping and will be bound to variables in the operation.
968
+ Variables are specified in a database-specific notation
969
+ (see the module's paramstyle attribute for details). [5]
970
+ A reference to the operation will be retained by the cursor.
971
+ If the same operation object is passed in again, then the cursor
972
+ can optimize its behavior. This is most effective for algorithms
973
+ where the same operation is used, but different parameters are bound to it (many times).
974
+
975
+ For maximum efficiency when reusing an operation, it is best to use
976
+ the setinputsizes() method to specify the parameter types and sizes ahead of time.
977
+ It is legal for a parameter to not match the predefined information;
978
+ the implementation should compensate, possibly with a loss of efficiency.
979
+
980
+ The parameters may also be specified as list of tuples to e.g. insert multiple rows in
981
+ a single operation, but this kind of usage is depreciated: executemany() should be used instead.
982
+
983
+ Return value is not defined.
984
+
985
+ [5] The module will use the __getitem__ method of the parameters object to map either positions
986
+ (integers) or names (strings) to parameter values. This allows for both sequences and mappings
987
+ to be used as input.
988
+ The term "bound" refers to the process of binding an input value to a database execution buffer.
989
+ In practical terms, this means that the input value is directly used as a value in the operation.
990
+ The client should not be required to "escape" the value so that it can be used -- the value
991
+ should be equal to the actual database value."""
992
+ if (
993
+ self.command is not operation
994
+ or self._ado_prepared == "setup"
995
+ or not hasattr(self, "commandText")
996
+ ):
997
+ if self.command is not operation:
998
+ self._ado_prepared = False
999
+ self.command = operation
1000
+ self._parameter_names = []
1001
+ self.commandText = (
1002
+ operation
1003
+ if (self.paramstyle == "qmark" or not parameters)
1004
+ else self._reformat_operation(operation, parameters)
1005
+ )
1006
+ self._new_command()
1007
+ self._buildADOparameterList(parameters)
1008
+ if verbose > 3:
1009
+ print("Params=", format_parameters(self.cmd.Parameters, True))
1010
+ self._execute_command()
1011
+
1012
+ def executemany(self, operation, seq_of_parameters):
1013
+ """Prepare a database operation (query or command)
1014
+ and then execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
1015
+
1016
+ Return values are not defined.
1017
+ """
1018
+ self.messages = list()
1019
+ total_recordcount = 0
1020
+
1021
+ self.prepare(operation)
1022
+ for params in seq_of_parameters:
1023
+ self.execute(self.command, params)
1024
+ if self.rowcount == -1:
1025
+ total_recordcount = -1
1026
+ if total_recordcount != -1:
1027
+ total_recordcount += self.rowcount
1028
+ self.rowcount = total_recordcount
1029
+
1030
+ def _fetch(self, limit=None):
1031
+ """Fetch rows from the current recordset.
1032
+
1033
+ limit -- Number of rows to fetch, or None (default) to fetch all rows.
1034
+ """
1035
+ if self.connection is None or self.rs is None:
1036
+ self._raiseCursorError(
1037
+ api.FetchFailedError, "fetch() on closed connection or empty query set"
1038
+ )
1039
+ return
1040
+
1041
+ if self.rs.State == adc.adStateClosed or self.rs.BOF or self.rs.EOF:
1042
+ return list()
1043
+ if limit: # limit number of rows retrieved
1044
+ ado_results = self.rs.GetRows(limit)
1045
+ else: # get all rows
1046
+ ado_results = self.rs.GetRows()
1047
+ if (
1048
+ self.recordset_format == api.RS_ARRAY
1049
+ ): # result of GetRows is a two-dimension array
1050
+ length = (
1051
+ len(ado_results) // self.numberOfColumns
1052
+ ) # length of first dimension
1053
+ else: # pywin32
1054
+ length = len(ado_results[0]) # result of GetRows is tuples in a tuple
1055
+ fetchObject = api.SQLrows(
1056
+ ado_results, length, self
1057
+ ) # new object to hold the results of the fetch
1058
+ return fetchObject
1059
+
1060
+ def fetchone(self):
1061
+ """Fetch the next row of a query result set, returning a single sequence,
1062
+ or None when no more data is available.
1063
+
1064
+ An Error (or subclass) exception is raised if the previous call to executeXXX()
1065
+ did not produce any result set or no call was issued yet.
1066
+ """
1067
+ self.messages = []
1068
+ result = self._fetch(1)
1069
+ if result: # return record (not list of records)
1070
+ return result[0]
1071
+ return None
1072
+
1073
+ def fetchmany(self, size=None):
1074
+ """Fetch the next set of rows of a query result, returning a list of tuples. An empty sequence is returned when no more rows are available.
1075
+
1076
+ The number of rows to fetch per call is specified by the parameter.
1077
+ If it is not given, the cursor's arraysize determines the number of rows to be fetched.
1078
+ The method should try to fetch as many rows as indicated by the size parameter.
1079
+ If this is not possible due to the specified number of rows not being available,
1080
+ fewer rows may be returned.
1081
+
1082
+ An Error (or subclass) exception is raised if the previous call to executeXXX()
1083
+ did not produce any result set or no call was issued yet.
1084
+
1085
+ Note there are performance considerations involved with the size parameter.
1086
+ For optimal performance, it is usually best to use the arraysize attribute.
1087
+ If the size parameter is used, then it is best for it to retain the same value from
1088
+ one fetchmany() call to the next.
1089
+ """
1090
+ self.messages = []
1091
+ if size is None:
1092
+ size = self.arraysize
1093
+ return self._fetch(size)
1094
+
1095
+ def fetchall(self):
1096
+ """Fetch all (remaining) rows of a query result, returning them as a sequence of sequences (e.g. a list of tuples).
1097
+
1098
+ Note that the cursor's arraysize attribute
1099
+ can affect the performance of this operation.
1100
+ An Error (or subclass) exception is raised if the previous call to executeXXX()
1101
+ did not produce any result set or no call was issued yet.
1102
+ """
1103
+ self.messages = []
1104
+ return self._fetch()
1105
+
1106
+ def nextset(self):
1107
+ """Skip to the next available recordset, discarding any remaining rows from the current recordset.
1108
+
1109
+ If there are no more sets, the method returns None. Otherwise, it returns a true
1110
+ value and subsequent calls to the fetch methods will return rows from the next result set.
1111
+
1112
+ An Error (or subclass) exception is raised if the previous call to executeXXX()
1113
+ did not produce any result set or no call was issued yet.
1114
+ """
1115
+ self.messages = []
1116
+ if self.connection is None or self.rs is None:
1117
+ self._raiseCursorError(
1118
+ api.OperationalError,
1119
+ ("nextset() on closed connection or empty query set"),
1120
+ )
1121
+ return None
1122
+
1123
+ try: # [begin 2.1 ekelund]
1124
+ rsTuple = self.rs.NextRecordset() #
1125
+ except pywintypes.com_error as exc: # return appropriate error
1126
+ self._raiseCursorError(api.NotSupportedError, exc.args) # [end 2.1 ekelund]
1127
+ recordset = rsTuple[0]
1128
+ if recordset is None:
1129
+ return None
1130
+ self.build_column_info(recordset)
1131
+ return True
1132
+
1133
+ def setinputsizes(self, sizes):
1134
+ pass
1135
+
1136
+ def setoutputsize(self, size, column=None):
1137
+ pass
1138
+
1139
+ def _last_query(self): # let the programmer see what query we actually used
1140
+ try:
1141
+ if self.parameters is None:
1142
+ ret = self.commandText
1143
+ else:
1144
+ ret = f"{self.commandText},parameters={self.parameters!r}"
1145
+ except:
1146
+ ret = None
1147
+ return ret
1148
+
1149
+ query = property(_last_query, None, None, "returns the last query executed")
1150
+
1151
+
1152
+ if __name__ == "__main__":
1153
+ raise api.ProgrammingError(version + " cannot be run as a main program.")
venv/Lib/site-packages/adodbapi/apibase.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adodbapi.apibase - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
2
+
3
+ Copyright (C) 2002 Henrik Ekelund, version 2.1 by Vernon Cole
4
+ * https://sourceforge.net/projects/pywin32
5
+ * https://sourceforge.net/projects/adodbapi
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import datetime
11
+ import decimal
12
+ import numbers
13
+ import sys
14
+ import time
15
+ from collections.abc import Callable, Iterable, Mapping
16
+
17
+ # noinspection PyUnresolvedReferences
18
+ from . import ado_consts as adc
19
+
20
+ verbose = False # debugging flag
21
+
22
+
23
+ # ------- Error handlers ------
24
+ def standardErrorHandler(connection, cursor, errorclass, errorvalue):
25
+ err = (errorclass, errorvalue)
26
+ try:
27
+ connection.messages.append(err)
28
+ except:
29
+ pass
30
+ if cursor is not None:
31
+ try:
32
+ cursor.messages.append(err)
33
+ except:
34
+ pass
35
+ raise errorclass(errorvalue)
36
+
37
+
38
+ class Error(Exception):
39
+ pass # Exception that is the base class of all other error
40
+ # exceptions. You can use this to catch all errors with one
41
+ # single 'except' statement. Warnings are not considered
42
+ # errors and thus should not use this class as base. It must
43
+ # be a subclass of the Python StandardError (defined in the
44
+ # module exceptions).
45
+
46
+
47
+ class Warning(Exception):
48
+ pass
49
+
50
+
51
+ class InterfaceError(Error):
52
+ pass
53
+
54
+
55
+ class DatabaseError(Error):
56
+ pass
57
+
58
+
59
+ class InternalError(DatabaseError):
60
+ pass
61
+
62
+
63
+ class OperationalError(DatabaseError):
64
+ pass
65
+
66
+
67
+ class ProgrammingError(DatabaseError):
68
+ pass
69
+
70
+
71
+ class IntegrityError(DatabaseError):
72
+ pass
73
+
74
+
75
+ class DataError(DatabaseError):
76
+ pass
77
+
78
+
79
+ class NotSupportedError(DatabaseError):
80
+ pass
81
+
82
+
83
+ class FetchFailedError(OperationalError):
84
+ """
85
+ Error is used by RawStoredProcedureQuerySet to determine when a fetch
86
+ failed due to a connection being closed or there is no record set
87
+ returned. (Non-standard, added especially for django)
88
+ """
89
+
90
+ pass
91
+
92
+
93
+ # # # # # ----- Type Objects and Constructors ----- # # # # #
94
+ # Many databases need to have the input in a particular format for binding to an operation's input parameters.
95
+ # For example, if an input is destined for a DATE column, then it must be bound to the database in a particular
96
+ # string format. Similar problems exist for "Row ID" columns or large binary items (e.g. blobs or RAW columns).
97
+ # This presents problems for Python since the parameters to the executeXXX() method are untyped.
98
+ # When the database module sees a Python string object, it doesn't know if it should be bound as a simple CHAR
99
+ # column, as a raw BINARY item, or as a DATE.
100
+ #
101
+ # To overcome this problem, a module must provide the constructors defined below to create objects that can
102
+ # hold special values. When passed to the cursor methods, the module can then detect the proper type of
103
+ # the input parameter and bind it accordingly.
104
+
105
+ # A Cursor Object's description attribute returns information about each of the result columns of a query.
106
+ # The type_code must compare equal to one of Type Objects defined below. Type Objects may be equal to more than
107
+ # one type code (e.g. DATETIME could be equal to the type codes for date, time and timestamp columns;
108
+ # see the Implementation Hints below for details).
109
+
110
+ # SQL NULL values are represented by the Python None singleton on input and output.
111
+
112
+ # Note: Usage of Unix ticks for database interfacing can cause troubles because of the limited date range they cover.
113
+
114
+
115
+ # def Date(year,month,day):
116
+ # "This function constructs an object holding a date value. "
117
+ # return dateconverter.date(year,month,day) #dateconverter.Date(year,month,day)
118
+ #
119
+ # def Time(hour,minute,second):
120
+ # "This function constructs an object holding a time value. "
121
+ # return dateconverter.time(hour, minute, second) # dateconverter.Time(hour,minute,second)
122
+ #
123
+ # def Timestamp(year,month,day,hour,minute,second):
124
+ # "This function constructs an object holding a time stamp value. "
125
+ # return dateconverter.datetime(year,month,day,hour,minute,second)
126
+ #
127
+ # def DateFromTicks(ticks):
128
+ # """This function constructs an object holding a date value from the given ticks value
129
+ # (number of seconds since the epoch; see the documentation of the standard Python time module for details). """
130
+ # return Date(*time.gmtime(ticks)[:3])
131
+ #
132
+ # def TimeFromTicks(ticks):
133
+ # """This function constructs an object holding a time value from the given ticks value
134
+ # (number of seconds since the epoch; see the documentation of the standard Python time module for details). """
135
+ # return Time(*time.gmtime(ticks)[3:6])
136
+ #
137
+ # def TimestampFromTicks(ticks):
138
+ # """This function constructs an object holding a time stamp value from the given
139
+ # ticks value (number of seconds since the epoch;
140
+ # see the documentation of the standard Python time module for details). """
141
+ # return Timestamp(*time.gmtime(ticks)[:6])
142
+ #
143
+ # def Binary(aString):
144
+ # """This function constructs an object capable of holding a binary (long) string value. """
145
+ # b = bytes(aString)
146
+ # return b
147
+ # ----- Time converters ----------------------------------------------
148
+ class TimeConverter: # this is a generic time converter skeleton
149
+ def __init__(self): # the details will be filled in by instances
150
+ self._ordinal_1899_12_31 = datetime.date(1899, 12, 31).toordinal() - 1
151
+ # Use cls.types to compare if an input parameter is a datetime
152
+ self.types = {
153
+ # Dynamically get the types as the methods may be overriden
154
+ type(self.Date(2000, 1, 1)),
155
+ type(self.Time(12, 1, 1)),
156
+ type(self.Timestamp(2000, 1, 1, 12, 1, 1)),
157
+ datetime.datetime,
158
+ datetime.time,
159
+ datetime.date,
160
+ }
161
+
162
+ def COMDate(self, obj):
163
+ """Returns a ComDate from a date-time"""
164
+ try: # most likely a datetime
165
+ tt = obj.timetuple()
166
+
167
+ try:
168
+ ms = obj.microsecond
169
+ except:
170
+ ms = 0
171
+ return self.ComDateFromTuple(tt, ms)
172
+ except: # might be a tuple
173
+ try:
174
+ return self.ComDateFromTuple(obj)
175
+ except:
176
+ raise ValueError(f'Cannot convert "{obj!r}" to COMdate.')
177
+
178
+ def ComDateFromTuple(self, t, microseconds=0):
179
+ d = datetime.date(t[0], t[1], t[2])
180
+ integerPart = d.toordinal() - self._ordinal_1899_12_31
181
+ ms = (t[3] * 3600 + t[4] * 60 + t[5]) * 1000000 + microseconds
182
+ fractPart = float(ms) / 86400000000.0
183
+ return integerPart + fractPart
184
+
185
+ def DateObjectFromCOMDate(self, comDate):
186
+ "Returns an object of the wanted type from a ComDate"
187
+ raise NotImplementedError # "Abstract class"
188
+
189
+ def Date(self, year, month, day):
190
+ "This function constructs an object holding a date value."
191
+ raise NotImplementedError # "Abstract class"
192
+
193
+ def Time(self, hour, minute, second):
194
+ "This function constructs an object holding a time value."
195
+ raise NotImplementedError # "Abstract class"
196
+
197
+ def Timestamp(self, year, month, day, hour, minute, second):
198
+ "This function constructs an object holding a time stamp value."
199
+ raise NotImplementedError # "Abstract class"
200
+ # all purpose date to ISO format converter
201
+
202
+ def DateObjectToIsoFormatString(self, obj):
203
+ "This function should return a string in the format 'YYYY-MM-dd HH:MM:SS:ms' (ms optional)"
204
+ try: # most likely, a datetime.datetime
205
+ s = obj.isoformat(" ")
206
+ except (TypeError, AttributeError):
207
+ if isinstance(obj, datetime.date):
208
+ s = obj.isoformat() + " 00:00:00" # return exact midnight
209
+ else:
210
+ try: # but may be time.struct_time
211
+ s = time.strftime("%Y-%m-%d %H:%M:%S", obj)
212
+ except:
213
+ raise ValueError(f'Cannot convert "{obj!r}" to isoformat')
214
+ return s
215
+
216
+
217
+ class pythonDateTimeConverter(TimeConverter): # standard since Python 2.3
218
+ def __init__(self):
219
+ TimeConverter.__init__(self)
220
+
221
+ def DateObjectFromCOMDate(self, comDate):
222
+ if isinstance(comDate, datetime.datetime):
223
+ odn = comDate.toordinal()
224
+ tim = comDate.time()
225
+ new = datetime.datetime.combine(datetime.datetime.fromordinal(odn), tim)
226
+ return new
227
+ # return comDate.replace(tzinfo=None) # make non aware
228
+ else:
229
+ fComDate = float(comDate) # ComDate is number of days since 1899-12-31
230
+ integerPart = int(fComDate)
231
+ floatpart = fComDate - integerPart
232
+ ##if floatpart == 0.0:
233
+ ## return datetime.date.fromordinal(integerPart + self._ordinal_1899_12_31)
234
+ dte = datetime.datetime.fromordinal(
235
+ integerPart + self._ordinal_1899_12_31
236
+ ) + datetime.timedelta(milliseconds=floatpart * 86400000)
237
+ # millisecondsperday=86400000 # 24*60*60*1000
238
+ return dte
239
+
240
+ def Date(self, year, month, day):
241
+ return datetime.date(year, month, day)
242
+
243
+ def Time(self, hour, minute, second):
244
+ return datetime.time(hour, minute, second)
245
+
246
+ def Timestamp(self, year, month, day, hour, minute, second):
247
+ return datetime.datetime(year, month, day, hour, minute, second)
248
+
249
+
250
+ class pythonTimeConverter(TimeConverter): # the old, ?nix type date and time
251
+ def __init__(self): # caution: this Class gets confised by timezones and DST
252
+ TimeConverter.__init__(self)
253
+ self.types.add(time.struct_time)
254
+
255
+ def DateObjectFromCOMDate(self, comDate):
256
+ "Returns ticks since 1970"
257
+ if isinstance(comDate, datetime.datetime):
258
+ return comDate.timetuple()
259
+ else:
260
+ fcomDate = float(comDate)
261
+ secondsperday = 86400 # 24*60*60
262
+ # ComDate is number of days since 1899-12-31, gmtime epoch is 1970-1-1 = 25569 days
263
+ t = time.gmtime(secondsperday * (fcomDate - 25569.0))
264
+ return t # year,month,day,hour,minute,second,weekday,julianday,daylightsaving=t
265
+
266
+ def Date(self, year, month, day):
267
+ return self.Timestamp(year, month, day, 0, 0, 0)
268
+
269
+ def Time(self, hour, minute, second):
270
+ return time.gmtime((hour * 60 + minute) * 60 + second)
271
+
272
+ def Timestamp(self, year, month, day, hour, minute, second):
273
+ return time.localtime(
274
+ time.mktime((year, month, day, hour, minute, second, 0, 0, -1))
275
+ )
276
+
277
+
278
+ base_dateconverter = pythonDateTimeConverter()
279
+
280
+ # ------ DB API required module attributes ---------------------
281
+ threadsafety = 1 # TODO -- find out whether this module is actually BETTER than 1.
282
+
283
+ apilevel = "2.0" # String constant stating the supported DB API level.
284
+
285
+ paramstyle = "qmark" # the default parameter style
286
+
287
+ # ------ control for an extension which may become part of DB API 3.0 ---
288
+ accepted_paramstyles = ("qmark", "named", "format", "pyformat", "dynamic")
289
+
290
+ # ------------------------------------------------------------------------------------------
291
+ # define similar types for generic conversion routines
292
+ adoIntegerTypes = (
293
+ adc.adInteger,
294
+ adc.adSmallInt,
295
+ adc.adTinyInt,
296
+ adc.adUnsignedInt,
297
+ adc.adUnsignedSmallInt,
298
+ adc.adUnsignedTinyInt,
299
+ adc.adBoolean,
300
+ adc.adError,
301
+ ) # max 32 bits
302
+ adoRowIdTypes = (adc.adChapter,) # v2.1 Rose
303
+ adoLongTypes = (adc.adBigInt, adc.adFileTime, adc.adUnsignedBigInt)
304
+ adoExactNumericTypes = (
305
+ adc.adDecimal,
306
+ adc.adNumeric,
307
+ adc.adVarNumeric,
308
+ adc.adCurrency,
309
+ ) # v2.3 Cole
310
+ adoApproximateNumericTypes = (adc.adDouble, adc.adSingle) # v2.1 Cole
311
+ adoStringTypes = (
312
+ adc.adBSTR,
313
+ adc.adChar,
314
+ adc.adLongVarChar,
315
+ adc.adLongVarWChar,
316
+ adc.adVarChar,
317
+ adc.adVarWChar,
318
+ adc.adWChar,
319
+ )
320
+ adoBinaryTypes = (adc.adBinary, adc.adLongVarBinary, adc.adVarBinary)
321
+ adoDateTimeTypes = (adc.adDBTime, adc.adDBTimeStamp, adc.adDate, adc.adDBDate)
322
+ adoRemainingTypes = (
323
+ adc.adEmpty,
324
+ adc.adIDispatch,
325
+ adc.adIUnknown,
326
+ adc.adPropVariant,
327
+ adc.adArray,
328
+ adc.adUserDefined,
329
+ adc.adVariant,
330
+ adc.adGUID,
331
+ )
332
+
333
+
334
+ # this class is a trick to determine whether a type is a member of a related group of types. see PEP notes
335
+ class DBAPITypeObject:
336
+ def __init__(self, valuesTuple):
337
+ self.values = frozenset(valuesTuple)
338
+
339
+ def __eq__(self, other):
340
+ return other in self.values
341
+
342
+ def __ne__(self, other):
343
+ return other not in self.values
344
+
345
+
346
+ """This type object is used to describe columns in a database that are string-based (e.g. CHAR). """
347
+ STRING = DBAPITypeObject(adoStringTypes)
348
+
349
+ """This type object is used to describe (long) binary columns in a database (e.g. LONG, RAW, BLOBs). """
350
+ BINARY = DBAPITypeObject(adoBinaryTypes)
351
+
352
+ """This type object is used to describe numeric columns in a database. """
353
+ NUMBER = DBAPITypeObject(
354
+ adoIntegerTypes + adoLongTypes + adoExactNumericTypes + adoApproximateNumericTypes
355
+ )
356
+
357
+ """This type object is used to describe date/time columns in a database. """
358
+
359
+ DATETIME = DBAPITypeObject(adoDateTimeTypes)
360
+ """This type object is used to describe the "Row ID" column in a database. """
361
+ ROWID = DBAPITypeObject(adoRowIdTypes)
362
+
363
+ OTHER = DBAPITypeObject(adoRemainingTypes)
364
+
365
+ # ------- utilities for translating python data types to ADO data types ---------------------------------
366
+ typeMap = {
367
+ memoryview: adc.adVarBinary,
368
+ float: adc.adDouble,
369
+ type(None): adc.adEmpty,
370
+ str: adc.adBSTR,
371
+ bool: adc.adBoolean, # v2.1 Cole
372
+ decimal.Decimal: adc.adDecimal,
373
+ int: adc.adBigInt,
374
+ bytes: adc.adVarBinary,
375
+ }
376
+
377
+
378
+ def pyTypeToADOType(d):
379
+ tp = type(d)
380
+ try:
381
+ return typeMap[tp]
382
+ except KeyError: # The type was not defined in the pre-computed Type table
383
+ from . import dateconverter
384
+
385
+ # maybe it is one of our supported Date/Time types
386
+ if tp in dateconverter.types:
387
+ return adc.adDate
388
+ # otherwise, attempt to discern the type by probing the data object itself -- to handle duck typing
389
+ if isinstance(d, str):
390
+ return adc.adBSTR
391
+ if isinstance(d, numbers.Integral):
392
+ return adc.adBigInt
393
+ if isinstance(d, numbers.Real):
394
+ return adc.adDouble
395
+ raise DataError(f'cannot convert "{d!r}" (type={tp}) to ADO')
396
+
397
+
398
+ # # # # # # # # # # # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
399
+ # functions to convert database values to Python objects
400
+ # ------------------------------------------------------------------------
401
+ # variant type : function converting variant to Python value
402
+ def variantConvertDate(v):
403
+ from . import dateconverter # this function only called when adodbapi is running
404
+
405
+ return dateconverter.DateObjectFromCOMDate(v)
406
+
407
+
408
+ def cvtString(variant): # use to get old action of adodbapi v1 if desired
409
+ return str(variant)
410
+
411
+
412
+ def cvtDecimal(variant): # better name
413
+ return _convertNumberWithCulture(variant, decimal.Decimal)
414
+
415
+
416
+ def cvtNumeric(variant): # older name - don't break old code
417
+ return cvtDecimal(variant)
418
+
419
+
420
+ def cvtFloat(variant):
421
+ return _convertNumberWithCulture(variant, float)
422
+
423
+
424
+ def _convertNumberWithCulture(variant, f):
425
+ try:
426
+ return f(variant)
427
+ except (ValueError, TypeError, decimal.InvalidOperation):
428
+ try:
429
+ europeVsUS = str(variant).replace(",", ".")
430
+ return f(europeVsUS)
431
+ except (ValueError, TypeError, decimal.InvalidOperation):
432
+ pass
433
+
434
+
435
+ def cvtInt(variant):
436
+ return int(variant)
437
+
438
+
439
+ def cvtLong(variant): # only important in old versions where long and int differ
440
+ return int(variant)
441
+
442
+
443
+ def cvtBuffer(variant):
444
+ return bytes(variant)
445
+
446
+
447
+ def cvtUnicode(variant):
448
+ return str(variant)
449
+
450
+
451
+ def identity(x):
452
+ return x
453
+
454
+
455
+ def cvtUnusual(variant):
456
+ if verbose > 1:
457
+ sys.stderr.write(f"Conversion called for Unusual data={variant!r}\n")
458
+ return variant # cannot find conversion function -- just give the data to the user
459
+
460
+
461
+ def convert_to_python(variant, func): # convert DB value into Python value
462
+ if variant is None:
463
+ return None
464
+ return func(variant) # call the appropriate conversion function
465
+
466
+
467
+ class MultiMap(dict[int, Callable[[object], object]]):
468
+ # builds a dictionary from {(iterable,of,keys) : function}
469
+ """A dictionary of ado.type : function
470
+ -- but you can set multiple items by passing an iterable of keys"""
471
+
472
+ # useful for defining conversion functions for groups of similar data types.
473
+ def __init__(self, aDict: Mapping[Iterable[int] | int, Callable[[object], object]]):
474
+ for k, v in aDict.items():
475
+ self[k] = v # we must call __setitem__
476
+
477
+ def __setitem__(
478
+ self, adoType: Iterable[int] | int, cvtFn: Callable[[object], object]
479
+ ):
480
+ "set a single item, or a whole iterable of items"
481
+ if isinstance(adoType, Iterable):
482
+ # user passed us an iterable, set them individually
483
+ for type in adoType:
484
+ dict.__setitem__(self, type, cvtFn)
485
+ else:
486
+ dict.__setitem__(self, adoType, cvtFn)
487
+
488
+
489
+ # initialize variantConversions dictionary used to convert SQL to Python
490
+ # this is the dictionary of default conversion functions, built by the class above.
491
+ # this becomes a class attribute for the Connection, and that attribute is used
492
+ # to build the list of column conversion functions for the Cursor
493
+ variantConversions = MultiMap(
494
+ {
495
+ adoDateTimeTypes: variantConvertDate,
496
+ adoApproximateNumericTypes: cvtFloat,
497
+ adoExactNumericTypes: cvtDecimal, # use to force decimal rather than unicode
498
+ adoLongTypes: cvtLong,
499
+ adoIntegerTypes: cvtInt,
500
+ adoRowIdTypes: cvtInt,
501
+ adoStringTypes: identity,
502
+ adoBinaryTypes: cvtBuffer,
503
+ adoRemainingTypes: cvtUnusual,
504
+ }
505
+ )
506
+
507
+ # # # # # classes to emulate the result of cursor.fetchxxx() as a sequence of sequences # # # # #
508
+ # "an ENUM of how my low level records are laid out"
509
+ RS_WIN_32, RS_ARRAY, RS_REMOTE = list(range(1, 4))
510
+
511
+
512
+ class SQLrow: # a single database row
513
+ # class to emulate a sequence, so that a column may be retrieved by either number or name
514
+ def __init__(self, rows, index): # "rows" is an _SQLrows object, index is which row
515
+ self.rows = rows # parent 'fetch' container object
516
+ self.index = index # my row number within parent
517
+
518
+ def __getattr__(self, name): # used for row.columnName type of value access
519
+ try:
520
+ return self._getValue(self.rows.columnNames[name.lower()])
521
+ except KeyError:
522
+ raise AttributeError('Unknown column name "{}"'.format(name))
523
+
524
+ def _getValue(self, key): # key must be an integer
525
+ if (
526
+ self.rows.recordset_format == RS_ARRAY
527
+ ): # retrieve from two-dimensional array
528
+ v = self.rows.ado_results[key, self.index]
529
+ elif self.rows.recordset_format == RS_REMOTE:
530
+ v = self.rows.ado_results[self.index][key]
531
+ else: # pywin32 - retrieve from tuple of tuples
532
+ v = self.rows.ado_results[key][self.index]
533
+ if self.rows.converters is NotImplemented:
534
+ return v
535
+ return convert_to_python(v, self.rows.converters[key])
536
+
537
+ def __len__(self):
538
+ return self.rows.numberOfColumns
539
+
540
+ def __getitem__(self, key): # used for row[key] type of value access
541
+ if isinstance(key, int): # normal row[1] designation
542
+ try:
543
+ return self._getValue(key)
544
+ except IndexError:
545
+ raise
546
+ if isinstance(key, slice):
547
+ indices = key.indices(self.rows.numberOfColumns)
548
+ vl = [self._getValue(i) for i in range(*indices)]
549
+ return tuple(vl)
550
+ try:
551
+ return self._getValue(
552
+ self.rows.columnNames[key.lower()]
553
+ ) # extension row[columnName] designation
554
+ except (KeyError, TypeError):
555
+ er, st, tr = sys.exc_info()
556
+ raise er(f'No such key as "{key!r}" in {self!r}').with_traceback(tr)
557
+
558
+ def __iter__(self):
559
+ return iter(self.__next__())
560
+
561
+ def __next__(self):
562
+ for n in range(self.rows.numberOfColumns):
563
+ yield self._getValue(n)
564
+
565
+ def __repr__(self): # create a human readable representation
566
+ taglist = sorted(list(self.rows.columnNames.items()), key=lambda x: x[1])
567
+ s = "<SQLrow={"
568
+ for name, i in taglist:
569
+ s += f"{name}:{self._getValue(i)!r}, "
570
+ return s[:-2] + "}>"
571
+
572
+ def __str__(self): # create a pretty human readable representation
573
+ return str(
574
+ tuple(str(self._getValue(i)) for i in range(self.rows.numberOfColumns))
575
+ )
576
+
577
+ # TO-DO implement pickling an SQLrow directly
578
+ # def __getstate__(self): return self.__dict__
579
+ # def __setstate__(self, d): self.__dict__.update(d)
580
+ # which basically tell pickle to treat your class just like a normal one,
581
+ # taking self.__dict__ as representing the whole of the instance state,
582
+ # despite the existence of the __getattr__.
583
+ # # # #
584
+
585
+
586
+ class SQLrows:
587
+ # class to emulate a sequence for multiple rows using a container object
588
+ def __init__(self, ado_results, numberOfRows, cursor):
589
+ self.ado_results = ado_results # raw result of SQL get
590
+ try:
591
+ self.recordset_format = cursor.recordset_format
592
+ self.numberOfColumns = cursor.numberOfColumns
593
+ self.converters = cursor.converters
594
+ self.columnNames = cursor.columnNames
595
+ except AttributeError:
596
+ self.recordset_format = RS_ARRAY
597
+ self.numberOfColumns = 0
598
+ self.converters = []
599
+ self.columnNames = {}
600
+ self.numberOfRows = numberOfRows
601
+
602
+ def __len__(self):
603
+ return self.numberOfRows
604
+
605
+ def __getitem__(self, item): # used for row or row,column access
606
+ if not self.ado_results:
607
+ return []
608
+ if isinstance(item, slice): # will return a list of row objects
609
+ indices = item.indices(self.numberOfRows)
610
+ return [SQLrow(self, k) for k in range(*indices)]
611
+ elif isinstance(item, tuple) and len(item) == 2:
612
+ # d = some_rowsObject[i,j] will return a datum from a two-dimension address
613
+ i, j = item
614
+ if not isinstance(j, int):
615
+ try:
616
+ j = self.columnNames[j.lower()] # convert named column to numeric
617
+ except KeyError:
618
+ raise KeyError(f"adodbapi: no such column name as {j!r}")
619
+ if self.recordset_format == RS_ARRAY: # retrieve from two-dimensional array
620
+ v = self.ado_results[j, i]
621
+ elif self.recordset_format == RS_REMOTE:
622
+ v = self.ado_results[i][j]
623
+ else: # pywin32 - retrieve from tuple of tuples
624
+ v = self.ado_results[j][i]
625
+ if self.converters is NotImplemented:
626
+ return v
627
+ return convert_to_python(v, self.converters[j])
628
+ else:
629
+ row = SQLrow(self, item) # new row descriptor
630
+ return row
631
+
632
+ def __iter__(self):
633
+ return iter(self.__next__())
634
+
635
+ def __next__(self):
636
+ for n in range(self.numberOfRows):
637
+ row = SQLrow(self, n)
638
+ yield row
639
+ # # # # #
640
+
641
+ # # # # # functions to re-format SQL requests to other paramstyle requirements # # # # # # # # # #
642
+
643
+
644
+ def changeNamedToQmark(
645
+ op,
646
+ ): # convert from 'named' paramstyle to ADO required '?'mark parameters
647
+ outOp = ""
648
+ outparms = []
649
+ chunks = op.split(
650
+ "'"
651
+ ) # quote all literals -- odd numbered list results are literals.
652
+ inQuotes = False
653
+ for chunk in chunks:
654
+ if inQuotes: # this is inside a quote
655
+ if chunk == "": # double apostrophe to quote one apostrophe
656
+ outOp = outOp[:-1] # so take one away
657
+ else:
658
+ outOp += "'" + chunk + "'" # else pass the quoted string as is.
659
+ else: # is SQL code -- look for a :namedParameter
660
+ while chunk: # some SQL string remains
661
+ sp = chunk.split(":", 1)
662
+ outOp += sp[0] # concat the part up to the :
663
+ s = ""
664
+ try:
665
+ chunk = sp[1]
666
+ except IndexError:
667
+ chunk = None
668
+ if chunk: # there was a parameter - parse it out
669
+ i = 0
670
+ c = chunk[0]
671
+ while c.isalnum() or c == "_":
672
+ i += 1
673
+ try:
674
+ c = chunk[i]
675
+ except IndexError:
676
+ break
677
+ s = chunk[:i]
678
+ chunk = chunk[i:]
679
+ if s:
680
+ outparms.append(s) # list the parameters in order
681
+ outOp += "?" # put in the Qmark
682
+ inQuotes = not inQuotes
683
+ return outOp, outparms
684
+
685
+
686
+ def changeFormatToQmark(
687
+ op,
688
+ ): # convert from 'format' paramstyle to ADO required '?'mark parameters
689
+ outOp = ""
690
+ outparams = []
691
+ chunks = op.split(
692
+ "'"
693
+ ) # quote all literals -- odd numbered list results are literals.
694
+ inQuotes = False
695
+ for chunk in chunks:
696
+ if inQuotes:
697
+ if (
698
+ outOp != "" and chunk == ""
699
+ ): # he used a double apostrophe to quote one apostrophe
700
+ outOp = outOp[:-1] # so take one away
701
+ else:
702
+ outOp += "'" + chunk + "'" # else pass the quoted string as is.
703
+ else: # is SQL code -- look for a %s parameter
704
+ if "%(" in chunk: # ugh! pyformat!
705
+ while chunk: # some SQL string remains
706
+ sp = chunk.split("%(", 1)
707
+ outOp += sp[0] # concat the part up to the %
708
+ if len(sp) > 1:
709
+ try:
710
+ s, chunk = sp[1].split(")s", 1) # find the ')s'
711
+ except ValueError:
712
+ raise ProgrammingError(
713
+ 'Pyformat SQL has incorrect format near "%s"' % chunk
714
+ )
715
+ outparams.append(s)
716
+ outOp += "?" # put in the Qmark
717
+ else:
718
+ chunk = None
719
+ else: # proper '%s' format
720
+ sp = chunk.split("%s") # make each %s
721
+ outOp += "?".join(sp) # into ?
722
+ inQuotes = not inQuotes # every other chunk is a quoted string
723
+ return outOp, outparams
venv/Lib/site-packages/adodbapi/is64bit.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """is64bit.Python() --> boolean value of detected Python word size. is64bit.os() --> os build version"""
2
+
3
+ import sys
4
+
5
+
6
+ def Python():
7
+ return sys.maxsize > 2147483647
8
+
9
+
10
+ def os():
11
+ import platform
12
+
13
+ pm = platform.machine()
14
+ if pm != ".." and pm.endswith("64"): # recent 64 bit Python
15
+ return True
16
+ else:
17
+ import os
18
+
19
+ if "PROCESSOR_ARCHITEW6432" in os.environ:
20
+ return True # 32 bit program running on 64 bit Windows
21
+ try:
22
+ return os.environ["PROCESSOR_ARCHITECTURE"].endswith(
23
+ "64"
24
+ ) # 64 bit Windows 64 bit program
25
+ except (IndexError, KeyError):
26
+ pass # not Windows
27
+ try:
28
+ return "64" in platform.architecture()[0] # this often works in Linux
29
+ except:
30
+ return False # is an older version of Python, assume also an older os (best we can guess)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ print("is64bit.Python() =", Python(), "is64bit.os() =", os())
venv/Lib/site-packages/adodbapi/license.txt ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU LESSER GENERAL PUBLIC LICENSE
2
+ Version 2.1, February 1999
3
+
4
+ Copyright (C) 1991, 1999 Free Software Foundation, Inc.
5
+ 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
6
+ Everyone is permitted to copy and distribute verbatim copies
7
+ of this license document, but changing it is not allowed.
8
+
9
+ [This is the first released version of the Lesser GPL. It also counts
10
+ as the successor of the GNU Library Public License, version 2, hence
11
+ the version number 2.1.]
12
+
13
+ Preamble
14
+
15
+ The licenses for most software are designed to take away your
16
+ freedom to share and change it. By contrast, the GNU General Public
17
+ Licenses are intended to guarantee your freedom to share and change
18
+ free software--to make sure the software is free for all its users.
19
+
20
+ This license, the Lesser General Public License, applies to some
21
+ specially designated software packages--typically libraries--of the
22
+ Free Software Foundation and other authors who decide to use it. You
23
+ can use it too, but we suggest you first think carefully about whether
24
+ this license or the ordinary General Public License is the better
25
+ strategy to use in any particular case, based on the explanations below.
26
+
27
+ When we speak of free software, we are referring to freedom of use,
28
+ not price. Our General Public Licenses are designed to make sure that
29
+ you have the freedom to distribute copies of free software (and charge
30
+ for this service if you wish); that you receive source code or can get
31
+ it if you want it; that you can change the software and use pieces of
32
+ it in new free programs; and that you are informed that you can do
33
+ these things.
34
+
35
+ To protect your rights, we need to make restrictions that forbid
36
+ distributors to deny you these rights or to ask you to surrender these
37
+ rights. These restrictions translate to certain responsibilities for
38
+ you if you distribute copies of the library or if you modify it.
39
+
40
+ For example, if you distribute copies of the library, whether gratis
41
+ or for a fee, you must give the recipients all the rights that we gave
42
+ you. You must make sure that they, too, receive or can get the source
43
+ code. If you link other code with the library, you must provide
44
+ complete object files to the recipients, so that they can relink them
45
+ with the library after making changes to the library and recompiling
46
+ it. And you must show them these terms so they know their rights.
47
+
48
+ We protect your rights with a two-step method: (1) we copyright the
49
+ library, and (2) we offer you this license, which gives you legal
50
+ permission to copy, distribute and/or modify the library.
51
+
52
+ To protect each distributor, we want to make it very clear that
53
+ there is no warranty for the free library. Also, if the library is
54
+ modified by someone else and passed on, the recipients should know
55
+ that what they have is not the original version, so that the original
56
+ author's reputation will not be affected by problems that might be
57
+ introduced by others.
58
+
59
+
60
+
61
+ Finally, software patents pose a constant threat to the existence of
62
+ any free program. We wish to make sure that a company cannot
63
+ effectively restrict the users of a free program by obtaining a
64
+ restrictive license from a patent holder. Therefore, we insist that
65
+ any patent license obtained for a version of the library must be
66
+ consistent with the full freedom of use specified in this license.
67
+
68
+ Most GNU software, including some libraries, is covered by the
69
+ ordinary GNU General Public License. This license, the GNU Lesser
70
+ General Public License, applies to certain designated libraries, and
71
+ is quite different from the ordinary General Public License. We use
72
+ this license for certain libraries in order to permit linking those
73
+ libraries into non-free programs.
74
+
75
+ When a program is linked with a library, whether statically or using
76
+ a shared library, the combination of the two is legally speaking a
77
+ combined work, a derivative of the original library. The ordinary
78
+ General Public License therefore permits such linking only if the
79
+ entire combination fits its criteria of freedom. The Lesser General
80
+ Public License permits more lax criteria for linking other code with
81
+ the library.
82
+
83
+ We call this license the "Lesser" General Public License because it
84
+ does Less to protect the user's freedom than the ordinary General
85
+ Public License. It also provides other free software developers Less
86
+ of an advantage over competing non-free programs. These disadvantages
87
+ are the reason we use the ordinary General Public License for many
88
+ libraries. However, the Lesser license provides advantages in certain
89
+ special circumstances.
90
+
91
+ For example, on rare occasions, there may be a special need to
92
+ encourage the widest possible use of a certain library, so that it becomes
93
+ a de-facto standard. To achieve this, non-free programs must be
94
+ allowed to use the library. A more frequent case is that a free
95
+ library does the same job as widely used non-free libraries. In this
96
+ case, there is little to gain by limiting the free library to free
97
+ software only, so we use the Lesser General Public License.
98
+
99
+ In other cases, permission to use a particular library in non-free
100
+ programs enables a greater number of people to use a large body of
101
+ free software. For example, permission to use the GNU C Library in
102
+ non-free programs enables many more people to use the whole GNU
103
+ operating system, as well as its variant, the GNU/Linux operating
104
+ system.
105
+
106
+ Although the Lesser General Public License is Less protective of the
107
+ users' freedom, it does ensure that the user of a program that is
108
+ linked with the Library has the freedom and the wherewithal to run
109
+ that program using a modified version of the Library.
110
+
111
+ The precise terms and conditions for copying, distribution and
112
+ modification follow. Pay close attention to the difference between a
113
+ "work based on the library" and a "work that uses the library". The
114
+ former contains code derived from the library, whereas the latter must
115
+ be combined with the library in order to run.
116
+
117
+
118
+
119
+ GNU LESSER GENERAL PUBLIC LICENSE
120
+ TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
121
+
122
+ 0. This License Agreement applies to any software library or other
123
+ program which contains a notice placed by the copyright holder or
124
+ other authorized party saying it may be distributed under the terms of
125
+ this Lesser General Public License (also called "this License").
126
+ Each licensee is addressed as "you".
127
+
128
+ A "library" means a collection of software functions and/or data
129
+ prepared so as to be conveniently linked with application programs
130
+ (which use some of those functions and data) to form executables.
131
+
132
+ The "Library", below, refers to any such software library or work
133
+ which has been distributed under these terms. A "work based on the
134
+ Library" means either the Library or any derivative work under
135
+ copyright law: that is to say, a work containing the Library or a
136
+ portion of it, either verbatim or with modifications and/or translated
137
+ straightforwardly into another language. (Hereinafter, translation is
138
+ included without limitation in the term "modification".)
139
+
140
+ "Source code" for a work means the preferred form of the work for
141
+ making modifications to it. For a library, complete source code means
142
+ all the source code for all modules it contains, plus any associated
143
+ interface definition files, plus the scripts used to control compilation
144
+ and installation of the library.
145
+
146
+ Activities other than copying, distribution and modification are not
147
+ covered by this License; they are outside its scope. The act of
148
+ running a program using the Library is not restricted, and output from
149
+ such a program is covered only if its contents constitute a work based
150
+ on the Library (independent of the use of the Library in a tool for
151
+ writing it). Whether that is true depends on what the Library does
152
+ and what the program that uses the Library does.
153
+
154
+ 1. You may copy and distribute verbatim copies of the Library's
155
+ complete source code as you receive it, in any medium, provided that
156
+ you conspicuously and appropriately publish on each copy an
157
+ appropriate copyright notice and disclaimer of warranty; keep intact
158
+ all the notices that refer to this License and to the absence of any
159
+ warranty; and distribute a copy of this License along with the
160
+ Library.
161
+ You may charge a fee for the physical act of transferring a copy,
162
+ and you may at your option offer warranty protection in exchange for a
163
+ fee.
164
+
165
+ 2. You may modify your copy or copies of the Library or any portion
166
+ of it, thus forming a work based on the Library, and copy and
167
+ distribute such modifications or work under the terms of Section 1
168
+ above, provided that you also meet all of these conditions:
169
+
170
+ a) The modified work must itself be a software library.
171
+
172
+ b) You must cause the files modified to carry prominent notices
173
+ stating that you changed the files and the date of any change.
174
+
175
+ c) You must cause the whole of the work to be licensed at no
176
+ charge to all third parties under the terms of this License.
177
+
178
+ d) If a facility in the modified Library refers to a function or a
179
+ table of data to be supplied by an application program that uses
180
+ the facility, other than as an argument passed when the facility
181
+ is invoked, then you must make a good faith effort to ensure that,
182
+ in the event an application does not supply such function or
183
+ table, the facility still operates, and performs whatever part of
184
+ its purpose remains meaningful.
185
+
186
+ (For example, a function in a library to compute square roots has
187
+ a purpose that is entirely well-defined independent of the
188
+ application. Therefore, Subsection 2d requires that any
189
+ application-supplied function or table used by this function must
190
+ be optional: if the application does not supply it, the square
191
+ root function must still compute square roots.)
192
+
193
+ These requirements apply to the modified work as a whole. If
194
+ identifiable sections of that work are not derived from the Library,
195
+ and can be reasonably considered independent and separate works in
196
+ themselves, then this License, and its terms, do not apply to those
197
+ sections when you distribute them as separate works. But when you
198
+ distribute the same sections as part of a whole which is a work based
199
+ on the Library, the distribution of the whole must be on the terms of
200
+ this License, whose permissions for other licensees extend to the
201
+ entire whole, and thus to each and every part regardless of who wrote
202
+ it.
203
+
204
+ Thus, it is not the intent of this section to claim rights or contest
205
+ your rights to work written entirely by you; rather, the intent is to
206
+ exercise the right to control the distribution of derivative or
207
+ collective works based on the Library.
208
+
209
+ In addition, mere aggregation of another work not based on the Library
210
+ with the Library (or with a work based on the Library) on a volume of
211
+ a storage or distribution medium does not bring the other work under
212
+ the scope of this License.
213
+
214
+ 3. You may opt to apply the terms of the ordinary GNU General Public
215
+ License instead of this License to a given copy of the Library. To do
216
+ this, you must alter all the notices that refer to this License, so
217
+ that they refer to the ordinary GNU General Public License, version 2,
218
+ instead of to this License. (If a newer version than version 2 of the
219
+ ordinary GNU General Public License has appeared, then you can specify
220
+ that version instead if you wish.) Do not make any other change in
221
+ these notices.
222
+
223
+ Once this change is made in a given copy, it is irreversible for
224
+ that copy, so the ordinary GNU General Public License applies to all
225
+ subsequent copies and derivative works made from that copy.
226
+
227
+ This option is useful when you wish to copy part of the code of
228
+ the Library into a program that is not a library.
229
+
230
+ 4. You may copy and distribute the Library (or a portion or
231
+ derivative of it, under Section 2) in object code or executable form
232
+ under the terms of Sections 1 and 2 above provided that you accompany
233
+ it with the complete corresponding machine-readable source code, which
234
+ must be distributed under the terms of Sections 1 and 2 above on a
235
+ medium customarily used for software interchange.
236
+
237
+ If distribution of object code is made by offering access to copy
238
+ from a designated place, then offering equivalent access to copy the
239
+ source code from the same place satisfies the requirement to
240
+ distribute the source code, even though third parties are not
241
+ compelled to copy the source along with the object code.
242
+
243
+ 5. A program that contains no derivative of any portion of the
244
+ Library, but is designed to work with the Library by being compiled or
245
+ linked with it, is called a "work that uses the Library". Such a
246
+ work, in isolation, is not a derivative work of the Library, and
247
+ therefore falls outside the scope of this License.
248
+
249
+ However, linking a "work that uses the Library" with the Library
250
+ creates an executable that is a derivative of the Library (because it
251
+ contains portions of the Library), rather than a "work that uses the
252
+ library". The executable is therefore covered by this License.
253
+ Section 6 states terms for distribution of such executables.
254
+
255
+ When a "work that uses the Library" uses material from a header file
256
+ that is part of the Library, the object code for the work may be a
257
+ derivative work of the Library even though the source code is not.
258
+ Whether this is true is especially significant if the work can be
259
+ linked without the Library, or if the work is itself a library. The
260
+ threshold for this to be true is not precisely defined by law.
261
+
262
+ If such an object file uses only numerical parameters, data
263
+ structure layouts and accessors, and small macros and small inline
264
+ functions (ten lines or less in length), then the use of the object
265
+ file is unrestricted, regardless of whether it is legally a derivative
266
+ work. (Executables containing this object code plus portions of the
267
+ Library will still fall under Section 6.)
268
+
269
+ Otherwise, if the work is a derivative of the Library, you may
270
+ distribute the object code for the work under the terms of Section 6.
271
+ Any executables containing that work also fall under Section 6,
272
+ whether or not they are linked directly with the Library itself.
273
+
274
+ 6. As an exception to the Sections above, you may also combine or
275
+ link a "work that uses the Library" with the Library to produce a
276
+ work containing portions of the Library, and distribute that work
277
+ under terms of your choice, provided that the terms permit
278
+ modification of the work for the customer's own use and reverse
279
+ engineering for debugging such modifications.
280
+
281
+ You must give prominent notice with each copy of the work that the
282
+ Library is used in it and that the Library and its use are covered by
283
+ this License. You must supply a copy of this License. If the work
284
+ during execution displays copyright notices, you must include the
285
+ copyright notice for the Library among them, as well as a reference
286
+ directing the user to the copy of this License. Also, you must do one
287
+ of these things:
288
+
289
+ a) Accompany the work with the complete corresponding
290
+ machine-readable source code for the Library including whatever
291
+ changes were used in the work (which must be distributed under
292
+ Sections 1 and 2 above); and, if the work is an executable linked
293
+ with the Library, with the complete machine-readable "work that
294
+ uses the Library", as object code and/or source code, so that the
295
+ user can modify the Library and then relink to produce a modified
296
+ executable containing the modified Library. (It is understood
297
+ that the user who changes the contents of definitions files in the
298
+ Library will not necessarily be able to recompile the application
299
+ to use the modified definitions.)
300
+
301
+ b) Use a suitable shared library mechanism for linking with the
302
+ Library. A suitable mechanism is one that (1) uses at run time a
303
+ copy of the library already present on the user's computer system,
304
+ rather than copying library functions into the executable, and (2)
305
+ will operate properly with a modified version of the library, if
306
+ the user installs one, as long as the modified version is
307
+ interface-compatible with the version that the work was made with.
308
+
309
+ c) Accompany the work with a written offer, valid for at
310
+ least three years, to give the same user the materials
311
+ specified in Subsection 6a, above, for a charge no more
312
+ than the cost of performing this distribution.
313
+
314
+ d) If distribution of the work is made by offering access to copy
315
+ from a designated place, offer equivalent access to copy the above
316
+ specified materials from the same place.
317
+
318
+ e) Verify that the user has already received a copy of these
319
+ materials or that you have already sent this user a copy.
320
+
321
+ For an executable, the required form of the "work that uses the
322
+ Library" must include any data and utility programs needed for
323
+ reproducing the executable from it. However, as a special exception,
324
+ the materials to be distributed need not include anything that is
325
+ normally distributed (in either source or binary form) with the major
326
+ components (compiler, kernel, and so on) of the operating system on
327
+ which the executable runs, unless that component itself accompanies
328
+ the executable.
329
+
330
+ It may happen that this requirement contradicts the license
331
+ restrictions of other proprietary libraries that do not normally
332
+ accompany the operating system. Such a contradiction means you cannot
333
+ use both them and the Library together in an executable that you
334
+ distribute.
335
+
336
+ 7. You may place library facilities that are a work based on the
337
+ Library side-by-side in a single library together with other library
338
+ facilities not covered by this License, and distribute such a combined
339
+ library, provided that the separate distribution of the work based on
340
+ the Library and of the other library facilities is otherwise
341
+ permitted, and provided that you do these two things:
342
+
343
+ a) Accompany the combined library with a copy of the same work
344
+ based on the Library, uncombined with any other library
345
+ facilities. This must be distributed under the terms of the
346
+ Sections above.
347
+
348
+ b) Give prominent notice with the combined library of the fact
349
+ that part of it is a work based on the Library, and explaining
350
+ where to find the accompanying uncombined form of the same work.
351
+
352
+ 8. You may not copy, modify, sublicense, link with, or distribute
353
+ the Library except as expressly provided under this License. Any
354
+ attempt otherwise to copy, modify, sublicense, link with, or
355
+ distribute the Library is void, and will automatically terminate your
356
+ rights under this License. However, parties who have received copies,
357
+ or rights, from you under this License will not have their licenses
358
+ terminated so long as such parties remain in full compliance.
359
+
360
+ 9. You are not required to accept this License, since you have not
361
+ signed it. However, nothing else grants you permission to modify or
362
+ distribute the Library or its derivative works. These actions are
363
+ prohibited by law if you do not accept this License. Therefore, by
364
+ modifying or distributing the Library (or any work based on the
365
+ Library), you indicate your acceptance of this License to do so, and
366
+ all its terms and conditions for copying, distributing or modifying
367
+ the Library or works based on it.
368
+
369
+ 10. Each time you redistribute the Library (or any work based on the
370
+ Library), the recipient automatically receives a license from the
371
+ original licensor to copy, distribute, link with or modify the Library
372
+ subject to these terms and conditions. You may not impose any further
373
+ restrictions on the recipients' exercise of the rights granted herein.
374
+ You are not responsible for enforcing compliance by third parties with
375
+ this License.
376
+
377
+ 11. If, as a consequence of a court judgment or allegation of patent
378
+ infringement or for any other reason (not limited to patent issues),
379
+ conditions are imposed on you (whether by court order, agreement or
380
+ otherwise) that contradict the conditions of this License, they do not
381
+ excuse you from the conditions of this License. If you cannot
382
+ distribute so as to satisfy simultaneously your obligations under this
383
+ License and any other pertinent obligations, then as a consequence you
384
+ may not distribute the Library at all. For example, if a patent
385
+ license would not permit royalty-free redistribution of the Library by
386
+ all those who receive copies directly or indirectly through you, then
387
+ the only way you could satisfy both it and this License would be to
388
+ refrain entirely from distribution of the Library.
389
+
390
+ If any portion of this section is held invalid or unenforceable under any
391
+ particular circumstance, the balance of the section is intended to apply,
392
+ and the section as a whole is intended to apply in other circumstances.
393
+
394
+ It is not the purpose of this section to induce you to infringe any
395
+ patents or other property right claims or to contest validity of any
396
+ such claims; this section has the sole purpose of protecting the
397
+ integrity of the free software distribution system which is
398
+ implemented by public license practices. Many people have made
399
+ generous contributions to the wide range of software distributed
400
+ through that system in reliance on consistent application of that
401
+ system; it is up to the author/donor to decide if he or she is willing
402
+ to distribute software through any other system and a licensee cannot
403
+ impose that choice.
404
+
405
+ This section is intended to make thoroughly clear what is believed to
406
+ be a consequence of the rest of this License.
407
+
408
+ 12. If the distribution and/or use of the Library is restricted in
409
+ certain countries either by patents or by copyrighted interfaces, the
410
+ original copyright holder who places the Library under this License may add
411
+ an explicit geographical distribution limitation excluding those countries,
412
+ so that distribution is permitted only in or among countries not thus
413
+ excluded. In such case, this License incorporates the limitation as if
414
+ written in the body of this License.
415
+
416
+ 13. The Free Software Foundation may publish revised and/or new
417
+ versions of the Lesser General Public License from time to time.
418
+ Such new versions will be similar in spirit to the present version,
419
+ but may differ in detail to address new problems or concerns.
420
+
421
+ Each version is given a distinguishing version number. If the Library
422
+ specifies a version number of this License which applies to it and
423
+ "any later version", you have the option of following the terms and
424
+ conditions either of that version or of any later version published by
425
+ the Free Software Foundation. If the Library does not specify a
426
+ license version number, you may choose any version ever published by
427
+ the Free Software Foundation.
428
+
429
+ 14. If you wish to incorporate parts of the Library into other free
430
+ programs whose distribution conditions are incompatible with these,
431
+ write to the author to ask for permission. For software which is
432
+ copyrighted by the Free Software Foundation, write to the Free
433
+ Software Foundation; we sometimes make exceptions for this. Our
434
+ decision will be guided by the two goals of preserving the free status
435
+ of all derivatives of our free software and of promoting the sharing
436
+ and reuse of software generally.
437
+
438
+ NO WARRANTY
439
+
440
+ 15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO
441
+ WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW.
442
+ EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR
443
+ OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY
444
+ KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE
445
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
446
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE
447
+ LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME
448
+ THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
449
+
450
+ 16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN
451
+ WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY
452
+ AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU
453
+ FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR
454
+ CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE
455
+ LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING
456
+ RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A
457
+ FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF
458
+ SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
459
+ DAMAGES.
460
+
461
+ END OF TERMS AND CONDITIONS
462
+
463
+ How to Apply These Terms to Your New Libraries
464
+
465
+ If you develop a new library, and you want it to be of the greatest
466
+ possible use to the public, we recommend making it free software that
467
+ everyone can redistribute and change. You can do so by permitting
468
+ redistribution under these terms (or, alternatively, under the terms of the
469
+ ordinary General Public License).
470
+
471
+ To apply these terms, attach the following notices to the library. It is
472
+ safest to attach them to the start of each source file to most effectively
473
+ convey the exclusion of warranty; and each file should have at least the
474
+ "copyright" line and a pointer to where the full notice is found.
475
+
476
+ <one line to give the library's name and a brief idea of what it does.>
477
+ Copyright (C) <year> <name of author>
478
+
479
+ This library is free software; you can redistribute it and/or
480
+ modify it under the terms of the GNU Lesser General Public
481
+ License as published by the Free Software Foundation; either
482
+ version 2.1 of the License, or (at your option) any later version.
483
+
484
+ This library is distributed in the hope that it will be useful,
485
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
486
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
487
+ Lesser General Public License for more details.
488
+
489
+ You should have received a copy of the GNU Lesser General Public
490
+ License along with this library; if not, write to the Free Software
491
+ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
492
+
493
+ Also add information on how to contact you by electronic and paper mail.
494
+
495
+ You should also get your employer (if you work as a programmer) or your
496
+ school, if any, to sign a "copyright disclaimer" for the library, if
497
+ necessary. Here is a sample; alter the names:
498
+
499
+ Yoyodyne, Inc., hereby disclaims all copyright interest in the
500
+ library `Frob' (a library for tweaking knobs) written by James Random Hacker.
501
+
502
+ <signature of Ty Coon>, 1 April 1990
503
+ Ty Coon, President of Vice
504
+
505
+ That's all there is to it!
venv/Lib/site-packages/adodbapi/process_connect_string.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """a clumsy attempt at a macro language to let the programmer execute code on the server (ex: determine 64bit)"""
2
+
3
+ from . import is64bit
4
+
5
+
6
+ def macro_call(macro_name, args, kwargs):
7
+ """allow the programmer to perform limited processing on the server by passing macro names and args
8
+
9
+ :new_key - the key name the macro will create
10
+ :args[0] - macro name
11
+ :args[1:] - any arguments
12
+ :code - the value of the keyword item
13
+ :kwargs - the connection keyword dictionary. ??key has been removed
14
+ --> the value to put in for kwargs['name'] = value
15
+ """
16
+ if isinstance(args, (str, str)):
17
+ args = [
18
+ args
19
+ ] # the user forgot to pass a sequence, so make a string into args[0]
20
+ new_key = args[0]
21
+ try:
22
+ if macro_name == "is64bit":
23
+ if is64bit.Python(): # if on 64 bit Python
24
+ return new_key, args[1] # return first argument
25
+ else:
26
+ try:
27
+ return new_key, args[2] # else return second argument (if defined)
28
+ except IndexError:
29
+ return new_key, "" # else return blank
30
+
31
+ elif (
32
+ macro_name == "getuser"
33
+ ): # get the name of the user the server is logged in under
34
+ if not new_key in kwargs:
35
+ import getpass
36
+
37
+ return new_key, getpass.getuser()
38
+
39
+ elif macro_name == "getnode": # get the name of the computer running the server
40
+ import platform
41
+
42
+ try:
43
+ return new_key, args[1] % platform.node()
44
+ except IndexError:
45
+ return new_key, platform.node()
46
+
47
+ elif macro_name == "getenv": # expand the server's environment variable args[1]
48
+ import os
49
+
50
+ try:
51
+ dflt = args[2] # if not found, default from args[2]
52
+ except IndexError: # or blank
53
+ dflt = ""
54
+ return new_key, os.environ.get(args[1], dflt)
55
+
56
+ elif macro_name == "auto_security":
57
+ if (
58
+ not "user" in kwargs or not kwargs["user"]
59
+ ): # missing, blank, or Null username
60
+ return new_key, "Integrated Security=SSPI"
61
+ return new_key, "User ID=%(user)s; Password=%(password)s" % kwargs
62
+
63
+ elif (
64
+ macro_name == "find_temp_test_path"
65
+ ): # helper function for testing ado operation -- undocumented
66
+ import os
67
+ import tempfile
68
+
69
+ return new_key, os.path.join(
70
+ tempfile.gettempdir(), "adodbapi_test", args[1]
71
+ )
72
+
73
+ raise ValueError(f"Unknown connect string macro={macro_name}")
74
+ except:
75
+ raise ValueError(f"Error in macro processing {macro_name} {args!r}")
76
+
77
+
78
+ def process(
79
+ args, kwargs, expand_macros=False
80
+ ): # --> connection string with keyword arguments processed.
81
+ """attempts to inject arguments into a connection string using Python "%" operator for strings
82
+
83
+ co: adodbapi connection object
84
+ args: positional parameters from the .connect() call
85
+ kvargs: keyword arguments from the .connect() call
86
+ """
87
+ try:
88
+ dsn = args[0]
89
+ except IndexError:
90
+ dsn = None
91
+ # as a convenience the first argument may be django settings
92
+ if isinstance(dsn, dict):
93
+ kwargs.update(dsn)
94
+ # the connection string is passed to the connection as part of the keyword dictionary
95
+ elif dsn:
96
+ kwargs["connection_string"] = dsn
97
+ try:
98
+ a1 = args[1]
99
+ except IndexError:
100
+ a1 = None
101
+ # historically, the second positional argument might be a timeout value
102
+ if isinstance(a1, int):
103
+ kwargs["timeout"] = a1
104
+ # if the second positional argument is a string, then it is user
105
+ elif isinstance(a1, str):
106
+ kwargs["user"] = a1
107
+ # if the second positional argument is a dictionary, use it as keyword arguments, too
108
+ elif isinstance(a1, dict):
109
+ kwargs.update(a1)
110
+ try:
111
+ kwargs["password"] = args[2] # the third positional argument is password
112
+ kwargs["host"] = args[3] # the fourth positional argument is host name
113
+ kwargs["database"] = args[4] # the fifth positional argument is database name
114
+ except IndexError:
115
+ pass
116
+
117
+ # make sure connection string is defined somehow
118
+ if not "connection_string" in kwargs:
119
+ try: # perhaps 'dsn' was defined
120
+ kwargs["connection_string"] = kwargs["dsn"]
121
+ except KeyError:
122
+ try: # as a last effort, use the "host" keyword
123
+ kwargs["connection_string"] = kwargs["host"]
124
+ except KeyError:
125
+ raise TypeError("Must define 'connection_string' for ado connections")
126
+ if expand_macros:
127
+ for kwarg in list(kwargs.keys()):
128
+ if kwarg.startswith("macro_"): # If a key defines a macro
129
+ macro_name = kwarg[6:] # name without the "macro_"
130
+ macro_code = kwargs.pop(
131
+ kwarg
132
+ ) # we remove the macro_key and get the code to execute
133
+ new_key, rslt = macro_call(
134
+ macro_name, macro_code, kwargs
135
+ ) # run the code in the local context
136
+ kwargs[new_key] = rslt # put the result back in the keywords dict
137
+ return kwargs
venv/Lib/site-packages/adodbapi/readme.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Project
2
+ -------
3
+ adodbapi
4
+
5
+ A Python DB-API 2.0 (PEP-249) module that makes it easy to use Microsoft ADO
6
+ for connecting with databases and other data sources using CPython.
7
+
8
+ Home page: <https://sourceforge.net/projects/adodbapi>
9
+
10
+ Features:
11
+ * 100% DB-API 2.0 (PEP-249) compliant (including most extensions and recommendations).
12
+ * Includes pyunit testcases that describe how to use the module.
13
+ * Fully implemented in Python. -- runs in current versions of Python 3
14
+ * Licensed under the LGPL license, which means that it can be used freely even in commercial programs subject to certain restrictions.
15
+ * The user can choose between paramstyles: 'qmark' 'named' 'format' 'pyformat' 'dynamic'
16
+ * Supports data retrieval by column name e.g.:
17
+ for row in myCurser.execute("select name,age from students"):
18
+ print("Student", row.name, "is", row.age, "years old.")
19
+ * Supports user-definable system-to-Python data conversion functions (selected by ADO data type, or by column)
20
+
21
+ Prerequisites:
22
+ * C Python 3.6 or higher
23
+ and pywin32 (Mark Hammond's python for windows extensions.)
24
+
25
+ Installation:
26
+ * (C-Python on Windows): Install pywin32 (`python -m pip install pywin32`) which includes adodbapi.
27
+ * (IronPython on Windows): Download adodbapi from https://sourceforge.net/projects/adodbapi/ . Unpack the zip.
28
+
29
+ NOTE: ...........
30
+ If you do not like the new default operation of returning Numeric columns as decimal.Decimal,
31
+ you can select other options by the user defined conversion feature.
32
+ Try:
33
+ adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = adodbapi.apibase.cvtString
34
+ or:
35
+ adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = adodbapi.apibase.cvtFloat
36
+ or:
37
+ adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = write_your_own_conversion_function
38
+ ............
39
+ notes for 2.6.2:
40
+ The definitive source has been moved to https://github.com/mhammond/pywin32/tree/main/adodbapi.
41
+ Remote has proven too hard to configure and test with Pyro4. I am moving it to unsupported status
42
+ until I can change to a different connection method.
43
+ what's new in version 2.6
44
+ A cursor.prepare() method and support for prepared SQL statements.
45
+ Lots of refactoring, especially of the Remote and Server modules (still to be treated as Beta code).
46
+ The quick start document 'quick_reference.odt' will export as a nice-looking pdf.
47
+ Added paramstyles 'pyformat' and 'dynamic'. If your 'paramstyle' is 'named' you _must_ pass a dictionary of
48
+ parameters to your .execute() method. If your 'paramstyle' is 'format' 'pyformat' or 'dynamic', you _may_
49
+ pass a dictionary of parameters -- provided your SQL operation string is formatted correctly.
50
+
51
+ what's new in version 2.5
52
+ Remote module: (works on Linux!) allows a Windows computer to serve ADO databases via PyRO
53
+ Server module: PyRO server for ADO. Run using a command like= C:>python -m adodbapi.server
54
+ (server has simple connection string macros: is64bit, getuser, sql_provider, auto_security)
55
+ Brief documentation included. See adodbapi/examples folder adodbapi.rtf
56
+ New connection method conn.get_table_names() --> list of names of tables in database
57
+
58
+ Vastly refactored. Data conversion things have been moved to the new adodbapi.apibase module.
59
+ Many former module-level attributes are now class attributes. (Should be more thread-safe)
60
+ Connection objects are now context managers for transactions and will commit or rollback.
61
+ Cursor objects are context managers and will automatically close themselves.
62
+ Autocommit can be switched on and off.
63
+ Keyword and positional arguments on the connect() method work as documented in PEP 249.
64
+ Keyword arguments from the connect call can be formatted into the connection string.
65
+ New keyword arguments defined, such as: autocommit, paramstyle, remote_proxy, remote_port.
66
+ *** Breaking change: variantConversion lookups are simplified: the following will raise KeyError:
67
+ oldconverter=adodbapi.variantConversions[adodbapi.adoStringTypes]
68
+ Refactor as: oldconverter=adodbapi.variantConversions[adodbapi.adoStringTypes[0]]
69
+
70
+ License
71
+ -------
72
+ LGPL, see https://opensource.org/license/lgpl-2-1
73
+
74
+ Documentation
75
+ -------------
76
+
77
+ Look at:
78
+ - `adodbapi/quick_reference.md`
79
+ - https://wiki.python.org/moin/DatabaseProgramming#The_DB-API
80
+ - read the examples in adodbapi/examples
81
+ - and the test cases in `adodbapi/test directory`
82
+
83
+ Mailing lists
84
+ -------------
85
+ The adodbapi mailing lists have been deactivated. Submit comments to the
86
+ pywin32 mailing lists.
87
+ -- the bug tracker on sourceforge.net/projects/adodbapi may be checked, (infrequently).
88
+ -- please use: https://github.com/mhammond/pywin32/issues
venv/Lib/site-packages/adodbapi/schema_table.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """call using an open ADO connection --> list of table names"""
2
+
3
+ from . import adodbapi
4
+
5
+
6
+ def names(connection_object):
7
+ ado = connection_object.adoConn
8
+ schema = ado.OpenSchema(20) # constant = adSchemaTables
9
+
10
+ tables = []
11
+ while not schema.EOF:
12
+ name = adodbapi.getIndexedValue(schema.Fields, "TABLE_NAME").Value
13
+ tables.append(name)
14
+ schema.MoveNext()
15
+ del schema
16
+ return tables
venv/Lib/site-packages/adodbapi/setup.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adodbapi -- a pure Python PEP 249 DB-API package using Microsoft ADO
2
+
3
+ Adodbapi can be run on CPython 3.5 and later.
4
+ """
5
+
6
+ NAME = "adodbapi"
7
+ MAINTAINER = "Vernon Cole"
8
+ MAINTAINER_EMAIL = "[email protected]"
9
+ DESCRIPTION = (
10
+ """A pure Python package implementing PEP 249 DB-API using Microsoft ADO."""
11
+ )
12
+ URL = "https://sourceforge.net/projects/adodbapi"
13
+ LICENSE = "LGPL"
14
+ CLASSIFIERS = [
15
+ "Development Status :: 5 - Production/Stable",
16
+ "Intended Audience :: Developers",
17
+ "License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)",
18
+ "Operating System :: Microsoft :: Windows",
19
+ "Operating System :: POSIX :: Linux",
20
+ "Programming Language :: Python",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: SQL",
23
+ "Topic :: Software Development",
24
+ "Topic :: Software Development :: Libraries :: Python Modules",
25
+ "Topic :: Database",
26
+ ]
27
+ AUTHOR = "Henrik Ekelund, Vernon Cole, et.al."
28
+ AUTHOR_EMAIL = "[email protected]"
29
+ PLATFORMS = ["Windows", "Linux"]
30
+
31
+ VERSION = None # in case searching for version fails
32
+ a = open("adodbapi.py") # find the version string in the source code
33
+ for line in a:
34
+ if "__version__" in line:
35
+ VERSION = line.split("'")[1] # pyright: ignore[reportConstantRedefinition]
36
+ print('adodbapi version="%s"' % VERSION)
37
+ break
38
+ a.close()
39
+
40
+
41
+ def setup_package():
42
+ from setuptools import setup
43
+ from setuptools.command.build_py import build_py
44
+
45
+ setup(
46
+ cmdclass={"build_py": build_py},
47
+ name=NAME,
48
+ maintainer=MAINTAINER,
49
+ maintainer_email=MAINTAINER_EMAIL,
50
+ description=DESCRIPTION,
51
+ url=URL,
52
+ keywords="database ado odbc dbapi db-api Microsoft SQL",
53
+ ## download_url=DOWNLOAD_URL,
54
+ long_description=open("README.txt").read(),
55
+ license=LICENSE,
56
+ classifiers=CLASSIFIERS,
57
+ author=AUTHOR,
58
+ author_email=AUTHOR_EMAIL,
59
+ platforms=PLATFORMS,
60
+ version=VERSION,
61
+ package_dir={"adodbapi": ""},
62
+ packages=["adodbapi"],
63
+ )
64
+ return
65
+
66
+
67
+ if __name__ == "__main__":
68
+ setup_package()
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/LICENSE ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A. HISTORY OF THE SOFTWARE
2
+ ==========================
3
+
4
+ Python was created in the early 1990s by Guido van Rossum at Stichting
5
+ Mathematisch Centrum (CWI, see https://www.cwi.nl) in the Netherlands
6
+ as a successor of a language called ABC. Guido remains Python's
7
+ principal author, although it includes many contributions from others.
8
+
9
+ In 1995, Guido continued his work on Python at the Corporation for
10
+ National Research Initiatives (CNRI, see https://www.cnri.reston.va.us)
11
+ in Reston, Virginia where he released several versions of the
12
+ software.
13
+
14
+ In May 2000, Guido and the Python core development team moved to
15
+ BeOpen.com to form the BeOpen PythonLabs team. In October of the same
16
+ year, the PythonLabs team moved to Digital Creations, which became
17
+ Zope Corporation. In 2001, the Python Software Foundation (PSF, see
18
+ https://www.python.org/psf/) was formed, a non-profit organization
19
+ created specifically to own Python-related Intellectual Property.
20
+ Zope Corporation was a sponsoring member of the PSF.
21
+
22
+ All Python releases are Open Source (see https://opensource.org for
23
+ the Open Source Definition). Historically, most, but not all, Python
24
+ releases have also been GPL-compatible; the table below summarizes
25
+ the various releases.
26
+
27
+ Release Derived Year Owner GPL-
28
+ from compatible? (1)
29
+
30
+ 0.9.0 thru 1.2 1991-1995 CWI yes
31
+ 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes
32
+ 1.6 1.5.2 2000 CNRI no
33
+ 2.0 1.6 2000 BeOpen.com no
34
+ 1.6.1 1.6 2001 CNRI yes (2)
35
+ 2.1 2.0+1.6.1 2001 PSF no
36
+ 2.0.1 2.0+1.6.1 2001 PSF yes
37
+ 2.1.1 2.1+2.0.1 2001 PSF yes
38
+ 2.1.2 2.1.1 2002 PSF yes
39
+ 2.1.3 2.1.2 2002 PSF yes
40
+ 2.2 and above 2.1.1 2001-now PSF yes
41
+
42
+ Footnotes:
43
+
44
+ (1) GPL-compatible doesn't mean that we're distributing Python under
45
+ the GPL. All Python licenses, unlike the GPL, let you distribute
46
+ a modified version without making your changes open source. The
47
+ GPL-compatible licenses make it possible to combine Python with
48
+ other software that is released under the GPL; the others don't.
49
+
50
+ (2) According to Richard Stallman, 1.6.1 is not GPL-compatible,
51
+ because its license has a choice of law clause. According to
52
+ CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1
53
+ is "not incompatible" with the GPL.
54
+
55
+ Thanks to the many outside volunteers who have worked under Guido's
56
+ direction to make these releases possible.
57
+
58
+
59
+ B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON
60
+ ===============================================================
61
+
62
+ Python software and documentation are licensed under the
63
+ Python Software Foundation License Version 2.
64
+
65
+ Starting with Python 3.8.6, examples, recipes, and other code in
66
+ the documentation are dual licensed under the PSF License Version 2
67
+ and the Zero-Clause BSD license.
68
+
69
+ Some software incorporated into Python is under different licenses.
70
+ The licenses are listed with code falling under that license.
71
+
72
+
73
+ PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
74
+ --------------------------------------------
75
+
76
+ 1. This LICENSE AGREEMENT is between the Python Software Foundation
77
+ ("PSF"), and the Individual or Organization ("Licensee") accessing and
78
+ otherwise using this software ("Python") in source or binary form and
79
+ its associated documentation.
80
+
81
+ 2. Subject to the terms and conditions of this License Agreement, PSF hereby
82
+ grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
83
+ analyze, test, perform and/or display publicly, prepare derivative works,
84
+ distribute, and otherwise use Python alone or in any derivative version,
85
+ provided, however, that PSF's License Agreement and PSF's notice of copyright,
86
+ i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
87
+ 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023 Python Software Foundation;
88
+ All Rights Reserved" are retained in Python alone or in any derivative version
89
+ prepared by Licensee.
90
+
91
+ 3. In the event Licensee prepares a derivative work that is based on
92
+ or incorporates Python or any part thereof, and wants to make
93
+ the derivative work available to others as provided herein, then
94
+ Licensee hereby agrees to include in any such work a brief summary of
95
+ the changes made to Python.
96
+
97
+ 4. PSF is making Python available to Licensee on an "AS IS"
98
+ basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
99
+ IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
100
+ DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
101
+ FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
102
+ INFRINGE ANY THIRD PARTY RIGHTS.
103
+
104
+ 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
105
+ FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
106
+ A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
107
+ OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
108
+
109
+ 6. This License Agreement will automatically terminate upon a material
110
+ breach of its terms and conditions.
111
+
112
+ 7. Nothing in this License Agreement shall be deemed to create any
113
+ relationship of agency, partnership, or joint venture between PSF and
114
+ Licensee. This License Agreement does not grant permission to use PSF
115
+ trademarks or trade name in a trademark sense to endorse or promote
116
+ products or services of Licensee, or any third party.
117
+
118
+ 8. By copying, installing or otherwise using Python, Licensee
119
+ agrees to be bound by the terms and conditions of this License
120
+ Agreement.
121
+
122
+
123
+ BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0
124
+ -------------------------------------------
125
+
126
+ BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1
127
+
128
+ 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an
129
+ office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the
130
+ Individual or Organization ("Licensee") accessing and otherwise using
131
+ this software in source or binary form and its associated
132
+ documentation ("the Software").
133
+
134
+ 2. Subject to the terms and conditions of this BeOpen Python License
135
+ Agreement, BeOpen hereby grants Licensee a non-exclusive,
136
+ royalty-free, world-wide license to reproduce, analyze, test, perform
137
+ and/or display publicly, prepare derivative works, distribute, and
138
+ otherwise use the Software alone or in any derivative version,
139
+ provided, however, that the BeOpen Python License is retained in the
140
+ Software, alone or in any derivative version prepared by Licensee.
141
+
142
+ 3. BeOpen is making the Software available to Licensee on an "AS IS"
143
+ basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
144
+ IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND
145
+ DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
146
+ FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT
147
+ INFRINGE ANY THIRD PARTY RIGHTS.
148
+
149
+ 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE
150
+ SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS
151
+ AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY
152
+ DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
153
+
154
+ 5. This License Agreement will automatically terminate upon a material
155
+ breach of its terms and conditions.
156
+
157
+ 6. This License Agreement shall be governed by and interpreted in all
158
+ respects by the law of the State of California, excluding conflict of
159
+ law provisions. Nothing in this License Agreement shall be deemed to
160
+ create any relationship of agency, partnership, or joint venture
161
+ between BeOpen and Licensee. This License Agreement does not grant
162
+ permission to use BeOpen trademarks or trade names in a trademark
163
+ sense to endorse or promote products or services of Licensee, or any
164
+ third party. As an exception, the "BeOpen Python" logos available at
165
+ http://www.pythonlabs.com/logos.html may be used according to the
166
+ permissions granted on that web page.
167
+
168
+ 7. By copying, installing or otherwise using the software, Licensee
169
+ agrees to be bound by the terms and conditions of this License
170
+ Agreement.
171
+
172
+
173
+ CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1
174
+ ---------------------------------------
175
+
176
+ 1. This LICENSE AGREEMENT is between the Corporation for National
177
+ Research Initiatives, having an office at 1895 Preston White Drive,
178
+ Reston, VA 20191 ("CNRI"), and the Individual or Organization
179
+ ("Licensee") accessing and otherwise using Python 1.6.1 software in
180
+ source or binary form and its associated documentation.
181
+
182
+ 2. Subject to the terms and conditions of this License Agreement, CNRI
183
+ hereby grants Licensee a nonexclusive, royalty-free, world-wide
184
+ license to reproduce, analyze, test, perform and/or display publicly,
185
+ prepare derivative works, distribute, and otherwise use Python 1.6.1
186
+ alone or in any derivative version, provided, however, that CNRI's
187
+ License Agreement and CNRI's notice of copyright, i.e., "Copyright (c)
188
+ 1995-2001 Corporation for National Research Initiatives; All Rights
189
+ Reserved" are retained in Python 1.6.1 alone or in any derivative
190
+ version prepared by Licensee. Alternately, in lieu of CNRI's License
191
+ Agreement, Licensee may substitute the following text (omitting the
192
+ quotes): "Python 1.6.1 is made available subject to the terms and
193
+ conditions in CNRI's License Agreement. This Agreement together with
194
+ Python 1.6.1 may be located on the internet using the following
195
+ unique, persistent identifier (known as a handle): 1895.22/1013. This
196
+ Agreement may also be obtained from a proxy server on the internet
197
+ using the following URL: http://hdl.handle.net/1895.22/1013".
198
+
199
+ 3. In the event Licensee prepares a derivative work that is based on
200
+ or incorporates Python 1.6.1 or any part thereof, and wants to make
201
+ the derivative work available to others as provided herein, then
202
+ Licensee hereby agrees to include in any such work a brief summary of
203
+ the changes made to Python 1.6.1.
204
+
205
+ 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS"
206
+ basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
207
+ IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND
208
+ DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
209
+ FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT
210
+ INFRINGE ANY THIRD PARTY RIGHTS.
211
+
212
+ 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
213
+ 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
214
+ A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1,
215
+ OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
216
+
217
+ 6. This License Agreement will automatically terminate upon a material
218
+ breach of its terms and conditions.
219
+
220
+ 7. This License Agreement shall be governed by the federal
221
+ intellectual property law of the United States, including without
222
+ limitation the federal copyright law, and, to the extent such
223
+ U.S. federal law does not apply, by the law of the Commonwealth of
224
+ Virginia, excluding Virginia's conflict of law provisions.
225
+ Notwithstanding the foregoing, with regard to derivative works based
226
+ on Python 1.6.1 that incorporate non-separable material that was
227
+ previously distributed under the GNU General Public License (GPL), the
228
+ law of the Commonwealth of Virginia shall govern this License
229
+ Agreement only as to issues arising under or with respect to
230
+ Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this
231
+ License Agreement shall be deemed to create any relationship of
232
+ agency, partnership, or joint venture between CNRI and Licensee. This
233
+ License Agreement does not grant permission to use CNRI trademarks or
234
+ trade name in a trademark sense to endorse or promote products or
235
+ services of Licensee, or any third party.
236
+
237
+ 8. By clicking on the "ACCEPT" button where indicated, or by copying,
238
+ installing or otherwise using Python 1.6.1, Licensee agrees to be
239
+ bound by the terms and conditions of this License Agreement.
240
+
241
+ ACCEPT
242
+
243
+
244
+ CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2
245
+ --------------------------------------------------
246
+
247
+ Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam,
248
+ The Netherlands. All rights reserved.
249
+
250
+ Permission to use, copy, modify, and distribute this software and its
251
+ documentation for any purpose and without fee is hereby granted,
252
+ provided that the above copyright notice appear in all copies and that
253
+ both that copyright notice and this permission notice appear in
254
+ supporting documentation, and that the name of Stichting Mathematisch
255
+ Centrum or CWI not be used in advertising or publicity pertaining to
256
+ distribution of the software without specific, written prior
257
+ permission.
258
+
259
+ STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO
260
+ THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
261
+ FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE
262
+ FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
263
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
264
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
265
+ OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
266
+
267
+ ZERO-CLAUSE BSD LICENSE FOR CODE IN THE PYTHON DOCUMENTATION
268
+ ----------------------------------------------------------------------
269
+
270
+ Permission to use, copy, modify, and/or distribute this software for any
271
+ purpose with or without fee is hereby granted.
272
+
273
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
274
+ REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
275
+ AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
276
+ INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
277
+ LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
278
+ OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
279
+ PERFORMANCE OF THIS SOFTWARE.
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/METADATA ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.3
2
+ Name: aiohappyeyeballs
3
+ Version: 2.6.1
4
+ Summary: Happy Eyeballs for asyncio
5
+ License: PSF-2.0
6
+ Author: J. Nick Koston
7
+ Author-email: [email protected]
8
+ Requires-Python: >=3.9
9
+ Classifier: Development Status :: 5 - Production/Stable
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Natural Language :: English
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Software Development :: Libraries
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: License :: OSI Approved :: Python Software Foundation License
21
+ Project-URL: Bug Tracker, https://github.com/aio-libs/aiohappyeyeballs/issues
22
+ Project-URL: Changelog, https://github.com/aio-libs/aiohappyeyeballs/blob/main/CHANGELOG.md
23
+ Project-URL: Documentation, https://aiohappyeyeballs.readthedocs.io
24
+ Project-URL: Repository, https://github.com/aio-libs/aiohappyeyeballs
25
+ Description-Content-Type: text/markdown
26
+
27
+ # aiohappyeyeballs
28
+
29
+ <p align="center">
30
+ <a href="https://github.com/aio-libs/aiohappyeyeballs/actions/workflows/ci.yml?query=branch%3Amain">
31
+ <img src="https://img.shields.io/github/actions/workflow/status/aio-libs/aiohappyeyeballs/ci-cd.yml?branch=main&label=CI&logo=github&style=flat-square" alt="CI Status" >
32
+ </a>
33
+ <a href="https://aiohappyeyeballs.readthedocs.io">
34
+ <img src="https://img.shields.io/readthedocs/aiohappyeyeballs.svg?logo=read-the-docs&logoColor=fff&style=flat-square" alt="Documentation Status">
35
+ </a>
36
+ <a href="https://codecov.io/gh/aio-libs/aiohappyeyeballs">
37
+ <img src="https://img.shields.io/codecov/c/github/aio-libs/aiohappyeyeballs.svg?logo=codecov&logoColor=fff&style=flat-square" alt="Test coverage percentage">
38
+ </a>
39
+ </p>
40
+ <p align="center">
41
+ <a href="https://python-poetry.org/">
42
+ <img src="https://img.shields.io/badge/packaging-poetry-299bd7?style=flat-square&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA4AAAASCAYAAABrXO8xAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAJJSURBVHgBfZLPa1NBEMe/s7tNXoxW1KJQKaUHkXhQvHgW6UHQQ09CBS/6V3hKc/AP8CqCrUcpmop3Cx48eDB4yEECjVQrlZb80CRN8t6OM/teagVxYZi38+Yz853dJbzoMV3MM8cJUcLMSUKIE8AzQ2PieZzFxEJOHMOgMQQ+dUgSAckNXhapU/NMhDSWLs1B24A8sO1xrN4NECkcAC9ASkiIJc6k5TRiUDPhnyMMdhKc+Zx19l6SgyeW76BEONY9exVQMzKExGKwwPsCzza7KGSSWRWEQhyEaDXp6ZHEr416ygbiKYOd7TEWvvcQIeusHYMJGhTwF9y7sGnSwaWyFAiyoxzqW0PM/RjghPxF2pWReAowTEXnDh0xgcLs8l2YQmOrj3N7ByiqEoH0cARs4u78WgAVkoEDIDoOi3AkcLOHU60RIg5wC4ZuTC7FaHKQm8Hq1fQuSOBvX/sodmNJSB5geaF5CPIkUeecdMxieoRO5jz9bheL6/tXjrwCyX/UYBUcjCaWHljx1xiX6z9xEjkYAzbGVnB8pvLmyXm9ep+W8CmsSHQQY77Zx1zboxAV0w7ybMhQmfqdmmw3nEp1I0Z+FGO6M8LZdoyZnuzzBdjISicKRnpxzI9fPb+0oYXsNdyi+d3h9bm9MWYHFtPeIZfLwzmFDKy1ai3p+PDls1Llz4yyFpferxjnyjJDSEy9CaCx5m2cJPerq6Xm34eTrZt3PqxYO1XOwDYZrFlH1fWnpU38Y9HRze3lj0vOujZcXKuuXm3jP+s3KbZVra7y2EAAAAAASUVORK5CYII=" alt="Poetry">
43
+ </a>
44
+ <a href="https://github.com/astral-sh/ruff">
45
+ <img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff">
46
+ </a>
47
+ <a href="https://github.com/pre-commit/pre-commit">
48
+ <img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white&style=flat-square" alt="pre-commit">
49
+ </a>
50
+ </p>
51
+ <p align="center">
52
+ <a href="https://pypi.org/project/aiohappyeyeballs/">
53
+ <img src="https://img.shields.io/pypi/v/aiohappyeyeballs.svg?logo=python&logoColor=fff&style=flat-square" alt="PyPI Version">
54
+ </a>
55
+ <img src="https://img.shields.io/pypi/pyversions/aiohappyeyeballs.svg?style=flat-square&logo=python&amp;logoColor=fff" alt="Supported Python versions">
56
+ <img src="https://img.shields.io/pypi/l/aiohappyeyeballs.svg?style=flat-square" alt="License">
57
+ </p>
58
+
59
+ ---
60
+
61
+ **Documentation**: <a href="https://aiohappyeyeballs.readthedocs.io" target="_blank">https://aiohappyeyeballs.readthedocs.io </a>
62
+
63
+ **Source Code**: <a href="https://github.com/aio-libs/aiohappyeyeballs" target="_blank">https://github.com/aio-libs/aiohappyeyeballs </a>
64
+
65
+ ---
66
+
67
+ [Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
68
+ ([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
69
+
70
+ ## Use case
71
+
72
+ This library exists to allow connecting with
73
+ [Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
74
+ ([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
75
+ when you
76
+ already have a list of addrinfo and not a DNS name.
77
+
78
+ The stdlib version of `loop.create_connection()`
79
+ will only work when you pass in an unresolved name which
80
+ is not a good fit when using DNS caching or resolving
81
+ names via another method such as `zeroconf`.
82
+
83
+ ## Installation
84
+
85
+ Install this via pip (or your favourite package manager):
86
+
87
+ `pip install aiohappyeyeballs`
88
+
89
+ ## License
90
+
91
+ [aiohappyeyeballs is licensed under the same terms as cpython itself.](https://github.com/python/cpython/blob/main/LICENSE)
92
+
93
+ ## Example usage
94
+
95
+ ```python
96
+
97
+ addr_infos = await loop.getaddrinfo("example.org", 80)
98
+
99
+ socket = await start_connection(addr_infos)
100
+ socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, happy_eyeballs_delay=0.2)
101
+
102
+ transport, protocol = await loop.create_connection(
103
+ MyProtocol, sock=socket, ...)
104
+
105
+ # Remove the first address for each family from addr_info
106
+ pop_addr_infos_interleave(addr_info, 1)
107
+
108
+ # Remove all matching address from addr_info
109
+ remove_addr_infos(addr_info, "dead::beef::")
110
+
111
+ # Convert a local_addr to local_addr_infos
112
+ local_addr_infos = addr_to_addr_infos(("127.0.0.1",0))
113
+ ```
114
+
115
+ ## Credits
116
+
117
+ This package contains code from cpython and is licensed under the same terms as cpython itself.
118
+
119
+ This package was created with
120
+ [Copier](https://copier.readthedocs.io/) and the
121
+ [browniebroke/pypackage-template](https://github.com/browniebroke/pypackage-template)
122
+ project template.
123
+
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/RECORD ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs-2.6.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ aiohappyeyeballs-2.6.1.dist-info/LICENSE,sha256=Oy-B_iHRgcSZxZolbI4ZaEVdZonSaaqFNzv7avQdo78,13936
3
+ aiohappyeyeballs-2.6.1.dist-info/METADATA,sha256=NSXlhJwAfi380eEjAo7BQ4P_TVal9xi0qkyZWibMsVM,5915
4
+ aiohappyeyeballs-2.6.1.dist-info/RECORD,,
5
+ aiohappyeyeballs-2.6.1.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
6
+ aiohappyeyeballs/__init__.py,sha256=x7kktHEtaD9quBcWDJPuLeKyjuVAI-Jj14S9B_5hcTs,361
7
+ aiohappyeyeballs/__pycache__/__init__.cpython-312.pyc,,
8
+ aiohappyeyeballs/__pycache__/_staggered.cpython-312.pyc,,
9
+ aiohappyeyeballs/__pycache__/impl.cpython-312.pyc,,
10
+ aiohappyeyeballs/__pycache__/types.cpython-312.pyc,,
11
+ aiohappyeyeballs/__pycache__/utils.cpython-312.pyc,,
12
+ aiohappyeyeballs/_staggered.py,sha256=edfVowFx-P-ywJjIEF3MdPtEMVODujV6CeMYr65otac,6900
13
+ aiohappyeyeballs/impl.py,sha256=Dlcm2mTJ28ucrGnxkb_fo9CZzLAkOOBizOt7dreBbXE,9681
14
+ aiohappyeyeballs/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ aiohappyeyeballs/types.py,sha256=YZJIAnyoV4Dz0WFtlaf_OyE4EW7Xus1z7aIfNI6tDDQ,425
16
+ aiohappyeyeballs/utils.py,sha256=on9GxIR0LhEfZu8P6Twi9hepX9zDanuZM20MWsb3xlQ,3028
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/WHEEL ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 2.1.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
venv/Lib/site-packages/aiohappyeyeballs/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.6.1"
2
+
3
+ from .impl import start_connection
4
+ from .types import AddrInfoType, SocketFactoryType
5
+ from .utils import addr_to_addr_infos, pop_addr_infos_interleave, remove_addr_infos
6
+
7
+ __all__ = (
8
+ "AddrInfoType",
9
+ "SocketFactoryType",
10
+ "addr_to_addr_infos",
11
+ "pop_addr_infos_interleave",
12
+ "remove_addr_infos",
13
+ "start_connection",
14
+ )
venv/Lib/site-packages/aiohappyeyeballs/_staggered.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import contextlib
3
+
4
+ # PY3.9: Import Callable from typing until we drop Python 3.9 support
5
+ # https://github.com/python/cpython/issues/87131
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Awaitable,
10
+ Callable,
11
+ Iterable,
12
+ List,
13
+ Optional,
14
+ Set,
15
+ Tuple,
16
+ TypeVar,
17
+ Union,
18
+ )
19
+
20
+ _T = TypeVar("_T")
21
+
22
+ RE_RAISE_EXCEPTIONS = (SystemExit, KeyboardInterrupt)
23
+
24
+
25
+ def _set_result(wait_next: "asyncio.Future[None]") -> None:
26
+ """Set the result of a future if it is not already done."""
27
+ if not wait_next.done():
28
+ wait_next.set_result(None)
29
+
30
+
31
+ async def _wait_one(
32
+ futures: "Iterable[asyncio.Future[Any]]",
33
+ loop: asyncio.AbstractEventLoop,
34
+ ) -> _T:
35
+ """Wait for the first future to complete."""
36
+ wait_next = loop.create_future()
37
+
38
+ def _on_completion(fut: "asyncio.Future[Any]") -> None:
39
+ if not wait_next.done():
40
+ wait_next.set_result(fut)
41
+
42
+ for f in futures:
43
+ f.add_done_callback(_on_completion)
44
+
45
+ try:
46
+ return await wait_next
47
+ finally:
48
+ for f in futures:
49
+ f.remove_done_callback(_on_completion)
50
+
51
+
52
+ async def staggered_race(
53
+ coro_fns: Iterable[Callable[[], Awaitable[_T]]],
54
+ delay: Optional[float],
55
+ *,
56
+ loop: Optional[asyncio.AbstractEventLoop] = None,
57
+ ) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
58
+ """
59
+ Run coroutines with staggered start times and take the first to finish.
60
+
61
+ This method takes an iterable of coroutine functions. The first one is
62
+ started immediately. From then on, whenever the immediately preceding one
63
+ fails (raises an exception), or when *delay* seconds has passed, the next
64
+ coroutine is started. This continues until one of the coroutines complete
65
+ successfully, in which case all others are cancelled, or until all
66
+ coroutines fail.
67
+
68
+ The coroutines provided should be well-behaved in the following way:
69
+
70
+ * They should only ``return`` if completed successfully.
71
+
72
+ * They should always raise an exception if they did not complete
73
+ successfully. In particular, if they handle cancellation, they should
74
+ probably reraise, like this::
75
+
76
+ try:
77
+ # do work
78
+ except asyncio.CancelledError:
79
+ # undo partially completed work
80
+ raise
81
+
82
+ Args:
83
+ ----
84
+ coro_fns: an iterable of coroutine functions, i.e. callables that
85
+ return a coroutine object when called. Use ``functools.partial`` or
86
+ lambdas to pass arguments.
87
+
88
+ delay: amount of time, in seconds, between starting coroutines. If
89
+ ``None``, the coroutines will run sequentially.
90
+
91
+ loop: the event loop to use. If ``None``, the running loop is used.
92
+
93
+ Returns:
94
+ -------
95
+ tuple *(winner_result, winner_index, exceptions)* where
96
+
97
+ - *winner_result*: the result of the winning coroutine, or ``None``
98
+ if no coroutines won.
99
+
100
+ - *winner_index*: the index of the winning coroutine in
101
+ ``coro_fns``, or ``None`` if no coroutines won. If the winning
102
+ coroutine may return None on success, *winner_index* can be used
103
+ to definitively determine whether any coroutine won.
104
+
105
+ - *exceptions*: list of exceptions returned by the coroutines.
106
+ ``len(exceptions)`` is equal to the number of coroutines actually
107
+ started, and the order is the same as in ``coro_fns``. The winning
108
+ coroutine's entry is ``None``.
109
+
110
+ """
111
+ loop = loop or asyncio.get_running_loop()
112
+ exceptions: List[Optional[BaseException]] = []
113
+ tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
114
+
115
+ async def run_one_coro(
116
+ coro_fn: Callable[[], Awaitable[_T]],
117
+ this_index: int,
118
+ start_next: "asyncio.Future[None]",
119
+ ) -> Optional[Tuple[_T, int]]:
120
+ """
121
+ Run a single coroutine.
122
+
123
+ If the coroutine fails, set the exception in the exceptions list and
124
+ start the next coroutine by setting the result of the start_next.
125
+
126
+ If the coroutine succeeds, return the result and the index of the
127
+ coroutine in the coro_fns list.
128
+
129
+ If SystemExit or KeyboardInterrupt is raised, re-raise it.
130
+ """
131
+ try:
132
+ result = await coro_fn()
133
+ except RE_RAISE_EXCEPTIONS:
134
+ raise
135
+ except BaseException as e:
136
+ exceptions[this_index] = e
137
+ _set_result(start_next) # Kickstart the next coroutine
138
+ return None
139
+
140
+ return result, this_index
141
+
142
+ start_next_timer: Optional[asyncio.TimerHandle] = None
143
+ start_next: Optional[asyncio.Future[None]]
144
+ task: asyncio.Task[Optional[Tuple[_T, int]]]
145
+ done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
146
+ coro_iter = iter(coro_fns)
147
+ this_index = -1
148
+ try:
149
+ while True:
150
+ if coro_fn := next(coro_iter, None):
151
+ this_index += 1
152
+ exceptions.append(None)
153
+ start_next = loop.create_future()
154
+ task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
155
+ tasks.add(task)
156
+ start_next_timer = (
157
+ loop.call_later(delay, _set_result, start_next) if delay else None
158
+ )
159
+ elif not tasks:
160
+ # We exhausted the coro_fns list and no tasks are running
161
+ # so we have no winner and all coroutines failed.
162
+ break
163
+
164
+ while tasks or start_next:
165
+ done = await _wait_one(
166
+ (*tasks, start_next) if start_next else tasks, loop
167
+ )
168
+ if done is start_next:
169
+ # The current task has failed or the timer has expired
170
+ # so we need to start the next task.
171
+ start_next = None
172
+ if start_next_timer:
173
+ start_next_timer.cancel()
174
+ start_next_timer = None
175
+
176
+ # Break out of the task waiting loop to start the next
177
+ # task.
178
+ break
179
+
180
+ if TYPE_CHECKING:
181
+ assert isinstance(done, asyncio.Task)
182
+
183
+ tasks.remove(done)
184
+ if winner := done.result():
185
+ return *winner, exceptions
186
+ finally:
187
+ # We either have:
188
+ # - a winner
189
+ # - all tasks failed
190
+ # - a KeyboardInterrupt or SystemExit.
191
+
192
+ #
193
+ # If the timer is still running, cancel it.
194
+ #
195
+ if start_next_timer:
196
+ start_next_timer.cancel()
197
+
198
+ #
199
+ # If there are any tasks left, cancel them and than
200
+ # wait them so they fill the exceptions list.
201
+ #
202
+ for task in tasks:
203
+ task.cancel()
204
+ with contextlib.suppress(asyncio.CancelledError):
205
+ await task
206
+
207
+ return None, None, exceptions
venv/Lib/site-packages/aiohappyeyeballs/impl.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base implementation."""
2
+
3
+ import asyncio
4
+ import collections
5
+ import contextlib
6
+ import functools
7
+ import itertools
8
+ import socket
9
+ from typing import List, Optional, Sequence, Set, Union
10
+
11
+ from . import _staggered
12
+ from .types import AddrInfoType, SocketFactoryType
13
+
14
+
15
+ async def start_connection(
16
+ addr_infos: Sequence[AddrInfoType],
17
+ *,
18
+ local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
19
+ happy_eyeballs_delay: Optional[float] = None,
20
+ interleave: Optional[int] = None,
21
+ loop: Optional[asyncio.AbstractEventLoop] = None,
22
+ socket_factory: Optional[SocketFactoryType] = None,
23
+ ) -> socket.socket:
24
+ """
25
+ Connect to a TCP server.
26
+
27
+ Create a socket connection to a specified destination. The
28
+ destination is specified as a list of AddrInfoType tuples as
29
+ returned from getaddrinfo().
30
+
31
+ The arguments are, in order:
32
+
33
+ * ``family``: the address family, e.g. ``socket.AF_INET`` or
34
+ ``socket.AF_INET6``.
35
+ * ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
36
+ ``socket.SOCK_DGRAM``.
37
+ * ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
38
+ ``socket.IPPROTO_UDP``.
39
+ * ``canonname``: the canonical name of the address, e.g.
40
+ ``"www.python.org"``.
41
+ * ``sockaddr``: the socket address
42
+
43
+ This method is a coroutine which will try to establish the connection
44
+ in the background. When successful, the coroutine returns a
45
+ socket.
46
+
47
+ The expected use case is to use this method in conjunction with
48
+ loop.create_connection() to establish a connection to a server::
49
+
50
+ socket = await start_connection(addr_infos)
51
+ transport, protocol = await loop.create_connection(
52
+ MyProtocol, sock=socket, ...)
53
+ """
54
+ if not (current_loop := loop):
55
+ current_loop = asyncio.get_running_loop()
56
+
57
+ single_addr_info = len(addr_infos) == 1
58
+
59
+ if happy_eyeballs_delay is not None and interleave is None:
60
+ # If using happy eyeballs, default to interleave addresses by family
61
+ interleave = 1
62
+
63
+ if interleave and not single_addr_info:
64
+ addr_infos = _interleave_addrinfos(addr_infos, interleave)
65
+
66
+ sock: Optional[socket.socket] = None
67
+ # uvloop can raise RuntimeError instead of OSError
68
+ exceptions: List[List[Union[OSError, RuntimeError]]] = []
69
+ if happy_eyeballs_delay is None or single_addr_info:
70
+ # not using happy eyeballs
71
+ for addrinfo in addr_infos:
72
+ try:
73
+ sock = await _connect_sock(
74
+ current_loop,
75
+ exceptions,
76
+ addrinfo,
77
+ local_addr_infos,
78
+ None,
79
+ socket_factory,
80
+ )
81
+ break
82
+ except (RuntimeError, OSError):
83
+ continue
84
+ else: # using happy eyeballs
85
+ open_sockets: Set[socket.socket] = set()
86
+ try:
87
+ sock, _, _ = await _staggered.staggered_race(
88
+ (
89
+ functools.partial(
90
+ _connect_sock,
91
+ current_loop,
92
+ exceptions,
93
+ addrinfo,
94
+ local_addr_infos,
95
+ open_sockets,
96
+ socket_factory,
97
+ )
98
+ for addrinfo in addr_infos
99
+ ),
100
+ happy_eyeballs_delay,
101
+ )
102
+ finally:
103
+ # If we have a winner, staggered_race will
104
+ # cancel the other tasks, however there is a
105
+ # small race window where any of the other tasks
106
+ # can be done before they are cancelled which
107
+ # will leave the socket open. To avoid this problem
108
+ # we pass a set to _connect_sock to keep track of
109
+ # the open sockets and close them here if there
110
+ # are any "runner up" sockets.
111
+ for s in open_sockets:
112
+ if s is not sock:
113
+ with contextlib.suppress(OSError):
114
+ s.close()
115
+ open_sockets = None # type: ignore[assignment]
116
+
117
+ if sock is None:
118
+ all_exceptions = [exc for sub in exceptions for exc in sub]
119
+ try:
120
+ first_exception = all_exceptions[0]
121
+ if len(all_exceptions) == 1:
122
+ raise first_exception
123
+ else:
124
+ # If they all have the same str(), raise one.
125
+ model = str(first_exception)
126
+ if all(str(exc) == model for exc in all_exceptions):
127
+ raise first_exception
128
+ # Raise a combined exception so the user can see all
129
+ # the various error messages.
130
+ msg = "Multiple exceptions: {}".format(
131
+ ", ".join(str(exc) for exc in all_exceptions)
132
+ )
133
+ # If the errno is the same for all exceptions, raise
134
+ # an OSError with that errno.
135
+ if isinstance(first_exception, OSError):
136
+ first_errno = first_exception.errno
137
+ if all(
138
+ isinstance(exc, OSError) and exc.errno == first_errno
139
+ for exc in all_exceptions
140
+ ):
141
+ raise OSError(first_errno, msg)
142
+ elif isinstance(first_exception, RuntimeError) and all(
143
+ isinstance(exc, RuntimeError) for exc in all_exceptions
144
+ ):
145
+ raise RuntimeError(msg)
146
+ # We have a mix of OSError and RuntimeError
147
+ # so we have to pick which one to raise.
148
+ # and we raise OSError for compatibility
149
+ raise OSError(msg)
150
+ finally:
151
+ all_exceptions = None # type: ignore[assignment]
152
+ exceptions = None # type: ignore[assignment]
153
+
154
+ return sock
155
+
156
+
157
+ async def _connect_sock(
158
+ loop: asyncio.AbstractEventLoop,
159
+ exceptions: List[List[Union[OSError, RuntimeError]]],
160
+ addr_info: AddrInfoType,
161
+ local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
162
+ open_sockets: Optional[Set[socket.socket]] = None,
163
+ socket_factory: Optional[SocketFactoryType] = None,
164
+ ) -> socket.socket:
165
+ """
166
+ Create, bind and connect one socket.
167
+
168
+ If open_sockets is passed, add the socket to the set of open sockets.
169
+ Any failure caught here will remove the socket from the set and close it.
170
+
171
+ Callers can use this set to close any sockets that are not the winner
172
+ of all staggered tasks in the result there are runner up sockets aka
173
+ multiple winners.
174
+ """
175
+ my_exceptions: List[Union[OSError, RuntimeError]] = []
176
+ exceptions.append(my_exceptions)
177
+ family, type_, proto, _, address = addr_info
178
+ sock = None
179
+ try:
180
+ if socket_factory is not None:
181
+ sock = socket_factory(addr_info)
182
+ else:
183
+ sock = socket.socket(family=family, type=type_, proto=proto)
184
+ if open_sockets is not None:
185
+ open_sockets.add(sock)
186
+ sock.setblocking(False)
187
+ if local_addr_infos is not None:
188
+ for lfamily, _, _, _, laddr in local_addr_infos:
189
+ # skip local addresses of different family
190
+ if lfamily != family:
191
+ continue
192
+ try:
193
+ sock.bind(laddr)
194
+ break
195
+ except OSError as exc:
196
+ msg = (
197
+ f"error while attempting to bind on "
198
+ f"address {laddr!r}: "
199
+ f"{(exc.strerror or '').lower()}"
200
+ )
201
+ exc = OSError(exc.errno, msg)
202
+ my_exceptions.append(exc)
203
+ else: # all bind attempts failed
204
+ if my_exceptions:
205
+ raise my_exceptions.pop()
206
+ else:
207
+ raise OSError(f"no matching local address with {family=} found")
208
+ await loop.sock_connect(sock, address)
209
+ return sock
210
+ except (RuntimeError, OSError) as exc:
211
+ my_exceptions.append(exc)
212
+ if sock is not None:
213
+ if open_sockets is not None:
214
+ open_sockets.remove(sock)
215
+ try:
216
+ sock.close()
217
+ except OSError as e:
218
+ my_exceptions.append(e)
219
+ raise
220
+ raise
221
+ except:
222
+ if sock is not None:
223
+ if open_sockets is not None:
224
+ open_sockets.remove(sock)
225
+ try:
226
+ sock.close()
227
+ except OSError as e:
228
+ my_exceptions.append(e)
229
+ raise
230
+ raise
231
+ finally:
232
+ exceptions = my_exceptions = None # type: ignore[assignment]
233
+
234
+
235
+ def _interleave_addrinfos(
236
+ addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
237
+ ) -> List[AddrInfoType]:
238
+ """Interleave list of addrinfo tuples by family."""
239
+ # Group addresses by family
240
+ addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = (
241
+ collections.OrderedDict()
242
+ )
243
+ for addr in addrinfos:
244
+ family = addr[0]
245
+ if family not in addrinfos_by_family:
246
+ addrinfos_by_family[family] = []
247
+ addrinfos_by_family[family].append(addr)
248
+ addrinfos_lists = list(addrinfos_by_family.values())
249
+
250
+ reordered: List[AddrInfoType] = []
251
+ if first_address_family_count > 1:
252
+ reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
253
+ del addrinfos_lists[0][: first_address_family_count - 1]
254
+ reordered.extend(
255
+ a
256
+ for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
257
+ if a is not None
258
+ )
259
+ return reordered
venv/Lib/site-packages/aiohappyeyeballs/py.typed ADDED
File without changes
venv/Lib/site-packages/aiohappyeyeballs/types.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Types for aiohappyeyeballs."""
2
+
3
+ import socket
4
+
5
+ # PY3.9: Import Callable from typing until we drop Python 3.9 support
6
+ # https://github.com/python/cpython/issues/87131
7
+ from typing import Callable, Tuple, Union
8
+
9
+ AddrInfoType = Tuple[
10
+ Union[int, socket.AddressFamily],
11
+ Union[int, socket.SocketKind],
12
+ int,
13
+ str,
14
+ Tuple, # type: ignore[type-arg]
15
+ ]
16
+
17
+ SocketFactoryType = Callable[[AddrInfoType], socket.socket]
venv/Lib/site-packages/aiohappyeyeballs/utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for aiohappyeyeballs."""
2
+
3
+ import ipaddress
4
+ import socket
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ from .types import AddrInfoType
8
+
9
+
10
+ def addr_to_addr_infos(
11
+ addr: Optional[
12
+ Union[Tuple[str, int, int, int], Tuple[str, int, int], Tuple[str, int]]
13
+ ],
14
+ ) -> Optional[List[AddrInfoType]]:
15
+ """Convert an address tuple to a list of addr_info tuples."""
16
+ if addr is None:
17
+ return None
18
+ host = addr[0]
19
+ port = addr[1]
20
+ is_ipv6 = ":" in host
21
+ if is_ipv6:
22
+ flowinfo = 0
23
+ scopeid = 0
24
+ addr_len = len(addr)
25
+ if addr_len >= 4:
26
+ scopeid = addr[3] # type: ignore[misc]
27
+ if addr_len >= 3:
28
+ flowinfo = addr[2] # type: ignore[misc]
29
+ addr = (host, port, flowinfo, scopeid)
30
+ family = socket.AF_INET6
31
+ else:
32
+ addr = (host, port)
33
+ family = socket.AF_INET
34
+ return [(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)]
35
+
36
+
37
+ def pop_addr_infos_interleave(
38
+ addr_infos: List[AddrInfoType], interleave: Optional[int] = None
39
+ ) -> None:
40
+ """
41
+ Pop addr_info from the list of addr_infos by family up to interleave times.
42
+
43
+ The interleave parameter is used to know how many addr_infos for
44
+ each family should be popped of the top of the list.
45
+ """
46
+ seen: Dict[int, int] = {}
47
+ if interleave is None:
48
+ interleave = 1
49
+ to_remove: List[AddrInfoType] = []
50
+ for addr_info in addr_infos:
51
+ family = addr_info[0]
52
+ if family not in seen:
53
+ seen[family] = 0
54
+ if seen[family] < interleave:
55
+ to_remove.append(addr_info)
56
+ seen[family] += 1
57
+ for addr_info in to_remove:
58
+ addr_infos.remove(addr_info)
59
+
60
+
61
+ def _addr_tuple_to_ip_address(
62
+ addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
63
+ ) -> Union[
64
+ Tuple[ipaddress.IPv4Address, int], Tuple[ipaddress.IPv6Address, int, int, int]
65
+ ]:
66
+ """Convert an address tuple to an IPv4Address."""
67
+ return (ipaddress.ip_address(addr[0]), *addr[1:])
68
+
69
+
70
+ def remove_addr_infos(
71
+ addr_infos: List[AddrInfoType],
72
+ addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
73
+ ) -> None:
74
+ """
75
+ Remove an address from the list of addr_infos.
76
+
77
+ The addr value is typically the return value of
78
+ sock.getpeername().
79
+ """
80
+ bad_addrs_infos: List[AddrInfoType] = []
81
+ for addr_info in addr_infos:
82
+ if addr_info[-1] == addr:
83
+ bad_addrs_infos.append(addr_info)
84
+ if bad_addrs_infos:
85
+ for bad_addr_info in bad_addrs_infos:
86
+ addr_infos.remove(bad_addr_info)
87
+ return
88
+ # Slow path in case addr is formatted differently
89
+ match_addr = _addr_tuple_to_ip_address(addr)
90
+ for addr_info in addr_infos:
91
+ if match_addr == _addr_tuple_to_ip_address(addr_info[-1]):
92
+ bad_addrs_infos.append(addr_info)
93
+ if bad_addrs_infos:
94
+ for bad_addr_info in bad_addrs_infos:
95
+ addr_infos.remove(bad_addr_info)
96
+ return
97
+ raise ValueError(f"Address {addr} not found in addr_infos")
venv/Lib/site-packages/aiohttp/abc.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import socket
4
+ import zlib
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Sized
7
+ from http.cookies import BaseCookie, Morsel
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Awaitable,
12
+ Callable,
13
+ Dict,
14
+ Generator,
15
+ Iterable,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ TypedDict,
20
+ Union,
21
+ )
22
+
23
+ from multidict import CIMultiDict
24
+ from yarl import URL
25
+
26
+ from .typedefs import LooseCookies
27
+
28
+ if TYPE_CHECKING:
29
+ from .web_app import Application
30
+ from .web_exceptions import HTTPException
31
+ from .web_request import BaseRequest, Request
32
+ from .web_response import StreamResponse
33
+ else:
34
+ BaseRequest = Request = Application = StreamResponse = None
35
+ HTTPException = None
36
+
37
+
38
+ class AbstractRouter(ABC):
39
+ def __init__(self) -> None:
40
+ self._frozen = False
41
+
42
+ def post_init(self, app: Application) -> None:
43
+ """Post init stage.
44
+
45
+ Not an abstract method for sake of backward compatibility,
46
+ but if the router wants to be aware of the application
47
+ it can override this.
48
+ """
49
+
50
+ @property
51
+ def frozen(self) -> bool:
52
+ return self._frozen
53
+
54
+ def freeze(self) -> None:
55
+ """Freeze router."""
56
+ self._frozen = True
57
+
58
+ @abstractmethod
59
+ async def resolve(self, request: Request) -> "AbstractMatchInfo":
60
+ """Return MATCH_INFO for given request"""
61
+
62
+
63
+ class AbstractMatchInfo(ABC):
64
+
65
+ __slots__ = ()
66
+
67
+ @property # pragma: no branch
68
+ @abstractmethod
69
+ def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
70
+ """Execute matched request handler"""
71
+
72
+ @property
73
+ @abstractmethod
74
+ def expect_handler(
75
+ self,
76
+ ) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]:
77
+ """Expect handler for 100-continue processing"""
78
+
79
+ @property # pragma: no branch
80
+ @abstractmethod
81
+ def http_exception(self) -> Optional[HTTPException]:
82
+ """HTTPException instance raised on router's resolving, or None"""
83
+
84
+ @abstractmethod # pragma: no branch
85
+ def get_info(self) -> Dict[str, Any]:
86
+ """Return a dict with additional info useful for introspection"""
87
+
88
+ @property # pragma: no branch
89
+ @abstractmethod
90
+ def apps(self) -> Tuple[Application, ...]:
91
+ """Stack of nested applications.
92
+
93
+ Top level application is left-most element.
94
+
95
+ """
96
+
97
+ @abstractmethod
98
+ def add_app(self, app: Application) -> None:
99
+ """Add application to the nested apps stack."""
100
+
101
+ @abstractmethod
102
+ def freeze(self) -> None:
103
+ """Freeze the match info.
104
+
105
+ The method is called after route resolution.
106
+
107
+ After the call .add_app() is forbidden.
108
+
109
+ """
110
+
111
+
112
+ class AbstractView(ABC):
113
+ """Abstract class based view."""
114
+
115
+ def __init__(self, request: Request) -> None:
116
+ self._request = request
117
+
118
+ @property
119
+ def request(self) -> Request:
120
+ """Request instance."""
121
+ return self._request
122
+
123
+ @abstractmethod
124
+ def __await__(self) -> Generator[Any, None, StreamResponse]:
125
+ """Execute the view handler."""
126
+
127
+
128
+ class ResolveResult(TypedDict):
129
+ """Resolve result.
130
+
131
+ This is the result returned from an AbstractResolver's
132
+ resolve method.
133
+
134
+ :param hostname: The hostname that was provided.
135
+ :param host: The IP address that was resolved.
136
+ :param port: The port that was resolved.
137
+ :param family: The address family that was resolved.
138
+ :param proto: The protocol that was resolved.
139
+ :param flags: The flags that were resolved.
140
+ """
141
+
142
+ hostname: str
143
+ host: str
144
+ port: int
145
+ family: int
146
+ proto: int
147
+ flags: int
148
+
149
+
150
+ class AbstractResolver(ABC):
151
+ """Abstract DNS resolver."""
152
+
153
+ @abstractmethod
154
+ async def resolve(
155
+ self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
156
+ ) -> List[ResolveResult]:
157
+ """Return IP address for given hostname"""
158
+
159
+ @abstractmethod
160
+ async def close(self) -> None:
161
+ """Release resolver"""
162
+
163
+
164
+ if TYPE_CHECKING:
165
+ IterableBase = Iterable[Morsel[str]]
166
+ else:
167
+ IterableBase = Iterable
168
+
169
+
170
+ ClearCookiePredicate = Callable[["Morsel[str]"], bool]
171
+
172
+
173
+ class AbstractCookieJar(Sized, IterableBase):
174
+ """Abstract Cookie Jar."""
175
+
176
+ def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
177
+ self._loop = loop or asyncio.get_running_loop()
178
+
179
+ @property
180
+ @abstractmethod
181
+ def quote_cookie(self) -> bool:
182
+ """Return True if cookies should be quoted."""
183
+
184
+ @abstractmethod
185
+ def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
186
+ """Clear all cookies if no predicate is passed."""
187
+
188
+ @abstractmethod
189
+ def clear_domain(self, domain: str) -> None:
190
+ """Clear all cookies for domain and all subdomains."""
191
+
192
+ @abstractmethod
193
+ def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
194
+ """Update cookies."""
195
+
196
+ @abstractmethod
197
+ def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
198
+ """Return the jar's cookies filtered by their attributes."""
199
+
200
+
201
+ class AbstractStreamWriter(ABC):
202
+ """Abstract stream writer."""
203
+
204
+ buffer_size: int = 0
205
+ output_size: int = 0
206
+ length: Optional[int] = 0
207
+
208
+ @abstractmethod
209
+ async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
210
+ """Write chunk into stream."""
211
+
212
+ @abstractmethod
213
+ async def write_eof(self, chunk: bytes = b"") -> None:
214
+ """Write last chunk."""
215
+
216
+ @abstractmethod
217
+ async def drain(self) -> None:
218
+ """Flush the write buffer."""
219
+
220
+ @abstractmethod
221
+ def enable_compression(
222
+ self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
223
+ ) -> None:
224
+ """Enable HTTP body compression"""
225
+
226
+ @abstractmethod
227
+ def enable_chunking(self) -> None:
228
+ """Enable HTTP chunked mode"""
229
+
230
+ @abstractmethod
231
+ async def write_headers(
232
+ self, status_line: str, headers: "CIMultiDict[str]"
233
+ ) -> None:
234
+ """Write HTTP headers"""
235
+
236
+
237
+ class AbstractAccessLogger(ABC):
238
+ """Abstract writer to access log."""
239
+
240
+ __slots__ = ("logger", "log_format")
241
+
242
+ def __init__(self, logger: logging.Logger, log_format: str) -> None:
243
+ self.logger = logger
244
+ self.log_format = log_format
245
+
246
+ @abstractmethod
247
+ def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
248
+ """Emit log to logger."""
249
+
250
+ @property
251
+ def enabled(self) -> bool:
252
+ """Check if logger is enabled."""
253
+ return True
venv/Lib/site-packages/aiohttp/base_protocol.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Optional, cast
3
+
4
+ from .client_exceptions import ClientConnectionResetError
5
+ from .helpers import set_exception
6
+ from .tcp_helpers import tcp_nodelay
7
+
8
+
9
+ class BaseProtocol(asyncio.Protocol):
10
+ __slots__ = (
11
+ "_loop",
12
+ "_paused",
13
+ "_drain_waiter",
14
+ "_connection_lost",
15
+ "_reading_paused",
16
+ "transport",
17
+ )
18
+
19
+ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
20
+ self._loop: asyncio.AbstractEventLoop = loop
21
+ self._paused = False
22
+ self._drain_waiter: Optional[asyncio.Future[None]] = None
23
+ self._reading_paused = False
24
+
25
+ self.transport: Optional[asyncio.Transport] = None
26
+
27
+ @property
28
+ def connected(self) -> bool:
29
+ """Return True if the connection is open."""
30
+ return self.transport is not None
31
+
32
+ @property
33
+ def writing_paused(self) -> bool:
34
+ return self._paused
35
+
36
+ def pause_writing(self) -> None:
37
+ assert not self._paused
38
+ self._paused = True
39
+
40
+ def resume_writing(self) -> None:
41
+ assert self._paused
42
+ self._paused = False
43
+
44
+ waiter = self._drain_waiter
45
+ if waiter is not None:
46
+ self._drain_waiter = None
47
+ if not waiter.done():
48
+ waiter.set_result(None)
49
+
50
+ def pause_reading(self) -> None:
51
+ if not self._reading_paused and self.transport is not None:
52
+ try:
53
+ self.transport.pause_reading()
54
+ except (AttributeError, NotImplementedError, RuntimeError):
55
+ pass
56
+ self._reading_paused = True
57
+
58
+ def resume_reading(self) -> None:
59
+ if self._reading_paused and self.transport is not None:
60
+ try:
61
+ self.transport.resume_reading()
62
+ except (AttributeError, NotImplementedError, RuntimeError):
63
+ pass
64
+ self._reading_paused = False
65
+
66
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
67
+ tr = cast(asyncio.Transport, transport)
68
+ tcp_nodelay(tr, True)
69
+ self.transport = tr
70
+
71
+ def connection_lost(self, exc: Optional[BaseException]) -> None:
72
+ # Wake up the writer if currently paused.
73
+ self.transport = None
74
+ if not self._paused:
75
+ return
76
+ waiter = self._drain_waiter
77
+ if waiter is None:
78
+ return
79
+ self._drain_waiter = None
80
+ if waiter.done():
81
+ return
82
+ if exc is None:
83
+ waiter.set_result(None)
84
+ else:
85
+ set_exception(
86
+ waiter,
87
+ ConnectionError("Connection lost"),
88
+ exc,
89
+ )
90
+
91
+ async def _drain_helper(self) -> None:
92
+ if self.transport is None:
93
+ raise ClientConnectionResetError("Connection lost")
94
+ if not self._paused:
95
+ return
96
+ waiter = self._drain_waiter
97
+ if waiter is None:
98
+ waiter = self._loop.create_future()
99
+ self._drain_waiter = waiter
100
+ await asyncio.shield(waiter)
venv/Lib/site-packages/scipy-1.15.3-cp312-cp312-win_amd64.whl ADDED
File without changes
venv/Lib/site-packages/six.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2010-2024 Benjamin Peterson
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ """Utilities for writing code that runs on Python 2 and 3"""
22
+
23
+ from __future__ import absolute_import
24
+
25
+ import functools
26
+ import itertools
27
+ import operator
28
+ import sys
29
+ import types
30
+
31
+ __author__ = "Benjamin Peterson <[email protected]>"
32
+ __version__ = "1.17.0"
33
+
34
+
35
+ # Useful for very coarse version differentiation.
36
+ PY2 = sys.version_info[0] == 2
37
+ PY3 = sys.version_info[0] == 3
38
+ PY34 = sys.version_info[0:2] >= (3, 4)
39
+
40
+ if PY3:
41
+ string_types = str,
42
+ integer_types = int,
43
+ class_types = type,
44
+ text_type = str
45
+ binary_type = bytes
46
+
47
+ MAXSIZE = sys.maxsize
48
+ else:
49
+ string_types = basestring,
50
+ integer_types = (int, long)
51
+ class_types = (type, types.ClassType)
52
+ text_type = unicode
53
+ binary_type = str
54
+
55
+ if sys.platform.startswith("java"):
56
+ # Jython always uses 32 bits.
57
+ MAXSIZE = int((1 << 31) - 1)
58
+ else:
59
+ # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
60
+ class X(object):
61
+
62
+ def __len__(self):
63
+ return 1 << 31
64
+ try:
65
+ len(X())
66
+ except OverflowError:
67
+ # 32-bit
68
+ MAXSIZE = int((1 << 31) - 1)
69
+ else:
70
+ # 64-bit
71
+ MAXSIZE = int((1 << 63) - 1)
72
+ del X
73
+
74
+ if PY34:
75
+ from importlib.util import spec_from_loader
76
+ else:
77
+ spec_from_loader = None
78
+
79
+
80
+ def _add_doc(func, doc):
81
+ """Add documentation to a function."""
82
+ func.__doc__ = doc
83
+
84
+
85
+ def _import_module(name):
86
+ """Import module, returning the module after the last dot."""
87
+ __import__(name)
88
+ return sys.modules[name]
89
+
90
+
91
+ class _LazyDescr(object):
92
+
93
+ def __init__(self, name):
94
+ self.name = name
95
+
96
+ def __get__(self, obj, tp):
97
+ result = self._resolve()
98
+ setattr(obj, self.name, result) # Invokes __set__.
99
+ try:
100
+ # This is a bit ugly, but it avoids running this again by
101
+ # removing this descriptor.
102
+ delattr(obj.__class__, self.name)
103
+ except AttributeError:
104
+ pass
105
+ return result
106
+
107
+
108
+ class MovedModule(_LazyDescr):
109
+
110
+ def __init__(self, name, old, new=None):
111
+ super(MovedModule, self).__init__(name)
112
+ if PY3:
113
+ if new is None:
114
+ new = name
115
+ self.mod = new
116
+ else:
117
+ self.mod = old
118
+
119
+ def _resolve(self):
120
+ return _import_module(self.mod)
121
+
122
+ def __getattr__(self, attr):
123
+ _module = self._resolve()
124
+ value = getattr(_module, attr)
125
+ setattr(self, attr, value)
126
+ return value
127
+
128
+
129
+ class _LazyModule(types.ModuleType):
130
+
131
+ def __init__(self, name):
132
+ super(_LazyModule, self).__init__(name)
133
+ self.__doc__ = self.__class__.__doc__
134
+
135
+ def __dir__(self):
136
+ attrs = ["__doc__", "__name__"]
137
+ attrs += [attr.name for attr in self._moved_attributes]
138
+ return attrs
139
+
140
+ # Subclasses should override this
141
+ _moved_attributes = []
142
+
143
+
144
+ class MovedAttribute(_LazyDescr):
145
+
146
+ def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
147
+ super(MovedAttribute, self).__init__(name)
148
+ if PY3:
149
+ if new_mod is None:
150
+ new_mod = name
151
+ self.mod = new_mod
152
+ if new_attr is None:
153
+ if old_attr is None:
154
+ new_attr = name
155
+ else:
156
+ new_attr = old_attr
157
+ self.attr = new_attr
158
+ else:
159
+ self.mod = old_mod
160
+ if old_attr is None:
161
+ old_attr = name
162
+ self.attr = old_attr
163
+
164
+ def _resolve(self):
165
+ module = _import_module(self.mod)
166
+ return getattr(module, self.attr)
167
+
168
+
169
+ class _SixMetaPathImporter(object):
170
+
171
+ """
172
+ A meta path importer to import six.moves and its submodules.
173
+
174
+ This class implements a PEP302 finder and loader. It should be compatible
175
+ with Python 2.5 and all existing versions of Python3
176
+ """
177
+
178
+ def __init__(self, six_module_name):
179
+ self.name = six_module_name
180
+ self.known_modules = {}
181
+
182
+ def _add_module(self, mod, *fullnames):
183
+ for fullname in fullnames:
184
+ self.known_modules[self.name + "." + fullname] = mod
185
+
186
+ def _get_module(self, fullname):
187
+ return self.known_modules[self.name + "." + fullname]
188
+
189
+ def find_module(self, fullname, path=None):
190
+ if fullname in self.known_modules:
191
+ return self
192
+ return None
193
+
194
+ def find_spec(self, fullname, path, target=None):
195
+ if fullname in self.known_modules:
196
+ return spec_from_loader(fullname, self)
197
+ return None
198
+
199
+ def __get_module(self, fullname):
200
+ try:
201
+ return self.known_modules[fullname]
202
+ except KeyError:
203
+ raise ImportError("This loader does not know module " + fullname)
204
+
205
+ def load_module(self, fullname):
206
+ try:
207
+ # in case of a reload
208
+ return sys.modules[fullname]
209
+ except KeyError:
210
+ pass
211
+ mod = self.__get_module(fullname)
212
+ if isinstance(mod, MovedModule):
213
+ mod = mod._resolve()
214
+ else:
215
+ mod.__loader__ = self
216
+ sys.modules[fullname] = mod
217
+ return mod
218
+
219
+ def is_package(self, fullname):
220
+ """
221
+ Return true, if the named module is a package.
222
+
223
+ We need this method to get correct spec objects with
224
+ Python 3.4 (see PEP451)
225
+ """
226
+ return hasattr(self.__get_module(fullname), "__path__")
227
+
228
+ def get_code(self, fullname):
229
+ """Return None
230
+
231
+ Required, if is_package is implemented"""
232
+ self.__get_module(fullname) # eventually raises ImportError
233
+ return None
234
+ get_source = get_code # same as get_code
235
+
236
+ def create_module(self, spec):
237
+ return self.load_module(spec.name)
238
+
239
+ def exec_module(self, module):
240
+ pass
241
+
242
+ _importer = _SixMetaPathImporter(__name__)
243
+
244
+
245
+ class _MovedItems(_LazyModule):
246
+
247
+ """Lazy loading of moved objects"""
248
+ __path__ = [] # mark as package
249
+
250
+
251
+ _moved_attributes = [
252
+ MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
253
+ MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
254
+ MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"),
255
+ MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
256
+ MovedAttribute("intern", "__builtin__", "sys"),
257
+ MovedAttribute("map", "itertools", "builtins", "imap", "map"),
258
+ MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
259
+ MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
260
+ MovedAttribute("getoutput", "commands", "subprocess"),
261
+ MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
262
+ MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"),
263
+ MovedAttribute("reduce", "__builtin__", "functools"),
264
+ MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
265
+ MovedAttribute("StringIO", "StringIO", "io"),
266
+ MovedAttribute("UserDict", "UserDict", "collections", "IterableUserDict", "UserDict"),
267
+ MovedAttribute("UserList", "UserList", "collections"),
268
+ MovedAttribute("UserString", "UserString", "collections"),
269
+ MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
270
+ MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
271
+ MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
272
+ MovedModule("builtins", "__builtin__"),
273
+ MovedModule("configparser", "ConfigParser"),
274
+ MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"),
275
+ MovedModule("copyreg", "copy_reg"),
276
+ MovedModule("dbm_gnu", "gdbm", "dbm.gnu"),
277
+ MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"),
278
+ MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"),
279
+ MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
280
+ MovedModule("http_cookies", "Cookie", "http.cookies"),
281
+ MovedModule("html_entities", "htmlentitydefs", "html.entities"),
282
+ MovedModule("html_parser", "HTMLParser", "html.parser"),
283
+ MovedModule("http_client", "httplib", "http.client"),
284
+ MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
285
+ MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
286
+ MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
287
+ MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
288
+ MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
289
+ MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
290
+ MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
291
+ MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
292
+ MovedModule("cPickle", "cPickle", "pickle"),
293
+ MovedModule("queue", "Queue"),
294
+ MovedModule("reprlib", "repr"),
295
+ MovedModule("socketserver", "SocketServer"),
296
+ MovedModule("_thread", "thread", "_thread"),
297
+ MovedModule("tkinter", "Tkinter"),
298
+ MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
299
+ MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
300
+ MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
301
+ MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
302
+ MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
303
+ MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
304
+ MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
305
+ MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
306
+ MovedModule("tkinter_colorchooser", "tkColorChooser",
307
+ "tkinter.colorchooser"),
308
+ MovedModule("tkinter_commondialog", "tkCommonDialog",
309
+ "tkinter.commondialog"),
310
+ MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
311
+ MovedModule("tkinter_font", "tkFont", "tkinter.font"),
312
+ MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
313
+ MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
314
+ "tkinter.simpledialog"),
315
+ MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
316
+ MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
317
+ MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
318
+ MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
319
+ MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"),
320
+ MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"),
321
+ ]
322
+ # Add windows specific modules.
323
+ if sys.platform == "win32":
324
+ _moved_attributes += [
325
+ MovedModule("winreg", "_winreg"),
326
+ ]
327
+
328
+ for attr in _moved_attributes:
329
+ setattr(_MovedItems, attr.name, attr)
330
+ if isinstance(attr, MovedModule):
331
+ _importer._add_module(attr, "moves." + attr.name)
332
+ del attr
333
+
334
+ _MovedItems._moved_attributes = _moved_attributes
335
+
336
+ moves = _MovedItems(__name__ + ".moves")
337
+ _importer._add_module(moves, "moves")
338
+
339
+
340
+ class Module_six_moves_urllib_parse(_LazyModule):
341
+
342
+ """Lazy loading of moved objects in six.moves.urllib_parse"""
343
+
344
+
345
+ _urllib_parse_moved_attributes = [
346
+ MovedAttribute("ParseResult", "urlparse", "urllib.parse"),
347
+ MovedAttribute("SplitResult", "urlparse", "urllib.parse"),
348
+ MovedAttribute("parse_qs", "urlparse", "urllib.parse"),
349
+ MovedAttribute("parse_qsl", "urlparse", "urllib.parse"),
350
+ MovedAttribute("urldefrag", "urlparse", "urllib.parse"),
351
+ MovedAttribute("urljoin", "urlparse", "urllib.parse"),
352
+ MovedAttribute("urlparse", "urlparse", "urllib.parse"),
353
+ MovedAttribute("urlsplit", "urlparse", "urllib.parse"),
354
+ MovedAttribute("urlunparse", "urlparse", "urllib.parse"),
355
+ MovedAttribute("urlunsplit", "urlparse", "urllib.parse"),
356
+ MovedAttribute("quote", "urllib", "urllib.parse"),
357
+ MovedAttribute("quote_plus", "urllib", "urllib.parse"),
358
+ MovedAttribute("unquote", "urllib", "urllib.parse"),
359
+ MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
360
+ MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"),
361
+ MovedAttribute("urlencode", "urllib", "urllib.parse"),
362
+ MovedAttribute("splitquery", "urllib", "urllib.parse"),
363
+ MovedAttribute("splittag", "urllib", "urllib.parse"),
364
+ MovedAttribute("splituser", "urllib", "urllib.parse"),
365
+ MovedAttribute("splitvalue", "urllib", "urllib.parse"),
366
+ MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
367
+ MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
368
+ MovedAttribute("uses_params", "urlparse", "urllib.parse"),
369
+ MovedAttribute("uses_query", "urlparse", "urllib.parse"),
370
+ MovedAttribute("uses_relative", "urlparse", "urllib.parse"),
371
+ ]
372
+ for attr in _urllib_parse_moved_attributes:
373
+ setattr(Module_six_moves_urllib_parse, attr.name, attr)
374
+ del attr
375
+
376
+ Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
377
+
378
+ _importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
379
+ "moves.urllib_parse", "moves.urllib.parse")
380
+
381
+
382
+ class Module_six_moves_urllib_error(_LazyModule):
383
+
384
+ """Lazy loading of moved objects in six.moves.urllib_error"""
385
+
386
+
387
+ _urllib_error_moved_attributes = [
388
+ MovedAttribute("URLError", "urllib2", "urllib.error"),
389
+ MovedAttribute("HTTPError", "urllib2", "urllib.error"),
390
+ MovedAttribute("ContentTooShortError", "urllib", "urllib.error"),
391
+ ]
392
+ for attr in _urllib_error_moved_attributes:
393
+ setattr(Module_six_moves_urllib_error, attr.name, attr)
394
+ del attr
395
+
396
+ Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
397
+
398
+ _importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
399
+ "moves.urllib_error", "moves.urllib.error")
400
+
401
+
402
+ class Module_six_moves_urllib_request(_LazyModule):
403
+
404
+ """Lazy loading of moved objects in six.moves.urllib_request"""
405
+
406
+
407
+ _urllib_request_moved_attributes = [
408
+ MovedAttribute("urlopen", "urllib2", "urllib.request"),
409
+ MovedAttribute("install_opener", "urllib2", "urllib.request"),
410
+ MovedAttribute("build_opener", "urllib2", "urllib.request"),
411
+ MovedAttribute("pathname2url", "urllib", "urllib.request"),
412
+ MovedAttribute("url2pathname", "urllib", "urllib.request"),
413
+ MovedAttribute("getproxies", "urllib", "urllib.request"),
414
+ MovedAttribute("Request", "urllib2", "urllib.request"),
415
+ MovedAttribute("OpenerDirector", "urllib2", "urllib.request"),
416
+ MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"),
417
+ MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"),
418
+ MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"),
419
+ MovedAttribute("ProxyHandler", "urllib2", "urllib.request"),
420
+ MovedAttribute("BaseHandler", "urllib2", "urllib.request"),
421
+ MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"),
422
+ MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"),
423
+ MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"),
424
+ MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"),
425
+ MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"),
426
+ MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"),
427
+ MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"),
428
+ MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"),
429
+ MovedAttribute("HTTPHandler", "urllib2", "urllib.request"),
430
+ MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"),
431
+ MovedAttribute("FileHandler", "urllib2", "urllib.request"),
432
+ MovedAttribute("FTPHandler", "urllib2", "urllib.request"),
433
+ MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"),
434
+ MovedAttribute("UnknownHandler", "urllib2", "urllib.request"),
435
+ MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"),
436
+ MovedAttribute("urlretrieve", "urllib", "urllib.request"),
437
+ MovedAttribute("urlcleanup", "urllib", "urllib.request"),
438
+ MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
439
+ MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
440
+ MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
441
+ ]
442
+ if sys.version_info[:2] < (3, 14):
443
+ _urllib_request_moved_attributes.extend(
444
+ [
445
+ MovedAttribute("URLopener", "urllib", "urllib.request"),
446
+ MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
447
+ ]
448
+ )
449
+ for attr in _urllib_request_moved_attributes:
450
+ setattr(Module_six_moves_urllib_request, attr.name, attr)
451
+ del attr
452
+
453
+ Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
454
+
455
+ _importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
456
+ "moves.urllib_request", "moves.urllib.request")
457
+
458
+
459
+ class Module_six_moves_urllib_response(_LazyModule):
460
+
461
+ """Lazy loading of moved objects in six.moves.urllib_response"""
462
+
463
+
464
+ _urllib_response_moved_attributes = [
465
+ MovedAttribute("addbase", "urllib", "urllib.response"),
466
+ MovedAttribute("addclosehook", "urllib", "urllib.response"),
467
+ MovedAttribute("addinfo", "urllib", "urllib.response"),
468
+ MovedAttribute("addinfourl", "urllib", "urllib.response"),
469
+ ]
470
+ for attr in _urllib_response_moved_attributes:
471
+ setattr(Module_six_moves_urllib_response, attr.name, attr)
472
+ del attr
473
+
474
+ Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
475
+
476
+ _importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
477
+ "moves.urllib_response", "moves.urllib.response")
478
+
479
+
480
+ class Module_six_moves_urllib_robotparser(_LazyModule):
481
+
482
+ """Lazy loading of moved objects in six.moves.urllib_robotparser"""
483
+
484
+
485
+ _urllib_robotparser_moved_attributes = [
486
+ MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"),
487
+ ]
488
+ for attr in _urllib_robotparser_moved_attributes:
489
+ setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
490
+ del attr
491
+
492
+ Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
493
+
494
+ _importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
495
+ "moves.urllib_robotparser", "moves.urllib.robotparser")
496
+
497
+
498
+ class Module_six_moves_urllib(types.ModuleType):
499
+
500
+ """Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
501
+ __path__ = [] # mark as package
502
+ parse = _importer._get_module("moves.urllib_parse")
503
+ error = _importer._get_module("moves.urllib_error")
504
+ request = _importer._get_module("moves.urllib_request")
505
+ response = _importer._get_module("moves.urllib_response")
506
+ robotparser = _importer._get_module("moves.urllib_robotparser")
507
+
508
+ def __dir__(self):
509
+ return ['parse', 'error', 'request', 'response', 'robotparser']
510
+
511
+ _importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
512
+ "moves.urllib")
513
+
514
+
515
+ def add_move(move):
516
+ """Add an item to six.moves."""
517
+ setattr(_MovedItems, move.name, move)
518
+
519
+
520
+ def remove_move(name):
521
+ """Remove item from six.moves."""
522
+ try:
523
+ delattr(_MovedItems, name)
524
+ except AttributeError:
525
+ try:
526
+ del moves.__dict__[name]
527
+ except KeyError:
528
+ raise AttributeError("no such move, %r" % (name,))
529
+
530
+
531
+ if PY3:
532
+ _meth_func = "__func__"
533
+ _meth_self = "__self__"
534
+
535
+ _func_closure = "__closure__"
536
+ _func_code = "__code__"
537
+ _func_defaults = "__defaults__"
538
+ _func_globals = "__globals__"
539
+ else:
540
+ _meth_func = "im_func"
541
+ _meth_self = "im_self"
542
+
543
+ _func_closure = "func_closure"
544
+ _func_code = "func_code"
545
+ _func_defaults = "func_defaults"
546
+ _func_globals = "func_globals"
547
+
548
+
549
+ try:
550
+ advance_iterator = next
551
+ except NameError:
552
+ def advance_iterator(it):
553
+ return it.next()
554
+ next = advance_iterator
555
+
556
+
557
+ try:
558
+ callable = callable
559
+ except NameError:
560
+ def callable(obj):
561
+ return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
562
+
563
+
564
+ if PY3:
565
+ def get_unbound_function(unbound):
566
+ return unbound
567
+
568
+ create_bound_method = types.MethodType
569
+
570
+ def create_unbound_method(func, cls):
571
+ return func
572
+
573
+ Iterator = object
574
+ else:
575
+ def get_unbound_function(unbound):
576
+ return unbound.im_func
577
+
578
+ def create_bound_method(func, obj):
579
+ return types.MethodType(func, obj, obj.__class__)
580
+
581
+ def create_unbound_method(func, cls):
582
+ return types.MethodType(func, None, cls)
583
+
584
+ class Iterator(object):
585
+
586
+ def next(self):
587
+ return type(self).__next__(self)
588
+
589
+ callable = callable
590
+ _add_doc(get_unbound_function,
591
+ """Get the function out of a possibly unbound function""")
592
+
593
+
594
+ get_method_function = operator.attrgetter(_meth_func)
595
+ get_method_self = operator.attrgetter(_meth_self)
596
+ get_function_closure = operator.attrgetter(_func_closure)
597
+ get_function_code = operator.attrgetter(_func_code)
598
+ get_function_defaults = operator.attrgetter(_func_defaults)
599
+ get_function_globals = operator.attrgetter(_func_globals)
600
+
601
+
602
+ if PY3:
603
+ def iterkeys(d, **kw):
604
+ return iter(d.keys(**kw))
605
+
606
+ def itervalues(d, **kw):
607
+ return iter(d.values(**kw))
608
+
609
+ def iteritems(d, **kw):
610
+ return iter(d.items(**kw))
611
+
612
+ def iterlists(d, **kw):
613
+ return iter(d.lists(**kw))
614
+
615
+ viewkeys = operator.methodcaller("keys")
616
+
617
+ viewvalues = operator.methodcaller("values")
618
+
619
+ viewitems = operator.methodcaller("items")
620
+ else:
621
+ def iterkeys(d, **kw):
622
+ return d.iterkeys(**kw)
623
+
624
+ def itervalues(d, **kw):
625
+ return d.itervalues(**kw)
626
+
627
+ def iteritems(d, **kw):
628
+ return d.iteritems(**kw)
629
+
630
+ def iterlists(d, **kw):
631
+ return d.iterlists(**kw)
632
+
633
+ viewkeys = operator.methodcaller("viewkeys")
634
+
635
+ viewvalues = operator.methodcaller("viewvalues")
636
+
637
+ viewitems = operator.methodcaller("viewitems")
638
+
639
+ _add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
640
+ _add_doc(itervalues, "Return an iterator over the values of a dictionary.")
641
+ _add_doc(iteritems,
642
+ "Return an iterator over the (key, value) pairs of a dictionary.")
643
+ _add_doc(iterlists,
644
+ "Return an iterator over the (key, [values]) pairs of a dictionary.")
645
+
646
+
647
+ if PY3:
648
+ def b(s):
649
+ return s.encode("latin-1")
650
+
651
+ def u(s):
652
+ return s
653
+ unichr = chr
654
+ import struct
655
+ int2byte = struct.Struct(">B").pack
656
+ del struct
657
+ byte2int = operator.itemgetter(0)
658
+ indexbytes = operator.getitem
659
+ iterbytes = iter
660
+ import io
661
+ StringIO = io.StringIO
662
+ BytesIO = io.BytesIO
663
+ del io
664
+ _assertCountEqual = "assertCountEqual"
665
+ if sys.version_info[1] <= 1:
666
+ _assertRaisesRegex = "assertRaisesRegexp"
667
+ _assertRegex = "assertRegexpMatches"
668
+ _assertNotRegex = "assertNotRegexpMatches"
669
+ else:
670
+ _assertRaisesRegex = "assertRaisesRegex"
671
+ _assertRegex = "assertRegex"
672
+ _assertNotRegex = "assertNotRegex"
673
+ else:
674
+ def b(s):
675
+ return s
676
+ # Workaround for standalone backslash
677
+
678
+ def u(s):
679
+ return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
680
+ unichr = unichr
681
+ int2byte = chr
682
+
683
+ def byte2int(bs):
684
+ return ord(bs[0])
685
+
686
+ def indexbytes(buf, i):
687
+ return ord(buf[i])
688
+ iterbytes = functools.partial(itertools.imap, ord)
689
+ import StringIO
690
+ StringIO = BytesIO = StringIO.StringIO
691
+ _assertCountEqual = "assertItemsEqual"
692
+ _assertRaisesRegex = "assertRaisesRegexp"
693
+ _assertRegex = "assertRegexpMatches"
694
+ _assertNotRegex = "assertNotRegexpMatches"
695
+ _add_doc(b, """Byte literal""")
696
+ _add_doc(u, """Text literal""")
697
+
698
+
699
+ def assertCountEqual(self, *args, **kwargs):
700
+ return getattr(self, _assertCountEqual)(*args, **kwargs)
701
+
702
+
703
+ def assertRaisesRegex(self, *args, **kwargs):
704
+ return getattr(self, _assertRaisesRegex)(*args, **kwargs)
705
+
706
+
707
+ def assertRegex(self, *args, **kwargs):
708
+ return getattr(self, _assertRegex)(*args, **kwargs)
709
+
710
+
711
+ def assertNotRegex(self, *args, **kwargs):
712
+ return getattr(self, _assertNotRegex)(*args, **kwargs)
713
+
714
+
715
+ if PY3:
716
+ exec_ = getattr(moves.builtins, "exec")
717
+
718
+ def reraise(tp, value, tb=None):
719
+ try:
720
+ if value is None:
721
+ value = tp()
722
+ if value.__traceback__ is not tb:
723
+ raise value.with_traceback(tb)
724
+ raise value
725
+ finally:
726
+ value = None
727
+ tb = None
728
+
729
+ else:
730
+ def exec_(_code_, _globs_=None, _locs_=None):
731
+ """Execute code in a namespace."""
732
+ if _globs_ is None:
733
+ frame = sys._getframe(1)
734
+ _globs_ = frame.f_globals
735
+ if _locs_ is None:
736
+ _locs_ = frame.f_locals
737
+ del frame
738
+ elif _locs_ is None:
739
+ _locs_ = _globs_
740
+ exec("""exec _code_ in _globs_, _locs_""")
741
+
742
+ exec_("""def reraise(tp, value, tb=None):
743
+ try:
744
+ raise tp, value, tb
745
+ finally:
746
+ tb = None
747
+ """)
748
+
749
+
750
+ if sys.version_info[:2] > (3,):
751
+ exec_("""def raise_from(value, from_value):
752
+ try:
753
+ raise value from from_value
754
+ finally:
755
+ value = None
756
+ """)
757
+ else:
758
+ def raise_from(value, from_value):
759
+ raise value
760
+
761
+
762
+ print_ = getattr(moves.builtins, "print", None)
763
+ if print_ is None:
764
+ def print_(*args, **kwargs):
765
+ """The new-style print function for Python 2.4 and 2.5."""
766
+ fp = kwargs.pop("file", sys.stdout)
767
+ if fp is None:
768
+ return
769
+
770
+ def write(data):
771
+ if not isinstance(data, basestring):
772
+ data = str(data)
773
+ # If the file has an encoding, encode unicode with it.
774
+ if (isinstance(fp, file) and
775
+ isinstance(data, unicode) and
776
+ fp.encoding is not None):
777
+ errors = getattr(fp, "errors", None)
778
+ if errors is None:
779
+ errors = "strict"
780
+ data = data.encode(fp.encoding, errors)
781
+ fp.write(data)
782
+ want_unicode = False
783
+ sep = kwargs.pop("sep", None)
784
+ if sep is not None:
785
+ if isinstance(sep, unicode):
786
+ want_unicode = True
787
+ elif not isinstance(sep, str):
788
+ raise TypeError("sep must be None or a string")
789
+ end = kwargs.pop("end", None)
790
+ if end is not None:
791
+ if isinstance(end, unicode):
792
+ want_unicode = True
793
+ elif not isinstance(end, str):
794
+ raise TypeError("end must be None or a string")
795
+ if kwargs:
796
+ raise TypeError("invalid keyword arguments to print()")
797
+ if not want_unicode:
798
+ for arg in args:
799
+ if isinstance(arg, unicode):
800
+ want_unicode = True
801
+ break
802
+ if want_unicode:
803
+ newline = unicode("\n")
804
+ space = unicode(" ")
805
+ else:
806
+ newline = "\n"
807
+ space = " "
808
+ if sep is None:
809
+ sep = space
810
+ if end is None:
811
+ end = newline
812
+ for i, arg in enumerate(args):
813
+ if i:
814
+ write(sep)
815
+ write(arg)
816
+ write(end)
817
+ if sys.version_info[:2] < (3, 3):
818
+ _print = print_
819
+
820
+ def print_(*args, **kwargs):
821
+ fp = kwargs.get("file", sys.stdout)
822
+ flush = kwargs.pop("flush", False)
823
+ _print(*args, **kwargs)
824
+ if flush and fp is not None:
825
+ fp.flush()
826
+
827
+ _add_doc(reraise, """Reraise an exception.""")
828
+
829
+ if sys.version_info[0:2] < (3, 4):
830
+ # This does exactly the same what the :func:`py3:functools.update_wrapper`
831
+ # function does on Python versions after 3.2. It sets the ``__wrapped__``
832
+ # attribute on ``wrapper`` object and it doesn't raise an error if any of
833
+ # the attributes mentioned in ``assigned`` and ``updated`` are missing on
834
+ # ``wrapped`` object.
835
+ def _update_wrapper(wrapper, wrapped,
836
+ assigned=functools.WRAPPER_ASSIGNMENTS,
837
+ updated=functools.WRAPPER_UPDATES):
838
+ for attr in assigned:
839
+ try:
840
+ value = getattr(wrapped, attr)
841
+ except AttributeError:
842
+ continue
843
+ else:
844
+ setattr(wrapper, attr, value)
845
+ for attr in updated:
846
+ getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
847
+ wrapper.__wrapped__ = wrapped
848
+ return wrapper
849
+ _update_wrapper.__doc__ = functools.update_wrapper.__doc__
850
+
851
+ def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
852
+ updated=functools.WRAPPER_UPDATES):
853
+ return functools.partial(_update_wrapper, wrapped=wrapped,
854
+ assigned=assigned, updated=updated)
855
+ wraps.__doc__ = functools.wraps.__doc__
856
+
857
+ else:
858
+ wraps = functools.wraps
859
+
860
+
861
+ def with_metaclass(meta, *bases):
862
+ """Create a base class with a metaclass."""
863
+ # This requires a bit of explanation: the basic idea is to make a dummy
864
+ # metaclass for one level of class instantiation that replaces itself with
865
+ # the actual metaclass.
866
+ class metaclass(type):
867
+
868
+ def __new__(cls, name, this_bases, d):
869
+ if sys.version_info[:2] >= (3, 7):
870
+ # This version introduced PEP 560 that requires a bit
871
+ # of extra care (we mimic what is done by __build_class__).
872
+ resolved_bases = types.resolve_bases(bases)
873
+ if resolved_bases is not bases:
874
+ d['__orig_bases__'] = bases
875
+ else:
876
+ resolved_bases = bases
877
+ return meta(name, resolved_bases, d)
878
+
879
+ @classmethod
880
+ def __prepare__(cls, name, this_bases):
881
+ return meta.__prepare__(name, bases)
882
+ return type.__new__(metaclass, 'temporary_class', (), {})
883
+
884
+
885
+ def add_metaclass(metaclass):
886
+ """Class decorator for creating a class with a metaclass."""
887
+ def wrapper(cls):
888
+ orig_vars = cls.__dict__.copy()
889
+ slots = orig_vars.get('__slots__')
890
+ if slots is not None:
891
+ if isinstance(slots, str):
892
+ slots = [slots]
893
+ for slots_var in slots:
894
+ orig_vars.pop(slots_var)
895
+ orig_vars.pop('__dict__', None)
896
+ orig_vars.pop('__weakref__', None)
897
+ if hasattr(cls, '__qualname__'):
898
+ orig_vars['__qualname__'] = cls.__qualname__
899
+ return metaclass(cls.__name__, cls.__bases__, orig_vars)
900
+ return wrapper
901
+
902
+
903
+ def ensure_binary(s, encoding='utf-8', errors='strict'):
904
+ """Coerce **s** to six.binary_type.
905
+
906
+ For Python 2:
907
+ - `unicode` -> encoded to `str`
908
+ - `str` -> `str`
909
+
910
+ For Python 3:
911
+ - `str` -> encoded to `bytes`
912
+ - `bytes` -> `bytes`
913
+ """
914
+ if isinstance(s, binary_type):
915
+ return s
916
+ if isinstance(s, text_type):
917
+ return s.encode(encoding, errors)
918
+ raise TypeError("not expecting type '%s'" % type(s))
919
+
920
+
921
+ def ensure_str(s, encoding='utf-8', errors='strict'):
922
+ """Coerce *s* to `str`.
923
+
924
+ For Python 2:
925
+ - `unicode` -> encoded to `str`
926
+ - `str` -> `str`
927
+
928
+ For Python 3:
929
+ - `str` -> `str`
930
+ - `bytes` -> decoded to `str`
931
+ """
932
+ # Optimization: Fast return for the common case.
933
+ if type(s) is str:
934
+ return s
935
+ if PY2 and isinstance(s, text_type):
936
+ return s.encode(encoding, errors)
937
+ elif PY3 and isinstance(s, binary_type):
938
+ return s.decode(encoding, errors)
939
+ elif not isinstance(s, (text_type, binary_type)):
940
+ raise TypeError("not expecting type '%s'" % type(s))
941
+ return s
942
+
943
+
944
+ def ensure_text(s, encoding='utf-8', errors='strict'):
945
+ """Coerce *s* to six.text_type.
946
+
947
+ For Python 2:
948
+ - `unicode` -> `unicode`
949
+ - `str` -> `unicode`
950
+
951
+ For Python 3:
952
+ - `str` -> `str`
953
+ - `bytes` -> decoded to `str`
954
+ """
955
+ if isinstance(s, binary_type):
956
+ return s.decode(encoding, errors)
957
+ elif isinstance(s, text_type):
958
+ return s
959
+ else:
960
+ raise TypeError("not expecting type '%s'" % type(s))
961
+
962
+
963
+ def python_2_unicode_compatible(klass):
964
+ """
965
+ A class decorator that defines __unicode__ and __str__ methods under Python 2.
966
+ Under Python 3 it does nothing.
967
+
968
+ To support Python 2 and 3 with a single code base, define a __str__ method
969
+ returning text and apply this decorator to the class.
970
+ """
971
+ if PY2:
972
+ if '__str__' not in klass.__dict__:
973
+ raise ValueError("@python_2_unicode_compatible cannot be applied "
974
+ "to %s because it doesn't define __str__()." %
975
+ klass.__name__)
976
+ klass.__unicode__ = klass.__str__
977
+ klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
978
+ return klass
979
+
980
+
981
+ # Complete the moves implementation.
982
+ # This code is at the end of this module to speed up module loading.
983
+ # Turn this module into a package.
984
+ __path__ = [] # required for PEP 302 and PEP 451
985
+ __package__ = __name__ # see PEP 366 @ReservedAssignment
986
+ if globals().get("__spec__") is not None:
987
+ __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable
988
+ # Remove other six meta path importers, since they cause problems. This can
989
+ # happen if six is removed from sys.modules and then reloaded. (Setuptools does
990
+ # this for some reason.)
991
+ if sys.meta_path:
992
+ for i, importer in enumerate(sys.meta_path):
993
+ # Here's some real nastiness: Another "instance" of the six module might
994
+ # be floating around. Therefore, we can't use isinstance() to check for
995
+ # the six meta path importer, since the other six instance will have
996
+ # inserted an importer with different class.
997
+ if (type(importer).__name__ == "_SixMetaPathImporter" and
998
+ importer.name == __name__):
999
+ del sys.meta_path[i]
1000
+ break
1001
+ del i, importer
1002
+ # Finally, add the importer to the meta path import hook.
1003
+ sys.meta_path.append(_importer)
venv/Lib/site-packages/threadpoolctl.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """threadpoolctl
2
+
3
+ This module provides utilities to introspect native libraries that relies on
4
+ thread pools (notably BLAS and OpenMP implementations) and dynamically set the
5
+ maximal number of threads they can use.
6
+ """
7
+
8
+ # License: BSD 3-Clause
9
+
10
+ # The code to introspect dynamically loaded libraries on POSIX systems is
11
+ # adapted from code by Intel developer @anton-malakhov available at
12
+ # https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation)
13
+ # and also published under the BSD 3-Clause license
14
+ import os
15
+ import re
16
+ import sys
17
+ import ctypes
18
+ import itertools
19
+ import textwrap
20
+ from typing import final
21
+ import warnings
22
+ from ctypes.util import find_library
23
+ from abc import ABC, abstractmethod
24
+ from functools import lru_cache
25
+ from contextlib import ContextDecorator
26
+
27
+ __version__ = "3.6.0"
28
+ __all__ = [
29
+ "threadpool_limits",
30
+ "threadpool_info",
31
+ "ThreadpoolController",
32
+ "LibController",
33
+ "register",
34
+ ]
35
+
36
+
37
+ # One can get runtime errors or even segfaults due to multiple OpenMP libraries
38
+ # loaded simultaneously which can happen easily in Python when importing and
39
+ # using compiled extensions built with different compilers and therefore
40
+ # different OpenMP runtimes in the same program. In particular libiomp (used by
41
+ # Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for
42
+ # instance when calling BLAS inside a prange. Setting the following environment
43
+ # variable allows multiple OpenMP libraries to be loaded. It should not degrade
44
+ # performances since we manually take care of potential over-subscription
45
+ # performance issues, in sections of the code where nested OpenMP loops can
46
+ # happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily
47
+ # disable it while under the scope of the outer OpenMP parallel section.
48
+ os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True")
49
+
50
+ # Structure to cast the info on dynamically loaded library. See
51
+ # https://linux.die.net/man/3/dl_iterate_phdr for more details.
52
+ _SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32
53
+ _SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16
54
+
55
+
56
+ class _dl_phdr_info(ctypes.Structure):
57
+ _fields_ = [
58
+ ("dlpi_addr", _SYSTEM_UINT), # Base address of object
59
+ ("dlpi_name", ctypes.c_char_p), # path to the library
60
+ ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers
61
+ ("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr
62
+ ]
63
+
64
+
65
+ # The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows.
66
+ try:
67
+ _RTLD_NOLOAD = os.RTLD_NOLOAD
68
+ except AttributeError:
69
+ _RTLD_NOLOAD = ctypes.DEFAULT_MODE
70
+
71
+
72
+ class LibController(ABC):
73
+ """Abstract base class for the individual library controllers
74
+
75
+ A library controller must expose the following class attributes:
76
+ - user_api : str
77
+ Usually the name of the library or generic specification the library
78
+ implements, e.g. "blas" is a specification with different implementations.
79
+ - internal_api : str
80
+ Usually the name of the library or concrete implementation of some
81
+ specification, e.g. "openblas" is an implementation of the "blas"
82
+ specification.
83
+ - filename_prefixes : tuple
84
+ Possible prefixes of the shared library's filename that allow to
85
+ identify the library. e.g. "libopenblas" for libopenblas.so.
86
+
87
+ and implement the following methods: `get_num_threads`, `set_num_threads` and
88
+ `get_version`.
89
+
90
+ Threadpoolctl loops through all the loaded shared libraries and tries to match
91
+ the filename of each library with the `filename_prefixes`. If a match is found, a
92
+ controller is instantiated and a handler to the library is stored in the `dynlib`
93
+ attribute as a `ctypes.CDLL` object. It can be used to access the necessary symbols
94
+ of the shared library to implement the above methods.
95
+
96
+ The following information will be exposed in the info dictionary:
97
+ - user_api : standardized API, if any, or a copy of internal_api.
98
+ - internal_api : implementation-specific API.
99
+ - num_threads : the current thread limit.
100
+ - prefix : prefix of the shared library's filename.
101
+ - filepath : path to the loaded shared library.
102
+ - version : version of the library (if available).
103
+
104
+ In addition, each library controller may expose internal API specific entries. They
105
+ must be set as attributes in the `set_additional_attributes` method.
106
+ """
107
+
108
+ @final
109
+ def __init__(self, *, filepath=None, prefix=None, parent=None):
110
+ """This is not meant to be overriden by subclasses."""
111
+ self.parent = parent
112
+ self.prefix = prefix
113
+ self.filepath = filepath
114
+ self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
115
+ self._symbol_prefix, self._symbol_suffix = self._find_affixes()
116
+ self.version = self.get_version()
117
+ self.set_additional_attributes()
118
+
119
+ def info(self):
120
+ """Return relevant info wrapped in a dict"""
121
+ hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix")
122
+ return {
123
+ "user_api": self.user_api,
124
+ "internal_api": self.internal_api,
125
+ "num_threads": self.num_threads,
126
+ **{k: v for k, v in vars(self).items() if k not in hidden_attrs},
127
+ }
128
+
129
+ def set_additional_attributes(self):
130
+ """Set additional attributes meant to be exposed in the info dict"""
131
+
132
+ @property
133
+ def num_threads(self):
134
+ """Exposes the current thread limit as a dynamic property
135
+
136
+ This is not meant to be used or overriden by subclasses.
137
+ """
138
+ return self.get_num_threads()
139
+
140
+ @abstractmethod
141
+ def get_num_threads(self):
142
+ """Return the maximum number of threads available to use"""
143
+
144
+ @abstractmethod
145
+ def set_num_threads(self, num_threads):
146
+ """Set the maximum number of threads to use"""
147
+
148
+ @abstractmethod
149
+ def get_version(self):
150
+ """Return the version of the shared library"""
151
+
152
+ def _find_affixes(self):
153
+ """Return the affixes for the symbols of the shared library"""
154
+ return "", ""
155
+
156
+ def _get_symbol(self, name):
157
+ """Return the symbol of the shared library accounding for the affixes"""
158
+ return getattr(
159
+ self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None
160
+ )
161
+
162
+
163
+ class OpenBLASController(LibController):
164
+ """Controller class for OpenBLAS"""
165
+
166
+ user_api = "blas"
167
+ internal_api = "openblas"
168
+ filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas")
169
+
170
+ _symbol_prefixes = ("", "scipy_")
171
+ _symbol_suffixes = ("", "64_", "_64")
172
+
173
+ # All variations of "openblas_get_num_threads", accounting for the affixes
174
+ check_symbols = tuple(
175
+ f"{prefix}openblas_get_num_threads{suffix}"
176
+ for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes)
177
+ )
178
+
179
+ def _find_affixes(self):
180
+ for prefix, suffix in itertools.product(
181
+ self._symbol_prefixes, self._symbol_suffixes
182
+ ):
183
+ if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"):
184
+ return prefix, suffix
185
+
186
+ def set_additional_attributes(self):
187
+ self.threading_layer = self._get_threading_layer()
188
+ self.architecture = self._get_architecture()
189
+
190
+ def get_num_threads(self):
191
+ get_num_threads_func = self._get_symbol("openblas_get_num_threads")
192
+ if get_num_threads_func is not None:
193
+ return get_num_threads_func()
194
+ return None
195
+
196
+ def set_num_threads(self, num_threads):
197
+ set_num_threads_func = self._get_symbol("openblas_set_num_threads")
198
+ if set_num_threads_func is not None:
199
+ return set_num_threads_func(num_threads)
200
+ return None
201
+
202
+ def get_version(self):
203
+ # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
204
+ # did not expose its version before that.
205
+ get_version_func = self._get_symbol("openblas_get_config")
206
+ if get_version_func is not None:
207
+ get_version_func.restype = ctypes.c_char_p
208
+ config = get_version_func().split()
209
+ if config[0] == b"OpenBLAS":
210
+ return config[1].decode("utf-8")
211
+ return None
212
+ return None
213
+
214
+ def _get_threading_layer(self):
215
+ """Return the threading layer of OpenBLAS"""
216
+ get_threading_layer_func = self._get_symbol("openblas_get_parallel")
217
+ if get_threading_layer_func is not None:
218
+ threading_layer = get_threading_layer_func()
219
+ if threading_layer == 2:
220
+ return "openmp"
221
+ elif threading_layer == 1:
222
+ return "pthreads"
223
+ return "disabled"
224
+ return "unknown"
225
+
226
+ def _get_architecture(self):
227
+ """Return the architecture detected by OpenBLAS"""
228
+ get_architecture_func = self._get_symbol("openblas_get_corename")
229
+ if get_architecture_func is not None:
230
+ get_architecture_func.restype = ctypes.c_char_p
231
+ return get_architecture_func().decode("utf-8")
232
+ return None
233
+
234
+
235
+ class BLISController(LibController):
236
+ """Controller class for BLIS"""
237
+
238
+ user_api = "blas"
239
+ internal_api = "blis"
240
+ filename_prefixes = ("libblis", "libblas")
241
+ check_symbols = (
242
+ "bli_thread_get_num_threads",
243
+ "bli_thread_set_num_threads",
244
+ "bli_info_get_version_str",
245
+ "bli_info_get_enable_openmp",
246
+ "bli_info_get_enable_pthreads",
247
+ "bli_arch_query_id",
248
+ "bli_arch_string",
249
+ )
250
+
251
+ def set_additional_attributes(self):
252
+ self.threading_layer = self._get_threading_layer()
253
+ self.architecture = self._get_architecture()
254
+
255
+ def get_num_threads(self):
256
+ get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None)
257
+ num_threads = get_func()
258
+ # by default BLIS is single-threaded and get_num_threads
259
+ # returns -1. We map it to 1 for consistency with other libraries.
260
+ return 1 if num_threads == -1 else num_threads
261
+
262
+ def set_num_threads(self, num_threads):
263
+ set_func = getattr(
264
+ self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None
265
+ )
266
+ return set_func(num_threads)
267
+
268
+ def get_version(self):
269
+ get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None)
270
+ if get_version_ is None:
271
+ return None
272
+
273
+ get_version_.restype = ctypes.c_char_p
274
+ return get_version_().decode("utf-8")
275
+
276
+ def _get_threading_layer(self):
277
+ """Return the threading layer of BLIS"""
278
+ if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)():
279
+ return "openmp"
280
+ elif getattr(self.dynlib, "bli_info_get_enable_pthreads", lambda: False)():
281
+ return "pthreads"
282
+ return "disabled"
283
+
284
+ def _get_architecture(self):
285
+ """Return the architecture detected by BLIS"""
286
+ bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None)
287
+ bli_arch_string = getattr(self.dynlib, "bli_arch_string", None)
288
+ if bli_arch_query_id is None or bli_arch_string is None:
289
+ return None
290
+
291
+ # the true restype should be BLIS' arch_t (enum) but int should work
292
+ # for us:
293
+ bli_arch_query_id.restype = ctypes.c_int
294
+ bli_arch_string.restype = ctypes.c_char_p
295
+ return bli_arch_string(bli_arch_query_id()).decode("utf-8")
296
+
297
+
298
+ class FlexiBLASController(LibController):
299
+ """Controller class for FlexiBLAS"""
300
+
301
+ user_api = "blas"
302
+ internal_api = "flexiblas"
303
+ filename_prefixes = ("libflexiblas",)
304
+ check_symbols = (
305
+ "flexiblas_get_num_threads",
306
+ "flexiblas_set_num_threads",
307
+ "flexiblas_get_version",
308
+ "flexiblas_list",
309
+ "flexiblas_list_loaded",
310
+ "flexiblas_current_backend",
311
+ )
312
+
313
+ @property
314
+ def loaded_backends(self):
315
+ return self._get_backend_list(loaded=True)
316
+
317
+ @property
318
+ def current_backend(self):
319
+ return self._get_current_backend()
320
+
321
+ def info(self):
322
+ """Return relevant info wrapped in a dict"""
323
+ # We override the info method because the loaded and current backends
324
+ # are dynamic properties
325
+ exposed_attrs = super().info()
326
+ exposed_attrs["loaded_backends"] = self.loaded_backends
327
+ exposed_attrs["current_backend"] = self.current_backend
328
+
329
+ return exposed_attrs
330
+
331
+ def set_additional_attributes(self):
332
+ self.available_backends = self._get_backend_list(loaded=False)
333
+
334
+ def get_num_threads(self):
335
+ get_func = getattr(self.dynlib, "flexiblas_get_num_threads", lambda: None)
336
+ num_threads = get_func()
337
+ # by default BLIS is single-threaded and get_num_threads
338
+ # returns -1. We map it to 1 for consistency with other libraries.
339
+ return 1 if num_threads == -1 else num_threads
340
+
341
+ def set_num_threads(self, num_threads):
342
+ set_func = getattr(
343
+ self.dynlib, "flexiblas_set_num_threads", lambda num_threads: None
344
+ )
345
+ return set_func(num_threads)
346
+
347
+ def get_version(self):
348
+ get_version_ = getattr(self.dynlib, "flexiblas_get_version", None)
349
+ if get_version_ is None:
350
+ return None
351
+
352
+ major = ctypes.c_int()
353
+ minor = ctypes.c_int()
354
+ patch = ctypes.c_int()
355
+ get_version_(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
356
+ return f"{major.value}.{minor.value}.{patch.value}"
357
+
358
+ def _get_backend_list(self, loaded=False):
359
+ """Return the list of available backends for FlexiBLAS.
360
+
361
+ If loaded is False, return the list of available backends from the FlexiBLAS
362
+ configuration. If loaded is True, return the list of actually loaded backends.
363
+ """
364
+ func_name = f"flexiblas_list{'_loaded' if loaded else ''}"
365
+ get_backend_list_ = getattr(self.dynlib, func_name, None)
366
+ if get_backend_list_ is None:
367
+ return None
368
+
369
+ n_backends = get_backend_list_(None, 0, 0)
370
+
371
+ backends = []
372
+ for i in range(n_backends):
373
+ backend_name = ctypes.create_string_buffer(1024)
374
+ get_backend_list_(backend_name, 1024, i)
375
+ if backend_name.value.decode("utf-8") != "__FALLBACK__":
376
+ # We don't know when to expect __FALLBACK__ but it is not a real
377
+ # backend and does not show up when running flexiblas list.
378
+ backends.append(backend_name.value.decode("utf-8"))
379
+ return backends
380
+
381
+ def _get_current_backend(self):
382
+ """Return the backend of FlexiBLAS"""
383
+ get_backend_ = getattr(self.dynlib, "flexiblas_current_backend", None)
384
+ if get_backend_ is None:
385
+ return None
386
+
387
+ backend = ctypes.create_string_buffer(1024)
388
+ get_backend_(backend, ctypes.sizeof(backend))
389
+ return backend.value.decode("utf-8")
390
+
391
+ def switch_backend(self, backend):
392
+ """Switch the backend of FlexiBLAS
393
+
394
+ Parameters
395
+ ----------
396
+ backend : str
397
+ The name or the path to the shared library of the backend to switch to. If
398
+ the backend is not already loaded, it will be loaded first.
399
+ """
400
+ if backend not in self.loaded_backends:
401
+ if backend in self.available_backends:
402
+ load_func = getattr(self.dynlib, "flexiblas_load_backend", lambda _: -1)
403
+ else: # assume backend is a path to a shared library
404
+ load_func = getattr(
405
+ self.dynlib, "flexiblas_load_backend_library", lambda _: -1
406
+ )
407
+ res = load_func(str(backend).encode("utf-8"))
408
+ if res == -1:
409
+ raise RuntimeError(
410
+ f"Failed to load backend {backend!r}. It must either be the name of"
411
+ " a backend available in the FlexiBLAS configuration "
412
+ f"{self.available_backends} or the path to a valid shared library."
413
+ )
414
+
415
+ # Trigger a new search of loaded shared libraries since loading a new
416
+ # backend caused a dlopen.
417
+ self.parent._load_libraries()
418
+
419
+ switch_func = getattr(self.dynlib, "flexiblas_switch", lambda _: -1)
420
+ idx = self.loaded_backends.index(backend)
421
+ res = switch_func(idx)
422
+ if res == -1:
423
+ raise RuntimeError(f"Failed to switch to backend {backend!r}.")
424
+
425
+
426
+ class MKLController(LibController):
427
+ """Controller class for MKL"""
428
+
429
+ user_api = "blas"
430
+ internal_api = "mkl"
431
+ filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas")
432
+ check_symbols = (
433
+ "MKL_Get_Max_Threads",
434
+ "MKL_Set_Num_Threads",
435
+ "MKL_Get_Version_String",
436
+ "MKL_Set_Threading_Layer",
437
+ )
438
+
439
+ def set_additional_attributes(self):
440
+ self.threading_layer = self._get_threading_layer()
441
+
442
+ def get_num_threads(self):
443
+ get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None)
444
+ return get_func()
445
+
446
+ def set_num_threads(self, num_threads):
447
+ set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None)
448
+ return set_func(num_threads)
449
+
450
+ def get_version(self):
451
+ if not hasattr(self.dynlib, "MKL_Get_Version_String"):
452
+ return None
453
+
454
+ res = ctypes.create_string_buffer(200)
455
+ self.dynlib.MKL_Get_Version_String(res, 200)
456
+
457
+ version = res.value.decode("utf-8")
458
+ group = re.search(r"Version ([^ ]+) ", version)
459
+ if group is not None:
460
+ version = group.groups()[0]
461
+ return version.strip()
462
+
463
+ def _get_threading_layer(self):
464
+ """Return the threading layer of MKL"""
465
+ # The function mkl_set_threading_layer returns the current threading
466
+ # layer. Calling it with an invalid threading layer allows us to safely
467
+ # get the threading layer
468
+ set_threading_layer = getattr(
469
+ self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1
470
+ )
471
+ layer_map = {
472
+ 0: "intel",
473
+ 1: "sequential",
474
+ 2: "pgi",
475
+ 3: "gnu",
476
+ 4: "tbb",
477
+ -1: "not specified",
478
+ }
479
+ return layer_map[set_threading_layer(-1)]
480
+
481
+
482
+ class OpenMPController(LibController):
483
+ """Controller class for OpenMP"""
484
+
485
+ user_api = "openmp"
486
+ internal_api = "openmp"
487
+ filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp")
488
+ check_symbols = (
489
+ "omp_get_max_threads",
490
+ "omp_get_num_threads",
491
+ )
492
+
493
+ def get_num_threads(self):
494
+ get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None)
495
+ return get_func()
496
+
497
+ def set_num_threads(self, num_threads):
498
+ set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None)
499
+ return set_func(num_threads)
500
+
501
+ def get_version(self):
502
+ # There is no way to get the version number programmatically in OpenMP.
503
+ return None
504
+
505
+
506
+ # Controllers for the libraries that we'll look for in the loaded libraries.
507
+ # Third party libraries can register their own controllers.
508
+ _ALL_CONTROLLERS = [
509
+ OpenBLASController,
510
+ BLISController,
511
+ MKLController,
512
+ OpenMPController,
513
+ FlexiBLASController,
514
+ ]
515
+
516
+ # Helpers for the doc and test names
517
+ _ALL_USER_APIS = list(set(lib.user_api for lib in _ALL_CONTROLLERS))
518
+ _ALL_INTERNAL_APIS = [lib.internal_api for lib in _ALL_CONTROLLERS]
519
+ _ALL_PREFIXES = list(
520
+ set(prefix for lib in _ALL_CONTROLLERS for prefix in lib.filename_prefixes)
521
+ )
522
+ _ALL_BLAS_LIBRARIES = [
523
+ lib.internal_api for lib in _ALL_CONTROLLERS if lib.user_api == "blas"
524
+ ]
525
+ _ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes
526
+
527
+
528
+ def register(controller):
529
+ """Register a new controller"""
530
+ _ALL_CONTROLLERS.append(controller)
531
+ _ALL_USER_APIS.append(controller.user_api)
532
+ _ALL_INTERNAL_APIS.append(controller.internal_api)
533
+ _ALL_PREFIXES.extend(controller.filename_prefixes)
534
+
535
+
536
+ def _format_docstring(*args, **kwargs):
537
+ def decorator(o):
538
+ if o.__doc__ is not None:
539
+ o.__doc__ = o.__doc__.format(*args, **kwargs)
540
+ return o
541
+
542
+ return decorator
543
+
544
+
545
+ @lru_cache(maxsize=10000)
546
+ def _realpath(filepath):
547
+ """Small caching wrapper around os.path.realpath to limit system calls"""
548
+ return os.path.realpath(filepath)
549
+
550
+
551
+ @_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS)
552
+ def threadpool_info():
553
+ """Return the maximal number of threads for each detected library.
554
+
555
+ Return a list with all the supported libraries that have been found. Each
556
+ library is represented by a dict with the following information:
557
+
558
+ - "user_api" : user API. Possible values are {USER_APIS}.
559
+ - "internal_api": internal API. Possible values are {INTERNAL_APIS}.
560
+ - "prefix" : filename prefix of the specific implementation.
561
+ - "filepath": path to the loaded library.
562
+ - "version": version of the library (if available).
563
+ - "num_threads": the current thread limit.
564
+
565
+ In addition, each library may contain internal_api specific entries.
566
+ """
567
+ return ThreadpoolController().info()
568
+
569
+
570
+ class _ThreadpoolLimiter:
571
+ """The guts of ThreadpoolController.limit
572
+
573
+ Refer to the docstring of ThreadpoolController.limit for more details.
574
+
575
+ It will only act on the library controllers held by the provided `controller`.
576
+ Using the default constructor sets the limits right away such that it can be used as
577
+ a callable. Setting the limits can be delayed by using the `wrap` class method such
578
+ that it can be used as a decorator.
579
+ """
580
+
581
+ def __init__(self, controller, *, limits=None, user_api=None):
582
+ self._controller = controller
583
+ self._limits, self._user_api, self._prefixes = self._check_params(
584
+ limits, user_api
585
+ )
586
+ self._original_info = self._controller.info()
587
+ self._set_threadpool_limits()
588
+
589
+ def __enter__(self):
590
+ return self
591
+
592
+ def __exit__(self, type, value, traceback):
593
+ self.restore_original_limits()
594
+
595
+ @classmethod
596
+ def wrap(cls, controller, *, limits=None, user_api=None):
597
+ """Return an instance of this class that can be used as a decorator"""
598
+ return _ThreadpoolLimiterDecorator(
599
+ controller=controller, limits=limits, user_api=user_api
600
+ )
601
+
602
+ def restore_original_limits(self):
603
+ """Set the limits back to their original values"""
604
+ for lib_controller, original_info in zip(
605
+ self._controller.lib_controllers, self._original_info
606
+ ):
607
+ lib_controller.set_num_threads(original_info["num_threads"])
608
+
609
+ # Alias of `restore_original_limits` for backward compatibility
610
+ unregister = restore_original_limits
611
+
612
+ def get_original_num_threads(self):
613
+ """Original num_threads from before calling threadpool_limits
614
+
615
+ Return a dict `{user_api: num_threads}`.
616
+ """
617
+ num_threads = {}
618
+ warning_apis = []
619
+
620
+ for user_api in self._user_api:
621
+ limits = [
622
+ lib_info["num_threads"]
623
+ for lib_info in self._original_info
624
+ if lib_info["user_api"] == user_api
625
+ ]
626
+ limits = set(limits)
627
+ n_limits = len(limits)
628
+
629
+ if n_limits == 1:
630
+ limit = limits.pop()
631
+ elif n_limits == 0:
632
+ limit = None
633
+ else:
634
+ limit = min(limits)
635
+ warning_apis.append(user_api)
636
+
637
+ num_threads[user_api] = limit
638
+
639
+ if warning_apis:
640
+ warnings.warn(
641
+ "Multiple value possible for following user apis: "
642
+ + ", ".join(warning_apis)
643
+ + ". Returning the minimum."
644
+ )
645
+
646
+ return num_threads
647
+
648
+ def _check_params(self, limits, user_api):
649
+ """Suitable values for the _limits, _user_api and _prefixes attributes"""
650
+
651
+ if isinstance(limits, str) and limits == "sequential_blas_under_openmp":
652
+ (
653
+ limits,
654
+ user_api,
655
+ ) = self._controller._get_params_for_sequential_blas_under_openmp().values()
656
+
657
+ if limits is None or isinstance(limits, int):
658
+ if user_api is None:
659
+ user_api = _ALL_USER_APIS
660
+ elif user_api in _ALL_USER_APIS:
661
+ user_api = [user_api]
662
+ else:
663
+ raise ValueError(
664
+ f"user_api must be either in {_ALL_USER_APIS} or None. Got "
665
+ f"{user_api} instead."
666
+ )
667
+
668
+ if limits is not None:
669
+ limits = {api: limits for api in user_api}
670
+ prefixes = []
671
+ else:
672
+ if isinstance(limits, list):
673
+ # This should be a list of dicts of library info, for
674
+ # compatibility with the result from threadpool_info.
675
+ limits = {
676
+ lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits
677
+ }
678
+ elif isinstance(limits, ThreadpoolController):
679
+ # To set the limits from the library controllers of a
680
+ # ThreadpoolController object.
681
+ limits = {
682
+ lib_controller.prefix: lib_controller.num_threads
683
+ for lib_controller in limits.lib_controllers
684
+ }
685
+
686
+ if not isinstance(limits, dict):
687
+ raise TypeError(
688
+ "limits must either be an int, a list, a dict, or "
689
+ f"'sequential_blas_under_openmp'. Got {type(limits)} instead"
690
+ )
691
+
692
+ # With a dictionary, can set both specific limit for given
693
+ # libraries and global limit for user_api. Fetch each separately.
694
+ prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES]
695
+ user_api = [api for api in limits if api in _ALL_USER_APIS]
696
+
697
+ return limits, user_api, prefixes
698
+
699
+ def _set_threadpool_limits(self):
700
+ """Change the maximal number of threads in selected thread pools.
701
+
702
+ Return a list with all the supported libraries that have been found
703
+ matching `self._prefixes` and `self._user_api`.
704
+ """
705
+ if self._limits is None:
706
+ return
707
+
708
+ for lib_controller in self._controller.lib_controllers:
709
+ # self._limits is a dict {key: num_threads} where key is either
710
+ # a prefix or a user_api. If a library matches both, the limit
711
+ # corresponding to the prefix is chosen.
712
+ if lib_controller.prefix in self._limits:
713
+ num_threads = self._limits[lib_controller.prefix]
714
+ elif lib_controller.user_api in self._limits:
715
+ num_threads = self._limits[lib_controller.user_api]
716
+ else:
717
+ continue
718
+
719
+ if num_threads is not None:
720
+ lib_controller.set_num_threads(num_threads)
721
+
722
+
723
+ class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator):
724
+ """Same as _ThreadpoolLimiter but to be used as a decorator"""
725
+
726
+ def __init__(self, controller, *, limits=None, user_api=None):
727
+ self._limits, self._user_api, self._prefixes = self._check_params(
728
+ limits, user_api
729
+ )
730
+ self._controller = controller
731
+
732
+ def __enter__(self):
733
+ # we need to set the limits here and not in the __init__ because we want the
734
+ # limits to be set when calling the decorated function, not when creating the
735
+ # decorator.
736
+ self._original_info = self._controller.info()
737
+ self._set_threadpool_limits()
738
+ return self
739
+
740
+
741
+ @_format_docstring(
742
+ USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
743
+ BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
744
+ OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
745
+ )
746
+ class threadpool_limits(_ThreadpoolLimiter):
747
+ """Change the maximal number of threads that can be used in thread pools.
748
+
749
+ This object can be used either as a callable (the construction of this object
750
+ limits the number of threads), as a context manager in a `with` block to
751
+ automatically restore the original state of the controlled libraries when exiting
752
+ the block, or as a decorator through its `wrap` method.
753
+
754
+ Set the maximal number of threads that can be used in thread pools used in
755
+ the supported libraries to `limit`. This function works for libraries that
756
+ are already loaded in the interpreter and can be changed dynamically.
757
+
758
+ This effect is global and impacts the whole Python process. There is no thread level
759
+ isolation as these libraries do not offer thread-local APIs to configure the number
760
+ of threads to use in nested parallel calls.
761
+
762
+ Parameters
763
+ ----------
764
+ limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
765
+ The maximal number of threads that can be used in thread pools
766
+
767
+ - If int, sets the maximum number of threads to `limits` for each
768
+ library selected by `user_api`.
769
+
770
+ - If it is a dictionary `{{key: max_threads}}`, this function sets a
771
+ custom maximum number of threads for each `key` which can be either a
772
+ `user_api` or a `prefix` for a specific library.
773
+
774
+ - If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
775
+ and `user_api` parameters for the specific use case of sequential BLAS
776
+ calls within an OpenMP parallel region. The `user_api` parameter is
777
+ ignored.
778
+
779
+ - If None, this function does not do anything.
780
+
781
+ user_api : {USER_APIS} or None (default=None)
782
+ APIs of libraries to limit. Used only if `limits` is an int.
783
+
784
+ - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
785
+
786
+ - If "openmp", it will only limit OpenMP supported libraries
787
+ ({OPENMP_LIBS}). Note that it can affect the number of threads used
788
+ by the BLAS libraries if they rely on OpenMP.
789
+
790
+ - If None, this function will apply to all supported libraries.
791
+ """
792
+
793
+ def __init__(self, limits=None, user_api=None):
794
+ super().__init__(ThreadpoolController(), limits=limits, user_api=user_api)
795
+
796
+ @classmethod
797
+ def wrap(cls, limits=None, user_api=None):
798
+ return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api)
799
+
800
+
801
+ class ThreadpoolController:
802
+ """Collection of LibController objects for all loaded supported libraries
803
+
804
+ Attributes
805
+ ----------
806
+ lib_controllers : list of `LibController` objects
807
+ The list of library controllers of all loaded supported libraries.
808
+ """
809
+
810
+ # Cache for libc under POSIX and a few system libraries under Windows.
811
+ # We use a class level cache instead of an instance level cache because
812
+ # it's very unlikely that a shared library will be unloaded and reloaded
813
+ # during the lifetime of a program.
814
+ _system_libraries = dict()
815
+
816
+ def __init__(self):
817
+ self.lib_controllers = []
818
+ self._load_libraries()
819
+ self._warn_if_incompatible_openmp()
820
+
821
+ @classmethod
822
+ def _from_controllers(cls, lib_controllers):
823
+ new_controller = cls.__new__(cls)
824
+ new_controller.lib_controllers = lib_controllers
825
+ return new_controller
826
+
827
+ def info(self):
828
+ """Return lib_controllers info as a list of dicts"""
829
+ return [lib_controller.info() for lib_controller in self.lib_controllers]
830
+
831
+ def select(self, **kwargs):
832
+ """Return a ThreadpoolController containing a subset of its current
833
+ library controllers
834
+
835
+ It will select all libraries matching at least one pair (key, value) from kwargs
836
+ where key is an entry of the library info dict (like "user_api", "internal_api",
837
+ "prefix", ...) and value is the value or a list of acceptable values for that
838
+ entry.
839
+
840
+ For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])`
841
+ will select all library controllers whose internal_api is either "blis" or
842
+ "openblas".
843
+ """
844
+ for key, vals in kwargs.items():
845
+ kwargs[key] = [vals] if not isinstance(vals, list) else vals
846
+
847
+ lib_controllers = [
848
+ lib_controller
849
+ for lib_controller in self.lib_controllers
850
+ if any(
851
+ getattr(lib_controller, key, None) in vals
852
+ for key, vals in kwargs.items()
853
+ )
854
+ ]
855
+
856
+ return ThreadpoolController._from_controllers(lib_controllers)
857
+
858
+ def _get_params_for_sequential_blas_under_openmp(self):
859
+ """Return appropriate params to use for a sequential BLAS call in an OpenMP loop
860
+
861
+ This function takes into account the unexpected behavior of OpenBLAS with the
862
+ OpenMP threading layer.
863
+ """
864
+ if self.select(
865
+ internal_api="openblas", threading_layer="openmp"
866
+ ).lib_controllers:
867
+ return {"limits": None, "user_api": None}
868
+ return {"limits": 1, "user_api": "blas"}
869
+
870
+ @_format_docstring(
871
+ USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
872
+ BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
873
+ OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
874
+ )
875
+ def limit(self, *, limits=None, user_api=None):
876
+ """Change the maximal number of threads that can be used in thread pools.
877
+
878
+ This function returns an object that can be used either as a callable (the
879
+ construction of this object limits the number of threads) or as a context
880
+ manager, in a `with` block to automatically restore the original state of the
881
+ controlled libraries when exiting the block.
882
+
883
+ Set the maximal number of threads that can be used in thread pools used in
884
+ the supported libraries to `limits`. This function works for libraries that
885
+ are already loaded in the interpreter and can be changed dynamically.
886
+
887
+ This effect is global and impacts the whole Python process. There is no thread
888
+ level isolation as these libraries do not offer thread-local APIs to configure
889
+ the number of threads to use in nested parallel calls.
890
+
891
+ Parameters
892
+ ----------
893
+ limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
894
+ The maximal number of threads that can be used in thread pools
895
+
896
+ - If int, sets the maximum number of threads to `limits` for each
897
+ library selected by `user_api`.
898
+
899
+ - If it is a dictionary `{{key: max_threads}}`, this function sets a
900
+ custom maximum number of threads for each `key` which can be either a
901
+ `user_api` or a `prefix` for a specific library.
902
+
903
+ - If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
904
+ and `user_api` parameters for the specific use case of sequential BLAS
905
+ calls within an OpenMP parallel region. The `user_api` parameter is
906
+ ignored.
907
+
908
+ - If None, this function does not do anything.
909
+
910
+ user_api : {USER_APIS} or None (default=None)
911
+ APIs of libraries to limit. Used only if `limits` is an int.
912
+
913
+ - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
914
+
915
+ - If "openmp", it will only limit OpenMP supported libraries
916
+ ({OPENMP_LIBS}). Note that it can affect the number of threads used
917
+ by the BLAS libraries if they rely on OpenMP.
918
+
919
+ - If None, this function will apply to all supported libraries.
920
+ """
921
+ return _ThreadpoolLimiter(self, limits=limits, user_api=user_api)
922
+
923
+ @_format_docstring(
924
+ USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
925
+ BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
926
+ OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
927
+ )
928
+ def wrap(self, *, limits=None, user_api=None):
929
+ """Change the maximal number of threads that can be used in thread pools.
930
+
931
+ This function returns an object that can be used as a decorator.
932
+
933
+ Set the maximal number of threads that can be used in thread pools used in
934
+ the supported libraries to `limits`. This function works for libraries that
935
+ are already loaded in the interpreter and can be changed dynamically.
936
+
937
+ Parameters
938
+ ----------
939
+ limits : int, dict or None (default=None)
940
+ The maximal number of threads that can be used in thread pools
941
+
942
+ - If int, sets the maximum number of threads to `limits` for each
943
+ library selected by `user_api`.
944
+
945
+ - If it is a dictionary `{{key: max_threads}}`, this function sets a
946
+ custom maximum number of threads for each `key` which can be either a
947
+ `user_api` or a `prefix` for a specific library.
948
+
949
+ - If None, this function does not do anything.
950
+
951
+ user_api : {USER_APIS} or None (default=None)
952
+ APIs of libraries to limit. Used only if `limits` is an int.
953
+
954
+ - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
955
+
956
+ - If "openmp", it will only limit OpenMP supported libraries
957
+ ({OPENMP_LIBS}). Note that it can affect the number of threads used
958
+ by the BLAS libraries if they rely on OpenMP.
959
+
960
+ - If None, this function will apply to all supported libraries.
961
+ """
962
+ return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api)
963
+
964
+ def __len__(self):
965
+ return len(self.lib_controllers)
966
+
967
+ def _load_libraries(self):
968
+ """Loop through loaded shared libraries and store the supported ones"""
969
+ if sys.platform == "darwin":
970
+ self._find_libraries_with_dyld()
971
+ elif sys.platform == "win32":
972
+ self._find_libraries_with_enum_process_module_ex()
973
+ elif "pyodide" in sys.modules:
974
+ self._find_libraries_pyodide()
975
+ else:
976
+ self._find_libraries_with_dl_iterate_phdr()
977
+
978
+ def _find_libraries_with_dl_iterate_phdr(self):
979
+ """Loop through loaded libraries and return binders on supported ones
980
+
981
+ This function is expected to work on POSIX system only.
982
+ This code is adapted from code by Intel developer @anton-malakhov
983
+ available at https://github.com/IntelPython/smp
984
+
985
+ Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause
986
+ license
987
+ """
988
+ libc = self._get_libc()
989
+ if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover
990
+ warnings.warn(
991
+ "Could not find dl_iterate_phdr in the C standard library.",
992
+ RuntimeWarning,
993
+ )
994
+ return []
995
+
996
+ # Callback function for `dl_iterate_phdr` which is called for every
997
+ # library loaded in the current process until it returns 1.
998
+ def match_library_callback(info, size, data):
999
+ # Get the path of the current library
1000
+ filepath = info.contents.dlpi_name
1001
+ if filepath:
1002
+ filepath = filepath.decode("utf-8")
1003
+
1004
+ # Store the library controller if it is supported and selected
1005
+ self._make_controller_from_path(filepath)
1006
+ return 0
1007
+
1008
+ c_func_signature = ctypes.CFUNCTYPE(
1009
+ ctypes.c_int, # Return type
1010
+ ctypes.POINTER(_dl_phdr_info),
1011
+ ctypes.c_size_t,
1012
+ ctypes.c_char_p,
1013
+ )
1014
+ c_match_library_callback = c_func_signature(match_library_callback)
1015
+
1016
+ data = ctypes.c_char_p(b"")
1017
+ libc.dl_iterate_phdr(c_match_library_callback, data)
1018
+
1019
+ def _find_libraries_with_dyld(self):
1020
+ """Loop through loaded libraries and return binders on supported ones
1021
+
1022
+ This function is expected to work on OSX system only
1023
+ """
1024
+ libc = self._get_libc()
1025
+ if not hasattr(libc, "_dyld_image_count"): # pragma: no cover
1026
+ warnings.warn(
1027
+ "Could not find _dyld_image_count in the C standard library.",
1028
+ RuntimeWarning,
1029
+ )
1030
+ return []
1031
+
1032
+ n_dyld = libc._dyld_image_count()
1033
+ libc._dyld_get_image_name.restype = ctypes.c_char_p
1034
+
1035
+ for i in range(n_dyld):
1036
+ filepath = ctypes.string_at(libc._dyld_get_image_name(i))
1037
+ filepath = filepath.decode("utf-8")
1038
+
1039
+ # Store the library controller if it is supported and selected
1040
+ self._make_controller_from_path(filepath)
1041
+
1042
+ def _find_libraries_with_enum_process_module_ex(self):
1043
+ """Loop through loaded libraries and return binders on supported ones
1044
+
1045
+ This function is expected to work on windows system only.
1046
+ This code is adapted from code by Philipp Hagemeister @phihag available
1047
+ at https://stackoverflow.com/questions/17474574
1048
+ """
1049
+ from ctypes.wintypes import DWORD, HMODULE, MAX_PATH
1050
+
1051
+ PROCESS_QUERY_INFORMATION = 0x0400
1052
+ PROCESS_VM_READ = 0x0010
1053
+
1054
+ LIST_LIBRARIES_ALL = 0x03
1055
+
1056
+ ps_api = self._get_windll("Psapi")
1057
+ kernel_32 = self._get_windll("kernel32")
1058
+
1059
+ h_process = kernel_32.OpenProcess(
1060
+ PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid()
1061
+ )
1062
+ if not h_process: # pragma: no cover
1063
+ raise OSError(f"Could not open PID {os.getpid()}")
1064
+
1065
+ try:
1066
+ buf_count = 256
1067
+ needed = DWORD()
1068
+ # Grow the buffer until it becomes large enough to hold all the
1069
+ # module headers
1070
+ while True:
1071
+ buf = (HMODULE * buf_count)()
1072
+ buf_size = ctypes.sizeof(buf)
1073
+ if not ps_api.EnumProcessModulesEx(
1074
+ h_process,
1075
+ ctypes.byref(buf),
1076
+ buf_size,
1077
+ ctypes.byref(needed),
1078
+ LIST_LIBRARIES_ALL,
1079
+ ):
1080
+ raise OSError("EnumProcessModulesEx failed")
1081
+ if buf_size >= needed.value:
1082
+ break
1083
+ buf_count = needed.value // (buf_size // buf_count)
1084
+
1085
+ count = needed.value // (buf_size // buf_count)
1086
+ h_modules = map(HMODULE, buf[:count])
1087
+
1088
+ # Loop through all the module headers and get the library path
1089
+ # Allocate a buffer for the path 10 times the size of MAX_PATH to take
1090
+ # into account long path names.
1091
+ max_path = 10 * MAX_PATH
1092
+ buf = ctypes.create_unicode_buffer(max_path)
1093
+ n_size = DWORD()
1094
+ for h_module in h_modules:
1095
+ # Get the path of the current module
1096
+ if not ps_api.GetModuleFileNameExW(
1097
+ h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size)
1098
+ ):
1099
+ raise OSError("GetModuleFileNameEx failed")
1100
+ filepath = buf.value
1101
+
1102
+ if len(filepath) == max_path: # pragma: no cover
1103
+ warnings.warn(
1104
+ "Could not get the full path of a dynamic library (path too "
1105
+ "long). This library will be ignored and threadpoolctl might "
1106
+ "not be able to control or display information about all "
1107
+ f"loaded libraries. Here's the truncated path: {filepath!r}",
1108
+ RuntimeWarning,
1109
+ )
1110
+ else:
1111
+ # Store the library controller if it is supported and selected
1112
+ self._make_controller_from_path(filepath)
1113
+ finally:
1114
+ kernel_32.CloseHandle(h_process)
1115
+
1116
+ def _find_libraries_pyodide(self):
1117
+ """Pyodide specific implementation for finding loaded libraries.
1118
+
1119
+ Adapted from suggestion in https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1946696449.
1120
+
1121
+ One day, we may have a simpler solution. libc dl_iterate_phdr needs to
1122
+ be implemented in Emscripten and exposed in Pyodide, see
1123
+ https://github.com/emscripten-core/emscripten/issues/21354 for more
1124
+ details.
1125
+ """
1126
+ try:
1127
+ from pyodide_js._module import LDSO
1128
+ except ImportError:
1129
+ warnings.warn(
1130
+ "Unable to import LDSO from pyodide_js._module. This should never "
1131
+ "happen."
1132
+ )
1133
+ return
1134
+
1135
+ for filepath in LDSO.loadedLibsByName.as_object_map():
1136
+ # Some libraries are duplicated by Pyodide and do not exist in the
1137
+ # filesystem, so we first check for the existence of the file. For
1138
+ # more details, see
1139
+ # https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1947946728
1140
+ if os.path.exists(filepath):
1141
+ self._make_controller_from_path(filepath)
1142
+
1143
+ def _make_controller_from_path(self, filepath):
1144
+ """Store a library controller if it is supported and selected"""
1145
+ # Required to resolve symlinks
1146
+ filepath = _realpath(filepath)
1147
+ # `lower` required to take account of OpenMP dll case on Windows
1148
+ # (vcomp, VCOMP, Vcomp, ...)
1149
+ filename = os.path.basename(filepath).lower()
1150
+
1151
+ # Loop through supported libraries to find if this filename corresponds
1152
+ # to a supported one.
1153
+ for controller_class in _ALL_CONTROLLERS:
1154
+ # check if filename matches a supported prefix
1155
+ prefix = self._check_prefix(filename, controller_class.filename_prefixes)
1156
+
1157
+ # filename does not match any of the prefixes of the candidate
1158
+ # library. move to next library.
1159
+ if prefix is None:
1160
+ continue
1161
+
1162
+ # workaround for BLAS libraries packaged by conda-forge on windows, which
1163
+ # are all renamed "libblas.dll". We thus have to check to which BLAS
1164
+ # implementation it actually corresponds looking for implementation
1165
+ # specific symbols.
1166
+ if prefix == "libblas":
1167
+ if filename.endswith(".dll"):
1168
+ libblas = ctypes.CDLL(filepath, _RTLD_NOLOAD)
1169
+ if not any(
1170
+ hasattr(libblas, func)
1171
+ for func in controller_class.check_symbols
1172
+ ):
1173
+ continue
1174
+ else:
1175
+ # We ignore libblas on other platforms than windows because there
1176
+ # might be a libblas dso comming with openblas for instance that
1177
+ # can't be used to instantiate a pertinent LibController (many
1178
+ # symbols are missing) and would create confusion by making a
1179
+ # duplicate entry in threadpool_info.
1180
+ continue
1181
+
1182
+ # filename matches a prefix. Now we check if the library has the symbols we
1183
+ # are looking for. If none of the symbols exists, it's very likely not the
1184
+ # expected library (e.g. a library having a common prefix with one of the
1185
+ # our supported libraries). Otherwise, create and store the library
1186
+ # controller.
1187
+ lib_controller = controller_class(
1188
+ filepath=filepath, prefix=prefix, parent=self
1189
+ )
1190
+
1191
+ if filepath in (lib.filepath for lib in self.lib_controllers):
1192
+ # We already have a controller for this library.
1193
+ continue
1194
+
1195
+ if not hasattr(controller_class, "check_symbols") or any(
1196
+ hasattr(lib_controller.dynlib, func)
1197
+ for func in controller_class.check_symbols
1198
+ ):
1199
+ self.lib_controllers.append(lib_controller)
1200
+
1201
+ def _check_prefix(self, library_basename, filename_prefixes):
1202
+ """Return the prefix library_basename starts with
1203
+
1204
+ Return None if none matches.
1205
+ """
1206
+ for prefix in filename_prefixes:
1207
+ if library_basename.startswith(prefix):
1208
+ return prefix
1209
+ return None
1210
+
1211
+ def _warn_if_incompatible_openmp(self):
1212
+ """Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded"""
1213
+ prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers]
1214
+ msg = textwrap.dedent(
1215
+ """
1216
+ Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
1217
+ the same time. Both libraries are known to be incompatible and this
1218
+ can cause random crashes or deadlocks on Linux when loaded in the
1219
+ same Python program.
1220
+ Using threadpoolctl may cause crashes or deadlocks. For more
1221
+ information and possible workarounds, please see
1222
+ https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md
1223
+ """
1224
+ )
1225
+ if "libomp" in prefixes and "libiomp" in prefixes:
1226
+ warnings.warn(msg, RuntimeWarning)
1227
+
1228
+ @classmethod
1229
+ def _get_libc(cls):
1230
+ """Load the lib-C for unix systems."""
1231
+ libc = cls._system_libraries.get("libc")
1232
+ if libc is None:
1233
+ # Remark: If libc is statically linked or if Python is linked against an
1234
+ # alternative implementation of libc like musl, find_library will return
1235
+ # None and CDLL will load the main program itself which should contain the
1236
+ # libc symbols. We still name it libc for convenience.
1237
+ # If the main program does not contain the libc symbols, it's ok because
1238
+ # we check their presence later anyway.
1239
+ libc = ctypes.CDLL(find_library("c"), mode=_RTLD_NOLOAD)
1240
+ cls._system_libraries["libc"] = libc
1241
+ return libc
1242
+
1243
+ @classmethod
1244
+ def _get_windll(cls, dll_name):
1245
+ """Load a windows DLL"""
1246
+ dll = cls._system_libraries.get(dll_name)
1247
+ if dll is None:
1248
+ dll = ctypes.WinDLL(f"{dll_name}.dll")
1249
+ cls._system_libraries[dll_name] = dll
1250
+ return dll
1251
+
1252
+
1253
+ def _main():
1254
+ """Commandline interface to display thread-pool information and exit."""
1255
+ import argparse
1256
+ import importlib
1257
+ import json
1258
+ import sys
1259
+
1260
+ parser = argparse.ArgumentParser(
1261
+ usage="python -m threadpoolctl -i numpy scipy.linalg xgboost",
1262
+ description="Display thread-pool information and exit.",
1263
+ )
1264
+ parser.add_argument(
1265
+ "-i",
1266
+ "--import",
1267
+ dest="modules",
1268
+ nargs="*",
1269
+ default=(),
1270
+ help="Python modules to import before introspecting thread-pools.",
1271
+ )
1272
+ parser.add_argument(
1273
+ "-c",
1274
+ "--command",
1275
+ help="a Python statement to execute before introspecting thread-pools.",
1276
+ )
1277
+
1278
+ options = parser.parse_args(sys.argv[1:])
1279
+ for module in options.modules:
1280
+ try:
1281
+ importlib.import_module(module, package=None)
1282
+ except ImportError:
1283
+ print("WARNING: could not import", module, file=sys.stderr)
1284
+
1285
+ if options.command:
1286
+ exec(options.command)
1287
+
1288
+ print(json.dumps(threadpool_info(), indent=2))
1289
+
1290
+
1291
+ if __name__ == "__main__":
1292
+ _main()
venv/Lib/site-packages/typing_extensions.py ADDED
The diff for this file is too large to render. See raw diff