CLIP_Images_Embeddings / search_engine_model.py
DanielIglesias97's picture
First upload to the repository of the CLIP Embeddings extractor.
e93b7b1
raw
history blame contribute delete
2.17 kB
import clip
import logging
import os
import pandas as pd
from PIL import Image
import random
import torch
class SearchEngineModel():
def __init__(self):
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = self.load_clip_model()
def load_clip_model(self):
model, preprocess = clip.load("ViT-B/32", device=self.device)
return model, preprocess
def read_image(self, image_path):
pil_image = Image.open(image_path)
return pil_image
def encode_image(self, model, preprocess, image_path):
image = preprocess(Image.open(image_path)).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = model.encode_image(image)
image_features = pd.DataFrame(image_features.numpy())
return image_features
def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show):
encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path)
similarity = encoded_images @ prompt_features.T
values, indices = similarity.topk(nofimages_to_show, dim=0)
results = []
for value, index in zip(values, indices):
results.append(image_paths[index])
return results
def search_image_by_text_prompt(self, text_prompt, nofimages_to_show):
query = clip.tokenize([text_prompt]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(query)
search_results = self.__search_image_auxiliar_func__(text_features, nofimages_to_show)
return search_results
def search_image_by_image_prompt(self, image_prompt, nofimages_to_show):
image = self.preprocess(image_prompt).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image)
search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show)
return search_results