|
import configparser |
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
from search_engine_model import SearchEngineModel |
|
|
|
def get_text_embeddings(text_prompt, input_np_array): |
|
search_engine_model = SearchEngineModel() |
|
|
|
model, _ = search_engine_model.load_clip_model() |
|
text_embeddings = search_engine_model.encode_text(model, text_prompt) |
|
|
|
input_df = pd.DataFrame(input_np_array) |
|
search_result = search_engine_model.search_image_by_text_prompt(text_embeddings, input_df) |
|
|
|
return text_embeddings, search_result |
|
|
|
def main(): |
|
config_manager_obj = configparser.ConfigParser() |
|
config_manager_obj.read('./config.cfg') |
|
|
|
random_features = np.random.rand(50, 512) |
|
initial_dataframe = pd.DataFrame(random_features) |
|
names_column = [f'image_{it}.png' for it in range(0, len(random_features))] |
|
initial_dataframe.insert(0, 'images_names', names_column) |
|
|
|
main_app = gr.Interface( |
|
fn=get_text_embeddings, |
|
inputs=[ |
|
gr.Textbox(), |
|
gr.Dataframe( |
|
initial_dataframe.values, |
|
headers = ["image_name"] + [f'feature_{it}'for it in range(0, random_features.shape[1])], |
|
type='numpy', |
|
interactive=False |
|
) |
|
], |
|
outputs=[ |
|
gr.Dataframe(type='numpy', headers = [f'feature_{it}'for it in range(0, random_features.shape[1])]), |
|
gr.Dataframe(type='numpy', headers = ['image_name', 'similarity']) |
|
], |
|
title="CLIP Text Embeddings", |
|
description="Obtain the embeddings of a given text and use the API to compare with a set of images' embeddings.", |
|
flagging_mode="never" |
|
) |
|
|
|
HOST_IP_ADDRESS = config_manager_obj['SERVER']['HOST_IP_ADDRESS'] |
|
PORT_NUMBER = int(config_manager_obj['SERVER']['PORT_NUMBER']) |
|
main_app.launch(server_name=HOST_IP_ADDRESS, server_port=PORT_NUMBER, show_error=True) |
|
|
|
main() |