Carlexx commited on
Commit
9f8531b
·
verified ·
1 Parent(s): e0f9c7c

Update dreamo_helpers.py

Browse files
Files changed (1) hide show
  1. dreamo_helpers.py +37 -29
dreamo_helpers.py CHANGED
@@ -62,38 +62,46 @@ class Generator:
62
  if torch.cuda.is_available(): torch.cuda.empty_cache()
63
 
64
  @torch.inference_mode()
65
- # <<<<< MODIFICAÇÃO PRINCIPAL: Aceita uma lista de dicionários de referência >>>>>
66
  def generate_image_with_gpu_management(self, reference_items, prompt, width, height):
67
- ref_conds = []
68
-
69
- for idx, item in enumerate(reference_items):
70
- ref_image_np = item.get('image_np')
71
- ref_task = item.get('task')
72
-
73
- if ref_image_np is not None:
74
- if ref_task == "id":
75
- ref_image = self.get_align_face(ref_image_np)
76
- elif ref_task != "style":
77
- ref_image = self.bg_rm_model.inference(Image.fromarray(ref_image_np))
78
- else: # Style usa a imagem original
79
- ref_image = ref_image_np
80
 
81
- ref_image_tensor = img2tensor(np.array(ref_image), bgr2rgb=False).unsqueeze(0) / 255.0
82
- ref_image_tensor = (2 * ref_image_tensor - 1.0).to(self.gpu_device, dtype=torch.bfloat16)
 
 
 
83
 
84
- # O modelo DreamO espera o índice começando em 1
85
- ref_conds.append({'img': ref_image_tensor, 'task': ref_task, 'idx': idx + 1})
86
-
87
- image = self.dreamo_pipeline(
88
- prompt=prompt,
89
- width=width,
90
- height=height,
91
- num_inference_steps=12,
92
- guidance_scale=4.5,
93
- ref_conds=ref_conds,
94
- generator=torch.Generator(device="cpu").manual_seed(42)
95
- ).images[0]
96
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  @torch.no_grad()
99
  def get_align_face(self, img):
 
62
  if torch.cuda.is_available(): torch.cuda.empty_cache()
63
 
64
  @torch.inference_mode()
65
+ # <<<<< CORREÇÃO IMPLEMENTADA: Gerenciamento de GPU atômico por chamada >>>>>
66
  def generate_image_with_gpu_management(self, reference_items, prompt, width, height):
67
+ try:
68
+ self.to_gpu() # Move os modelos para a GPU no início de CADA chamada
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ ref_conds = []
71
+
72
+ for idx, item in enumerate(reference_items):
73
+ ref_image_np = item.get('image_np')
74
+ ref_task = item.get('task')
75
 
76
+ if ref_image_np is not None:
77
+ if ref_task == "id":
78
+ ref_image = self.get_align_face(ref_image_np)
79
+ elif ref_task != "style":
80
+ ref_image = self.bg_rm_model.inference(Image.fromarray(ref_image_np))
81
+ else: # Style usa a imagem original
82
+ ref_image = ref_image_np
83
+
84
+ ref_image_tensor = img2tensor(np.array(ref_image), bgr2rgb=False).unsqueeze(0) / 255.0
85
+ ref_image_tensor = (2 * ref_image_tensor - 1.0).to(self.gpu_device, dtype=torch.bfloat16)
86
+
87
+ # O modelo DreamO espera o índice começando em 1
88
+ ref_conds.append({'img': ref_image_tensor, 'task': ref_task, 'idx': idx + 1})
89
+
90
+ image = self.dreamo_pipeline(
91
+ prompt=prompt,
92
+ width=width,
93
+ height=height,
94
+ num_inference_steps=12,
95
+ guidance_scale=4.5,
96
+ ref_conds=ref_conds,
97
+ generator=torch.Generator(device="cpu").manual_seed(42)
98
+ ).images[0]
99
+
100
+ return image
101
+
102
+ finally:
103
+ self.to_cpu() # Garante que os modelos voltem para a CPU, mesmo se ocorrer um erro
104
+
105
 
106
  @torch.no_grad()
107
  def get_align_face(self, img):