jiuface commited on
Commit
f3ff64c
·
verified ·
1 Parent(s): 890ef8f

Create aoti.py

Browse files
Files changed (1) hide show
  1. aoti.py +32 -0
aoti.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import cast
2
+
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
6
+ from spaces.zero.torch.aoti import ZeroGPUWeights
7
+ from torch._functorch._aot_autograd.subclass_parametrization import unwrap_tensor_subclass_parameters
8
+
9
+
10
+ def _shallow_clone_module(module: torch.nn.Module) -> torch.nn.Module:
11
+ clone = object.__new__(module.__class__)
12
+ clone.__dict__ = module.__dict__.copy()
13
+ clone._parameters = module._parameters.copy()
14
+ clone._buffers = module._buffers.copy()
15
+ clone._modules = {k: _shallow_clone_module(v) for k, v in module._modules.items() if v is not None}
16
+ return clone
17
+
18
+
19
+ def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
20
+ repeated_blocks = cast(list[str], module._repeated_blocks)
21
+ aoti_files = {name: hf_hub_download(
22
+ repo_id=repo_id,
23
+ filename='package.pt2',
24
+ subfolder=name if variant is None else f'{name}.{variant}',
25
+ ) for name in repeated_blocks}
26
+ for block_name, aoti_file in aoti_files.items():
27
+ for block in module.modules():
28
+ if block.__class__.__name__ == block_name:
29
+ block_ = _shallow_clone_module(block)
30
+ unwrap_tensor_subclass_parameters(block_)
31
+ weights = ZeroGPUWeights(block_.state_dict())
32
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)