Spaces:
Sleeping
Sleeping
Initial commit
Browse files- input/test_data_with_embeddings.parquet +3 -0
- input/train_data_with_embeddings.parquet +3 -0
- input/val_data_with_embeddings.parquet +3 -0
- models/embedding/gte-multilingual-base/.gitattributes +36 -0
- models/embedding/gte-multilingual-base/1_Pooling/config.json +7 -0
- models/embedding/gte-multilingual-base/README.md +0 -0
- models/embedding/gte-multilingual-base/config.json +43 -0
- models/embedding/gte-multilingual-base/model.safetensors +3 -0
- models/embedding/gte-multilingual-base/modules.json +20 -0
- models/embedding/gte-multilingual-base/scripts/gte_embedding.py +154 -0
- models/embedding/gte-multilingual-base/sentence_bert_config.json +4 -0
- models/embedding/gte-multilingual-base/special_tokens_map.json +51 -0
- models/embedding/gte-multilingual-base/tokenizer.json +3 -0
- models/embedding/gte-multilingual-base/tokenizer_config.json +54 -0
- models/no_edge_gnn/gnn_classifier_model.pth +3 -0
- models/no_edge_gnn/gnn_graph_data.pt +3 -0
- models/no_edge_gnn/label_mapping.pt +3 -0
- models/no_edge_gnn/title_to_id.pt +3 -0
- models/undirected_gnn/gnn_classifier_model.pth +3 -0
- models/undirected_gnn/gnn_graph_data.pt +3 -0
- models/undirected_gnn/label_mapping.pt +3 -0
- models/undirected_gnn/title_to_id.pt +3 -0
- src/__pycache__/config.cpython-311.pyc +0 -0
- src/__pycache__/embedding.cpython-311.pyc +0 -0
- src/__pycache__/gnn.cpython-311.pyc +0 -0
- src/__pycache__/heuristic.cpython-311.pyc +0 -0
- src/__pycache__/utils.cpython-311.pyc +0 -0
- src/__pycache__/visualization.cpython-311.pyc +0 -0
- src/config.py +115 -0
- src/demo.py +265 -0
- src/embedding.py +31 -0
- src/gnn.py +164 -0
- src/heuristic.py +87 -0
- src/streamlit_app.py +0 -40
- src/utils.py +140 -0
- src/visualization.py +23 -0
input/test_data_with_embeddings.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:039cbc8a7b595f3e10a29e496bd3c44eeba116cb7c6977c86277f05a398e9dcf
|
| 3 |
+
size 6799466
|
input/train_data_with_embeddings.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e6b545f46a7ab93b8b11092b2bee5ce78973ea9c91d59b8d3e7ff88a7f5beb6
|
| 3 |
+
size 121744699
|
input/val_data_with_embeddings.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4888805e8ca378c6040a92334ca68c04449d667907c3b758e53b8ef312616fe8
|
| 3 |
+
size 6695324
|
models/embedding/gte-multilingual-base/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
models/embedding/gte-multilingual-base/1_Pooling/config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"word_embedding_dimension": 768,
|
| 3 |
+
"pooling_mode_cls_token": true,
|
| 4 |
+
"pooling_mode_mean_tokens": false,
|
| 5 |
+
"pooling_mode_max_tokens": false,
|
| 6 |
+
"pooling_mode_mean_sqrt_len_tokens": false
|
| 7 |
+
}
|
models/embedding/gte-multilingual-base/README.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/embedding/gte-multilingual-base/config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"NewModel",
|
| 4 |
+
"NewForTokenClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "Alibaba-NLP/new-impl--configuration.NewConfig",
|
| 9 |
+
"AutoModelForMaskedLM": "Alibaba-NLP/new-impl--modeling.NewForMaskedLM",
|
| 10 |
+
"AutoModel": "Alibaba-NLP/new-impl--modeling.NewModel",
|
| 11 |
+
"AutoModelForMultipleChoice": "Alibaba-NLP/new-impl--modeling.NewForMultipleChoice",
|
| 12 |
+
"AutoModelForQuestionAnswering": "Alibaba-NLP/new-impl--modeling.NewForQuestionAnswering",
|
| 13 |
+
"AutoModelForSequenceClassification": "Alibaba-NLP/new-impl--modeling.NewForSequenceClassification",
|
| 14 |
+
"AutoModelForTokenClassification": "Alibaba-NLP/new-impl--modeling.NewForTokenClassification"
|
| 15 |
+
},
|
| 16 |
+
"classifier_dropout": 0.0,
|
| 17 |
+
"hidden_act": "gelu",
|
| 18 |
+
"hidden_dropout_prob": 0.1,
|
| 19 |
+
"hidden_size": 768,
|
| 20 |
+
"initializer_range": 0.02,
|
| 21 |
+
"intermediate_size": 3072,
|
| 22 |
+
"layer_norm_eps": 1e-12,
|
| 23 |
+
"layer_norm_type": "layer_norm",
|
| 24 |
+
"max_position_embeddings": 8192,
|
| 25 |
+
"model_type": "new",
|
| 26 |
+
"num_attention_heads": 12,
|
| 27 |
+
"num_hidden_layers": 12,
|
| 28 |
+
"num_labels": 1,
|
| 29 |
+
"pack_qkv": true,
|
| 30 |
+
"pad_token_id": 1,
|
| 31 |
+
"position_embedding_type": "rope",
|
| 32 |
+
"rope_scaling": {
|
| 33 |
+
"factor": 8.0,
|
| 34 |
+
"type": "ntk"
|
| 35 |
+
},
|
| 36 |
+
"rope_theta": 20000,
|
| 37 |
+
"torch_dtype": "float16",
|
| 38 |
+
"transformers_version": "4.39.1",
|
| 39 |
+
"type_vocab_size": 1,
|
| 40 |
+
"unpad_inputs": false,
|
| 41 |
+
"use_memory_efficient_attention": false,
|
| 42 |
+
"vocab_size": 250048
|
| 43 |
+
}
|
models/embedding/gte-multilingual-base/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5a35a10faa54da7717870af1517c9b41e9bd8e3880bc5a8e9363d4c3c63e9b0
|
| 3 |
+
size 610753338
|
models/embedding/gte-multilingual-base/modules.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"idx": 0,
|
| 4 |
+
"name": "0",
|
| 5 |
+
"path": "",
|
| 6 |
+
"type": "sentence_transformers.models.Transformer"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"idx": 1,
|
| 10 |
+
"name": "1",
|
| 11 |
+
"path": "1_Pooling",
|
| 12 |
+
"type": "sentence_transformers.models.Pooling"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"idx": 2,
|
| 16 |
+
"name": "2",
|
| 17 |
+
"path": "2_Normalize",
|
| 18 |
+
"type": "sentence_transformers.models.Normalize"
|
| 19 |
+
}
|
| 20 |
+
]
|
models/embedding/gte-multilingual-base/scripts/gte_embedding.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The GTE Team Authors and Alibaba Group.
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
| 11 |
+
from transformers.utils import is_torch_npu_available
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GTEEmbeddidng(torch.nn.Module):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
model_name: str = None,
|
| 17 |
+
normalized: bool = True,
|
| 18 |
+
use_fp16: bool = True,
|
| 19 |
+
device: str = None
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.normalized = normalized
|
| 23 |
+
if device:
|
| 24 |
+
self.device = torch.device(device)
|
| 25 |
+
else:
|
| 26 |
+
if torch.cuda.is_available():
|
| 27 |
+
self.device = torch.device("cuda")
|
| 28 |
+
elif torch.backends.mps.is_available():
|
| 29 |
+
self.device = torch.device("mps")
|
| 30 |
+
elif is_torch_npu_available():
|
| 31 |
+
self.device = torch.device("npu")
|
| 32 |
+
else:
|
| 33 |
+
self.device = torch.device("cpu")
|
| 34 |
+
use_fp16 = False
|
| 35 |
+
self.use_fp16 = use_fp16
|
| 36 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 37 |
+
self.model = AutoModelForTokenClassification.from_pretrained(
|
| 38 |
+
model_name, trust_remote_code=True, torch_dtype=torch.float16 if self.use_fp16 else None
|
| 39 |
+
)
|
| 40 |
+
self.vocab_size = self.model.config.vocab_size
|
| 41 |
+
self.model.to(self.device)
|
| 42 |
+
|
| 43 |
+
def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
|
| 44 |
+
# conver to dict
|
| 45 |
+
result = defaultdict(int)
|
| 46 |
+
unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
|
| 47 |
+
self.tokenizer.unk_token_id])
|
| 48 |
+
# token_weights = np.ceil(token_weights * 100)
|
| 49 |
+
for w, idx in zip(token_weights, input_ids):
|
| 50 |
+
if idx not in unused_tokens and w > 0:
|
| 51 |
+
token = self.tokenizer.decode([int(idx)])
|
| 52 |
+
if w > result[token]:
|
| 53 |
+
result[token] = w
|
| 54 |
+
return result
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def encode(self,
|
| 58 |
+
texts: None,
|
| 59 |
+
dimension: int = None,
|
| 60 |
+
max_length: int = 8192,
|
| 61 |
+
batch_size: int = 16,
|
| 62 |
+
return_dense: bool = True,
|
| 63 |
+
return_sparse: bool = False):
|
| 64 |
+
if dimension is None:
|
| 65 |
+
dimension = self.model.config.hidden_size
|
| 66 |
+
if isinstance(texts, str):
|
| 67 |
+
texts = [texts]
|
| 68 |
+
num_texts = len(texts)
|
| 69 |
+
all_dense_vecs = []
|
| 70 |
+
all_token_weights = []
|
| 71 |
+
for n, i in enumerate(range(0, num_texts, batch_size)):
|
| 72 |
+
batch = texts[i: i + batch_size]
|
| 73 |
+
resulst = self._encode(batch, dimension, max_length, batch_size, return_dense, return_sparse)
|
| 74 |
+
if return_dense:
|
| 75 |
+
all_dense_vecs.append(resulst['dense_embeddings'])
|
| 76 |
+
if return_sparse:
|
| 77 |
+
all_token_weights.extend(resulst['token_weights'])
|
| 78 |
+
all_dense_vecs = torch.cat(all_dense_vecs, dim=0)
|
| 79 |
+
return {
|
| 80 |
+
"dense_embeddings": all_dense_vecs,
|
| 81 |
+
"token_weights": all_token_weights
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def _encode(self,
|
| 86 |
+
texts: Dict[str, torch.Tensor] = None,
|
| 87 |
+
dimension: int = None,
|
| 88 |
+
max_length: int = 1024,
|
| 89 |
+
batch_size: int = 16,
|
| 90 |
+
return_dense: bool = True,
|
| 91 |
+
return_sparse: bool = False):
|
| 92 |
+
|
| 93 |
+
text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
|
| 94 |
+
text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
|
| 95 |
+
model_out = self.model(**text_input, return_dict=True)
|
| 96 |
+
|
| 97 |
+
output = {}
|
| 98 |
+
if return_dense:
|
| 99 |
+
dense_vecs = model_out.last_hidden_state[:, 0, :dimension]
|
| 100 |
+
if self.normalized:
|
| 101 |
+
dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
|
| 102 |
+
output['dense_embeddings'] = dense_vecs
|
| 103 |
+
if return_sparse:
|
| 104 |
+
token_weights = torch.relu(model_out.logits).squeeze(-1)
|
| 105 |
+
token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
|
| 106 |
+
text_input['input_ids'].cpu().numpy().tolist()))
|
| 107 |
+
output['token_weights'] = token_weights
|
| 108 |
+
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
def _compute_sparse_scores(self, embs1, embs2):
|
| 112 |
+
scores = 0
|
| 113 |
+
for token, weight in embs1.items():
|
| 114 |
+
if token in embs2:
|
| 115 |
+
scores += weight * embs2[token]
|
| 116 |
+
return scores
|
| 117 |
+
|
| 118 |
+
def compute_sparse_scores(self, embs1, embs2):
|
| 119 |
+
scores = [self._compute_sparse_scores(emb1, emb2) for emb1, emb2 in zip(embs1, embs2)]
|
| 120 |
+
return np.array(scores)
|
| 121 |
+
|
| 122 |
+
def compute_dense_scores(self, embs1, embs2):
|
| 123 |
+
scores = torch.sum(embs1*embs2, dim=-1).cpu().detach().numpy()
|
| 124 |
+
return scores
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def compute_scores(self,
|
| 128 |
+
text_pairs: List[Tuple[str, str]],
|
| 129 |
+
dimension: int = None,
|
| 130 |
+
max_length: int = 1024,
|
| 131 |
+
batch_size: int = 16,
|
| 132 |
+
dense_weight=1.0,
|
| 133 |
+
sparse_weight=0.1):
|
| 134 |
+
text1_list = [text_pair[0] for text_pair in text_pairs]
|
| 135 |
+
text2_list = [text_pair[1] for text_pair in text_pairs]
|
| 136 |
+
embs1 = self.encode(text1_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
|
| 137 |
+
embs2 = self.encode(text2_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
|
| 138 |
+
scores = self.compute_dense_scores(embs1['dense_embeddings'], embs2['dense_embeddings']) * dense_weight + \
|
| 139 |
+
self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
|
| 140 |
+
scores = scores.tolist()
|
| 141 |
+
return scores
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if __name__ == '__main__':
|
| 145 |
+
gte = GTEEmbeddidng('Alibaba-NLP/gte-multilingual-base')
|
| 146 |
+
docs = [
|
| 147 |
+
"黑龙江离俄罗斯很近",
|
| 148 |
+
"哈尔滨是中国黑龙江省的省会,位于中国东北",
|
| 149 |
+
"you are the hero"
|
| 150 |
+
]
|
| 151 |
+
print('docs', docs)
|
| 152 |
+
embs = gte.encode(docs, return_dense=True,return_sparse=True)
|
| 153 |
+
print('dense vecs', embs['dense_embeddings'])
|
| 154 |
+
print('sparse vecs', embs['token_weights'])
|
models/embedding/gte-multilingual-base/sentence_bert_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_seq_length": 8192,
|
| 3 |
+
"do_lower_case": false
|
| 4 |
+
}
|
models/embedding/gte-multilingual-base/special_tokens_map.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"cls_token": {
|
| 10 |
+
"content": "<s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"eos_token": {
|
| 17 |
+
"content": "</s>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"mask_token": {
|
| 24 |
+
"content": "<mask>",
|
| 25 |
+
"lstrip": true,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"pad_token": {
|
| 31 |
+
"content": "<pad>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
"sep_token": {
|
| 38 |
+
"content": "</s>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
},
|
| 44 |
+
"unk_token": {
|
| 45 |
+
"content": "<unk>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false
|
| 50 |
+
}
|
| 51 |
+
}
|
models/embedding/gte-multilingual-base/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f59925fcb90c92b894cb93e51bb9b4a6105c5c249fe54ce1c704420ac39b81af
|
| 3 |
+
size 17082756
|
models/embedding/gte-multilingual-base/tokenizer_config.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<s>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<pad>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"250001": {
|
| 36 |
+
"content": "<mask>",
|
| 37 |
+
"lstrip": true,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"bos_token": "<s>",
|
| 45 |
+
"clean_up_tokenization_spaces": true,
|
| 46 |
+
"cls_token": "<s>",
|
| 47 |
+
"eos_token": "</s>",
|
| 48 |
+
"mask_token": "<mask>",
|
| 49 |
+
"model_max_length": 32768,
|
| 50 |
+
"pad_token": "<pad>",
|
| 51 |
+
"sep_token": "</s>",
|
| 52 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
| 53 |
+
"unk_token": "<unk>"
|
| 54 |
+
}
|
models/no_edge_gnn/gnn_classifier_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7c60400474e5da38560a2146fa9dc997a949c293816f14095120cb3e863cfeb
|
| 3 |
+
size 417848
|
models/no_edge_gnn/gnn_graph_data.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cdce3a0f4ddd0abf8293ae2ddfab06bb4ece1ac0e55d113bc8c28dbad4d77cb8
|
| 3 |
+
size 30912782
|
models/no_edge_gnn/label_mapping.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd27285c0b3d71f944f4fa2d8096f4a8ddc49c34d8728ccb2a7a82bceccd99ff
|
| 3 |
+
size 1720
|
models/no_edge_gnn/title_to_id.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c66e0124c82ab11fcc6db287aa5898d8e4745d1a2467b1c5f1d3779729420e3b
|
| 3 |
+
size 281136
|
models/undirected_gnn/gnn_classifier_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a3756f5565bf83a4ed03713814c60a52fa5f94140b2e0dd493035e1ba955b13
|
| 3 |
+
size 417720
|
models/undirected_gnn/gnn_graph_data.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f0ebba135027659290f04a40ea2f46bbf7ec219a12331ec78030b6da29633a7
|
| 3 |
+
size 31942670
|
models/undirected_gnn/label_mapping.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd27285c0b3d71f944f4fa2d8096f4a8ddc49c34d8728ccb2a7a82bceccd99ff
|
| 3 |
+
size 1720
|
models/undirected_gnn/title_to_id.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c66e0124c82ab11fcc6db287aa5898d8e4745d1a2467b1c5f1d3779729420e3b
|
| 3 |
+
size 281136
|
src/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
src/__pycache__/embedding.cpython-311.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|
src/__pycache__/gnn.cpython-311.pyc
ADDED
|
Binary file (8.88 kB). View file
|
|
|
src/__pycache__/heuristic.cpython-311.pyc
ADDED
|
Binary file (3.3 kB). View file
|
|
|
src/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
src/__pycache__/visualization.cpython-311.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Config(BaseSettings):
|
| 5 |
+
EMBEDDING_MODEL_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base"
|
| 6 |
+
TRAINING_DATA_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\input\train_data_with_embeddings.parquet"
|
| 7 |
+
|
| 8 |
+
GNN_MODEL_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\gnn_classifier_model.pth"
|
| 9 |
+
GNN_GRAPH_DATA_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\gnn_graph_data.pt"
|
| 10 |
+
LABEL_MAPPING_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\label_mapping.pt"
|
| 11 |
+
TITLE_TO_ID_PATH: str = r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\undirected_gnn\title_to_id.pt"
|
| 12 |
+
|
| 13 |
+
ICON_MAPPING: dict[str, str] = {
|
| 14 |
+
'Africa': 'language',
|
| 15 |
+
'Americas': 'language',
|
| 16 |
+
'Architecture': 'home',
|
| 17 |
+
'Asia': 'language',
|
| 18 |
+
'Biography': 'person',
|
| 19 |
+
'Biology': 'science',
|
| 20 |
+
'Business_and_economics': 'account_balance',
|
| 21 |
+
'Chemistry': 'science',
|
| 22 |
+
'Computing': 'laptop',
|
| 23 |
+
'Earth_and_environment': 'cloud',
|
| 24 |
+
'Education': 'description',
|
| 25 |
+
'Engineering': 'factory',
|
| 26 |
+
'Entertainment': 'campaign',
|
| 27 |
+
'Europe': 'language',
|
| 28 |
+
'Fashion': 'sell',
|
| 29 |
+
'Films': 'monitor',
|
| 30 |
+
'Food_and_drink': 'store',
|
| 31 |
+
'Geographical': 'place',
|
| 32 |
+
'History': 'inventory',
|
| 33 |
+
'Internet_culture': 'alternate_email',
|
| 34 |
+
'Libraries_&_Information': 'folder',
|
| 35 |
+
'Linguistics': 'translate',
|
| 36 |
+
'Literature': 'description',
|
| 37 |
+
'Mathematics': 'analytics',
|
| 38 |
+
'Media': 'chat',
|
| 39 |
+
'Medicine_&_Health': 'medical_services',
|
| 40 |
+
'Military_and_warfare': 'flag',
|
| 41 |
+
'Music': 'campaign',
|
| 42 |
+
'Oceania': 'language',
|
| 43 |
+
'Performing_arts': 'group',
|
| 44 |
+
'Philosophy_and_religion': 'assured_workload',
|
| 45 |
+
'Physics': 'science',
|
| 46 |
+
'Politics_and_government': 'assured_workload',
|
| 47 |
+
'STEM': 'science',
|
| 48 |
+
'Society': 'group',
|
| 49 |
+
'Space': 'cloud',
|
| 50 |
+
'Sports': 'group',
|
| 51 |
+
'Technology': 'smartphone',
|
| 52 |
+
'Television': 'monitor',
|
| 53 |
+
'Transportation': 'directions_car',
|
| 54 |
+
'Video_games': 'smartphone',
|
| 55 |
+
'Visual_arts': 'description'
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
COLOR_MAPPING: dict[str, str] = {
|
| 59 |
+
# STEM & Natural Sciences -> Emerald (#06d6a0)
|
| 60 |
+
'Biology': '#06d6a0',
|
| 61 |
+
'Chemistry': '#06d6a0',
|
| 62 |
+
'Earth_and_environment': '#06d6a0',
|
| 63 |
+
'Mathematics': '#06d6a0',
|
| 64 |
+
'Physics': '#06d6a0',
|
| 65 |
+
'STEM': '#06d6a0',
|
| 66 |
+
'Space': '#06d6a0',
|
| 67 |
+
|
| 68 |
+
# Geography & Places -> Ocean Blue (#118ab2)
|
| 69 |
+
'Africa': '#118ab2',
|
| 70 |
+
'Americas': '#118ab2',
|
| 71 |
+
'Asia': '#118ab2',
|
| 72 |
+
'Europe': '#118ab2',
|
| 73 |
+
'Oceania': '#118ab2',
|
| 74 |
+
'Geographical': '#118ab2',
|
| 75 |
+
|
| 76 |
+
# Arts, Entertainment & Culture -> Bubblegum Pink (#ef476f)
|
| 77 |
+
'Entertainment': '#ef476f',
|
| 78 |
+
'Fashion': '#ef476f',
|
| 79 |
+
'Films': '#ef476f',
|
| 80 |
+
'Music': '#ef476f',
|
| 81 |
+
'Performing_arts': '#ef476f',
|
| 82 |
+
'Television': '#ef476f',
|
| 83 |
+
'Visual_arts': '#ef476f',
|
| 84 |
+
'Literature': '#ef476f',
|
| 85 |
+
|
| 86 |
+
# Tech, Engineering & Infrastructure -> Dark Teal (#073b4c)
|
| 87 |
+
'Architecture': '#073b4c',
|
| 88 |
+
'Computing': '#073b4c',
|
| 89 |
+
'Engineering': '#073b4c',
|
| 90 |
+
'Internet_culture': '#073b4c',
|
| 91 |
+
'Technology': '#073b4c',
|
| 92 |
+
'Transportation': '#073b4c',
|
| 93 |
+
'Video_games': '#073b4c',
|
| 94 |
+
|
| 95 |
+
# Society, Humanities & Lifestyle -> Coral Glow (#f78c6b)
|
| 96 |
+
'Biography': '#f78c6b',
|
| 97 |
+
'Food_and_drink': '#f78c6b',
|
| 98 |
+
'Linguistics': '#f78c6b',
|
| 99 |
+
'Media': '#f78c6b',
|
| 100 |
+
'Medicine_&_Health': '#f78c6b',
|
| 101 |
+
'Society': '#f78c6b',
|
| 102 |
+
'Sports': '#f78c6b',
|
| 103 |
+
|
| 104 |
+
# Institutions, History & Governance -> Royal Gold (#ffd166)
|
| 105 |
+
'Business_and_economics': '#ffd166',
|
| 106 |
+
'Education': '#ffd166',
|
| 107 |
+
'History': '#ffd166',
|
| 108 |
+
'Libraries_&_Information': '#ffd166',
|
| 109 |
+
'Military_and_warfare': '#ffd166',
|
| 110 |
+
'Philosophy_and_religion': '#ffd166',
|
| 111 |
+
'Politics_and_government': '#ffd166',
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
config = Config()
|
src/demo.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import networkx as nx
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from src.config import config
|
| 8 |
+
from src.embedding import Embedder
|
| 9 |
+
from src.utils import (create_graph_from_df, gather_neighbors,
|
| 10 |
+
get_unique_article_titles)
|
| 11 |
+
from src.heuristic import predict_topic_nth_degree
|
| 12 |
+
from src.gnn import GNNClassifier, load_data, infer_new_node
|
| 13 |
+
from st_link_analysis import EdgeStyle, NodeStyle, st_link_analysis
|
| 14 |
+
from src.visualization import get_edge_styles, get_node_styles
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="Semantic Article Graph", layout="wide", initial_sidebar_state="expanded"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
if "setup_complete" not in st.session_state:
|
| 23 |
+
loader = st.empty()
|
| 24 |
+
|
| 25 |
+
with loader.container():
|
| 26 |
+
st.subheader("🚀 Starting...")
|
| 27 |
+
|
| 28 |
+
with st.status("Loading...", expanded=True) as status:
|
| 29 |
+
st.write("Initializing Embedding Model...")
|
| 30 |
+
embedder = Embedder(path=config.EMBEDDING_MODEL_PATH)
|
| 31 |
+
st.session_state.embedder = embedder
|
| 32 |
+
|
| 33 |
+
st.write("Initializing GNN Model (Undirected)...")
|
| 34 |
+
undirected_graph_data, undirected_title_to_id, undirected_label_mapping = load_data(version="undirected")
|
| 35 |
+
undirected_gnn_model = GNNClassifier(
|
| 36 |
+
input_dim=768,
|
| 37 |
+
hidden_dim=128,
|
| 38 |
+
layers=2,
|
| 39 |
+
output_dim=len(undirected_label_mapping),
|
| 40 |
+
dropout_rate=0.5,
|
| 41 |
+
)
|
| 42 |
+
undirected_gnn_model.load_state_dict(
|
| 43 |
+
torch.load(config.GNN_MODEL_PATH)
|
| 44 |
+
)
|
| 45 |
+
st.session_state.undirected_gnn_model = undirected_gnn_model
|
| 46 |
+
st.session_state.undirected_graph_data = undirected_graph_data
|
| 47 |
+
st.session_state.undirected_title_to_id = undirected_title_to_id
|
| 48 |
+
st.session_state.undirected_label_mapping = undirected_label_mapping
|
| 49 |
+
|
| 50 |
+
st.write("Initializing GNN Model (No Edges)...")
|
| 51 |
+
no_edge_graph_data, no_edge_title_to_id, no_edge_label_mapping = load_data(
|
| 52 |
+
version="no_edge"
|
| 53 |
+
)
|
| 54 |
+
no_edge_gnn_model = GNNClassifier(
|
| 55 |
+
input_dim=768,
|
| 56 |
+
hidden_dim=128,
|
| 57 |
+
layers=2,
|
| 58 |
+
output_dim=len(no_edge_label_mapping),
|
| 59 |
+
dropout_rate=0.5,
|
| 60 |
+
)
|
| 61 |
+
no_edge_gnn_model.load_state_dict(
|
| 62 |
+
torch.load(config.GNN_MODEL_PATH.replace("undirected_gnn", "no_edge_gnn"))
|
| 63 |
+
)
|
| 64 |
+
st.session_state.no_edge_gnn_model = no_edge_gnn_model
|
| 65 |
+
st.session_state.no_edge_graph_data = no_edge_graph_data
|
| 66 |
+
st.session_state.no_edge_title_to_id = no_edge_title_to_id
|
| 67 |
+
st.session_state.no_edge_label_mapping = no_edge_label_mapping
|
| 68 |
+
|
| 69 |
+
st.write("Reading training data...")
|
| 70 |
+
training_data = pd.read_parquet(config.TRAINING_DATA_PATH)
|
| 71 |
+
training_data["embedding"] = training_data["embedding"].apply(lambda x: eval(x))
|
| 72 |
+
st.session_state.training_data = training_data
|
| 73 |
+
|
| 74 |
+
st.write("Creating graph for visualization...")
|
| 75 |
+
directed_graph = create_graph_from_df(training_data, directed=True)
|
| 76 |
+
st.session_state.directed_graph = directed_graph
|
| 77 |
+
undirected_graph = create_graph_from_df(training_data, directed=False)
|
| 78 |
+
st.session_state.undirected_graph = undirected_graph
|
| 79 |
+
|
| 80 |
+
status.update(label="Done!", state="complete", expanded=False)
|
| 81 |
+
|
| 82 |
+
time.sleep(0.5)
|
| 83 |
+
|
| 84 |
+
loader.empty()
|
| 85 |
+
st.session_state.setup_complete = True
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# node_styles = [
|
| 89 |
+
# NodeStyle("PERSON", "#FF7F3E", "name", "person"),
|
| 90 |
+
# NodeStyle("POST", "#2A629A", "content", "description"),
|
| 91 |
+
# ]
|
| 92 |
+
|
| 93 |
+
# edge_styles = [
|
| 94 |
+
# EdgeStyle("FOLLOWS", caption="label", directed=True),
|
| 95 |
+
# EdgeStyle("POSTED", caption="label", directed=True),
|
| 96 |
+
# EdgeStyle("QUOTES", caption="label", directed=True),
|
| 97 |
+
# ]
|
| 98 |
+
|
| 99 |
+
node_styles = get_node_styles()
|
| 100 |
+
edge_styles = get_edge_styles()
|
| 101 |
+
|
| 102 |
+
if "existing_nodes" not in st.session_state:
|
| 103 |
+
article_titles = get_unique_article_titles(st.session_state.training_data)
|
| 104 |
+
st.session_state.existing_nodes = article_titles
|
| 105 |
+
|
| 106 |
+
CLASSES = list(config.ICON_MAPPING.keys())
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_dummy_probabilities():
|
| 110 |
+
"""Generates random probabilities for the classes."""
|
| 111 |
+
probs = np.random.dirichlet(np.ones(len(CLASSES)), size=1)[0]
|
| 112 |
+
data = pd.DataFrame({"Class": CLASSES, "Score": probs})
|
| 113 |
+
# Sort by Score descending
|
| 114 |
+
return data.sort_values(by="Score", ascending=False).head(10)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
st.title("📄 Semantic Article Graph")
|
| 118 |
+
st.markdown("---")
|
| 119 |
+
|
| 120 |
+
col_input, col_vis = st.columns([1, 2], gap="large")
|
| 121 |
+
|
| 122 |
+
with col_input:
|
| 123 |
+
st.subheader("1. New Node Details")
|
| 124 |
+
|
| 125 |
+
new_title = st.text_input("Node Title", placeholder="e.g., Istanbul")
|
| 126 |
+
new_content = st.text_area(
|
| 127 |
+
"Content", height=150, placeholder="Paste content here..."
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
references = st.multiselect(
|
| 131 |
+
"References (Select existing nodes)",
|
| 132 |
+
options=st.session_state.existing_nodes,
|
| 133 |
+
help="Search and select multiple papers this node cites.",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
st.markdown("---")
|
| 137 |
+
st.subheader("2. Methodology Configuration")
|
| 138 |
+
|
| 139 |
+
method = st.selectbox(
|
| 140 |
+
"Select Classification Method",
|
| 141 |
+
["GNN (Graph Neural Network)", "Rule-Based"],
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
model_params = {}
|
| 145 |
+
is_directed = False
|
| 146 |
+
max_depth = 2
|
| 147 |
+
|
| 148 |
+
if method == "GNN (Graph Neural Network)":
|
| 149 |
+
use_edges = st.checkbox("Use Graph Edges", value=True)
|
| 150 |
+
|
| 151 |
+
elif method == "Rule-Based":
|
| 152 |
+
max_depth = st.slider("Max Depth", 1, 3, 1)
|
| 153 |
+
is_weighted = st.checkbox("Apply Weights", value=True)
|
| 154 |
+
is_directed = st.checkbox("Use Directed Graph", value=False)
|
| 155 |
+
model_params = {"max_depth": max_depth, "is_weighted": is_weighted}
|
| 156 |
+
else:
|
| 157 |
+
st.warning("Please select a valid method.")
|
| 158 |
+
|
| 159 |
+
st.markdown("---")
|
| 160 |
+
|
| 161 |
+
run_inference = st.button(
|
| 162 |
+
"Add Node & Run Inference", type="primary", width="stretch"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
with col_vis:
|
| 167 |
+
if run_inference:
|
| 168 |
+
if not new_title:
|
| 169 |
+
st.error("Please enter a title for the node.")
|
| 170 |
+
else:
|
| 171 |
+
st.subheader(f"🌐 Graph Neighborhood (k-hop)")
|
| 172 |
+
|
| 173 |
+
with st.spinner("Updating Graph Topology..."):
|
| 174 |
+
time.sleep(1)
|
| 175 |
+
|
| 176 |
+
graph_container = st.container(border=True)
|
| 177 |
+
with graph_container:
|
| 178 |
+
graph = (
|
| 179 |
+
st.session_state.directed_graph
|
| 180 |
+
if is_directed
|
| 181 |
+
else st.session_state.undirected_graph
|
| 182 |
+
)
|
| 183 |
+
elements = gather_neighbors(
|
| 184 |
+
graph, new_title, references, depth=max_depth
|
| 185 |
+
)
|
| 186 |
+
st_link_analysis(elements, "cose", node_styles, edge_styles)
|
| 187 |
+
st.caption(
|
| 188 |
+
f"Visualizing neighbors for: **{new_title}** with {len(references)} connections."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
st.markdown("---")
|
| 192 |
+
st.subheader("📊 Classification Results")
|
| 193 |
+
|
| 194 |
+
with st.spinner(f"Running {method}..."):
|
| 195 |
+
time.sleep(1.5)
|
| 196 |
+
embedding = st.session_state.embedder.generate_embedding(new_content)
|
| 197 |
+
if method == "GNN (Graph Neural Network)":
|
| 198 |
+
base_data = st.session_state.undirected_graph_data if use_edges else st.session_state.no_edge_graph_data
|
| 199 |
+
title_to_id = st.session_state.undirected_title_to_id if use_edges else st.session_state.no_edge_title_to_id
|
| 200 |
+
label_mapping = st.session_state.undirected_label_mapping if use_edges else st.session_state.no_edge_label_mapping
|
| 201 |
+
model = st.session_state.undirected_gnn_model if use_edges else st.session_state.no_edge_gnn_model
|
| 202 |
+
df_results = infer_new_node(
|
| 203 |
+
base_data=base_data,
|
| 204 |
+
model=model,
|
| 205 |
+
new_embedding=embedding,
|
| 206 |
+
referenced_titles=references,
|
| 207 |
+
title_to_id=title_to_id,
|
| 208 |
+
label_mapping=label_mapping,
|
| 209 |
+
device=torch.device("cpu"),
|
| 210 |
+
make_undirected_for_new_node=not is_directed,
|
| 211 |
+
use_edges=use_edges,
|
| 212 |
+
)
|
| 213 |
+
elif method == "Rule-Based":
|
| 214 |
+
graph = (
|
| 215 |
+
st.session_state.directed_graph
|
| 216 |
+
if is_directed
|
| 217 |
+
else st.session_state.undirected_graph
|
| 218 |
+
)
|
| 219 |
+
df_results = predict_topic_nth_degree(
|
| 220 |
+
new_article_title=new_title,
|
| 221 |
+
new_article_embedding=embedding,
|
| 222 |
+
edges=references,
|
| 223 |
+
G=graph,
|
| 224 |
+
decay_factor=1.0,
|
| 225 |
+
**model_params,
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
st.error("Invalid method selected.")
|
| 229 |
+
st.stop()
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
top_class = df_results.iloc[0]
|
| 233 |
+
st.success(
|
| 234 |
+
f"**Predicted Class:** {top_class['Class']} ({top_class['Score']:.2%})"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
st.dataframe(
|
| 238 |
+
df_results,
|
| 239 |
+
column_config={
|
| 240 |
+
"Class": "Class Name",
|
| 241 |
+
"Score": st.column_config.ProgressColumn(
|
| 242 |
+
"Confidence",
|
| 243 |
+
help="The model's confidence score",
|
| 244 |
+
format="%.2f",
|
| 245 |
+
min_value=0,
|
| 246 |
+
max_value=1,
|
| 247 |
+
),
|
| 248 |
+
},
|
| 249 |
+
hide_index=True,
|
| 250 |
+
width="stretch",
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
else:
|
| 254 |
+
st.info(
|
| 255 |
+
"👈 Enter node details on the left and click 'Add' to see the graph and predictions."
|
| 256 |
+
)
|
| 257 |
+
st.markdown(
|
| 258 |
+
"""
|
| 259 |
+
<div style="height: 600px; border: 2px dashed #ccc; border-radius: 10px;
|
| 260 |
+
display: flex; align-items: center; justify-content: center; color: #ccc;">
|
| 261 |
+
Waiting for input...
|
| 262 |
+
</div>
|
| 263 |
+
""",
|
| 264 |
+
unsafe_allow_html=True,
|
| 265 |
+
)
|
src/embedding.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import AutoModel, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Embedder:
|
| 11 |
+
def __init__(self, path):
|
| 12 |
+
# time.sleep(1)
|
| 13 |
+
self.model_name_or_path = path
|
| 14 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
| 16 |
+
self.model = AutoModel.from_pretrained(
|
| 17 |
+
self.model_name_or_path, trust_remote_code=True
|
| 18 |
+
)
|
| 19 |
+
self.model.to(self.device)
|
| 20 |
+
|
| 21 |
+
def generate_embedding(self, text):
|
| 22 |
+
inputs = self.tokenizer(
|
| 23 |
+
text, max_length=8192, padding=True, truncation=True, return_tensors="pt"
|
| 24 |
+
)
|
| 25 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
outputs = self.model(**inputs)
|
| 28 |
+
dimension = 768
|
| 29 |
+
embeddings = outputs.last_hidden_state[:, 0][:dimension]
|
| 30 |
+
normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 31 |
+
return normalized_embeddings.squeeze().cpu().numpy()
|
src/gnn.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch_geometric.nn import GCNConv
|
| 4 |
+
from torch_geometric.data import Data
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from src.config import config
|
| 8 |
+
|
| 9 |
+
class GNNClassifier(torch.nn.Module):
|
| 10 |
+
def __init__(self, input_dim, hidden_dim, layers, output_dim, dropout_rate=0.5):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.dropout_rate = dropout_rate
|
| 13 |
+
self.hidden_dim = hidden_dim
|
| 14 |
+
self.layers = layers
|
| 15 |
+
self.output_dim = output_dim
|
| 16 |
+
|
| 17 |
+
# IMPROVEMENT 1: Reduce to 2 layers to prevent over-smoothing
|
| 18 |
+
# If you really need 3 layers, you must add Residual Connections (x = x + conv(x))
|
| 19 |
+
if layers == 2:
|
| 20 |
+
self.conv1 = GCNConv(input_dim, hidden_dim)
|
| 21 |
+
self.conv2 = GCNConv(hidden_dim, output_dim)
|
| 22 |
+
elif layers == 3:
|
| 23 |
+
self.conv1 = GCNConv(input_dim, hidden_dim)
|
| 24 |
+
self.conv2 = GCNConv(hidden_dim, hidden_dim)
|
| 25 |
+
self.conv3 = GCNConv(hidden_dim, output_dim)
|
| 26 |
+
|
| 27 |
+
def forward(self, data):
|
| 28 |
+
x, edge_index = data.x, data.edge_index
|
| 29 |
+
|
| 30 |
+
# Layer 1
|
| 31 |
+
x = self.conv1(x, edge_index)
|
| 32 |
+
x = F.relu(x)
|
| 33 |
+
|
| 34 |
+
# IMPROVEMENT 2: Higher Dropout (0.5 is standard for citation networks)
|
| 35 |
+
# This prevents the model from relying too much on specific neighbor connections
|
| 36 |
+
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
| 37 |
+
|
| 38 |
+
# Layer 2
|
| 39 |
+
x = self.conv2(x, edge_index)
|
| 40 |
+
|
| 41 |
+
if self.layers == 3:
|
| 42 |
+
x = F.relu(x)
|
| 43 |
+
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
| 44 |
+
x = self.conv3(x, edge_index)
|
| 45 |
+
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
def load_data(version: str = "undirected"):
|
| 49 |
+
|
| 50 |
+
if version == "undirected":
|
| 51 |
+
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH)
|
| 52 |
+
title_to_id = torch.load(config.TITLE_TO_ID_PATH)
|
| 53 |
+
label_mapping = torch.load(config.LABEL_MAPPING_PATH)
|
| 54 |
+
elif version == "no_edge":
|
| 55 |
+
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH.replace("undirected_gnn", "no_edge_gnn"))
|
| 56 |
+
title_to_id = torch.load(config.TITLE_TO_ID_PATH.replace("undirected_gnn", "no_edge_gnn"))
|
| 57 |
+
label_mapping = torch.load(config.LABEL_MAPPING_PATH.replace("undirected_gnn", "no_edge_gnn"))
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown version: {version}")
|
| 60 |
+
|
| 61 |
+
return graph_data, title_to_id, label_mapping
|
| 62 |
+
|
| 63 |
+
def infer_new_node(
|
| 64 |
+
base_data: Data,
|
| 65 |
+
model: torch.nn.Module,
|
| 66 |
+
new_embedding, # shape (768,) list/np array/torch
|
| 67 |
+
referenced_titles: list[str], # titles the user selected
|
| 68 |
+
title_to_id: dict[str, int],
|
| 69 |
+
label_mapping: dict[str, int],
|
| 70 |
+
device: torch.device,
|
| 71 |
+
make_undirected_for_new_node: bool = True,
|
| 72 |
+
use_edges: bool = True,
|
| 73 |
+
):
|
| 74 |
+
model.eval()
|
| 75 |
+
|
| 76 |
+
# Move model to device
|
| 77 |
+
model = model.to(device)
|
| 78 |
+
base_data = base_data.to(device)
|
| 79 |
+
|
| 80 |
+
# --- 1) Prepare new node feature ---
|
| 81 |
+
x_old = base_data.x
|
| 82 |
+
new_x = torch.tensor(new_embedding, dtype=x_old.dtype).view(1, -1)
|
| 83 |
+
new_x = new_x.to(device)
|
| 84 |
+
x = torch.cat([x_old, new_x], dim=0)
|
| 85 |
+
|
| 86 |
+
new_id = x.size(0) - 1
|
| 87 |
+
|
| 88 |
+
# --- 2) Build new edges that attach the node ---
|
| 89 |
+
src_list = []
|
| 90 |
+
tgt_list = []
|
| 91 |
+
|
| 92 |
+
for t in referenced_titles:
|
| 93 |
+
if t not in title_to_id:
|
| 94 |
+
continue
|
| 95 |
+
old_id = title_to_id[t]
|
| 96 |
+
|
| 97 |
+
# If you want new node to be influenced by referenced nodes in 1 hop,
|
| 98 |
+
# you need edges old -> new (incoming to new).
|
| 99 |
+
src_list.append(old_id)
|
| 100 |
+
tgt_list.append(new_id)
|
| 101 |
+
|
| 102 |
+
# Optional: also add new -> old to make it undirected / symmetric
|
| 103 |
+
if make_undirected_for_new_node:
|
| 104 |
+
src_list.append(new_id)
|
| 105 |
+
tgt_list.append(old_id)
|
| 106 |
+
|
| 107 |
+
# If the user picked nothing, the node is isolated; GCNConv can still work
|
| 108 |
+
# because it adds self-loops by default, but performance may be weak.
|
| 109 |
+
|
| 110 |
+
if len(src_list) > 0 and use_edges:
|
| 111 |
+
new_edges = torch.tensor([src_list, tgt_list], dtype=torch.long)
|
| 112 |
+
new_edges = new_edges.to(device)
|
| 113 |
+
edge_index = torch.cat([base_data.edge_index, new_edges], dim=1)
|
| 114 |
+
else:
|
| 115 |
+
edge_index = base_data.edge_index
|
| 116 |
+
|
| 117 |
+
# --- 3) Run inference on the augmented graph ---
|
| 118 |
+
data_aug = Data(x=x, edge_index=edge_index).to(device)
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
out = model(data_aug) # your model returns raw logits
|
| 122 |
+
log_probs = F.log_softmax(out, dim=1)
|
| 123 |
+
log_probs = log_probs[new_id] # get log-probs for the new node only
|
| 124 |
+
pred_id = int(torch.argmax(log_probs).item())
|
| 125 |
+
|
| 126 |
+
inv_label_mapping = {v: k for k, v in label_mapping.items()}
|
| 127 |
+
pred_label = inv_label_mapping[pred_id]
|
| 128 |
+
|
| 129 |
+
probs = log_probs.exp().detach().cpu() # convert log-probs -> probs
|
| 130 |
+
|
| 131 |
+
columns = ["Class", "Score"]
|
| 132 |
+
result_df = pd.DataFrame(
|
| 133 |
+
[(inv_label_mapping[i], prob.item()) for i, prob in enumerate(probs)],
|
| 134 |
+
columns=columns,
|
| 135 |
+
).sort_values(by="Score", ascending=False)
|
| 136 |
+
|
| 137 |
+
return result_df
|
| 138 |
+
|
| 139 |
+
if __name__ == "__main__":
|
| 140 |
+
from src.embedding import Embedder
|
| 141 |
+
graph_data, title_to_id, label_mapping = load_data()
|
| 142 |
+
|
| 143 |
+
model = GNNClassifier(input_dim=768, hidden_dim=128, layers=2, output_dim=len(label_mapping), dropout_rate=0.5)
|
| 144 |
+
model.load_state_dict(torch.load(r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\gnn\gnn_classifier_model.pth"))
|
| 145 |
+
|
| 146 |
+
new_node_content = "Istanbul Türkiye'nin en büyük şehri ve kültürel başkentidir. Tarih boyunca birçok medeniyete ev sahipliği yapmıştır."
|
| 147 |
+
embedder = Embedder(path=r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base")
|
| 148 |
+
new_embedding = embedder.generate_embedding(new_node_content)
|
| 149 |
+
referenced_titles = ["forum istanbul", "istanbul film festivali", "akıllı şehir"]
|
| 150 |
+
|
| 151 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 152 |
+
result = infer_new_node(
|
| 153 |
+
base_data=graph_data,
|
| 154 |
+
model=model,
|
| 155 |
+
new_embedding=new_embedding,
|
| 156 |
+
referenced_titles=referenced_titles,
|
| 157 |
+
title_to_id=title_to_id,
|
| 158 |
+
label_mapping=label_mapping,
|
| 159 |
+
device=device,
|
| 160 |
+
make_undirected_for_new_node=True,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
print("Prediction Results for New Node:")
|
| 164 |
+
print(result)
|
src/heuristic.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 3 |
+
|
| 4 |
+
from typing import List, Any, Optional
|
| 5 |
+
from collections import defaultdict, deque
|
| 6 |
+
|
| 7 |
+
def predict_topic_nth_degree(
|
| 8 |
+
new_article_title: str,
|
| 9 |
+
new_article_embedding: List[float],
|
| 10 |
+
edges: List[str],
|
| 11 |
+
G: Any,
|
| 12 |
+
max_depth: int = 1,
|
| 13 |
+
is_weighted: bool = False,
|
| 14 |
+
decay_factor: float = 1.0,
|
| 15 |
+
) -> Optional[str]:
|
| 16 |
+
"""
|
| 17 |
+
Predicts topic based on neighbors up to n-degrees away.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
max_depth: How many hops to traverse (1 = direct neighbors, 2 = neighbors of neighbors).
|
| 21 |
+
decay_factor: Multiplier for distance. 1.0 = no decay.
|
| 22 |
+
0.5 means a neighbor at depth 2 has half the voting power of depth 1.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# 1. Setup BFS
|
| 26 |
+
# Queue stores: (current_node_name, current_depth)
|
| 27 |
+
queue = deque()
|
| 28 |
+
|
| 29 |
+
# We maintain a visited set to avoid cycles and processing the same node twice
|
| 30 |
+
visited = set()
|
| 31 |
+
visited.add(new_article_title)
|
| 32 |
+
|
| 33 |
+
# 2. Initialize BFS with the "Virtual" First Hop
|
| 34 |
+
# We iterate the input list 'edges' manually because the new article isn't in G.
|
| 35 |
+
for ref in edges:
|
| 36 |
+
if ref in G and ref not in visited:
|
| 37 |
+
visited.add(ref)
|
| 38 |
+
queue.append((ref, 1)) # Depth 1
|
| 39 |
+
|
| 40 |
+
if not queue:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
topic_scores = defaultdict(float)
|
| 44 |
+
|
| 45 |
+
# 3. Process BFS
|
| 46 |
+
while queue:
|
| 47 |
+
current_node, current_depth = queue.popleft()
|
| 48 |
+
|
| 49 |
+
# --- Score Calculation ---
|
| 50 |
+
node_data = G.nodes[current_node]
|
| 51 |
+
topic = node_data.get("label")
|
| 52 |
+
|
| 53 |
+
if topic:
|
| 54 |
+
# Determine base weight
|
| 55 |
+
if is_weighted:
|
| 56 |
+
neighbor_embedding = node_data["embedding"]
|
| 57 |
+
# Calculate similarity
|
| 58 |
+
base_score = cosine_similarity(
|
| 59 |
+
[new_article_embedding], [neighbor_embedding]
|
| 60 |
+
)[0][0]
|
| 61 |
+
else:
|
| 62 |
+
base_score = 1.0
|
| 63 |
+
|
| 64 |
+
# Apply Distance Decay
|
| 65 |
+
# Formula: Score * (decay ^ (depth - 1))
|
| 66 |
+
# Depth 1: Score * 1
|
| 67 |
+
# Depth 2: Score * decay
|
| 68 |
+
weighted_score = base_score * (decay_factor ** (current_depth - 1))
|
| 69 |
+
|
| 70 |
+
topic_scores[topic] += weighted_score
|
| 71 |
+
|
| 72 |
+
# --- Expand to next level if within limit ---
|
| 73 |
+
if current_depth < max_depth:
|
| 74 |
+
for neighbor in G.neighbors(current_node):
|
| 75 |
+
if neighbor not in visited:
|
| 76 |
+
visited.add(neighbor)
|
| 77 |
+
queue.append((neighbor, current_depth + 1))
|
| 78 |
+
|
| 79 |
+
# 4. Determine Winner
|
| 80 |
+
if not topic_scores:
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
columns = ["Class", "Score"]
|
| 84 |
+
result_df = pd.DataFrame(
|
| 85 |
+
[(topic, score) for topic, score in topic_scores.items()], columns=columns
|
| 86 |
+
).sort_values(by="Score", ascending=False)
|
| 87 |
+
return result_df
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_unique_article_titles(df: pd.DataFrame) -> list[str]:
|
| 6 |
+
unique_articles = df["article_title_processed"].unique()
|
| 7 |
+
unique_articles_sorted = sorted(unique_articles.tolist())
|
| 8 |
+
return unique_articles_sorted
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_graph_from_df(df, directed: bool = False) -> nx.Graph:
|
| 12 |
+
G = nx.Graph()
|
| 13 |
+
for i, row in df.iterrows():
|
| 14 |
+
node_title = row["article_title_processed"]
|
| 15 |
+
node_class = row["predicted_topic"]
|
| 16 |
+
G.add_node(node_title, label=node_class, embedding=row["embedding"])
|
| 17 |
+
|
| 18 |
+
for i, row in df.iterrows():
|
| 19 |
+
node_title = row["article_title_processed"]
|
| 20 |
+
references = eval(row["links_processed"])
|
| 21 |
+
|
| 22 |
+
for ref in references:
|
| 23 |
+
if ref in G and ref != node_title:
|
| 24 |
+
G.add_edge(node_title, ref)
|
| 25 |
+
|
| 26 |
+
if not directed:
|
| 27 |
+
G.add_edge(ref, node_title)
|
| 28 |
+
|
| 29 |
+
return G
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def gather_neighbors(
|
| 33 |
+
graph: nx.DiGraph, node_title: str, references: list[str], depth: int = 1
|
| 34 |
+
):
|
| 35 |
+
neighbors = set()
|
| 36 |
+
|
| 37 |
+
modified_graph = graph.copy()
|
| 38 |
+
|
| 39 |
+
modified_graph.add_node(node_title)
|
| 40 |
+
|
| 41 |
+
for ref in references:
|
| 42 |
+
if ref in modified_graph and ref != node_title:
|
| 43 |
+
modified_graph.add_edge(node_title, ref)
|
| 44 |
+
|
| 45 |
+
neighbors = get_neighbors_for_visualizer(modified_graph, node_title, depth=depth)
|
| 46 |
+
|
| 47 |
+
return neighbors
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
|
| 51 |
+
"""
|
| 52 |
+
Returns the neighbors of a node within a given depth in a format
|
| 53 |
+
compatible with Cytoscape-style visualizers.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
graph (nx.Graph): The source NetworkX graph.
|
| 57 |
+
start_node: The title/ID of the node to start from.
|
| 58 |
+
depth (int): How many hops (degrees of separation) to traverse.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
dict: A dictionary containing 'nodes' and 'edges' formatted for the visualizer.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
# 1. Create a subgraph of neighbors within the specified depth
|
| 65 |
+
# If the node doesn't exist, return empty structure or raise error
|
| 66 |
+
if start_node not in graph:
|
| 67 |
+
return {"nodes": [], "edges": []}
|
| 68 |
+
|
| 69 |
+
subgraph = nx.ego_graph(graph, start_node, radius=depth)
|
| 70 |
+
|
| 71 |
+
# 2. Prepare data structures
|
| 72 |
+
nodes_data = []
|
| 73 |
+
edges_data = []
|
| 74 |
+
|
| 75 |
+
# Helper to map actual node names (titles) to integer IDs required by the format
|
| 76 |
+
# The example uses 1-based integers for IDs.
|
| 77 |
+
node_to_id_map = {}
|
| 78 |
+
current_id = 1
|
| 79 |
+
|
| 80 |
+
# 3. Process Nodes
|
| 81 |
+
for node in subgraph.nodes():
|
| 82 |
+
# Assign an integer ID
|
| 83 |
+
node_to_id_map[node] = current_id
|
| 84 |
+
|
| 85 |
+
# Get attributes (safely default if label is missing)
|
| 86 |
+
# We ignore 'embedding' as requested
|
| 87 |
+
node_attrs = subgraph.nodes[node]
|
| 88 |
+
label = node_attrs.get("label", "Unknown")
|
| 89 |
+
|
| 90 |
+
node_obj = {
|
| 91 |
+
"data": {
|
| 92 |
+
"id": current_id,
|
| 93 |
+
"label": label,
|
| 94 |
+
"name": str(node), # Using the node title/ID as 'name'
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
nodes_data.append(node_obj)
|
| 98 |
+
current_id += 1
|
| 99 |
+
|
| 100 |
+
# 4. Process Edges
|
| 101 |
+
# Edge IDs usually need to be unique strings or integers.
|
| 102 |
+
# We continue the counter from where nodes left off to ensure uniqueness.
|
| 103 |
+
edge_id_counter = current_id
|
| 104 |
+
|
| 105 |
+
for u, v in subgraph.edges():
|
| 106 |
+
source_id = node_to_id_map[u]
|
| 107 |
+
target_id = node_to_id_map[v]
|
| 108 |
+
|
| 109 |
+
# Get edge attributes if they exist (e.g., relationship type)
|
| 110 |
+
edge_attrs = subgraph.edges[u, v]
|
| 111 |
+
edge_label = edge_attrs.get("label", "CITES") # Default label if none exists
|
| 112 |
+
|
| 113 |
+
edge_obj = {
|
| 114 |
+
"data": {
|
| 115 |
+
"id": edge_id_counter,
|
| 116 |
+
"label": edge_label,
|
| 117 |
+
"source": source_id,
|
| 118 |
+
"target": target_id,
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
edges_data.append(edge_obj)
|
| 122 |
+
edge_id_counter += 1
|
| 123 |
+
|
| 124 |
+
# 5. Return the final structure
|
| 125 |
+
return {"nodes": nodes_data, "edges": edges_data}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
data = pd.read_parquet(
|
| 130 |
+
r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\input\train_data_with_embeddings.parquet"
|
| 131 |
+
)
|
| 132 |
+
graph = create_graph_from_df(data)
|
| 133 |
+
|
| 134 |
+
test_title = "Sample Article Title"
|
| 135 |
+
test_references = ["finansal matematik", "genel yapay zekâ", "andrej karpathy"]
|
| 136 |
+
|
| 137 |
+
neighbors = gather_neighbors(graph, test_title, test_references, depth=2)
|
| 138 |
+
|
| 139 |
+
# print(f"References for '{test_title}': {test_references}")
|
| 140 |
+
print(f"Neighbors of '{test_title}': {neighbors}")
|
src/visualization.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from st_link_analysis import EdgeStyle, NodeStyle
|
| 2 |
+
from src.config import config
|
| 3 |
+
|
| 4 |
+
def get_node_styles() -> list[NodeStyle]:
|
| 5 |
+
node_styles = []
|
| 6 |
+
for class_name in config.ICON_MAPPING.keys():
|
| 7 |
+
color = config.COLOR_MAPPING.get(class_name, "#888888") # Default gray if not found
|
| 8 |
+
icon = config.ICON_MAPPING.get(class_name, None)
|
| 9 |
+
node_styles.append(NodeStyle(
|
| 10 |
+
label=class_name,
|
| 11 |
+
color=color,
|
| 12 |
+
icon=icon,
|
| 13 |
+
))
|
| 14 |
+
return node_styles
|
| 15 |
+
|
| 16 |
+
def get_edge_styles() -> list[EdgeStyle]:
|
| 17 |
+
edge_styles = [
|
| 18 |
+
EdgeStyle(
|
| 19 |
+
label="CITES",
|
| 20 |
+
)
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
return edge_styles
|