Ffftdtd5dtft commited on
Commit
1811735
verified
1 Parent(s): 94334a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -48,13 +48,13 @@ def store_special_tokens(tokenizer, model_name):
48
  def load_special_tokens(tokenizer, model_name):
49
  special_tokens = redis_client.hgetall(f"tokenizer_special_tokens:{model_name}")
50
  if special_tokens:
51
- tokenizer.pad_token = special_tokens.get('pad_token')
52
  tokenizer.pad_token_id = int(special_tokens.get('pad_token_id', -1))
53
- tokenizer.eos_token = special_tokens.get('eos_token')
54
  tokenizer.eos_token_id = int(special_tokens.get('eos_token_id', -1))
55
- tokenizer.unk_token = special_tokens.get('unk_token')
56
  tokenizer.unk_token_id = int(special_tokens.get('unk_token_id', -1))
57
- tokenizer.bos_token = special_tokens.get('bos_token')
58
  tokenizer.bos_token_id = int(special_tokens.get('bos_token_id', -1))
59
 
60
  def train_and_store_transformers_model(model_name, data):
@@ -299,7 +299,7 @@ def generate_response_from_prompt(prompt, model_name="google/flan-t5-xl"):
299
  responses = generate_random_response([prompt], generator)
300
  return responses[0]
301
 
302
- def generate_image_from_prompt(prompt, image_type, model_name):
303
  if image_type == "diffusers":
304
  image = generate_diffusers_image_from_redis(model_name, prompt)
305
  elif image_type == "stable-diffusion":
@@ -321,9 +321,10 @@ def gradio_app():
321
  with gr.Row():
322
  prompt_image = gr.Textbox(label="Prompt de Imagen")
323
  image_type = gr.Dropdown(["diffusers", "stable-diffusion", "img2img"], label="Tipo de Imagen")
 
324
  image_output = gr.Image(label="Imagen Generada")
325
  image_button = gr.Button("Generar Imagen")
326
- image_button.click(generate_image_from_prompt, inputs=[prompt_image, image_type], outputs=image_output)
327
 
328
  gr.Markdown("## Generaci贸n de Video")
329
  with gr.Row():
@@ -334,10 +335,11 @@ def gradio_app():
334
 
335
  gr.Markdown("## Generaci贸n de Audio con MusicGen")
336
  with gr.Row():
 
337
  text_prompts_audio = gr.Textbox(label="Prompts de Audio")
338
  audio_output = gr.Audio(label="Audio Generado")
339
  audio_button = gr.Button("Generar Audio")
340
- audio_button.click(generate_musicgen_audio_from_redis, inputs=text_prompts_audio, outputs=audio_output)
341
 
342
  gr.Markdown("## Transcripci贸n de Audio con Whisper")
343
  with gr.Row():
@@ -348,19 +350,21 @@ def gradio_app():
348
 
349
  gr.Markdown("## Traducci贸n de Texto")
350
  with gr.Row():
 
351
  text_input = gr.Textbox(label="Texto a Traducir")
352
  translation_output = gr.Textbox(label="Traducci贸n")
353
  src_lang_input = gr.Textbox(label="Idioma de Origen", value="en")
354
  tgt_lang_input = gr.Textbox(label="Idioma de Destino", value="es")
355
  translate_button = gr.Button("Traducir Texto")
356
- translate_button.click(translate_text_from_redis, inputs=[text_input, src_lang_input, tgt_lang_input], outputs=translation_output)
357
 
358
  gr.Markdown("## Resumen de Texto")
359
  with gr.Row():
 
360
  text_to_summarize = gr.Textbox(label="Texto para Resumir")
361
  summary_output = gr.Textbox(label="Resumen")
362
  summarize_button = gr.Button("Generar Resumen")
363
- summarize_button.click(summarize_text_from_redis, inputs=text_to_summarize, outputs=summary_output)
364
 
365
  app.launch()
366
 
 
48
  def load_special_tokens(tokenizer, model_name):
49
  special_tokens = redis_client.hgetall(f"tokenizer_special_tokens:{model_name}")
