|
|
import torch |
|
|
|
|
|
from typing_extensions import override |
|
|
from comfy_api.latest import ComfyExtension, io |
|
|
|
|
|
|
|
|
class InstructPixToPixConditioning(io.ComfyNode): |
|
|
@classmethod |
|
|
def define_schema(cls): |
|
|
return io.Schema( |
|
|
node_id="InstructPixToPixConditioning", |
|
|
category="conditioning/instructpix2pix", |
|
|
inputs=[ |
|
|
io.Conditioning.Input("positive"), |
|
|
io.Conditioning.Input("negative"), |
|
|
io.Vae.Input("vae"), |
|
|
io.Image.Input("pixels"), |
|
|
], |
|
|
outputs=[ |
|
|
io.Conditioning.Output(display_name="positive"), |
|
|
io.Conditioning.Output(display_name="negative"), |
|
|
io.Latent.Output(display_name="latent"), |
|
|
], |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput: |
|
|
x = (pixels.shape[1] // 8) * 8 |
|
|
y = (pixels.shape[2] // 8) * 8 |
|
|
|
|
|
if pixels.shape[1] != x or pixels.shape[2] != y: |
|
|
x_offset = (pixels.shape[1] % 8) // 2 |
|
|
y_offset = (pixels.shape[2] % 8) // 2 |
|
|
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] |
|
|
|
|
|
concat_latent = vae.encode(pixels) |
|
|
|
|
|
out_latent = {} |
|
|
out_latent["samples"] = torch.zeros_like(concat_latent) |
|
|
|
|
|
out = [] |
|
|
for conditioning in [positive, negative]: |
|
|
c = [] |
|
|
for t in conditioning: |
|
|
d = t[1].copy() |
|
|
d["concat_latent_image"] = concat_latent |
|
|
n = [t[0], d] |
|
|
c.append(n) |
|
|
out.append(c) |
|
|
return io.NodeOutput(out[0], out[1], out_latent) |
|
|
|
|
|
|
|
|
class InstructPix2PixExtension(ComfyExtension): |
|
|
@override |
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]: |
|
|
return [ |
|
|
InstructPixToPixConditioning, |
|
|
] |
|
|
|
|
|
|
|
|
async def comfy_entrypoint() -> InstructPix2PixExtension: |
|
|
return InstructPix2PixExtension() |
|
|
|
|
|
|