fffiloni commited on
Commit
160741c
·
verified ·
1 Parent(s): d92bb4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -20
app.py CHANGED
@@ -1,10 +1,19 @@
1
  import gradio as gr
2
-
 
 
 
3
  import soundfile as sf
4
  from auffusion_pipeline import AuffusionPipeline
5
 
6
  pipeline = AuffusionPipeline.from_pretrained("auffusion/auffusion")
7
 
 
 
 
 
 
 
8
  def infer(prompt, progress=gr.Progress(track_tqdm=True)):
9
 
10
  prompt = prompt
@@ -14,6 +23,80 @@ def infer(prompt, progress=gr.Progress(track_tqdm=True)):
14
 
15
  return f"{prompt}.wav"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  css="""
18
  div#col-container{
19
  margin: 0 auto;
@@ -40,24 +123,45 @@ with gr.Blocks(css=css) as demo:
40
  </a>
41
  </div>
42
  """)
43
-
44
- prompt = gr.Textbox(label="Prompt")
45
- submit_btn = gr.Button("Submit")
46
- audio_out = gr.Audio(label="Audio Ressult")
47
-
48
- gr.Examples(
49
- examples = [
50
- "Rolling thunder with lightning strikes",
51
- "Two gunshots followed by birds chirping",
52
- "A train whistle blowing in the distance"
53
- ],
54
- inputs = [prompt]
55
- )
56
-
57
- submit_btn.click(
58
- fn = infer,
59
- inputs = [prompt],
60
- outputs = [audio_out]
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  demo.queue().launch(show_api=False, show_error=True)
 
1
  import gradio as gr
2
+ import torch, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from huggingface_hub import snapshot_download
6
  import soundfile as sf
7
  from auffusion_pipeline import AuffusionPipeline
8
 
9
  pipeline = AuffusionPipeline.from_pretrained("auffusion/auffusion")
10
 
11
+ # ——
12
+
13
+ from diffusers import StableDiffusionImg2ImgPipeline
14
+ from converter import load_wav, mel_spectrogram, normalize_spectrogram, denormalize_spectrogram, Generator, get_mel_spectrogram_from_audio
15
+ from utils import pad_spec, image_add_color, torch_to_pil, normalize, denormalize
16
+
17
  def infer(prompt, progress=gr.Progress(track_tqdm=True)):
18
 
19
  prompt = prompt
 
23
 
24
  return f"{prompt}.wav"
25
 
26
+ def infer_img2img(prompt, audio_path):
27
+
28
+ pretrained_model_name_or_path = "auffusion/auffusion-full-no-adapter"
29
+ dtype = torch.float16
30
+ device = "cuda"
31
+
32
+ vocoder = Generator.from_pretrained(pretrained_model_name_or_path, subfolder="vocoder")
33
+ vocoder = vocoder.to(device=device, dtype=dtype)
34
+
35
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=dtype)
36
+ pipe = pipe.to(device)
37
+
38
+ width_start, width = 0, 160
39
+ strength_list = [0.0, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7]
40
+ prompt = "aumbulance siren"
41
+ seed = 42
42
+
43
+ # Loading
44
+ audio, sampling_rate = load_wav(audio_path)
45
+ audio, spec = get_mel_spectrogram_from_audio(audio)
46
+ norm_spec = normalize_spectrogram(spec)
47
+ norm_spec = norm_spec[:,:, width_start:width_start+width]
48
+ norm_spec = pad_spec(norm_spec, 1024)
49
+ norm_spec = normalize(norm_spec) # normalize to [-1, 1], because pipeline do not normalize for torch.Tensor input
50
+
51
+ raw_image = image_add_color(torch_to_pil(norm_spec[:,:,:width]))
52
+
53
+ # Generation for different strength
54
+ image_list = []
55
+ audio_list = []
56
+
57
+ generator = torch.Generator(device=device).manual_seed(seed)
58
+
59
+ for strength in strength_list:
60
+ with torch.autocast("cuda"):
61
+ output_spec = pipe(
62
+ prompt=prompt, image=norm_spec, num_inference_steps=100, generator=generator, output_type="pt", strength=strength, guidance_scale=7.5
63
+ ).images[0]
64
+
65
+ # add to image_list
66
+ output_spec = output_spec[:, :, :width]
67
+ output_spec_image = torch_to_pil(output_spec)
68
+ color_output_spec_image = image_add_color(output_spec_image)
69
+ image_list.append(color_output_spec_image)
70
+
71
+ # add to audio_list
72
+ denorm_spec = denormalize_spectrogram(output_spec)
73
+ denorm_spec_audio = vocoder.inference(denorm_spec)
74
+ audio_list.append(denorm_spec_audio)
75
+
76
+ # Display
77
+
78
+ # Concat image with different strength & add interval between images with black color
79
+ concat_image_list = []
80
+ for i in range(len(image_list)):
81
+ if i == len(image_list) - 1:
82
+ concat_image_list.append(np.array(image_list[i]))
83
+ else:
84
+ concat_image_list.append(np.concatenate([np.array(image_list[i]), np.ones((256, 20, 3))*0], axis=1))
85
+
86
+ concat_image = np.concatenate(concat_image_list, axis=1)
87
+ concat_image = Image.fromarray(np.uint8(concat_image))
88
+
89
+ ### Concat audio
90
+ concat_audio_list = [np.concatenate([audio, np.zeros((1, 16000))], axis=1) for audio in audio_list]
91
+ concat_audio = np.concatenate(concat_audio_list, axis=1)
92
+
93
+ print("audio_path:", audio_path)
94
+ print("width_start:", width_start, "width:", width)
95
+ print("text prompt:", prompt)
96
+ print("strength_list:", strength_list)
97
+
98
+ return concat_audio
99
+
100
  css="""
101
  div#col-container{
102
  margin: 0 auto;
 
123
  </a>
124
  </div>
125
  """)
126
+ with gr.Tab("Text-to-Audio"):
127
+ prompt = gr.Textbox(label="Prompt")
128
+ submit_btn = gr.Button("Submit")
129
+ audio_out = gr.Audio(label="Audio Ressult")
130
+
131
+ gr.Examples(
132
+ examples = [
133
+ "Rolling thunder with lightning strikes",
134
+ "Two gunshots followed by birds chirping",
135
+ "A train whistle blowing in the distance"
136
+ ],
137
+ inputs = [prompt]
138
+ )
139
+
140
+ submit_btn.click(
141
+ fn = infer,
142
+ inputs = [prompt],
143
+ outputs = [audio_out]
144
+ )
145
+
146
+ with gr.Tab("Audio-to-Audio"):
147
+ prompt_img2img = gr.Textbox(label="Prompt")
148
+ audio_in_img2img = gr.Audio(label="Audio Reference", type="filepath")
149
+ submit_btn_img2img = gr.Button("Submit")
150
+ audio_out_img2img = gr.Audio(label="Audio Ressult")
151
+
152
+ gr.Examples(
153
+ examples = [
154
+ "Rolling thunder with lightning strikes",
155
+ "Two gunshots followed by birds chirping",
156
+ "A train whistle blowing in the distance"
157
+ ],
158
+ inputs = [prompt_img2img]
159
+ )
160
+
161
+ submit_btn_img2img.click(
162
+ fn = infer_img2img,
163
+ inputs = [prompt_img2img],
164
+ outputs = [audio_out_img2img]
165
+ )
166
 
167
  demo.queue().launch(show_api=False, show_error=True)