Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -68,6 +68,57 @@ PREVIEW_LENGTH = 120 # in tokens
|
|
68 |
|
69 |
#==================================================================================
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def load_model(model_selector):
|
72 |
|
73 |
return [[], []]
|
@@ -176,55 +227,6 @@ def generate_music(prime,
|
|
176 |
else:
|
177 |
inputs = prime[-num_mem_tokens:]
|
178 |
|
179 |
-
print('=' * 70)
|
180 |
-
print('Instantiating model...')
|
181 |
-
|
182 |
-
device_type = 'cuda'
|
183 |
-
dtype = 'bfloat16'
|
184 |
-
|
185 |
-
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
186 |
-
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
187 |
-
|
188 |
-
SEQ_LEN = 4096
|
189 |
-
|
190 |
-
if model_selector == 'with velocity - 3 epochs':
|
191 |
-
PAD_IDX = 512
|
192 |
-
|
193 |
-
else:
|
194 |
-
PAD_IDX = 384
|
195 |
-
|
196 |
-
model = TransformerWrapper(
|
197 |
-
num_tokens = PAD_IDX+1,
|
198 |
-
max_seq_len = SEQ_LEN,
|
199 |
-
attn_layers = Decoder(dim = 2048,
|
200 |
-
depth = 8,
|
201 |
-
heads = 32,
|
202 |
-
rotary_pos_emb = True,
|
203 |
-
attn_flash = True
|
204 |
-
)
|
205 |
-
)
|
206 |
-
|
207 |
-
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
|
208 |
-
|
209 |
-
print('=' * 70)
|
210 |
-
print('Loading model checkpoint...')
|
211 |
-
|
212 |
-
model_checkpoint = hf_hub_download(repo_id='asigalov61/Godzilla-Piano-Transformer',
|
213 |
-
filename='Godzilla_Piano_Transformer_No_Velocity_Trained_Model_14903_steps_0.4874_loss_0.8571_acc.pth')
|
214 |
-
|
215 |
-
model.load_state_dict(torch.load(model_checkpoint, map_location='cuda', weights_only=True))
|
216 |
-
|
217 |
-
model = torch.compile(model, mode='max-autotune')
|
218 |
-
|
219 |
-
print('=' * 70)
|
220 |
-
print('Done!')
|
221 |
-
print('=' * 70)
|
222 |
-
print('Model will use', dtype, 'precision...')
|
223 |
-
print('=' * 70)
|
224 |
-
|
225 |
-
model.cuda()
|
226 |
-
model.eval()
|
227 |
-
|
228 |
print('Generating...')
|
229 |
|
230 |
inp = [inputs] * num_gen_batches
|
|
|
68 |
|
69 |
#==================================================================================
|
70 |
|
71 |
+
print('=' * 70)
|
72 |
+
print('Instantiating model...')
|
73 |
+
|
74 |
+
device_type = 'cuda'
|
75 |
+
dtype = 'bfloat16'
|
76 |
+
|
77 |
+
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
78 |
+
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
79 |
+
|
80 |
+
SEQ_LEN = 4096
|
81 |
+
|
82 |
+
if model_selector == 'with velocity - 3 epochs':
|
83 |
+
PAD_IDX = 512
|
84 |
+
|
85 |
+
else:
|
86 |
+
PAD_IDX = 384
|
87 |
+
|
88 |
+
model = TransformerWrapper(
|
89 |
+
num_tokens = PAD_IDX+1,
|
90 |
+
max_seq_len = SEQ_LEN,
|
91 |
+
attn_layers = Decoder(dim = 2048,
|
92 |
+
depth = 8,
|
93 |
+
heads = 32,
|
94 |
+
rotary_pos_emb = True,
|
95 |
+
attn_flash = True
|
96 |
+
)
|
97 |
+
)
|
98 |
+
|
99 |
+
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
|
100 |
+
|
101 |
+
print('=' * 70)
|
102 |
+
print('Loading model checkpoint...')
|
103 |
+
|
104 |
+
model_checkpoint = hf_hub_download(repo_id='asigalov61/Godzilla-Piano-Transformer',
|
105 |
+
filename='Godzilla_Piano_Transformer_No_Velocity_Trained_Model_14903_steps_0.4874_loss_0.8571_acc.pth')
|
106 |
+
|
107 |
+
model.load_state_dict(torch.load(model_checkpoint, map_location='cuda', weights_only=True))
|
108 |
+
|
109 |
+
model = torch.compile(model, mode='max-autotune')
|
110 |
+
|
111 |
+
print('=' * 70)
|
112 |
+
print('Done!')
|
113 |
+
print('=' * 70)
|
114 |
+
print('Model will use', dtype, 'precision...')
|
115 |
+
print('=' * 70)
|
116 |
+
|
117 |
+
model.cuda()
|
118 |
+
model.eval()
|
119 |
+
|
120 |
+
#==================================================================================
|
121 |
+
|
122 |
def load_model(model_selector):
|
123 |
|
124 |
return [[], []]
|
|
|
227 |
else:
|
228 |
inputs = prime[-num_mem_tokens:]
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
print('Generating...')
|
231 |
|
232 |
inp = [inputs] * num_gen_batches
|