LiKenun commited on
Commit
1c1b97a
·
1 Parent(s): 65e848c

Move environment variable querying code out of the inference functions

Browse files
automatic_speech_recognition.py CHANGED
@@ -4,7 +4,7 @@ from os import getenv, path, unlink
4
  import gradio as gr
5
  from utils import save_audio_to_temp_file, get_model_sample_rate, request_audio
6
 
7
- def automatic_speech_recognition(client: InferenceClient, audio: tuple[int, bytes]) -> str:
8
  """Transcribe audio to text using Hugging Face Inference API.
9
 
10
  This function converts speech audio into text transcription. The audio is
@@ -13,6 +13,7 @@ def automatic_speech_recognition(client: InferenceClient, audio: tuple[int, byte
13
 
14
  Args:
15
  client: Hugging Face InferenceClient instance for API calls.
 
16
  audio: Tuple containing:
17
  - int: Sample rate of the input audio (e.g., 44100 Hz)
18
  - bytes: Raw audio data as bytes
@@ -21,18 +22,15 @@ def automatic_speech_recognition(client: InferenceClient, audio: tuple[int, byte
21
  String containing the transcribed text from the audio.
22
 
23
  Note:
24
- - The model ID is determined by the AUDIO_TRANSCRIPTION_MODEL environment variable.
25
  - Audio is automatically resampled to match the model's expected sample rate.
26
  - Audio is saved as a WAV file for InferenceClient compatibility.
27
  - Automatically cleans up temporary files after transcription.
28
- - Uses openai/whisper-large-v3 or similar ASR models.
29
  """
30
  temp_file_path = None
31
  try:
32
- model_id = getenv("AUDIO_TRANSCRIPTION_MODEL")
33
- sample_rate = get_model_sample_rate(model_id)
34
  temp_file_path = save_audio_to_temp_file(sample_rate, audio)
35
- result = client.automatic_speech_recognition(temp_file_path, model=model_id)
36
  return result["text"]
37
  finally:
38
  if temp_file_path and path.exists(temp_file_path): # Clean up temporary file.
@@ -54,6 +52,7 @@ def create_asr_tab(client: InferenceClient):
54
  Args:
55
  client: Hugging Face InferenceClient instance to pass to the automatic_speech_recognition function.
56
  """
 
57
  gr.Markdown("Transcribe audio to text.")
58
  audio_transcription_url_input = gr.Textbox(label="Audio URL")
59
  audio_transcription_audio_request_button = gr.Button("Get Audio")
@@ -66,7 +65,7 @@ def create_asr_tab(client: InferenceClient):
66
  audio_transcription_generate_button = gr.Button("Transcribe")
67
  audio_transcription_output = gr.Textbox(label="Text")
68
  audio_transcription_generate_button.click(
69
- fn=partial(automatic_speech_recognition, client),
70
  inputs=audio_transcription_audio_input,
71
  outputs=audio_transcription_output
72
  )
 
4
  import gradio as gr
5
  from utils import save_audio_to_temp_file, get_model_sample_rate, request_audio
6
 
7
+ def automatic_speech_recognition(client: InferenceClient, model: str, audio: tuple[int, bytes]) -> str:
8
  """Transcribe audio to text using Hugging Face Inference API.
9
 
10
  This function converts speech audio into text transcription. The audio is
 
13
 
14
  Args:
15
  client: Hugging Face InferenceClient instance for API calls.
16
+ model: Hugging Face model ID to use for automatic speech recognition.
17
  audio: Tuple containing:
18
  - int: Sample rate of the input audio (e.g., 44100 Hz)
19
  - bytes: Raw audio data as bytes
 
22
  String containing the transcribed text from the audio.
23
 
24
  Note:
 
25
  - Audio is automatically resampled to match the model's expected sample rate.
26
  - Audio is saved as a WAV file for InferenceClient compatibility.
27
  - Automatically cleans up temporary files after transcription.
 
