yonishafir commited on
Commit
a0be13a
·
verified ·
1 Parent(s): 8e98ed6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -50
app.py CHANGED
@@ -22,7 +22,7 @@ import pandas as pd
22
  import json
23
  import requests
24
  from io import BytesIO
25
- from huggingface_hub import hf_hub_download
26
 
27
 
28
  def resize_img(input_image, max_side=1280, min_side=1024, size=None,
@@ -113,14 +113,14 @@ def process_image_by_bbox_larger(input_image, bbox_xyxy, min_bbox_ratio=0.2):
113
 
114
  return resized_image
115
 
116
- def calc_emb_cropped(image, app):
117
  face_image = image.copy()
118
 
119
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
120
 
121
  face_info = face_info[0]
122
 
123
- cropped_face_image = process_image_by_bbox_larger(face_image, face_info["bbox"], min_bbox_ratio=0.2)
124
 
125
  return cropped_face_image
126
 
@@ -144,7 +144,29 @@ default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly
144
  CURRENT_LORA_NAME = None
145
 
146
  # Load face detection and recognition package
147
- app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  app.prepare(ctx_id=0, det_size=(640, 640))
149
 
150
 
@@ -155,19 +177,14 @@ hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="controlnet/diffu
155
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="ip-adapter.bin", local_dir="./checkpoints")
156
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/pytorch_model.bin", local_dir="./checkpoints")
157
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/config.json", local_dir="./checkpoints")
158
- # Download Lora weights
159
- hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/3D_illustration/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
160
- hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Avatar_internlm/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
161
- hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Characters/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
162
- hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Storyboards/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
163
- hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Vangogh_Vanilla/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
164
 
165
  device = "cuda" if torch.cuda.is_available() else "cpu"
166
 
167
  # ckpts paths
168
  face_adapter = f"./checkpoints/ip-adapter.bin"
169
  controlnet_path = f"./checkpoints/controlnet"
170
- lora_base_path = "./checkpoints/LoRAs"
171
  base_model_path = f'briaai/BRIA-2.3'
172
  resolution = 1024
173
 
@@ -199,16 +216,16 @@ pipe.load_ip_adapter_instantid(face_adapter)
199
 
200
  clip_embeds=None
201
 
202
- Loras_dict = {
203
- "":"",
204
- "Vangogh_Vanilla": "bold, dramatic brush strokes, vibrant colors, swirling patterns, intense, emotionally charged paintings of",
205
- "Avatar_internlm": "2d anime sketch avatar of",
206
- "Storyboards": "Illustration style for storyboarding.",
207
- "3D_illustration": "3D object illustration, abstract.",
208
- "Characters": "gaming vector Art."
209
- }
210
 
211
- lora_names = Loras_dict.keys()
212
 
213
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
214
  if randomize_seed:
@@ -217,8 +234,9 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
217
 
218
 
219
  @spaces.GPU
220
- def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale, progress=gr.Progress(track_tqdm=True)):
221
- global CURRENT_LORA_NAME # Use the global variable to track LoRA
 
222
 
223
  if image_path is None:
224
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
@@ -272,31 +290,31 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
272
 
273
  generator = torch.Generator(device=device).manual_seed(seed)
274
 
275
- if lora_name != CURRENT_LORA_NAME: # Check if LoRA needs to be changed
276
- if CURRENT_LORA_NAME is not None: # If a LoRA is already loaded, unload it
277
- pipe.disable_lora()
278
- pipe.unfuse_lora()
279
- pipe.unload_lora_weights()
280
- print(f"Unloaded LoRA: {CURRENT_LORA_NAME}")
281
 
282
- if lora_name != "": # Load the new LoRA if specified
283
- # pipe.enable_model_cpu_offload()
284
- lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
285
- pipe.load_lora_weights(lora_path)
286
- pipe.fuse_lora(lora_scale)
287
- pipe.enable_lora()
288
 
289
- # lora_prefix = Loras_dict[lora_name]
290
 
291
- print(f"Loaded new LoRA: {lora_name}")
292
 
293
- # Update the current LoRA name
294
- CURRENT_LORA_NAME = lora_name
295
 
296
- if lora_name != "":
297
- full_prompt = f"{Loras_dict[lora_name]} + " " + {prompt}"
298
- else:
299
- full_prompt = prompt
300
 
301
  print("Start inference...")
302
  images = pipe(
@@ -360,7 +378,7 @@ with gr.Blocks(css=css) as demo:
360
  info="Describe what you want to generate or modify in the image."
361
  )
362
 
363
- lora_name = gr.Dropdown(choices=lora_names, label="LoRA", value="", info="Select a LoRA name from the list, not selecting any will disable LoRA.")
364
 
365
  submit = gr.Button("Submit", variant="primary")
366
 
@@ -407,13 +425,13 @@ with gr.Blocks(css=css) as demo:
407
  step=0.01,
408
  value=0.4,
409
  )
