Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
import time | |
import os | |
from diffusers import DiffusionPipeline | |
from custom_pipeline import FLUXPipelineWithIntermediateOutputs | |
from transformers import pipeline | |
# ๋ฒ์ญ ๋ชจ๋ธ ์ค์ (CPU ์ฌ์ฉ) | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu") | |
# ์์ ์ ์ | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
DEFAULT_WIDTH = 1024 | |
DEFAULT_HEIGHT = 1024 | |
DEFAULT_INFERENCE_STEPS = 1 | |
GPU_DURATION = 15 # GPU ํ ๋น ์๊ฐ ์ถ์ | |
# ๋ชจ๋ธ ์ค์ | |
def setup_model(): | |
dtype = torch.float16 | |
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
torch_dtype=dtype | |
).to("cuda") | |
return pipe | |
pipe = setup_model() | |
# ๋ฉ๋ด ๋ ์ด๋ธ | |
labels = { | |
"Generated Image": "์์ฑ๋ ์ด๋ฏธ์ง", | |
"Prompt": "ํ๋กฌํํธ", | |
"Enhance Image": "์ด๋ฏธ์ง ํฅ์", | |
"Advanced Options": "๊ณ ๊ธ ์ค์ ", | |
"Seed": "์๋", | |
"Randomize Seed": "๋๋ค ์๋", | |
"Width": "๋๋น", | |
"Height": "๋์ด", | |
"Inference Steps": "์ถ๋ก ๋จ๊ณ", | |
"Inspiration Gallery": "์๊ฐ ๊ฐค๋ฌ๋ฆฌ" | |
} | |
def translate_if_korean(text): | |
"""ํ๊ธ ํ ์คํธ๋ฅผ ์์ด๋ก ์์ ํ๊ฒ ๋ฒ์ญ""" | |
try: | |
if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text): | |
return translator(text)[0]['translation_text'] | |
return text | |
except Exception as e: | |
print(f"๋ฒ์ญ ์ค๋ฅ: {e}") | |
return text | |
# ์ด๋ฏธ์ง ์์ฑ ํจ์ | |
def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, | |
randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS): | |
try: | |
# ์ ๋ ฅ๊ฐ ๊ฒ์ฆ | |
if not isinstance(seed, (int, type(None))): | |
seed = None | |
randomize_seed = True | |
prompt = translate_if_korean(prompt) | |
if seed is None or randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
# ํฌ๊ธฐ ์ ํจ์ฑ ๊ฒ์ฌ | |
width = min(max(256, width), MAX_IMAGE_SIZE) | |
height = min(max(256, height), MAX_IMAGE_SIZE) | |
generator = torch.Generator().manual_seed(seed) | |
start_time = time.time() | |
with torch.cuda.amp.autocast(): | |
for img in pipe.generate_images( | |
prompt=prompt, | |
guidance_scale=0, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator | |
): | |
latency = f"์ฒ๋ฆฌ ์๊ฐ: {(time.time()-start_time):.2f} ์ด" | |
# CUDA ์บ์ ์ ๋ฆฌ | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
yield img, seed, latency | |
except Exception as e: | |
print(f"์ด๋ฏธ์ง ์์ฑ ์ค๋ฅ: {e}") | |
yield None, seed, f"์ค๋ฅ: {str(e)}" | |
# ์์ ์ด๋ฏธ์ง ์์ฑ | |
def generate_example_image(prompt): | |
try: | |
return next(generate_image(prompt, randomize_seed=True)) | |
except Exception as e: | |
print(f"์์ ์์ฑ ์ค๋ฅ: {e}") | |
return None, None, f"์ค๋ฅ: {str(e)}" | |
# Example prompts | |
examples = [ | |
"๋น๋ ์๋์ฒผ์ ์ ๋๋ฉ์ด์ ์ผ๋ฌ์คํธ๋ ์ด์ ", | |
"A steampunk owl wearing Victorian-era clothing and reading a mechanical book", | |
"A floating island made of books with waterfalls of knowledge cascading down", | |
"A bioluminescent forest where mushrooms glow like neon signs in a cyberpunk city", | |
"An ancient temple being reclaimed by nature, with robots performing archaeology", | |
"A cosmic coffee shop where baristas are constellations serving drinks made of stardust" | |
] | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
def create_snow_effect(): | |
# CSS ์คํ์ผ ์ ์ | |
snow_css = """ | |
@keyframes snowfall { | |
0% { | |
transform: translateY(-10vh) translateX(0); | |
opacity: 1; | |
} | |
100% { | |
transform: translateY(100vh) translateX(100px); | |
opacity: 0.3; | |
} | |
} | |
.snowflake { | |
position: fixed; | |
color: white; | |
font-size: 1.5em; | |
user-select: none; | |
z-index: 1000; | |
pointer-events: none; | |
animation: snowfall linear infinite; | |
} | |
""" | |
# JavaScript ์ฝ๋ ์ ์ | |
snow_js = """ | |
function createSnowflake() { | |
const snowflake = document.createElement('div'); | |
snowflake.innerHTML = 'โ'; | |
snowflake.className = 'snowflake'; | |
snowflake.style.left = Math.random() * 100 + 'vw'; | |
snowflake.style.animationDuration = Math.random() * 3 + 2 + 's'; | |
snowflake.style.opacity = Math.random(); | |
document.body.appendChild(snowflake); | |
setTimeout(() => { | |
snowflake.remove(); | |
}, 5000); | |
} | |
setInterval(createSnowflake, 200); | |
""" | |
# CSS์ JavaScript๋ฅผ ๊ฒฐํฉํ HTML | |
snow_html = f""" | |
<style> | |
{snow_css} | |
</style> | |
<script> | |
{snow_js} | |
</script> | |
""" | |
return gr.HTML(snow_html) | |
# Gradio ์ฑ์์ ์ฌ์ฉํ ๋: | |
# with app: ์๋์ | |
# Gradio UI ๊ตฌ์ฑ | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
create_snow_effect() | |
with gr.Column(elem_id="app-container"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
result = gr.Image(label=labels["Generated Image"], | |
show_label=False, | |
interactive=False) | |
with gr.Column(scale=1): | |
prompt = gr.Text( | |
label=labels["Prompt"], | |
placeholder="์์ฑํ๊ณ ์ถ์ ์ด๋ฏธ์ง๋ฅผ ์ค๋ช ํด์ฃผ์ธ์...", | |
lines=3, | |
show_label=False, | |
container=False, | |
) | |
enhanceBtn = gr.Button(f"๐ {labels['Enhance Image']}") | |
with gr.Column(labels["Advanced Options"]): | |
with gr.Row(): | |
latency = gr.Text(show_label=False) | |
with gr.Row(): | |
seed = gr.Number( | |
label=labels["Seed"], | |
value=42, | |
precision=0, | |
minimum=0, | |
maximum=MAX_SEED | |
) | |
randomize_seed = gr.Checkbox( | |
label=labels["Randomize Seed"], | |
value=True | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label=labels["Width"], | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=DEFAULT_WIDTH | |
) | |
height = gr.Slider( | |
label=labels["Height"], | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=DEFAULT_HEIGHT | |
) | |
num_inference_steps = gr.Slider( | |
label=labels["Inference Steps"], | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=DEFAULT_INFERENCE_STEPS | |
) | |
with gr.Row(): | |
gr.Markdown(f"### ๐ {labels['Inspiration Gallery']}") | |
with gr.Row(): | |
gr.Examples( | |
examples=examples, | |
fn=generate_example_image, | |
inputs=[prompt], | |
outputs=[result, seed], | |
cache_examples=False | |
) | |
# ์ด๋ฒคํธ ์ฒ๋ฆฌ | |
def validated_generate(*args): | |
try: | |
return next(generate_image(*args)) | |
except Exception as e: | |
print(f"๊ฒ์ฆ ์์ฑ ์ค๋ฅ: {e}") | |
return None, args[1], f"์ค๋ฅ: {str(e)}" | |
enhanceBtn.click( | |
fn=generate_image, | |
inputs=[prompt, seed, width, height], | |
outputs=[result, seed, latency], | |
show_progress="hidden", | |
show_api=False, | |
queue=False | |
) | |
gr.on( | |
triggers=[prompt.input, width.input, height.input, num_inference_steps.input], | |
fn=validated_generate, | |
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps], | |
outputs=[result, seed, latency], | |
show_progress="hidden", | |
show_api=False, | |
trigger_mode="always_last", | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch() | |