import pandas as pd import altair as alt import streamlit as st from sklearn.decomposition import PCA from sklearn.manifold import MDS from sentence_transformers import SentenceTransformer st.set_page_config(layout='wide') # constants LIST_SIZE = 5 # cached resources @st.cache_resource def get_model() -> SentenceTransformer: return SentenceTransformer('panalexeu/xlm-roberta-ua-distilled') @st.cache_resource def get_pca() -> PCA: return PCA(n_components=2) @st.cache_resource def get_mds() -> MDS: return MDS(n_components=2) # define state if 'sentences' not in st.session_state: st.session_state['s_i'] = 0 st.session_state['sentences'] = [] # layout st.title('Playground') left_column, right_column = st.columns([1, 1], gap='large', border=True) # sentences with left_column: st.subheader('Sentences') sub_left, sub_right = st.columns(2, gap='small', vertical_alignment='bottom') # read input with sub_left: if text_input := st.chat_input(placeholder='What is on your mind?', ): if len(st.session_state['sentences']) < LIST_SIZE: st.session_state['sentences'].append(text_input) else: # Rewrite previous entries st.session_state['sentences'][st.session_state['s_i']] = text_input st.session_state['s_i'] += 1 if st.session_state['s_i'] == LIST_SIZE: st.session_state['s_i'] = 0 with sub_right: if clear_btn := st.button(label='Clear', icon=':material/clear:', type='primary'): st.session_state['sentences'] = [] # display sentences for i, sentence in enumerate(st.session_state['sentences']): st.write(i, sentence) # embeds display with (right_column): st.subheader('xlm-roberta-ua-distilled') if compute_btn := st.button('Compute', icon='🤖'): # calculating embeddings embeds = get_model().encode(st.session_state['sentences']) # projecting them proj_points = get_mds().fit_transform(embeds) # displaying projected embeddings points_df = pd.DataFrame(proj_points, columns=['x', 'y']) points_df['sentence'] = st.session_state['sentences'] chart = alt.Chart(points_df) \ .mark_circle(size=100) \ .encode(x='x', y='y', tooltip='sentence') \ .properties(title='2-d Projection of Sentence Embeddings (MDS)') \ .interactive() st.altair_chart(chart)