kayfahaarukku commited on
Commit
0b52264
·
verified ·
1 Parent(s): 638c2b7

Upload 2 files

Browse files
Files changed (2) hide show
  1. gradio_app.py +120 -120
  2. test_lora.py +7 -4
gradio_app.py CHANGED
@@ -1,121 +1,121 @@
1
- import gradio as gr
2
- from test_lora import DanbooruTagTester
3
- import sys
4
- import io
5
- import spaces
6
-
7
- # Gradio's state management will hold the instance of our tester
8
- # This is better than a global variable as it's session-specific
9
-
10
- @spaces.GPU(duration=300) # Request GPU for model loading, with a 5-min timeout
11
- def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm=True)):
12
- """
13
- Loads the model and updates the UI.
14
- Captures stdout to display loading progress in the UI.
15
- """
16
- # Redirect stdout to capture print statements from the model loading process
17
- old_stdout = sys.stdout
18
- sys.stdout = captured_output = io.StringIO()
19
-
20
- tester = None
21
- status_message = ""
22
- success = False
23
-
24
- try:
25
- tester = DanbooruTagTester(
26
- model_path=model_path,
27
- base_model_id=base_model,
28
- use_4bit=use_4bit,
29
- non_interactive=True # Ensure no input() calls hang the app
30
- )
31
- status_message = "Model loaded successfully!"
32
- success = True
33
- except Exception as e:
34
- status_message = f"Error loading model: {e}"
35
- finally:
36
- # Restore stdout
37
- sys.stdout = old_stdout
38
-
39
- # Get captured output and combine with status message
40
- log_output = captured_output.getvalue()
41
- final_status = log_output + "\n" + status_message
42
-
43
- # Return the loaded model instance, the status message, and UI updates
44
- return tester, final_status, gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success)
45
-
46
- @spaces.GPU # Request GPU for generation
47
- def generate_tags(tester, prompt, max_new_tokens, temperature, top_k, top_p, do_sample):
48
- """
49
- Generates tags using the loaded model.
50
- """
51
- if tester is None:
52
- return "Error: Model not loaded. Please load a model first."
53
-
54
- try:
55
- completion = tester.generate_tags(
56
- input_prompt=prompt,
57
- max_new_tokens=int(max_new_tokens),
58
- temperature=temperature,
59
- top_k=int(top_k),
60
- top_p=top_p,
61
- do_sample=do_sample
62
- )
63
- return completion
64
- except Exception as e:
65
- return f"Error during generation: {e}"
66
-
67
- # --- Gradio Interface Definition ---
68
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
69
- tester_state = gr.State(None)
70
-
71
- gr.Markdown("# Danbooru Tag Autocompletion UI")
72
- gr.Markdown("Load a LoRA model and generate Danbooru tag completions.")
73
-
74
- with gr.Row():
75
- with gr.Column(scale=1):
76
- gr.Markdown("## 1. Load Model")
77
- # Using user's github username "nawka12" as default model path from memory
78
- model_path_input = gr.Textbox(label="Model Path (HF Hub or local)", value="kayfahaarukku/chek-8")
79
- base_model_input = gr.Textbox(label="Base Model ID", value="google/gemma-3-1b-it")
80
- use_4bit_checkbox = gr.Checkbox(label="Use 4-bit Quantization", value=True)
81
- load_button = gr.Button("Load Model", variant="primary")
82
-
83
- with gr.Column(scale=2):
84
- gr.Markdown("## 2. Generate Tags")
85
- # Generation UI is disabled until model is loaded
86
- prompt_input = gr.Textbox(label="Input Prompt", lines=2, placeholder="e.g., 1girl, hatsune miku, vocaloid", interactive=False)
87
- generate_button = gr.Button("Generate", variant="primary", interactive=False)
88
-
89
- with gr.Accordion("Generation Settings", open=False):
90
- max_new_tokens_slider = gr.Slider(minimum=10, maximum=500, value=150, step=10, label="Max New Tokens", interactive=False)
91
- temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", interactive=False)
92
- top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K", interactive=False)
93
- top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P", interactive=False)
94
- do_sample_checkbox = gr.Checkbox(label="Use Sampling", value=True, interactive=False)
95
-
96
- with gr.Row():
97
- with gr.Column():
98
- gr.Markdown("### Status & Logs")
99
- status_output = gr.Textbox(label="Loading Log", lines=8, interactive=False, max_lines=20)
100
-
101
- with gr.Column():
102
- gr.Markdown("### Generated Tags")
103
- completion_output = gr.Textbox(label="Output", lines=8, interactive=False, max_lines=20)
104
-
105
- # --- Event Handlers ---
106
- generation_inputs = [prompt_input, generate_button, max_new_tokens_slider, temperature_slider, top_k_slider, top_p_slider, do_sample_checkbox]
107
-
108
- load_button.click(
109
- fn=load_model,
110
- inputs=[model_path_input, base_model_input, use_4bit_checkbox],
111
- outputs=[tester_state, status_output] + generation_inputs
112
- )
113
-
114
- generate_button.click(
115
- fn=generate_tags,
116
- inputs=[tester_state, prompt_input, max_new_tokens_slider, temperature_slider, top_k_slider, top_p_slider, do_sample_checkbox],
117
- outputs=completion_output
118
- )
119
-
120
- if __name__ == "__main__":
121
  demo.launch(server_name="0.0.0.0")
 
