|
import torch |
|
from diffusers import FluxImg2ImgPipeline |
|
from PIL import Image |
|
import sys |
|
import spaces |
|
|
|
|
|
|
|
@spaces.GPU |
|
def process_image(image, mask_image, prompt="a person", model_id="black-forest-labs/FLUX.1-schnell", strength=0.75, seed=0, num_inference_steps=4): |
|
print("Starting process_image") |
|
if image is None: |
|
print("Empty input image returned.") |
|
return None |
|
|
|
|
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
|
|
pipe = FluxImg2ImgPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
print(prompt) |
|
output = pipe( |
|
prompt=prompt, |
|
image=image, |
|
generator=generator, |
|
strength=strength, |
|
guidance_scale=0, |
|
num_inference_steps=num_inference_steps, |
|
max_sequence_length=256 |
|
) |
|
|
|
|
|
return output.images[0] |
|
|
|
if __name__ == "__main__": |
|
|
|
image = Image.open(sys.argv[1]) |
|
mask = Image.open(sys.argv[2]) |
|
output = process_image(image, mask) |
|
output.save(sys.argv[3]) |
|
|