kayfahaarukku commited on
Commit
638c2b7
·
verified ·
1 Parent(s): f174f96

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +3 -1
gradio_app.py CHANGED
@@ -2,10 +2,12 @@ import gradio as gr
2
  from test_lora import DanbooruTagTester
3
  import sys
4
  import io
 
5
 
6
  # Gradio's state management will hold the instance of our tester
7
  # This is better than a global variable as it's session-specific
8
 
 
9
  def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm=True)):
10
  """
11
  Loads the model and updates the UI.
@@ -41,7 +43,7 @@ def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm
41
  # Return the loaded model instance, the status message, and UI updates
42
  return tester, final_status, gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success)
43
 
44
-
45
  def generate_tags(tester, prompt, max_new_tokens, temperature, top_k, top_p, do_sample):
46
  """
47
  Generates tags using the loaded model.
 
2
  from test_lora import DanbooruTagTester
3
  import sys
4
  import io
5
+ import spaces
6
 
7
  # Gradio's state management will hold the instance of our tester
8
  # This is better than a global variable as it's session-specific
9
 
10
+ @spaces.GPU(duration=300) # Request GPU for model loading, with a 5-min timeout
11
  def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm=True)):
12
  """
13
  Loads the model and updates the UI.
 
43
  # Return the loaded model instance, the status message, and UI updates
44
  return tester, final_status, gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success)
45
 
46
+ @spaces.GPU # Request GPU for generation
47
  def generate_tags(tester, prompt, max_new_tokens, temperature, top_k, top_p, do_sample):
48
  """
49
  Generates tags using the loaded model.