DanielIglesias97's picture
We have changed the app code to add the names of the images to the
95ec4d8
raw
history blame contribute delete
1.77 kB
import configparser
import gradio as gr
import numpy as np
import pandas as pd
from search_engine_model import SearchEngineModel
def get_image_embeddings(input_image_paths_list):
search_engine_model = SearchEngineModel()
model, preprocess = search_engine_model.load_clip_model()
images_paths_list = []
image_embeddings_list = []
for current_input_image_path_aux in input_image_paths_list:
current_image_embeddings = search_engine_model.encode_image(model, preprocess, current_input_image_path_aux)
image_embeddings_list.append(current_image_embeddings.values[0])
image_embeddings_np = np.array(image_embeddings_list)
image_embeddings_df = pd.DataFrame(image_embeddings_np)
image_embeddings_df.insert(0, "image_name", input_image_paths_list)
image_embeddings_np = image_embeddings_df.values
output_df = gr.DataFrame(
type="numpy",
headers=['image_name'] + [f'feature_{it}' for it in range(0, len(image_embeddings_df.columns)-1)],
value=image_embeddings_np
)
return output_df
def main():
config_manager_obj = configparser.ConfigParser()
config_manager_obj.read('./config.cfg')
main_app = gr.Interface(
fn=get_image_embeddings,
inputs=[
gr.File(label="Upload Image", file_count="multiple"),
],
outputs=[
gr.Dataframe(type='numpy'),
],
title="CLIP Image Embeddings",
description="Obtain the embeddings of the input images",
flagging_mode="never"
)
HOST_IP_ADDRESS = config_manager_obj['SERVER']['HOST_IP_ADDRESS']
PORT_NUMBER = int(config_manager_obj['SERVER']['PORT_NUMBER'])
main_app.launch(server_name=HOST_IP_ADDRESS, server_port=PORT_NUMBER)
main()