Akjava commited on
Commit
ba8ad6f
·
1 Parent(s): d9f0652

resize back 32

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -58,9 +58,9 @@ def process_images(image,prompt="a girl",strength=0.75,seed=0,inference_step=4,p
58
  generators = []
59
  generator = torch.Generator(device).manual_seed(seed)
60
  generators.append(generator)
61
- width,height = convert_to_fit_size(image.size)
62
  #print(f"fit {width}x{height}")
63
- width,height = adjust_to_multiple_of_32(width,height)
64
  #print(f"multiple {width}x{height}")
65
  image = image.resize((width, height), Image.LANCZOS)
66
  #mask_image = mask_image.resize((width, height), Image.NEAREST)
@@ -70,8 +70,15 @@ def process_images(image,prompt="a girl",strength=0.75,seed=0,inference_step=4,p
70
  output = pipe(prompt=prompt, image=image,generator=generator,strength=strength,width=width,height=height
71
  ,guidance_scale=0,num_inference_steps=num_inference_steps,max_sequence_length=256)
72
 
73
- # TODO support mask
74
- return output.images[0]
 
 
 
 
 
 
 
75
 
76
  output = process_img2img(image,prompt,strength,seed,inference_step)
77
 
 
58
  generators = []
59
  generator = torch.Generator(device).manual_seed(seed)
60
  generators.append(generator)
61
+ fit_width,fit_height = convert_to_fit_size(image.size)
62
  #print(f"fit {width}x{height}")
63
+ width,height = adjust_to_multiple_of_32(fit_width,fit_height)
64
  #print(f"multiple {width}x{height}")
65
  image = image.resize((width, height), Image.LANCZOS)
66
  #mask_image = mask_image.resize((width, height), Image.NEAREST)
 
70
  output = pipe(prompt=prompt, image=image,generator=generator,strength=strength,width=width,height=height
71
  ,guidance_scale=0,num_inference_steps=num_inference_steps,max_sequence_length=256)
72
 
73
+ pil_image = Image.fromarray(output.images[0])
74
+ new_width,new_height = pil_image.size
75
+
76
+ # resize back multiple of 32
77
+ if (new_width!=fit_width) or (new_height!=fit_height):
78
+ resized_image= pil_image.resize(fit_width,fit_height,Image.LANCZOS)
79
+ return resized_image
80
+
81
+ return pil_image
82
 
83
  output = process_img2img(image,prompt,strength,seed,inference_step)
84