asigalov61 commited on
Commit
018d705
·
verified ·
1 Parent(s): da0ed0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -49
app.py CHANGED
@@ -68,6 +68,57 @@ PREVIEW_LENGTH = 120 # in tokens
68
 
69
  #==================================================================================
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def load_model(model_selector):
72
 
73
  return [[], []]
@@ -176,55 +227,6 @@ def generate_music(prime,
176
  else:
177
  inputs = prime[-num_mem_tokens:]
178
 
179
- print('=' * 70)
180
- print('Instantiating model...')
181
-
182
- device_type = 'cuda'
183
- dtype = 'bfloat16'
184
-
185
- ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
186
- ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
187
-
188
- SEQ_LEN = 4096
189
-
190
- if model_selector == 'with velocity - 3 epochs':
191
- PAD_IDX = 512
192
-
193
- else:
194
- PAD_IDX = 384
195
-
196
- model = TransformerWrapper(
197
- num_tokens = PAD_IDX+1,
198
- max_seq_len = SEQ_LEN,
199
- attn_layers = Decoder(dim = 2048,
200
- depth = 8,
201
- heads = 32,
202
- rotary_pos_emb = True,
203
- attn_flash = True
204
- )
205
- )
206
-
207
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
208
-
209
- print('=' * 70)
210
- print('Loading model checkpoint...')
211
-
212
- model_checkpoint = hf_hub_download(repo_id='asigalov61/Godzilla-Piano-Transformer',
213
- filename='Godzilla_Piano_Transformer_No_Velocity_Trained_Model_14903_steps_0.4874_loss_0.8571_acc.pth')
214
-
215
- model.load_state_dict(torch.load(model_checkpoint, map_location='cuda', weights_only=True))
216
-
217
- model = torch.compile(model, mode='max-autotune')
218
-
219
- print('=' * 70)
220
- print('Done!')
221
- print('=' * 70)
222
- print('Model will use', dtype, 'precision...')
223
- print('=' * 70)
224
-
225
- model.cuda()
226
- model.eval()
227
-
228
  print('Generating...')
229
 
230
  inp = [inputs] * num_gen_batches
 
68
 
69
  #==================================================================================
70
 
71
+ print('=' * 70)
72
+ print('Instantiating model...')
73
+
74
+ device_type = 'cuda'
75
+ dtype = 'bfloat16'
76
+
77
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
78
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
79
+
80
+ SEQ_LEN = 4096
81
+
82
+ if model_selector == 'with velocity - 3 epochs':
83
+ PAD_IDX = 512
84
+
85
+ else:
86
+ PAD_IDX = 384
87
+
88
+ model = TransformerWrapper(
89
+ num_tokens = PAD_IDX+1,
90
+ max_seq_len = SEQ_LEN,
91
+ attn_layers = Decoder(dim = 2048,
92
+ depth = 8,
93
+ heads = 32,
94
+ rotary_pos_emb = True,
95
+ attn_flash = True
96
+ )
97
+ )
98
+
99
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
100
+
101
+ print('=' * 70)
102
+ print('Loading model checkpoint...')
103
+
104
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Godzilla-Piano-Transformer',
105
+ filename='Godzilla_Piano_Transformer_No_Velocity_Trained_Model_14903_steps_0.4874_loss_0.8571_acc.pth')
106
+
107
+ model.load_state_dict(torch.load(model_checkpoint, map_location='cuda', weights_only=True))
108
+
109
+ model = torch.compile(model, mode='max-autotune')
110
+
111
+ print('=' * 70)
112
+ print('Done!')
113
+ print('=' * 70)
114
+ print('Model will use', dtype, 'precision...')
115
+ print('=' * 70)
116
+
117
+ model.cuda()
118
+ model.eval()
119
+
120
+ #==================================================================================
121
+
122
  def load_model(model_selector):
123
 
124
  return [[], []]
 
227
  else:
228
  inputs = prime[-num_mem_tokens:]
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  print('Generating...')
231
 
232
  inp = [inputs] * num_gen_batches