RuoyuChen's picture
first commit
4dca37a
raw
history blame
14 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2
import matplotlib
import clip
from utils import *
matplotlib.get_cachedir()
plt.rc('font', family="Times New Roman")
from sklearn import metrics
import torch
from torchvision import transforms
from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus
data_transform = transforms.Compose(
[
transforms.Resize(
(224,224), interpolation=transforms.InterpolationMode.BICUBIC
),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
class CLIPModel_Super(torch.nn.Module):
def __init__(self,
type="ViT-L/14",
download_root=None,
device = "cuda"):
super().__init__()
self.device = device
self.model, _ = clip.load(type, device=self.device, download_root=download_root)
self.model = self.model.float()
def forward(self, vision_inputs):
"""
Input:
vision_inputs: torch.size([B,C,W,H])
Output:
embeddings: a d-dimensional vector torch.size([B,d])
"""
vision_inputs = vision_inputs.type(torch.float32)
with torch.no_grad():
image_features = self.model.encode_image(vision_inputs)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def transform_vision_data(image):
"""
Input:
image: An image read by opencv [w,h,c]
Output:
image: After preproccessing, is a tensor [c,w,h]
"""
image = Image.fromarray(image)
image = data_transform(image)
return image
device = "cuda" if torch.cuda.is_available() else "cpu"
# Instantiate model
vis_model = CLIPModel_Super("ViT-L/14", device=device, download_root="./ckpt")
vis_model.eval()
vis_model.to(device)
print("load clip model")
semantic_path = "./clip_vitl_imagenet_zeroweights.pt"
if os.path.exists(semantic_path):
semantic_feature = torch.load(semantic_path, map_location="cpu")
semantic_feature = semantic_feature.to(device)
semantic_feature = semantic_feature.type(torch.float32)
explainer = MultiModalSubModularExplanationEfficientPlus(
vis_model, semantic_feature, transform_vision_data, device=device,
lambda1=0.01,
lambda2=0.05,
lambda3=20.,
lambda4=5.)
def add_value_decrease(smdl_mask, json_file):
single_mask = np.zeros_like(smdl_mask[0].mean(-1))
value_list_1 = np.array(json_file["consistency_score"]) + np.array(json_file["collaboration_score"])
value_list_2 = np.array([json_file["baseline_score"]] + json_file["consistency_score"][:-1]) + np.array([1 - json_file["org_score"]] + json_file["collaboration_score"][:-1])
value_list = value_list_1 - value_list_2
values = []
value = 0
for smdl_single_mask, smdl_value in zip(smdl_mask, value_list):
value = value - abs(smdl_value)
single_mask[smdl_single_mask.sum(-1)>0] = value
values.append(value)
attribution_map = single_mask - single_mask.min()
attribution_map /= attribution_map.max()
return attribution_map, np.array(values)
def visualization(image, submodular_image_set, saved_json_file, vis_image, index=None, compute_params=True):
insertion_ours_images = []
# deletion_ours_images = []
insertion_image = submodular_image_set[0] - submodular_image_set[0]
insertion_ours_images.append(insertion_image)
# deletion_ours_images.append(image - insertion_image)
for smdl_sub_mask in submodular_image_set[:]:
insertion_image = insertion_image.copy() + smdl_sub_mask
insertion_ours_images.append(insertion_image)
# deletion_ours_images.append(image - insertion_image)
insertion_ours_images_input_results = np.array([1-saved_json_file["collaboration_score"][-1]] + saved_json_file["consistency_score"])
if index == None:
ours_best_index = np.argmax(insertion_ours_images_input_results)
else:
ours_best_index = index
x = [(insertion_ours_image.sum(-1)!=0).sum() / (image.shape[0] * image.shape[1]) for insertion_ours_image in insertion_ours_images]
i = len(x)
fig, [ax1, ax2, ax3] = plt.subplots(1,3, gridspec_kw = {'width_ratios':[1, 1, 1.5]}, figsize=(30,8))
ax1.spines["left"].set_visible(False)
ax1.spines["right"].set_visible(False)
ax1.spines["top"].set_visible(False)
ax1.spines["bottom"].set_visible(False)
ax1.xaxis.set_visible(False)
ax1.yaxis.set_visible(False)
ax1.set_title('Attribution Map', fontsize=54)
ax1.set_facecolor('white')
ax1.imshow(vis_image.astype(np.uint8))
ax2.spines["left"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
ax2.xaxis.set_visible(True)
ax2.yaxis.set_visible(False)
ax2.set_title('Searched Region', fontsize=54)
ax2.set_facecolor('white')
ax2.set_xlabel("Highest conf. {:.4f}".format(insertion_ours_images_input_results.max()), fontsize=44)
ax2.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
ax3.set_xlim((0, 1))
ax3.set_ylim((0, 1))
ax3.set_ylabel('Recognition Score', fontsize=44)
ax3.set_xlabel('Percentage of image revealed', fontsize=44)
ax3.tick_params(axis='both', which='major', labelsize=36)
x_ = x[:i]
ours_y = insertion_ours_images_input_results[:i]
ax3.plot(x_, ours_y, color='dodgerblue', linewidth=3.5) # draw curve
ax3.set_facecolor('white')
ax3.spines['bottom'].set_color('black')
ax3.spines['bottom'].set_linewidth(2.0)
ax3.spines['top'].set_color('none')
ax3.spines['left'].set_color('black')
ax3.spines['left'].set_linewidth(2.0)
ax3.spines['right'].set_color('none')
# plt.legend(["Ours"], fontsize=40, loc="upper left")
ax3.scatter(x_[-1], ours_y[-1], color='dodgerblue', s=54) # Plot latest point
# 在曲线下方填充淡蓝色
ax3.fill_between(x_, ours_y, color='dodgerblue', alpha=0.1)
kernel = np.ones((3, 3), dtype=np.uint8)
# ax3.plot([x_[ours_best_index], x_[ours_best_index]], [0, 1], color='red', linewidth=3.5) # 绘制红色曲线
ax3.axvline(x=x_[int(ours_best_index)], color='red', linewidth=3.5) # 绘制红色垂直线
# Ours
mask = (image - insertion_ours_images[int(ours_best_index)]).mean(-1)
mask[mask>0] = 1
if int(ours_best_index) != 0:
dilate = cv2.dilate(mask, kernel, 3)
# erosion = cv2.erode(dilate, kernel, iterations=3)
# dilate = cv2.dilate(erosion, kernel, 2)
edge = dilate - mask
# erosion = cv2.erode(dilate, kernel, iterations=1)
image_debug = image.copy()
image_debug[mask>0] = image_debug[mask>0] * 0.5
if int(ours_best_index) != 0:
image_debug[edge>0] = np.array([255,0,0])
ax2.imshow(image_debug)
if compute_params:
auc = metrics.auc(x, insertion_ours_images_input_results)
ax3.set_title('Insertion Curve', fontsize=54)
fig.canvas.draw()
img_curve = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img_curve = img_curve.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig) # 关闭图形以释放资源
if compute_params:
return img_curve, insertion_ours_images_input_results.max(), auc, ours_best_index
else:
return img_curve
def gen_cam(image, mask):
"""
Generate heatmap
:param image: [H,W,C]
:param mask: [H,W],range 0-1
:return: tuple(cam,heatmap)
"""
# Read image
# image = cv2.resize(cv2.imread(image_path), (224,224))
# mask->heatmap
heatmap = cv2.applyColorMap(np.uint8(mask), cv2.COLORMAP_COOL)
heatmap = np.float32(heatmap)
# merge heatmap to original image
cam = 0.5*heatmap + 0.5*np.float32(image)
return cam, (heatmap).astype(np.uint8)
def norm_image(image):
"""
Normalization image
:param image: [H,W,C]
:return:
"""
image = image.copy()
image -= np.max(np.min(image), 0)
image /= np.max(image)
image *= 255.
return np.uint8(image)
def read_image(file_path):
image = Image.open(file_path)
image = image.convert("RGB")
return np.array(image)
# 使用同一个示例图像 "shark.png"
default_images = {
# "Default Image": read_image("images/shark.png"),
"Example: Tiger Shark": read_image("images/shark.png"),
"Example: Quail": read_image("images/bird.png") # 所有选项都使用相同的图片
}
def interpret_image(uploaded_image, slider, text_input):
# 使用上传的图像(如果有),否则使用生成的图像
if uploaded_image is not None:
image = np.array(uploaded_image)
else:
return None, 0, 0
image = cv2.resize(image, (224, 224))
element_sets_V = SubRegionDivision(image, mode="slico", region_size=30)
explainer.k = len(element_sets_V)
global submodular_image_set
global saved_json_file
global im
submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=None)
attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
im, heatmap = gen_cam(image, norm_image(attribution_map))
image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, im, index=None)
text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label])
return image_curve, highest_confidence, insertion_auc_score, text_output_class
def visualization_slider(uploaded_image, slider):
# 使用上传的图像(如果有),否则使用生成的图像
if uploaded_image is not None:
image = np.array(uploaded_image)
else:
return None, 0, 0
image = cv2.resize(image, (224, 224))
image_curve = visualization(image, submodular_image_set, saved_json_file, im, index=slider, compute_params=False)
return image_curve
def update_image(thumbnail_name):
# 返回对应缩略图的图像数据
return default_images[thumbnail_name]
# 创建 Gradio 界面
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
# 第一排:上传图像输入框和一个缩略图
with gr.Row():
# 上传图像输入框
image_input = gr.Image(label="Upload Image", type="numpy")
# 第一个缩略图和按钮
with gr.Column():
# gr.Image(value=default_images["Default Image"], type="numpy")
# button_default = gr.Button(value="Default Image")
# button_default.click(
# fn=lambda k="Default Image": update_image(k),
# inputs=[],
# outputs=image_input
# )
gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False)
# 第二排:两个缩略图
with gr.Row():
for key in default_images.keys():
with gr.Column():
gr.Image(value=default_images[key], type="numpy")
button = gr.Button(value=key)
button.click(
fn=lambda k=key: update_image(k),
inputs=[],
outputs=image_input
)
# 文本输入框和滑块
text_input = gr.Textbox(label="Text Input", placeholder="Enter some text here... (optional)")
with gr.Column():
# 输出图像和控件
image_output = gr.Image(label="Output Image")
slider = gr.Slider(minimum=0, maximum=50, step=1, label="Confidence Slider")
text_output_class = gr.Textbox(label="Explaining Category")
with gr.Row():
# 最高置信度和插入 AUC Score 并排显示
text_output_confidence = gr.Textbox(label="Highest Confidence")
text_output_auc = gr.Textbox(label="Insertion AUC Score")
interpret_button = gr.Button("Interpreting Model")
# 定义解释模型按钮点击事件
interpret_button.click(
fn=interpret_image,
inputs=[image_input, slider, text_input],
outputs=[image_output, text_output_confidence, text_output_auc, text_output_class]
)
# 实时更新的滑块
slider.change(
fn=visualization_slider,
inputs=[image_input, slider],
outputs=[image_output]
)
# 启动 Gradio 应用
demo.launch()