File size: 2,343 Bytes
08614a1 f0c9860 08614a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import clip
import logging
import os
import pandas as pd
from PIL import Image
import random
import torch
class SearchEngineModel():
def __init__(self):
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = self.load_clip_model()
def load_clip_model(self):
model, preprocess = clip.load("ViT-B/32", device=self.device)
return model, preprocess
def read_image(self, image_path):
pil_image = Image.open(image_path)
return pil_image
def encode_text(self, model, text_prompt):
query = clip.tokenize([text_prompt]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(query)
text_features = text_features.numpy()
return text_features
def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show):
encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path)
similarity = encoded_images @ prompt_features.T
values, indices = similarity.topk(nofimages_to_show, dim=0)
results = []
for value, index in zip(values, indices):
results.append(image_paths[index])
return results
def search_image_by_text_prompt(self, text_features, images_features):
names_column = images_features.values[:, 0]
search_results = images_features.values[:, 1:].astype(float) @ text_features.T
search_results_df = pd.DataFrame(search_results)
search_results_df.insert(0, "images_names", names_column)
search_results_df.columns = ['images_names', 'similarity']
search_results_df = search_results_df.sort_values(by='similarity', ascending=False)
search_results = search_results_df.values
return search_results
def search_image_by_image_prompt(self, image_prompt, nofimages_to_show):
image = self.preprocess(image_prompt).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image)
search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show)
return search_results |