ttoosi commited on
Commit
7449d44
·
verified ·
1 Parent(s): f7dbbd9

Upload 11 files

Browse files

direct initial upload

.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stimuli/figure_ground.png filter=lfs diff=lfs merge=lfs -text
37
+ stimuli/NeonColorSaeedi.jpg filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ git \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy only requirements to leverage Docker caching
14
+ COPY ./requirements.txt /code/requirements.txt
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
18
+
19
+ # Copy all code and data
20
+ COPY . /code/
21
+
22
+ # Create necessary directories
23
+ RUN mkdir -p /code/models
24
+ RUN mkdir -p /code/stimuli
25
+
26
+ # Make sure stimuli and models are writable
27
+ RUN chmod -R 777 /code/models
28
+ RUN chmod -R 777 /code/stimuli
29
+
30
+ # Set up the command to run the app
31
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 GenerativeInferenceDemo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,83 @@
1
  ---
2
- title: GenerativeInferenceDemo
3
- emoji: 🚀
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.23.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: Generative Inference enables ai to see illusions out-of-box
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Generative Inference Demo
3
+ emoji: 🧠
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
 
11
  ---
12
 
13
+ # Generative Inference Demo
14
+
15
+ This Gradio demo showcases how neural networks perceive visual illusions through generative inference. The demo uses both standard and robust ResNet50 models to reveal emergent perception of contours, figure-ground separation, and other visual phenomena.
16
+
17
+ ## Models
18
+
19
+ - **Robust ResNet50**: A model trained with adversarial examples (ε=3.0), exhibiting more human-like visual perception
20
+ - **Standard ResNet50**: A model trained without adversarial examples (ε=0.0)
21
+
22
+ ## Features
23
+
24
+ - Upload your own images or use example illusions
25
+ - Choose between robust and standard models
26
+ - Adjust perturbation size (epsilon) and iteration count
27
+ - Visualize how perception emerges over time
28
+ - Includes classic illusions:
29
+ - Kanizsa shapes
30
+ - Face-Vase illusions
31
+ - Figure-Ground segmentation
32
+ - Neon color spreading
33
+
34
+ ## Usage
35
+
36
+ 1. Select an example image or upload your own
37
+ 2. Choose the model type (robust or standard)
38
+ 3. Adjust epsilon and iteration parameters
39
+ 4. Click "Run Inference" to see how the model perceives the image
40
+
41
+ ## About
42
+
43
+ This demo is based on research showing how adversarially robust models develop more human-like visual representations. The generative inference process reveals these perceptual biases by optimizing the input to maximize the model's confidence.
44
+
45
+ ## Installation
46
+
47
+ To run this demo locally:
48
+
49
+ ```bash
50
+ # Clone the repository
51
+ git clone [repo-url]
52
+ cd GenerativeInferenceDemo
53
+
54
+ # Install dependencies
55
+ pip install -r requirements.txt
56
+
57
+ # Run the app
58
+ python app.py
59
+ ```
60
+
61
+ The web app will be available at http://localhost:7860 (or another port if 7860 is busy).
62
+
63
+ ## About the Models
64
+
65
+ - **Robust ResNet50**: A model trained with adversarial examples, making it more robust to small perturbations. These models often exhibit more human-like visual perception.
66
+ - **Standard ResNet50**: A standard ImageNet-trained ResNet50 model.
67
+
68
+ ## How It Works
69
+
70
+ 1. The algorithm starts with an input image
71
+ 2. It iteratively updates the image to increase the model's confidence in its predictions
72
+ 3. These updates are constrained to a small neighborhood (controlled by epsilon) around the original image
73
+ 4. The resulting changes reveal how the network "sees" the image
74
+
75
+ ## Citation
76
+
77
+ If you use this work in your research, please cite the original paper:
78
+
79
+ [Citation information will be added here]
80
+
81
+ ## License
82
+
83
+ This project is licensed under the MIT License - see the LICENSE file for details.
app.py CHANGED
@@ -1,7 +1,115 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import os
6
+ import argparse
7
+ from inference import GenerativeInferenceModel, get_inference_configs
8
 
9
+ # Parse command line arguments
10
+ parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
11
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
12
+ args = parser.parse_args()
13
 
