tuan2308 commited on
Commit
6f7ddd9
·
verified ·
1 Parent(s): ddffdaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -119
app.py CHANGED
@@ -1,116 +1,42 @@
1
  import gradio as gr
2
  import numpy as np
 
3
  import spaces
4
  import torch
5
  import random
6
  from PIL import Image
7
-
8
  from diffusers import FluxKontextPipeline
9
- from huggingface_hub import hf_hub_download
10
-
11
- # --------------------
12
- # Globals (CPU-only; DO NOT touch CUDA here)
13
- # --------------------
14
- MAX_SEED = np.iinfo(np.int32).max
15
-
16
- DEVICE = "cpu" # set in gpu_startup()
17
- DTYPE = torch.float32 # set in gpu_startup()
18
 
19
- PIPE = None # cached pipeline (loaded on first GPU run)
20
- LORA_LOADED = False
21
-
22
- # --------------------
23
- # ZeroGPU: allocate GPU at startup & decide dtype
24
- # --------------------
25
- @spaces.GPU
26
- def gpu_startup():
27
- """
28
- Runs when the Space boots on ZeroGPU. Safe place to detect CUDA and dtype.
29
- """
30
- global DEVICE, DTYPE
31
- has_cuda = torch.cuda.is_available()
32
- DEVICE = "cuda" if has_cuda else "cpu"
33
-
34
- # Prefer bfloat16 on CUDA if supported; otherwise fp16; CPU uses fp32
35
- if DEVICE == "cuda":
36
- if torch.cuda.is_bf16_supported():
37
- DTYPE = torch.bfloat16
38
- else:
39
- DTYPE = torch.float16
40
- else:
41
- DTYPE = torch.float32
42
-
43
- print(f"[startup] device={DEVICE}, dtype={DTYPE}")
44
-
45
- # --------------------
46
- # Lazy model loader (runs inside GPU context)
47
- # --------------------
48
- def get_pipeline():
49
- """
50
- Create/cache the FluxKontextPipeline and load LoRA once.
51
- Must be called from within a @spaces.GPU function.
52
- """
53
- global PIPE, LORA_LOADED
54
-
55
- if PIPE is None:
56
- # Load base pipeline with chosen dtype, then move to device
57
- PIPE = FluxKontextPipeline.from_pretrained(
58
- "black-forest-labs/FLUX.1-Kontext-dev",
59
- torch_dtype=DTYPE
60
- ).to(DEVICE)
61
 
62
- if not LORA_LOADED:
63
- # Load LoRA weights once and set adapter
64
- PIPE.load_lora_weights(
65
- "kontext-community/relighting-kontext-dev-lora-v3",
66
- weight_name="relighting-kontext-dev-lora-v3.safetensors",
67
- adapter_name="lora"
68
- )
69
- PIPE.set_adapters(["lora"], adapter_weights=[1.0])
70
- LORA_LOADED = True
71
 
72
- return PIPE
73
 
74
- # --------------------
75
- # Inference (GPU entrypoint)
76
- # --------------------
77
  @spaces.GPU
78
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
79
- pipe = get_pipeline()
80
-
81
  if randomize_seed:
82
  seed = random.randint(0, MAX_SEED)
83
-
84
- if input_image is None:
85
- return None, seed
86
-
87
  input_image = input_image.convert("RGB")
88
-
89
- prompt_with_template = (
90
- f"Change the lighting conditions in this image and add {prompt}. "
91
- "Change the background details but maintain the foreground. "
92
- "Lighting determines how bright or dark different parts of the image appear, "
93
- "where shadows fall, and how colors look. When you relight an image, "
94
- "you're simulating what the photo would look like under different lighting conditions."
95
- )
96
-
97
- generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
98
-
99
  image = pipe(
100
- image=input_image,
101
  prompt=prompt_with_template,
102
- guidance_scale=float(guidance_scale),
103
  width=input_image.size[0],
104
  height=input_image.size[1],
105
- generator=generator,
106
  ).images[0]
107
-
108
  return image, seed
109
 
