Oleksii Horhcynskyi
refactoring, minor bug fix in list displaying
26310fb
raw
history blame
2.5 kB
import pandas as pd
import altair as alt
import streamlit as st
from sklearn.decomposition import PCA
from sklearn.manifold import MDS
from sentence_transformers import SentenceTransformer
st.set_page_config(layout='wide')
# constants
LIST_SIZE = 5
# cached resources
@st.cache_resource
def get_model() -> SentenceTransformer:
return SentenceTransformer('panalexeu/xlm-roberta-ua-distilled')
@st.cache_resource
def get_pca() -> PCA:
return PCA(n_components=2)
@st.cache_resource
def get_mds() -> MDS:
return MDS(n_components=2)
# define state
if 'sentences' not in st.session_state:
st.session_state['s_i'] = 0
st.session_state['sentences'] = []
# layout
st.title('Playground')
left_column, right_column = st.columns([1, 1], gap='large', border=True)
# sentences
with left_column:
st.subheader('Sentences')
sub_left, sub_right = st.columns(2, gap='small', vertical_alignment='bottom')
# read input
with sub_left:
if text_input := st.chat_input(placeholder='What is on your mind?', ):
if len(st.session_state['sentences']) < LIST_SIZE:
st.session_state['sentences'].append(text_input)
else: # Rewrite previous entries
st.session_state['sentences'][st.session_state['s_i']] = text_input
st.session_state['s_i'] += 1
if st.session_state['s_i'] == LIST_SIZE:
st.session_state['s_i'] = 0
with sub_right:
if clear_btn := st.button(label='Clear', icon=':material/clear:', type='primary'):
st.session_state['sentences'] = []
# display sentences
for i, sentence in enumerate(st.session_state['sentences']):
st.write(i, sentence)
# embeds display
with (right_column):
st.subheader('xlm-roberta-ua-distilled')
if compute_btn := st.button('Compute', icon='πŸ€–'):
# calculating embeddings
embeds = get_model().encode(st.session_state['sentences'])
# projecting them
proj_points = get_mds().fit_transform(embeds)
# displaying projected embeddings
points_df = pd.DataFrame(proj_points, columns=['x', 'y'])
points_df['sentence'] = st.session_state['sentences']
chart = alt.Chart(points_df) \
.mark_circle(size=100) \
.encode(x='x', y='y', tooltip='sentence') \
.properties(title='2-d Projection of Sentence Embeddings (MDS)') \
.interactive()
st.altair_chart(chart)