import numpy as np | |
import pandas as pd | |
from search_engine_model import SearchEngineModel | |
def main(): | |
search_engine_model = SearchEngineModel() | |
model, preprocess = search_engine_model.load_clip_model() | |
text_prompt = 'cat' | |
text_embeddings = search_engine_model.encode_text(model, text_prompt) | |
random_features = np.random.rand(50, 512) | |
input_df = pd.DataFrame(random_features) | |
names_column = [f'image_{it}.png' for it in range(0, len(random_features))] | |
input_df.insert(0, 'images_names', names_column) | |
search_result = search_engine_model.search_image_by_text_prompt(text_embeddings, input_df) | |
main() |