File size: 1,928 Bytes
08614a1 685680e 08614a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import configparser
import gradio as gr
import numpy as np
import pandas as pd
from search_engine_model import SearchEngineModel
def get_text_embeddings(text_prompt, input_np_array):
search_engine_model = SearchEngineModel()
model, _ = search_engine_model.load_clip_model()
text_embeddings = search_engine_model.encode_text(model, text_prompt)
input_df = pd.DataFrame(input_np_array)
search_result = search_engine_model.search_image_by_text_prompt(text_embeddings, input_df)
return text_embeddings, search_result
def main():
config_manager_obj = configparser.ConfigParser()
config_manager_obj.read('./config.cfg')
random_features = np.random.rand(50, 512)
initial_dataframe = pd.DataFrame(random_features)
names_column = [f'image_{it}.png' for it in range(0, len(random_features))]
initial_dataframe.insert(0, 'images_names', names_column)
main_app = gr.Interface(
fn=get_text_embeddings,
inputs=[
gr.Textbox(),
gr.Dataframe(
initial_dataframe.values,
headers = ["image_name"] + [f'feature_{it}'for it in range(0, random_features.shape[1])],
type='numpy',
interactive=False
)
],
outputs=[
gr.Dataframe(type='numpy', headers = [f'feature_{it}'for it in range(0, random_features.shape[1])]),
gr.Dataframe(type='numpy', headers = ['image_name', 'similarity'])
],
title="CLIP Text Embeddings",
description="Obtain the embeddings of a given text and use the API to compare with a set of images' embeddings.",
flagging_mode="never"
)
HOST_IP_ADDRESS = config_manager_obj['SERVER']['HOST_IP_ADDRESS']
PORT_NUMBER = int(config_manager_obj['SERVER']['PORT_NUMBER'])
main_app.launch(server_name=HOST_IP_ADDRESS, server_port=PORT_NUMBER, show_error=True)
main() |