hysts HF Staff commited on
Commit
0dd76b1
·
1 Parent(s): 800576d
Files changed (2) hide show
  1. app.py +65 -53
  2. live_preview_helpers.py +0 -166
app.py CHANGED
@@ -1,50 +1,70 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
14
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
15
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
16
- torch.cuda.empty_cache()
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 2048
20
 
21
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
22
 
23
  @spaces.GPU(duration=75)
24
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
  generator = torch.Generator().manual_seed(seed)
28
-
29
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
30
- prompt=prompt,
31
- guidance_scale=guidance_scale,
32
- num_inference_steps=num_inference_steps,
33
- width=width,
34
- height=height,
35
- generator=generator,
36
- output_type="pil",
37
- good_vae=good_vae,
38
- ):
39
- yield img, seed
40
-
41
  examples = [
42
  "a tiny astronaut hatching from an egg on the moon",
43
  "a cat holding a sign that says hello world",
44
  "an anime illustration of a wiener schnitzel",
45
  ]
46
 
47
- css="""
48
  #col-container {
49
  margin: 0 auto;
50
  max-width: 520px;
@@ -52,29 +72,23 @@ css="""
52
  """
53
 
54
  with gr.Blocks(css=css) as demo:
55
-
56
  with gr.Column(elem_id="col-container"):
57
- gr.Markdown(f"""# FLUX.1 [dev]
58
- 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
59
  [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
60
  """)
61
-
62
  with gr.Row():
63
-
64
  prompt = gr.Text(
65
  label="Prompt",
66
  show_label=False,
67
  max_lines=1,
68
  placeholder="Enter your prompt",
69
- container=False,
70
  )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
  result = gr.Image(label="Result", show_label=False)
75
-
76
  with gr.Accordion("Advanced Settings", open=False):
77
-
78
  seed = gr.Slider(
79
  label="Seed",
80
  minimum=0,
@@ -82,11 +96,9 @@ with gr.Blocks(css=css) as demo:
82
  step=1,
83
  value=0,
84
  )
85
-
86
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
87
-
88
  with gr.Row():
89
-
90
  width = gr.Slider(
91
  label="Width",
92
  minimum=256,
@@ -94,7 +106,7 @@ with gr.Blocks(css=css) as demo:
94
  step=32,
95
  value=1024,
96
  )
97
-
98
  height = gr.Slider(
99
  label="Height",
100
  minimum=256,
@@ -102,9 +114,8 @@ with gr.Blocks(css=css) as demo:
102
  step=32,
103
  value=1024,
104
  )
105
-
106
- with gr.Row():
107
 
 
108
  guidance_scale = gr.Slider(
109
  label="Guidance Scale",
110
  minimum=1,
@@ -112,7 +123,7 @@ with gr.Blocks(css=css) as demo:
112
  step=0.1,
113
  value=3.5,
114
  )
115
-
116
  num_inference_steps = gr.Slider(
117
  label="Number of inference steps",
118
  minimum=1,
@@ -120,20 +131,21 @@ with gr.Blocks(css=css) as demo:
120
  step=1,
121
  value=28,
122
  )
123
-
124
  gr.Examples(
125
- examples = examples,
126
- fn = infer,
127
- inputs = [prompt],
128
- outputs = [result, seed],
129
- cache_examples="lazy"
 
130
  )
131
 
132
- gr.on(
133
- triggers=[run_button.click, prompt.submit],
134
- fn = infer,
135
- inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
136
- outputs = [result, seed]
137
  )
138
 
139
- demo.launch()
 
 
1
+ import random
2
+
3
  import gradio as gr
4
  import numpy as np
5
+ import PIL.Image
6
  import spaces
7
  import torch
8
+ from diffusers import AutoencoderTiny, DiffusionPipeline
 
 
9
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
 
14
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
 
15
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 2048
18
 
 
19
 
20
  @spaces.GPU(duration=75)
21
+ def infer(
22
+ prompt: str,
23
+ seed: int = 42,
24
+ randomize_seed: bool = False,
25
+ width: int = 1024,
26
+ height: int = 1024,
27
+ guidance_scale: float = 3.5,
28
+ num_inference_steps: int = 28,
29
+ progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
30
+ ) -> tuple[PIL.Image.Image, int]:
31
+ """Generate an image from a prompt using the Flux.1 [dev] model.
32
+
33
+ Args:
34
+ prompt: The prompt to generate an image from.
35
+ seed: The seed to use for the image generation. Defaults to 42.
36
+ randomize_seed: Whether to randomize the seed. Defaults to False.
37
+ width: The width of the image. Defaults to 1024.
38
+ height: The height of the image. Defaults to 1024.
39
+ guidance_scale: The guidance scale to use for the image generation. Defaults to 3.5.
40
+ num_inference_steps: The number of inference steps to use for the image generation. Defaults to 28.
41
+ progress: The progress bar to use for the image generation. Defaults to a progress bar that tracks the tqdm progress.
42
+
43
+ Returns:
44
+ A tuple containing the generated image and the seed.
45
+ """
46
  if randomize_seed:
47
+ seed = random.randint(0, MAX_SEED) # noqa: S311
48
  generator = torch.Generator().manual_seed(seed)
49
+
50
+ image = pipe(
51
+ prompt=prompt,
52
+ width=width,
53
+ height=height,
54
+ num_inference_steps=num_inference_steps,
55
+ generator=generator,
56
+ guidance_scale=guidance_scale,
57
+ ).images[0]
58
+ return image, seed
59
+
60
+
 
61
  examples = [
62
  "a tiny astronaut hatching from an egg on the moon",
63
  "a cat holding a sign that says hello world",
64
  "an anime illustration of a wiener schnitzel",
65
  ]
66
 
67
+ css = """
68
  #col-container {
69
  margin: 0 auto;
70
  max-width: 520px;
 
72
  """
73
 
74
  with gr.Blocks(css=css) as demo:
 
75
  with gr.Column(elem_id="col-container"):
76
+ gr.Markdown("""# FLUX.1 [dev]
77
+ 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
78
  [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
79
  """)
80
+
81
  with gr.Row():
 
82
  prompt = gr.Text(
83
  label="Prompt",
84
  show_label=False,
85
  max_lines=1,
86
  placeholder="Enter your prompt",
87
+ submit_btn=True,
88
  )
 
 
 
89
  result = gr.Image(label="Result", show_label=False)
90
+
91
  with gr.Accordion("Advanced Settings", open=False):
 
92
  seed = gr.Slider(
93
  label="Seed",
94
  minimum=0,
 
96
  step=1,
97
  value=0,
98
  )
 
99
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
100
+
101
  with gr.Row():
 
102
  width = gr.Slider(
103
  label="Width",
104
  minimum=256,
 
106
  step=32,
107
  value=1024,
108
  )
109
+
110
  height = gr.Slider(
111
  label="Height",
112
  minimum=256,
 
114
  step=32,
115
  value=1024,
116
  )
 
 
117
 
118
+ with gr.Row():
119
  guidance_scale = gr.Slider(
120
  label="Guidance Scale",
121
  minimum=1,
 
123
  step=0.1,
124
  value=3.5,
125
  )
126
+
127
  num_inference_steps = gr.Slider(
128
  label="Number of inference steps",
129
  minimum=1,
 
131
  step=1,
132
  value=28,
133
  )
134
+
135
  gr.Examples(
136
+ examples=examples,
137
+ fn=infer,
138
+ inputs=prompt,
139
+ outputs=[result, seed],
140
+ cache_examples=True,
141
+ cache_mode="lazy",
142
  )
143
 
144
+ prompt.submit(
145
+ fn=infer,
146
+ inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
147
+ outputs=[result, seed],
 
148
  )
149
 
150
+ if __name__ == "__main__":
151
+ demo.launch(mcp_server=True)
live_preview_helpers.py DELETED
@@ -1,166 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
5
-
6
- # Helper functions
7
- def calculate_shift(
8
- image_seq_len,
9
- base_seq_len: int = 256,
10
- max_seq_len: int = 4096,
11
- base_shift: float = 0.5,
12
- max_shift: float = 1.16,
13
- ):
14
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
15
- b = base_shift - m * base_seq_len
16
- mu = image_seq_len * m + b
17
- return mu
18
-
19
- def retrieve_timesteps(
20
- scheduler,
21
- num_inference_steps: Optional[int] = None,
22
- device: Optional[Union[str, torch.device]] = None,
23
- timesteps: Optional[List[int]] = None,
24
- sigmas: Optional[List[float]] = None,
25
- **kwargs,
26
- ):
27
- if timesteps is not None and sigmas is not None:
28
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
29
- if timesteps is not None:
30
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
31
- timesteps = scheduler.timesteps
32
- num_inference_steps = len(timesteps)
33
- elif sigmas is not None:
34
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
35
- timesteps = scheduler.timesteps
36
- num_inference_steps = len(timesteps)
37
- else:
38
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
39
- timesteps = scheduler.timesteps
40
- return timesteps, num_inference_steps
41
-
42
- # FLUX pipeline function
43
- @torch.inference_mode()
44
- def flux_pipe_call_that_returns_an_iterable_of_images(
45
- self,
46
- prompt: Union[str, List[str]] = None,
47
- prompt_2: Optional[Union[str, List[str]]] = None,
48
- height: Optional[int] = None,
49
- width: Optional[int] = None,
50
- num_inference_steps: int = 28,
51
- timesteps: List[int] = None,
52
- guidance_scale: float = 3.5,
53
- num_images_per_prompt: Optional[int] = 1,
54
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
- latents: Optional[torch.FloatTensor] = None,
56
- prompt_embeds: Optional[torch.FloatTensor] = None,
57
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
- output_type: Optional[str] = "pil",
59
- return_dict: bool = True,
60
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
- max_sequence_length: int = 512,
62
- good_vae: Optional[Any] = None,
63
- ):
64
- height = height or self.default_sample_size * self.vae_scale_factor
65
- width = width or self.default_sample_size * self.vae_scale_factor
66
-
67
- # 1. Check inputs
68
- self.check_inputs(
69
- prompt,
70
- prompt_2,
71
- height,
72
- width,
73
- prompt_embeds=prompt_embeds,
74
- pooled_prompt_embeds=pooled_prompt_embeds,
75
- max_sequence_length=max_sequence_length,
76
- )
77
-
78
- self._guidance_scale = guidance_scale
79
- self._joint_attention_kwargs = joint_attention_kwargs
80
- self._interrupt = False
81
-
82
- # 2. Define call parameters
83
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
84
- device = self._execution_device
85
-
86
- # 3. Encode prompt
87
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
88
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
89
- prompt=prompt,
90
- prompt_2=prompt_2,
91
- prompt_embeds=prompt_embeds,
92
- pooled_prompt_embeds=pooled_prompt_embeds,
93
- device=device,
94
- num_images_per_prompt=num_images_per_prompt,
95
- max_sequence_length=max_sequence_length,
96
- lora_scale=lora_scale,
97
- )
98
- # 4. Prepare latent variables
99
- num_channels_latents = self.transformer.config.in_channels // 4
100
- latents, latent_image_ids = self.prepare_latents(
101
- batch_size * num_images_per_prompt,
102
- num_channels_latents,
103
- height,
104
- width,
105
- prompt_embeds.dtype,
106
- device,
107
- generator,
108
- latents,
109
- )
110
- # 5. Prepare timesteps
111
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
112
- image_seq_len = latents.shape[1]
113
- mu = calculate_shift(
114
- image_seq_len,
115
- self.scheduler.config.base_image_seq_len,
116
- self.scheduler.config.max_image_seq_len,
117
- self.scheduler.config.base_shift,
118
- self.scheduler.config.max_shift,
119
- )
120
- timesteps, num_inference_steps = retrieve_timesteps(
121
- self.scheduler,
122
- num_inference_steps,
123
- device,
124
- timesteps,
125
- sigmas,
126
- mu=mu,
127
- )
128
- self._num_timesteps = len(timesteps)
129
-
130
- # Handle guidance
131
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
132
-
133
- # 6. Denoising loop
134
- for i, t in enumerate(timesteps):
135
- if self.interrupt:
136
- continue
137
-
138
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
139
-
140
- noise_pred = self.transformer(
141
- hidden_states=latents,
142
- timestep=timestep / 1000,
143
- guidance=guidance,
144
- pooled_projections=pooled_prompt_embeds,
145
- encoder_hidden_states=prompt_embeds,
146
- txt_ids=text_ids,
147
- img_ids=latent_image_ids,
148
- joint_attention_kwargs=self.joint_attention_kwargs,
149
- return_dict=False,
150
- )[0]
151
- # Yield intermediate result
152
- latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
153
- latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
154
- image = self.vae.decode(latents_for_image, return_dict=False)[0]
155
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
156
-
157
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
158
- torch.cuda.empty_cache()
159
-
160
- # Final image using good_vae
161
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
162
- latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
163
- image = good_vae.decode(latents, return_dict=False)[0]
164
- self.maybe_free_model_hooks()
165
- torch.cuda.empty_cache()
166
- yield self.image_processor.postprocess(image, output_type=output_type)[0]