Spaces:
Runtime error
Runtime error
small fixes, CLIP vecs graph
Browse files- data/CLIP_vecs.pkl +3 -0
- pages/1_Disentanglement.py +1 -1
- pages/2_Concepts_comparison.py +64 -11
- view_predictions.ipynb +67 -0
data/CLIP_vecs.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e2971a01a74a391c752fff9ba91c2939ffc6b29165842a87b911e67d9658df53
|
| 3 |
+
size 412234
|
pages/1_Disentanglement.py
CHANGED
|
@@ -128,7 +128,7 @@ with input_col_2:
|
|
| 128 |
random_id = st.form_submit_button('Generate a random image')
|
| 129 |
|
| 130 |
if random_id:
|
| 131 |
-
image_id = random.randint(0,
|
| 132 |
st.session_state.image_id = image_id
|
| 133 |
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
| 134 |
|
|
|
|
| 128 |
random_id = st.form_submit_button('Generate a random image')
|
| 129 |
|
| 130 |
if random_id:
|
| 131 |
+
image_id = random.randint(0, 50000)
|
| 132 |
st.session_state.image_id = image_id
|
| 133 |
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
| 134 |
|
pages/2_Concepts_comparison.py
CHANGED
|
@@ -25,11 +25,11 @@ st.write('> **What is their join impact on the image?**')
|
|
| 25 |
st.write("""Description to write""")
|
| 26 |
|
| 27 |
|
| 28 |
-
annotations_file = './data/annotated_files/seeds0000-
|
| 29 |
with open(annotations_file, 'rb') as f:
|
| 30 |
annotations = pickle.load(f)
|
| 31 |
|
| 32 |
-
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-
|
| 33 |
concepts = './data/concepts.txt'
|
| 34 |
|
| 35 |
with open(concepts) as f:
|
|
@@ -57,13 +57,6 @@ with input_col_1:
|
|
| 57 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
| 58 |
concept_ids = st.multiselect('Concept:', tuple(labels))
|
| 59 |
|
| 60 |
-
choose_text_button = st.form_submit_button('Choose the defined concepts')
|
| 61 |
-
# random_text = st.form_submit_button('Select a random concept')
|
| 62 |
-
|
| 63 |
-
# if random_text:
|
| 64 |
-
# concept_id = random.choice(labels)
|
| 65 |
-
# st.session_state.concept_id = concept_id
|
| 66 |
-
# chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
| 67 |
st.write('**Choose a latent space to disentangle**')
|
| 68 |
# chosen_text_id_input = st.empty()
|
| 69 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
|
@@ -85,8 +78,8 @@ st.subheader('Concept vector')
|
|
| 85 |
# perform attack container
|
| 86 |
# header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
| 87 |
# output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
| 88 |
-
header_col_1, header_col_2 = st.columns([
|
| 89 |
-
output_col_1, output_col_2 = st.columns([
|
| 90 |
|
| 91 |
st.subheader('Derivations along the concept vector')
|
| 92 |
|
|
@@ -157,6 +150,66 @@ with output_col_1:
|
|
| 157 |
# Load HTML file in HTML component for display on Streamlit page
|
| 158 |
components.html(HtmlFile.read(), height=435)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
| 161 |
# with input_col_2:
|
| 162 |
# with st.form('image_form'):
|
|
|
|
| 25 |
st.write("""Description to write""")
|
| 26 |
|
| 27 |
|
| 28 |
+
annotations_file = './data/annotated_files/seeds0000-50000.pkl'
|
| 29 |
with open(annotations_file, 'rb') as f:
|
| 30 |
annotations = pickle.load(f)
|
| 31 |
|
| 32 |
+
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
|
| 33 |
concepts = './data/concepts.txt'
|
| 34 |
|
| 35 |
with open(concepts) as f:
|
|
|
|
| 57 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
| 58 |
concept_ids = st.multiselect('Concept:', tuple(labels))
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
st.write('**Choose a latent space to disentangle**')
|
| 61 |
# chosen_text_id_input = st.empty()
|
| 62 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
|
|
|
| 78 |
# perform attack container
|
| 79 |
# header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
| 80 |
# output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
| 81 |
+
header_col_1, header_col_2 = st.columns([1,1])
|
| 82 |
+
output_col_1, output_col_2 = st.columns([1,1])
|
| 83 |
|
| 84 |
st.subheader('Derivations along the concept vector')
|
| 85 |
|
|
|
|
| 150 |
# Load HTML file in HTML component for display on Streamlit page
|
| 151 |
components.html(HtmlFile.read(), height=435)
|
| 152 |
|
| 153 |
+
with output_col_2:
|
| 154 |
+
with open('data/CLIP_vecs.pkl', 'rb') as f:
|
| 155 |
+
vectors = pickle.load(f)
|
| 156 |
+
|
| 157 |
+
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
| 158 |
+
#st.write('Concept vector', separation_vector)
|
| 159 |
+
header_col_2.write(f'Concepts {", ".join(concept_ids)} - Latent space CLIP')# - Nodes {",".join(list(imp_nodes))}')
|
| 160 |
+
|
| 161 |
+
edges = []
|
| 162 |
+
for i in range(len(concept_ids)):
|
| 163 |
+
for j in range(len(concept_ids)):
|
| 164 |
+
if i != j:
|
| 165 |
+
print(f'Similarity between {concept_ids[i]} and {concept_ids[j]}')
|
| 166 |
+
similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1))
|
| 167 |
+
print(np.round(similarity[0][0], 3))
|
| 168 |
+
edges.append((concept_ids[i], concept_ids[j], np.round(similarity[0][0], 3)))
|
| 169 |
+
|
| 170 |
+
# # Create an empty graph
|
| 171 |
+
# G = nx.Graph()
|
| 172 |
+
|
| 173 |
+
# # Add edges with weights to the graph
|
| 174 |
+
# for edge in edges:
|
| 175 |
+
# node1, node2, weight = edge
|
| 176 |
+
# G.add_edge(node1, node2, weight=weight)
|
| 177 |
+
|
| 178 |
+
net = Network(height="750px", width="100%",)
|
| 179 |
+
for e in edges:
|
| 180 |
+
src = e[0]
|
| 181 |
+
dst = e[1]
|
| 182 |
+
w = e[2]
|
| 183 |
+
|
| 184 |
+
net.add_node(src, src, title=src)
|
| 185 |
+
net.add_node(dst, dst, title=dst)
|
| 186 |
+
net.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
|
| 187 |
+
print(net)
|
| 188 |
+
|
| 189 |
+
# Generate network with specific layout settings
|
| 190 |
+
net.repulsion(
|
| 191 |
+
node_distance=420,
|
| 192 |
+
central_gravity=0.33,
|
| 193 |
+
spring_length=110,
|
| 194 |
+
spring_strength=0.10,
|
| 195 |
+
damping=0.95
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Save and read graph as HTML file (on Streamlit Sharing)
|
| 199 |
+
try:
|
| 200 |
+
path = '/tmp'
|
| 201 |
+
net.save_graph(f'{path}/pyvis_graph_clip.html')
|
| 202 |
+
HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
|
| 203 |
+
|
| 204 |
+
# Save and read graph as HTML file (locally)
|
| 205 |
+
except:
|
| 206 |
+
path = '/html_files'
|
| 207 |
+
net.save_graph(f'{path}/pyvis_graph_clip.html')
|
| 208 |
+
HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
|
| 209 |
+
|
| 210 |
+
# Load HTML file in HTML component for display on Streamlit page
|
| 211 |
+
components.html(HtmlFile.read(), height=435)
|
| 212 |
+
|
| 213 |
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
| 214 |
# with input_col_2:
|
| 215 |
# with st.form('image_form'):
|
view_predictions.ipynb
CHANGED
|
@@ -208,6 +208,73 @@
|
|
| 208 |
"images[-1]"
|
| 209 |
]
|
| 210 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
{
|
| 212 |
"cell_type": "code",
|
| 213 |
"execution_count": 8,
|
|
|
|
| 208 |
"images[-1]"
|
| 209 |
]
|
| 210 |
},
|
| 211 |
+
{
|
| 212 |
+
"cell_type": "code",
|
| 213 |
+
"execution_count": 20,
|
| 214 |
+
"id": "f5390d8f",
|
| 215 |
+
"metadata": {},
|
| 216 |
+
"outputs": [
|
| 217 |
+
{
|
| 218 |
+
"name": "stderr",
|
| 219 |
+
"output_type": "stream",
|
| 220 |
+
"text": [
|
| 221 |
+
"/Users/ludovicaschaerf/anaconda3/envs/art-reco_x86/lib/python3.8/site-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
|
| 222 |
+
" warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"name": "stdout",
|
| 227 |
+
"output_type": "stream",
|
| 228 |
+
"text": [
|
| 229 |
+
"(132, 768)\n"
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
],
|
| 233 |
+
"source": [
|
| 234 |
+
"import open_clip\n",
|
| 235 |
+
"import os\n",
|
| 236 |
+
"import random\n",
|
| 237 |
+
"from tqdm import tqdm\n",
|
| 238 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')\n",
|
| 241 |
+
"tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"pre_prompt = \"Artwork, \" #@param {type:\"string\"}\n",
|
| 244 |
+
"text_descriptions = [f\"{pre_prompt}{label}\" for label in labels]\n",
|
| 245 |
+
"text_tokens = tokenizer(text_descriptions)\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"with torch.no_grad(), torch.cuda.amp.autocast():\n",
|
| 248 |
+
" text_features = model_clip.encode_text(text_tokens).float()\n",
|
| 249 |
+
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
|
| 250 |
+
" \n",
|
| 251 |
+
"text_features = text_features.cpu().numpy()\n",
|
| 252 |
+
"print(text_features.shape)\n",
|
| 253 |
+
"\n"
|
| 254 |
+
]
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"cell_type": "code",
|
| 258 |
+
"execution_count": 22,
|
| 259 |
+
"id": "f7858bbf",
|
| 260 |
+
"metadata": {},
|
| 261 |
+
"outputs": [],
|
| 262 |
+
"source": [
|
| 263 |
+
"dic_clip_vecs = {l:v for l,v in zip(labels, text_features)}"
|
| 264 |
+
]
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"cell_type": "code",
|
| 268 |
+
"execution_count": 26,
|
| 269 |
+
"id": "89b4a6fc",
|
| 270 |
+
"metadata": {},
|
| 271 |
+
"outputs": [],
|
| 272 |
+
"source": [
|
| 273 |
+
"dic_clip_vecs['Abstract'].shape\n",
|
| 274 |
+
"with open('data/CLIP_vecs.pkl', 'wb') as f:\n",
|
| 275 |
+
" pickle.dump(dic_clip_vecs, f)"
|
| 276 |
+
]
|
| 277 |
+
},
|
| 278 |
{
|
| 279 |
"cell_type": "code",
|
| 280 |
"execution_count": 8,
|