Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import altair as alt | |
| import streamlit as st | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import MDS | |
| from sentence_transformers import SentenceTransformer | |
| st.set_page_config(layout='wide') | |
| # constants | |
| LIST_SIZE = 5 | |
| # cached resources | |
| def get_model() -> SentenceTransformer: | |
| return SentenceTransformer('panalexeu/xlm-roberta-ua-distilled') | |
| def get_pca() -> PCA: | |
| return PCA(n_components=2) | |
| def get_mds() -> MDS: | |
| return MDS(n_components=2) | |
| # define state | |
| if 'sentences' not in st.session_state: | |
| st.session_state['s_i'] = 0 | |
| st.session_state['sentences'] = [] | |
| # layout | |
| st.title('Playground') | |
| left_column, right_column = st.columns([1, 1], gap='large', border=True) | |
| # sentences | |
| with left_column: | |
| st.subheader('Sentences') | |
| sub_left, sub_right = st.columns(2, gap='small', vertical_alignment='bottom') | |
| # read input | |
| with sub_left: | |
| if text_input := st.chat_input(placeholder='What is on your mind?', ): | |
| if len(st.session_state['sentences']) < LIST_SIZE: | |
| st.session_state['sentences'].append(text_input) | |
| else: # Rewrite previous entries | |
| st.session_state['sentences'][st.session_state['s_i']] = text_input | |
| st.session_state['s_i'] += 1 | |
| if st.session_state['s_i'] == LIST_SIZE: | |
| st.session_state['s_i'] = 0 | |
| with sub_right: | |
| if clear_btn := st.button(label='Clear', icon=':material/clear:', type='primary'): | |
| st.session_state['sentences'] = [] | |
| # display sentences | |
| for i, sentence in enumerate(st.session_state['sentences']): | |
| st.write(i, sentence) | |
| # embeds display | |
| with (right_column): | |
| st.subheader('xlm-roberta-ua-distilled') | |
| if compute_btn := st.button('Compute', icon='π€'): | |
| # calculating embeddings | |
| embeds = get_model().encode(st.session_state['sentences']) | |
| # projecting them | |
| proj_points = get_mds().fit_transform(embeds) | |
| # displaying projected embeddings | |
| points_df = pd.DataFrame(proj_points, columns=['x', 'y']) | |
| points_df['sentence'] = st.session_state['sentences'] | |
| chart = alt.Chart(points_df) \ | |
| .mark_circle(size=100) \ | |
| .encode(x='x', y='y', tooltip='sentence') \ | |
| .properties(title='2-d Projection of Sentence Embeddings (MDS)') \ | |
| .interactive() | |
| st.altair_chart(chart) | |