Spaces:
Running
on
T4
Running
on
T4
Fix prompt_tokens.shape[-1] issue with max_prompt_length
Browse files- app.py +1 -1
- audiocraft/models/genmodel.py +16 -10
- audiocraft/models/lm.py +1 -1
- audiocraft/models/musicgen.py +2 -2
app.py
CHANGED
@@ -281,7 +281,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
281 |
cfg_coef=cfg_coef,
|
282 |
duration=segment_duration,
|
283 |
two_step_cfg=False,
|
284 |
-
extend_stride=
|
285 |
rep_penalty=0.5,
|
286 |
cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
|
287 |
)
|
|
|
281 |
cfg_coef=cfg_coef,
|
282 |
duration=segment_duration,
|
283 |
two_step_cfg=False,
|
284 |
+
extend_stride=2,
|
285 |
rep_penalty=0.5,
|
286 |
cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
|
287 |
)
|
audiocraft/models/genmodel.py
CHANGED
@@ -16,6 +16,7 @@ import typing as tp
|
|
16 |
|
17 |
import omegaconf
|
18 |
import torch
|
|
|
19 |
|
20 |
from .encodec import CompressionModel
|
21 |
from .lm import LMModel
|
@@ -191,11 +192,11 @@ class BaseGenModel(ABC):
|
|
191 |
return self.generate_audio(tokens)
|
192 |
|
193 |
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
194 |
-
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
195 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
196 |
|
197 |
Args:
|
198 |
-
attributes (list of ConditioningAttributes): Conditions used for generation (
|
199 |
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
200 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
201 |
Returns:
|
@@ -207,20 +208,24 @@ class BaseGenModel(ABC):
|
|
207 |
|
208 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
209 |
generated_tokens += current_gen_offset
|
|
|
|
|
210 |
if self._progress_callback is not None:
|
211 |
# Note that total_gen_len might be quite wrong depending on the
|
212 |
# codebook pattern used, but with delay it is almost accurate.
|
213 |
-
self._progress_callback(generated_tokens, tokens_to_generate)
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
216 |
|
217 |
if prompt_tokens is not None:
|
218 |
-
|
219 |
-
|
220 |
|
221 |
-
callback = None
|
222 |
-
|
223 |
-
callback = _progress_callback
|
224 |
|
225 |
if self.duration <= self.max_duration:
|
226 |
# generate by sampling from LM, simple case.
|
@@ -240,6 +245,7 @@ class BaseGenModel(ABC):
|
|
240 |
prompt_length = prompt_tokens.shape[-1]
|
241 |
|
242 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
|
|
243 |
while current_gen_offset + prompt_length < total_gen_len:
|
244 |
time_offset = current_gen_offset / self.frame_rate
|
245 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
|
|
16 |
|
17 |
import omegaconf
|
18 |
import torch
|
19 |
+
import gradio as gr
|
20 |
|
21 |
from .encodec import CompressionModel
|
22 |
from .lm import LMModel
|
|
|
192 |
return self.generate_audio(tokens)
|
193 |
|
194 |
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
195 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False, progress_callback: gr.Progress = None) -> torch.Tensor:
|
196 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
197 |
|
198 |
Args:
|
199 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
200 |
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
201 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
202 |
Returns:
|
|
|
208 |
|
209 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
210 |
generated_tokens += current_gen_offset
|
211 |
+
generated_tokens /= ((tokens_to_generate) / self.duration)
|
212 |
+
tokens_to_generate /= ((tokens_to_generate) / self.duration)
|
213 |
if self._progress_callback is not None:
|
214 |
# Note that total_gen_len might be quite wrong depending on the
|
215 |
# codebook pattern used, but with delay it is almost accurate.
|
216 |
+
self._progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
217 |
+
if progress_callback is not None:
|
218 |
+
# Update Gradio progress bar
|
219 |
+
progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
220 |
+
if progress:
|
221 |
+
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
222 |
|
223 |
if prompt_tokens is not None:
|
224 |
+
if prompt_tokens.shape[-1] > max_prompt_len:
|
225 |
+
prompt_tokens = prompt_tokens[..., :max_prompt_len]
|
226 |
|
227 |
+
# callback = None
|
228 |
+
callback = _progress_callback
|
|
|
229 |
|
230 |
if self.duration <= self.max_duration:
|
231 |
# generate by sampling from LM, simple case.
|
|
|
245 |
prompt_length = prompt_tokens.shape[-1]
|
246 |
|
247 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
248 |
+
|
249 |
while current_gen_offset + prompt_length < total_gen_len:
|
250 |
time_offset = current_gen_offset / self.frame_rate
|
251 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
audiocraft/models/lm.py
CHANGED
@@ -517,7 +517,7 @@ class LMModel(StreamingModule):
|
|
517 |
B, K, T = prompt.shape
|
518 |
start_offset = T
|
519 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
520 |
-
assert start_offset
|
521 |
|
522 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
523 |
# this token is used as default value for codes that are not generated yet
|
|
|
517 |
B, K, T = prompt.shape
|
518 |
start_offset = T
|
519 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
520 |
+
assert start_offset <= max_gen_len
|
521 |
|
522 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
523 |
# this token is used as default value for codes that are not generated yet
|
audiocraft/models/musicgen.py
CHANGED
@@ -453,8 +453,8 @@ class MusicGen:
|
|
453 |
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
454 |
|
455 |
if prompt_tokens is not None:
|
456 |
-
|
457 |
-
|
458 |
|
459 |
# callback = None
|
460 |
callback = _progress_callback
|
|
|
453 |
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
454 |
|
455 |
if prompt_tokens is not None:
|
456 |
+
if prompt_tokens.shape[-1] > max_prompt_len:
|
457 |
+
prompt_tokens = prompt_tokens[..., :max_prompt_len]
|
458 |
|
459 |
# callback = None
|
460 |
callback = _progress_callback
|