Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import matplotlib | |
import clip | |
from utils import * | |
matplotlib.get_cachedir() | |
plt.rc('font', family="Times New Roman") | |
from sklearn import metrics | |
import torch | |
from torchvision import transforms | |
from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus | |
data_transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
(224,224), interpolation=transforms.InterpolationMode.BICUBIC | |
), | |
# transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
class CLIPModel_Super(torch.nn.Module): | |
def __init__(self, | |
type="ViT-L/14", | |
download_root=None, | |
device = "cuda"): | |
super().__init__() | |
self.device = device | |
self.model, _ = clip.load(type, device=self.device, download_root=download_root) | |
self.model = self.model.float() | |
def forward(self, vision_inputs): | |
""" | |
Input: | |
vision_inputs: torch.size([B,C,W,H]) | |
Output: | |
embeddings: a d-dimensional vector torch.size([B,d]) | |
""" | |
vision_inputs = vision_inputs.type(torch.float32) | |
with torch.no_grad(): | |
image_features = self.model.encode_image(vision_inputs) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
return image_features | |
def transform_vision_data(image): | |
""" | |
Input: | |
image: An image read by opencv [w,h,c] | |
Output: | |
image: After preproccessing, is a tensor [c,w,h] | |
""" | |
image = Image.fromarray(image) | |
image = data_transform(image) | |
return image | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Instantiate model | |
vis_model = CLIPModel_Super("ViT-L/14", device=device, download_root="./ckpt") | |
vis_model.eval() | |
vis_model.to(device) | |
print("load clip model") | |
semantic_path = "./clip_vitl_imagenet_zeroweights.pt" | |
if os.path.exists(semantic_path): | |
semantic_feature = torch.load(semantic_path, map_location="cpu") | |
semantic_feature = semantic_feature.to(device) | |
semantic_feature = semantic_feature.type(torch.float32) | |
explainer = MultiModalSubModularExplanationEfficientPlus( | |
vis_model, semantic_feature, transform_vision_data, device=device, | |
lambda1=0.01, | |
lambda2=0.05, | |
lambda3=20., | |
lambda4=5.) | |
def add_value_decrease(smdl_mask, json_file): | |
single_mask = np.zeros_like(smdl_mask[0].mean(-1)) | |
value_list_1 = np.array(json_file["consistency_score"]) + np.array(json_file["collaboration_score"]) | |
value_list_2 = np.array([json_file["baseline_score"]] + json_file["consistency_score"][:-1]) + np.array([1 - json_file["org_score"]] + json_file["collaboration_score"][:-1]) | |
value_list = value_list_1 - value_list_2 | |
values = [] | |
value = 0 | |
for smdl_single_mask, smdl_value in zip(smdl_mask, value_list): | |
value = value - abs(smdl_value) | |
single_mask[smdl_single_mask.sum(-1)>0] = value | |
values.append(value) | |
attribution_map = single_mask - single_mask.min() | |
attribution_map /= attribution_map.max() | |
return attribution_map, np.array(values) | |
def visualization(image, submodular_image_set, saved_json_file, vis_image, index=None, compute_params=True): | |
insertion_ours_images = [] | |
# deletion_ours_images = [] | |
insertion_image = submodular_image_set[0] - submodular_image_set[0] | |
insertion_ours_images.append(insertion_image) | |
# deletion_ours_images.append(image - insertion_image) | |
for smdl_sub_mask in submodular_image_set[:]: | |
insertion_image = insertion_image.copy() + smdl_sub_mask | |
insertion_ours_images.append(insertion_image) | |
# deletion_ours_images.append(image - insertion_image) | |
insertion_ours_images_input_results = np.array([1-saved_json_file["collaboration_score"][-1]] + saved_json_file["consistency_score"]) | |
if index == None: | |
ours_best_index = np.argmax(insertion_ours_images_input_results) | |
else: | |
ours_best_index = index | |
x = [(insertion_ours_image.sum(-1)!=0).sum() / (image.shape[0] * image.shape[1]) for insertion_ours_image in insertion_ours_images] | |
i = len(x) | |
fig, [ax1, ax2, ax3] = plt.subplots(1,3, gridspec_kw = {'width_ratios':[1, 1, 1.5]}, figsize=(30,8)) | |
ax1.spines["left"].set_visible(False) | |
ax1.spines["right"].set_visible(False) | |
ax1.spines["top"].set_visible(False) | |
ax1.spines["bottom"].set_visible(False) | |
ax1.xaxis.set_visible(False) | |
ax1.yaxis.set_visible(False) | |
ax1.set_title('Attribution Map', fontsize=54) | |
ax1.set_facecolor('white') | |
ax1.imshow(vis_image.astype(np.uint8)) | |
ax2.spines["left"].set_visible(False) | |
ax2.spines["right"].set_visible(False) | |
ax2.spines["top"].set_visible(False) | |
ax2.spines["bottom"].set_visible(False) | |
ax2.xaxis.set_visible(True) | |
ax2.yaxis.set_visible(False) | |
ax2.set_title('Searched Region', fontsize=54) | |
ax2.set_facecolor('white') | |
ax2.set_xlabel("Highest conf. {:.4f}".format(insertion_ours_images_input_results.max()), fontsize=44) | |
ax2.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) | |
ax3.set_xlim((0, 1)) | |
ax3.set_ylim((0, 1)) | |
ax3.set_ylabel('Recognition Score', fontsize=44) | |
ax3.set_xlabel('Percentage of image revealed', fontsize=44) | |
ax3.tick_params(axis='both', which='major', labelsize=36) | |
x_ = x[:i] | |
ours_y = insertion_ours_images_input_results[:i] | |
ax3.plot(x_, ours_y, color='dodgerblue', linewidth=3.5) # draw curve | |
ax3.set_facecolor('white') | |
ax3.spines['bottom'].set_color('black') | |
ax3.spines['bottom'].set_linewidth(2.0) | |
ax3.spines['top'].set_color('none') | |
ax3.spines['left'].set_color('black') | |
ax3.spines['left'].set_linewidth(2.0) | |
ax3.spines['right'].set_color('none') | |
# plt.legend(["Ours"], fontsize=40, loc="upper left") | |
ax3.scatter(x_[-1], ours_y[-1], color='dodgerblue', s=54) # Plot latest point | |
# 在曲线下方填充淡蓝色 | |
ax3.fill_between(x_, ours_y, color='dodgerblue', alpha=0.1) | |
kernel = np.ones((3, 3), dtype=np.uint8) | |
# ax3.plot([x_[ours_best_index], x_[ours_best_index]], [0, 1], color='red', linewidth=3.5) # 绘制红色曲线 | |
ax3.axvline(x=x_[int(ours_best_index)], color='red', linewidth=3.5) # 绘制红色垂直线 | |
# Ours | |
mask = (image - insertion_ours_images[int(ours_best_index)]).mean(-1) | |
mask[mask>0] = 1 | |
if int(ours_best_index) != 0: | |
dilate = cv2.dilate(mask, kernel, 3) | |
# erosion = cv2.erode(dilate, kernel, iterations=3) | |
# dilate = cv2.dilate(erosion, kernel, 2) | |
edge = dilate - mask | |
# erosion = cv2.erode(dilate, kernel, iterations=1) | |
image_debug = image.copy() | |
image_debug[mask>0] = image_debug[mask>0] * 0.5 | |
if int(ours_best_index) != 0: | |
image_debug[edge>0] = np.array([255,0,0]) | |
ax2.imshow(image_debug) | |
if compute_params: | |
auc = metrics.auc(x, insertion_ours_images_input_results) | |
ax3.set_title('Insertion Curve', fontsize=54) | |
fig.canvas.draw() | |
img_curve = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
img_curve = img_curve.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plt.close(fig) # 关闭图形以释放资源 | |
if compute_params: | |
return img_curve, insertion_ours_images_input_results.max(), auc, ours_best_index | |
else: | |
return img_curve | |
def gen_cam(image, mask): | |
""" | |
Generate heatmap | |
:param image: [H,W,C] | |
:param mask: [H,W],range 0-1 | |
:return: tuple(cam,heatmap) | |
""" | |
# Read image | |
# image = cv2.resize(cv2.imread(image_path), (224,224)) | |
# mask->heatmap | |
heatmap = cv2.applyColorMap(np.uint8(mask), cv2.COLORMAP_COOL) | |
heatmap = np.float32(heatmap) | |
# merge heatmap to original image | |
cam = 0.5*heatmap + 0.5*np.float32(image) | |
return cam, (heatmap).astype(np.uint8) | |
def norm_image(image): | |
""" | |
Normalization image | |
:param image: [H,W,C] | |
:return: | |
""" | |
image = image.copy() | |
image -= np.max(np.min(image), 0) | |
image /= np.max(image) | |
image *= 255. | |
return np.uint8(image) | |
def read_image(file_path): | |
image = Image.open(file_path) | |
image = image.convert("RGB") | |
return np.array(image) | |
# 使用同一个示例图像 "shark.png" | |
default_images = { | |
# "Default Image": read_image("images/shark.png"), | |
"Example: Tiger Shark": read_image("images/shark.png"), | |
"Example: Quail": read_image("images/bird.png") # 所有选项都使用相同的图片 | |
} | |
def interpret_image(uploaded_image, slider, text_input): | |
# 使用上传的图像(如果有),否则使用生成的图像 | |
if uploaded_image is not None: | |
image = np.array(uploaded_image) | |
else: | |
return None, 0, 0 | |
image = cv2.resize(image, (224, 224)) | |
element_sets_V = SubRegionDivision(image, mode="slico", region_size=30) | |
explainer.k = len(element_sets_V) | |
global submodular_image_set | |
global saved_json_file | |
global im | |
submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=None) | |
attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file) | |
im, heatmap = gen_cam(image, norm_image(attribution_map)) | |
image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, im, index=None) | |
text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label]) | |
return image_curve, highest_confidence, insertion_auc_score, text_output_class | |
def visualization_slider(uploaded_image, slider): | |
# 使用上传的图像(如果有),否则使用生成的图像 | |
if uploaded_image is not None: | |
image = np.array(uploaded_image) | |
else: | |
return None, 0, 0 | |
image = cv2.resize(image, (224, 224)) | |
image_curve = visualization(image, submodular_image_set, saved_json_file, im, index=slider, compute_params=False) | |
return image_curve | |
def update_image(thumbnail_name): | |
# 返回对应缩略图的图像数据 | |
return default_images[thumbnail_name] | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
# 第一排:上传图像输入框和一个缩略图 | |
with gr.Row(): | |
# 上传图像输入框 | |
image_input = gr.Image(label="Upload Image", type="numpy") | |
# 第一个缩略图和按钮 | |
with gr.Column(): | |
# gr.Image(value=default_images["Default Image"], type="numpy") | |
# button_default = gr.Button(value="Default Image") | |
# button_default.click( | |
# fn=lambda k="Default Image": update_image(k), | |
# inputs=[], | |
# outputs=image_input | |
# ) | |
gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False) | |
# 第二排:两个缩略图 | |
with gr.Row(): | |
for key in default_images.keys(): | |
with gr.Column(): | |
gr.Image(value=default_images[key], type="numpy") | |
button = gr.Button(value=key) | |
button.click( | |
fn=lambda k=key: update_image(k), | |
inputs=[], | |
outputs=image_input | |
) | |
# 文本输入框和滑块 | |
text_input = gr.Textbox(label="Text Input", placeholder="Enter some text here... (optional)") | |
with gr.Column(): | |
# 输出图像和控件 | |
image_output = gr.Image(label="Output Image") | |
slider = gr.Slider(minimum=0, maximum=50, step=1, label="Confidence Slider") | |
text_output_class = gr.Textbox(label="Explaining Category") | |
with gr.Row(): | |
# 最高置信度和插入 AUC Score 并排显示 | |
text_output_confidence = gr.Textbox(label="Highest Confidence") | |
text_output_auc = gr.Textbox(label="Insertion AUC Score") | |
interpret_button = gr.Button("Interpreting Model") | |
# 定义解释模型按钮点击事件 | |
interpret_button.click( | |
fn=interpret_image, | |
inputs=[image_input, slider, text_input], | |
outputs=[image_output, text_output_confidence, text_output_auc, text_output_class] | |
) | |
# 实时更新的滑块 | |
slider.change( | |
fn=visualization_slider, | |
inputs=[image_input, slider], | |
outputs=[image_output] | |
) | |
# 启动 Gradio 应用 | |
demo.launch() | |