28
  """
29
  temp_file_path = None
30
  try:
31
+ sample_rate = get_model_sample_rate(model)
 
32
  temp_file_path = save_audio_to_temp_file(sample_rate, audio)
33
+ result = client.automatic_speech_recognition(temp_file_path, model=model)
34
  return result["text"]
35
  finally:
36
  if temp_file_path and path.exists(temp_file_path): # Clean up temporary file.
 
52
  Args:
53
  client: Hugging Face InferenceClient instance to pass to the automatic_speech_recognition function.
54
  """
55
+ model_id = getenv("AUDIO_TRANSCRIPTION_MODEL")
56
  gr.Markdown("Transcribe audio to text.")
57
  audio_transcription_url_input = gr.Textbox(label="Audio URL")
58
  audio_transcription_audio_request_button = gr.Button("Get Audio")
 
65
  audio_transcription_generate_button = gr.Button("Transcribe")
66
  audio_transcription_output = gr.Textbox(label="Text")
67
  audio_transcription_generate_button.click(
68
+ fn=partial(automatic_speech_recognition, client, model_id),
69
  inputs=audio_transcription_audio_input,
70
  outputs=audio_transcription_output
71
  )
chatbot.py CHANGED
@@ -8,7 +8,7 @@ _chatbot = None
8
  _tokenizer = None
9
  _is_seq2seq = None
10
 
11
- def get_chatbot():
12
  """Get or create the chatbot model instance.
13
 
14
  This function implements a singleton pattern to load and cache the chatbot
@@ -16,6 +16,9 @@ def get_chatbot():
16
  models) and sequence-to-sequence models (like BlenderBot). The model type
17
  is automatically detected from the model configuration.
18
 
 
 
 
19
  Returns:
20
  Tuple containing:
21
  - Model: The loaded transformer model (AutoModelForCausalLM or AutoModelForSeq2SeqLM)
@@ -23,7 +26,6 @@ def get_chatbot():
23
  - bool: Whether the model is a seq2seq model (True) or causal LM (False)
24
 
25
  Note:
26
- - The model ID is determined by the CHAT_MODEL environment variable.
27
  - Models are loaded with safetensors for secure loading.
28
  - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
29
  - Sets pad_token to eos_token if pad_token is not configured.
@@ -31,15 +33,14 @@ def get_chatbot():
31
  """
32
  global _chatbot, _tokenizer, _is_seq2seq
33
  if _chatbot is None:
34
- model_id = getenv("CHAT_MODEL")
35
  device = get_pytorch_device()
36
- _tokenizer = AutoTokenizer.from_pretrained(model_id)
37
 
38
  # Try to determine model type and load accordingly
39
  # Check tokenizer config or model config to see if it's seq2seq
40
  try:
41
  from transformers import AutoConfig
42
- config = AutoConfig.from_pretrained(model_id)
43
  # Seq2seq models have encoder/decoder, causal LMs don't
44
  _is_seq2seq = hasattr(config, 'is_encoder_decoder') and config.is_encoder_decoder
45
  except Exception:
@@ -48,12 +49,12 @@ def get_chatbot():
48
 
49
  if _is_seq2seq:
50
  _chatbot = AutoModelForSeq2SeqLM.from_pretrained(
51
- model_id,
52
  use_safetensors=True
53
  ).to(device)
54
  else:
55
  _chatbot = AutoModelForCausalLM.from_pretrained(
56
- model_id,
57
  use_safetensors=True
58
  ).to(device)
59
 
@@ -64,7 +65,7 @@ def get_chatbot():
64
  return _chatbot, _tokenizer, _is_seq2seq
65
 
66
  @spaces_gpu
67
- def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]:
68
  """Generate a chatbot response given a user message and conversation history.
69
 
70
  This function handles conversation with AI chatbots, supporting both modern
@@ -73,6 +74,7 @@ def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, li
73
  formats inputs appropriately based on the model type.
74
 
75
  Args:
 
76
  message: The user's current message as a string.
77
  conversation_history: Optional list of previous conversation messages.
78
  Each message is a dict with "role" ("user" or "assistant") and "content".
@@ -92,7 +94,7 @@ def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, li
92
  - Automatically manages conversation context and history
93
  - Extracts only newly generated text for causal LMs with chat templates
94
  """
