import clip import logging import os import pandas as pd from PIL import Image import random import torch class SearchEngineModel(): def __init__(self): self.logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model, self.preprocess = self.load_clip_model() def load_clip_model(self): model, preprocess = clip.load("ViT-B/32", device=self.device) return model, preprocess def read_image(self, image_path): pil_image = Image.open(image_path) return pil_image def encode_image(self, model, preprocess, image_path): image = preprocess(Image.open(image_path)).unsqueeze(0).to(self.device) with torch.no_grad(): image_features = model.encode_image(image) image_features = pd.DataFrame(image_features.numpy()) return image_features def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show): encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path) similarity = encoded_images @ prompt_features.T values, indices = similarity.topk(nofimages_to_show, dim=0) results = [] for value, index in zip(values, indices): results.append(image_paths[index]) return results def search_image_by_text_prompt(self, text_prompt, nofimages_to_show): query = clip.tokenize([text_prompt]).to(self.device) with torch.no_grad(): text_features = self.model.encode_text(query) search_results = self.__search_image_auxiliar_func__(text_features, nofimages_to_show) return search_results def search_image_by_image_prompt(self, image_prompt, nofimages_to_show): image = self.preprocess(image_prompt).unsqueeze(0).to(self.device) with torch.no_grad(): image_features = self.model.encode_image(image) search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show) return search_results