Metin commited on
Commit
d97a439
·
1 Parent(s): 39daa82

Initial commit

Browse files
Files changed (36) hide show
  1. input/test_data_with_embeddings.parquet +3 -0
  2. input/train_data_with_embeddings.parquet +3 -0
  3. input/val_data_with_embeddings.parquet +3 -0
  4. models/embedding/gte-multilingual-base/.gitattributes +36 -0
  5. models/embedding/gte-multilingual-base/1_Pooling/config.json +7 -0
  6. models/embedding/gte-multilingual-base/README.md +0 -0
  7. models/embedding/gte-multilingual-base/config.json +43 -0
  8. models/embedding/gte-multilingual-base/model.safetensors +3 -0
  9. models/embedding/gte-multilingual-base/modules.json +20 -0
  10. models/embedding/gte-multilingual-base/scripts/gte_embedding.py +154 -0
  11. models/embedding/gte-multilingual-base/sentence_bert_config.json +4 -0
  12. models/embedding/gte-multilingual-base/special_tokens_map.json +51 -0
  13. models/embedding/gte-multilingual-base/tokenizer.json +3 -0
  14. models/embedding/gte-multilingual-base/tokenizer_config.json +54 -0
  15. models/no_edge_gnn/gnn_classifier_model.pth +3 -0
  16. models/no_edge_gnn/gnn_graph_data.pt +3 -0
  17. models/no_edge_gnn/label_mapping.pt +3 -0
  18. models/no_edge_gnn/title_to_id.pt +3 -0
  19. models/undirected_gnn/gnn_classifier_model.pth +3 -0
  20. models/undirected_gnn/gnn_graph_data.pt +3 -0
  21. models/undirected_gnn/label_mapping.pt +3 -0
  22. models/undirected_gnn/title_to_id.pt +3 -0
  23. src/__pycache__/config.cpython-311.pyc +0 -0
  24. src/__pycache__/embedding.cpython-311.pyc +0 -0
  25. src/__pycache__/gnn.cpython-311.pyc +0 -0
  26. src/__pycache__/heuristic.cpython-311.pyc +0 -0
  27. src/__pycache__/utils.cpython-311.pyc +0 -0
  28. src/__pycache__/visualization.cpython-311.pyc +0 -0
  29. src/config.py +115 -0
  30. src/demo.py +265 -0
  31. src/embedding.py +31 -0
  32. src/gnn.py +164 -0
  33. src/heuristic.py +87 -0
  34. src/streamlit_app.py +0 -40
  35. src/utils.py +140 -0
  36. src/visualization.py +23 -0
input/test_data_with_embeddings.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:039cbc8a7b595f3e10a29e496bd3c44eeba116cb7c6977c86277f05a398e9dcf
3
+ size 6799466
input/train_data_with_embeddings.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e6b545f46a7ab93b8b11092b2bee5ce78973ea9c91d59b8d3e7ff88a7f5beb6
3
+ size 121744699
input/val_data_with_embeddings.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4888805e8ca378c6040a92334ca68c04449d667907c3b758e53b8ef312616fe8
3
+ size 6695324
models/embedding/gte-multilingual-base/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
models/embedding/gte-multilingual-base/1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": true,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
models/embedding/gte-multilingual-base/README.md ADDED
The diff for this file is too large to render. See raw diff
 
