Spaces:
Running
Running
fix: make update_audio fn async to avoid server freeze
Browse files- src/audio/audio_generator.py +62 -44
- src/main.py +0 -1
src/audio/audio_generator.py
CHANGED
@@ -1,18 +1,14 @@
|
|
1 |
import asyncio
|
2 |
from google.genai import types
|
3 |
import wave
|
4 |
-
import queue
|
5 |
import logging
|
6 |
import io
|
7 |
-
import
|
8 |
from config import settings
|
9 |
from services.google import GoogleClientFactory
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
17 |
if user_hash in sessions:
|
18 |
logger.info(
|
@@ -44,7 +40,7 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
|
44 |
logger.info(
|
45 |
f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
|
46 |
)
|
47 |
-
sessions[user_hash] = {"session": session, "queue":
|
48 |
|
49 |
|
50 |
async def change_music_tone(user_hash: str, new_tone):
|
@@ -75,7 +71,7 @@ async def receive_audio(session, user_hash):
|
|
75 |
audio_data = message.server_content.audio_chunks[0].data
|
76 |
queue = sessions[user_hash]["queue"]
|
77 |
# audio_data is already bytes (raw PCM)
|
78 |
-
await
|
79 |
await asyncio.sleep(10**-12)
|
80 |
except Exception as e:
|
81 |
logger.error(f"Error in receive_audio: {e}")
|
@@ -102,44 +98,66 @@ async def cleanup_music_session(user_hash: str):
|
|
102 |
del sessions[user_hash]
|
103 |
|
104 |
|
105 |
-
def update_audio(user_hash):
|
106 |
-
"""
|
|
|
|
|
|
|
107 |
if user_hash == "":
|
108 |
return
|
109 |
|
110 |
logger.info(f"Starting audio update loop for user hash: {user_hash}")
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
from google.genai import types
|
3 |
import wave
|
|
|
4 |
import logging
|
5 |
import io
|
6 |
+
import gradio as gr
|
7 |
from config import settings
|
8 |
from services.google import GoogleClientFactory
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
|
|
|
|
|
|
12 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
13 |
if user_hash in sessions:
|
14 |
logger.info(
|
|
|
40 |
logger.info(
|
41 |
f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
|
42 |
)
|
43 |
+
sessions[user_hash] = {"session": session, "queue": asyncio.Queue()}
|
44 |
|
45 |
|
46 |
async def change_music_tone(user_hash: str, new_tone):
|
|
|
71 |
audio_data = message.server_content.audio_chunks[0].data
|
72 |
queue = sessions[user_hash]["queue"]
|
73 |
# audio_data is already bytes (raw PCM)
|
74 |
+
await queue.put(audio_data)
|
75 |
await asyncio.sleep(10**-12)
|
76 |
except Exception as e:
|
77 |
logger.error(f"Error in receive_audio: {e}")
|
|
|
98 |
del sessions[user_hash]
|
99 |
|
100 |
|
101 |
+
async def update_audio(user_hash: str, request: gr.Request):
|
102 |
+
"""
|
103 |
+
Continuously stream audio from the queue as WAV bytes, and clean up
|
104 |
+
when the user disconnects.
|
105 |
+
"""
|
106 |
if user_hash == "":
|
107 |
return
|
108 |
|
109 |
logger.info(f"Starting audio update loop for user hash: {user_hash}")
|
110 |
+
try:
|
111 |
+
while True:
|
112 |
+
if await request.request.is_disconnected():
|
113 |
+
logger.info(f"Client disconnected for user hash {user_hash}.")
|
114 |
+
break
|
115 |
+
|
116 |
+
if user_hash not in sessions:
|
117 |
+
await asyncio.sleep(0.5)
|
118 |
+
continue
|
119 |
+
|
120 |
+
try:
|
121 |
+
queue = sessions[user_hash]["queue"]
|
122 |
+
pcm_data = await asyncio.wait_for(queue.get(), timeout=1.0)
|
123 |
+
except asyncio.TimeoutError:
|
124 |
+
continue # Check for disconnect again
|
125 |
+
except (KeyError, AttributeError):
|
126 |
+
logger.warning(
|
127 |
+
f"Session or queue for {user_hash} not found. Stopping audio loop."
|
128 |
+
)
|
129 |
+
break
|
130 |
+
|
131 |
+
if not isinstance(pcm_data, bytes):
|
132 |
+
logger.warning(
|
133 |
+
f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
|
134 |
+
)
|
135 |
+
continue
|
136 |
+
|
137 |
+
# Lyria provides stereo, 16-bit PCM at 48kHz.
|
138 |
+
# Ensure the number of bytes is consistent with stereo 16-bit audio.
|
139 |
+
# Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
|
140 |
+
# If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
|
141 |
+
# it might indicate an incomplete chunk or an issue.
|
142 |
+
bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
|
143 |
+
if len(pcm_data) % bytes_per_frame != 0:
|
144 |
+
logger.warning(
|
145 |
+
f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
|
146 |
+
f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
|
147 |
+
)
|
148 |
+
# Depending on strictness, you might want to skip this chunk:
|
149 |
+
# continue
|
150 |
+
|
151 |
+
wav_buffer = io.BytesIO()
|
152 |
+
with wave.open(wav_buffer, "wb") as wf:
|
153 |
+
wf.setnchannels(NUM_CHANNELS)
|
154 |
+
wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
|
155 |
+
wf.setframerate(SAMPLE_RATE)
|
156 |
+
wf.writeframes(pcm_data)
|
157 |
+
wav_bytes = wav_buffer.getvalue()
|
158 |
+
yield wav_bytes
|
159 |
+
finally:
|
160 |
+
logger.info(
|
161 |
+
f"Audio update loop finished for {user_hash}. Cleaning up music session."
|
162 |
+
)
|
163 |
+
await cleanup_music_session(user_hash)
|
src/main.py
CHANGED
@@ -357,7 +357,6 @@ with gr.Blocks(
|
|
357 |
outputs=[game_text, game_image, game_choices, custom_choice],
|
358 |
)
|
359 |
|
360 |
-
demo.unload(cleanup_music_session)
|
361 |
demo.load(
|
362 |
fn=generate_user_hash,
|
363 |
inputs=[],
|
|
|
357 |
outputs=[game_text, game_image, game_choices, custom_choice],
|
358 |
)
|
359 |
|
|
|
360 |
demo.load(
|
361 |
fn=generate_user_hash,
|
362 |
inputs=[],
|