1
+ import gradio as gr
2
+ from test_lora import DanbooruTagTester
3
+ import sys
4
+ import io
5
+ import spaces
6
+
7
+ # Gradio's state management will hold the instance of our tester
8
+ # This is better than a global variable as it's session-specific
9
+
10
+ @spaces.GPU(duration=300) # Request GPU for model loading, with a 5-min timeout
11
+ def load_model(model_path, base_model, use_4bit, progress=gr.Progress(track_tqdm=True)):
12
+ """
13
+ Loads the model and updates the UI.
14
+ Captures stdout to display loading progress in the UI.
15
+ """
16
+ # Redirect stdout to capture print statements from the model loading process
17
+ old_stdout = sys.stdout
18
+ sys.stdout = captured_output = io.StringIO()
19
+
20
+ tester = None
21
+ status_message = ""
22
+ success = False
23
+
24
+ try:
25
+ tester = DanbooruTagTester(
26
+ model_path=model_path,
27
+ base_model_id=base_model,
28
+ use_4bit=use_4bit,
29
+ non_interactive=True # Ensure no input() calls hang the app
30
+ )
31
+ status_message = "Model loaded successfully!"
32
+ success = True
33
+ except Exception as e:
34
+ status_message = f"Error loading model: {e}"
35
+ finally:
36
+ # Restore stdout
37
+ sys.stdout = old_stdout
38
+
39
+ # Get captured output and combine with status message
40
+ log_output = captured_output.getvalue()
41
+ final_status = log_output + "\n" + status_message
42
+
43
+ # Return the loaded model instance, the status message, and UI updates
44
+ return tester, final_status, gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success), gr.update(interactive=success)
45
+
46
+ @spaces.GPU # Request GPU for generation
47
+ def generate_tags(tester, prompt, max_new_tokens, temperature, top_k, top_p, do_sample):
48
+ """
49
+ Generates tags using the loaded model.
50
+ """
51
+ if tester is None:
52
+ return "Error: Model not loaded. Please load a model first."
53
+
54
+ try:
55
+ completion = tester.generate_tags(
56
+ input_prompt=prompt,
57
+ max_new_tokens=int(max_new_tokens),
58
+ temperature=temperature,
59
+ top_k=int(top_k),
60
+ top_p=top_p,
61
+ do_sample=do_sample
62
+ )
63
+ return completion
64
+ except Exception as e:
65
+ return f"Error during generation: {e}"
66
+
67
+ # --- Gradio Interface Definition ---
68
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
69
+ tester_state = gr.State(None)
70
+
71
+ gr.Markdown("# Danbooru Tag Autocompletion UI")
72
+ gr.Markdown("Load a LoRA model and generate Danbooru tag completions.")
73
+
74
+ with gr.Row():
75
+ with gr.Column(scale=1):
76
+ gr.Markdown("## 1. Load Model")
77
+ # Using user's github username "nawka12" as default model path from memory
78
+ model_path_input = gr.Textbox(label="Model Path (HF Hub or local)", value="kayfahaarukku/chek-8")
79
+ base_model_input = gr.Textbox(label="Base Model ID", value="google/gemma-3-1b-it")
80
+ use_4bit_checkbox = gr.Checkbox(label="Use 4-bit Quantization", value=True)
81
+ load_button = gr.Button("Load Model", variant="primary")
82
+
83
+ with gr.Column(scale=2):
84
+ gr.Markdown("## 2. Generate Tags")
85
+ # Generation UI is disabled until model is loaded
86
+ prompt_input = gr.Textbox(label="Input Prompt", lines=2, placeholder="e.g., 1girl, hatsune miku, vocaloid", interactive=False)
87
+ generate_button = gr.Button("Generate", variant="primary", interactive=False)
88
+
89
+ with gr.Accordion("Generation Settings", open=False):
90
+ max_new_tokens_slider = gr.Slider(minimum=10, maximum=500, value=150, step=10, label="Max New Tokens", interactive=False)
91
+ temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", interactive=False)
92
+ top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K", interactive=False)
93
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P", interactive=False)
94
+ do_sample_checkbox = gr.Checkbox(label="Use Sampling", value=True, interactive=False)
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ gr.Markdown("### Status & Logs")
99
+ status_output = gr.Textbox(label="Loading Log", lines=8, interactive=False, max_lines=20)
100
+
101
+ with gr.Column():
102
+ gr.Markdown("### Generated Tags")
103
+ completion_output = gr.Textbox(label="Output", lines=8, interactive=False, max_lines=20)
104
+
105
+ # --- Event Handlers ---
106
+ generation_inputs = [prompt_input, generate_button, max_new_tokens_slider, temperature_slider, top_k_slider, top_p_slider, do_sample_checkbox]
107
+
108
+ load_button.click(
109
+ fn=load_model,
110
+ inputs=[model_path_input, base_model_input, use_4bit_checkbox],
111
+ outputs=[tester_state, status_output] + generation_inputs
112
+ )
113
+
114
+ generate_button.click(
115
+ fn=generate_tags,
116
+ inputs=[tester_state, prompt_input, max_new_tokens_slider, temperature_slider, top_k_slider, top_p_slider, do_sample_checkbox],
117
+ outputs=completion_output
118
+ )
119
+
120
+ if __name__ == "__main__":
121
  demo.launch(server_name="0.0.0.0")