410
- lora_scale = gr.Slider(
411
- label="lora_scale",
412
- minimum=0.0,
413
- maximum=1.0,
414
- step=0.01,
415
- value=0.7,
416
- )
417
  seed = gr.Slider(
418
  label="Seed",
419
  minimum=0,
 
22
  import json
23
  import requests
24
  from io import BytesIO
25
+ from huggingface_hub import hf_hub_download, HfApi
26
 
27
 
28
  def resize_img(input_image, max_side=1280, min_side=1024, size=None,
 
113
 
114
  return resized_image
115
 
116
+ def calc_emb_cropped(image, app, min_bbox_ratio=0.2):
117
  face_image = image.copy()
118
 
119
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
120
 
121
  face_info = face_info[0]
122
 
123
+ cropped_face_image = process_image_by_bbox_larger(face_image, face_info["bbox"], min_bbox_ratio=min_bbox_ratio)
124
 
125
  return cropped_face_image
126
 
 
144
  CURRENT_LORA_NAME = None
145
 
146
  # Load face detection and recognition package
147
+ # app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
148
+
149
+
150
+ # Define your repository, target directory, and local directory
151
+ repo_id = "briaai/ID_preservation_2.3"
152
+ target_dir = "models_aura/" # The directory you want to download
153
+ local_dir = "./checkpoints" # Local directory to save files
154
+
155
+ # Initialize the API
156
+ api = HfApi()
157
+ # List all files in the repository
158
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model") # Use repo_type="space" for Spaces
159
+
160
+ # Filter files that are in the target directory
161
+ files_in_dir = [file for file in files if file.startswith(target_dir)]
162
+ # Download each file in the target directory
163
+ for file in files_in_dir:
164
+ local_path = os.path.join(local_dir, file)
165
+ os.makedirs(os.path.dirname(local_path), exist_ok=True) # Ensure local directories exist
166
+ print(f"Downloading: {file}")
167
+ hf_hub_download(repo_id=repo_id, filename=file, local_dir=os.path.dirname(local_path))
168
+
169
+ app = FaceAnalysis(name='auraface', root='./checkpoints/models_aura/', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
170
  app.prepare(ctx_id=0, det_size=(640, 640))
171
 
172
 
 
177
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="ip-adapter.bin", local_dir="./checkpoints")
178
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/pytorch_model.bin", local_dir="./checkpoints")
179
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/config.json", local_dir="./checkpoints")
180
+
 
 
 
 
 
181
 
182
  device = "cuda" if torch.cuda.is_available() else "cpu"
183
 
184
  # ckpts paths
185
  face_adapter = f"./checkpoints/ip-adapter.bin"
186
  controlnet_path = f"./checkpoints/controlnet"
187
+ # lora_base_path = "./checkpoints/LoRAs"
188
  base_model_path = f'briaai/BRIA-2.3'
189
  resolution = 1024
190
 
 
216
 
217
  clip_embeds=None
218
 
219
+ # Loras_dict = {
220
+ # "":"",
221
+ # "Vangogh_Vanilla": "bold, dramatic brush strokes, vibrant colors, swirling patterns, intense, emotionally charged paintings of",
222
+ # "Avatar_internlm": "2d anime sketch avatar of",
223
+ # "Storyboards": "Illustration style for storyboarding.",
224
+ # "3D_illustration": "3D object illustration, abstract.",
225
+ # "Characters": "gaming vector Art."
226
+ # }
227
 
228
+ # lora_names = Loras_dict.keys()
229
 
230
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
231
  if randomize_seed:
 
234
 
235
 
236
  @spaces.GPU
237
+ # def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale, progress=gr.Progress(track_tqdm=True)):
238
+ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, progress=gr.Progress(track_tqdm=True)):
239
+ # global CURRENT_LORA_NAME # Use the global variable to track LoRA
240
 
241
  if image_path is None:
242
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
 
290
 
291
  generator = torch.Generator(device=device).manual_seed(seed)
292
 
293
+ # if lora_name != CURRENT_LORA_NAME: # Check if LoRA needs to be changed
294
+ # if CURRENT_LORA_NAME is not None: # If a LoRA is already loaded, unload it
295
+ # pipe.disable_lora()
296
+ # pipe.unfuse_lora()
297
+ # pipe.unload_lora_weights()
298
+ # print(f"Unloaded LoRA: {CURRENT_LORA_NAME}")
299
 
300
+ # if lora_name != "": # Load the new LoRA if specified
301
+ # # pipe.enable_model_cpu_offload()
302
+ # lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
303
+ # pipe.load_lora_weights(lora_path)
304
+ # pipe.fuse_lora(lora_scale)
305
+ # pipe.enable_lora()
306
 
307
+ # # lora_prefix = Loras_dict[lora_name]
308
 
309
+ # print(f"Loaded new LoRA: {lora_name}")
310
 
311
+ # # Update the current LoRA name
312
+ # CURRENT_LORA_NAME = lora_name
313
 
314
+ # if lora_name != "":
315
+ # full_prompt = f"{Loras_dict[lora_name]} + " " + {prompt}"
316
+ # else:
317
+ full_prompt = prompt
318
 
319
  print("Start inference...")
320
  images = pipe(
 
378
  info="Describe what you want to generate or modify in the image."
379
  )
380
 
381
+ # lora_name = gr.Dropdown(choices=lora_names, label="LoRA", value="", info="Select a LoRA name from the list, not selecting any will disable LoRA.")
382
 
383
  submit = gr.Button("Submit", variant="primary")
384
 
 
425
  step=0.01,
426
  value=0.4,
427
  )
428
+ # lora_scale = gr.Slider(
429
+ # label="lora_scale",
430
+ # minimum=0.0,
431
+ # maximum=1.0,
432
+ # step=0.01,
433
+ # value=0.7,
434
+ # )
435
  seed = gr.Slider(
436
  label="Seed",
437
  minimum=0,