50
  if special_tokens:
51
+ tokenizer.pad_token = special_tokens.get('pad_token', '').decode("utf-8")
52
  tokenizer.pad_token_id = int(special_tokens.get('pad_token_id', -1))
53
+ tokenizer.eos_token = special_tokens.get('eos_token', '').decode("utf-8")
54
  tokenizer.eos_token_id = int(special_tokens.get('eos_token_id', -1))
55
+ tokenizer.unk_token = special_tokens.get('unk_token', '').decode("utf-8")
56
  tokenizer.unk_token_id = int(special_tokens.get('unk_token_id', -1))
57
+ tokenizer.bos_token = special_tokens.get('bos_token', '').decode("utf-8")
58
  tokenizer.bos_token_id = int(special_tokens.get('bos_token_id', -1))
59
 
60
  def train_and_store_transformers_model(model_name, data):
 
299
  responses = generate_random_response([prompt], generator)
300
  return responses[0]
301
 
302
+ def generate_image_from_prompt(prompt, image_type, model_name="CompVis/stable-diffusion-v1-4"):
303
  if image_type == "diffusers":
304
  image = generate_diffusers_image_from_redis(model_name, prompt)
305
  elif image_type == "stable-diffusion":
 
321
  with gr.Row():
322
  prompt_image = gr.Textbox(label="Prompt de Imagen")
323
  image_type = gr.Dropdown(["diffusers", "stable-diffusion", "img2img"], label="Tipo de Imagen")
324
+ model_name_image = gr.Textbox(label="Nombre del Modelo", value="CompVis/stable-diffusion-v1-4")
325
  image_output = gr.Image(label="Imagen Generada")
326
  image_button = gr.Button("Generar Imagen")
327
+ image_button.click(generate_image_from_prompt, inputs=[prompt_image, image_type, model_name_image], outputs=image_output)
328
 
329
  gr.Markdown("## Generaci贸n de Video")
330
  with gr.Row():
 
335
 
336
  gr.Markdown("## Generaci贸n de Audio con MusicGen")
337
  with gr.Row():
338
+ model_name_audio = gr.Textbox(label="Nombre del Modelo", value="facebook/musicgen-small")
339
  text_prompts_audio = gr.Textbox(label="Prompts de Audio")
340
  audio_output = gr.Audio(label="Audio Generado")
341
  audio_button = gr.Button("Generar Audio")
342
+ audio_button.click(generate_musicgen_audio_from_redis, inputs=[model_name_audio, text_prompts_audio], outputs=audio_output)
343
 
344
  gr.Markdown("## Transcripci贸n de Audio con Whisper")
345
  with gr.Row():
 
350
 
351
  gr.Markdown("## Traducci贸n de Texto")
352
  with gr.Row():
353
+ model_name_translate = gr.Textbox(label="Nombre del Modelo", value="Helsinki-NLP/opus-mt-en-es")
354
  text_input = gr.Textbox(label="Texto a Traducir")
355
  translation_output = gr.Textbox(label="Traducci贸n")
356
  src_lang_input = gr.Textbox(label="Idioma de Origen", value="en")
357
  tgt_lang_input = gr.Textbox(label="Idioma de Destino", value="es")
358
  translate_button = gr.Button("Traducir Texto")
359
+ translate_button.click(translate_text_from_redis, inputs=[model_name_translate, text_input, src_lang_input, tgt_lang_input], outputs=translation_output)
360
 
361
  gr.Markdown("## Resumen de Texto")
362
  with gr.Row():
363
+ model_name_summarize = gr.Textbox(label="Nombre del Modelo", value="facebook/bart-large-cnn")
364
  text_to_summarize = gr.Textbox(label="Texto para Resumir")
365
  summary_output = gr.Textbox(label="Resumen")
366
  summarize_button = gr.Button("Generar Resumen")
367
+ summarize_button.click(summarize_text_from_redis, inputs=[model_name_summarize, text_to_summarize], outputs=summary_output)
368
 
369
  app.launch()
370