import streamlit as st import vec2text import torch from umap import UMAP import plotly.express as px import numpy as np from streamlit_plotly_events import plotly_events import utils import pandas as pd from scipy.spatial import distance from resources import get_gtr_embeddings from transformers import PreTrainedModel, PreTrainedTokenizer dimensionality_reduction_model_name = "PCA" def diffs(embeddings: np.ndarray, corrector, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer): st.title('"A man is to king, what woman is to queen"') st.markdown("A well known pehnomenon in semantic vectors is the way we can do vector operations like addition and subtraction to find spacial relations in the vector space.") st.markdown( 'In word embedding models, we have found that the relationship between words can be captured mathematically, ' 'such that "king" is to "man" as "queen" is to "woman," demonstrating that vector arithmetic can encode analogies and semantic relationships in high-dimensional space ([Mikolov et al., 2013](https://arxiv.org/abs/1301.3781)).' ) st.markdown("This application lets you freely explore to which extent that property applies to embedding inversion models given the other factors of inaccuracy") generated_sentence = "" device = encoder.device with st.form(key="foo") as form: submit_button = st.form_submit_button("Synthesize") sent1 = st.text_input("Sentence 1", value="I am a king") st.latex("-") sent2 = st.text_input("Sentence 2", value="I am a man") st.latex("+") sent3 = st.text_input("Sentence 3", value="I am a woman") st.latex("=") if submit_button: v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer, device=encoder.device).to(device) v4 = v1 - v2 + v3 generated_sentence, = vec2text.invert_embeddings( embeddings=v4.unsqueeze(0).to(device), corrector=corrector, num_steps=20, ) generated_sentence = generated_sentence.strip() sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True) if submit_button: generated_sentence = "HI!" # st.html('Array icons created by Voysla - Flaticon') def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector, device): # Add a scatter plot using Plotly fig = px.scatter( x=vectors_2d[:, 0], y=vectors_2d[:, 1], opacity=0.6, hover_data={"Title": df["title"]}, labels={'x': f'{dimensionality_reduction_model_name} Component 1', 'y': f'{dimensionality_reduction_model_name} Component 2'}, title="UMAP Scatter Plot of Reddit Titles", color_discrete_sequence=["#ff504c"] # Set default blue color for points ) # Customize the layout to adapt to browser settings (light/dark mode) fig.update_layout( template=None, # Let Plotly adapt automatically based on user settings plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)" ) x, y = 0.0, 0.0 vec = np.array([x, y]).astype("float32") inferred_embedding = None # Add a card container to the right of the content with Streamlit columns col1, col2 = st.columns([0.6, 0.4]) # Adjusting ratio to allocate space for the card container inversion_output_text = None with col1: # Main content stays here (scatterplot, form, etc.) selected_points = plotly_events(fig, click_event=True, hover_event=False,# override_height="600", override_width="600" ) with st.form(key="form1_main"): if selected_points: clicked_point = selected_points[0] x = clicked_point['x'] y = clicked_point['y'] x = st.number_input("X Coordinate", value=x, format="%.10f") y = st.number_input("Y Coordinate", value=y, format="%.10f") vec = np.array([x, y]).astype("float32") submit_button = st.form_submit_button("Synthesize") if submit_button: inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) inferred_embedding = inferred_embedding.astype("float32") inversion_output_text, = vec2text.invert_embeddings( embeddings=torch.tensor(inferred_embedding).to(device), corrector=corrector, num_steps=20, ) else: st.text("Click on a point in the scatterplot to see its coordinates.") with col2: closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3) selected_sentence = df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else None selected_sentence_embedding = embeddings[closest_sentence_index] if closest_sentence_index > -1 else None st.markdown( f"### Selected text:\n```console\n{selected_sentence}\n```" ) st.markdown( f"### Synthesized text:\n```console\n{inversion_output_text}\n```" ) if inferred_embedding is not None and (closest_sentence_index != -1): couple = selected_sentence_embedding.squeeze(), inferred_embedding.squeeze() st.markdown("### Inferred embedding distance:") st.number_input("Euclidean", value=distance.euclidean( *couple ), disabled=True) st.number_input("Cosine", value=distance.cosine(*couple), disabled=True)