MohamedRashad commited on
Commit
9459cdf
·
1 Parent(s): 5ee125e

Move FluxPipeline to GPU and back to CPU during image generation; remove unused AutoencoderKL reference

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -21,7 +21,6 @@ from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_imag
21
  llm_client = Client("Qwen/Qwen2.5-72B-Instruct")
22
 
23
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
24
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16)
25
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
26
 
27
  def generate_t2i_prompt(item_name):
@@ -72,6 +71,7 @@ def preprocess_pil_image(image: Image.Image) -> Tuple[str, Image.Image]:
72
  @spaces.GPU
73
  def generate_item_image(object_t2i_prompt):
74
  trial_id = ""
 
75
  for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
76
  prompt=object_t2i_prompt,
77
  guidance_scale=3.5,
@@ -80,7 +80,6 @@ def generate_item_image(object_t2i_prompt):
80
  height=1024,
81
  generator=torch.Generator("cpu").manual_seed(0),
82
  output_type="pil",
83
- good_vae=good_vae,
84
  ):
85
  yield trial_id, image
86
  # img_path = t2i_client.predict(
@@ -94,6 +93,7 @@ def generate_item_image(object_t2i_prompt):
94
  # api_name="/infer"
95
  # )[0]
96
  # image = Image.open(img_path)
 
97
  trial_id, processed_image = preprocess_pil_image(image)
98
  yield trial_id, processed_image
99
 
 
21
  llm_client = Client("Qwen/Qwen2.5-72B-Instruct")
22
 
23
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
 
24
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
25
 
26
  def generate_t2i_prompt(item_name):
 
71
  @spaces.GPU
72
  def generate_item_image(object_t2i_prompt):
73
  trial_id = ""
74
+ pipe.to("cuda")
75
  for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
76
  prompt=object_t2i_prompt,
77
  guidance_scale=3.5,
 
80
  height=1024,
81
  generator=torch.Generator("cpu").manual_seed(0),
82
  output_type="pil",
 
83
  ):
84
  yield trial_id, image
85
  # img_path = t2i_client.predict(
 
93
  # api_name="/infer"
94
  # )[0]
95
  # image = Image.open(img_path)
96
+ pipe.to("cpu")
97
  trial_id, processed_image = preprocess_pil_image(image)
98
  yield trial_id, processed_image
99