arjunbahuguna commited on
Commit
b48e9ef
·
verified ·
1 Parent(s): b6a5faf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -2,26 +2,24 @@
2
 
3
  from queue import Queue
4
  from threading import Thread
5
- from typing import Optional
6
  import numpy as np
7
  import torch
8
  from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
9
- from transformers.generation.streamers import BaseStreamer
10
  import gradio as gr
11
  import spaces
12
 
13
  model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
14
  processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
15
-
16
  title = "AI Radio"
17
 
18
- class MusicgenStreamer(BaseStreamer):
19
  def __init__(self, model, device=None, play_steps=10, stride=None, timeout=None):
20
  self.decoder, self.audio_encoder, self.generation_config = model.decoder, model.audio_encoder, model.generation_config
21
- self.device = device if device else model.device
22
  self.play_steps = play_steps
23
- self.stride = stride if stride else np.prod(self.audio_encoder.config.upsampling_ratios) * (play_steps - self.decoder.num_codebooks) // 6
24
- self.token_cache, self.to_yield, self.audio_queue, self.stop_signal, self.timeout = None, 0, Queue(), None, timeout
 
25
 
26
  def apply_delay_pattern_mask(self, input_ids):
27
  _, mask = self.decoder.build_delay_pattern_mask(input_ids[:, :1], pad_token_id=self.generation_config.decoder_start_token_id, max_length=input_ids.shape[-1])
@@ -32,15 +30,15 @@ class MusicgenStreamer(BaseStreamer):
32
  def put(self, value):
33
  if value.shape[0] // self.decoder.num_codebooks > 1:
34
  raise ValueError("MusicgenStreamer only supports batch size 1")
35
- self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1) if self.token_cache else value
36
  if self.token_cache.shape[-1] % self.play_steps == 0:
37
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
38
- self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
39
  self.to_yield += len(audio_values) - self.to_yield - self.stride
40
 
41
  def end(self):
42
  audio_values = self.apply_delay_pattern_mask(self.token_cache) if self.token_cache else np.zeros(self.to_yield)
43
- self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
44
 
45
  def on_finalized_audio(self, audio, stream_end=False):
46
  self.audio_queue.put(audio, timeout=self.timeout)
@@ -52,29 +50,26 @@ class MusicgenStreamer(BaseStreamer):
52
 
53
  def __next__(self):
54
  value = self.audio_queue.get(timeout=self.timeout)
55
- if not isinstance(value, np.ndarray) and value == self.stop_signal:
56
  raise StopIteration()
57
  return value
58
 
59
- sampling_rate, frame_rate = model.audio_encoder.config.sampling_rate, model.audio_encoder.config.frame_rate
60
-
61
  @spaces.GPU()
62
  def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
63
- max_new_tokens, play_steps = int(frame_rate * audio_length_in_s), int(frame_rate * play_steps_in_s)
64
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
65
  if device != model.device:
66
  model.to(device)
67
  if device == "cuda:0":
68
  model.half()
 
 
69
  inputs = processor(text=text_prompt, padding=True, return_tensors="pt")
70
  streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
71
- generation_kwargs = dict(**inputs.to(device), streamer=streamer, max_new_tokens=max_new_tokens)
72
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
73
- thread.start()
74
  set_seed(seed)
75
  for new_audio in streamer:
76
- print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
77
- yield sampling_rate, new_audio
78
 
79
  demo = gr.Interface(
80
  fn=generate_audio,
 
2
 
3
  from queue import Queue
4
  from threading import Thread
 
5
  import numpy as np
6
  import torch
7
  from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
 
8
  import gradio as gr
9
  import spaces
10
 
11
  model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
12
  processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
 
13
  title = "AI Radio"
14
 
15
+ class MusicgenStreamer:
16
  def __init__(self, model, device=None, play_steps=10, stride=None, timeout=None):
17
  self.decoder, self.audio_encoder, self.generation_config = model.decoder, model.audio_encoder, model.generation_config
18
+ self.device = device or model.device
19
  self.play_steps = play_steps
20
+ self.stride = stride or np.prod(self.audio_encoder.config.upsampling_ratios) * (play_steps - self.decoder.num_codebooks) // 6
21
+ self.token_cache, self.to_yield, self.audio_queue, self.timeout = None, 0, Queue(), timeout
22
+ self.stop_signal = object()
23
 
24
  def apply_delay_pattern_mask(self, input_ids):
25
  _, mask = self.decoder.build_delay_pattern_mask(input_ids[:, :1], pad_token_id=self.generation_config.decoder_start_token_id, max_length=input_ids.shape[-1])
 
30
  def put(self, value):
31
  if value.shape[0] // self.decoder.num_codebooks > 1:
32
  raise ValueError("MusicgenStreamer only supports batch size 1")
33
+ self.token_cache = torch.cat([self.token_cache, value[:, None]], dim=-1) if self.token_cache else value
34
  if self.token_cache.shape[-1] % self.play_steps == 0:
35
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
36
+ self.on_finalized_audio(audio_values[self.to_yield:-self.stride])
37
  self.to_yield += len(audio_values) - self.to_yield - self.stride
38
 
39
  def end(self):
40
  audio_values = self.apply_delay_pattern_mask(self.token_cache) if self.token_cache else np.zeros(self.to_yield)
41
+ self.on_finalized_audio(audio_values[self.to_yield:], stream_end=True)
42
 
43
  def on_finalized_audio(self, audio, stream_end=False):
44
  self.audio_queue.put(audio, timeout=self.timeout)
 
50
 
51
  def __next__(self):
52
  value = self.audio_queue.get(timeout=self.timeout)
53
+ if value is self.stop_signal:
54
  raise StopIteration()
55
  return value
56
 
 
 
57
  @spaces.GPU()
58
  def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
 
59
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
60
  if device != model.device:
61
  model.to(device)
62
  if device == "cuda:0":
63
  model.half()
64
+ max_new_tokens = int(model.audio_encoder.config.frame_rate * audio_length_in_s)
65
+ play_steps = int(model.audio_encoder.config.frame_rate * play_steps_in_s)
66
  inputs = processor(text=text_prompt, padding=True, return_tensors="pt")
67
  streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
68
+ Thread(target=model.generate, kwargs=dict(**inputs.to(device), streamer=streamer, max_new_tokens=max_new_tokens)).start()
 
 
69
  set_seed(seed)
70
  for new_audio in streamer:
71
+ print(f"Sample of length: {round(new_audio.shape[0] / model.audio_encoder.config.sampling_rate, 2)} seconds")
72
+ yield model.audio_encoder.config.sampling_rate, new_audio
73
 
74
  demo = gr.Interface(
75
  fn=generate_audio,