Kiwinicki commited on
Commit
5851f14
·
1 Parent(s): 5c23324

upload pth weights

Browse files
Files changed (3) hide show
  1. README.md +6 -14
  2. generator.safetensors → generator.pth +2 -2
  3. model.py +2 -6
README.md CHANGED
@@ -7,42 +7,34 @@ To load and initialize the `Generator` (based on CycleGAN with better cycles) mo
7
  Ensure you have the necessary Python packages installed:
8
 
9
  ```bash
10
- pip install torch==2.5.1 torchvision==0.20.1 safetensors huggingface_hub
11
  ```
12
 
13
  ## 2. Download Model Files
14
 
15
- Retrieve the `pytorch_model.safetensors` and `model.py` files from the Hugging Face repository using the `huggingface_hub` library:
16
 
17
  ```python
18
  from huggingface_hub import hf_hub_download
19
 
20
  repo_id = "Kiwinicki/sat2map-generator"
21
- model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.safetensors")
22
  generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
23
  ```
24
 
25
  ## 3. Load the Model
26
 
27
- Import the `Generator` class and load the model weights from the safetensors file:
28
 
29
  ```python
30
  import torch
31
- from safetensors.torch import load_file
32
  from model import Generator, GeneratorConfig
33
 
34
- # Initialize configuration with default values
35
- cfg = GeneratorConfig(
36
- channels=3,
37
- num_features=64,
38
- num_residuals=12,
39
- depth=4
40
- )
41
 
42
  # Load the generator model
43
- state_dict = load_file(model_path)
44
  generator = Generator(cfg)
45
- generator.load_state_dict(state_dict)
46
  generator.eval()
47
 
48
  # Test the model
 
7
  Ensure you have the necessary Python packages installed:
8
 
9
  ```bash
10
+ pip install torch==2.5.1 torchvision==0.20.1 huggingface_hub
11
  ```
12
 
13
  ## 2. Download Model Files
14
 
15
+ Retrieve the `generator.pth` and `model.py` files from the Hugging Face repository using the `huggingface_hub` library:
16
 
17
  ```python
18
  from huggingface_hub import hf_hub_download
19
 
20
  repo_id = "Kiwinicki/sat2map-generator"
21
+ model_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
22
  generator_code_path = hf_hub_download(repo_id=repo_id, filename="model.py")
23
  ```
24
 
25
  ## 3. Load the Model
26
 
27
+ Import the `Generator` class and load the model weights from the `.pth` file:
28
 
29
  ```python
30
  import torch
 
31
  from model import Generator, GeneratorConfig
32
 
 
 
 
 
 
 
 
33
 
34
  # Load the generator model
35
+ cfg = GeneratorConfig()
36
  generator = Generator(cfg)
37
+ generator.load_state_dict(torch.load('generator.pth'))
38
  generator.eval()
39
 
40
  # Test the model
generator.safetensors → generator.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:662ab517fdeb62d15349409d45f0df28870fcd9f1c3a71f5a4efb8d8d830f144
3
- size 59680580
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:323795e775e9a92e6f3d44cb87c3609911cee6201f586519c2f0f6fdb9361841
3
+ size 59701794
model.py CHANGED
@@ -135,14 +135,10 @@ class ResidualBlock(nn.Module):
135
 
136
  if __name__ == '__main__':
137
  import torch
138
- from safetensors.torch import load_file
139
-
140
  cfg = GeneratorConfig()
141
- state_dict = load_file('generator.safetensors')
142
  generator = Generator(cfg)
143
- generator.load_state_dict(state_dict)
144
  generator.eval()
145
 
146
- x = torch.randn([1, cfg.channels, 256, 256])
147
- out = generator(x)
148
  print(out.shape)
 
135
 
136
  if __name__ == '__main__':
137
  import torch
 
 
138
  cfg = GeneratorConfig()
 
139
  generator = Generator(cfg)
140
+ generator.load_state_dict(torch.load('generator.pth'))
141
  generator.eval()
142
 
143
+ out = generator(torch.randn([1, cfg.channels, 256, 256]))
 
144
  print(out.shape)