File size: 2,496 Bytes
fa8fd97
 
b1ce51e
a9b1577
e4d63c4
6ac9a13
b1ce51e
6ac9a13
4919963
6ac9a13
4919963
 
6ac9a13
 
 
 
 
 
 
a9b1577
 
 
 
 
e4d63c4
 
 
 
 
4919963
 
 
 
 
26310fb
879dab1
 
6ac9a13
4919963
6ac9a13
4919963
 
 
 
 
 
 
e4d63c4
4919963
 
 
 
 
 
 
26310fb
4919963
 
 
 
 
 
 
 
 
 
fa8fd97
 
e4d63c4
26310fb
6ac9a13
fa8fd97
6ac9a13
fa8fd97
 
e4d63c4
fa8fd97
 
 
 
 
 
 
e4d63c4
fa8fd97
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pandas as pd
import altair as alt
import streamlit as st
from sklearn.decomposition import PCA
from sklearn.manifold import MDS
from sentence_transformers import SentenceTransformer

st.set_page_config(layout='wide')

# constants
LIST_SIZE = 5


# cached resources
@st.cache_resource
def get_model() -> SentenceTransformer:
    return SentenceTransformer('panalexeu/xlm-roberta-ua-distilled')


@st.cache_resource
def get_pca() -> PCA:
    return PCA(n_components=2)


@st.cache_resource
def get_mds() -> MDS:
    return MDS(n_components=2)


# define state
if 'sentences' not in st.session_state:
    st.session_state['s_i'] = 0
    st.session_state['sentences'] = []

# layout
st.title('Playground')

left_column, right_column = st.columns([1, 1], gap='large', border=True)

# sentences
with left_column:
    st.subheader('Sentences')

    sub_left, sub_right = st.columns(2, gap='small', vertical_alignment='bottom')

    # read input
    with sub_left:
        if text_input := st.chat_input(placeholder='What is on your mind?', ):

            if len(st.session_state['sentences']) < LIST_SIZE:
                st.session_state['sentences'].append(text_input)
            else:  # Rewrite previous entries
                st.session_state['sentences'][st.session_state['s_i']] = text_input

                st.session_state['s_i'] += 1
                if st.session_state['s_i'] == LIST_SIZE:
                    st.session_state['s_i'] = 0

    with sub_right:
        if clear_btn := st.button(label='Clear', icon=':material/clear:', type='primary'):
            st.session_state['sentences'] = []

    # display sentences
    for i, sentence in enumerate(st.session_state['sentences']):
        st.write(i, sentence)

# embeds display
with (right_column):
    st.subheader('xlm-roberta-ua-distilled')

    if compute_btn := st.button('Compute', icon='🤖'):
        # calculating embeddings
        embeds = get_model().encode(st.session_state['sentences'])

        # projecting them
        proj_points = get_mds().fit_transform(embeds)

        # displaying projected embeddings
        points_df = pd.DataFrame(proj_points, columns=['x', 'y'])
        points_df['sentence'] = st.session_state['sentences']
        chart = alt.Chart(points_df) \
            .mark_circle(size=100) \
            .encode(x='x', y='y', tooltip='sentence') \
            .properties(title='2-d Projection of Sentence Embeddings (MDS)') \
            .interactive()

        st.altair_chart(chart)