File size: 1,928 Bytes
08614a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685680e
08614a1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import configparser
import gradio as gr
import numpy as np
import pandas as pd
from search_engine_model import SearchEngineModel

def get_text_embeddings(text_prompt, input_np_array):
    search_engine_model = SearchEngineModel()

    model, _ = search_engine_model.load_clip_model()
    text_embeddings = search_engine_model.encode_text(model, text_prompt)
    
    input_df = pd.DataFrame(input_np_array)
    search_result = search_engine_model.search_image_by_text_prompt(text_embeddings, input_df)
    
    return text_embeddings, search_result

def main():
    config_manager_obj = configparser.ConfigParser()
    config_manager_obj.read('./config.cfg')
    
    random_features = np.random.rand(50, 512)
    initial_dataframe = pd.DataFrame(random_features)
    names_column = [f'image_{it}.png' for it in range(0, len(random_features))]
    initial_dataframe.insert(0, 'images_names', names_column)

    main_app = gr.Interface(
        fn=get_text_embeddings,
        inputs=[
            gr.Textbox(),
            gr.Dataframe(
                initial_dataframe.values,
                headers = ["image_name"] + [f'feature_{it}'for it in range(0, random_features.shape[1])],
                type='numpy', 
                interactive=False
            )
        ],
        outputs=[
            gr.Dataframe(type='numpy', headers = [f'feature_{it}'for it in range(0, random_features.shape[1])]),
            gr.Dataframe(type='numpy', headers = ['image_name', 'similarity'])
        ],
        title="CLIP Text Embeddings",
        description="Obtain the embeddings of a given text and use the API to compare with a set of images' embeddings.",
        flagging_mode="never"
    )

    HOST_IP_ADDRESS = config_manager_obj['SERVER']['HOST_IP_ADDRESS']
    PORT_NUMBER = int(config_manager_obj['SERVER']['PORT_NUMBER'])
    main_app.launch(server_name=HOST_IP_ADDRESS, server_port=PORT_NUMBER, show_error=True)

main()