Surn commited on
Commit
0e6c759
·
1 Parent(s): 907a484

Fix prompt_tokens.shape[-1] issue with max_prompt_length

Browse files
app.py CHANGED
@@ -281,7 +281,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
281
  cfg_coef=cfg_coef,
282
  duration=segment_duration,
283
  two_step_cfg=False,
284
- extend_stride=10,
285
  rep_penalty=0.5,
286
  cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
287
  )
 
281
  cfg_coef=cfg_coef,
282
  duration=segment_duration,
283
  two_step_cfg=False,
284
+ extend_stride=2,
285
  rep_penalty=0.5,
286
  cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
287
  )
audiocraft/models/genmodel.py CHANGED
@@ -16,6 +16,7 @@ import typing as tp
16
 
17
  import omegaconf
18
  import torch
 
19
 
20
  from .encodec import CompressionModel
21
  from .lm import LMModel
@@ -191,11 +192,11 @@ class BaseGenModel(ABC):
191
  return self.generate_audio(tokens)
192
 
193
  def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
194
- prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
195
  """Generate discrete audio tokens given audio prompt and/or conditions.
196
 
197
  Args:
198
- attributes (list of ConditioningAttributes): Conditions used for generation (here text).
199
  prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
200
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
201
  Returns:
@@ -207,20 +208,24 @@ class BaseGenModel(ABC):
207
 
208
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
209
  generated_tokens += current_gen_offset
 
 
210
  if self._progress_callback is not None:
211
  # Note that total_gen_len might be quite wrong depending on the
212
  # codebook pattern used, but with delay it is almost accurate.
213
- self._progress_callback(generated_tokens, tokens_to_generate)
214
- else:
215
- print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
 
 
 
216
 
217
  if prompt_tokens is not None:
218
- assert max_prompt_len >= prompt_tokens.shape[-1], \
219
- "Prompt is longer than audio to generate"
220
 
221
- callback = None
222
- if progress:
223
- callback = _progress_callback
224
 
225
  if self.duration <= self.max_duration:
226
  # generate by sampling from LM, simple case.
@@ -240,6 +245,7 @@ class BaseGenModel(ABC):
240
  prompt_length = prompt_tokens.shape[-1]
241
 
242
  stride_tokens = int(self.frame_rate * self.extend_stride)
 
243
  while current_gen_offset + prompt_length < total_gen_len:
244
  time_offset = current_gen_offset / self.frame_rate
245
  chunk_duration = min(self.duration - time_offset, self.max_duration)
 
16
 
17
  import omegaconf
18
  import torch
19
+ import gradio as gr
20
 
21
  from .encodec import CompressionModel
22
  from .lm import LMModel
 
192
  return self.generate_audio(tokens)
193
 
194
  def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
195
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False, progress_callback: gr.Progress = None) -> torch.Tensor:
196
  """Generate discrete audio tokens given audio prompt and/or conditions.
197
 
198
  Args:
199
+ attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
200
  prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
201
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
202
  Returns:
 
208
 
209
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
210
  generated_tokens += current_gen_offset
211
+ generated_tokens /= ((tokens_to_generate) / self.duration)
212
+ tokens_to_generate /= ((tokens_to_generate) / self.duration)
213
  if self._progress_callback is not None:
214
  # Note that total_gen_len might be quite wrong depending on the
215
  # codebook pattern used, but with delay it is almost accurate.
216
+ self._progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
217
+ if progress_callback is not None:
218
+ # Update Gradio progress bar
219
+ progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
220
+ if progress:
221
+ print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
222
 
223
  if prompt_tokens is not None:
224
+ if prompt_tokens.shape[-1] > max_prompt_len:
225
+ prompt_tokens = prompt_tokens[..., :max_prompt_len]
226
 
227
+ # callback = None
228
+ callback = _progress_callback
 
229
 
230
  if self.duration <= self.max_duration:
231
  # generate by sampling from LM, simple case.
 
245
  prompt_length = prompt_tokens.shape[-1]
246
 
247
  stride_tokens = int(self.frame_rate * self.extend_stride)
248
+
249
  while current_gen_offset + prompt_length < total_gen_len:
250
  time_offset = current_gen_offset / self.frame_rate
251
  chunk_duration = min(self.duration - time_offset, self.max_duration)
audiocraft/models/lm.py CHANGED
@@ -517,7 +517,7 @@ class LMModel(StreamingModule):
517
  B, K, T = prompt.shape
518
  start_offset = T
519
  print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
520
- assert start_offset < max_gen_len
521
 
522
  pattern = self.pattern_provider.get_pattern(max_gen_len)
523
  # this token is used as default value for codes that are not generated yet
 
517
  B, K, T = prompt.shape
518
  start_offset = T
519
  print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
520
+ assert start_offset <= max_gen_len
521
 
522
  pattern = self.pattern_provider.get_pattern(max_gen_len)
523
  # this token is used as default value for codes that are not generated yet
audiocraft/models/musicgen.py CHANGED
@@ -453,8 +453,8 @@ class MusicGen:
453
  print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
454
 
455
  if prompt_tokens is not None:
456
- assert max_prompt_len > prompt_tokens.shape[-1], \
457
- "Prompt is longer than audio to generate"
458
 
459
  # callback = None
460
  callback = _progress_callback
 
453
  print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
454
 
455
  if prompt_tokens is not None:
456
+ if prompt_tokens.shape[-1] > max_prompt_len:
457
+ prompt_tokens = prompt_tokens[..., :max_prompt_len]
458
 
459
  # callback = None
460
  callback = _progress_callback