test_lora.py CHANGED
@@ -30,6 +30,8 @@ class DanbooruTagTester:
30
  def _load_model(self):
31
  """Load the base model, LoRA weights, and tokenizer"""
32
 
 
 
33
  # Configure quantization if requested
34
  if self.use_4bit:
35
  try:
@@ -53,14 +55,15 @@ class DanbooruTagTester:
53
  quantization_config=bnb_config,
54
  device_map="auto",
55
  torch_dtype=torch.bfloat16 if not self.use_4bit else None,
 
56
  )
57
 
58
  # Check if this is actually a LoRA model or just the base model
59
  try:
60
  # Try to load LoRA config to check if it's a LoRA model
61
- peft_config = PeftConfig.from_pretrained(self.model_path)
62
  print("Loading LoRA weights...")
63
- self.model = PeftModel.from_pretrained(self.base_model, self.model_path)
64
  print("LoRA model loaded successfully!")
65
  except Exception as e:
66
  print(f"Warning: Could not load LoRA weights from {self.model_path}")
@@ -91,10 +94,10 @@ class DanbooruTagTester:
91
  # Load tokenizer
92
  print("Loading tokenizer...")
93
  try:
94
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
95
  except Exception as e:
96
  print(f"Could not load tokenizer from model path, trying base model...")
97
- self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
98
 
99
  if self.tokenizer.pad_token is None:
100
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
30
  def _load_model(self):
31
  """Load the base model, LoRA weights, and tokenizer"""
32
 
33
+ hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
34
+
35
  # Configure quantization if requested
36
  if self.use_4bit:
37
  try:
 
55
  quantization_config=bnb_config,
56
  device_map="auto",
57
  torch_dtype=torch.bfloat16 if not self.use_4bit else None,
58
+ token=hf_token,
59
  )
60
 
61
  # Check if this is actually a LoRA model or just the base model
62
  try:
63
  # Try to load LoRA config to check if it's a LoRA model
64
+ peft_config = PeftConfig.from_pretrained(self.model_path, token=hf_token)
65
  print("Loading LoRA weights...")
66
+ self.model = PeftModel.from_pretrained(self.base_model, self.model_path, token=hf_token)
67
  print("LoRA model loaded successfully!")
68
  except Exception as e:
69
  print(f"Warning: Could not load LoRA weights from {self.model_path}")
 
94
  # Load tokenizer
95
  print("Loading tokenizer...")
96
  try:
97
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, token=hf_token)
98
  except Exception as e:
99
  print(f"Could not load tokenizer from model path, trying base model...")
100
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id, token=hf_token)
101
 
102
  if self.tokenizer.pad_token is None:
103
  self.tokenizer.pad_token = self.tokenizer.eos_token