CLIP_Text_Embeddings / search_engine_model.py
DanielIglesias97's picture
We have updated the search_engine_model.py file to sort the resultant values of a search in the correct order, because it was reversed
f0c9860 verified
import clip
import logging
import os
import pandas as pd
from PIL import Image
import random
import torch
class SearchEngineModel():
def __init__(self):
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = self.load_clip_model()
def load_clip_model(self):
model, preprocess = clip.load("ViT-B/32", device=self.device)
return model, preprocess
def read_image(self, image_path):
pil_image = Image.open(image_path)
return pil_image
def encode_text(self, model, text_prompt):
query = clip.tokenize([text_prompt]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(query)
text_features = text_features.numpy()
return text_features
def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show):
encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path)
similarity = encoded_images @ prompt_features.T
values, indices = similarity.topk(nofimages_to_show, dim=0)
results = []
for value, index in zip(values, indices):
results.append(image_paths[index])
return results
def search_image_by_text_prompt(self, text_features, images_features):
names_column = images_features.values[:, 0]
search_results = images_features.values[:, 1:].astype(float) @ text_features.T
search_results_df = pd.DataFrame(search_results)
search_results_df.insert(0, "images_names", names_column)
search_results_df.columns = ['images_names', 'similarity']
search_results_df = search_results_df.sort_values(by='similarity', ascending=False)
search_results = search_results_df.values
return search_results
def search_image_by_image_prompt(self, image_prompt, nofimages_to_show):
image = self.preprocess(image_prompt).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image)
search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show)
return search_results