models/embedding/gte-multilingual-base/config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NewModel",
4
+ "NewForTokenClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "Alibaba-NLP/new-impl--configuration.NewConfig",
9
+ "AutoModelForMaskedLM": "Alibaba-NLP/new-impl--modeling.NewForMaskedLM",
10
+ "AutoModel": "Alibaba-NLP/new-impl--modeling.NewModel",
11
+ "AutoModelForMultipleChoice": "Alibaba-NLP/new-impl--modeling.NewForMultipleChoice",
12
+ "AutoModelForQuestionAnswering": "Alibaba-NLP/new-impl--modeling.NewForQuestionAnswering",
13
+ "AutoModelForSequenceClassification": "Alibaba-NLP/new-impl--modeling.NewForSequenceClassification",
14
+ "AutoModelForTokenClassification": "Alibaba-NLP/new-impl--modeling.NewForTokenClassification"
15
+ },
16
+ "classifier_dropout": 0.0,
17
+ "hidden_act": "gelu",
18
+ "hidden_dropout_prob": 0.1,
19
+ "hidden_size": 768,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 3072,
22
+ "layer_norm_eps": 1e-12,
23
+ "layer_norm_type": "layer_norm",
24
+ "max_position_embeddings": 8192,
25
+ "model_type": "new",
26
+ "num_attention_heads": 12,
27
+ "num_hidden_layers": 12,
28
+ "num_labels": 1,
29
+ "pack_qkv": true,
30
+ "pad_token_id": 1,
31
+ "position_embedding_type": "rope",
32
+ "rope_scaling": {
33
+ "factor": 8.0,
34
+ "type": "ntk"
35
+ },
36
+ "rope_theta": 20000,
37
+ "torch_dtype": "float16",
38
+ "transformers_version": "4.39.1",
39
+ "type_vocab_size": 1,
40
+ "unpad_inputs": false,
41
+ "use_memory_efficient_attention": false,
42
+ "vocab_size": 250048
43
+ }
models/embedding/gte-multilingual-base/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5a35a10faa54da7717870af1517c9b41e9bd8e3880bc5a8e9363d4c3c63e9b0
3
+ size 610753338
models/embedding/gte-multilingual-base/modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
models/embedding/gte-multilingual-base/scripts/gte_embedding.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+
5
+ from collections import defaultdict
6
+ from typing import Dict, List, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
11
+ from transformers.utils import is_torch_npu_available
12
+
13
+
14
+ class GTEEmbeddidng(torch.nn.Module):
15
+ def __init__(self,
16
+ model_name: str = None,
17
+ normalized: bool = True,
18
+ use_fp16: bool = True,
19
+ device: str = None
20
+ ):
21
+ super().__init__()
22
+ self.normalized = normalized
23
+ if device:
24
+ self.device = torch.device(device)
25
+ else:
26
+ if torch.cuda.is_available():
27
+ self.device = torch.device("cuda")
28
+ elif torch.backends.mps.is_available():
29
+ self.device = torch.device("mps")
30
+ elif is_torch_npu_available():
31
+ self.device = torch.device("npu")
32
+ else:
33
+ self.device = torch.device("cpu")
34
+ use_fp16 = False
35
+ self.use_fp16 = use_fp16
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ self.model = AutoModelForTokenClassification.from_pretrained(
38
+ model_name, trust_remote_code=True, torch_dtype=torch.float16 if self.use_fp16 else None
39
+ )
40
+ self.vocab_size = self.model.config.vocab_size
41
+ self.model.to(self.device)
42
+
43
+ def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
44
+ # conver to dict
45
+ result = defaultdict(int)
46
+ unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
47
+ self.tokenizer.unk_token_id])
48
+ # token_weights = np.ceil(token_weights * 100)
49
+ for w, idx in zip(token_weights, input_ids):
50
+ if idx not in unused_tokens and w > 0:
51
+ token = self.tokenizer.decode([int(idx)])
52
+ if w > result[token]:
53
+ result[token] = w
54
+ return result
55
+
56
+ @torch.no_grad()
57
+ def encode(self,
58
+ texts: None,
59
+ dimension: int = None,
60
+ max_length: int = 8192,
61
+ batch_size: int = 16,
62
+ return_dense: bool = True,
63
+ return_sparse: bool = False):
64
+ if dimension is None:
65
+ dimension = self.model.config.hidden_size
66
+ if isinstance(texts, str):
67
+ texts = [texts]
68
+ num_texts = len(texts)
69
+ all_dense_vecs = []
70
+ all_token_weights = []
71
+ for n, i in enumerate(range(0, num_texts, batch_size)):
72
+ batch = texts[i: i + batch_size]
73
+ resulst = self._encode(batch, dimension, max_length, batch_size, return_dense, return_sparse)
74
+ if return_dense:
75
+ all_dense_vecs.append(resulst['dense_embeddings'])
76
+ if return_sparse:
77
+ all_token_weights.extend(resulst['token_weights'])
78
+ all_dense_vecs = torch.cat(all_dense_vecs, dim=0)
79
+ return {
80
+ "dense_embeddings": all_dense_vecs,
81
+ "token_weights": all_token_weights
82
+ }
83
+
84
+ @torch.no_grad()
85
+ def _encode(self,
86
+ texts: Dict[str, torch.Tensor] = None,
87
+ dimension: int = None,
88
+ max_length: int = 1024,
89
+ batch_size: int = 16,
90
+ return_dense: bool = True,
91
+ return_sparse: bool = False):
92
+
93
+ text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
94
+ text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
95
+ model_out = self.model(**text_input, return_dict=True)
96
+
97
+ output = {}
98
+ if return_dense:
99
+ dense_vecs = model_out.last_hidden_state[:, 0, :dimension]
100
+ if self.normalized:
101
+ dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
102
+ output['dense_embeddings'] = dense_vecs
103
+ if return_sparse:
104
+ token_weights = torch.relu(model_out.logits).squeeze(-1)
105
+ token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
106
+ text_input['input_ids'].cpu().numpy().tolist()))
107
+ output['token_weights'] = token_weights
108
+
109
+ return output
110
+
111
+ def _compute_sparse_scores(self, embs1, embs2):
112
+ scores = 0
113
+ for token, weight in embs1.items():
114
+ if token in embs2:
115
+ scores += weight * embs2[token]
116
+ return scores
117
+
118
+ def compute_sparse_scores(self, embs1, embs2):
119
+ scores = [self._compute_sparse_scores(emb1, emb2) for emb1, emb2 in zip(embs1, embs2)]
120
+ return np.array(scores)
121
+
122
+ def compute_dense_scores(self, embs1, embs2):
123
+ scores = torch.sum(embs1*embs2, dim=-1).cpu().detach().numpy()
124
+ return scores
125
+
126
+ @torch.no_grad()
127
+ def compute_scores(self,
128
+ text_pairs: List[Tuple[str, str]],
129
+ dimension: int = None,
130
+ max_length: int = 1024,
131
+ batch_size: int = 16,
132
+ dense_weight=1.0,
133
+ sparse_weight=0.1):
134
+ text1_list = [text_pair[0] for text_pair in text_pairs]
135
+ text2_list = [text_pair[1] for text_pair in text_pairs]
136
+ embs1 = self.encode(text1_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
137
+ embs2 = self.encode(text2_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
138
+ scores = self.compute_dense_scores(embs1['dense_embeddings'], embs2['dense_embeddings']) * dense_weight + \
139
+ self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
140
+ scores = scores.tolist()
141
+ return scores
142
+
143
+
144
+ if __name__ == '__main__':
145
+ gte = GTEEmbeddidng('Alibaba-NLP/gte-multilingual-base')
146
+ docs = [
147
+ "黑龙江离俄罗斯很近",
148
+ "哈尔滨是中国黑龙江省的省会,位于中国东北",
149
+ "you are the hero"
150
+ ]
151
+ print('docs', docs)
152
+ embs = gte.encode(docs, return_dense=True,return_sparse=True)
153
+ print('dense vecs', embs['dense_embeddings'])
154
+ print('sparse vecs', embs['token_weights'])
models/embedding/gte-multilingual-base/sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 8192,
3
+ "do_lower_case": false
4
+ }
models/embedding/gte-multilingual-base/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
models/embedding/gte-multilingual-base/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f59925fcb90c92b894cb93e51bb9b4a6105c5c249fe54ce1c704420ac39b81af
3
+ size 17082756
models/embedding/gte-multilingual-base/tokenizer_config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 32768,
50
+ "pad_token": "<pad>",
51
+ "sep_token": "</s>",
52
+ "tokenizer_class": "XLMRobertaTokenizer",
53
+ "unk_token": "<unk>"
54
+ }
models/no_edge_gnn/gnn_classifier_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7c60400474e5da38560a2146fa9dc997a949c293816f14095120cb3e863cfeb
3
+ size 417848
models/no_edge_gnn/gnn_graph_data.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdce3a0f4ddd0abf8293ae2ddfab06bb4ece1ac0e55d113bc8c28dbad4d77cb8
3
+ size 30912782
models/no_edge_gnn/label_mapping.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd27285c0b3d71f944f4fa2d8096f4a8ddc49c34d8728ccb2a7a82bceccd99ff
3
+ size 1720
models/no_edge_gnn/title_to_id.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c66e0124c82ab11fcc6db287aa5898d8e4745d1a2467b1c5f1d3779729420e3b
3
+ size 281136
models/undirected_gnn/gnn_classifier_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a3756f5565bf83a4ed03713814c60a52fa5f94140b2e0dd493035e1ba955b13
3
+ size 417720
models/undirected_gnn/gnn_graph_data.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f0ebba135027659290f04a40ea2f46bbf7ec219a12331ec78030b6da29633a7
3
+ size 31942670
models/undirected_gnn/label_mapping.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd27285c0b3d71f944f4fa2d8096f4a8ddc49c34d8728ccb2a7a82bceccd99ff
3
+ size 1720
models/undirected_gnn/title_to_id.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c66e0124c82ab11fcc6db287aa5898d8e4745d1a2467b1c5f1d3779729420e3b
3
+ size 281136
src/__pycache__/config.cpython-311.pyc ADDED
Binary file (4.22 kB). View file
 
