|
import clip |
|
import logging |
|
import os |
|
import pandas as pd |
|
from PIL import Image |
|
import random |
|
import torch |
|
|
|
class SearchEngineModel(): |
|
|
|
def __init__(self): |
|
self.logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model, self.preprocess = self.load_clip_model() |
|
|
|
def load_clip_model(self): |
|
model, preprocess = clip.load("ViT-B/32", device=self.device) |
|
|
|
return model, preprocess |
|
|
|
def read_image(self, image_path): |
|
pil_image = Image.open(image_path) |
|
|
|
return pil_image |
|
|
|
def encode_image(self, model, preprocess, image_path): |
|
image = preprocess(Image.open(image_path)).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
image_features = model.encode_image(image) |
|
|
|
image_features = pd.DataFrame(image_features.numpy()) |
|
|
|
return image_features |
|
|
|
def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show): |
|
encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path) |
|
similarity = encoded_images @ prompt_features.T |
|
values, indices = similarity.topk(nofimages_to_show, dim=0) |
|
|
|
results = [] |
|
for value, index in zip(values, indices): |
|
results.append(image_paths[index]) |
|
|
|
return results |
|
|
|
def search_image_by_text_prompt(self, text_prompt, nofimages_to_show): |
|
query = clip.tokenize([text_prompt]).to(self.device) |
|
with torch.no_grad(): |
|
text_features = self.model.encode_text(query) |
|
|
|
search_results = self.__search_image_auxiliar_func__(text_features, nofimages_to_show) |
|
|
|
return search_results |
|
|
|
def search_image_by_image_prompt(self, image_prompt, nofimages_to_show): |
|
image = self.preprocess(image_prompt).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
image_features = self.model.encode_image(image) |
|
|
|
search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show) |
|
|
|
return search_results |