remiai3 commited on
Commit
631a798
·
verified ·
1 Parent(s): 164f469

Upload 2 files

Browse files
Files changed (2) hide show
  1. data/sample_sentences.txt +8 -0
  2. src/visualizer.py +41 -0
data/sample_sentences.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Artificial intelligence is transforming the world.
2
+ Cats are amazing pets.
3
+ The capital of France is Paris.
4
+ The Eiffel Tower is in France.
5
+ Deep learning enables image recognition.
6
+ Dogs are loyal companions.
7
+ The sun rises in the east.
8
+ The moon orbits the Earth.
src/visualizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sentence_transformers import SentenceTransformer
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.decomposition import PCA
5
+ from sklearn.manifold import TSNE
6
+
7
+ # Detect device
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ print(f"Using device: {device}")
10
+
11
+ # Load dataset
12
+ with open("../data/sample_sentences.txt", "r", encoding="utf-8") as f:
13
+ sentences = [line.strip() for line in f if line.strip()]
14
+
15
+ # Load embedding model
16
+ model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
17
+
18
+ # Create embeddings
19
+ embeddings = model.encode(sentences)
20
+
21
+ # PCA Visualization
22
+ pca = PCA(n_components=2)
23
+ pca_result = pca.fit_transform(embeddings)
24
+
25
+ plt.figure(figsize=(8,6))
26
+ plt.scatter(pca_result[:,0], pca_result[:,1])
27
+ for i, txt in enumerate(sentences):
28
+ plt.annotate(txt, (pca_result[i,0], pca_result[i,1]))
29
+ plt.title("Text Embeddings (PCA)")
30
+ plt.show()
31
+
32
+ # t-SNE Visualization
33
+ tsne = TSNE(n_components=2, random_state=42, perplexity=5)
34
+ tsne_result = tsne.fit_transform(embeddings)
35
+
36
+ plt.figure(figsize=(8,6))
37
+ plt.scatter(tsne_result[:,0], tsne_result[:,1])
38
+ for i, txt in enumerate(sentences):
39
+ plt.annotate(txt, (tsne_result[i,0], tsne_result[i,1]))
40
+ plt.title("Text Embeddings (t-SNE)")
41
+ plt.show()