src/__pycache__/embedding.cpython-311.pyc ADDED
Binary file (2.87 kB). View file
 
src/__pycache__/gnn.cpython-311.pyc ADDED
Binary file (8.88 kB). View file
 
src/__pycache__/heuristic.cpython-311.pyc ADDED
Binary file (3.3 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.98 kB). View file
 
src/__pycache__/visualization.cpython-311.pyc ADDED
Binary file (1.32 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+
3
+
4
+ class Config(BaseSettings):
5
+ EMBEDDING_MODEL_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base"
6
+ TRAINING_DATA_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\input\train_data_with_embeddings.parquet"
7
+
8
+ GNN_MODEL_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\gnn_classifier_model.pth"
9
+ GNN_GRAPH_DATA_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\gnn_graph_data.pt"
10
+ LABEL_MAPPING_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\label_mapping.pt"
11
+ TITLE_TO_ID_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\title_to_id.pt"
12
+
13
+ ICON_MAPPING: dict[str, str] = {
14
+ 'Africa': 'language',
15
+ 'Americas': 'language',
16
+ 'Architecture': 'home',
17
+ 'Asia': 'language',
18
+ 'Biography': 'person',
19
+ 'Biology': 'science',
20
+ 'Business_and_economics': 'account_balance',
21
+ 'Chemistry': 'science',
22
+ 'Computing': 'laptop',
23
+ 'Earth_and_environment': 'cloud',
24
+ 'Education': 'description',
25
+ 'Engineering': 'factory',
26
+ 'Entertainment': 'campaign',
27
+ 'Europe': 'language',
28
+ 'Fashion': 'sell',
29
+ 'Films': 'monitor',
30
+ 'Food_and_drink': 'store',
31
+ 'Geographical': 'place',
32
+ 'History': 'inventory',
33
+ 'Internet_culture': 'alternate_email',
34
+ 'Libraries_&_Information': 'folder',
35
+ 'Linguistics': 'translate',
36
+ 'Literature': 'description',
37
+ 'Mathematics': 'analytics',
38
+ 'Media': 'chat',
39
+ 'Medicine_&_Health': 'medical_services',
40
+ 'Military_and_warfare': 'flag',
41
+ 'Music': 'campaign',
42
+ 'Oceania': 'language',
43
+ 'Performing_arts': 'group',
44
+ 'Philosophy_and_religion': 'assured_workload',
45
+ 'Physics': 'science',
46
+ 'Politics_and_government': 'assured_workload',
47
+ 'STEM': 'science',
48
+ 'Society': 'group',
49
+ 'Space': 'cloud',
50
+ 'Sports': 'group',
51
+ 'Technology': 'smartphone',
52
+ 'Television': 'monitor',
53
+ 'Transportation': 'directions_car',
54
+ 'Video_games': 'smartphone',
55
+ 'Visual_arts': 'description'
56
+ }
57
+
58
+ COLOR_MAPPING: dict[str, str] = {
59
+ # STEM & Natural Sciences -> Emerald (#06d6a0)
60
+ 'Biology': '#06d6a0',
61
+ 'Chemistry': '#06d6a0',
62
+ 'Earth_and_environment': '#06d6a0',
63
+ 'Mathematics': '#06d6a0',
64
+ 'Physics': '#06d6a0',
65
+ 'STEM': '#06d6a0',
66
+ 'Space': '#06d6a0',
67
+
68
+ # Geography & Places -> Ocean Blue (#118ab2)
69
+ 'Africa': '#118ab2',
70
+ 'Americas': '#118ab2',
71
+ 'Asia': '#118ab2',
72
+ 'Europe': '#118ab2',
73
+ 'Oceania': '#118ab2',
74
+ 'Geographical': '#118ab2',
75
+
76
+ # Arts, Entertainment & Culture -> Bubblegum Pink (#ef476f)
77
+ 'Entertainment': '#ef476f',
78
+ 'Fashion': '#ef476f',
79
+ 'Films': '#ef476f',
80
+ 'Music': '#ef476f',
81
+ 'Performing_arts': '#ef476f',
82
+ 'Television': '#ef476f',
83
+ 'Visual_arts': '#ef476f',
84
+ 'Literature': '#ef476f',
85
+
86
+ # Tech, Engineering & Infrastructure -> Dark Teal (#073b4c)
87
+ 'Architecture': '#073b4c',
88
+ 'Computing': '#073b4c',
89
+ 'Engineering': '#073b4c',
90
+ 'Internet_culture': '#073b4c',
91
+ 'Technology': '#073b4c',
92
+ 'Transportation': '#073b4c',
93
+ 'Video_games': '#073b4c',
94
+
95
+ # Society, Humanities & Lifestyle -> Coral Glow (#f78c6b)
96
+ 'Biography': '#f78c6b',
97
+ 'Food_and_drink': '#f78c6b',
98
+ 'Linguistics': '#f78c6b',
99
+ 'Media': '#f78c6b',
100
+ 'Medicine_&_Health': '#f78c6b',
101
+ 'Society': '#f78c6b',
102
+ 'Sports': '#f78c6b',
103
+
104
+ # Institutions, History & Governance -> Royal Gold (#ffd166)
105
+ 'Business_and_economics': '#ffd166',
106
+ 'Education': '#ffd166',
107
+ 'History': '#ffd166',
108
+ 'Libraries_&_Information': '#ffd166',
109
+ 'Military_and_warfare': '#ffd166',
110
+ 'Philosophy_and_religion': '#ffd166',
111
+ 'Politics_and_government': '#ffd166',
112
+ }
113
+
114
+
115
+ config = Config()
src/demo.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import networkx as nx
4
+ import numpy as np
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from src.config import config
8
+ from src.embedding import Embedder
9
+ from src.utils import (create_graph_from_df, gather_neighbors,
10
+ get_unique_article_titles)
11
+ from src.heuristic import predict_topic_nth_degree
12
+ from src.gnn import GNNClassifier, load_data, infer_new_node
13
+ from st_link_analysis import EdgeStyle, NodeStyle, st_link_analysis
14
+ from src.visualization import get_edge_styles, get_node_styles
15
+
16
+ import torch
17
+
18
+ st.set_page_config(
19
+ page_title="Semantic Article Graph", layout="wide", initial_sidebar_state="expanded"
20
+ )
21
+
22
+ if "setup_complete" not in st.session_state:
23
+ loader = st.empty()
24
+
25
+ with loader.container():
26
+ st.subheader("🚀 Starting...")
27
+
28
+ with st.status("Loading...", expanded=True) as status:
29
+ st.write("Initializing Embedding Model...")
30
+ embedder = Embedder(path=config.EMBEDDING_MODEL_PATH)
31
+ st.session_state.embedder = embedder
32
+
33
+ st.write("Initializing GNN Model (Undirected)...")
34
+ undirected_graph_data, undirected_title_to_id, undirected_label_mapping = load_data(version="undirected")
35
+ undirected_gnn_model = GNNClassifier(
36
+ input_dim=768,
37
+ hidden_dim=128,
38
+ layers=2,
39
+ output_dim=len(undirected_label_mapping),
40
+ dropout_rate=0.5,
41
+ )
42
+ undirected_gnn_model.load_state_dict(
43
+ torch.load(config.GNN_MODEL_PATH)
44
+ )
45
+ st.session_state.undirected_gnn_model = undirected_gnn_model
46
+ st.session_state.undirected_graph_data = undirected_graph_data
47
+ st.session_state.undirected_title_to_id = undirected_title_to_id
48
+ st.session_state.undirected_label_mapping = undirected_label_mapping
49
+
50
+ st.write("Initializing GNN Model (No Edges)...")
51
+ no_edge_graph_data, no_edge_title_to_id, no_edge_label_mapping = load_data(
52
+ version="no_edge"
53
+ )
54
+ no_edge_gnn_model = GNNClassifier(
55
+ input_dim=768,
56
+ hidden_dim=128,
57
+ layers=2,
58
+ output_dim=len(no_edge_label_mapping),
59
+ dropout_rate=0.5,
60
+ )
61
+ no_edge_gnn_model.load_state_dict(
62
+ torch.load(config.GNN_MODEL_PATH.replace("undirected_gnn", "no_edge_gnn"))
63
+ )
64
+ st.session_state.no_edge_gnn_model = no_edge_gnn_model
65
+ st.session_state.no_edge_graph_data = no_edge_graph_data
66
+ st.session_state.no_edge_title_to_id = no_edge_title_to_id
67
+ st.session_state.no_edge_label_mapping = no_edge_label_mapping
68
+
69
+ st.write("Reading training data...")
70
+ training_data = pd.read_parquet(config.TRAINING_DATA_PATH)
71
+ training_data["embedding"] = training_data["embedding"].apply(lambda x: eval(x))
72
+ st.session_state.training_data = training_data
73
+
74
+ st.write("Creating graph for visualization...")
75
+ directed_graph = create_graph_from_df(training_data, directed=True)
76
+ st.session_state.directed_graph = directed_graph
77
+ undirected_graph = create_graph_from_df(training_data, directed=False)
78
+ st.session_state.undirected_graph = undirected_graph
79
+
80
+ status.update(label="Done!", state="complete", expanded=False)
81
+
82
+ time.sleep(0.5)
83
+
84
+ loader.empty()
85
+ st.session_state.setup_complete = True
86
+
87
+
88
+ # node_styles = [
89
+ # NodeStyle("PERSON", "#FF7F3E", "name", "person"),
90
+ # NodeStyle("POST", "#2A629A", "content", "description"),
91
+ # ]
92
+
93
+ # edge_styles = [
94
+ # EdgeStyle("FOLLOWS", caption="label", directed=True),
95
+ # EdgeStyle("POSTED", caption="label", directed=True),
96
+ # EdgeStyle("QUOTES", caption="label", directed=True),
97
+ # ]
98
+
99
+ node_styles = get_node_styles()
100
+ edge_styles = get_edge_styles()
101
+
102
+ if "existing_nodes" not in st.session_state:
103
+ article_titles = get_unique_article_titles(st.session_state.training_data)
104
+ st.session_state.existing_nodes = article_titles
105
+
106
+ CLASSES = list(config.ICON_MAPPING.keys())
107
+
108
+
109
+ def get_dummy_probabilities():
110
+ """Generates random probabilities for the classes."""
111
+ probs = np.random.dirichlet(np.ones(len(CLASSES)), size=1)[0]
112
+ data = pd.DataFrame({"Class": CLASSES, "Score": probs})
113
+ # Sort by Score descending
114
+ return data.sort_values(by="Score", ascending=False).head(10)
115
+
116
+
117
+ st.title("📄 Semantic Article Graph")
118
+ st.markdown("---")
119
+
120
+ col_input, col_vis = st.columns([1, 2], gap="large")
121
+
122
+ with col_input:
123
+ st.subheader("1. New Node Details")
124
+
125
+ new_title = st.text_input("Node Title", placeholder="e.g., Istanbul")
126
+ new_content = st.text_area(
127
+ "Content", height=150, placeholder="Paste content here..."
128
+ )
129
+
130
+ references = st.multiselect(
131
+ "References (Select existing nodes)",
132
+ options=st.session_state.existing_nodes,
133
+ help="Search and select multiple papers this node cites.",
134
+ )
135
+
136
+ st.markdown("---")
137
+ st.subheader("2. Methodology Configuration")
138
+
139
+ method = st.selectbox(
140
+ "Select Classification Method",
141
+ ["GNN (Graph Neural Network)", "Rule-Based"],
142
+ )
143
+
144
+ model_params = {}
145
+ is_directed = False
146
+ max_depth = 2
147
+
148
+ if method == "GNN (Graph Neural Network)":
149
+ use_edges = st.checkbox("Use Graph Edges", value=True)
150
+
151
+ elif method == "Rule-Based":
152
+ max_depth = st.slider("Max Depth", 1, 3, 1)
153
+ is_weighted = st.checkbox("Apply Weights", value=True)
154
+ is_directed = st.checkbox("Use Directed Graph", value=False)
155
+ model_params = {"max_depth": max_depth, "is_weighted": is_weighted}
156
+ else:
157
+ st.warning("Please select a valid method.")
158
+
159
+ st.markdown("---")
160
+
161
+ run_inference = st.button(
162
+ "Add Node & Run Inference", type="primary", width="stretch"
163
+ )
164
+
165
+
166
+ with col_vis:
167
+ if run_inference:
168
+ if not new_title:
169
+ st.error("Please enter a title for the node.")
170
+ else:
171
+ st.subheader(f"🌐 Graph Neighborhood (k-hop)")
172
+
173
+ with st.spinner("Updating Graph Topology..."):
174
+ time.sleep(1)
175
+
176
+ graph_container = st.container(border=True)
177
+ with graph_container:
178
+ graph = (
179
+ st.session_state.directed_graph
180
+ if is_directed
181
+ else st.session_state.undirected_graph
182
+ )
183
+ elements = gather_neighbors(
184
+ graph, new_title, references, depth=max_depth
185
+ )
186
+ st_link_analysis(elements, "cose", node_styles, edge_styles)
187
+ st.caption(
188
+ f"Visualizing neighbors for: **{new_title}** with {len(references)} connections."
189
+ )
190
+
191
+ st.markdown("---")
192
+ st.subheader("📊 Classification Results")
193
+
194
+ with st.spinner(f"Running {method}..."):
195
+ time.sleep(1.5)
196
+ embedding = st.session_state.embedder.generate_embedding(new_content)
197
+ if method == "GNN (Graph Neural Network)":
198
+ base_data = st.session_state.undirected_graph_data if use_edges else st.session_state.no_edge_graph_data
199
+ title_to_id = st.session_state.undirected_title_to_id if use_edges else st.session_state.no_edge_title_to_id
200
+ label_mapping = st.session_state.undirected_label_mapping if use_edges else st.session_state.no_edge_label_mapping
201
+ model = st.session_state.undirected_gnn_model if use_edges else st.session_state.no_edge_gnn_model
202
+ df_results = infer_new_node(
203
+ base_data=base_data,
204
+ model=model,
205
+ new_embedding=embedding,
206
+ referenced_titles=references,
207
+ title_to_id=title_to_id,
208
+ label_mapping=label_mapping,
209
+ device=torch.device("cpu"),
210
+ make_undirected_for_new_node=not is_directed,
211
+ use_edges=use_edges,
212
+ )
213
+ elif method == "Rule-Based":
214
+ graph = (
215
+ st.session_state.directed_graph
216
+ if is_directed
217
+ else st.session_state.undirected_graph
218
+ )
219
+ df_results = predict_topic_nth_degree(
220
+ new_article_title=new_title,
221
+ new_article_embedding=embedding,
222
+ edges=references,
223
+ G=graph,
224
+ decay_factor=1.0,
225
+ **model_params,
226
+ )
227
+ else:
228
+ st.error("Invalid method selected.")
229
+ st.stop()
230
+
231
+
232
+ top_class = df_results.iloc[0]
233
+ st.success(
234
+ f"**Predicted Class:** {top_class['Class']} ({top_class['Score']:.2%})"
235
+ )
236
+
237
+ st.dataframe(
238
+ df_results,
239
+ column_config={
240
+ "Class": "Class Name",
241
+ "Score": st.column_config.ProgressColumn(
242
+ "Confidence",
243
+ help="The model's confidence score",
244
+ format="%.2f",
245
+ min_value=0,
246
+ max_value=1,
247
+ ),
248
+ },
249
+ hide_index=True,
250
+ width="stretch",
251
+ )
252
+
253
+ else:
254
+ st.info(
255
+ "👈 Enter node details on the left and click 'Add' to see the graph and predictions."
256
+ )
257
+ st.markdown(
258
+ """
259
+ <div style="height: 600px; border: 2px dashed #ccc; border-radius: 10px;
260
+ display: flex; align-items: center; justify-content: center; color: #ccc;">
261
+ Waiting for input...
262
+ </div>
263
+ """,
264
+ unsafe_allow_html=True,
265
+ )
src/embedding.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+
10
+ class Embedder:
11
+ def __init__(self, path):
12
+ # time.sleep(1)
13
+ self.model_name_or_path = path
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
16
+ self.model = AutoModel.from_pretrained(
17
+ self.model_name_or_path, trust_remote_code=True
18
+ )
19
+ self.model.to(self.device)
20
+
21
+ def generate_embedding(self, text):
22
+ inputs = self.tokenizer(
23
+ text, max_length=8192, padding=True, truncation=True, return_tensors="pt"
24
+ )
25
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
26
+ with torch.no_grad():
27
+ outputs = self.model(**inputs)
28
+ dimension = 768
29
+ embeddings = outputs.last_hidden_state[:, 0][:dimension]
30
+ normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
31
+ return normalized_embeddings.squeeze().cpu().numpy()
src/gnn.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch_geometric.nn import GCNConv
4
+ from torch_geometric.data import Data
5
+ import pandas as pd
6
+
7
+ from src.config import config
8
+
9
+ class GNNClassifier(torch.nn.Module):
10
+ def __init__(self, input_dim, hidden_dim, layers, output_dim, dropout_rate=0.5):
11
+ super().__init__()
12
+ self.dropout_rate = dropout_rate
13
+ self.hidden_dim = hidden_dim
14
+ self.layers = layers
15
+ self.output_dim = output_dim
16
+
17
+ # IMPROVEMENT 1: Reduce to 2 layers to prevent over-smoothing
18
+ # If you really need 3 layers, you must add Residual Connections (x = x + conv(x))
19
+ if layers == 2:
20
+ self.conv1 = GCNConv(input_dim, hidden_dim)
21
+ self.conv2 = GCNConv(hidden_dim, output_dim)
22
+ elif layers == 3:
23
+ self.conv1 = GCNConv(input_dim, hidden_dim)
24
+ self.conv2 = GCNConv(hidden_dim, hidden_dim)
25
+ self.conv3 = GCNConv(hidden_dim, output_dim)
26
+
27
+ def forward(self, data):
28
+ x, edge_index = data.x, data.edge_index
29
+
30
+ # Layer 1
31
+ x = self.conv1(x, edge_index)
32
+ x = F.relu(x)
33
+
34
+ # IMPROVEMENT 2: Higher Dropout (0.5 is standard for citation networks)
35
+ # This prevents the model from relying too much on specific neighbor connections
36
+ x = F.dropout(x, p=self.dropout_rate, training=self.training)
37
+
38
+ # Layer 2
39
+ x = self.conv2(x, edge_index)
40
+
41
+ if self.layers == 3:
42
+ x = F.relu(x)
43
+ x = F.dropout(x, p=self.dropout_rate, training=self.training)
44
+ x = self.conv3(x, edge_index)
45
+
46
+ return x
47
+
48
+ def load_data(version: str = "undirected"):
49
+
50
+ if version == "undirected":
51
+ graph_data = torch.load(config.GNN_GRAPH_DATA_PATH)
52
+ title_to_id = torch.load(config.TITLE_TO_ID_PATH)
53
+ label_mapping = torch.load(config.LABEL_MAPPING_PATH)
54
+ elif version == "no_edge":
55
+ graph_data = torch.load(config.GNN_GRAPH_DATA_PATH.replace("undirected_gnn", "no_edge_gnn"))
56
+ title_to_id = torch.load(config.TITLE_TO_ID_PATH.replace("undirected_gnn", "no_edge_gnn"))
57
+ label_mapping = torch.load(config.LABEL_MAPPING_PATH.replace("undirected_gnn", "no_edge_gnn"))
58
+ else:
59
+ raise ValueError(f"Unknown version: {version}")
60
+
61
+ return graph_data, title_to_id, label_mapping
62
+
63
+ def infer_new_node(
64
+ base_data: Data,
65
+ model: torch.nn.Module,
66
+ new_embedding, # shape (768,) list/np array/torch
67
+ referenced_titles: list[str], # titles the user selected
68
+ title_to_id: dict[str, int],
69
+ label_mapping: dict[str, int],
70
+ device: torch.device,
71
+ make_undirected_for_new_node: bool = True,
72
+ use_edges: bool = True,
73
+ ):
74
+ model.eval()
75
+
76
+ # Move model to device
77
+ model = model.to(device)
78
+ base_data = base_data.to(device)
79
+
80
+ # --- 1) Prepare new node feature ---
81
+ x_old = base_data.x
82
+ new_x = torch.tensor(new_embedding, dtype=x_old.dtype).view(1, -1)
83
+ new_x = new_x.to(device)
84
+ x = torch.cat([x_old, new_x], dim=0)
85
+
86
+ new_id = x.size(0) - 1
87
+
88
+ # --- 2) Build new edges that attach the node ---
89
+ src_list = []
90
+ tgt_list = []
91
+
92
+ for t in referenced_titles:
93
+ if t not in title_to_id:
94
+ continue
95
+ old_id = title_to_id[t]
96
+
97
+ # If you want new node to be influenced by referenced nodes in 1 hop,
98
+ # you need edges old -> new (incoming to new).
99
+ src_list.append(old_id)
100
+ tgt_list.append(new_id)
101
+
102
+ # Optional: also add new -> old to make it undirected / symmetric
103
+ if make_undirected_for_new_node:
104
+ src_list.append(new_id)
105
+ tgt_list.append(old_id)
106
+
107
+ # If the user picked nothing, the node is isolated; GCNConv can still work
108
+ # because it adds self-loops by default, but performance may be weak.
109
+
110
+ if len(src_list) > 0 and use_edges:
111
+ new_edges = torch.tensor([src_list, tgt_list], dtype=torch.long)
112
+ new_edges = new_edges.to(device)
113
+ edge_index = torch.cat([base_data.edge_index, new_edges], dim=1)
114
+ else:
115
+ edge_index = base_data.edge_index
116
+
117
+ # --- 3) Run inference on the augmented graph ---
118
+ data_aug = Data(x=x, edge_index=edge_index).to(device)
119
+
120
+ with torch.no_grad():
121
+ out = model(data_aug) # your model returns raw logits
122
+ log_probs = F.log_softmax(out, dim=1)
123
+ log_probs = log_probs[new_id] # get log-probs for the new node only
124
+ pred_id = int(torch.argmax(log_probs).item())
125
+
126
+ inv_label_mapping = {v: k for k, v in label_mapping.items()}
127
+ pred_label = inv_label_mapping[pred_id]
128
+
129
+ probs = log_probs.exp().detach().cpu() # convert log-probs -> probs
130
+
131
+ columns = ["Class", "Score"]
132
+ result_df = pd.DataFrame(
133
+ [(inv_label_mapping[i], prob.item()) for i, prob in enumerate(probs)],
134
+ columns=columns,
135
+ ).sort_values(by="Score", ascending=False)
136
+
137
+ return result_df
138
+
139
+ if __name__ == "__main__":
140
+ from src.embedding import Embedder
141
+ graph_data, title_to_id, label_mapping = load_data()
142
+
143
+ model = GNNClassifier(input_dim=768, hidden_dim=128, layers=2, output_dim=len(label_mapping), dropout_rate=0.5)
144
+ model.load_state_dict(torch.load(r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\gnn\gnn_classifier_model.pth"))
145
+
146
+ new_node_content = "Istanbul Türkiye'nin en büyük şehri ve kültürel başkentidir. Tarih boyunca birçok medeniyete ev sahipliği yapmıştır."
147
+ embedder = Embedder(path=r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base")
148
+ new_embedding = embedder.generate_embedding(new_node_content)
149
+ referenced_titles = ["forum istanbul", "istanbul film festivali", "akıllı şehir"]
150
+
151
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
+ result = infer_new_node(
153
+ base_data=graph_data,
154
+ model=model,
155
+ new_embedding=new_embedding,
156
+ referenced_titles=referenced_titles,
157
+ title_to_id=title_to_id,
158
+ label_mapping=label_mapping,
159
+ device=device,
160
+ make_undirected_for_new_node=True,
161
+ )
162
+
163
+ print("Prediction Results for New Node:")
164
+ print(result)
src/heuristic.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+
4
+ from typing import List, Any, Optional
5
+ from collections import defaultdict, deque
6
+
7
+ def predict_topic_nth_degree(
8
+ new_article_title: str,
9
+ new_article_embedding: List[float],
10
+ edges: List[str],
11
+ G: Any,
12
+ max_depth: int = 1,
13
+ is_weighted: bool = False,
14
+ decay_factor: float = 1.0,
15
+ ) -> Optional[str]:
16
+ """
17
+ Predicts topic based on neighbors up to n-degrees away.
18
+
19
+ Args:
20
+ max_depth: How many hops to traverse (1 = direct neighbors, 2 = neighbors of neighbors).
21
+ decay_factor: Multiplier for distance. 1.0 = no decay.
22
+ 0.5 means a neighbor at depth 2 has half the voting power of depth 1.
23
+ """
24
+
25
+ # 1. Setup BFS
26
+ # Queue stores: (current_node_name, current_depth)
27
+ queue = deque()
28
+
29
+ # We maintain a visited set to avoid cycles and processing the same node twice
30
+ visited = set()
31
+ visited.add(new_article_title)
32
+
33
+ # 2. Initialize BFS with the "Virtual" First Hop
34
+ # We iterate the input list 'edges' manually because the new article isn't in G.
35
+ for ref in edges:
36
+ if ref in G and ref not in visited:
37
+ visited.add(ref)
38
+ queue.append((ref, 1)) # Depth 1
39
+
40
+ if not queue:
41
+ return None
42
+
43
+ topic_scores = defaultdict(float)
44
+
45
+ # 3. Process BFS
46
+ while queue:
47
+ current_node, current_depth = queue.popleft()
48
+
49
+ # --- Score Calculation ---
50
+ node_data = G.nodes[current_node]
51
+ topic = node_data.get("label")
52
+
53
+ if topic:
54
+ # Determine base weight
55
+ if is_weighted:
56
+ neighbor_embedding = node_data["embedding"]
57
+ # Calculate similarity
58
+ base_score = cosine_similarity(
59
+ [new_article_embedding], [neighbor_embedding]
60
+ )[0][0]
61
+ else:
62
+ base_score = 1.0
63
+
64
+ # Apply Distance Decay
65
+ # Formula: Score * (decay ^ (depth - 1))
66
+ # Depth 1: Score * 1
67
+ # Depth 2: Score * decay
68
+ weighted_score = base_score * (decay_factor ** (current_depth - 1))
69
+
70
+ topic_scores[topic] += weighted_score
71
+
72
+ # --- Expand to next level if within limit ---
73
+ if current_depth < max_depth:
74
+ for neighbor in G.neighbors(current_node):
75
+ if neighbor not in visited:
76
+ visited.add(neighbor)
77
+ queue.append((neighbor, current_depth + 1))
78
+
79
+ # 4. Determine Winner
80
+ if not topic_scores:
81
+ return None
82
+
83
+ columns = ["Class", "Score"]
84
+ result_df = pd.DataFrame(
85
+ [(topic, score) for topic, score in topic_scores.items()], columns=columns
86
+ ).sort_values(by="Score", ascending=False)
87
+ return result_df
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import pandas as pd
3
+
4
+
5
+ def get_unique_article_titles(df: pd.DataFrame) -> list[str]:
6
+ unique_articles = df["article_title_processed"].unique()
7
+ unique_articles_sorted = sorted(unique_articles.tolist())
8
+ return unique_articles_sorted
9
+
10
+
11
+ def create_graph_from_df(df, directed: bool = False) -> nx.Graph:
12
+ G = nx.Graph()
13
+ for i, row in df.iterrows():
14
+ node_title = row["article_title_processed"]
15
+ node_class = row["predicted_topic"]
16
+ G.add_node(node_title, label=node_class, embedding=row["embedding"])
17
+
18
+ for i, row in df.iterrows():
19
+ node_title = row["article_title_processed"]
20
+ references = eval(row["links_processed"])
21
+
22
+ for ref in references:
23
+ if ref in G and ref != node_title:
24
+ G.add_edge(node_title, ref)
25
+
26
+ if not directed:
27
+ G.add_edge(ref, node_title)
28
+
29
+ return G
30
+
31
+
32
+ def gather_neighbors(
33
+ graph: nx.DiGraph, node_title: str, references: list[str], depth: int = 1
34
+ ):
35
+ neighbors = set()
36
+
37
+ modified_graph = graph.copy()
38
+
39
+ modified_graph.add_node(node_title)
40
+
41
+ for ref in references:
42
+ if ref in modified_graph and ref != node_title:
43
+ modified_graph.add_edge(node_title, ref)
44
+
45
+ neighbors = get_neighbors_for_visualizer(modified_graph, node_title, depth=depth)
46
+
47
+ return neighbors
48
+
49
+
50
+ def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
51
+ """
52
+ Returns the neighbors of a node within a given depth in a format
53
+ compatible with Cytoscape-style visualizers.
54
+
55
+ Args:
56
+ graph (nx.Graph): The source NetworkX graph.
57
+ start_node: The title/ID of the node to start from.
58
+ depth (int): How many hops (degrees of separation) to traverse.
59
+
60
+ Returns:
61
+ dict: A dictionary containing 'nodes' and 'edges' formatted for the visualizer.
62
+ """
63
+
64
+ # 1. Create a subgraph of neighbors within the specified depth
65
+ # If the node doesn't exist, return empty structure or raise error
66
+ if start_node not in graph:
67
+ return {"nodes": [], "edges": []}
68
+
69
+ subgraph = nx.ego_graph(graph, start_node, radius=depth)
70
+
71
+ # 2. Prepare data structures
72
+ nodes_data = []
73
+ edges_data = []
74
+
75
+ # Helper to map actual node names (titles) to integer IDs required by the format
76
+ # The example uses 1-based integers for IDs.
77
+ node_to_id_map = {}
78
+ current_id = 1
79
+
80
+ # 3. Process Nodes
81
+ for node in subgraph.nodes():
82
+ # Assign an integer ID
83
+ node_to_id_map[node] = current_id
84
+
85
+ # Get attributes (safely default if label is missing)
86
+ # We ignore 'embedding' as requested
87
+ node_attrs = subgraph.nodes[node]
88
+ label = node_attrs.get("label", "Unknown")
89
+
90
+ node_obj = {
91
+ "data": {
92
+ "id": current_id,
93
+ "label": label,
94
+ "name": str(node), # Using the node title/ID as 'name'
95
+ }
96
+ }
97
+ nodes_data.append(node_obj)
98
+ current_id += 1
99
+
100
+ # 4. Process Edges
101
+ # Edge IDs usually need to be unique strings or integers.
102
+ # We continue the counter from where nodes left off to ensure uniqueness.
103
+ edge_id_counter = current_id
104
+
105
+ for u, v in subgraph.edges():
106
+ source_id = node_to_id_map[u]
107
+ target_id = node_to_id_map[v]
108
+
109
+ # Get edge attributes if they exist (e.g., relationship type)
110
+ edge_attrs = subgraph.edges[u, v]
111
+ edge_label = edge_attrs.get("label", "CITES") # Default label if none exists
112
+
113
+ edge_obj = {
114
+ "data": {
115
+ "id": edge_id_counter,
116
+ "label": edge_label,
117
+ "source": source_id,
118
+ "target": target_id,
119
+ }
120
+ }
121
+ edges_data.append(edge_obj)
122
+ edge_id_counter += 1
123
+
124
+ # 5. Return the final structure
125
+ return {"nodes": nodes_data, "edges": edges_data}
126
+
127
+
128
+ if __name__ == "__main__":
129
+ data = pd.read_parquet(
130
+ r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\input\train_data_with_embeddings.parquet"
131
+ )
132
+ graph = create_graph_from_df(data)
133
+
134
+ test_title = "Sample Article Title"
135
+ test_references = ["finansal matematik", "genel yapay zekâ", "andrej karpathy"]
136
+
137
+ neighbors = gather_neighbors(graph, test_title, test_references, depth=2)
138
+
139
+ # print(f"References for '{test_title}': {test_references}")
140
+ print(f"Neighbors of '{test_title}': {neighbors}")
src/visualization.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from st_link_analysis import EdgeStyle, NodeStyle
2
+ from src.config import config
3
+
4
+ def get_node_styles() -> list[NodeStyle]:
5
+ node_styles = []
6
+ for class_name in config.ICON_MAPPING.keys():
7
+ color = config.COLOR_MAPPING.get(class_name, "#888888") # Default gray if not found
8
+ icon = config.ICON_MAPPING.get(class_name, None)
9
+ node_styles.append(NodeStyle(
10
+ label=class_name,
11
+ color=color,
12
+ icon=icon,
13
+ ))
14
+ return node_styles
15
+
16
+ def get_edge_styles() -> list[EdgeStyle]:
17
+ edge_styles = [
18
+ EdgeStyle(
19
+ label="CITES",
20
+ )
21
+ ]
22
+
23
+ return edge_styles