Spaces:
Runtime error
Runtime error
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision | |
import clip | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import gradio as gr | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'] | |
model, preprocess = clip.load(model_name) | |
model.to(DEVICE).eval() | |
resolution = model.visual.input_resolution | |
resizer = torchvision.transforms.Resize(size=(resolution, resolution)) | |
def create_rgb_tensor(color): | |
"""color is e.g. [1,0,0]""" | |
return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1)) | |
def encode_color(color): | |
"""color is e.g. [1,0,0]""" | |
rgb = create_rgb_tensor(color) | |
return model.encode_image( resizer(rgb) ) | |
def encode_text(text): | |
tokenized_text = clip.tokenize(text).to(DEVICE) | |
return model.encode_text(tokenized_text) | |
def lerp(x, y, steps=11): | |
"""Linear interpolation between two tensors """ | |
weights = torch.tensor(np.linspace(0,1,steps), device=DEVICE).reshape([-1, 1, 1, 1]) | |
interpolated = x * (1 - weights) + y * weights | |
return interpolated | |
def get_interpolated_scores(x, y, encoded_text, steps=11): | |
interpolated = lerp(x, y, steps) | |
interpolated_encodings = model.encode_image(resizer(interpolated)) | |
scores = torch.cosine_similarity(interpolated_encodings, encoded_text).detach().cpu().numpy() | |
rgb = interpolated.detach().cpu().numpy().reshape(-1, 3) | |
interpolated_hex = [rgb2hex(x) for x in rgb] | |
data = pd.DataFrame({ | |
'similarity': scores, | |
'color': interpolated_hex | |
}).reset_index().rename(columns={'index':'step'}) | |
return data | |
def rgb2hex(rgb): | |
rgb = (rgb * 255).astype(int) | |
r,g,b = rgb | |
return "#{:02x}{:02x}{:02x}".format(r,g,b) | |
def similarity_plot(data, text_prompt): | |
title = f'CLIP Cosine Similarity Prompt="{text_prompt}"' | |
fig, ax = plt.subplots() | |
plot = data['similarity'].plot(kind='bar', | |
ax=ax, | |
stacked=True, | |
title=title, | |
color=data['color'], | |
width=1.0, | |
xlim=(0, 2), | |
grid=False) | |
plot.get_xaxis().set_visible(False) ; | |
return fig | |
def interpolation_experiment(rgb_start, rgb_end, text_prompt, steps=11): | |
start = create_rgb_tensor(rgb_start) | |
end = create_rgb_tensor(rgb_end) | |
encoded_text = encode_text(text_prompt) | |
data = get_interpolated_scores(start, end, encoded_text, steps) | |
return similarity_plot(data, text_prompt) | |
start_input = gr.inputs.Textbox(lines=1, default="1, 0, 0", label="Start RGB") | |
end_input = gr.inputs.Textbox(lines=1, default="0, 1, 0", label="End RGB") | |
' (Comma separated numbers between 0 and 1)' | |
text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square') | |
steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Interpolation Steps") | |
def gradio_fn(rgb_start, rgb_end, text_prompt, steps=11): | |
rgb_start = [float(x.strip()) for x in rgb_start.split(',')] | |
rgb_end = [float(x.strip()) for x in rgb_end.split(',')] | |
out = interpolation_experiment(rgb_start, rgb_end, text_prompt, steps) | |
return out | |
iface = gr.Interface( fn=gradio_fn, inputs=[start_input, end_input, text_input, steps_input], outputs="plot") | |
iface.launch(debug=True, share=False) |