Spaces:
Runtime error
Runtime error
File size: 2,496 Bytes
fa8fd97 b1ce51e a9b1577 e4d63c4 6ac9a13 b1ce51e 6ac9a13 4919963 6ac9a13 4919963 6ac9a13 a9b1577 e4d63c4 4919963 26310fb 879dab1 6ac9a13 4919963 6ac9a13 4919963 e4d63c4 4919963 26310fb 4919963 fa8fd97 e4d63c4 26310fb 6ac9a13 fa8fd97 6ac9a13 fa8fd97 e4d63c4 fa8fd97 e4d63c4 fa8fd97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import pandas as pd
import altair as alt
import streamlit as st
from sklearn.decomposition import PCA
from sklearn.manifold import MDS
from sentence_transformers import SentenceTransformer
st.set_page_config(layout='wide')
# constants
LIST_SIZE = 5
# cached resources
@st.cache_resource
def get_model() -> SentenceTransformer:
return SentenceTransformer('panalexeu/xlm-roberta-ua-distilled')
@st.cache_resource
def get_pca() -> PCA:
return PCA(n_components=2)
@st.cache_resource
def get_mds() -> MDS:
return MDS(n_components=2)
# define state
if 'sentences' not in st.session_state:
st.session_state['s_i'] = 0
st.session_state['sentences'] = []
# layout
st.title('Playground')
left_column, right_column = st.columns([1, 1], gap='large', border=True)
# sentences
with left_column:
st.subheader('Sentences')
sub_left, sub_right = st.columns(2, gap='small', vertical_alignment='bottom')
# read input
with sub_left:
if text_input := st.chat_input(placeholder='What is on your mind?', ):
if len(st.session_state['sentences']) < LIST_SIZE:
st.session_state['sentences'].append(text_input)
else: # Rewrite previous entries
st.session_state['sentences'][st.session_state['s_i']] = text_input
st.session_state['s_i'] += 1
if st.session_state['s_i'] == LIST_SIZE:
st.session_state['s_i'] = 0
with sub_right:
if clear_btn := st.button(label='Clear', icon=':material/clear:', type='primary'):
st.session_state['sentences'] = []
# display sentences
for i, sentence in enumerate(st.session_state['sentences']):
st.write(i, sentence)
# embeds display
with (right_column):
st.subheader('xlm-roberta-ua-distilled')
if compute_btn := st.button('Compute', icon='🤖'):
# calculating embeddings
embeds = get_model().encode(st.session_state['sentences'])
# projecting them
proj_points = get_mds().fit_transform(embeds)
# displaying projected embeddings
points_df = pd.DataFrame(proj_points, columns=['x', 'y'])
points_df['sentence'] = st.session_state['sentences']
chart = alt.Chart(points_df) \
.mark_circle(size=100) \
.encode(x='x', y='y', tooltip='sentence') \
.properties(title='2-d Projection of Sentence Embeddings (MDS)') \
.interactive()
st.altair_chart(chart)
|