gsavin commited on
Commit
821d1b2
·
1 Parent(s): 7150fb6

fix: make update_audio fn async to avoid server freeze

Browse files
Files changed (2) hide show
  1. src/audio/audio_generator.py +62 -44
  2. src/main.py +0 -1
src/audio/audio_generator.py CHANGED
@@ -1,18 +1,14 @@
1
  import asyncio
2
  from google.genai import types
3
  import wave
4
- import queue
5
  import logging
6
  import io
7
- import time
8
  from config import settings
9
  from services.google import GoogleClientFactory
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
-
14
-
15
-
16
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
17
  if user_hash in sessions:
18
  logger.info(
@@ -44,7 +40,7 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
44
  logger.info(
45
  f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
46
  )
47
- sessions[user_hash] = {"session": session, "queue": queue.Queue()}
48
 
49
 
50
  async def change_music_tone(user_hash: str, new_tone):
@@ -75,7 +71,7 @@ async def receive_audio(session, user_hash):
75
  audio_data = message.server_content.audio_chunks[0].data
76
  queue = sessions[user_hash]["queue"]
77
  # audio_data is already bytes (raw PCM)
78
- await asyncio.to_thread(queue.put, audio_data)
79
  await asyncio.sleep(10**-12)
80
  except Exception as e:
81
  logger.error(f"Error in receive_audio: {e}")
@@ -102,44 +98,66 @@ async def cleanup_music_session(user_hash: str):
102
  del sessions[user_hash]
103
 
104
 
105
- def update_audio(user_hash):
106
- """Continuously stream audio from the queue as WAV bytes."""
 
 
 
107
  if user_hash == "":
108
  return
109
 
110
  logger.info(f"Starting audio update loop for user hash: {user_hash}")
111
- while True:
112
- if user_hash not in sessions:
113
- time.sleep(0.5)
114
- continue
115
- queue = sessions[user_hash]["queue"]
116
- pcm_data = queue.get() # This is raw PCM audio bytes
117
-
118
- if not isinstance(pcm_data, bytes):
119
- logger.warning(
120
- f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
121
- )
122
- continue
123
-
124
- # Lyria provides stereo, 16-bit PCM at 48kHz.
125
- # Ensure the number of bytes is consistent with stereo 16-bit audio.
126
- # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
127
- # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
128
- # it might indicate an incomplete chunk or an issue.
129
- bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
130
- if len(pcm_data) % bytes_per_frame != 0:
131
- logger.warning(
132
- f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
133
- f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
134
- )
135
- # Depending on strictness, you might want to skip this chunk:
136
- # continue
137
-
138
- wav_buffer = io.BytesIO()
139
- with wave.open(wav_buffer, "wb") as wf:
140
- wf.setnchannels(NUM_CHANNELS)
141
- wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
142
- wf.setframerate(SAMPLE_RATE)
143
- wf.writeframes(pcm_data)
144
- wav_bytes = wav_buffer.getvalue()
145
- yield wav_bytes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  from google.genai import types
3
  import wave
 
4
  import logging
5
  import io
6
+ import gradio as gr
7
  from config import settings
8
  from services.google import GoogleClientFactory
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
12
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
13
  if user_hash in sessions:
14
  logger.info(
 
40
  logger.info(
41
  f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
42
  )
43
+ sessions[user_hash] = {"session": session, "queue": asyncio.Queue()}
44
 
45
 
46
  async def change_music_tone(user_hash: str, new_tone):
 
71
  audio_data = message.server_content.audio_chunks[0].data
72
  queue = sessions[user_hash]["queue"]
73
  # audio_data is already bytes (raw PCM)
74
+ await queue.put(audio_data)
75
  await asyncio.sleep(10**-12)
76
  except Exception as e:
77
  logger.error(f"Error in receive_audio: {e}")
 
98
  del sessions[user_hash]
99
 
100
 
101
+ async def update_audio(user_hash: str, request: gr.Request):
102
+ """
103
+ Continuously stream audio from the queue as WAV bytes, and clean up
104
+ when the user disconnects.
105
+ """
106
  if user_hash == "":
107
  return
108
 
109
  logger.info(f"Starting audio update loop for user hash: {user_hash}")
110
+ try:
111
+ while True:
112
+ if await request.request.is_disconnected():
113
+ logger.info(f"Client disconnected for user hash {user_hash}.")
114
+ break
115
+
116
+ if user_hash not in sessions:
117
+ await asyncio.sleep(0.5)
118
+ continue
119
+
120
+ try:
121
+ queue = sessions[user_hash]["queue"]
122
+ pcm_data = await asyncio.wait_for(queue.get(), timeout=1.0)
123
+ except asyncio.TimeoutError:
124
+ continue # Check for disconnect again
125
+ except (KeyError, AttributeError):
126
+ logger.warning(
127
+ f"Session or queue for {user_hash} not found. Stopping audio loop."
128
+ )
129
+ break
130
+
131
+ if not isinstance(pcm_data, bytes):
132
+ logger.warning(
133
+ f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
134
+ )
135
+ continue
136
+
137
+ # Lyria provides stereo, 16-bit PCM at 48kHz.
138
+ # Ensure the number of bytes is consistent with stereo 16-bit audio.
139
+ # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
140
+ # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
141
+ # it might indicate an incomplete chunk or an issue.
142
+ bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
143
+ if len(pcm_data) % bytes_per_frame != 0:
144
+ logger.warning(
145
+ f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
146
+ f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
147
+ )
148
+ # Depending on strictness, you might want to skip this chunk:
149
+ # continue
150
+
151
+ wav_buffer = io.BytesIO()
152
+ with wave.open(wav_buffer, "wb") as wf:
153
+ wf.setnchannels(NUM_CHANNELS)
154
+ wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
155
+ wf.setframerate(SAMPLE_RATE)
156
+ wf.writeframes(pcm_data)
157
+ wav_bytes = wav_buffer.getvalue()
158
+ yield wav_bytes
159
+ finally:
160
+ logger.info(
161
+ f"Audio update loop finished for {user_hash}. Cleaning up music session."
162
+ )
163
+ await cleanup_music_session(user_hash)
src/main.py CHANGED
@@ -357,7 +357,6 @@ with gr.Blocks(
357
  outputs=[game_text, game_image, game_choices, custom_choice],
358
  )
359
 
360
- demo.unload(cleanup_music_session)
361
  demo.load(
362
  fn=generate_user_hash,
363
  inputs=[],
 
357
  outputs=[game_text, game_image, game_choices, custom_choice],
358
  )
359
 
 
360
  demo.load(
361
  fn=generate_user_hash,
362
  inputs=[],