Spaces:
Runtime error
Runtime error
import torch | |
from unittest.mock import patch, MagicMock | |
# Mock nodes module to prevent CUDA initialization during import | |
mock_nodes = MagicMock() | |
mock_nodes.MAX_RESOLUTION = 16384 | |
# Mock server module for PromptServer | |
mock_server = MagicMock() | |
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}): | |
from comfy_extras.nodes_images import ImageStitch | |
class TestImageStitch: | |
def create_test_image(self, batch_size=1, height=64, width=64, channels=3): | |
"""Helper to create test images with specific dimensions""" | |
return torch.rand(batch_size, height, width, channels) | |
def test_no_image2_passthrough(self): | |
"""Test that when image2 is None, image1 is returned unchanged""" | |
node = ImageStitch() | |
image1 = self.create_test_image() | |
result = node.stitch(image1, "right", True, 0, "white", image2=None) | |
assert len(result) == 1 | |
assert torch.equal(result[0], image1) | |
def test_basic_horizontal_stitch_right(self): | |
"""Test basic horizontal stitching to the right""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=32, width=24) | |
result = node.stitch(image1, "right", False, 0, "white", image2) | |
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width | |
def test_basic_horizontal_stitch_left(self): | |
"""Test basic horizontal stitching to the left""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=32, width=24) | |
result = node.stitch(image1, "left", False, 0, "white", image2) | |
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width | |
def test_basic_vertical_stitch_down(self): | |
"""Test basic vertical stitching downward""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=24, width=32) | |
result = node.stitch(image1, "down", False, 0, "white", image2) | |
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height | |
def test_basic_vertical_stitch_up(self): | |
"""Test basic vertical stitching upward""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=24, width=32) | |
result = node.stitch(image1, "up", False, 0, "white", image2) | |
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height | |
def test_size_matching_horizontal(self): | |
"""Test size matching for horizontal concatenation""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=64, width=64) | |
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio | |
result = node.stitch(image1, "right", True, 0, "white", image2) | |
# image2 should be resized to match image1's height (64) with preserved aspect ratio | |
expected_width = 64 + 64 # original + resized (32*64/32 = 64) | |
assert result[0].shape == (1, 64, expected_width, 3) | |
def test_size_matching_vertical(self): | |
"""Test size matching for vertical concatenation""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=64, width=64) | |
image2 = self.create_test_image(height=32, width=32) | |
result = node.stitch(image1, "down", True, 0, "white", image2) | |
# image2 should be resized to match image1's width (64) with preserved aspect ratio | |
expected_height = 64 + 64 # original + resized (32*64/32 = 64) | |
assert result[0].shape == (1, expected_height, 64, 3) | |
def test_padding_for_mismatched_heights_horizontal(self): | |
"""Test padding when heights don't match in horizontal concatenation""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=64, width=32) | |
image2 = self.create_test_image(height=48, width=24) # Shorter height | |
result = node.stitch(image1, "right", False, 0, "white", image2) | |
# Both images should be padded to height 64 | |
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height | |
def test_padding_for_mismatched_widths_vertical(self): | |
"""Test padding when widths don't match in vertical concatenation""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=64) | |
image2 = self.create_test_image(height=24, width=48) # Narrower width | |
result = node.stitch(image1, "down", False, 0, "white", image2) | |
# Both images should be padded to width 64 | |
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width | |
def test_spacing_horizontal(self): | |
"""Test spacing addition in horizontal concatenation""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=32, width=24) | |
spacing_width = 16 | |
result = node.stitch(image1, "right", False, spacing_width, "white", image2) | |
# Expected width: 32 + 16 (spacing) + 24 = 72 | |
assert result[0].shape == (1, 32, 72, 3) | |
def test_spacing_vertical(self): | |
"""Test spacing addition in vertical concatenation""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=24, width=32) | |
spacing_width = 16 | |
result = node.stitch(image1, "down", False, spacing_width, "white", image2) | |
# Expected height: 32 + 16 (spacing) + 24 = 72 | |
assert result[0].shape == (1, 72, 32, 3) | |
def test_spacing_color_values(self): | |
"""Test that spacing colors are applied correctly""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=32, width=32) | |
# Test white spacing | |
result_white = node.stitch(image1, "right", False, 16, "white", image2) | |
# Check that spacing region contains white values (close to 1.0) | |
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels | |
assert torch.all(spacing_region >= 0.9) # Should be close to white | |
# Test black spacing | |
result_black = node.stitch(image1, "right", False, 16, "black", image2) | |
spacing_region = result_black[0][:, :, 32:48, :] | |
assert torch.all(spacing_region <= 0.1) # Should be close to black | |
def test_odd_spacing_width_made_even(self): | |
"""Test that odd spacing widths are made even""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=32, width=32) | |
# Use odd spacing width | |
result = node.stitch(image1, "right", False, 15, "white", image2) | |
# Should be made even (16), so total width = 32 + 16 + 32 = 80 | |
assert result[0].shape == (1, 32, 80, 3) | |
def test_batch_size_matching(self): | |
"""Test that different batch sizes are handled correctly""" | |
node = ImageStitch() | |
image1 = self.create_test_image(batch_size=2, height=32, width=32) | |
image2 = self.create_test_image(batch_size=1, height=32, width=32) | |
result = node.stitch(image1, "right", False, 0, "white", image2) | |
# Should match larger batch size | |
assert result[0].shape == (2, 32, 64, 3) | |
def test_channel_matching_rgb_to_rgba(self): | |
"""Test that channel differences are handled (RGB + alpha)""" | |
node = ImageStitch() | |
image1 = self.create_test_image(channels=3) # RGB | |
image2 = self.create_test_image(channels=4) # RGBA | |
result = node.stitch(image1, "right", False, 0, "white", image2) | |
# Should have 4 channels (RGBA) | |
assert result[0].shape[-1] == 4 | |
def test_channel_matching_rgba_to_rgb(self): | |
"""Test that channel differences are handled (RGBA + RGB)""" | |
node = ImageStitch() | |
image1 = self.create_test_image(channels=4) # RGBA | |
image2 = self.create_test_image(channels=3) # RGB | |
result = node.stitch(image1, "right", False, 0, "white", image2) | |
# Should have 4 channels (RGBA) | |
assert result[0].shape[-1] == 4 | |
def test_all_color_options(self): | |
"""Test all available color options""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=32, width=32) | |
colors = ["white", "black", "red", "green", "blue"] | |
for color in colors: | |
result = node.stitch(image1, "right", False, 16, color, image2) | |
assert result[0].shape == (1, 32, 80, 3) # Basic shape check | |
def test_all_directions(self): | |
"""Test all direction options""" | |
node = ImageStitch() | |
image1 = self.create_test_image(height=32, width=32) | |
image2 = self.create_test_image(height=32, width=32) | |
directions = ["right", "left", "up", "down"] | |
for direction in directions: | |
result = node.stitch(image1, direction, False, 0, "white", image2) | |
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3) | |
def test_batch_size_channel_spacing_integration(self): | |
"""Test integration of batch matching, channel matching, size matching, and spacings""" | |
node = ImageStitch() | |
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3) | |
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4) | |
result = node.stitch(image1, "right", True, 8, "red", image2) | |
# Should handle: batch matching, size matching, channel matching, spacing | |
assert result[0].shape[0] == 2 # Batch size matched | |
assert result[0].shape[-1] == 4 # Channels matched to max | |
assert result[0].shape[1] == 64 # Height from image1 (size matching) | |
# Width should be: 48 + 8 (spacing) + resized_image2_width | |
expected_image2_width = int(64 * (32/32)) # Resized to height 64 | |
expected_total_width = 48 + 8 + expected_image2_width | |
assert result[0].shape[2] == expected_total_width | |