upload pth weights
Browse files- README.md +6 -14
- generator.safetensors → generator.pth +2 -2
- model.py +2 -6
README.md
CHANGED
@@ -7,42 +7,34 @@ To load and initialize the `Generator` (based on CycleGAN with better cycles) mo
|
|
7 |
Ensure you have the necessary Python packages installed:
|
8 |
|
9 |
```bash
|
10 |
-
pip install torch==2.5.1 torchvision==0.20.1
|
11 |
```
|
12 |
|
13 |
## 2. Download Model Files
|
14 |
|
15 |
-
Retrieve the `
|
16 |
|
17 |
```python
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
20 |
repo_id = "Kiwinicki/sat2map-generator"
|
21 |
-
model_path = hf_hub_download(repo_id=repo_id, filename="
|
22 |
generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
|
23 |
```
|
24 |
|
25 |
## 3. Load the Model
|
26 |
|
27 |
-
Import the `Generator` class and load the model weights from the
|
28 |
|
29 |
```python
|
30 |
import torch
|
31 |
-
from safetensors.torch import load_file
|
32 |
from model import Generator, GeneratorConfig
|
33 |
|
34 |
-
# Initialize configuration with default values
|
35 |
-
cfg = GeneratorConfig(
|
36 |
-
channels=3,
|
37 |
-
num_features=64,
|
38 |
-
num_residuals=12,
|
39 |
-
depth=4
|
40 |
-
)
|
41 |
|
42 |
# Load the generator model
|
43 |
-
|
44 |
generator = Generator(cfg)
|
45 |
-
generator.load_state_dict(
|
46 |
generator.eval()
|
47 |
|
48 |
# Test the model
|
|
|
7 |
Ensure you have the necessary Python packages installed:
|
8 |
|
9 |
```bash
|
10 |
+
pip install torch==2.5.1 torchvision==0.20.1 huggingface_hub
|
11 |
```
|
12 |
|
13 |
## 2. Download Model Files
|
14 |
|
15 |
+
Retrieve the `generator.pth` and `model.py` files from the Hugging Face repository using the `huggingface_hub` library:
|
16 |
|
17 |
```python
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
20 |
repo_id = "Kiwinicki/sat2map-generator"
|
21 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
|
22 |
generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
|
23 |
```
|
24 |
|
25 |
## 3. Load the Model
|
26 |
|
27 |
+
Import the `Generator` class and load the model weights from the `.pth` file:
|
28 |
|
29 |
```python
|
30 |
import torch
|
|
|
31 |
from model import Generator, GeneratorConfig
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
# Load the generator model
|
35 |
+
cfg = GeneratorConfig()
|
36 |
generator = Generator(cfg)
|
37 |
+
generator.load_state_dict(torch.load('generator.pth'))
|
38 |
generator.eval()
|
39 |
|
40 |
# Test the model
|
generator.safetensors → generator.pth
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:323795e775e9a92e6f3d44cb87c3609911cee6201f586519c2f0f6fdb9361841
|
3 |
+
size 59701794
|
model.py
CHANGED
@@ -135,14 +135,10 @@ class ResidualBlock(nn.Module):
|
|
135 |
|
136 |
if __name__ == '__main__':
|
137 |
import torch
|
138 |
-
from safetensors.torch import load_file
|
139 |
-
|
140 |
cfg = GeneratorConfig()
|
141 |
-
state_dict = load_file('generator.safetensors')
|
142 |
generator = Generator(cfg)
|
143 |
-
generator.load_state_dict(
|
144 |
generator.eval()
|
145 |
|
146 |
-
|
147 |
-
out = generator(x)
|
148 |
print(out.shape)
|
|
|
135 |
|
136 |
if __name__ == '__main__':
|
137 |
import torch
|
|
|
|
|
138 |
cfg = GeneratorConfig()
|
|
|
139 |
generator = Generator(cfg)
|
140 |
+
generator.load_state_dict(torch.load('generator.pth'))
|
141 |
generator.eval()
|
142 |
|
143 |
+
out = generator(torch.randn([1, cfg.channels, 256, 256]))
|
|
|
144 |
print(out.shape)
|