Spaces:
Runtime error
Runtime error
File size: 7,971 Bytes
748cc87 982eb46 748cc87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import random
import os
import torch
import pandas as pd
from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection
from PIL import Image
CLIPmodel_import = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
CLIPprocessor_import = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
DetrFeatureExtractor_import = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
DetrModel_import = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
# Import list of coco example objects
script_path = os.path.dirname(__file__)
coco_objects = open(script_path+"/coco-labels-paper.txt", "r")
coco_objects = coco_objects.read()
coco_objects = coco_objects.split("\n")
# Example image
#test_image = Image.open('pages/Functions/test_image.png')
#test_image = Image.open('pages/Functions/test_imageIV.png')
###### Helper functions
def Coco_object_set(included_object, set_length=6):
'''
Creates set of object based on coco objects and the currently correct object.
'''
curr_object_set = set([included_object])
while len(curr_object_set)<set_length:
temp_object = random.choice(coco_objects)
curr_object_set.add(temp_object)
return list(curr_object_set)
def Object_set_creator(included_object, list_of_all_objects = coco_objects, excluded_objects_list = [], set_length=6):
'''
Creates set of object based on list_of_all_objects.
The included object will always be in the list.
Optional list of objects to be excluded from the set.
'''
curr_object_set = set([included_object])
# Check that the included object is not contained in the excluded objects
if included_object in excluded_objects_list:
raise ValueError('The included_object can not be part of the excluded_objects list.')
while len(curr_object_set)<set_length:
temp_object = random.choice(list_of_all_objects)
if temp_object not in excluded_objects_list:
curr_object_set.add(temp_object)
return list(curr_object_set)
###### Single object recognition
def CLIP_single_object_classifier(img, object_class, task_specific_label=None):
'''
Test presence of object in image by using the "red herring strategy" and CLIP algorithm.
Note that the task_specific_label is not used for this classifier.
'''
# Define model and parameters
word_list = Coco_object_set(object_class)
inputs = CLIPprocessor_import(text=word_list, images=img, return_tensors="pt", padding=True)
# Run inference
outputs = CLIPmodel_import(**inputs)
# Get image-text similarity score
logits_per_image = outputs.logits_per_image
# Get probabilities
probs = logits_per_image.softmax(dim=1)
# Return true if the highest prob value is recognised
if word_list[probs.argmax().item()]==object_class:
return True
else:
return False
def CLIP_object_recognition(img, object_class, tested_classes):
'''
More general CLIP object recogntintion implementation
'''
if object_class not in tested_classes:
raise ValueError('The object_class has to be part of the tested_classes list.')
# Define model and parameters
inputs = CLIPprocessor_import(text=tested_classes, images=img, return_tensors="pt", padding=True)
# Run inference
outputs = CLIPmodel_import(**inputs)
# Get image-text similarity score
logits_per_image = outputs.logits_per_image
# Get probabilities
probs = logits_per_image.softmax(dim=1)
# Return true if the highest prob value is recognised
if tested_classes[probs.argmax().item()]==object_class:
return True
else:
return False
###### Multi object recognition
#list_of_objects = ['cat','apple','cow']
def CLIP_multi_object_recognition(img, list_of_objects):
'''
Algorithm based on CLIP to test presence of multiple objects.
Currently has a debugging print call in.
'''
# Loop over list of objects, test for presence of each inidividually, making sure that non of the other objects is part of test set
for i_object in list_of_objects:
# Create list with objects not in test set (all objects which arent i_object)
untested_objects = [x for x in list_of_objects if x!= i_object]
# Create set going into clip object recogniser and test this set using standard recognition function
CLIP_test_classes = Object_set_creator(included_object=i_object, excluded_objects_list=untested_objects)
i_object_present = CLIP_object_recognition(img, i_object, CLIP_test_classes)
print(i_object+str(i_object_present))
# Stop loop and return false if one of the objects is not recognised by CLIP
if i_object_present == False:
return False
# Return true if all objects were recognised
return True
def CLIP_multi_object_recognition_DSwrapper(img, representations, task_specific_label=None):
'''
Dashboard wrapper of CLIP_multi_object_recognition
Note that the task_specific_label is not used for this classifier.
'''
list_of_objects = representations.split(', ')
return CLIP_multi_object_recognition(img,list_of_objects)
###### Negation
def CLIP_object_negation(img, present_object, absent_object):
'''
Algorithm based on CLIP to test negation prompts
'''
# Create sets of objects for present and absent object
tested_classes_present = Object_set_creator(
included_object=present_object, excluded_objects_list=[absent_object])
tested_classes_absent = Object_set_creator(
included_object=absent_object, excluded_objects_list=[present_object],set_length=10)
# Use CLIP object recognition to test for objects.
presence_test = CLIP_object_recognition(img, present_object, tested_classes_present)
absence_test = CLIP_object_recognition(img, absent_object, tested_classes_absent)
if presence_test==True and absence_test==False:
return True
else:
return False
###### Counting / arithmetic
'''
test_image = Image.open('pages/Functions/test_imageIII.jpeg')
object_classes = ['cat','remote']
object_counts = [2,2]
'''
def DETR_multi_object_counting(img, object_classes, object_counts, confidence_treshold=0.5):
# Apply Detr to image
inputs = DetrFeatureExtractor_import(images=img, return_tensors="pt")
outputs = DetrModel_import(**inputs)
# Convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([img.size[::-1]])
results = DetrFeatureExtractor_import.post_process_object_detection(
outputs, threshold=confidence_treshold, target_sizes=target_sizes)[0]
# Create dict with value_counts
count_dict = pd.Series(results['labels'].numpy())
count_dict = count_dict.map(DetrModel_import.config.id2label)
count_dict = count_dict.value_counts().to_dict()
# Create dict for correct response
label_dict = dict(zip(object_classes, object_counts))
# Return False is the count for a given label does not match
for i_item in label_dict.items():
# Check whether current label item exists in count dict, else return false
if i_item[0] not in count_dict:
return False
# Now that we checked the label item is in count dict, check that the count matches
if int(count_dict[i_item[0]])==int(i_item[1]): # Adding type control for comparison due to str read in
print(str(i_item)+'_true')
else:
print(str(i_item)+'_false')
print("oberserved: "+str(count_dict[i_item[0]]))
return False
# If all match, return true
return True
def DETR_multi_object_counting_DSwrapper(img, representations, Task_specific_label):
'''
Dashboard wrapper of DETR_multi_object_counting
'''
list_of_objects = representations.split(', ')
object_counts = Task_specific_label.split(', ')
return DETR_multi_object_counting(img,list_of_objects, object_counts, confidence_treshold=0.5) |