miccull's picture
no more cuda
0d248fe
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torchvision
import clip
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
model, preprocess = clip.load(model_name)
model.to(DEVICE).eval()
resolution = model.visual.input_resolution
resizer = torchvision.transforms.Resize(size=(resolution, resolution))
def create_rgb_tensor(color):
"""color is e.g. [1,0,0]"""
return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1))
def encode_color(color):
"""color is e.g. [1,0,0]"""
rgb = create_rgb_tensor(color)
return model.encode_image( resizer(rgb) )
def encode_text(text):
tokenized_text = clip.tokenize(text).to(DEVICE)
return model.encode_text(tokenized_text)
def lerp(x, y, steps=11):
"""Linear interpolation between two tensors """
weights = torch.tensor(np.linspace(0,1,steps), device=DEVICE).reshape([-1, 1, 1, 1])
interpolated = x * (1 - weights) + y * weights
return interpolated
def get_interpolated_scores(x, y, encoded_text, steps=11):
interpolated = lerp(x, y, steps)
interpolated_encodings = model.encode_image(resizer(interpolated))
scores = torch.cosine_similarity(interpolated_encodings, encoded_text).detach().cpu().numpy()
rgb = interpolated.detach().cpu().numpy().reshape(-1, 3)
interpolated_hex = [rgb2hex(x) for x in rgb]
data = pd.DataFrame({
'similarity': scores,
'color': interpolated_hex
}).reset_index().rename(columns={'index':'step'})
return data
def rgb2hex(rgb):
rgb = (rgb * 255).astype(int)
r,g,b = rgb
return "#{:02x}{:02x}{:02x}".format(r,g,b)
def similarity_plot(data, text_prompt):
title = f'CLIP Cosine Similarity Prompt="{text_prompt}"'
fig, ax = plt.subplots()
plot = data['similarity'].plot(kind='bar',
ax=ax,
stacked=True,
title=title,
color=data['color'],
width=1.0,
xlim=(0, 2),
grid=False)
plot.get_xaxis().set_visible(False) ;
return fig
def interpolation_experiment(rgb_start, rgb_end, text_prompt, steps=11):
start = create_rgb_tensor(rgb_start)
end = create_rgb_tensor(rgb_end)
encoded_text = encode_text(text_prompt)
data = get_interpolated_scores(start, end, encoded_text, steps)
return similarity_plot(data, text_prompt)
start_input = gr.inputs.Textbox(lines=1, default="1, 0, 0", label="Start RGB")
end_input = gr.inputs.Textbox(lines=1, default="0, 1, 0", label="End RGB")
' (Comma separated numbers between 0 and 1)'
text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square')
steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Interpolation Steps")
def gradio_fn(rgb_start, rgb_end, text_prompt, steps=11):
rgb_start = [float(x.strip()) for x in rgb_start.split(',')]
rgb_end = [float(x.strip()) for x in rgb_end.split(',')]
out = interpolation_experiment(rgb_start, rgb_end, text_prompt, steps)
return out
iface = gr.Interface( fn=gradio_fn, inputs=[start_input, end_input, text_input, steps_input], outputs="plot")
iface.launch(debug=True, share=False)