Spaces:
Sleeping
Sleeping
Update gradio_app.py
Browse files- gradio_app.py +3 -1
gradio_app.py
CHANGED
|
@@ -2,10 +2,12 @@ import gradio as gr
|
|
| 2 |
from test_lora import DanbooruTagTester
|
| 3 |
import sys
|
| 4 |
import io
|
|
|
|
| 5 |
|
| 6 |
# Gradio's state management will hold the instance of our tester
|
| 7 |
# This is better than a global variable as it's session-specific
|
| 8 |
|
|
|
|
| 9 |
def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm=True)):
|
| 10 |
"""
|
| 11 |
Loads the model and updates the UI.
|
|
@@ -41,7 +43,7 @@ def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm
|
|
| 41 |
# Return the loaded model instance, the status message, and UI updates
|
| 42 |
return tester, final_status, gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success)
|
| 43 |
|
| 44 |
-
|
| 45 |
def generate_tags(tester, prompt, max_new_tokens, temperature, top_k, top_p, do_sample):
|
| 46 |
"""
|
| 47 |
Generates tags using the loaded model.
|
|
|
|
| 2 |
from test_lora import DanbooruTagTester
|
| 3 |
import sys
|
| 4 |
import io
|
| 5 |
+
import spaces
|
| 6 |
|
| 7 |
# Gradio's state management will hold the instance of our tester
|
| 8 |
# This is better than a global variable as it's session-specific
|
| 9 |
|
| 10 |
+
@spaces.GPU(duration=300) # Request GPU for model loading, with a 5-min timeout
|
| 11 |
def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm=True)):
|
| 12 |
"""
|
| 13 |
Loads the model and updates the UI.
|
|
|
|
| 43 |
# Return the loaded model instance, the status message, and UI updates
|
| 44 |
return tester, final_status, gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success)
|
| 45 |
|
| 46 |
+
@spaces.GPU # Request GPU for generation
|
| 47 |
def generate_tags(tester, prompt, max_new_tokens, temperature, top_k, top_p, do_sample):
|
| 48 |
"""
|
| 49 |
Generates tags using the loaded model.
|