
We have updated the search_engine_model.py file to sort the resultant values of a search in the correct order, because it was reversed
f0c9860
verified
import clip | |
import logging | |
import os | |
import pandas as pd | |
from PIL import Image | |
import random | |
import torch | |
class SearchEngineModel(): | |
def __init__(self): | |
self.logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model, self.preprocess = self.load_clip_model() | |
def load_clip_model(self): | |
model, preprocess = clip.load("ViT-B/32", device=self.device) | |
return model, preprocess | |
def read_image(self, image_path): | |
pil_image = Image.open(image_path) | |
return pil_image | |
def encode_text(self, model, text_prompt): | |
query = clip.tokenize([text_prompt]).to(self.device) | |
with torch.no_grad(): | |
text_features = self.model.encode_text(query) | |
text_features = text_features.numpy() | |
return text_features | |
def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show): | |
encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path) | |
similarity = encoded_images @ prompt_features.T | |
values, indices = similarity.topk(nofimages_to_show, dim=0) | |
results = [] | |
for value, index in zip(values, indices): | |
results.append(image_paths[index]) | |
return results | |
def search_image_by_text_prompt(self, text_features, images_features): | |
names_column = images_features.values[:, 0] | |
search_results = images_features.values[:, 1:].astype(float) @ text_features.T | |
search_results_df = pd.DataFrame(search_results) | |
search_results_df.insert(0, "images_names", names_column) | |
search_results_df.columns = ['images_names', 'similarity'] | |
search_results_df = search_results_df.sort_values(by='similarity', ascending=False) | |
search_results = search_results_df.values | |
return search_results | |
def search_image_by_image_prompt(self, image_prompt, nofimages_to_show): | |
image = self.preprocess(image_prompt).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
image_features = self.model.encode_image(image) | |
search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show) | |
return search_results |