95
- model, tokenizer, is_seq2seq = get_chatbot()
96
 
97
  # Initialize conversation history if this is the first message
98
  if conversation_history is None:
@@ -141,7 +143,7 @@ def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, li
141
  inputs = tokenizer(dialogue_text, return_tensors="pt", truncation=True, max_length=1024).to(device)
142
 
143
  # Generate response
144
- outputs = model.generate(
145
  **inputs,
146
  max_new_tokens=256,
147
  do_sample=True,
@@ -188,6 +190,7 @@ def create_chatbot_tab():
188
  and manages the conversion between Gradio's chat format and the internal
189
  conversation history format.
190
  """
 
191
  gr.Markdown("Have a conversation with an AI chatbot.")
192
  chatbot_history = gr.State(value=None) # Store the conversation history.
193
  chatbot_output = gr.Chatbot(label="Conversation")
@@ -214,7 +217,7 @@ def create_chatbot_tab():
214
  """
215
  if not message.strip():
216
  return history, conversation_state, ""
217
- response, updated_conversation = chat(message, conversation_state) # Get response from chatbot.
218
  if history is None: # Update Gradio chat history format: list of [user_message, bot_message] pairs.
219
  history = []
220
  history.append([message, response])
 
8
  _tokenizer = None
9
  _is_seq2seq = None
10
 
11
+ def get_chatbot(model: str):
12
  """Get or create the chatbot model instance.
13
 
14
  This function implements a singleton pattern to load and cache the chatbot
 
16
  models) and sequence-to-sequence models (like BlenderBot). The model type
17
  is automatically detected from the model configuration.
18
 
19
+ Args:
20
+ model: Hugging Face model ID to use for the chatbot.
21
+
22
  Returns:
23
  Tuple containing:
24
  - Model: The loaded transformer model (AutoModelForCausalLM or AutoModelForSeq2SeqLM)
 
26
  - bool: Whether the model is a seq2seq model (True) or causal LM (False)
27
 
28
  Note:
 
29
  - Models are loaded with safetensors for secure loading.
30
  - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
31
  - Sets pad_token to eos_token if pad_token is not configured.
 
33
  """
34
  global _chatbot, _tokenizer, _is_seq2seq
35
  if _chatbot is None:
 
36
  device = get_pytorch_device()
37
+ _tokenizer = AutoTokenizer.from_pretrained(model)
38
 
39
  # Try to determine model type and load accordingly
40
  # Check tokenizer config or model config to see if it's seq2seq
41
  try:
42
  from transformers import AutoConfig
43
+ config = AutoConfig.from_pretrained(model)
44
  # Seq2seq models have encoder/decoder, causal LMs don't
45
  _is_seq2seq = hasattr(config, 'is_encoder_decoder') and config.is_encoder_decoder
46
  except Exception:
 
49
 
50
  if _is_seq2seq:
51
  _chatbot = AutoModelForSeq2SeqLM.from_pretrained(
52
+ model,
53
  use_safetensors=True
54
  ).to(device)
55
  else:
56
  _chatbot = AutoModelForCausalLM.from_pretrained(
57
+ model,
58
  use_safetensors=True
59
  ).to(device)
60
 
 
65
  return _chatbot, _tokenizer, _is_seq2seq
66
 
67
  @spaces_gpu
68
+ def chat(model: str, message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]:
69
  """Generate a chatbot response given a user message and conversation history.
70
 
71
  This function handles conversation with AI chatbots, supporting both modern
 
74
  formats inputs appropriately based on the model type.
75
 
76
  Args:
77
+ model: Hugging Face model ID to use for the chatbot.
78
  message: The user's current message as a string.
79
  conversation_history: Optional list of previous conversation messages.
80
  Each message is a dict with "role" ("user" or "assistant") and "content".
 
94
  - Automatically manages conversation context and history
95
  - Extracts only newly generated text for causal LMs with chat templates
96
  """
97
+ model_instance, tokenizer, is_seq2seq = get_chatbot(model)
98
 
99
  # Initialize conversation history if this is the first message
100
  if conversation_history is None:
 
143
  inputs = tokenizer(dialogue_text, return_tensors="pt", truncation=True, max_length=1024).to(device)
144
 
145
  # Generate response
146
+ outputs = model_instance.generate(
147
  **inputs,
148
  max_new_tokens=256,
149
  do_sample=True,
 
190
  and manages the conversion between Gradio's chat format and the internal
191
  conversation history format.
192
  """
193
+ model_id = getenv("CHAT_MODEL")
194
  gr.Markdown("Have a conversation with an AI chatbot.")
195
  chatbot_history = gr.State(value=None) # Store the conversation history.
196
  chatbot_output = gr.Chatbot(label="Conversation")
 
217
  """
218
  if not message.strip():
219
  return history, conversation_state, ""
220
+ response, updated_conversation = chat(model_id, message, conversation_state) # Get response from chatbot.
221
  if history is None: # Update Gradio chat history format: list of [user_message, bot_message] pairs.
222
  history = []
223
  history.append([message, response])
image_classification.py CHANGED
@@ -8,7 +8,7 @@ from pandas import DataFrame
8
  from utils import save_image_to_temp_file, request_image
9
 
10
 
11
- def image_classification(client: InferenceClient, image: Image) -> DataFrame:
12
  """Classify an image using Hugging Face Inference API.
13
 
14
  This function classifies a recyclable item image into categories:
@@ -18,6 +18,7 @@ def image_classification(client: InferenceClient, image: Image) -> DataFrame:
18
 
19
  Args:
20
  client: Hugging Face InferenceClient instance for API calls.
 
21
  image: PIL Image object to classify.
22
 
23
  Returns:
@@ -26,14 +27,12 @@ def image_classification(client: InferenceClient, image: Image) -> DataFrame:
26
  - Probability: The confidence score as a percentage string (e.g., "95.23%")
27
 
28
  Note:
29
- - The model ID is determined by the IMAGE_CLASSIFICATION_MODEL environment variable.
30
- - Uses Trash-Net model for recyclable item classification.
31
  - Automatically cleans up temporary files after classification.
32
  - Temporary file is created with format preservation if possible.
33
  """
34
  try:
35
  temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
36
- classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
37
  return pd.DataFrame({
38
  "Label": classification.label,
39
  "Probability": f"{classification.score:.2%}"
@@ -60,6 +59,7 @@ def create_image_classification_tab(client: InferenceClient):
60
  Args:
61
  client: Hugging Face InferenceClient instance to pass to the image_classification function.
62
  """
 
63
  gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
64
  image_classification_url_input = gr.Textbox(label="Image URL")
65
  image_classification_image_request_button = gr.Button("Get Image")
@@ -72,7 +72,7 @@ def create_image_classification_tab(client: InferenceClient):
72
  image_classification_button = gr.Button("Classify")
73
  image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
74
  image_classification_button.click(
75
- fn=partial(image_classification, client),
76
  inputs=image_classification_image_input,
77
  outputs=image_classification_output
78
  )
 
8
  from utils import save_image_to_temp_file, request_image
9
 
10
 
11
+ def image_classification(client: InferenceClient, model: str, image: Image) -> DataFrame:
12
  """Classify an image using Hugging Face Inference API.
13
 
14
  This function classifies a recyclable item image into categories:
 
18
 
19
  Args:
20
  client: Hugging Face InferenceClient instance for API calls.
21
+ model: Hugging Face model ID to use for image classification.
22
  image: PIL Image object to classify.
23
 
24
  Returns:
 
27
  - Probability: The confidence score as a percentage string (e.g., "95.23%")
28
 
29
  Note:
 
 
30
  - Automatically cleans up temporary files after classification.
31
  - Temporary file is created with format preservation if possible.
32
  """
33
  try:
34
  temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
35
+ classifications = client.image_classification(temp_file_path, model=model)
36
  return pd.DataFrame({
37
  "Label": classification.label,
38
  "Probability": f"{classification.score:.2%}"
 
59
  Args:
60
  client: Hugging Face InferenceClient instance to pass to the image_classification function.
61
  """
62
+ model_id = getenv("IMAGE_CLASSIFICATION_MODEL")
63
  gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
64
  image_classification_url_input = gr.Textbox(label="Image URL")
65
  image_classification_image_request_button = gr.Button("Get Image")
 
72
  image_classification_button = gr.Button("Classify")
73
  image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
74
  image_classification_button.click(
75
+ fn=partial(image_classification, client, model_id),
76
  inputs=image_classification_image_input,
77
  outputs=image_classification_output
78
  )
image_to_text.py CHANGED
@@ -1,4 +1,5 @@
1
  import gc
 
2
  from os import getenv
3
  import gradio as gr
4
  from PIL.Image import Image
@@ -7,7 +8,7 @@ from utils import get_pytorch_device, spaces_gpu, request_image
7
 
8
 
9
  @spaces_gpu
10
- def image_to_text(image: Image) -> list[str]:
11
  """Generate text captions for an image using BLIP model.
12
 
13
  This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
@@ -15,29 +16,28 @@ def image_to_text(image: Image) -> list[str]:
15
  loaded, inference is performed, and then cleaned up to free GPU memory.
16
 
17
  Args:
 
18
  image: PIL Image object to generate captions for.
19
 
20
  Returns:
21
  List of string captions describing the image.
22
 
23
  Note:
24
- - The model ID is determined by the IMAGE_TO_TEXT_MODEL environment variable.
25
  - Uses safetensors for secure model loading.
26
  - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
27
  - Cleans up model and GPU memory after inference.
28
  - Uses beam search with 3 beams, max length 20, min length 5.
29
  """
30
- image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL")
31
  pytorch_device = get_pytorch_device()
32
- processor = AutoProcessor.from_pretrained(image_to_text_model_id)
33
- model = BlipForConditionalGeneration.from_pretrained(
34
- image_to_text_model_id,
35
  use_safetensors=True # Use safetensors to avoid torch.load restriction.
36
  ).to(pytorch_device)
37
  inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
38
- generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
39
  results = processor.batch_decode(generated_ids, skip_special_tokens=True)
40
- del model, inputs
41
  gc.collect()
42
  return results
43
 
@@ -51,6 +51,7 @@ def create_image_to_text_tab():
51
  - Image preview component
52
  - Caption button and output list
53
  """
 
54
  gr.Markdown("Generate a text description of an image.")
55
  image_to_text_url_input = gr.Textbox(label="Image URL")
56
  image_to_text_image_request_button = gr.Button("Get Image")
@@ -63,7 +64,7 @@ def create_image_to_text_tab():
63
  image_to_text_button = gr.Button("Caption")
64
  image_to_text_output = gr.List(label="Captions", headers=["Caption"])
65
  image_to_text_button.click(
66
- fn=image_to_text,
67
  inputs=image_to_text_image_input,
68
  outputs=image_to_text_output
69
  )
 
1
  import gc
2
+ from functools import partial
3
  from os import getenv
4
  import gradio as gr
5
  from PIL.Image import Image
 
8
 
9
 
10
  @spaces_gpu
11
+ def image_to_text(model: str, image: Image) -> list[str]:
12
  """Generate text captions for an image using BLIP model.
13
 
14
  This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
 
16
  loaded, inference is performed, and then cleaned up to free GPU memory.
17
 
18
  Args:
19
+ model: Hugging Face model ID to use for image captioning.
20
  image: PIL Image object to generate captions for.
21
 
22
  Returns:
23
  List of string captions describing the image.
24
 
25
  Note:
 
26
  - Uses safetensors for secure model loading.
27
  - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
28
  - Cleans up model and GPU memory after inference.
29
  - Uses beam search with 3 beams, max length 20, min length 5.
30
  """
 
31
  pytorch_device = get_pytorch_device()
32
+ processor = AutoProcessor.from_pretrained(model)
33
+ model_instance = BlipForConditionalGeneration.from_pretrained(
34
+ model,
35
  use_safetensors=True # Use safetensors to avoid torch.load restriction.
36
  ).to(pytorch_device)
37
  inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
38
+ generated_ids = model_instance.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
39
  results = processor.batch_decode(generated_ids, skip_special_tokens=True)
40
+ del model_instance, inputs
41
  gc.collect()
42
  return results
43
 
 
51
  - Image preview component
52
  - Caption button and output list
53
  """
54
+ model_id = getenv("IMAGE_TO_TEXT_MODEL")
55
  gr.Markdown("Generate a text description of an image.")
56
  image_to_text_url_input = gr.Textbox(label="Image URL")
57
  image_to_text_image_request_button = gr.Button("Get Image")
 
64
  image_to_text_button = gr.Button("Caption")
65
  image_to_text_output = gr.List(label="Captions", headers=["Caption"])
66
  image_to_text_button.click(
67
+ fn=partial(image_to_text, model_id),
68
  inputs=image_to_text_image_input,
69
  outputs=image_to_text_output
70
  )
text_to_image.py CHANGED
@@ -5,20 +5,18 @@ from PIL.Image import Image
5
  from huggingface_hub import InferenceClient
6
 
7
 
8
- def text_to_image(client: InferenceClient, prompt: str) -> Image:
9
  """Generate an image from a text prompt using Hugging Face Inference API.
10
 
11
  Args:
12
  client: Hugging Face InferenceClient instance for API calls.
 
13
  prompt: Text description of the desired image.
14
 
15
  Returns:
16
  PIL Image object representing the generated image.
17
-
18
- Note:
19
- The model to use is determined by the TEXT_TO_IMAGE_MODEL environment variable.
20
  """
21
- return client.text_to_image(prompt, model=getenv("TEXT_TO_IMAGE_MODEL"))
22
 
23
 
24
  def create_text_to_image_tab(client: InferenceClient):
@@ -30,12 +28,13 @@ def create_text_to_image_tab(client: InferenceClient):
30
  Args:
31
  client: Hugging Face InferenceClient instance to pass to the text_to_image function.
32
  """
 
33
  gr.Markdown("Generate an image from a text prompt.")
34
  text_to_image_prompt = gr.Textbox(label="Prompt")
35
  text_to_image_generate_button = gr.Button("Generate")
36
  text_to_image_output = gr.Image(label="Image", type="pil")
37
  text_to_image_generate_button.click(
38
- fn=partial(text_to_image, client),
39
  inputs=text_to_image_prompt,
40
  outputs=text_to_image_output
41
  )
 
5
  from huggingface_hub import InferenceClient
6
 
7
 
8
+ def text_to_image(client: InferenceClient, model: str, prompt: str) -> Image:
9
  """Generate an image from a text prompt using Hugging Face Inference API.
10
 
11
  Args:
12
  client: Hugging Face InferenceClient instance for API calls.
13
+ model: Hugging Face model ID to use for text-to-image generation.
14
  prompt: Text description of the desired image.
15
 
16
  Returns:
17
  PIL Image object representing the generated image.
 
 
 
18
  """
19
+ return client.text_to_image(prompt, model=model)
20
 
21
 
22
  def create_text_to_image_tab(client: InferenceClient):
 
28
  Args:
29
  client: Hugging Face InferenceClient instance to pass to the text_to_image function.
30
  """
31
+ model_id = getenv("TEXT_TO_IMAGE_MODEL")
32
  gr.Markdown("Generate an image from a text prompt.")
33
  text_to_image_prompt = gr.Textbox(label="Prompt")
34
  text_to_image_generate_button = gr.Button("Generate")
35
  text_to_image_output = gr.Image(label="Image", type="pil")
36
  text_to_image_generate_button.click(
37
+ fn=partial(text_to_image, client, model_id),
38
  inputs=text_to_image_prompt,
39
  outputs=text_to_image_output
40
  )
text_to_speech.py CHANGED
@@ -1,4 +1,5 @@
1
  import gc
 
2
  from os import getenv
3
  import gradio as gr
4
  from transformers import pipeline
@@ -6,7 +7,7 @@ from utils import spaces_gpu
6
 
7
 
8
  @spaces_gpu
9
- def text_to_speech(text: str) -> tuple[int, bytes]:
10
  """Convert text to speech audio using a TTS (Text-to-Speech) model.
11
 
12
  This function uses a transformer pipeline to generate speech audio from
@@ -14,6 +15,7 @@ def text_to_speech(text: str) -> tuple[int, bytes]:
14
  up to free GPU memory.
15
 
16
  Args:
 
17
  text: Input text string to convert to speech.
18
 
19
  Returns:
@@ -22,7 +24,6 @@ def text_to_speech(text: str) -> tuple[int, bytes]:
22
  - bytes: Raw audio data as bytes
23
 
24
  Note:
25
- - The model ID is determined by the TEXT_TO_SPEECH_MODEL environment variable.
26
  - Uses safetensors for secure model loading.
27
  - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
28
  - Cleans up model and GPU memory after inference.
@@ -30,7 +31,7 @@ def text_to_speech(text: str) -> tuple[int, bytes]:
30
  """
31
  narrator = pipeline(
32
  "text-to-speech",
33
- getenv("TEXT_TO_SPEECH_MODEL"),
34
  model_kwargs={"use_safetensors": True} # Use safetensors to avoid torch.load restriction.
35
  )
36
  result = narrator(text)
@@ -45,12 +46,13 @@ def create_text_to_speech_tab():
45
  This function sets up all UI components for text-to-speech generation,
46
  including input textbox, generate button, and output audio player.
47
  """
 
48
  gr.Markdown("Generate speech from text.")
49
  text_to_speech_text = gr.Textbox(label="Text")
50
  text_to_speech_generate_button = gr.Button("Generate")
51
  text_to_speech_output = gr.Audio(label="Speech")
52
  text_to_speech_generate_button.click(
53
- fn=text_to_speech,
54
  inputs=text_to_speech_text,
55
  outputs=text_to_speech_output
56
  )
 
1
  import gc
2
+ from functools import partial
3
  from os import getenv
4
  import gradio as gr
5
  from transformers import pipeline
 
7
 
8
 
9
  @spaces_gpu
10
+ def text_to_speech(model: str, text: str) -> tuple[int, bytes]:
11
  """Convert text to speech audio using a TTS (Text-to-Speech) model.
12
 
13
  This function uses a transformer pipeline to generate speech audio from
 
15
  up to free GPU memory.
16
 
17
  Args:
18
+ model: Hugging Face model ID to use for text-to-speech.
19
  text: Input text string to convert to speech.
20
 
21
  Returns:
 
24
  - bytes: Raw audio data as bytes
25
 
26
  Note:
 
27
  - Uses safetensors for secure model loading.
28
  - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
29
  - Cleans up model and GPU memory after inference.
 
31
  """
32
  narrator = pipeline(
33
  "text-to-speech",
34
+ model,
35
  model_kwargs={"use_safetensors": True} # Use safetensors to avoid torch.load restriction.
36
  )
37
  result = narrator(text)
 
46
  This function sets up all UI components for text-to-speech generation,
47
  including input textbox, generate button, and output audio player.
48
  """
49
+ model_id = getenv("TEXT_TO_SPEECH_MODEL")
50
  gr.Markdown("Generate speech from text.")
51
  text_to_speech_text = gr.Textbox(label="Text")
52
  text_to_speech_generate_button = gr.Button("Generate")
53
  text_to_speech_output = gr.Audio(label="Speech")
54
  text_to_speech_generate_button.click(
55
+ fn=partial(text_to_speech, model_id),
56
  inputs=text_to_speech_text,
57
  outputs=text_to_speech_output
58
  )