110
- # --------------------
111
- # UI
112
- # --------------------
113
- css = """
114
  #col-container {
115
  margin: 0 auto;
116
  max-width: 960px;
@@ -118,9 +44,12 @@ css = """
118
  """
119
 
120
  with gr.Blocks(css=css) as demo:
 
121
  with gr.Column(elem_id="col-container"):
122
- gr.Markdown("# FLUX.1 Kontext [dev] Relight 💡")
123
- gr.Markdown("Kontext[dev] used for object relighting ✨")
 
 
124
 
125
  with gr.Row():
126
  with gr.Column():
@@ -134,36 +63,48 @@ with gr.Blocks(css=css) as demo:
134
  container=False,
135
  )
136
  run_button = gr.Button("Run", scale=0)
137
-
138
  with gr.Accordion("Advanced Settings", open=False):
139
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
 
 
 
 
 
 
 
140
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
141
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
142
-
 
 
 
 
 
 
 
143
  with gr.Column():
144
  result = gr.Image(label="Result", show_label=False, interactive=False)
145
-
146
  gr.Examples(
147
- examples=[
148
- ["./assets/5_before.png", "sunset over sea lighting coming from the top right part of the photo", 0, True, 2.5],
149
- ["./assets/3_before.png", "sci-fi RGB glowing, studio lighting", 0, True, 2.5],
150
- ["./assets/2_before.png", "neon light, city", 0, True, 2.5],
151
- ["./assets/before_6.png", "bright sunlight, warm, luminous", 0, True, 2.5],
152
- ],
153
- inputs=[input_image, prompt, seed, randomize_seed, guidance_scale],
154
- outputs=[result, seed],
155
- fn=infer,
156
- cache_examples="lazy",
157
- )
158
-
159
- gr.on(
160
- triggers=[run_button.click, prompt.submit],
161
- fn=infer,
162
- inputs=[input_image, prompt, seed, randomize_seed, guidance_scale],
163
- outputs=[result, seed],
164
- )
 
165
 
166
- if __name__ == "__main__":
167
- # Ensure ZeroGPU allocates a GPU at boot and dtype/device are decided early
168
- gpu_startup()
169
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import numpy as np
3
+
4
  import spaces
5
  import torch
6
  import random
7
  from PIL import Image
 
8
  from diffusers import FluxKontextPipeline
9
+ from diffusers import FluxTransformer2DModel
10
+ from diffusers.utils import load_image
 
 
 
 
 
 
 
11
 
12
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
15
+ pipe.load_lora_weights("kontext-community/relighting-kontext-dev-lora-v3", weight_name="relighting-kontext-dev-lora-v3.safetensors", adapter_name="lora")
16
+ pipe.set_adapters(["lora"], adapter_weights=[1.0])
 
 
 
 
 
 
17
 
18
+ MAX_SEED = np.iinfo(np.int32).max
19
 
 
 
 
20
  @spaces.GPU
21
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
22
+
 
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
+
 
 
 
26
  input_image = input_image.convert("RGB")
27
+ prompt_with_template = f"Change the lighting conditions in this image and add {prompt}. change the background details but maintain the forground. Lighting determines how bright or dark different parts of the image appear, where shadows fall, and how colors look. When you relight an image, you're simulating what the photo would look like if it were taken under different lighting conditions."
28
+
 
 
 
 
 
 
 
 
 
29
  image = pipe(
30
+ image=input_image,
31
  prompt=prompt_with_template,
32
+ guidance_scale=guidance_scale,
33
  width=input_image.size[0],
34
  height=input_image.size[1],
35
+ generator=torch.Generator().manual_seed(seed),
36
  ).images[0]
 
37
  return image, seed
38
 
39
+ css="""
 
 
 
40
  #col-container {
41
  margin: 0 auto;
42
  max-width: 960px;
 
44
  """
45
 
46
  with gr.Blocks(css=css) as demo:
47
+
48
  with gr.Column(elem_id="col-container"):
49
+ gr.Markdown(f"""# FLUX.1 Kontext [dev] Relight 💡
50
+ """)
51
+ gr.Markdown(f"""Kontext[dev] used for object relighting ✨
52
+ """)
53
 
54
  with gr.Row():
55
  with gr.Column():
 
63
  container=False,
64
  )
65
  run_button = gr.Button("Run", scale=0)
 
66
  with gr.Accordion("Advanced Settings", open=False):
67
+
68
+ seed = gr.Slider(
69
+ label="Seed",
70
+ minimum=0,
71
+ maximum=MAX_SEED,
72
+ step=1,
73
+ value=0,
74
+ )
75
+
76
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
77
+
78
+ guidance_scale = gr.Slider(
79
+ label="Guidance Scale",
80
+ minimum=1,
81
+ maximum=10,
82
+ step=0.1,
83
+ value=2.5,
84
+ )
85
+
86
  with gr.Column():
87
  result = gr.Image(label="Result", show_label=False, interactive=False)
88
+
89
  gr.Examples(
90
+ examples=[
91
+ ["./assets/5_before.png", "sunset over sea lighting coming from the top right part of the photo", 0, True, 2.5],
92
+ ["./assets/3_before.png", "sci-fi RGB glowing, studio lighting",0, True,2.5],
93
+ ["./assets/2_before.png", "neon light, city",0, True, 2.5],
94
+ ["./assets/before_6.png", "bright sunlight, warm, luminous", 0, True, 2.5]
95
+ ],
96
+ inputs=[input_image, prompt, seed, randomize_seed, guidance_scale],
97
+ outputs=[result, seed],
98
+ fn=infer,
99
+ cache_examples="lazy"
100
+ )
101
+
102
+ gr.on(
103
+ triggers=[run_button.click, prompt.submit],
104
+ fn = infer,
105
+ inputs = [input_image, prompt, seed, randomize_seed, guidance_scale],
106
+ outputs = [result, seed]
107
+ )
108
+
109
 
110
+ demo.launch()