14
+ # Create model directories if they don't exist
15
+ os.makedirs("models", exist_ok=True)
16
+ os.makedirs("stimuli", exist_ok=True)
17
+
18
+ # Initialize model
19
+ model = GenerativeInferenceModel()
20
+
21
+ def run_inference(image, model_type, illusion_type, eps_value, num_iterations):
22
+ # Convert eps to float
23
+ eps = float(eps_value)
24
+
25
+ # Load inference configuration
26
+ config = get_inference_configs(eps=eps, n_itr=int(num_iterations))
27
+
28
+ # Run generative inference
29
+ output_images, all_steps = model.inference(image, model_type, config)
30
+
31
+ # Create animation frames
32
+ frames = []
33
+ for i, step_image in enumerate(all_steps):
34
+ # Convert tensor to PIL image
35
+ step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
36
+ frames.append(step_pil)
37
+
38
+ # Return the final inferred image and the animation
39
+ return output_images, gr.Gallery.update(value=frames)
40
+
41
+ # Define the interface
42
+ with gr.Blocks(title="Generative Inference Demo") as demo:
43
+ gr.Markdown("# Generative Inference Demo")
44
+ gr.Markdown("This demo showcases how neural networks can perceive visual illusions through generative inference.")
45
+
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ # Inputs
49
+ image_input = gr.Image(label="Upload Image or Select an Illusion", type="pil")
50
+
51
+ with gr.Row():
52
+ model_choice = gr.Dropdown(
53
+ choices=["robust_resnet50", "standard_resnet50"],
54
+ value="robust_resnet50",
55
+ label="Model"
56
+ )
57
+
58
+ illusion_type = gr.Dropdown(
59
+ choices=["Kanizsa", "Face-Vase", "Neon-Color", "Figure-Ground"],
60
+ value="Kanizsa",
61
+ label="Illusion Type"
62
+ )
63
+
64
+ with gr.Row():
65
+ eps_slider = gr.Slider(minimum=0.01, maximum=3.0, value=0.5, step=0.01, label="Epsilon (Perturbation Size)")
66
+ iterations_slider = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Number of Iterations")
67
+
68
+ run_button = gr.Button("Run Inference")
69
+
70
+ with gr.Column(scale=2):
71
+ # Outputs
72
+ output_image = gr.Image(label="Final Inferred Image")
73
+ output_frames = gr.Gallery(label="Inference Steps", columns=4, rows=2)
74
+
75
+ # Set up example images
76
+ examples = [
77
+ [os.path.join("stimuli", "Kanizsa_square.jpg"), "robust_resnet50", "Kanizsa", 0.5, 50],
78
+ [os.path.join("stimuli", "face_vase.png"), "robust_resnet50", "Face-Vase", 0.5, 50],
79
+ [os.path.join("stimuli", "figure_ground.png"), "robust_resnet50", "Figure-Ground", 0.7, 100],
80
+ [os.path.join("stimuli", "NeonColorSaeedi.jpg"), "robust_resnet50", "Neon-Color", 0.3, 80]
81
+ ]
82
+
83
+ gr.Examples(examples=examples, inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider])
84
+
85
+ # Set up event handler
86
+ run_button.click(
87
+ fn=run_inference,
88
+ inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider],
89
+ outputs=[output_image, output_frames]
90
+ )
91
+
92
+ # Include a description of the technique
93
+ gr.Markdown("""
94
+ ## About Generative Inference
95
+
96
+ Generative inference is a technique that reveals how neural networks perceive visual stimuli by optimizing the input
97
+ to increase the network's confidence in its predictions. This process can reveal emergent perception of contours,
98
+ figure-ground separation, and other visual phenomena similar to human perception.
99
+
100
+ This demo allows you to:
101
+ 1. Upload your own images or select from example illusions
102
+ 2. Choose between robust or standard models
103
+ 3. Adjust parameters like perturbation size (epsilon) and number of iterations
104
+ 4. Visualize how the perception emerges over time
105
+ """)
106
+
107
+ # Launch the demo with specific settings
108
+ if __name__ == "__main__":
109
+ print(f"Starting server on port {args.port}")
110
+ demo.launch(
111
+ server_name="0.0.0.0", # Listen on all interfaces
112
+ server_port=args.port, # Use the port from command line arguments
113
+ share=False,
114
+ debug=True
115
+ )
huggingface-metadata.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "Generative Inference Demo",
3
+ "emoji": "🧠",
4
+ "colorFrom": "indigo",
5
+ "colorTo": "purple",
6
+ "sdk": "gradio",
7
+ "sdk_version": "3.32.0",
8
+ "app_file": "app.py",
9
+ "pinned": false,
10
+ "license": "mit"
11
+ }
inference.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from torchvision.models.resnet import ResNet50_Weights
7
+ from PIL import Image
8
+ import numpy as np
9
+ import os
10
+ import requests
11
+ import time
12
+ from pathlib import Path
13
+
14
+ # Check CUDA availability
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"Using device: {device}")
17
+
18
+ # Constants
19
+ MODEL_URLS = {
20
+ 'robust_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_3.0.pt',
21
+ 'standard_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_0.0.pt'
22
+ }
23
+
24
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
+ IMAGENET_STD = [0.229, 0.224, 0.225]
26
+
27
+ # Default transform
28
+ transform = transforms.Compose([
29
+ transforms.Resize(224),
30
+ transforms.CenterCrop(224),
31
+ transforms.ToTensor(),
32
+ ])
33
+
34
+ normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
35
+
36
+ # Get ImageNet labels
37
+ def get_imagenet_labels():
38
+ url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
39
+ response = requests.get(url)
40
+ if response.status_code == 200:
41
+ return response.json()
42
+ else:
43
+ raise RuntimeError("Failed to fetch ImageNet labels")
44
+
45
+ # Download model if needed
46
+ def download_model(model_type):
47
+ if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
48
+ return None # Use PyTorch's pretrained model
49
+
50
+ model_path = Path(f"models/{model_type}.pt")
51
+ if not model_path.exists():
52
+ print(f"Downloading {model_type} model...")
53
+ url = MODEL_URLS[model_type]
54
+ response = requests.get(url, stream=True)
55
+ if response.status_code == 200:
56
+ with open(model_path, 'wb') as f:
57
+ for chunk in response.iter_content(chunk_size=8192):
58
+ f.write(chunk)
59
+ print(f"Model downloaded and saved to {model_path}")
60
+ else:
61
+ raise RuntimeError(f"Failed to download model: {response.status_code}")
62
+ return model_path
63
+
64
+ class NormalizeByChannelMeanStd(nn.Module):
65
+ def __init__(self, mean, std):
66
+ super(NormalizeByChannelMeanStd, self).__init__()
67
+ if not isinstance(mean, torch.Tensor):
68
+ mean = torch.tensor(mean)
69
+ if not isinstance(std, torch.Tensor):
70
+ std = torch.tensor(std)
71
+ self.register_buffer("mean", mean)
72
+ self.register_buffer("std", std)
73
+
74
+ def forward(self, tensor):
75
+ return self.normalize_fn(tensor, self.mean, self.std)
76
+
77
+ def normalize_fn(self, tensor, mean, std):
78
+ """Differentiable version of torchvision.functional.normalize"""
79
+ # here we assume the color channel is at dim=1
80
+ mean = mean[None, :, None, None]
81
+ std = std[None, :, None, None]
82
+ return tensor.sub(mean).div(std)
83
+
84
+ class InferStep:
85
+ def __init__(self, orig_image, eps, step_size):
86
+ self.orig_image = orig_image
87
+ self.eps = eps
88
+ self.step_size = step_size
89
+
90
+ def project(self, x):
91
+ diff = x - self.orig_image
92
+ diff = torch.clamp(diff, -self.eps, self.eps)
93
+ return torch.clamp(self.orig_image + diff, 0, 1)
94
+
95
+ def step(self, x, grad):
96
+ l = len(x.shape) - 1
97
+ grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l))
98
+ scaled_grad = grad / (grad_norm + 1e-10)
99
+ return scaled_grad * self.step_size
100
+
101
+ def get_inference_configs(eps=0.5, n_itr=50):
102
+ """Generate inference configuration with customizable parameters."""
103
+ config = {
104
+ 'loss_infer': 'IncreaseConfidence', # How to guide the optimization
105
+ 'loss_function': 'CE', # Loss function: Cross Entropy
106
+ 'n_itr': n_itr, # Number of iterations
107
+ 'eps': eps, # Maximum perturbation size
108
+ 'step_size': 0.02, # Step size for each iteration
109
+ 'diffusion_noise_ratio': 0.0, # No diffusion noise
110
+ 'initial_inference_noise_ratio': 0.0, # No initial noise
111
+ 'top_layer': 'all', # Use all layers of the model
112
+ 'inference_normalization': 'on', # Apply normalization during inference
113
+ 'recognition_normalization': 'on', # Apply normalization during recognition
114
+ 'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize
115
+ }
116
+ return config
117
+
118
+ class GenerativeInferenceModel:
119
+ def __init__(self):
120
+ self.models = {}
121
+ self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
122
+ self.labels = get_imagenet_labels()
123
+
124
+ def load_model(self, model_type):
125
+ if model_type in self.models:
126
+ return self.models[model_type]
127
+
128
+ model_path = download_model(model_type)
129
+
130
+ # Create standard ResNet50 model
131
+ model = models.resnet50()
132
+
133
+ # Load the model checkpoint
134
+ if model_path:
135
+ print(f"Loading {model_type} model from {model_path}...")
136
+ checkpoint = torch.load(model_path, map_location=device)
137
+
138
+ # Handle different checkpoint formats
139
+ if 'model' in checkpoint:
140
+ # Format from madrylab robust models
141
+ state_dict = checkpoint['model']
142
+ elif 'state_dict' in checkpoint:
143
+ state_dict = checkpoint['state_dict']
144
+ else:
145
+ # Direct state dict
146
+ state_dict = checkpoint
147
+
148
+ # Handle prefix in state dict keys
149
+ new_state_dict = {}
150
+ for key, value in state_dict.items():
151
+ if key.startswith('module.'):
152
+ new_key = key[7:] # Remove 'module.' prefix
153
+ else:
154
+ new_key = key
155
+ new_state_dict[new_key] = value
156
+
157
+ model.load_state_dict(new_state_dict)
158
+ else:
159
+ # Fallback to PyTorch's pretrained model
160
+ model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
161
+
162
+ model = model.to(device)
163
+ model.eval() # Set to evaluation mode
164
+
165
+ # Store the model for future use
166
+ self.models[model_type] = model
167
+ return model
168
+
169
+ def inference(self, image, model_type, config):
170
+ # Load model if not already loaded
171
+ model = self.load_model(model_type)
172
+
173
+ # Check if image is a file path
174
+ if isinstance(image, str):
175
+ if os.path.exists(image):
176
+ image = Image.open(image).convert('RGB')
177
+ else:
178
+ raise ValueError(f"Image path does not exist: {image}")
179
+
180
+ # Prepare image tensor
181
+ image_tensor = transform(image).unsqueeze(0).to(device)
182
+ image_tensor.requires_grad = True
183
+
184
+ # Normalize the image for model input
185
+ normalized_tensor = normalize_transform(image_tensor)
186
+
187
+ # Get original predictions
188
+ with torch.no_grad():
189
+ output_original = model(normalized_tensor)
190
+ probs_orig = F.softmax(output_original, dim=1)
191
+ conf_orig, classes_orig = torch.max(probs_orig, 1)
192
+
193
+ # Get least confident classes
194
+ _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
195
+
196
+ # Initialize inference step
197
+ infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
198
+
199
+ # Storage for inference steps
200
+ x = image_tensor.clone()
201
+ all_steps = [image_tensor[0].detach().cpu()]
202
+
203
+ # Main inference loop
204
+ for i in range(config['n_itr']):
205
+ # Reset gradients
206
+ x.grad = None
207
+
208
+ # Normalize input for the model
209
+ normalized_x = normalize_transform(x)
210
+
211
+ # Forward pass
212
+ output = model(normalized_x)
213
+
214
+ # Calculate loss to maximize confidence for least confident classes
215
+ target_classes = least_confident_classes[:10] # Use top 10 least confident classes
216
+ loss = 0
217
+ for idx in target_classes:
218
+ target = torch.tensor([idx.item()], device=device)
219
+ loss = loss - F.cross_entropy(output, target) # Negative because we want to maximize confidence
220
+
221
+ # Backward pass
222
+ loss.backward()
223
+
224
+ # Update image
225
+ with torch.no_grad():
226
+ step = infer_step.step(x, x.grad)
227
+ x = x + step
228
+ x = infer_step.project(x)
229
+
230
+ # Store step if in iterations_to_show
231
+ if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
232
+ all_steps.append(x[0].detach().cpu())
233
+
234
+ # Return final image and all stored steps
235
+ return x[0].detach().cpu(), all_steps
236
+
237
+ # Utility function to show inference steps
238
+ def show_inference_steps(steps, figsize=(15, 10)):
239
+ import matplotlib.pyplot as plt
240
+
241
+ n_steps = len(steps)
242
+ fig, axes = plt.subplots(1, n_steps, figsize=figsize)
243
+
244
+ for i, step_img in enumerate(steps):
245
+ img = step_img.permute(1, 2, 0).numpy()
246
+ axes[i].imshow(img)
247
+ axes[i].set_title(f"Step {i}")
248
+ axes[i].axis('off')
249
+
250
+ plt.tight_layout()
251
+ return fig
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ gradio
6
+ matplotlib
7
+ requests
8
+ tqdm
9
+ huggingface_hub
stimuli/Kanizsa_square.jpg ADDED
stimuli/NeonColorSaeedi.jpg ADDED

Git LFS Details

  • SHA256: ed0d51b349cfadd27fe9e683217e0bd83f45d23f86919022b56ba7c48463080d
  • Pointer size: 131 Bytes
  • Size of remote file: 839 kB
stimuli/face_vase.png ADDED
stimuli/figure_ground.png ADDED

Git LFS Details

  • SHA256: b366b9e23a3527f1587ed08df3abe3b74dc3410d8ff5a02daa652d364b8f2238
  • Pointer size: 131 Bytes
  • Size of remote file: 297 kB