Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +132 -0
- README.md +314 -0
- conf.py +68 -0
- deeprobust/__init__.py +4 -0
- deeprobust/graph/data/attacked_data.py +218 -0
- deeprobust/graph/defense/adv_training.py +57 -0
- deeprobust/graph/defense/chebnet.py +215 -0
- deeprobust/graph/defense/data/processed/pre_filter.pt +3 -0
- deeprobust/graph/defense/gat.py +222 -0
- deeprobust/graph/defense/gcn.py +377 -0
- deeprobust/graph/defense/gcn_cgscore.py +413 -0
- deeprobust/graph/defense/gcn_guard.py +411 -0
- deeprobust/graph/defense/gcn_preprocess.py +488 -0
- deeprobust/graph/defense/median_gcn.py +296 -0
- deeprobust/graph/defense/node_embedding.py +538 -0
- deeprobust/graph/defense/prognn.py +314 -0
- deeprobust/graph/defense/r_gcn.py +367 -0
- deeprobust/graph/defense/r_gcn.py.backup +200 -0
- deeprobust/graph/defense/sgc.py +196 -0
- deeprobust/graph/defense_pyg/airgnn.py +186 -0
- deeprobust/graph/defense_pyg/sage.py +157 -0
- deeprobust/graph/global_attack/__init__.py +18 -0
- deeprobust/graph/global_attack/dice.py +123 -0
- deeprobust/graph/global_attack/ig_attack.py.backup +192 -0
- deeprobust/graph/global_attack/mettack.py +572 -0
- deeprobust/graph/global_attack/nipa.py +285 -0
- deeprobust/graph/global_attack/random_attack.py +144 -0
- deeprobust/graph/global_attack/topology_attack.py +323 -0
- deeprobust/graph/rl/nipa_config.py +59 -0
- deeprobust/graph/rl/nipa_nstep_replay_mem.py +56 -0
- deeprobust/graph/rl/nipa_q_net_node.py +242 -0
- deeprobust/graph/rl/nstep_replay_mem.py +157 -0
- deeprobust/graph/rl/rl_s2v_env.py +256 -0
- deeprobust/graph/targeted_attack/rl_s2v.py +262 -0
- deeprobust/graph/visualization.py +91 -0
- deeprobust/image/attack/BPDA.py +105 -0
- deeprobust/image/attack/Universal.py +151 -0
- deeprobust/image/attack/YOPOpgd.py +113 -0
- deeprobust/image/attack/base_attack.py +88 -0
- deeprobust/image/attack/l2_attack.py +174 -0
- deeprobust/image/attack/lbfgs.py +203 -0
- deeprobust/image/attack/pgd.py +150 -0
- deeprobust/image/config.py +69 -0
- deeprobust/image/defense/LIDclassifier.py +145 -0
- docs/graph/attack.rst +109 -0
- docs/graph/data.rst +188 -0
- docs/graph/pyg.rst +155 -0
- docs/image/example.rst +58 -0
- docs/notes/installation.rst +23 -0
- requirements.txt +14 -0
.gitignore
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Auxillary file on MacOS
|
2 |
+
.DS_Store
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
pip-wheel-metadata/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
.python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
98 |
+
__pypackages__/
|
99 |
+
|
100 |
+
# Celery stuff
|
101 |
+
celerybeat-schedule
|
102 |
+
celerybeat.pid
|
103 |
+
|
104 |
+
# SageMath parsed files
|
105 |
+
*.sage.py
|
106 |
+
|
107 |
+
# Environments
|
108 |
+
.env
|
109 |
+
.venv
|
110 |
+
env/
|
111 |
+
venv/
|
112 |
+
ENV/
|
113 |
+
env.bak/
|
114 |
+
venv.bak/
|
115 |
+
|
116 |
+
# Spyder project settings
|
117 |
+
.spyderproject
|
118 |
+
.spyproject
|
119 |
+
|
120 |
+
# Rope project settings
|
121 |
+
.ropeproject
|
122 |
+
|
123 |
+
# mkdocs documentation
|
124 |
+
/site
|
125 |
+
|
126 |
+
# mypy
|
127 |
+
.mypy_cache/
|
128 |
+
.dmypy.json
|
129 |
+
dmypy.json
|
130 |
+
|
131 |
+
# Pyre type checker
|
132 |
+
.pyre/
|
README.md
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat
|
3 |
+
[contributing-url]: https://github.com/rusty1s/pytorch_geometric/blob/master/CONTRIBUTING.md
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
<img center src="https://github.com/DSE-MSU/DeepRobust/blob/master/adversary_examples/Deeprobust.png" width = "450" alt="logo">
|
7 |
+
</p>
|
8 |
+
|
9 |
+
---------------------
|
10 |
+
<!--
|
11 |
+
<a href="https://github.com/DSE-MSU/DeepRobust/stargazers"><img alt="GitHub stars" src="https://img.shields.io/github/stars/DSE-MSU/DeepRobust"></a> <a href="https://github.com/DSE-MSU/DeepRobust/network/members" ><img alt="GitHub forks" src="https://img.shields.io/github/forks/DSE-MSU/DeepRobust">
|
12 |
+
</a>
|
13 |
+
-->
|
14 |
+
|
15 |
+
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/DSE-MSU/DeepRobust"> <a href="https://github.com/DSE-MSU/DeepRobust/issues"> <img alt="GitHub issues" src="https://img.shields.io/github/issues/DSE-MSU/DeepRobust"></a> <img alt="GitHub" src="https://img.shields.io/github/license/DSE-MSU/DeepRobust">
|
16 |
+
[![Contributing][contributing-image]][contributing-url]
|
17 |
+
[](https://twitter.com/intent/tweet?text=Build%20your%20robust%20machine%20learning%20models%20with%20DeepRobust%20in%2060%20seconds&url=https://github.com/DSE-MSU/DeepRobust&via=dse_msu&hashtags=MachineLearning,DeepLearning,secruity,data,developers)
|
18 |
+
|
19 |
+
|
20 |
+
<!-- <img alt="GitHub top language" src="https://img.shields.io/github/languages/top/DSE-MSU/DeepRobust"> -->
|
21 |
+
|
22 |
+
<!--
|
23 |
+
<div align=center><img src="https://github.com/DSE-MSU/DeepRobust/blob/master/adversarial.png" width="500"/></div>
|
24 |
+
<div align=center><img src="https://github.com/DSE-MSU/DeepRobust/blob/master/adversary_examples/graph_attack_example.png" width="00" /></div>
|
25 |
+
-->
|
26 |
+
**[Documentation](https://deeprobust.readthedocs.io/en/latest/)** | **[Paper](https://arxiv.org/abs/2005.06149)** | **[Samples](https://github.com/DSE-MSU/DeepRobust/tree/master/examples)**
|
27 |
+
|
28 |
+
[AAAI 2021] DeepRobust is a PyTorch adversarial library for attack and defense methods on images and graphs.
|
29 |
+
* If you are new to DeepRobust, we highly suggest you read the [documentation page](https://deeprobust.readthedocs.io/en/latest/) or the following content in this README to learn how to use it.
|
30 |
+
* If you have any questions or suggestions regarding this library, feel free to create an issue [here](https://github.com/DSE-MSU/DeepRobust/issues). We will reply as soon as possible :)
|
31 |
+
|
32 |
+
<p float="left">
|
33 |
+
<img src="https://github.com/DSE-MSU/DeepRobust/blob/master/adversary_examples/adversarial.png" width="430" />
|
34 |
+
<img src="https://github.com/DSE-MSU/DeepRobust/blob/master/adversary_examples/graph_attack_example.png" width="380" />
|
35 |
+
</p>
|
36 |
+
|
37 |
+
**List of including algorithms can be found in [[Image Package]](https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/image) and [[Graph Package]](https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph).**
|
38 |
+
|
39 |
+
[Environment & Installation](#environment)
|
40 |
+
|
41 |
+
Usage
|
42 |
+
|
43 |
+
* [Image Attack and Defense](#image-attack-and-defense)
|
44 |
+
|
45 |
+
* [Graph Attack and Defense](#graph-attack-and-defense)
|
46 |
+
|
47 |
+
[Acknowledgement](#acknowledgement)
|
48 |
+
|
49 |
+
For more details about attacks and defenses, you can read the following papers.
|
50 |
+
* [Adversarial Attacks and Defenses on Graphs: A Review, A Tool and Empirical Studies](https://arxiv.org/abs/2003.00653)
|
51 |
+
* [Adversarial Attacks and Defenses in Images, Graphs and Text: A Review](https://arxiv.org/pdf/1909.08072.pdf)
|
52 |
+
|
53 |
+
If our work could help your research, please cite:
|
54 |
+
[DeepRobust: A PyTorch Library for Adversarial Attacks and Defenses](https://arxiv.org/abs/2005.06149)
|
55 |
+
```
|
56 |
+
@article{li2020deeprobust,
|
57 |
+
title={Deeprobust: A pytorch library for adversarial attacks and defenses},
|
58 |
+
author={Li, Yaxin and Jin, Wei and Xu, Han and Tang, Jiliang},
|
59 |
+
journal={arXiv preprint arXiv:2005.06149},
|
60 |
+
year={2020}
|
61 |
+
}
|
62 |
+
```
|
63 |
+
|
64 |
+
# Changelog
|
65 |
+
* [11/2023] Try <span style="color:red"> `git clone https://github.com/DSE-MSU/DeepRobust.git; cd DeepRobust; python setup_empty.py install` </span> to directly install DeepRobust without installing dependency packages.
|
66 |
+
* [11/2023] DeepRobust 0.2.9 Released. Please try `pip install deeprobust==0.2.9`. We have fixed the OOM issue of metattack on new pytorch versions.
|
67 |
+
* [06/2023] We have added a backdoor attack [UGBA, WWW'23](https://arxiv.org/abs/2303.01263) to graph package. We can now use UGBA to conduct unnoticeable backdoor attack on large-scale graphs such as ogb-arxiv (see example in [test_ugba.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_ugba.py))!
|
68 |
+
* [02/2023] DeepRobust 0.2.8 Released. Please try `pip install deeprobust==0.2.8`! We have added a scalable attack [PRBCD, NeurIPS'21](https://arxiv.org/abs/2110.14038) to graph package. We can now use PRBCD to attack large-scale graphs such as ogb-arxiv (see example in [test_prbcd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_prbcd.py))!
|
69 |
+
* [02/2023] Add a robust model [AirGNN, NeurIPS'21](https://proceedings.neurips.cc/paper/2021/file/50abc3e730e36b387ca8e02c26dc0a22-Paper.pdf) to graph package. Try `python examples/graph/test_airgnn.py`! See details in [test_airgnn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_airgnn.py)
|
70 |
+
* [11/2022] DeepRobust 0.2.6 Released. Please try `pip install deeprobust==0.2.6`! We have more updates coming. Please stay tuned!
|
71 |
+
* [11/2021] A subpackage that includes popular black box attacks in image domain is released. Find it here. [Link](https://github.com/I-am-Bot/Black-Box-Attacks)
|
72 |
+
* [11/2021] DeepRobust 0.2.4 Released. Please try `pip install deeprobust==0.2.4`!
|
73 |
+
* [10/2021] add scalable attack and MedianGCN. Thank [Jintang](https://github.com/EdisonLeeeee) for his contribution!
|
74 |
+
* [06/2021] [Image Package] Add preprocessing method: APE-GAN.
|
75 |
+
* [05/2021] DeepRobust is published at AAAI 2021. Check [here](https://ojs.aaai.org/index.php/AAAI/article/view/18017)!
|
76 |
+
* [05/2021] DeepRobust 0.2.2 Released. Please try `pip install deeprobust==0.2.2`!
|
77 |
+
* [04/2021] [Image Package] Add support for ImageNet. See details in [test_ImageNet.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/image/test_ImageNet.py)
|
78 |
+
* [04/2021] [Graph Package] Add support for OGB datasets. See more details in the [tutorial page](https://deeprobust.readthedocs.io/en/latest/graph/pyg.html).
|
79 |
+
* [03/2021] [Graph Package] Added node embedding attack and victim models! See this [tutorial page](https://deeprobust.readthedocs.io/en/latest/graph/node_embedding.html).
|
80 |
+
* [02/2021] **[Graph Package] DeepRobust now provides tools for converting the datasets between [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) and DeepRobust. See more details in the [tutorial page](https://deeprobust.readthedocs.io/en/latest/graph/pyg.html)!** DeepRobust now also support GAT, Chebnet and SGC based on pyg; see details in [test_gat.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_gat.py), [test_chebnet.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_chebnet.py) and [test_sgc.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_sgc.py)
|
81 |
+
* [12/2020] DeepRobust now can be installed via pip! Try `pip install deeprobust`!
|
82 |
+
* [12/2020] [Graph Package] Add four more [datasets](https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph/#supported-datasets) and one defense algorithm. More details can be found [here](https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph/#defense-methods). More datasets and algorithms will be added later. Stay tuned :)
|
83 |
+
* [07/2020] Add [documentation](https://deeprobust.readthedocs.io/en/latest/) page!
|
84 |
+
* [06/2020] Add docstring to both image and graph package
|
85 |
+
|
86 |
+
# Basic Environment
|
87 |
+
* `python >= 3.6` (python 3.5 should also work)
|
88 |
+
* `pytorch >= 1.2.0`
|
89 |
+
|
90 |
+
see `setup.py` or `requirements.txt` for more information.
|
91 |
+
|
92 |
+
# Installation
|
93 |
+
## Install from pip
|
94 |
+
```
|
95 |
+
pip install deeprobust
|
96 |
+
```
|
97 |
+
## Install from source
|
98 |
+
```
|
99 |
+
git clone https://github.com/DSE-MSU/DeepRobust.git
|
100 |
+
cd DeepRobust
|
101 |
+
python setup.py install
|
102 |
+
```
|
103 |
+
If you find the dependencies are hard to install, please try the following:
|
104 |
+
```python setup_empty.py install``` (only install deeprobust without installing other packages)
|
105 |
+
|
106 |
+
# Test Examples
|
107 |
+
|
108 |
+
```
|
109 |
+
python examples/image/test_PGD.py
|
110 |
+
python examples/image/test_pgdtraining.py
|
111 |
+
python examples/graph/test_gcn_jaccard.py --dataset cora
|
112 |
+
python examples/graph/test_mettack.py --dataset cora --ptb_rate 0.05
|
113 |
+
```
|
114 |
+
|
115 |
+
# Usage
|
116 |
+
## Image Attack and Defense
|
117 |
+
1. Train model
|
118 |
+
|
119 |
+
Example: Train a simple CNN model on MNIST dataset for 20 epoch on gpu.
|
120 |
+
```python
|
121 |
+
import deeprobust.image.netmodels.train_model as trainmodel
|
122 |
+
trainmodel.train('CNN', 'MNIST', 'cuda', 20)
|
123 |
+
```
|
124 |
+
Model would be saved in deeprobust/trained_models/.
|
125 |
+
|
126 |
+
2. Instantiated attack methods and defense methods.
|
127 |
+
|
128 |
+
Example: Generate adversary example with PGD attack.
|
129 |
+
```python
|
130 |
+
from deeprobust.image.attack.pgd import PGD
|
131 |
+
from deeprobust.image.config import attack_params
|
132 |
+
from deeprobust.image.utils import download_model
|
133 |
+
import torch
|
134 |
+
import deeprobust.image.netmodels.resnet as resnet
|
135 |
+
from torchvision import transforms,datasets
|
136 |
+
|
137 |
+
URL = "https://github.com/I-am-Bot/deeprobust_model/raw/master/CIFAR10_ResNet18_epoch_20.pt"
|
138 |
+
download_model(URL, "$MODEL_PATH$")
|
139 |
+
|
140 |
+
model = resnet.ResNet18().to('cuda')
|
141 |
+
model.load_state_dict(torch.load("$MODEL_PATH$"))
|
142 |
+
model.eval()
|
143 |
+
|
144 |
+
transform_val = transforms.Compose([transforms.ToTensor()])
|
145 |
+
test_loader = torch.utils.data.DataLoader(
|
146 |
+
datasets.CIFAR10('deeprobust/image/data', train = False, download=True,
|
147 |
+
transform = transform_val),
|
148 |
+
batch_size = 10, shuffle=True)
|
149 |
+
|
150 |
+
x, y = next(iter(test_loader))
|
151 |
+
x = x.to('cuda').float()
|
152 |
+
|
153 |
+
adversary = PGD(model, 'cuda')
|
154 |
+
Adv_img = adversary.generate(x, y, **attack_params['PGD_CIFAR10'])
|
155 |
+
```
|
156 |
+
|
157 |
+
Example: Train defense model.
|
158 |
+
```python
|
159 |
+
from deeprobust.image.defense.pgdtraining import PGDtraining
|
160 |
+
from deeprobust.image.config import defense_params
|
161 |
+
from deeprobust.image.netmodels.CNN import Net
|
162 |
+
import torch
|
163 |
+
from torchvision import datasets, transforms
|
164 |
+
|
165 |
+
model = Net()
|
166 |
+
train_loader = torch.utils.data.DataLoader(
|
167 |
+
datasets.MNIST('deeprobust/image/defense/data', train=True, download=True,
|
168 |
+
transform=transforms.Compose([transforms.ToTensor()])),
|
169 |
+
batch_size=100,shuffle=True)
|
170 |
+
|
171 |
+
test_loader = torch.utils.data.DataLoader(
|
172 |
+
datasets.MNIST('deeprobust/image/defense/data', train=False,
|
173 |
+
transform=transforms.Compose([transforms.ToTensor()])),
|
174 |
+
batch_size=1000,shuffle=True)
|
175 |
+
|
176 |
+
defense = PGDtraining(model, 'cuda')
|
177 |
+
defense.generate(train_loader, test_loader, **defense_params["PGDtraining_MNIST"])
|
178 |
+
```
|
179 |
+
|
180 |
+
More example code can be found in deeprobust/examples.
|
181 |
+
|
182 |
+
3. Use our evulation program to test attack algorithm against defense.
|
183 |
+
|
184 |
+
Example:
|
185 |
+
```
|
186 |
+
cd DeepRobust
|
187 |
+
python examples/image/test_train.py
|
188 |
+
python deeprobust/image/evaluation_attack.py
|
189 |
+
```
|
190 |
+
|
191 |
+
## Graph Attack and Defense
|
192 |
+
|
193 |
+
### Attacking Graph Neural Networks
|
194 |
+
|
195 |
+
1. Load dataset
|
196 |
+
```python
|
197 |
+
import torch
|
198 |
+
import numpy as np
|
199 |
+
from deeprobust.graph.data import Dataset
|
200 |
+
from deeprobust.graph.defense import GCN
|
201 |
+
from deeprobust.graph.global_attack import Metattack
|
202 |
+
|
203 |
+
data = Dataset(root='/tmp/', name='cora', setting='nettack')
|
204 |
+
adj, features, labels = data.adj, data.features, data.labels
|
205 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
206 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
207 |
+
```
|
208 |
+
|
209 |
+
2. Set up surrogate model
|
210 |
+
```python
|
211 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
212 |
+
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16,
|
213 |
+
with_relu=False, device=device)
|
214 |
+
surrogate = surrogate.to(device)
|
215 |
+
surrogate.fit(features, adj, labels, idx_train)
|
216 |
+
```
|
217 |
+
|
218 |
+
|
219 |
+
3. Set up attack model and generate perturbations
|
220 |
+
```python
|
221 |
+
model = Metattack(model=surrogate, nnodes=adj.shape[0], feature_shape=features.shape, device=device)
|
222 |
+
model = model.to(device)
|
223 |
+
perturbations = int(0.05 * (adj.sum() // 2))
|
224 |
+
model.attack(features, adj, labels, idx_train, idx_unlabeled, perturbations, ll_constraint=False)
|
225 |
+
modified_adj = model.modified_adj
|
226 |
+
```
|
227 |
+
|
228 |
+
For more details please refer to [mettack.py](https://github.com/I-am-Bot/DeepRobust/blob/master/examples/graph/test_mettack.py) or run
|
229 |
+
```
|
230 |
+
python examples/graph/test_mettack.py --dataset cora --ptb_rate 0.05
|
231 |
+
```
|
232 |
+
|
233 |
+
### Defending Against Graph Attacks
|
234 |
+
|
235 |
+
1. Load dataset
|
236 |
+
```python
|
237 |
+
import torch
|
238 |
+
from deeprobust.graph.data import Dataset, PtbDataset
|
239 |
+
from deeprobust.graph.defense import GCN, GCNJaccard
|
240 |
+
import numpy as np
|
241 |
+
np.random.seed(15)
|
242 |
+
|
243 |
+
# load clean graph
|
244 |
+
data = Dataset(root='/tmp/', name='cora', setting='nettack')
|
245 |
+
adj, features, labels = data.adj, data.features, data.labels
|
246 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
247 |
+
|
248 |
+
# load pre-attacked graph by mettack
|
249 |
+
perturbed_data = PtbDataset(root='/tmp/', name='cora')
|
250 |
+
perturbed_adj = perturbed_data.adj
|
251 |
+
```
|
252 |
+
2. Test
|
253 |
+
```python
|
254 |
+
# Set up defense model and test performance
|
255 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
256 |
+
model = GCNJaccard(nfeat=features.shape[1], nclass=labels.max()+1, nhid=16, device=device)
|
257 |
+
model = model.to(device)
|
258 |
+
model.fit(features, perturbed_adj, labels, idx_train)
|
259 |
+
model.eval()
|
260 |
+
output = model.test(idx_test)
|
261 |
+
|
262 |
+
# Test on GCN
|
263 |
+
model = GCN(nfeat=features.shape[1], nclass=labels.max()+1, nhid=16, device=device)
|
264 |
+
model = model.to(device)
|
265 |
+
model.fit(features, perturbed_adj, labels, idx_train)
|
266 |
+
model.eval()
|
267 |
+
output = model.test(idx_test)
|
268 |
+
```
|
269 |
+
|
270 |
+
For more details please refer to [test_gcn_jaccard.py](https://github.com/I-am-Bot/DeepRobust/blob/master/examples/graph/test_gcn_jaccard.py) or run
|
271 |
+
```
|
272 |
+
python examples/graph/test_gcn_jaccard.py --dataset cora
|
273 |
+
```
|
274 |
+
|
275 |
+
## Sample Results
|
276 |
+
adversary examples generated by fgsm:
|
277 |
+
<div align="center">
|
278 |
+
<img height=140 src="https://github.com/DSE-MSU/DeepRobust/blob/master/adversary_examples/mnist_advexample_fgsm_ori.png"/><img height=140 src="https://github.com/DSE-MSU/DeepRobust/blob/master/adversary_examples/mnist_advexample_fgsm_adv.png"/>
|
279 |
+
</div>
|
280 |
+
Left:original, classified as 6; Right:adversary, classified as 4.
|
281 |
+
|
282 |
+
Serveral trained models can be found here: https://drive.google.com/open?id=1uGLiuCyd8zCAQ8tPz9DDUQH6zm-C4tEL
|
283 |
+
|
284 |
+
## Acknowledgement
|
285 |
+
Some of the algorithms are referred to paper authors' implementations. References can be found at the top of each file.
|
286 |
+
|
287 |
+
Implementation of network structure are referred to weiaicunzai's github. Original code can be found here:
|
288 |
+
[pytorch-cifar100](https://github.com/weiaicunzai/pytorch-cifar100)
|
289 |
+
|
290 |
+
Thanks to their outstanding works!
|
291 |
+
|
292 |
+
|
293 |
+
<!----
|
294 |
+
We would be glad if you find our work useful and cite the paper.
|
295 |
+
|
296 |
+
'''
|
297 |
+
@misc{jin2020adversarial,
|
298 |
+
title={Adversarial Attacks and Defenses on Graphs: A Review and Empirical Study},
|
299 |
+
author={Wei Jin and Yaxin Li and Han Xu and Yiqi Wang and Jiliang Tang},
|
300 |
+
year={2020},
|
301 |
+
eprint={2003.00653},
|
302 |
+
archivePrefix={arXiv},
|
303 |
+
primaryClass={cs.LG}
|
304 |
+
}
|
305 |
+
'''
|
306 |
+
```
|
307 |
+
@article{xu2019adversarial,
|
308 |
+
title={Adversarial attacks and defenses in images, graphs and text: A review},
|
309 |
+
author={Xu, Han and Ma, Yao and Liu, Haochen and Deb, Debayan and Liu, Hui and Tang, Jiliang and Jain, Anil},
|
310 |
+
journal={arXiv preprint arXiv:1909.08072},
|
311 |
+
year={2019}
|
312 |
+
}
|
313 |
+
```
|
314 |
+
---->
|
conf.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration file for the Sphinx documentation builder.
|
2 |
+
#
|
3 |
+
# This file only contains a selection of the most common options. For a full
|
4 |
+
# list see the documentation:
|
5 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
6 |
+
|
7 |
+
# -- Path setup --------------------------------------------------------------
|
8 |
+
|
9 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
10 |
+
# add these directories to sys.path here. If the directory is relative to the
|
11 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
12 |
+
#
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
# sys.path.insert(0, os.path.abspath('.'))
|
16 |
+
sys.path.insert(0, os.path.abspath('../'))
|
17 |
+
# sys.path.append('/home/jinwei/Laboratory/api')
|
18 |
+
sys.path.append('/Users/yaxinli/Desktop/MSU/DeepRobust')
|
19 |
+
|
20 |
+
# -- Project information -----------------------------------------------------
|
21 |
+
|
22 |
+
project = 'x'
|
23 |
+
copyright = '2020, x'
|
24 |
+
author = 'x'
|
25 |
+
|
26 |
+
# The full version, including alpha/beta/rc tags
|
27 |
+
release = 'x'
|
28 |
+
|
29 |
+
|
30 |
+
# -- General configuration ---------------------------------------------------
|
31 |
+
|
32 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
33 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
34 |
+
# ones.
|
35 |
+
extensions = ['sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.autosummary',
|
36 |
+
]
|
37 |
+
# extensions = ['sphinx.ext.napoleon']
|
38 |
+
autodoc_mock_imports = ['torch', 'torchvision', 'texttable', 'tensorboardX',
|
39 |
+
]
|
40 |
+
|
41 |
+
# remove undoc members
|
42 |
+
#autodoc_default_flags = ['members']
|
43 |
+
|
44 |
+
# Add any paths that contain templates here, relative to this directory.
|
45 |
+
templates_path = ['_templates']
|
46 |
+
|
47 |
+
# List of patterns, relative to source directory, that match files and
|
48 |
+
# directories to ignore when looking for source files.
|
49 |
+
# This pattern also affects html_static_path and html_extra_path.
|
50 |
+
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
51 |
+
|
52 |
+
|
53 |
+
# -- Options for HTML output -------------------------------------------------
|
54 |
+
|
55 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
56 |
+
# a list of builtin themes.
|
57 |
+
#
|
58 |
+
# html_theme = 'alabaster'
|
59 |
+
html_theme = 'sphinx_rtd_theme'
|
60 |
+
|
61 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
62 |
+
# relative to this directory. They are copied after the builtin static files,
|
63 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
64 |
+
html_static_path = ['_static']
|
65 |
+
|
66 |
+
add_module_names = False
|
67 |
+
|
68 |
+
|
deeprobust/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from deeprobust import image
|
2 |
+
# from deeprobust import graph
|
3 |
+
|
4 |
+
# __all__ = ['image', 'graph']
|
deeprobust/graph/data/attacked_data.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.sparse as sp
|
3 |
+
import os.path as osp
|
4 |
+
import os
|
5 |
+
import urllib.request
|
6 |
+
import warnings
|
7 |
+
import json
|
8 |
+
|
9 |
+
class PtbDataset:
|
10 |
+
"""Dataset class manages pre-attacked adjacency matrix on different datasets. Currently only support metattack under 5% perturbation. Note metattack is generated by deeprobust/graph/global_attack/metattack.py. While PrePtbDataset provides pre-attacked graph generate by Zugner, https://github.com/danielzuegner/gnn-meta-attack. The attacked graphs are downloaded from https://github.com/ChandlerBang/pytorch-gnn-meta-attack/tree/master/pre-attacked.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
root :
|
15 |
+
root directory where the dataset should be saved.
|
16 |
+
name :
|
17 |
+
dataset name. It can be choosen from ['cora', 'citeseer', 'cora_ml', 'polblogs', 'pubmed']
|
18 |
+
attack_method :
|
19 |
+
currently this class only support metattack. User can pass 'meta', 'metattack' or 'mettack' since all of them will be interpreted as the same attack.
|
20 |
+
seed :
|
21 |
+
random seed for splitting training/validation/test.
|
22 |
+
|
23 |
+
Examples
|
24 |
+
--------
|
25 |
+
|
26 |
+
>>> from deeprobust.graph.data import Dataset, PtbDataset
|
27 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
28 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
29 |
+
>>> perturbed_data = PtbDataset(root='/tmp/',
|
30 |
+
name='cora',
|
31 |
+
attack_method='meta')
|
32 |
+
>>> perturbed_adj = perturbed_data.adj
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
def __init__(self, root, name, attack_method='mettack'):
|
38 |
+
assert attack_method in ['mettack', 'metattack', 'meta'], \
|
39 |
+
'Currently the database only stores graphs perturbed by 5% mettack'
|
40 |
+
|
41 |
+
self.name = name.lower()
|
42 |
+
assert self.name in ['cora', 'citeseer', 'polblogs'], \
|
43 |
+
'Currently only support cora, citeseer, polblogs'
|
44 |
+
|
45 |
+
self.attack_method = 'mettack' # attack_method
|
46 |
+
self.url = 'https://raw.githubusercontent.com/ChandlerBang/pytorch-gnn-meta-attack/master/pre-attacked/{}_{}_0.05.npz'.format(self.name, self.attack_method)
|
47 |
+
self.root = osp.expanduser(osp.normpath(root))
|
48 |
+
self.data_filename = osp.join(root,
|
49 |
+
'{}_{}_0.05.npz'.format(self.name, self.attack_method))
|
50 |
+
self.adj = self.load_data()
|
51 |
+
|
52 |
+
def load_data(self):
|
53 |
+
if not osp.exists(self.data_filename):
|
54 |
+
self.download_npz()
|
55 |
+
print('Loading {} dataset perturbed by 0.05 mettack...'.format(self.name))
|
56 |
+
adj = sp.load_npz(self.data_filename)
|
57 |
+
warnings.warn('''the adjacency matrix is perturbed, using the data splits under seed 15(default seed for deeprobust.graph.data.Dataset), so if you are going to verify the attacking performance, you should use the same data splits''')
|
58 |
+
return adj
|
59 |
+
|
60 |
+
def download_npz(self):
|
61 |
+
print('Dowloading from {} to {}'.format(self.url, self.data_filename))
|
62 |
+
try:
|
63 |
+
urllib.request.urlretrieve(self.url, self.data_filename)
|
64 |
+
except:
|
65 |
+
raise Exception('''Download failed! Make sure you have
|
66 |
+
stable Internet connection and enter the right name''')
|
67 |
+
|
68 |
+
|
69 |
+
class PrePtbDataset:
|
70 |
+
"""Dataset class manages pre-attacked adjacency matrix on different datasets. Note metattack is generated by deeprobust/graph/global_attack/metattack.py. While PrePtbDataset provides pre-attacked graph generate by Zugner, https://github.com/danielzuegner/gnn-meta-attack. The attacked graphs are downloaded from https://github.com/ChandlerBang/Pro-GNN/tree/master/meta.
|
71 |
+
|
72 |
+
Parameters
|
73 |
+
----------
|
74 |
+
root :
|
75 |
+
root directory where the dataset should be saved.
|
76 |
+
name :
|
77 |
+
dataset name. It can be choosen from ['cora', 'citeseer', 'polblogs', 'pubmed']
|
78 |
+
attack_method :
|
79 |
+
currently this class only support metattack and nettack. Note 'meta', 'metattack' or 'mettack' will be interpreted as the same attack.
|
80 |
+
seed :
|
81 |
+
random seed for splitting training/validation/test.
|
82 |
+
|
83 |
+
Examples
|
84 |
+
--------
|
85 |
+
|
86 |
+
>>> from deeprobust.graph.data import Dataset, PrePtbDataset
|
87 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
88 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
89 |
+
>>> # Load meta attacked data
|
90 |
+
>>> perturbed_data = PrePtbDataset(root='/tmp/',
|
91 |
+
name='cora',
|
92 |
+
attack_method='meta',
|
93 |
+
ptb_rate=0.05)
|
94 |
+
>>> perturbed_adj = perturbed_data.adj
|
95 |
+
>>> # Load nettacked data
|
96 |
+
>>> perturbed_data = PrePtbDataset(root='/tmp/',
|
97 |
+
name='cora',
|
98 |
+
attack_method='nettack',
|
99 |
+
ptb_rate=1.0)
|
100 |
+
>>> perturbed_adj = perturbed_data.adj
|
101 |
+
>>> target_nodes = perturbed_data.target_nodes
|
102 |
+
"""
|
103 |
+
|
104 |
+
|
105 |
+
def __init__(self, root, name, attack_method='meta', ptb_rate=0.05):
|
106 |
+
|
107 |
+
if attack_method == 'mettack' or attack_method == 'metattack':
|
108 |
+
attack_method = 'meta'
|
109 |
+
|
110 |
+
assert attack_method in ['meta', 'nettack'], \
|
111 |
+
' Currently the database only stores graphs perturbed by metattack, nettack'
|
112 |
+
# assert attack_method in ['meta'], \
|
113 |
+
# ' Currently the database only stores graphs perturbed by metattack. Will update nettack soon.'
|
114 |
+
|
115 |
+
self.name = name.lower()
|
116 |
+
assert self.name in ['cora', 'citeseer', 'polblogs', 'pubmed', 'cora_ml'], \
|
117 |
+
'Currently only support cora, citeseer, pubmed, polblogs, cora_ml'
|
118 |
+
|
119 |
+
self.attack_method = attack_method
|
120 |
+
self.ptb_rate = ptb_rate
|
121 |
+
self.url = 'https://raw.githubusercontent.com/ChandlerBang/Pro-GNN/master/{}/{}_{}_adj_{}.npz'.\
|
122 |
+
format(self.attack_method, self.name, self.attack_method, self.ptb_rate)
|
123 |
+
# self.url = 'https://github.com/ChandlerBang/Pro-GNN/blob/master/{}/{}_{}_adj_{}.npz'.\
|
124 |
+
# format(self.attack_method, self.name, self.attack_method, self.ptb_rate)
|
125 |
+
self.root = osp.expanduser(osp.normpath(root))
|
126 |
+
self.data_filename = osp.join(root,
|
127 |
+
'{}_{}_adj_{}.npz'.format(self.name, self.attack_method, self.ptb_rate))
|
128 |
+
self.target_nodes = None
|
129 |
+
self.adj = self.load_data()
|
130 |
+
|
131 |
+
def load_data(self):
|
132 |
+
if not osp.exists(self.data_filename):
|
133 |
+
self.download_npz()
|
134 |
+
print('Loading {} dataset perturbed by {} {}...'.format(self.name, self.ptb_rate, self.attack_method))
|
135 |
+
|
136 |
+
if self.attack_method == 'meta':
|
137 |
+
warnings.warn("The pre-attacked graph is perturbed under the data splits provided by ProGNN. So if you are going to verify the attacking performance, you should use the same data splits (setting='prognn').")
|
138 |
+
adj = sp.load_npz(self.data_filename)
|
139 |
+
|
140 |
+
if self.attack_method == 'nettack':
|
141 |
+
# assert True, "Will update pre-attacked data by nettack soon"
|
142 |
+
warnings.warn("The pre-attacked graph is perturbed under the data splits provided by ProGNN. So if you are going to verify the attacking performance, you should use the same data splits (setting='prognn').")
|
143 |
+
adj = sp.load_npz(self.data_filename)
|
144 |
+
self.target_nodes = self.get_target_nodes()
|
145 |
+
return adj
|
146 |
+
|
147 |
+
def get_target_nodes(self):
|
148 |
+
"""Get target nodes incides, which is the nodes with degree > 10 in the test set."""
|
149 |
+
url = 'https://raw.githubusercontent.com/ChandlerBang/Pro-GNN/master/nettack/{}_nettacked_nodes.json'.format(self.name)
|
150 |
+
json_file = osp.join(self.root,
|
151 |
+
'{}_nettacked_nodes.json'.format(self.name))
|
152 |
+
|
153 |
+
if not osp.exists(json_file):
|
154 |
+
self.download_file(url, json_file)
|
155 |
+
# with open(f'/mnt/home/jinwei2/Projects/nettack/{dataset}_nettacked_nodes.json', 'r') as f:
|
156 |
+
with open(json_file, 'r') as f:
|
157 |
+
idx = json.loads(f.read())
|
158 |
+
return idx["attacked_test_nodes"]
|
159 |
+
|
160 |
+
def download_file(self, url, file):
|
161 |
+
print('Dowloading from {} to {}'.format(url, file))
|
162 |
+
try:
|
163 |
+
urllib.request.urlretrieve(url, file)
|
164 |
+
except:
|
165 |
+
raise Exception("Download failed! Make sure you have \
|
166 |
+
stable Internet connection and enter the right name")
|
167 |
+
|
168 |
+
def download_npz(self):
|
169 |
+
print('Dowloading from {} to {}'.format(self.url, self.data_filename))
|
170 |
+
try:
|
171 |
+
urllib.request.urlretrieve(self.url, self.data_filename)
|
172 |
+
except:
|
173 |
+
raise Exception("Download failed! Make sure you have \
|
174 |
+
stable Internet connection and enter the right name")
|
175 |
+
|
176 |
+
|
177 |
+
class RandomAttack():
|
178 |
+
|
179 |
+
def __init__(self):
|
180 |
+
self.name = 'RandomAttack'
|
181 |
+
|
182 |
+
def attack(self, adj, ratio=0.4):
|
183 |
+
print('random attack: ratio=%s' % ratio)
|
184 |
+
modified_adj = self._random_add_edges(adj, ratio)
|
185 |
+
return modified_adj
|
186 |
+
|
187 |
+
def _random_add_edges(self, adj, add_ratio):
|
188 |
+
|
189 |
+
def sample_zero_forever(mat):
|
190 |
+
nonzero_or_sampled = set(zip(*mat.nonzero()))
|
191 |
+
while True:
|
192 |
+
t = tuple(np.random.randint(0, mat.shape[0], 2))
|
193 |
+
if t not in nonzero_or_sampled:
|
194 |
+
yield t
|
195 |
+
nonzero_or_sampled.add(t)
|
196 |
+
nonzero_or_sampled.add((t[1], t[0]))
|
197 |
+
|
198 |
+
def sample_zero_n(mat, n=100):
|
199 |
+
itr = sample_zero_forever(mat)
|
200 |
+
return [next(itr) for _ in range(n)]
|
201 |
+
|
202 |
+
assert np.abs(adj - adj.T).sum() == 0, "Input graph is not symmetric"
|
203 |
+
non_zeros = [(x, y) for x,y in np.argwhere(adj != 0) if x < y] # (x, y)
|
204 |
+
|
205 |
+
added = sample_zero_n(adj, n=int(add_ratio * len(non_zeros)))
|
206 |
+
for x, y in added:
|
207 |
+
adj[x, y] = 1
|
208 |
+
adj[y, x] = 1
|
209 |
+
return adj
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == '__main__':
|
213 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
214 |
+
name='cora',
|
215 |
+
attack_method='meta',
|
216 |
+
ptb_rate=0.05)
|
217 |
+
perturbed_adj = perturbed_data.adj
|
218 |
+
|
deeprobust/graph/defense/adv_training.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from torch.nn.modules.module import Module
|
5 |
+
from deeprobust.graph import utils
|
6 |
+
from deeprobust.graph.defense import GCN
|
7 |
+
from tqdm import tqdm
|
8 |
+
import scipy.sparse as sp
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class AdvTraining:
|
13 |
+
"""Adversarial training framework for defending against attacks.
|
14 |
+
|
15 |
+
Parameters
|
16 |
+
----------
|
17 |
+
model :
|
18 |
+
model to protect, e.g, GCN
|
19 |
+
adversary :
|
20 |
+
attack model
|
21 |
+
device : str
|
22 |
+
'cpu' or 'cuda'
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, model, adversary=None, device='cpu'):
|
26 |
+
|
27 |
+
self.model = model
|
28 |
+
if adversary is None:
|
29 |
+
adversary = RND()
|
30 |
+
self.adversary = adversary
|
31 |
+
self.device = device
|
32 |
+
|
33 |
+
def adv_train(self, features, adj, labels, idx_train, train_iters, **kwargs):
|
34 |
+
"""Start adversarial training.
|
35 |
+
|
36 |
+
Parameters
|
37 |
+
----------
|
38 |
+
features :
|
39 |
+
node features
|
40 |
+
adj :
|
41 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
42 |
+
labels :
|
43 |
+
node labels
|
44 |
+
idx_train :
|
45 |
+
node training indices
|
46 |
+
idx_val :
|
47 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
48 |
+
train_iters : int
|
49 |
+
number of training epochs
|
50 |
+
"""
|
51 |
+
for i in range(train_iters):
|
52 |
+
modified_adj = self.adversary.attack(features, adj)
|
53 |
+
self.model.fit(features, modified_adj, train_iters, initialize=False)
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
deeprobust/graph/defense/chebnet.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Extended from https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation
|
3 |
+
"""
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.optim as optim
|
9 |
+
from torch.nn.parameter import Parameter
|
10 |
+
from torch.nn.modules.module import Module
|
11 |
+
from deeprobust.graph import utils
|
12 |
+
from copy import deepcopy
|
13 |
+
from torch_geometric.nn import ChebConv
|
14 |
+
|
15 |
+
class ChebNet(nn.Module):
|
16 |
+
""" 2 Layer ChebNet based on pytorch geometric.
|
17 |
+
|
18 |
+
Parameters
|
19 |
+
----------
|
20 |
+
nfeat : int
|
21 |
+
size of input feature dimension
|
22 |
+
nhid : int
|
23 |
+
number of hidden units
|
24 |
+
nclass : int
|
25 |
+
size of output dimension
|
26 |
+
num_hops: int
|
27 |
+
number of hops in ChebConv
|
28 |
+
dropout : float
|
29 |
+
dropout rate for ChebNet
|
30 |
+
lr : float
|
31 |
+
learning rate for ChebNet
|
32 |
+
weight_decay : float
|
33 |
+
weight decay coefficient (l2 normalization) for GCN.
|
34 |
+
When `with_relu` is True, `weight_decay` will be set to 0.
|
35 |
+
with_bias: bool
|
36 |
+
whether to include bias term in ChebNet weights.
|
37 |
+
device: str
|
38 |
+
'cpu' or 'cuda'.
|
39 |
+
|
40 |
+
Examples
|
41 |
+
--------
|
42 |
+
We can first load dataset and then train ChebNet.
|
43 |
+
|
44 |
+
>>> from deeprobust.graph.data import Dataset
|
45 |
+
>>> from deeprobust.graph.defense import ChebNet
|
46 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
47 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
48 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
49 |
+
>>> cheby = ChebNet(nfeat=features.shape[1],
|
50 |
+
nhid=16, num_hops=3,
|
51 |
+
nclass=labels.max().item() + 1,
|
52 |
+
dropout=0.5, device='cpu')
|
53 |
+
>>> cheby = cheby.to('cpu')
|
54 |
+
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
|
55 |
+
>>> cheby.fit(pyg_data, patience=10, verbose=True) # train with earlystopping
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, nfeat, nhid, nclass, num_hops=3, dropout=0.5, lr=0.01,
|
59 |
+
weight_decay=5e-4, with_bias=True, device=None):
|
60 |
+
|
61 |
+
super(ChebNet, self).__init__()
|
62 |
+
|
63 |
+
assert device is not None, "Please specify 'device'!"
|
64 |
+
self.device = device
|
65 |
+
|
66 |
+
self.conv1 = ChebConv(
|
67 |
+
nfeat,
|
68 |
+
nhid,
|
69 |
+
K=num_hops,
|
70 |
+
bias=with_bias)
|
71 |
+
|
72 |
+
self.conv2 = ChebConv(
|
73 |
+
nhid,
|
74 |
+
nclass,
|
75 |
+
K=num_hops,
|
76 |
+
bias=with_bias)
|
77 |
+
|
78 |
+
self.dropout = dropout
|
79 |
+
self.weight_decay = weight_decay
|
80 |
+
self.lr = lr
|
81 |
+
self.output = None
|
82 |
+
self.best_model = None
|
83 |
+
self.best_output = None
|
84 |
+
|
85 |
+
def forward(self, data):
|
86 |
+
x, edge_index = data.x, data.edge_index
|
87 |
+
x = F.relu(self.conv1(x, edge_index))
|
88 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
89 |
+
x = self.conv2(x, edge_index)
|
90 |
+
return F.log_softmax(x, dim=1)
|
91 |
+
|
92 |
+
def initialize(self):
|
93 |
+
"""Initialize parameters of ChebNet.
|
94 |
+
"""
|
95 |
+
self.conv1.reset_parameters()
|
96 |
+
self.conv2.reset_parameters()
|
97 |
+
|
98 |
+
def fit(self, pyg_data, train_iters=200, initialize=True, verbose=False, patience=500, **kwargs):
|
99 |
+
"""Train the ChebNet model, when idx_val is not None, pick the best model
|
100 |
+
according to the validation loss.
|
101 |
+
|
102 |
+
Parameters
|
103 |
+
----------
|
104 |
+
pyg_data :
|
105 |
+
pytorch geometric dataset object
|
106 |
+
train_iters : int
|
107 |
+
number of training epochs
|
108 |
+
initialize : bool
|
109 |
+
whether to initialize parameters before training
|
110 |
+
verbose : bool
|
111 |
+
whether to show verbose logs
|
112 |
+
patience : int
|
113 |
+
patience for early stopping, only valid when `idx_val` is given
|
114 |
+
"""
|
115 |
+
|
116 |
+
self.device = self.conv1.weight.device
|
117 |
+
if initialize:
|
118 |
+
self.initialize()
|
119 |
+
|
120 |
+
self.data = pyg_data[0].to(self.device)
|
121 |
+
# By default, it is trained with early stopping on validation
|
122 |
+
self.train_with_early_stopping(train_iters, patience, verbose)
|
123 |
+
|
124 |
+
def train_with_early_stopping(self, train_iters, patience, verbose):
|
125 |
+
"""early stopping based on the validation loss
|
126 |
+
"""
|
127 |
+
if verbose:
|
128 |
+
print('=== training ChebNet model ===')
|
129 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
130 |
+
|
131 |
+
labels = self.data.y
|
132 |
+
train_mask, val_mask = self.data.train_mask, self.data.val_mask
|
133 |
+
|
134 |
+
early_stopping = patience
|
135 |
+
best_loss_val = 100
|
136 |
+
|
137 |
+
for i in range(train_iters):
|
138 |
+
self.train()
|
139 |
+
optimizer.zero_grad()
|
140 |
+
output = self.forward(self.data)
|
141 |
+
|
142 |
+
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
|
143 |
+
loss_train.backward()
|
144 |
+
optimizer.step()
|
145 |
+
|
146 |
+
if verbose and i % 10 == 0:
|
147 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
148 |
+
|
149 |
+
self.eval()
|
150 |
+
output = self.forward(self.data)
|
151 |
+
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
|
152 |
+
|
153 |
+
if best_loss_val > loss_val:
|
154 |
+
best_loss_val = loss_val
|
155 |
+
self.output = output
|
156 |
+
weights = deepcopy(self.state_dict())
|
157 |
+
patience = early_stopping
|
158 |
+
else:
|
159 |
+
patience -= 1
|
160 |
+
if i > early_stopping and patience <= 0:
|
161 |
+
break
|
162 |
+
|
163 |
+
if verbose:
|
164 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
|
165 |
+
self.load_state_dict(weights)
|
166 |
+
|
167 |
+
def test(self):
|
168 |
+
"""Evaluate ChebNet performance on test set.
|
169 |
+
|
170 |
+
Parameters
|
171 |
+
----------
|
172 |
+
idx_test :
|
173 |
+
node testing indices
|
174 |
+
"""
|
175 |
+
self.eval()
|
176 |
+
test_mask = self.data.test_mask
|
177 |
+
labels = self.data.y
|
178 |
+
output = self.forward(self.data)
|
179 |
+
# output = self.output
|
180 |
+
loss_test = F.nll_loss(output[test_mask], labels[test_mask])
|
181 |
+
acc_test = utils.accuracy(output[test_mask], labels[test_mask])
|
182 |
+
print("Test set results:",
|
183 |
+
"loss= {:.4f}".format(loss_test.item()),
|
184 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
185 |
+
return acc_test.item()
|
186 |
+
|
187 |
+
def predict(self):
|
188 |
+
"""
|
189 |
+
Returns
|
190 |
+
-------
|
191 |
+
torch.FloatTensor
|
192 |
+
output (log probabilities) of ChebNet
|
193 |
+
"""
|
194 |
+
|
195 |
+
self.eval()
|
196 |
+
return self.forward(self.data)
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
202 |
+
# from deeprobust.graph.defense import ChebNet
|
203 |
+
data = Dataset(root='/tmp/', name='cora')
|
204 |
+
adj, features, labels = data.adj, data.features, data.labels
|
205 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
206 |
+
cheby = ChebNet(nfeat=features.shape[1],
|
207 |
+
nhid=16,
|
208 |
+
nclass=labels.max().item() + 1,
|
209 |
+
dropout=0.5, device='cpu')
|
210 |
+
cheby = cheby.to('cpu')
|
211 |
+
pyg_data = Dpr2Pyg(data)
|
212 |
+
cheby.fit(pyg_data, verbose=True) # train with earlystopping
|
213 |
+
cheby.test()
|
214 |
+
print(cheby.predict())
|
215 |
+
|
deeprobust/graph/defense/data/processed/pre_filter.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b235e5f068a00e5bf391cec33fad69292177c841d5a8fd2ab76f764489fb6e7
|
3 |
+
size 431
|
deeprobust/graph/defense/gat.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Extended from https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation
|
3 |
+
"""
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.optim as optim
|
9 |
+
from torch.nn.parameter import Parameter
|
10 |
+
from torch.nn.modules.module import Module
|
11 |
+
from deeprobust.graph import utils
|
12 |
+
from copy import deepcopy
|
13 |
+
from torch_geometric.nn import GATConv
|
14 |
+
|
15 |
+
|
16 |
+
class GAT(nn.Module):
|
17 |
+
""" 2 Layer Graph Attention Network based on pytorch geometric.
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
nfeat : int
|
22 |
+
size of input feature dimension
|
23 |
+
nhid : int
|
24 |
+
number of hidden units
|
25 |
+
nclass : int
|
26 |
+
size of output dimension
|
27 |
+
heads: int
|
28 |
+
number of attention heads
|
29 |
+
output_heads: int
|
30 |
+
number of attention output heads
|
31 |
+
dropout : float
|
32 |
+
dropout rate for GAT
|
33 |
+
lr : float
|
34 |
+
learning rate for GAT
|
35 |
+
weight_decay : float
|
36 |
+
weight decay coefficient (l2 normalization) for GCN.
|
37 |
+
When `with_relu` is True, `weight_decay` will be set to 0.
|
38 |
+
with_bias: bool
|
39 |
+
whether to include bias term in GAT weights.
|
40 |
+
device: str
|
41 |
+
'cpu' or 'cuda'.
|
42 |
+
|
43 |
+
Examples
|
44 |
+
--------
|
45 |
+
We can first load dataset and then train GAT.
|
46 |
+
|
47 |
+
>>> from deeprobust.graph.data import Dataset
|
48 |
+
>>> from deeprobust.graph.defense import GAT
|
49 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
50 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
51 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
52 |
+
>>> gat = GAT(nfeat=features.shape[1],
|
53 |
+
nhid=8, heads=8,
|
54 |
+
nclass=labels.max().item() + 1,
|
55 |
+
dropout=0.5, device='cpu')
|
56 |
+
>>> gat = gat.to('cpu')
|
57 |
+
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
|
58 |
+
>>> gat.fit(pyg_data, patience=100, verbose=True) # train with earlystopping
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, nfeat, nhid, nclass, heads=8, output_heads=1, dropout=0.5, lr=0.01,
|
62 |
+
weight_decay=5e-4, with_bias=True, device=None):
|
63 |
+
|
64 |
+
super(GAT, self).__init__()
|
65 |
+
|
66 |
+
assert device is not None, "Please specify 'device'!"
|
67 |
+
self.device = device
|
68 |
+
|
69 |
+
self.conv1 = GATConv(
|
70 |
+
nfeat,
|
71 |
+
nhid,
|
72 |
+
heads=heads,
|
73 |
+
dropout=dropout,
|
74 |
+
bias=with_bias)
|
75 |
+
|
76 |
+
self.conv2 = GATConv(
|
77 |
+
nhid * heads,
|
78 |
+
nclass,
|
79 |
+
heads=output_heads,
|
80 |
+
concat=False,
|
81 |
+
dropout=dropout,
|
82 |
+
bias=with_bias)
|
83 |
+
|
84 |
+
self.dropout = dropout
|
85 |
+
self.weight_decay = weight_decay
|
86 |
+
self.lr = lr
|
87 |
+
self.output = None
|
88 |
+
self.best_model = None
|
89 |
+
self.best_output = None
|
90 |
+
|
91 |
+
def forward(self, data):
|
92 |
+
x, edge_index = data.x, data.edge_index
|
93 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
94 |
+
x = F.elu(self.conv1(x, edge_index))
|
95 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
96 |
+
x = self.conv2(x, edge_index)
|
97 |
+
return F.log_softmax(x, dim=1)
|
98 |
+
|
99 |
+
def initialize(self):
|
100 |
+
"""Initialize parameters of GAT.
|
101 |
+
"""
|
102 |
+
self.conv1.reset_parameters()
|
103 |
+
self.conv2.reset_parameters()
|
104 |
+
|
105 |
+
def fit(self, pyg_data, train_iters=200, initialize=True, verbose=False, patience=100, **kwargs):
|
106 |
+
"""Train the GAT model, when idx_val is not None, pick the best model
|
107 |
+
according to the validation loss.
|
108 |
+
|
109 |
+
Parameters
|
110 |
+
----------
|
111 |
+
pyg_data :
|
112 |
+
pytorch geometric dataset object
|
113 |
+
train_iters : int
|
114 |
+
number of training epochs
|
115 |
+
initialize : bool
|
116 |
+
whether to initialize parameters before training
|
117 |
+
verbose : bool
|
118 |
+
whether to show verbose logs
|
119 |
+
patience : int
|
120 |
+
patience for early stopping, only valid when `idx_val` is given
|
121 |
+
"""
|
122 |
+
|
123 |
+
|
124 |
+
if initialize:
|
125 |
+
self.initialize()
|
126 |
+
|
127 |
+
self.data = pyg_data[0].to(self.device)
|
128 |
+
# By default, it is trained with early stopping on validation
|
129 |
+
self.train_with_early_stopping(train_iters, patience, verbose)
|
130 |
+
|
131 |
+
def train_with_early_stopping(self, train_iters, patience, verbose):
|
132 |
+
"""early stopping based on the validation loss
|
133 |
+
"""
|
134 |
+
if verbose:
|
135 |
+
print('=== training GAT model ===')
|
136 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
137 |
+
|
138 |
+
labels = self.data.y
|
139 |
+
train_mask, val_mask = self.data.train_mask, self.data.val_mask
|
140 |
+
|
141 |
+
early_stopping = patience
|
142 |
+
best_loss_val = 100
|
143 |
+
|
144 |
+
for i in range(train_iters):
|
145 |
+
self.train()
|
146 |
+
optimizer.zero_grad()
|
147 |
+
output = self.forward(self.data)
|
148 |
+
|
149 |
+
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
|
150 |
+
loss_train.backward()
|
151 |
+
optimizer.step()
|
152 |
+
|
153 |
+
if verbose and i % 10 == 0:
|
154 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
155 |
+
|
156 |
+
self.eval()
|
157 |
+
output = self.forward(self.data)
|
158 |
+
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
|
159 |
+
|
160 |
+
if best_loss_val > loss_val:
|
161 |
+
best_loss_val = loss_val
|
162 |
+
self.output = output
|
163 |
+
weights = deepcopy(self.state_dict())
|
164 |
+
patience = early_stopping
|
165 |
+
else:
|
166 |
+
patience -= 1
|
167 |
+
if i > early_stopping and patience <= 0:
|
168 |
+
break
|
169 |
+
|
170 |
+
if verbose:
|
171 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
|
172 |
+
self.load_state_dict(weights)
|
173 |
+
|
174 |
+
def test(self):
|
175 |
+
"""Evaluate GAT performance on test set.
|
176 |
+
|
177 |
+
Parameters
|
178 |
+
----------
|
179 |
+
idx_test :
|
180 |
+
node testing indices
|
181 |
+
"""
|
182 |
+
self.eval()
|
183 |
+
test_mask = self.data.test_mask
|
184 |
+
labels = self.data.y
|
185 |
+
output = self.forward(self.data)
|
186 |
+
# output = self.output
|
187 |
+
loss_test = F.nll_loss(output[test_mask], labels[test_mask])
|
188 |
+
acc_test = utils.accuracy(output[test_mask], labels[test_mask])
|
189 |
+
print("Test set results:",
|
190 |
+
"loss= {:.4f}".format(loss_test.item()),
|
191 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
192 |
+
return acc_test.item()
|
193 |
+
|
194 |
+
def predict(self):
|
195 |
+
"""
|
196 |
+
Returns
|
197 |
+
-------
|
198 |
+
torch.FloatTensor
|
199 |
+
output (log probabilities) of GAT
|
200 |
+
"""
|
201 |
+
|
202 |
+
self.eval()
|
203 |
+
return self.forward(self.data)
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
209 |
+
# from deeprobust.graph.defense import GAT
|
210 |
+
data = Dataset(root='/tmp/', name='cora')
|
211 |
+
adj, features, labels = data.adj, data.features, data.labels
|
212 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
213 |
+
gat = GAT(nfeat=features.shape[1],
|
214 |
+
nhid=8, heads=8,
|
215 |
+
nclass=labels.max().item() + 1,
|
216 |
+
dropout=0.5, device='cpu')
|
217 |
+
gat = gat.to('cpu')
|
218 |
+
pyg_data = Dpr2Pyg(data)
|
219 |
+
gat.fit(pyg_data, verbose=True) # train with earlystopping
|
220 |
+
gat.test()
|
221 |
+
print(gat.predict())
|
222 |
+
|
deeprobust/graph/defense/gcn.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.nn.parameter import Parameter
|
8 |
+
from torch.nn.modules.module import Module
|
9 |
+
from deeprobust.graph import utils
|
10 |
+
from copy import deepcopy
|
11 |
+
from sklearn.metrics import f1_score
|
12 |
+
|
13 |
+
from collections import defaultdict
|
14 |
+
from functools import reduce
|
15 |
+
|
16 |
+
class GraphConvolution(Module):
|
17 |
+
"""Simple GCN layer, similar to https://github.com/tkipf/pygcn
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
21 |
+
super(GraphConvolution, self).__init__()
|
22 |
+
self.in_features = in_features
|
23 |
+
self.out_features = out_features
|
24 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
25 |
+
if with_bias:
|
26 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
27 |
+
else:
|
28 |
+
self.register_parameter('bias', None)
|
29 |
+
self.reset_parameters()
|
30 |
+
|
31 |
+
def reset_parameters(self):
|
32 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
33 |
+
self.weight.data.uniform_(-stdv, stdv)
|
34 |
+
if self.bias is not None:
|
35 |
+
self.bias.data.uniform_(-stdv, stdv)
|
36 |
+
|
37 |
+
# def forward(self, input, adj):
|
38 |
+
# """ Graph Convolutional Layer forward function
|
39 |
+
# """
|
40 |
+
# print("type(input):", input)
|
41 |
+
# print("type(adj):", adj)
|
42 |
+
# if input.data.is_sparse:
|
43 |
+
# support = torch.spmm(input, self.weight)
|
44 |
+
# else:
|
45 |
+
# support = torch.mm(input, self.weight)
|
46 |
+
# print("type(support):", support)
|
47 |
+
# output = torch.spmm(adj, support)
|
48 |
+
# if self.bias is not None:
|
49 |
+
# return output + self.bias
|
50 |
+
# else:
|
51 |
+
# return output
|
52 |
+
def forward(self, input, adj):
|
53 |
+
"""Graph Convolutional Layer forward function
|
54 |
+
"""
|
55 |
+
# print("type(input):", type(input))
|
56 |
+
# print("type(adj):", type(adj))
|
57 |
+
|
58 |
+
# Step 1: Compute AX
|
59 |
+
if input.data.is_sparse:
|
60 |
+
support = torch.spmm(adj, input) # Sparse matrix multiplication for AX
|
61 |
+
else:
|
62 |
+
support = torch.mm(adj, input) # Dense matrix multiplication for AX
|
63 |
+
|
64 |
+
# print("type(support):", type(support))
|
65 |
+
|
66 |
+
# Step 2: Compute AXW
|
67 |
+
output = torch.mm(support, self.weight) # Dense matrix multiplication for AXW
|
68 |
+
|
69 |
+
# Step 3: Add bias if applicable
|
70 |
+
if self.bias is not None:
|
71 |
+
return output + self.bias
|
72 |
+
else:
|
73 |
+
return output
|
74 |
+
|
75 |
+
|
76 |
+
def __repr__(self):
|
77 |
+
return self.__class__.__name__ + ' (' \
|
78 |
+
+ str(self.in_features) + ' -> ' \
|
79 |
+
+ str(self.out_features) + ')'
|
80 |
+
|
81 |
+
|
82 |
+
class GCN(nn.Module):
|
83 |
+
""" 2 Layer Graph Convolutional Network.
|
84 |
+
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
nfeat : int
|
88 |
+
size of input feature dimension
|
89 |
+
nhid : int
|
90 |
+
number of hidden units
|
91 |
+
nclass : int
|
92 |
+
size of output dimension
|
93 |
+
dropout : float
|
94 |
+
dropout rate for GCN
|
95 |
+
lr : float
|
96 |
+
learning rate for GCN
|
97 |
+
weight_decay : float
|
98 |
+
weight decay coefficient (l2 normalization) for GCN.
|
99 |
+
When `with_relu` is True, `weight_decay` will be set to 0.
|
100 |
+
with_relu : bool
|
101 |
+
whether to use relu activation function. If False, GCN will be linearized.
|
102 |
+
with_bias: bool
|
103 |
+
whether to include bias term in GCN weights.
|
104 |
+
device: str
|
105 |
+
'cpu' or 'cuda'.
|
106 |
+
|
107 |
+
Examples
|
108 |
+
--------
|
109 |
+
We can first load dataset and then train GCN.
|
110 |
+
|
111 |
+
>>> from deeprobust.graph.data import Dataset
|
112 |
+
>>> from deeprobust.graph.defense import GCN
|
113 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
114 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
115 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
116 |
+
>>> gcn = GCN(nfeat=features.shape[1],
|
117 |
+
nhid=16,
|
118 |
+
nclass=labels.max().item() + 1,
|
119 |
+
dropout=0.5, device='cpu')
|
120 |
+
>>> gcn = gcn.to('cpu')
|
121 |
+
>>> gcn.fit(features, adj, labels, idx_train) # train without earlystopping
|
122 |
+
>>> gcn.fit(features, adj, labels, idx_train, idx_val, patience=30) # train with earlystopping
|
123 |
+
>>> gcn.test(idx_test)
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
127 |
+
with_relu=True, with_bias=True, device=None):
|
128 |
+
|
129 |
+
super(GCN, self).__init__()
|
130 |
+
|
131 |
+
assert device is not None, "Please specify 'device'!"
|
132 |
+
self.device = device
|
133 |
+
self.nfeat = nfeat
|
134 |
+
self.hidden_sizes = [nhid]
|
135 |
+
self.nclass = nclass
|
136 |
+
self.gc1 = GraphConvolution(nfeat, nhid, with_bias=with_bias)
|
137 |
+
self.gc2 = GraphConvolution(nhid, nclass, with_bias=with_bias)
|
138 |
+
self.dropout = dropout
|
139 |
+
self.lr = lr
|
140 |
+
if not with_relu:
|
141 |
+
self.weight_decay = 0
|
142 |
+
else:
|
143 |
+
self.weight_decay = weight_decay
|
144 |
+
self.with_relu = with_relu
|
145 |
+
self.with_bias = with_bias
|
146 |
+
self.output = None
|
147 |
+
self.best_model = None
|
148 |
+
self.best_output = None
|
149 |
+
self.adj_norm = None
|
150 |
+
self.features = None
|
151 |
+
|
152 |
+
def forward(self, x, adj):
|
153 |
+
if self.with_relu:
|
154 |
+
x = F.relu(self.gc1(x, adj))
|
155 |
+
else:
|
156 |
+
x = self.gc1(x, adj)
|
157 |
+
|
158 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
159 |
+
x = self.gc2(x, adj)
|
160 |
+
return F.log_softmax(x, dim=1)
|
161 |
+
|
162 |
+
def initialize(self):
|
163 |
+
"""Initialize parameters of GCN.
|
164 |
+
"""
|
165 |
+
self.gc1.reset_parameters()
|
166 |
+
self.gc2.reset_parameters()
|
167 |
+
|
168 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, initialize=True, verbose=False, normalize=True, patience=500, **kwargs):
|
169 |
+
"""Train the gcn model, when idx_val is not None, pick the best model according to the validation loss.
|
170 |
+
|
171 |
+
Parameters
|
172 |
+
----------
|
173 |
+
features :
|
174 |
+
node features
|
175 |
+
adj :
|
176 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
177 |
+
labels :
|
178 |
+
node labels
|
179 |
+
idx_train :
|
180 |
+
node training indices
|
181 |
+
idx_val :
|
182 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
183 |
+
train_iters : int
|
184 |
+
number of training epochs
|
185 |
+
initialize : bool
|
186 |
+
whether to initialize parameters before training
|
187 |
+
verbose : bool
|
188 |
+
whether to show verbose logs
|
189 |
+
normalize : bool
|
190 |
+
whether to normalize the input adjacency matrix.
|
191 |
+
patience : int
|
192 |
+
patience for early stopping, only valid when `idx_val` is given
|
193 |
+
"""
|
194 |
+
|
195 |
+
self.device = self.gc1.weight.device
|
196 |
+
if initialize:
|
197 |
+
self.initialize()
|
198 |
+
|
199 |
+
if type(adj) is not torch.Tensor:
|
200 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
201 |
+
else:
|
202 |
+
features = features.to(self.device)
|
203 |
+
adj = adj.to(self.device)
|
204 |
+
labels = labels.to(self.device)
|
205 |
+
|
206 |
+
if normalize:
|
207 |
+
if utils.is_sparse_tensor(adj):
|
208 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
209 |
+
else:
|
210 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
211 |
+
else:
|
212 |
+
adj_norm = adj
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
self.adj_norm = adj_norm
|
217 |
+
self.features = features
|
218 |
+
self.labels = labels
|
219 |
+
|
220 |
+
if idx_val is None:
|
221 |
+
self._train_without_val(labels, idx_train, train_iters, verbose)
|
222 |
+
else:
|
223 |
+
if patience < train_iters:
|
224 |
+
self._train_with_early_stopping(labels, idx_train, idx_val, train_iters, patience, verbose)
|
225 |
+
else:
|
226 |
+
self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
|
227 |
+
|
228 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose):
|
229 |
+
self.train()
|
230 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
231 |
+
for i in range(train_iters):
|
232 |
+
optimizer.zero_grad()
|
233 |
+
output = self.forward(self.features, self.adj_norm)
|
234 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
235 |
+
loss_train.backward()
|
236 |
+
optimizer.step()
|
237 |
+
if verbose and i % 10 == 0:
|
238 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
239 |
+
|
240 |
+
self.eval()
|
241 |
+
output = self.forward(self.features, self.adj_norm)
|
242 |
+
self.output = output
|
243 |
+
|
244 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
245 |
+
if verbose:
|
246 |
+
print('=== training gcn model ===')
|
247 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
248 |
+
|
249 |
+
best_loss_val = 100
|
250 |
+
best_acc_val = 0
|
251 |
+
|
252 |
+
for i in range(train_iters):
|
253 |
+
self.train()
|
254 |
+
optimizer.zero_grad()
|
255 |
+
output = self.forward(self.features, self.adj_norm)
|
256 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
257 |
+
loss_train.backward()
|
258 |
+
optimizer.step()
|
259 |
+
|
260 |
+
if verbose and i % 10 == 0:
|
261 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
262 |
+
|
263 |
+
self.eval()
|
264 |
+
output = self.forward(self.features, self.adj_norm)
|
265 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
266 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
267 |
+
|
268 |
+
if best_loss_val > loss_val:
|
269 |
+
best_loss_val = loss_val
|
270 |
+
self.output = output
|
271 |
+
weights = deepcopy(self.state_dict())
|
272 |
+
|
273 |
+
if acc_val > best_acc_val:
|
274 |
+
best_acc_val = acc_val
|
275 |
+
self.output = output
|
276 |
+
weights = deepcopy(self.state_dict())
|
277 |
+
|
278 |
+
if verbose:
|
279 |
+
print('=== picking the best model according to the performance on validation ===')
|
280 |
+
self.load_state_dict(weights)
|
281 |
+
|
282 |
+
def _train_with_early_stopping(self, labels, idx_train, idx_val, train_iters, patience, verbose):
|
283 |
+
if verbose:
|
284 |
+
print('=== training gcn model ===')
|
285 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
286 |
+
|
287 |
+
early_stopping = patience
|
288 |
+
best_loss_val = 100
|
289 |
+
|
290 |
+
for i in range(train_iters):
|
291 |
+
self.train()
|
292 |
+
optimizer.zero_grad()
|
293 |
+
output = self.forward(self.features, self.adj_norm)
|
294 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
295 |
+
loss_train.backward()
|
296 |
+
optimizer.step()
|
297 |
+
|
298 |
+
if verbose and i % 10 == 0:
|
299 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
300 |
+
|
301 |
+
self.eval()
|
302 |
+
output = self.forward(self.features, self.adj_norm)
|
303 |
+
|
304 |
+
# def eval_class(output, labels):
|
305 |
+
# preds = output.max(1)[1].type_as(labels)
|
306 |
+
# return f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='micro') + \
|
307 |
+
# f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
|
308 |
+
|
309 |
+
# perf_sum = eval_class(output[idx_val], labels[idx_val])
|
310 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
311 |
+
|
312 |
+
if best_loss_val > loss_val:
|
313 |
+
best_loss_val = loss_val
|
314 |
+
self.output = output
|
315 |
+
weights = deepcopy(self.state_dict())
|
316 |
+
patience = early_stopping
|
317 |
+
else:
|
318 |
+
patience -= 1
|
319 |
+
if i > early_stopping and patience <= 0:
|
320 |
+
break
|
321 |
+
|
322 |
+
if verbose:
|
323 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
|
324 |
+
self.load_state_dict(weights)
|
325 |
+
|
326 |
+
def test(self, idx_test):
|
327 |
+
"""Evaluate GCN performance on test set.
|
328 |
+
|
329 |
+
Parameters
|
330 |
+
----------
|
331 |
+
idx_test :
|
332 |
+
node testing indices
|
333 |
+
"""
|
334 |
+
self.eval()
|
335 |
+
output = self.predict()
|
336 |
+
# output = self.output
|
337 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
338 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
339 |
+
print("Test set results:",
|
340 |
+
"loss= {:.4f}".format(loss_test.item()),
|
341 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
342 |
+
return acc_test.item()
|
343 |
+
|
344 |
+
|
345 |
+
def predict(self, features=None, adj=None):
|
346 |
+
"""By default, the inputs should be unnormalized adjacency
|
347 |
+
|
348 |
+
Parameters
|
349 |
+
----------
|
350 |
+
features :
|
351 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
352 |
+
adj :
|
353 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
354 |
+
|
355 |
+
|
356 |
+
Returns
|
357 |
+
-------
|
358 |
+
torch.FloatTensor
|
359 |
+
output (log probabilities) of GCN
|
360 |
+
"""
|
361 |
+
|
362 |
+
self.eval()
|
363 |
+
if features is None and adj is None:
|
364 |
+
return self.forward(self.features, self.adj_norm)
|
365 |
+
else:
|
366 |
+
if type(adj) is not torch.Tensor:
|
367 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
368 |
+
|
369 |
+
self.features = features
|
370 |
+
if utils.is_sparse_tensor(adj):
|
371 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
372 |
+
else:
|
373 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
374 |
+
return self.forward(self.features, self.adj_norm)
|
375 |
+
|
376 |
+
|
377 |
+
|
deeprobust/graph/defense/gcn_cgscore.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.nn.parameter import Parameter
|
8 |
+
from torch.nn.modules.module import Module
|
9 |
+
from deeprobust.graph import utils
|
10 |
+
from copy import deepcopy
|
11 |
+
from sklearn.metrics import f1_score
|
12 |
+
|
13 |
+
from collections import defaultdict
|
14 |
+
from functools import reduce
|
15 |
+
|
16 |
+
class GraphConvolution(Module):
|
17 |
+
"""Simple GCN layer, similar to https://github.com/tkipf/pygcn
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
21 |
+
super(GraphConvolution, self).__init__()
|
22 |
+
self.in_features = in_features
|
23 |
+
self.out_features = out_features
|
24 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
25 |
+
if with_bias:
|
26 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
27 |
+
else:
|
28 |
+
self.register_parameter('bias', None)
|
29 |
+
self.reset_parameters()
|
30 |
+
|
31 |
+
def reset_parameters(self):
|
32 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
33 |
+
self.weight.data.uniform_(-stdv, stdv)
|
34 |
+
if self.bias is not None:
|
35 |
+
self.bias.data.uniform_(-stdv, stdv)
|
36 |
+
|
37 |
+
def forward(self, input, adj):
|
38 |
+
""" Graph Convolutional Layer forward function
|
39 |
+
"""
|
40 |
+
if input.data.is_sparse:
|
41 |
+
support = torch.spmm(input, self.weight)
|
42 |
+
else:
|
43 |
+
support = torch.mm(input, self.weight)
|
44 |
+
output = torch.spmm(adj, support)
|
45 |
+
if self.bias is not None:
|
46 |
+
return output + self.bias
|
47 |
+
else:
|
48 |
+
return output
|
49 |
+
|
50 |
+
def __repr__(self):
|
51 |
+
return self.__class__.__name__ + ' (' \
|
52 |
+
+ str(self.in_features) + ' -> ' \
|
53 |
+
+ str(self.out_features) + ')'
|
54 |
+
|
55 |
+
class GraphConvolution_cg(Module):
|
56 |
+
"""Simple GCN layer, similar to https://github.com/tkipf/pygcn
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
60 |
+
super(GraphConvolution_cg, self).__init__()
|
61 |
+
self.in_features = in_features
|
62 |
+
self.out_features = out_features
|
63 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
64 |
+
if with_bias:
|
65 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
66 |
+
else:
|
67 |
+
self.register_parameter('bias', None)
|
68 |
+
self.reset_parameters()
|
69 |
+
|
70 |
+
def reset_parameters(self):
|
71 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
72 |
+
self.weight.data.uniform_(-stdv, stdv)
|
73 |
+
if self.bias is not None:
|
74 |
+
self.bias.data.uniform_(-stdv, stdv)
|
75 |
+
|
76 |
+
# def forward(self, input):
|
77 |
+
# """ Graph Convolutional Layer forward function
|
78 |
+
# """
|
79 |
+
# if input.data.is_sparse:
|
80 |
+
# support = input
|
81 |
+
# else:
|
82 |
+
# support = torch.mm(input, self.weight)
|
83 |
+
# output = torch.spmm(adj, support)
|
84 |
+
|
85 |
+
# if self.bias is not None:
|
86 |
+
# return output + self.bias
|
87 |
+
# else:
|
88 |
+
# return output
|
89 |
+
|
90 |
+
def forward(self, cg_features):
|
91 |
+
""" Graph Convolutional Layer forward function
|
92 |
+
"""
|
93 |
+
# print("cg_features:", cg_features)
|
94 |
+
# if cg_features.data.is_sparse:
|
95 |
+
# output = torch.spmm(cg_features, self.weight)
|
96 |
+
|
97 |
+
# else:
|
98 |
+
# output = torch.mm(cg_features, self.weight)
|
99 |
+
output = torch.spmm(cg_features, self.weight)
|
100 |
+
if self.bias is not None:
|
101 |
+
return output + self.bias
|
102 |
+
else:
|
103 |
+
return output
|
104 |
+
|
105 |
+
def __repr__(self):
|
106 |
+
return self.__class__.__name__ + ' (' \
|
107 |
+
+ str(self.in_features) + ' -> ' \
|
108 |
+
+ str(self.out_features) + ')'
|
109 |
+
|
110 |
+
class GCNScore(nn.Module):
|
111 |
+
""" 2 Layer Graph Convolutional Network.
|
112 |
+
|
113 |
+
Parameters
|
114 |
+
----------
|
115 |
+
nfeat : int
|
116 |
+
size of input feature dimension
|
117 |
+
nhid : int
|
118 |
+
number of hidden units
|
119 |
+
nclass : int
|
120 |
+
size of output dimension
|
121 |
+
dropout : float
|
122 |
+
dropout rate for GCN
|
123 |
+
lr : float
|
124 |
+
learning rate for GCN
|
125 |
+
weight_decay : float
|
126 |
+
weight decay coefficient (l2 normalization) for GCN.
|
127 |
+
When `with_relu` is True, `weight_decay` will be set to 0.
|
128 |
+
with_relu : bool
|
129 |
+
whether to use relu activation function. If False, GCN will be linearized.
|
130 |
+
with_bias: bool
|
131 |
+
whether to include bias term in GCN weights.
|
132 |
+
device: str
|
133 |
+
'cpu' or 'cuda'.
|
134 |
+
|
135 |
+
Examples
|
136 |
+
--------
|
137 |
+
We can first load dataset and then train GCN.
|
138 |
+
|
139 |
+
>>> from deeprobust.graph.data import Dataset
|
140 |
+
>>> from deeprobust.graph.defense import GCN
|
141 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
142 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
143 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
144 |
+
>>> gcn = GCN(nfeat=features.shape[1],
|
145 |
+
nhid=16,
|
146 |
+
nclass=labels.max().item() + 1,
|
147 |
+
dropout=0.5, device='cpu')
|
148 |
+
>>> gcn = gcn.to('cpu')
|
149 |
+
>>> gcn.fit(features, adj, labels, idx_train) # train without earlystopping
|
150 |
+
>>> gcn.fit(features, adj, labels, idx_train, idx_val, patience=30) # train with earlystopping
|
151 |
+
>>> gcn.test(idx_test)
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
155 |
+
with_relu=True, with_bias=True, device=None):
|
156 |
+
|
157 |
+
super(GCNScore, self).__init__()
|
158 |
+
|
159 |
+
assert device is not None, "Please specify 'device'!"
|
160 |
+
self.device = device
|
161 |
+
self.nfeat = nfeat
|
162 |
+
self.hidden_sizes = [nhid]
|
163 |
+
self.nclass = nclass
|
164 |
+
self.gc1 = GraphConvolution_cg(nfeat, nhid, with_bias=with_bias)
|
165 |
+
self.gc2 = GraphConvolution_cg(nhid, nclass, with_bias=with_bias)
|
166 |
+
self.dropout = dropout
|
167 |
+
self.lr = lr
|
168 |
+
if not with_relu:
|
169 |
+
self.weight_decay = 0
|
170 |
+
else:
|
171 |
+
self.weight_decay = weight_decay
|
172 |
+
self.with_relu = with_relu
|
173 |
+
self.with_bias = with_bias
|
174 |
+
self.output = None
|
175 |
+
self.best_model = None
|
176 |
+
self.best_output = None
|
177 |
+
self.adj_norm = None
|
178 |
+
self.features = None
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
if self.with_relu:
|
182 |
+
x = F.relu(self.gc1(x))
|
183 |
+
else:
|
184 |
+
x = self.gc1(x)
|
185 |
+
|
186 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
187 |
+
x = self.gc2(x)
|
188 |
+
return F.log_softmax(x, dim=1)
|
189 |
+
|
190 |
+
def initialize(self):
|
191 |
+
"""Initialize parameters of GCN.
|
192 |
+
"""
|
193 |
+
self.gc1.reset_parameters()
|
194 |
+
self.gc2.reset_parameters()
|
195 |
+
|
196 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, initialize=True, verbose=False, normalize=True, patience=500, **kwargs):
|
197 |
+
"""Train the gcn model, when idx_val is not None, pick the best model according to the validation loss.
|
198 |
+
|
199 |
+
Parameters
|
200 |
+
----------
|
201 |
+
features :
|
202 |
+
node features
|
203 |
+
adj :
|
204 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
205 |
+
labels :
|
206 |
+
node labels
|
207 |
+
idx_train :
|
208 |
+
node training indices
|
209 |
+
idx_val :
|
210 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
211 |
+
train_iters : int
|
212 |
+
number of training epochs
|
213 |
+
initialize : bool
|
214 |
+
whether to initialize parameters before training
|
215 |
+
verbose : bool
|
216 |
+
whether to show verbose logs
|
217 |
+
normalize : bool
|
218 |
+
whether to normalize the input adjacency matrix.
|
219 |
+
patience : int
|
220 |
+
patience for early stopping, only valid when `idx_val` is given
|
221 |
+
"""
|
222 |
+
|
223 |
+
self.device = self.gc1.weight.device
|
224 |
+
if initialize:
|
225 |
+
self.initialize()
|
226 |
+
|
227 |
+
if type(adj) is not torch.Tensor:
|
228 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
229 |
+
else:
|
230 |
+
features = features.to(self.device)
|
231 |
+
adj = adj.to(self.device)
|
232 |
+
labels = labels.to(self.device)
|
233 |
+
|
234 |
+
if normalize:
|
235 |
+
if utils.is_sparse_tensor(adj):
|
236 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
237 |
+
else:
|
238 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
239 |
+
else:
|
240 |
+
adj_norm = adj
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
self.adj_norm = adj_norm
|
245 |
+
self.features = features
|
246 |
+
self.labels = labels
|
247 |
+
|
248 |
+
if features.data.is_sparse:
|
249 |
+
feature_ax = torch.spmm(adj_norm, features)
|
250 |
+
else:
|
251 |
+
feature_ax = torch.mm(adj, features)
|
252 |
+
|
253 |
+
self.features_ax = feature_ax
|
254 |
+
|
255 |
+
if idx_val is None:
|
256 |
+
self._train_without_val(labels, idx_train, train_iters, verbose)
|
257 |
+
else:
|
258 |
+
if patience < train_iters:
|
259 |
+
self._train_with_early_stopping(labels, idx_train, idx_val, train_iters, patience, verbose)
|
260 |
+
else:
|
261 |
+
self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
|
262 |
+
|
263 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose):
|
264 |
+
self.train()
|
265 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
266 |
+
for i in range(train_iters):
|
267 |
+
optimizer.zero_grad()
|
268 |
+
output = self.forward(self.features_ax)
|
269 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
270 |
+
loss_train.backward()
|
271 |
+
optimizer.step()
|
272 |
+
if verbose and i % 10 == 0:
|
273 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
274 |
+
|
275 |
+
self.eval()
|
276 |
+
output = self.forward(self.features_ax)
|
277 |
+
self.output = output
|
278 |
+
|
279 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
280 |
+
if verbose:
|
281 |
+
print('=== training gcn model ===')
|
282 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
283 |
+
|
284 |
+
best_loss_val = 100
|
285 |
+
best_acc_val = 0
|
286 |
+
|
287 |
+
for i in range(train_iters):
|
288 |
+
self.train()
|
289 |
+
optimizer.zero_grad()
|
290 |
+
output = self.forward(self.features_ax)
|
291 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
292 |
+
loss_train.backward()
|
293 |
+
optimizer.step()
|
294 |
+
|
295 |
+
if verbose and i % 10 == 0:
|
296 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
297 |
+
|
298 |
+
self.eval()
|
299 |
+
output = self.forward(self.features_ax)
|
300 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
301 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
302 |
+
|
303 |
+
if best_loss_val > loss_val:
|
304 |
+
best_loss_val = loss_val
|
305 |
+
self.output = output
|
306 |
+
weights = deepcopy(self.state_dict())
|
307 |
+
|
308 |
+
if acc_val > best_acc_val:
|
309 |
+
best_acc_val = acc_val
|
310 |
+
self.output = output
|
311 |
+
weights = deepcopy(self.state_dict())
|
312 |
+
|
313 |
+
if verbose:
|
314 |
+
print('=== picking the best model according to the performance on validation ===')
|
315 |
+
self.load_state_dict(weights)
|
316 |
+
|
317 |
+
def _train_with_early_stopping(self, labels, idx_train, idx_val, train_iters, patience, verbose):
|
318 |
+
if verbose:
|
319 |
+
print('=== training gcn model ===')
|
320 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
321 |
+
|
322 |
+
early_stopping = patience
|
323 |
+
best_loss_val = 100
|
324 |
+
|
325 |
+
for i in range(train_iters):
|
326 |
+
self.train()
|
327 |
+
optimizer.zero_grad()
|
328 |
+
output = self.forward(self.features_ax)
|
329 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
330 |
+
loss_train.backward()
|
331 |
+
optimizer.step()
|
332 |
+
|
333 |
+
if verbose and i % 10 == 0:
|
334 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
335 |
+
|
336 |
+
self.eval()
|
337 |
+
output = self.forward(self.features_ax)
|
338 |
+
|
339 |
+
# def eval_class(output, labels):
|
340 |
+
# preds = output.max(1)[1].type_as(labels)
|
341 |
+
# return f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='micro') + \
|
342 |
+
# f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
|
343 |
+
|
344 |
+
# perf_sum = eval_class(output[idx_val], labels[idx_val])
|
345 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
346 |
+
|
347 |
+
if best_loss_val > loss_val:
|
348 |
+
best_loss_val = loss_val
|
349 |
+
self.output = output
|
350 |
+
weights = deepcopy(self.state_dict())
|
351 |
+
patience = early_stopping
|
352 |
+
else:
|
353 |
+
patience -= 1
|
354 |
+
if i > early_stopping and patience <= 0:
|
355 |
+
break
|
356 |
+
|
357 |
+
if verbose:
|
358 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
|
359 |
+
self.load_state_dict(weights)
|
360 |
+
|
361 |
+
def test(self, idx_test):
|
362 |
+
"""Evaluate GCN performance on test set.
|
363 |
+
|
364 |
+
Parameters
|
365 |
+
----------
|
366 |
+
idx_test :
|
367 |
+
node testing indices
|
368 |
+
"""
|
369 |
+
self.eval()
|
370 |
+
# output = self.predict()
|
371 |
+
output = self.forward(self.features_ax)
|
372 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
373 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
374 |
+
print("Test set results:",
|
375 |
+
"loss= {:.4f}".format(loss_test.item()),
|
376 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
377 |
+
return acc_test.item()
|
378 |
+
|
379 |
+
|
380 |
+
def predict(self, features=None, adj=None):
|
381 |
+
"""By default, the inputs should be unnormalized adjacency
|
382 |
+
|
383 |
+
Parameters
|
384 |
+
----------
|
385 |
+
features :
|
386 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
387 |
+
adj :
|
388 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
389 |
+
|
390 |
+
|
391 |
+
Returns
|
392 |
+
-------
|
393 |
+
torch.FloatTensor
|
394 |
+
output (log probabilities) of GCN
|
395 |
+
"""
|
396 |
+
print("f_exist:", features)
|
397 |
+
print("adj_exist:", adj)
|
398 |
+
self.eval()
|
399 |
+
if features is None and adj is None:
|
400 |
+
return self.forward(self.features_ax)
|
401 |
+
else:
|
402 |
+
if type(adj) is not torch.Tensor:
|
403 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
404 |
+
|
405 |
+
self.features = features_
|
406 |
+
if utils.is_sparse_tensor(adj):
|
407 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
408 |
+
else:
|
409 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
410 |
+
return self.forward(self.features_ax)
|
411 |
+
|
412 |
+
|
413 |
+
|
deeprobust/graph/defense/gcn_guard.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.optim as optim
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
from torch.nn.modules.module import Module
|
8 |
+
from deeprobust.graph import utils
|
9 |
+
from copy import deepcopy
|
10 |
+
import scipy
|
11 |
+
from sklearn.metrics import jaccard_score
|
12 |
+
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
|
13 |
+
import numpy as np
|
14 |
+
from deeprobust.graph.utils import *
|
15 |
+
from torch_geometric.nn import GINConv, GATConv, GCNConv, JumpingKnowledge
|
16 |
+
# from nn.conv.gcn_conv import GCNConv
|
17 |
+
# from nn import GINConv, GATConv, GCNConv, JumpingKnowledge
|
18 |
+
from torch.nn import Sequential, Linear, ReLU
|
19 |
+
from sklearn.preprocessing import normalize
|
20 |
+
# from deeprobust.graph.defense.basicfunction import att_coef
|
21 |
+
# from sklearn.metrics import f1_score
|
22 |
+
from scipy.sparse import lil_matrix
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
class GCNGuard(nn.Module):
|
27 |
+
|
28 |
+
def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, drop=False, weight_decay=5e-4, with_relu=True,
|
29 |
+
with_bias=True, device=None):
|
30 |
+
|
31 |
+
super(GCNGuard, self).__init__()
|
32 |
+
|
33 |
+
assert device is not None, "Please specify 'device'!"
|
34 |
+
self.device = device
|
35 |
+
self.nfeat = nfeat
|
36 |
+
self.hidden_sizes = [nhid]
|
37 |
+
self.nclass = nclass
|
38 |
+
self.dropout = dropout
|
39 |
+
self.lr = lr
|
40 |
+
|
41 |
+
# weight_decay =0 # set weight_decay as 0
|
42 |
+
|
43 |
+
if not with_relu:
|
44 |
+
self.weight_decay = 0
|
45 |
+
else:
|
46 |
+
self.weight_decay = weight_decay
|
47 |
+
self.with_relu = with_relu
|
48 |
+
self.with_bias = with_bias
|
49 |
+
self.output = None
|
50 |
+
self.best_model = None
|
51 |
+
self.best_output = None
|
52 |
+
self.adj_norm = None
|
53 |
+
self.features = None
|
54 |
+
self.gate = Parameter(torch.rand(1)) # creat a generator between [0,1]
|
55 |
+
self.test_value = Parameter(torch.rand(1))
|
56 |
+
self.drop_learn_1 = Linear(2, 1)
|
57 |
+
self.drop_learn_2 = Linear(2, 1)
|
58 |
+
self.drop = drop
|
59 |
+
self.bn1 = torch.nn.BatchNorm1d(nhid)
|
60 |
+
self.bn2 = torch.nn.BatchNorm1d(nhid)
|
61 |
+
nclass = int(nclass)
|
62 |
+
|
63 |
+
"""GCN from geometric"""
|
64 |
+
"""network from torch-geometric, """
|
65 |
+
self.gc1 = GCNConv(nfeat, nhid, bias=True,)
|
66 |
+
self.gc2 = GCNConv(nhid, nclass, bias=True, )
|
67 |
+
|
68 |
+
|
69 |
+
# """GAT from torch-geometric"""
|
70 |
+
# nclass = int(nclass)
|
71 |
+
# self.gc1 = GATConv(nfeat, nhid, heads=8, dropout=0.6)
|
72 |
+
# self.gc2 = GATConv(nhid*8, nclass, heads=1, concat=True, dropout=0.6)
|
73 |
+
|
74 |
+
"""GIN from torch-geometric"""
|
75 |
+
# dim = 32
|
76 |
+
# nn1 = Sequential(Linear(nfeat, dim), ReLU(), )
|
77 |
+
# self.gc1 = GINConv(nn1)
|
78 |
+
# # self.bn1 = torch.nn.BatchNorm1d(dim)
|
79 |
+
# nn2 = Sequential(Linear(dim, dim), ReLU(), )
|
80 |
+
# self.gc2 = GINConv(nn2)
|
81 |
+
# self.jump = JumpingKnowledge(mode='cat')
|
82 |
+
# # self.bn2 = torch.nn.BatchNorm1d(dim)
|
83 |
+
# self.fc2 = Linear(dim, int(nclass))
|
84 |
+
|
85 |
+
# """JK-Nets"""
|
86 |
+
# num_features = nfeat
|
87 |
+
# dim = 32
|
88 |
+
# nn1 = Sequential(Linear(num_features, dim), ReLU(), )
|
89 |
+
# self.gc1 = GINConv(nn1)
|
90 |
+
# self.bn1 = torch.nn.BatchNorm1d(dim)
|
91 |
+
#
|
92 |
+
# nn2 = Sequential(Linear(dim, dim), ReLU(), )
|
93 |
+
# self.gc2 = GINConv(nn2)
|
94 |
+
# nn3 = Sequential(Linear(dim, dim), ReLU(), )
|
95 |
+
# self.gc3 = GINConv(nn3)
|
96 |
+
#
|
97 |
+
# self.jump = JumpingKnowledge(mode='cat') # 'cat', 'lstm', 'max'
|
98 |
+
# self.bn2 = torch.nn.BatchNorm1d(dim)
|
99 |
+
# # self.fc1 = Linear(dim*3, dim)
|
100 |
+
# self.fc2 = Linear(dim*2, int(nclass))
|
101 |
+
|
102 |
+
def forward(self, x, adj):
|
103 |
+
"""we don't change the edge_index, just update the edge_weight;
|
104 |
+
some edge_weight are regarded as removed if it equals to zero"""
|
105 |
+
x = x.to_dense() # topology attack中需要注销掉(wisconsin不需要), 在meta attack还有 dice中均需要
|
106 |
+
|
107 |
+
"""GCN and GAT"""
|
108 |
+
if self.attention:
|
109 |
+
adj = self.att_coef(x, adj, i=0)
|
110 |
+
adj = adj.to(self.device)
|
111 |
+
|
112 |
+
edge_index = adj._indices()
|
113 |
+
x = self.gc1(x, edge_index, edge_weight=adj._values())
|
114 |
+
x = F.relu(x)
|
115 |
+
# x = self.bn1(x)
|
116 |
+
if self.attention: # if attention=True, use attention mechanism
|
117 |
+
adj_2 = self.att_coef(x, adj, i=1)
|
118 |
+
adj_2= adj_2.to(self.device)
|
119 |
+
# adj_memory = adj_2.to_dense() # without memory
|
120 |
+
adj_memory = self.gate * adj.to_dense() + (1 - self.gate) * adj_2.to_dense()
|
121 |
+
row, col = adj_memory.nonzero()[:,0], adj_memory.nonzero()[:,1]
|
122 |
+
edge_index = torch.stack((row, col), dim=0)
|
123 |
+
edge_index = edge_index.to(self.device)
|
124 |
+
adj_values = adj_memory[row, col]
|
125 |
+
adj_values = adj_values.to(self.device)
|
126 |
+
else:
|
127 |
+
edge_index = adj._indices()
|
128 |
+
adj_values = adj._values()
|
129 |
+
edge_index = edge_index.to(self.device)
|
130 |
+
adj_values = adj_values.to(self.device)
|
131 |
+
|
132 |
+
|
133 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
134 |
+
x = self.gc2(x, edge_index, edge_weight=adj_values)
|
135 |
+
|
136 |
+
return F.log_softmax(x, dim=1)
|
137 |
+
|
138 |
+
def initialize(self):
|
139 |
+
self.gc1.reset_parameters()
|
140 |
+
self.gc2.reset_parameters()
|
141 |
+
self.drop_learn_1.reset_parameters()
|
142 |
+
self.drop_learn_2.reset_parameters()
|
143 |
+
try:
|
144 |
+
self.gate.reset_parameters()
|
145 |
+
self.fc2.reset_parameters()
|
146 |
+
except:
|
147 |
+
pass
|
148 |
+
|
149 |
+
def att_coef(self, fea, edge_index, is_lil=False, i=0):
|
150 |
+
|
151 |
+
if is_lil == False:
|
152 |
+
edge_index = edge_index._indices()
|
153 |
+
else:
|
154 |
+
edge_index = edge_index.tocoo()
|
155 |
+
|
156 |
+
n_node = fea.shape[0]
|
157 |
+
row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu().data.numpy()[:]
|
158 |
+
|
159 |
+
fea_copy = fea.cpu().data.numpy()
|
160 |
+
sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy) # try cosine similarity
|
161 |
+
sim = sim_matrix[row, col]
|
162 |
+
sim[sim<0.1] = 0
|
163 |
+
# print('dropped {} edges'.format(1-sim.nonzero()[0].shape[0]/len(sim)))
|
164 |
+
|
165 |
+
# """use jaccard for binary features and cosine for numeric features"""
|
166 |
+
# fea_start, fea_end = fea[edge_index[0]], fea[edge_index[1]]
|
167 |
+
# isbinray = np.array_equal(fea_copy, fea_copy.astype(bool)) # check is the fea are binary
|
168 |
+
# np.seterr(divide='ignore', invalid='ignore')
|
169 |
+
# if isbinray:
|
170 |
+
# fea_start, fea_end = fea_start.T, fea_end.T
|
171 |
+
# sim = jaccard_score(fea_start, fea_end, average=None) # similarity scores of each edge
|
172 |
+
# else:
|
173 |
+
# fea_copy[np.isinf(fea_copy)] = 0
|
174 |
+
# fea_copy[np.isnan(fea_copy)] = 0
|
175 |
+
# sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy) # try cosine similarity
|
176 |
+
# sim = sim_matrix[edge_index[0], edge_index[1]]
|
177 |
+
# sim[sim < 0.01] = 0
|
178 |
+
|
179 |
+
"""build a attention matrix"""
|
180 |
+
att_dense = lil_matrix((n_node, n_node), dtype=np.float32)
|
181 |
+
att_dense[row, col] = sim
|
182 |
+
if att_dense[0, 0] == 1:
|
183 |
+
att_dense = att_dense - sp.diags(att_dense.diagonal(), offsets=0, format="lil")
|
184 |
+
# normalization, make the sum of each row is 1
|
185 |
+
att_dense_norm = normalize(att_dense, axis=1, norm='l1')
|
186 |
+
|
187 |
+
|
188 |
+
"""add learnable dropout, make character vector"""
|
189 |
+
if self.drop:
|
190 |
+
character = np.vstack((att_dense_norm[row, col].A1,
|
191 |
+
att_dense_norm[col, row].A1))
|
192 |
+
character = torch.from_numpy(character.T)
|
193 |
+
drop_score = self.drop_learn_1(character)
|
194 |
+
drop_score = torch.sigmoid(drop_score) # do not use softmax since we only have one element
|
195 |
+
mm = torch.nn.Threshold(0.5, 0)
|
196 |
+
drop_score = mm(drop_score)
|
197 |
+
mm_2 = torch.nn.Threshold(-0.49, 1)
|
198 |
+
drop_score = mm_2(-drop_score)
|
199 |
+
drop_decision = drop_score.clone().requires_grad_()
|
200 |
+
# print('rate of left edges', drop_decision.sum().data/drop_decision.shape[0])
|
201 |
+
drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32)
|
202 |
+
drop_matrix[row, col] = drop_decision.cpu().data.numpy().squeeze(-1)
|
203 |
+
att_dense_norm = att_dense_norm.multiply(drop_matrix.tocsr()) # update, remove the 0 edges
|
204 |
+
|
205 |
+
if att_dense_norm[0, 0] == 0: # add the weights of self-loop only add self-loop at the first layer
|
206 |
+
degree = (att_dense_norm != 0).sum(1).A1
|
207 |
+
lam = 1 / (degree + 1) # degree +1 is to add itself
|
208 |
+
self_weight = sp.diags(np.array(lam), offsets=0, format="lil")
|
209 |
+
att = att_dense_norm + self_weight # add the self loop
|
210 |
+
else:
|
211 |
+
att = att_dense_norm
|
212 |
+
|
213 |
+
row, col = att.nonzero()
|
214 |
+
att_adj = np.vstack((row, col))
|
215 |
+
att_edge_weight = att[row, col]
|
216 |
+
att_edge_weight = np.exp(att_edge_weight) # exponent, kind of softmax
|
217 |
+
att_edge_weight = torch.tensor(np.array(att_edge_weight)[0], dtype=torch.float32)#.cuda()
|
218 |
+
att_adj = torch.tensor(att_adj, dtype=torch.int64)#.cuda()
|
219 |
+
|
220 |
+
shape = (n_node, n_node)
|
221 |
+
new_adj = torch.sparse.FloatTensor(att_adj, att_edge_weight, shape)
|
222 |
+
return new_adj
|
223 |
+
|
224 |
+
def add_loop_sparse(self, adj, fill_value=1):
|
225 |
+
# make identify sparse tensor
|
226 |
+
row = torch.range(0, int(adj.shape[0]-1), dtype=torch.int64)
|
227 |
+
i = torch.stack((row, row), dim=0)
|
228 |
+
v = torch.ones(adj.shape[0], dtype=torch.float32)
|
229 |
+
shape = adj.shape
|
230 |
+
I_n = torch.sparse.FloatTensor(i, v, shape)
|
231 |
+
return adj + I_n.to(self.device)
|
232 |
+
|
233 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, idx_test=None, train_iters=81, att_0=None,
|
234 |
+
attention=False, model_name=None, initialize=True, verbose=False, normalize=True, patience=510, ):
|
235 |
+
'''
|
236 |
+
train the gcn model, when idx_val is not None, pick the best model
|
237 |
+
according to the validation loss
|
238 |
+
'''
|
239 |
+
self.sim = None
|
240 |
+
self.idx_test = idx_test
|
241 |
+
self.attention = attention
|
242 |
+
# if self.attention:
|
243 |
+
# att_0 = self.att_coef_1(features, adj)
|
244 |
+
# adj = att_0 # update adj
|
245 |
+
# self.sim = att_0 # update att_0
|
246 |
+
|
247 |
+
# self.device = self.gc1.weight.device
|
248 |
+
|
249 |
+
if initialize:
|
250 |
+
self.initialize()
|
251 |
+
|
252 |
+
if type(adj) is not torch.Tensor:
|
253 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
254 |
+
else:
|
255 |
+
features = features.to(self.device)
|
256 |
+
adj = adj.to(self.device)
|
257 |
+
labels = labels.to(self.device)
|
258 |
+
|
259 |
+
# normalize = False # we don't need normalize here, the norm is conducted in the GCN (self.gcn1) model
|
260 |
+
# if normalize:
|
261 |
+
# if utils.is_sparse_tensor(adj):
|
262 |
+
# adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
263 |
+
# else:
|
264 |
+
# adj_norm = utils.normalize_adj_tensor(adj)
|
265 |
+
# else:
|
266 |
+
# adj_norm = adj
|
267 |
+
# add self loop
|
268 |
+
adj = self.add_loop_sparse(adj)
|
269 |
+
|
270 |
+
|
271 |
+
"""The normalization gonna be done in the GCNConv"""
|
272 |
+
self.adj_norm = adj
|
273 |
+
self.features = features
|
274 |
+
self.labels = labels
|
275 |
+
|
276 |
+
if idx_val is None:
|
277 |
+
self._train_without_val(labels, idx_train, train_iters, verbose)
|
278 |
+
else:
|
279 |
+
if patience < train_iters:
|
280 |
+
self._train_with_early_stopping(labels, idx_train, idx_val, train_iters, patience, verbose)
|
281 |
+
else:
|
282 |
+
self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
|
283 |
+
|
284 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose):
|
285 |
+
self.train()
|
286 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
287 |
+
for i in range(train_iters):
|
288 |
+
optimizer.zero_grad()
|
289 |
+
output = self.forward(self.features, self.adj_norm)
|
290 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train], weight=None) # this weight is the weight of each training nodes
|
291 |
+
loss_train.backward()
|
292 |
+
optimizer.step()
|
293 |
+
if verbose and i % 20 == 0:
|
294 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
295 |
+
|
296 |
+
self.eval()
|
297 |
+
output = self.forward(self.features, self.adj_norm)
|
298 |
+
self.output = output
|
299 |
+
|
300 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
301 |
+
if verbose:
|
302 |
+
print('=== training gcn model ===')
|
303 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
304 |
+
|
305 |
+
best_loss_val = 100
|
306 |
+
best_acc_val = 0
|
307 |
+
|
308 |
+
for i in range(train_iters):
|
309 |
+
# print('epoch', i)
|
310 |
+
self.train()
|
311 |
+
optimizer.zero_grad()
|
312 |
+
output = self.forward(self.features, self.adj_norm)
|
313 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
314 |
+
loss_train.backward()
|
315 |
+
optimizer.step()
|
316 |
+
self.eval()
|
317 |
+
|
318 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
319 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
320 |
+
# acc_test = utils.accuracy(output[self.idx_test], labels[self.idx_test])
|
321 |
+
|
322 |
+
if verbose and i % 1 == 0:
|
323 |
+
print('Epoch {}, training loss: {}, val acc: {}, '.format(i, loss_train.item(), acc_val))
|
324 |
+
|
325 |
+
if best_loss_val > loss_val:
|
326 |
+
best_loss_val = loss_val
|
327 |
+
self.output = output
|
328 |
+
weights = deepcopy(self.state_dict())
|
329 |
+
|
330 |
+
if acc_val > best_acc_val:
|
331 |
+
best_acc_val = acc_val
|
332 |
+
self.output = output
|
333 |
+
weights = deepcopy(self.state_dict())
|
334 |
+
|
335 |
+
if verbose:
|
336 |
+
print('=== picking the best model according to the performance on validation ===')
|
337 |
+
self.load_state_dict(weights)
|
338 |
+
# """my test"""
|
339 |
+
# output_ = self.forward(self.features, self.adj_norm)
|
340 |
+
# acc_test_ = utils.accuracy(output_[self.idx_test], labels[self.idx_test])
|
341 |
+
# print('With best weights, test acc:', acc_test_)
|
342 |
+
|
343 |
+
def _train_with_early_stopping(self, labels, idx_train, idx_val, train_iters, patience, verbose):
|
344 |
+
if verbose:
|
345 |
+
print('=== training gcn model ===')
|
346 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
347 |
+
|
348 |
+
early_stopping = patience
|
349 |
+
best_loss_val = 100
|
350 |
+
|
351 |
+
for i in range(train_iters):
|
352 |
+
self.train()
|
353 |
+
optimizer.zero_grad()
|
354 |
+
output = self.forward(self.features, self.adj_norm)
|
355 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
356 |
+
loss_train.backward()
|
357 |
+
optimizer.step()
|
358 |
+
|
359 |
+
self.eval()
|
360 |
+
output = self.forward(self.features, self.adj_norm)
|
361 |
+
|
362 |
+
if verbose and i % 10 == 0:
|
363 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
364 |
+
|
365 |
+
|
366 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
367 |
+
|
368 |
+
if best_loss_val > loss_val:
|
369 |
+
best_loss_val = loss_val
|
370 |
+
self.output = output
|
371 |
+
weights = deepcopy(self.state_dict())
|
372 |
+
patience = early_stopping
|
373 |
+
else:
|
374 |
+
patience -= 1
|
375 |
+
if i > early_stopping and patience <= 0:
|
376 |
+
break
|
377 |
+
|
378 |
+
if verbose:
|
379 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
|
380 |
+
self.load_state_dict(weights)
|
381 |
+
|
382 |
+
def test(self, idx_test):
|
383 |
+
self.eval()
|
384 |
+
output = self.predict() # here use the self.features and self.adj_norm in training stage
|
385 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
386 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
387 |
+
print("Test set results:",
|
388 |
+
"loss= {:.4f}".format(loss_test.item()),
|
389 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
390 |
+
return acc_test, output
|
391 |
+
|
392 |
+
def _set_parameters(self):
|
393 |
+
# TODO
|
394 |
+
pass
|
395 |
+
|
396 |
+
def predict(self, features=None, adj=None):
|
397 |
+
'''By default, inputs are unnormalized data'''
|
398 |
+
self.eval()
|
399 |
+
if features is None and adj is None:
|
400 |
+
return self.forward(self.features, self.adj_norm)
|
401 |
+
else:
|
402 |
+
if type(adj) is not torch.Tensor:
|
403 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
404 |
+
|
405 |
+
self.features = features
|
406 |
+
if utils.is_sparse_tensor(adj):
|
407 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
408 |
+
else:
|
409 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
410 |
+
return self.forward(self.features, self.adj_norm)
|
411 |
+
|
deeprobust/graph/defense/gcn_preprocess.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch.nn.parameter import Parameter
|
6 |
+
from torch.nn.modules.module import Module
|
7 |
+
from deeprobust.graph import utils
|
8 |
+
from deeprobust.graph.defense import GCN
|
9 |
+
from tqdm import tqdm
|
10 |
+
import scipy.sparse as sp
|
11 |
+
import numpy as np
|
12 |
+
from numba import njit
|
13 |
+
|
14 |
+
class GCNSVD(GCN):
|
15 |
+
"""GCNSVD is a 2 Layer Graph Convolutional Network with Truncated SVD as
|
16 |
+
preprocessing. See more details in All You Need Is Low (Rank): Defending
|
17 |
+
Against Adversarial Attacks on Graphs,
|
18 |
+
https://dl.acm.org/doi/abs/10.1145/3336191.3371789.
|
19 |
+
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
nfeat : int
|
23 |
+
size of input feature dimension
|
24 |
+
nhid : int
|
25 |
+
number of hidden units
|
26 |
+
nclass : int
|
27 |
+
size of output dimension
|
28 |
+
dropout : float
|
29 |
+
dropout rate for GCN
|
30 |
+
lr : float
|
31 |
+
learning rate for GCN
|
32 |
+
weight_decay : float
|
33 |
+
weight decay coefficient (l2 normalization) for GCN. When `with_relu` is True, `weight_decay` will be set to 0.
|
34 |
+
with_relu : bool
|
35 |
+
whether to use relu activation function. If False, GCN will be linearized.
|
36 |
+
with_bias: bool
|
37 |
+
whether to include bias term in GCN weights.
|
38 |
+
device: str
|
39 |
+
'cpu' or 'cuda'.
|
40 |
+
|
41 |
+
Examples
|
42 |
+
--------
|
43 |
+
We can first load dataset and then train GCNSVD.
|
44 |
+
|
45 |
+
>>> from deeprobust.graph.data import PrePtbDataset, Dataset
|
46 |
+
>>> from deeprobust.graph.defense import GCNSVD
|
47 |
+
>>> # load clean graph data
|
48 |
+
>>> data = Dataset(root='/tmp/', name='cora', seed=15)
|
49 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
50 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
51 |
+
>>> # load perturbed graph data
|
52 |
+
>>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora')
|
53 |
+
>>> perturbed_adj = perturbed_data.adj
|
54 |
+
>>> # train defense model
|
55 |
+
>>> model = GCNSVD(nfeat=features.shape[1],
|
56 |
+
nhid=16,
|
57 |
+
nclass=labels.max().item() + 1,
|
58 |
+
dropout=0.5, device='cpu').to('cpu')
|
59 |
+
>>> model.fit(features, perturbed_adj, labels, idx_train, idx_val, k=20)
|
60 |
+
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4, with_relu=True, with_bias=True, device='cpu'):
|
64 |
+
|
65 |
+
super(GCNSVD, self).__init__(nfeat, nhid, nclass, dropout, lr, weight_decay, with_relu, with_bias, device=device)
|
66 |
+
self.device = device
|
67 |
+
self.k = None
|
68 |
+
|
69 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, k=50, train_iters=200, initialize=True, verbose=True, **kwargs):
|
70 |
+
"""First perform rank-k approximation of adjacency matrix via
|
71 |
+
truncated SVD, and then train the gcn model on the processed graph,
|
72 |
+
when idx_val is not None, pick the best model according to
|
73 |
+
the validation loss.
|
74 |
+
|
75 |
+
Parameters
|
76 |
+
----------
|
77 |
+
features :
|
78 |
+
node features
|
79 |
+
adj :
|
80 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
81 |
+
labels :
|
82 |
+
node labels
|
83 |
+
idx_train :
|
84 |
+
node training indices
|
85 |
+
idx_val :
|
86 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
87 |
+
k : int
|
88 |
+
number of singular values and vectors to compute.
|
89 |
+
train_iters : int
|
90 |
+
number of training epochs
|
91 |
+
initialize : bool
|
92 |
+
whether to initialize parameters before training
|
93 |
+
verbose : bool
|
94 |
+
whether to show verbose logs
|
95 |
+
"""
|
96 |
+
adj = adj.to('cpu')
|
97 |
+
modified_adj = self.truncatedSVD(adj, k=k)
|
98 |
+
self.k = k
|
99 |
+
# modified_adj_tensor = utils.sparse_mx_to_torch_sparse_tensor(self.modified_adj)
|
100 |
+
features, modified_adj, labels = utils.to_tensor(features, modified_adj, labels, device=self.device)
|
101 |
+
|
102 |
+
self.modified_adj = modified_adj
|
103 |
+
self.features = features
|
104 |
+
self.labels = labels
|
105 |
+
super().fit(features, modified_adj, labels, idx_train, idx_val, train_iters=train_iters, initialize=initialize, verbose=verbose)
|
106 |
+
|
107 |
+
def truncatedSVD(self, data, k=50):
|
108 |
+
"""Truncated SVD on input data.
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
data :
|
113 |
+
input matrix to be decomposed
|
114 |
+
k : int
|
115 |
+
number of singular values and vectors to compute.
|
116 |
+
|
117 |
+
Returns
|
118 |
+
-------
|
119 |
+
numpy.array
|
120 |
+
reconstructed matrix.
|
121 |
+
"""
|
122 |
+
print('=== GCN-SVD: rank={} ==='.format(k))
|
123 |
+
if sp.issparse(data):
|
124 |
+
data = data.asfptype()
|
125 |
+
U, S, V = sp.linalg.svds(data, k=k)
|
126 |
+
print("rank_after = {}".format(len(S.nonzero()[0])))
|
127 |
+
diag_S = np.diag(S)
|
128 |
+
else:
|
129 |
+
U, S, V = np.linalg.svd(data)
|
130 |
+
U = U[:, :k]
|
131 |
+
S = S[:k]
|
132 |
+
V = V[:k, :]
|
133 |
+
print("rank_before = {}".format(len(S.nonzero()[0])))
|
134 |
+
diag_S = np.diag(S)
|
135 |
+
print("rank_after = {}".format(len(diag_S.nonzero()[0])))
|
136 |
+
|
137 |
+
return U @ diag_S @ V
|
138 |
+
|
139 |
+
def predict(self, features=None, adj=None):
|
140 |
+
"""By default, the inputs should be unnormalized adjacency
|
141 |
+
|
142 |
+
Parameters
|
143 |
+
----------
|
144 |
+
features :
|
145 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
146 |
+
adj :
|
147 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
148 |
+
|
149 |
+
|
150 |
+
Returns
|
151 |
+
-------
|
152 |
+
torch.FloatTensor
|
153 |
+
output (log probabilities) of GCNSVD
|
154 |
+
"""
|
155 |
+
|
156 |
+
self.eval()
|
157 |
+
if features is None and adj is None:
|
158 |
+
return self.forward(self.features, self.adj_norm)
|
159 |
+
else:
|
160 |
+
adj = self.truncatedSVD(adj, k=self.k)
|
161 |
+
if type(adj) is not torch.Tensor:
|
162 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
163 |
+
|
164 |
+
self.features = features
|
165 |
+
if utils.is_sparse_tensor(adj):
|
166 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
167 |
+
else:
|
168 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
169 |
+
return self.forward(self.features, self.adj_norm)
|
170 |
+
|
171 |
+
|
172 |
+
class GCNJaccard(GCN):
|
173 |
+
"""GCNJaccard first preprocesses input graph via droppining dissimilar
|
174 |
+
edges and train a GCN based on the processed graph. See more details in
|
175 |
+
Adversarial Examples on Graph Data: Deep Insights into Attack and Defense,
|
176 |
+
https://arxiv.org/pdf/1903.01610.pdf.
|
177 |
+
|
178 |
+
Parameters
|
179 |
+
----------
|
180 |
+
nfeat : int
|
181 |
+
size of input feature dimension
|
182 |
+
nhid : int
|
183 |
+
number of hidden units
|
184 |
+
nclass : int
|
185 |
+
size of output dimension
|
186 |
+
dropout : float
|
187 |
+
dropout rate for GCN
|
188 |
+
lr : float
|
189 |
+
learning rate for GCN
|
190 |
+
weight_decay : float
|
191 |
+
weight decay coefficient (l2 normalization) for GCN. When `with_relu` is True, `weight_decay` will be set to 0.
|
192 |
+
with_relu : bool
|
193 |
+
whether to use relu activation function. If False, GCN will be linearized.
|
194 |
+
with_bias: bool
|
195 |
+
whether to include bias term in GCN weights.
|
196 |
+
device: str
|
197 |
+
'cpu' or 'cuda'.
|
198 |
+
|
199 |
+
Examples
|
200 |
+
--------
|
201 |
+
We can first load dataset and then train GCNJaccard.
|
202 |
+
|
203 |
+
>>> from deeprobust.graph.data import PrePtbDataset, Dataset
|
204 |
+
>>> from deeprobust.graph.defense import GCNJaccard
|
205 |
+
>>> # load clean graph data
|
206 |
+
>>> data = Dataset(root='/tmp/', name='cora', seed=15)
|
207 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
208 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
209 |
+
>>> # load perturbed graph data
|
210 |
+
>>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora')
|
211 |
+
>>> perturbed_adj = perturbed_data.adj
|
212 |
+
>>> # train defense model
|
213 |
+
>>> model = GCNJaccard(nfeat=features.shape[1],
|
214 |
+
nhid=16,
|
215 |
+
nclass=labels.max().item() + 1,
|
216 |
+
dropout=0.5, device='cpu').to('cpu')
|
217 |
+
>>> model.fit(features, perturbed_adj, labels, idx_train, idx_val, threshold=0.03)
|
218 |
+
|
219 |
+
"""
|
220 |
+
def __init__(self, nfeat, nhid, nclass, binary_feature=True, dropout=0.5, lr=0.01, weight_decay=5e-4, with_relu=True, with_bias=True, device='cpu'):
|
221 |
+
|
222 |
+
super(GCNJaccard, self).__init__(nfeat, nhid, nclass, dropout, lr, weight_decay, with_relu, with_bias, device=device)
|
223 |
+
self.device = device
|
224 |
+
self.binary_feature = binary_feature
|
225 |
+
|
226 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, threshold=0.01, train_iters=200, initialize=True, verbose=True, **kwargs):
|
227 |
+
"""First drop dissimilar edges with similarity smaller than given
|
228 |
+
threshold and then train the gcn model on the processed graph.
|
229 |
+
When idx_val is not None, pick the best model according to the
|
230 |
+
validation loss.
|
231 |
+
|
232 |
+
Parameters
|
233 |
+
----------
|
234 |
+
features :
|
235 |
+
node features. The format can be numpy.array or scipy matrix
|
236 |
+
adj :
|
237 |
+
the adjacency matrix.
|
238 |
+
labels :
|
239 |
+
node labels
|
240 |
+
idx_train :
|
241 |
+
node training indices
|
242 |
+
idx_val :
|
243 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
244 |
+
threshold : float
|
245 |
+
similarity threshold for dropping edges. If two connected nodes with similarity smaller than threshold, the edge between them will be removed.
|
246 |
+
train_iters : int
|
247 |
+
number of training epochs
|
248 |
+
initialize : bool
|
249 |
+
whether to initialize parameters before training
|
250 |
+
verbose : bool
|
251 |
+
whether to show verbose logs
|
252 |
+
"""
|
253 |
+
|
254 |
+
self.threshold = threshold
|
255 |
+
modified_adj = self.drop_dissimilar_edges(features, adj)
|
256 |
+
# modified_adj_tensor = utils.sparse_mx_to_torch_sparse_tensor(self.modified_adj)
|
257 |
+
features, modified_adj, labels = utils.to_tensor(features, modified_adj, labels, device=self.device)
|
258 |
+
self.modified_adj = modified_adj
|
259 |
+
self.features = features
|
260 |
+
self.labels = labels
|
261 |
+
super().fit(features, modified_adj, labels, idx_train, idx_val, train_iters=train_iters, initialize=initialize, verbose=verbose)
|
262 |
+
|
263 |
+
def drop_dissimilar_edges(self, features, adj, metric='similarity'):
|
264 |
+
"""Drop dissimilar edges.(Faster version using numba)
|
265 |
+
"""
|
266 |
+
if not sp.issparse(adj):
|
267 |
+
adj = adj.to('cpu')
|
268 |
+
adj = sp.csr_matrix(adj)
|
269 |
+
|
270 |
+
adj_triu = sp.triu(adj, format='csr')
|
271 |
+
|
272 |
+
if sp.issparse(features):
|
273 |
+
features = features.todense().A # make it easier for njit processing
|
274 |
+
|
275 |
+
if metric == 'distance':
|
276 |
+
removed_cnt = dropedge_dis(adj_triu.data, adj_triu.indptr, adj_triu.indices, features, threshold=self.threshold)
|
277 |
+
else:
|
278 |
+
if self.binary_feature:
|
279 |
+
removed_cnt = dropedge_jaccard(adj_triu.data, adj_triu.indptr, adj_triu.indices, features, threshold=self.threshold)
|
280 |
+
else:
|
281 |
+
removed_cnt = dropedge_cosine(adj_triu.data, adj_triu.indptr, adj_triu.indices, features, threshold=self.threshold)
|
282 |
+
print('removed %s edges in the original graph' % removed_cnt)
|
283 |
+
modified_adj = adj_triu + adj_triu.transpose()
|
284 |
+
return modified_adj
|
285 |
+
|
286 |
+
def predict(self, features=None, adj=None):
|
287 |
+
"""By default, the inputs should be unnormalized adjacency
|
288 |
+
|
289 |
+
Parameters
|
290 |
+
----------
|
291 |
+
features :
|
292 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
293 |
+
adj :
|
294 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
295 |
+
|
296 |
+
|
297 |
+
Returns
|
298 |
+
-------
|
299 |
+
torch.FloatTensor
|
300 |
+
output (log probabilities) of GCNJaccard
|
301 |
+
"""
|
302 |
+
|
303 |
+
self.eval()
|
304 |
+
if features is None and adj is None:
|
305 |
+
return self.forward(self.features, self.adj_norm)
|
306 |
+
else:
|
307 |
+
adj = self.drop_dissimilar_edges(features, adj)
|
308 |
+
if type(adj) is not torch.Tensor:
|
309 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
310 |
+
|
311 |
+
self.features = features
|
312 |
+
if utils.is_sparse_tensor(adj):
|
313 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
314 |
+
else:
|
315 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
316 |
+
return self.forward(self.features, self.adj_norm)
|
317 |
+
|
318 |
+
def _drop_dissimilar_edges(self, features, adj):
|
319 |
+
"""Drop dissimilar edges. (Slower version)
|
320 |
+
"""
|
321 |
+
if not sp.issparse(adj):
|
322 |
+
adj = sp.csr_matrix(adj)
|
323 |
+
modified_adj = adj.copy().tolil()
|
324 |
+
|
325 |
+
# preprocessing based on features
|
326 |
+
print('=== GCN-Jaccrad ===')
|
327 |
+
edges = np.array(modified_adj.nonzero()).T
|
328 |
+
removed_cnt = 0
|
329 |
+
for edge in tqdm(edges):
|
330 |
+
n1 = edge[0]
|
331 |
+
n2 = edge[1]
|
332 |
+
if n1 > n2:
|
333 |
+
continue
|
334 |
+
|
335 |
+
if self.binary_feature:
|
336 |
+
J = self._jaccard_similarity(features[n1], features[n2])
|
337 |
+
|
338 |
+
if J < self.threshold:
|
339 |
+
modified_adj[n1, n2] = 0
|
340 |
+
modified_adj[n2, n1] = 0
|
341 |
+
removed_cnt += 1
|
342 |
+
else:
|
343 |
+
# For not binary feature, use cosine similarity
|
344 |
+
C = self._cosine_similarity(features[n1], features[n2])
|
345 |
+
if C < self.threshold:
|
346 |
+
modified_adj[n1, n2] = 0
|
347 |
+
modified_adj[n2, n1] = 0
|
348 |
+
removed_cnt += 1
|
349 |
+
print('removed %s edges in the original graph' % removed_cnt)
|
350 |
+
return modified_adj
|
351 |
+
|
352 |
+
def _jaccard_similarity(self, a, b):
|
353 |
+
intersection = a.multiply(b).count_nonzero()
|
354 |
+
J = intersection * 1.0 / (a.count_nonzero() + b.count_nonzero() - intersection)
|
355 |
+
return J
|
356 |
+
|
357 |
+
def _cosine_similarity(self, a, b):
|
358 |
+
inner_product = (a * b).sum()
|
359 |
+
C = inner_product / (np.sqrt(np.square(a).sum()) * np.sqrt(np.square(b).sum()) + 1e-10)
|
360 |
+
return C
|
361 |
+
|
362 |
+
def __dropedge_jaccard(A, iA, jA, features, threshold):
|
363 |
+
# deprecated: for sparse feature matrix...
|
364 |
+
removed_cnt = 0
|
365 |
+
for row in range(len(iA)-1):
|
366 |
+
for i in range(iA[row], iA[row+1]):
|
367 |
+
# print(row, jA[i], A[i])
|
368 |
+
n1 = row
|
369 |
+
n2 = jA[i]
|
370 |
+
a, b = features[n1], features[n2]
|
371 |
+
|
372 |
+
intersection = a.multiply(b).count_nonzero()
|
373 |
+
J = intersection * 1.0 / (a.count_nonzero() + b.count_nonzero() - intersection)
|
374 |
+
|
375 |
+
if J < threshold:
|
376 |
+
A[i] = 0
|
377 |
+
# A[n2, n1] = 0
|
378 |
+
removed_cnt += 1
|
379 |
+
return removed_cnt
|
380 |
+
|
381 |
+
@njit
|
382 |
+
def dropedge_jaccard(A, iA, jA, features, threshold):
|
383 |
+
removed_cnt = 0
|
384 |
+
for row in range(len(iA)-1):
|
385 |
+
for i in range(iA[row], iA[row+1]):
|
386 |
+
# print(row, jA[i], A[i])
|
387 |
+
n1 = row
|
388 |
+
n2 = jA[i]
|
389 |
+
a, b = features[n1], features[n2]
|
390 |
+
intersection = np.count_nonzero(a*b)
|
391 |
+
J = intersection * 1.0 / (np.count_nonzero(a) + np.count_nonzero(b) - intersection)
|
392 |
+
|
393 |
+
if J < threshold:
|
394 |
+
A[i] = 0
|
395 |
+
# A[n2, n1] = 0
|
396 |
+
removed_cnt += 1
|
397 |
+
return removed_cnt
|
398 |
+
|
399 |
+
|
400 |
+
@njit
|
401 |
+
def dropedge_cosine(A, iA, jA, features, threshold):
|
402 |
+
removed_cnt = 0
|
403 |
+
for row in range(len(iA)-1):
|
404 |
+
for i in range(iA[row], iA[row+1]):
|
405 |
+
# print(row, jA[i], A[i])
|
406 |
+
n1 = row
|
407 |
+
n2 = jA[i]
|
408 |
+
a, b = features[n1], features[n2]
|
409 |
+
inner_product = (a * b).sum()
|
410 |
+
C = inner_product / (np.sqrt(np.square(a).sum()) * np.sqrt(np.square(b).sum()) + 1e-8)
|
411 |
+
|
412 |
+
if C < threshold:
|
413 |
+
A[i] = 0
|
414 |
+
# A[n2, n1] = 0
|
415 |
+
removed_cnt += 1
|
416 |
+
return removed_cnt
|
417 |
+
|
418 |
+
@njit
|
419 |
+
def dropedge_dis(A, iA, jA, features, threshold):
|
420 |
+
removed_cnt = 0
|
421 |
+
for row in range(len(iA)-1):
|
422 |
+
for i in range(iA[row], iA[row+1]):
|
423 |
+
# print(row, jA[i], A[i])
|
424 |
+
n1 = row
|
425 |
+
n2 = jA[i]
|
426 |
+
C = np.linalg.norm(features[n1] - features[n2])
|
427 |
+
if C > threshold:
|
428 |
+
A[i] = 0
|
429 |
+
# A[n2, n1] = 0
|
430 |
+
removed_cnt += 1
|
431 |
+
|
432 |
+
return removed_cnt
|
433 |
+
|
434 |
+
@njit
|
435 |
+
def dropedge_both(A, iA, jA, features, threshold1=2.5, threshold2=0.01):
|
436 |
+
removed_cnt = 0
|
437 |
+
for row in range(len(iA)-1):
|
438 |
+
for i in range(iA[row], iA[row+1]):
|
439 |
+
# print(row, jA[i], A[i])
|
440 |
+
n1 = row
|
441 |
+
n2 = jA[i]
|
442 |
+
C1 = np.linalg.norm(features[n1] - features[n2])
|
443 |
+
|
444 |
+
a, b = features[n1], features[n2]
|
445 |
+
inner_product = (a * b).sum()
|
446 |
+
C2 = inner_product / (np.sqrt(np.square(a).sum() + np.square(b).sum())+ 1e-6)
|
447 |
+
if C1 > threshold1 or threshold2 < 0:
|
448 |
+
A[i] = 0
|
449 |
+
# A[n2, n1] = 0
|
450 |
+
removed_cnt += 1
|
451 |
+
|
452 |
+
return removed_cnt
|
453 |
+
|
454 |
+
|
455 |
+
if __name__ == "__main__":
|
456 |
+
from deeprobust.graph.data import PrePtbDataset, Dataset
|
457 |
+
# load clean graph data
|
458 |
+
dataset_str = 'pubmed'
|
459 |
+
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
|
460 |
+
adj, features, labels = data.adj, data.features, data.labels
|
461 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
462 |
+
# load perturbed graph data
|
463 |
+
perturbed_data = PrePtbDataset(root='/tmp/', name=dataset_str)
|
464 |
+
perturbed_adj = perturbed_data.adj
|
465 |
+
# train defense model
|
466 |
+
print("Test GCNJaccard")
|
467 |
+
model = GCNJaccard(nfeat=features.shape[1],
|
468 |
+
nhid=16,
|
469 |
+
nclass=labels.max().item() + 1,
|
470 |
+
binary_feature=False,
|
471 |
+
dropout=0.5, device='cuda').to('cuda')
|
472 |
+
model.fit(features, perturbed_adj, labels, idx_train, idx_val, threshold=0.1)
|
473 |
+
model.test(idx_test)
|
474 |
+
prediction_1 = model.predict()
|
475 |
+
prediction_2 = model.predict(features, perturbed_adj)
|
476 |
+
assert (prediction_1 != prediction_2).sum() == 0
|
477 |
+
|
478 |
+
print("Test GCNSVD")
|
479 |
+
model = GCNSVD(nfeat=features.shape[1],
|
480 |
+
nhid=16,
|
481 |
+
nclass=labels.max().item() + 1,
|
482 |
+
dropout=0.5, device='cuda').to('cuda')
|
483 |
+
model.fit(features, perturbed_adj, labels, idx_train, idx_val, k=20)
|
484 |
+
model.test(idx_test)
|
485 |
+
prediction_1 = model.predict()
|
486 |
+
prediction_2 = model.predict(features, perturbed_adj)
|
487 |
+
assert (prediction_1 - prediction_2).mean() < 1e-5
|
488 |
+
|
deeprobust/graph/defense/median_gcn.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch_geometric.typing import Adj, OptTensor
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
from torch import optim
|
8 |
+
from copy import deepcopy
|
9 |
+
from deeprobust.graph import utils
|
10 |
+
from torch_geometric.nn.inits import zeros
|
11 |
+
from torch_geometric.nn.conv import MessagePassing
|
12 |
+
|
13 |
+
# This works for higher version of torch_gometric, e.g., 2.0.
|
14 |
+
# from torch_geometric.nn.dense.linear import Linear
|
15 |
+
from torch.nn import Linear
|
16 |
+
|
17 |
+
|
18 |
+
from torch_sparse import SparseTensor, set_diag
|
19 |
+
from torch_geometric.utils import to_dense_batch
|
20 |
+
from torch_geometric.utils import remove_self_loops, add_self_loops
|
21 |
+
|
22 |
+
|
23 |
+
class MedianConv(MessagePassing):
|
24 |
+
|
25 |
+
def __init__(self, in_channels: int, out_channels: int,
|
26 |
+
add_self_loops: bool = True,
|
27 |
+
bias: bool = True, **kwargs):
|
28 |
+
kwargs.setdefault('aggr', None)
|
29 |
+
super(MedianConv, self).__init__(**kwargs)
|
30 |
+
|
31 |
+
self.in_channels = in_channels
|
32 |
+
self.out_channels = out_channels
|
33 |
+
self.add_self_loops = add_self_loops
|
34 |
+
|
35 |
+
# This works for higher version of torch_gometric, e.g., 2.0.
|
36 |
+
# self.lin = Linear(in_channels, out_channels, bias=False,
|
37 |
+
# weight_initializer='glorot')
|
38 |
+
self.lin = Linear(in_channels, out_channels, bias=False)
|
39 |
+
|
40 |
+
if bias:
|
41 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
42 |
+
else:
|
43 |
+
self.register_parameter('bias', None)
|
44 |
+
|
45 |
+
self.reset_parameters()
|
46 |
+
|
47 |
+
def reset_parameters(self):
|
48 |
+
self.lin.reset_parameters()
|
49 |
+
zeros(self.bias)
|
50 |
+
|
51 |
+
def forward(self, x: Tensor, edge_index: Adj,
|
52 |
+
edge_weight: OptTensor = None) -> Tensor:
|
53 |
+
|
54 |
+
if self.add_self_loops:
|
55 |
+
if isinstance(edge_index, Tensor):
|
56 |
+
edge_index, _ = remove_self_loops(edge_index)
|
57 |
+
edge_index, _ = add_self_loops(edge_index,
|
58 |
+
num_nodes=x.size(self.node_dim))
|
59 |
+
elif isinstance(edge_index, SparseTensor):
|
60 |
+
edge_index = set_diag(edge_index)
|
61 |
+
|
62 |
+
x = self.lin(x)
|
63 |
+
# propagate_type: (x: Tensor, edge_weight: OptTensor)
|
64 |
+
out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
|
65 |
+
size=None)
|
66 |
+
if self.bias is not None:
|
67 |
+
out += self.bias
|
68 |
+
return out
|
69 |
+
|
70 |
+
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
|
71 |
+
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
|
72 |
+
|
73 |
+
def aggregate(self, x_j, index):
|
74 |
+
"""median aggregation"""
|
75 |
+
# important! `to_dense_batch` requires the `index` is sorted
|
76 |
+
ix = torch.argsort(index)
|
77 |
+
index = index[ix]
|
78 |
+
x_j = x_j[ix]
|
79 |
+
|
80 |
+
dense_x, mask = to_dense_batch(x_j, index)
|
81 |
+
out = x_j.new_zeros(dense_x.size(0), dense_x.size(-1))
|
82 |
+
deg = mask.sum(dim=1)
|
83 |
+
for i in deg.unique():
|
84 |
+
deg_mask = deg == i
|
85 |
+
out[deg_mask] = dense_x[deg_mask, :i].median(dim=1).values
|
86 |
+
return out
|
87 |
+
|
88 |
+
def __repr__(self):
|
89 |
+
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
|
90 |
+
self.out_channels)
|
91 |
+
|
92 |
+
|
93 |
+
class MedianGCN(torch.nn.Module):
|
94 |
+
"""Graph Convolutional Networks with Median aggregation (MedianGCN)
|
95 |
+
based on pytorch geometric.
|
96 |
+
|
97 |
+
`Understanding Structural Vulnerability in Graph Convolutional Networks
|
98 |
+
<https://arxiv.org/abs/2108.06280>`
|
99 |
+
|
100 |
+
MedianGCN uses median aggregation function instead of
|
101 |
+
`weighted mean` adopted in GCN, which improves the robustness
|
102 |
+
of the model against adversarial structural attack.
|
103 |
+
|
104 |
+
Parameters
|
105 |
+
----------
|
106 |
+
nfeat : int
|
107 |
+
size of input feature dimension
|
108 |
+
nhid : int
|
109 |
+
number of hidden units
|
110 |
+
nclass : int
|
111 |
+
size of output dimension
|
112 |
+
lr : float
|
113 |
+
learning rate for MedianGCN
|
114 |
+
weight_decay : float
|
115 |
+
weight decay coefficient (l2 normalization) for MedianGCN.
|
116 |
+
with_bias: bool
|
117 |
+
whether to include bias term in MedianGCN weights.
|
118 |
+
device: str
|
119 |
+
'cpu' or 'cuda'.
|
120 |
+
|
121 |
+
Examples
|
122 |
+
--------
|
123 |
+
We can first load dataset and then train MedianGCN.
|
124 |
+
|
125 |
+
>>> from deeprobust.graph.data import Dataset
|
126 |
+
>>> from deeprobust.graph.defense import MedianGCN
|
127 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
128 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
129 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
130 |
+
>>> MedianGCN = MedianGCN(nfeat=features.shape[1],
|
131 |
+
nhid=16, nclass=labels.max().item() + 1,
|
132 |
+
device='cuda')
|
133 |
+
>>> MedianGCN = MedianGCN.to('cuda')
|
134 |
+
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
|
135 |
+
>>> MedianGCN.fit(pyg_data, verbose=True) # train with earlystopping
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
139 |
+
with_bias=True, device=None):
|
140 |
+
|
141 |
+
super(MedianGCN, self).__init__()
|
142 |
+
|
143 |
+
assert device is not None, "Please specify 'device'!"
|
144 |
+
self.device = device
|
145 |
+
|
146 |
+
self.conv1 = MedianConv(nfeat, nhid, bias=with_bias)
|
147 |
+
self.conv2 = MedianConv(nhid, nclass, bias=with_bias)
|
148 |
+
|
149 |
+
self.dropout = dropout
|
150 |
+
self.lr = lr
|
151 |
+
self.weight_decay = weight_decay
|
152 |
+
self.with_bias = with_bias
|
153 |
+
self.output = None
|
154 |
+
|
155 |
+
def forward(self, data):
|
156 |
+
x, edge_index = data.x, data.edge_index
|
157 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
158 |
+
x = self.conv1(x, edge_index).relu()
|
159 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
160 |
+
x = self.conv2(x, edge_index)
|
161 |
+
return F.log_softmax(x, dim=1)
|
162 |
+
|
163 |
+
def initialize(self):
|
164 |
+
"""Initialize parameters of MedianGCN.
|
165 |
+
"""
|
166 |
+
self.conv1.reset_parameters()
|
167 |
+
self.conv2.reset_parameters()
|
168 |
+
|
169 |
+
def fit(self, pyg_data, train_iters=200, initialize=True, verbose=False, patience=500, **kwargs):
|
170 |
+
"""Train the MedianGCN model, when idx_val is not None, pick the best model
|
171 |
+
according to the validation loss.
|
172 |
+
|
173 |
+
Parameters
|
174 |
+
----------
|
175 |
+
pyg_data :
|
176 |
+
pytorch geometric dataset object
|
177 |
+
train_iters : int
|
178 |
+
number of training epochs
|
179 |
+
initialize : bool
|
180 |
+
whether to initialize parameters before training
|
181 |
+
verbose : bool
|
182 |
+
whether to show verbose logs
|
183 |
+
patience : int
|
184 |
+
patience for early stopping, only valid when `idx_val` is given
|
185 |
+
"""
|
186 |
+
|
187 |
+
# self.device = self.conv1.weight.device
|
188 |
+
if initialize:
|
189 |
+
self.initialize()
|
190 |
+
|
191 |
+
self.data = pyg_data[0].to(self.device)
|
192 |
+
# By default, it is trained with early stopping on validation
|
193 |
+
self.train_with_early_stopping(train_iters, patience, verbose)
|
194 |
+
|
195 |
+
def train_with_early_stopping(self, train_iters, patience, verbose):
|
196 |
+
"""early stopping based on the validation loss
|
197 |
+
"""
|
198 |
+
if verbose:
|
199 |
+
print('=== training MedianGCN model ===')
|
200 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
201 |
+
|
202 |
+
labels = self.data.y
|
203 |
+
train_mask, val_mask = self.data.train_mask, self.data.val_mask
|
204 |
+
|
205 |
+
early_stopping = patience
|
206 |
+
best_loss_val = 100
|
207 |
+
|
208 |
+
for i in range(train_iters):
|
209 |
+
self.train()
|
210 |
+
optimizer.zero_grad()
|
211 |
+
output = self.forward(self.data)
|
212 |
+
|
213 |
+
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
|
214 |
+
loss_train.backward()
|
215 |
+
optimizer.step()
|
216 |
+
|
217 |
+
if verbose and i % 10 == 0:
|
218 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
219 |
+
|
220 |
+
self.eval()
|
221 |
+
output = self.forward(self.data)
|
222 |
+
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
|
223 |
+
|
224 |
+
if best_loss_val > loss_val:
|
225 |
+
best_loss_val = loss_val
|
226 |
+
self.output = output
|
227 |
+
weights = deepcopy(self.state_dict())
|
228 |
+
patience = early_stopping
|
229 |
+
else:
|
230 |
+
patience -= 1
|
231 |
+
if i > early_stopping and patience <= 0:
|
232 |
+
break
|
233 |
+
|
234 |
+
if verbose:
|
235 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val))
|
236 |
+
self.load_state_dict(weights)
|
237 |
+
|
238 |
+
@torch.no_grad()
|
239 |
+
def test(self, pyg_data=None):
|
240 |
+
"""Evaluate MedianGCN performance on test set.
|
241 |
+
|
242 |
+
Parameters
|
243 |
+
----------
|
244 |
+
pyg_data :
|
245 |
+
pytorch geometric dataset object
|
246 |
+
idx_test :
|
247 |
+
node testing indices
|
248 |
+
"""
|
249 |
+
self.eval()
|
250 |
+
data = pyg_data[0].to(self.device) if pyg_data is not None else self.data
|
251 |
+
test_mask = data.test_mask
|
252 |
+
labels = data.y
|
253 |
+
output = self.forward(data)
|
254 |
+
# output = self.output
|
255 |
+
loss_test = F.nll_loss(output[test_mask], labels[test_mask])
|
256 |
+
acc_test = utils.accuracy(output[test_mask], labels[test_mask])
|
257 |
+
print("Test set results:",
|
258 |
+
"loss= {:.4f}".format(loss_test.item()),
|
259 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
260 |
+
return acc_test.item()
|
261 |
+
|
262 |
+
@torch.no_grad()
|
263 |
+
def predict(self, pyg_data=None):
|
264 |
+
"""
|
265 |
+
Parameters
|
266 |
+
----------
|
267 |
+
pyg_data :
|
268 |
+
pytorch geometric dataset object
|
269 |
+
|
270 |
+
Returns
|
271 |
+
-------
|
272 |
+
torch.FloatTensor
|
273 |
+
output (log probabilities) of MedianGCN
|
274 |
+
"""
|
275 |
+
|
276 |
+
self.eval()
|
277 |
+
data = pyg_data[0].to(self.device) if pyg_data is not None else self.data
|
278 |
+
return self.forward(data)
|
279 |
+
|
280 |
+
|
281 |
+
if __name__ == "__main__":
|
282 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
283 |
+
# from deeprobust.graph.defense import MedianGCN
|
284 |
+
data = Dataset(root='/tmp/', name='cora')
|
285 |
+
adj, features, labels = data.adj, data.features, data.labels
|
286 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
287 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
288 |
+
MedianGCN = MedianGCN(nfeat=features.shape[1],
|
289 |
+
nhid=16,
|
290 |
+
nclass=labels.max().item() + 1,
|
291 |
+
device=device)
|
292 |
+
MedianGCN = MedianGCN.to(device)
|
293 |
+
pyg_data = Dpr2Pyg(data)
|
294 |
+
MedianGCN.fit(pyg_data, verbose=True) # train with earlystopping
|
295 |
+
MedianGCN.test()
|
296 |
+
print(MedianGCN.predict().size())
|
deeprobust/graph/defense/node_embedding.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code in this file is modified from https://github.com/abojchevski/node_embedding_attack
|
3 |
+
|
4 |
+
'Adversarial Attacks on Node Embeddings via Graph Poisoning'
|
5 |
+
Aleksandar Bojchevski and Stephan Günnemann, ICML 2019
|
6 |
+
http://proceedings.mlr.press/v97/bojchevski19a.html
|
7 |
+
Copyright (C) owned by the authors, 2019
|
8 |
+
"""
|
9 |
+
import numba
|
10 |
+
import numpy as np
|
11 |
+
import scipy.sparse as sp
|
12 |
+
from gensim.models import Word2Vec
|
13 |
+
import networkx as nx
|
14 |
+
from gensim.models import KeyedVectors
|
15 |
+
from sklearn.linear_model import LogisticRegression
|
16 |
+
from sklearn.preprocessing import normalize
|
17 |
+
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score, accuracy_score
|
18 |
+
|
19 |
+
class BaseEmbedding:
|
20 |
+
"""Base class for node embedding methods such as DeepWalk and Node2Vec.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
self.embedding = None
|
25 |
+
self.model = None
|
26 |
+
|
27 |
+
def evaluate_node_classification(self, labels, idx_train, idx_test,
|
28 |
+
normalize_embedding=True, lr_params=None):
|
29 |
+
"""Evaluate the node embeddings on the node classification task..
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
---------
|
33 |
+
labels: np.ndarray, shape [n_nodes]
|
34 |
+
The ground truth labels
|
35 |
+
normalize_embedding: bool
|
36 |
+
Whether to normalize the embeddings
|
37 |
+
idx_train: np.array
|
38 |
+
Indices of training nodes
|
39 |
+
idx_test: np.array
|
40 |
+
Indices of test nodes
|
41 |
+
lr_params: dict
|
42 |
+
Parameters for the LogisticRegression model
|
43 |
+
|
44 |
+
Returns
|
45 |
+
-------
|
46 |
+
[numpy.array, float, float] :
|
47 |
+
Predictions from LR, micro F1 score and macro F1 score
|
48 |
+
"""
|
49 |
+
|
50 |
+
embedding_matrix = self.embedding
|
51 |
+
|
52 |
+
if normalize_embedding:
|
53 |
+
embedding_matrix = normalize(embedding_matrix)
|
54 |
+
|
55 |
+
features_train = embedding_matrix[idx_train]
|
56 |
+
features_test = embedding_matrix[idx_test]
|
57 |
+
labels_train = labels[idx_train]
|
58 |
+
labels_test = labels[idx_test]
|
59 |
+
|
60 |
+
if lr_params is None:
|
61 |
+
lr = LogisticRegression(solver='lbfgs', max_iter=1000, multi_class='auto')
|
62 |
+
else:
|
63 |
+
lr = LogisticRegression(**lr_params)
|
64 |
+
lr.fit(features_train, labels_train)
|
65 |
+
|
66 |
+
lr_z_predict = lr.predict(features_test)
|
67 |
+
f1_micro = f1_score(labels_test, lr_z_predict, average='micro')
|
68 |
+
f1_macro = f1_score(labels_test, lr_z_predict, average='macro')
|
69 |
+
test_acc = accuracy_score(labels_test, lr_z_predict)
|
70 |
+
print('Micro F1:', f1_micro)
|
71 |
+
print('Macro F1:', f1_macro)
|
72 |
+
return lr_z_predict, f1_micro, f1_macro
|
73 |
+
|
74 |
+
|
75 |
+
def evaluate_link_prediction(self, adj, node_pairs, normalize_embedding=True):
|
76 |
+
"""Evaluate the node embeddings on the link prediction task.
|
77 |
+
|
78 |
+
adj: sp.csr_matrix, shape [n_nodes, n_nodes]
|
79 |
+
Adjacency matrix of the graph
|
80 |
+
node_pairs: numpy.array, shape [n_pairs, 2]
|
81 |
+
Node pairs
|
82 |
+
normalize_embedding: bool
|
83 |
+
Whether to normalize the embeddings
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
[numpy.array, float, float]
|
88 |
+
Inner product of embeddings, Area under ROC curve (AUC) score and average precision (AP) score
|
89 |
+
"""
|
90 |
+
|
91 |
+
embedding_matrix = self.embedding
|
92 |
+
if normalize_embedding:
|
93 |
+
embedding_matrix = normalize(embedding_matrix)
|
94 |
+
|
95 |
+
true = adj[node_pairs[:, 0], node_pairs[:, 1]].A1
|
96 |
+
scores = (embedding_matrix[node_pairs[:, 0]] * embedding_matrix[node_pairs[:, 1]]).sum(1)
|
97 |
+
# print(np.unique(true, return_counts=True))
|
98 |
+
try:
|
99 |
+
auc_score = roc_auc_score(true, scores)
|
100 |
+
except Exception as e:
|
101 |
+
auc_score = 0.00
|
102 |
+
print('ROC error')
|
103 |
+
ap_score = average_precision_score(true, scores)
|
104 |
+
print("AUC:", auc_score)
|
105 |
+
print("AP:", ap_score)
|
106 |
+
return scores, auc_score, ap_score
|
107 |
+
|
108 |
+
class Node2Vec(BaseEmbedding):
|
109 |
+
"""node2vec: Scalable Feature Learning for Networks. KDD'15.
|
110 |
+
To use this model, you need to "pip install node2vec" first.
|
111 |
+
|
112 |
+
Examples
|
113 |
+
----
|
114 |
+
>>> from deeprobust.graph.data import Dataset
|
115 |
+
>>> from deeprobust.graph.global_attack import NodeEmbeddingAttack
|
116 |
+
>>> from deeprobust.graph.defense import Node2Vec
|
117 |
+
>>> data = Dataset(root='/tmp/', name='cora_ml', seed=15)
|
118 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
119 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
120 |
+
>>> # set up attack model
|
121 |
+
>>> attacker = NodeEmbeddingAttack()
|
122 |
+
>>> attacker.attack(adj, attack_type="remove", n_perturbations=1000)
|
123 |
+
>>> modified_adj = attacker.modified_adj
|
124 |
+
>>> print("Test Node2vec on clean graph")
|
125 |
+
>>> model = Node2Vec()
|
126 |
+
>>> model.fit(adj)
|
127 |
+
>>> model.evaluate_node_classification(labels, idx_train, idx_test)
|
128 |
+
>>> print("Test Node2vec on attacked graph")
|
129 |
+
>>> model = Node2Vec()
|
130 |
+
>>> model.fit(modified_adj)
|
131 |
+
>>> model.evaluate_node_classification(labels, idx_train, idx_test)
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self):
|
135 |
+
# self.fit = self.node2vec_snap
|
136 |
+
super(Node2Vec, self).__init__()
|
137 |
+
self.fit = self.node2vec
|
138 |
+
|
139 |
+
def node2vec(self, adj, embedding_dim=64, walk_length=30, walks_per_node=10,
|
140 |
+
workers=8, window_size=10, num_neg_samples=1, p=4, q=1):
|
141 |
+
"""Compute Node2Vec embeddings for the given graph.
|
142 |
+
|
143 |
+
Parameters
|
144 |
+
----------
|
145 |
+
adj : sp.csr_matrix, shape [n_nodes, n_nodes]
|
146 |
+
Adjacency matrix of the graph
|
147 |
+
embedding_dim : int, optional
|
148 |
+
Dimension of the embedding
|
149 |
+
walks_per_node : int, optional
|
150 |
+
Number of walks sampled from each node
|
151 |
+
walk_length : int, optional
|
152 |
+
Length of each random walk
|
153 |
+
workers : int, optional
|
154 |
+
Number of threads (see gensim.models.Word2Vec process)
|
155 |
+
window_size : int, optional
|
156 |
+
Window size (see gensim.models.Word2Vec)
|
157 |
+
num_neg_samples : int, optional
|
158 |
+
Number of negative samples (see gensim.models.Word2Vec)
|
159 |
+
p : float
|
160 |
+
The hyperparameter p in node2vec
|
161 |
+
q : float
|
162 |
+
The hyperparameter q in node2vec
|
163 |
+
"""
|
164 |
+
|
165 |
+
|
166 |
+
walks = sample_n2v_random_walks(adj, walk_length, walks_per_node, p=p, q=q)
|
167 |
+
walks = [list(map(str, walk)) for walk in walks]
|
168 |
+
self.model = Word2Vec(walks, size=embedding_dim, window=window_size, min_count=0, sg=1, workers=workers,
|
169 |
+
iter=1, negative=num_neg_samples, hs=0, compute_loss=True)
|
170 |
+
self.embedding = self.model.wv.vectors[np.fromiter(map(int, self.model.wv.index2word), np.int32).argsort()]
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
class DeepWalk(BaseEmbedding):
|
175 |
+
"""DeepWalk: Online Learning of Social Representations. KDD'14. The implementation is
|
176 |
+
modified from https://github.com/abojchevski/node_embedding_attack
|
177 |
+
|
178 |
+
Examples
|
179 |
+
----
|
180 |
+
>>> from deeprobust.graph.data import Dataset
|
181 |
+
>>> from deeprobust.graph.global_attack import NodeEmbeddingAttack
|
182 |
+
>>> from deeprobust.graph.defense import DeepWalk
|
183 |
+
>>> data = Dataset(root='/tmp/', name='cora_ml', seed=15)
|
184 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
185 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
186 |
+
>>> # set up attack model
|
187 |
+
>>> attacker = NodeEmbeddingAttack()
|
188 |
+
>>> attacker.attack(adj, attack_type="remove", n_perturbations=1000)
|
189 |
+
>>> modified_adj = attacker.modified_adj
|
190 |
+
>>> print("Test DeepWalk on clean graph")
|
191 |
+
>>> model = DeepWalk()
|
192 |
+
>>> model.fit(adj)
|
193 |
+
>>> model.evaluate_node_classification(labels, idx_train, idx_test)
|
194 |
+
>>> print("Test DeepWalk on attacked graph")
|
195 |
+
>>> model.fit(modified_adj)
|
196 |
+
>>> model.evaluate_node_classification(labels, idx_train, idx_test)
|
197 |
+
>>> print("Test DeepWalk SVD")
|
198 |
+
>>> model = DeepWalk(type="svd")
|
199 |
+
>>> model.fit(modified_adj)
|
200 |
+
>>> model.evaluate_node_classification(labels, idx_train, idx_test)
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self, type="skipgram"):
|
204 |
+
super(DeepWalk, self).__init__()
|
205 |
+
if type == "skipgram":
|
206 |
+
self.fit = self.deepwalk_skipgram
|
207 |
+
elif type == "svd":
|
208 |
+
self.fit = self.deepwalk_svd
|
209 |
+
else:
|
210 |
+
raise NotImplementedError
|
211 |
+
|
212 |
+
def deepwalk_skipgram(self, adj, embedding_dim=64, walk_length=80, walks_per_node=10,
|
213 |
+
workers=8, window_size=10, num_neg_samples=1):
|
214 |
+
"""Compute DeepWalk embeddings for the given graph using the skip-gram formulation.
|
215 |
+
|
216 |
+
Parameters
|
217 |
+
----------
|
218 |
+
adj : sp.csr_matrix, shape [n_nodes, n_nodes]
|
219 |
+
Adjacency matrix of the graph
|
220 |
+
embedding_dim : int, optional
|
221 |
+
Dimension of the embedding
|
222 |
+
walks_per_node : int, optional
|
223 |
+
Number of walks sampled from each node
|
224 |
+
walk_length : int, optional
|
225 |
+
Length of each random walk
|
226 |
+
workers : int, optional
|
227 |
+
Number of threads (see gensim.models.Word2Vec process)
|
228 |
+
window_size : int, optional
|
229 |
+
Window size (see gensim.models.Word2Vec)
|
230 |
+
num_neg_samples : int, optional
|
231 |
+
Number of negative samples (see gensim.models.Word2Vec)
|
232 |
+
"""
|
233 |
+
|
234 |
+
walks = sample_random_walks(adj, walk_length, walks_per_node)
|
235 |
+
walks = [list(map(str, walk)) for walk in walks]
|
236 |
+
self.model = Word2Vec(walks, size=embedding_dim, window=window_size, min_count=0, sg=1, workers=workers,
|
237 |
+
iter=1, negative=num_neg_samples, hs=0, compute_loss=True)
|
238 |
+
self.embedding = self.model.wv.vectors[np.fromiter(map(int, self.model.wv.index2word), np.int32).argsort()]
|
239 |
+
|
240 |
+
|
241 |
+
def deepwalk_svd(self, adj, window_size=10, embedding_dim=64, num_neg_samples=1, sparse=True):
|
242 |
+
"""Compute DeepWalk embeddings for the given graph using the matrix factorization formulation.
|
243 |
+
adj: sp.csr_matrix, shape [n_nodes, n_nodes]
|
244 |
+
Adjacency matrix of the graph
|
245 |
+
window_size: int
|
246 |
+
Size of the window
|
247 |
+
embedding_dim: int
|
248 |
+
Size of the embedding
|
249 |
+
num_neg_samples: int
|
250 |
+
Number of negative samples
|
251 |
+
sparse: bool
|
252 |
+
Whether to perform sparse operations
|
253 |
+
Returns
|
254 |
+
------
|
255 |
+
np.ndarray, shape [num_nodes, embedding_dim]
|
256 |
+
Embedding matrix.
|
257 |
+
"""
|
258 |
+
sum_powers_transition = sum_of_powers_of_transition_matrix(adj, window_size)
|
259 |
+
|
260 |
+
deg = adj.sum(1).A1
|
261 |
+
deg[deg == 0] = 1
|
262 |
+
deg_matrix = sp.diags(1 / deg)
|
263 |
+
|
264 |
+
volume = adj.sum()
|
265 |
+
|
266 |
+
M = sum_powers_transition.dot(deg_matrix) * volume / (num_neg_samples * window_size)
|
267 |
+
|
268 |
+
log_M = M.copy()
|
269 |
+
log_M[M > 1] = np.log(log_M[M > 1])
|
270 |
+
log_M = log_M.multiply(M > 1)
|
271 |
+
|
272 |
+
if not sparse:
|
273 |
+
log_M = log_M.toarray()
|
274 |
+
|
275 |
+
Fu, Fv = self.svd_embedding(log_M, embedding_dim, sparse)
|
276 |
+
|
277 |
+
loss = np.linalg.norm(Fu.dot(Fv.T) - log_M, ord='fro')
|
278 |
+
self.embedding = Fu
|
279 |
+
return Fu, Fv, loss, log_M
|
280 |
+
|
281 |
+
def svd_embedding(self, x, embedding_dim, sparse=False):
|
282 |
+
"""Computes an embedding by selection the top (embedding_dim) largest singular-values/vectors.
|
283 |
+
:param x: sp.csr_matrix or np.ndarray
|
284 |
+
The matrix that we want to embed
|
285 |
+
:param embedding_dim: int
|
286 |
+
Dimension of the embedding
|
287 |
+
:param sparse: bool
|
288 |
+
Whether to perform sparse operations
|
289 |
+
:return: np.ndarray, shape [?, embedding_dim], np.ndarray, shape [?, embedding_dim]
|
290 |
+
Embedding matrices.
|
291 |
+
"""
|
292 |
+
if sparse:
|
293 |
+
U, s, V = sp.linalg.svds(x, embedding_dim)
|
294 |
+
else:
|
295 |
+
U, s, V = np.linalg.svd(x)
|
296 |
+
|
297 |
+
S = np.diag(s)
|
298 |
+
Fu = U.dot(np.sqrt(S))[:, :embedding_dim]
|
299 |
+
Fv = np.sqrt(S).dot(V)[:embedding_dim, :].T
|
300 |
+
return Fu, Fv
|
301 |
+
|
302 |
+
|
303 |
+
def sample_random_walks(adj, walk_length, walks_per_node, seed=None):
|
304 |
+
"""Sample random walks of fixed length from each node in the graph in parallel.
|
305 |
+
Parameters
|
306 |
+
----------
|
307 |
+
adj : sp.csr_matrix, shape [n_nodes, n_nodes]
|
308 |
+
Sparse adjacency matrix
|
309 |
+
walk_length : int
|
310 |
+
Random walk length
|
311 |
+
walks_per_node : int
|
312 |
+
Number of random walks per node
|
313 |
+
seed : int or None
|
314 |
+
Random seed
|
315 |
+
Returns
|
316 |
+
-------
|
317 |
+
walks : np.ndarray, shape [num_walks * num_nodes, walk_length]
|
318 |
+
The sampled random walks
|
319 |
+
"""
|
320 |
+
if seed is None:
|
321 |
+
seed = np.random.randint(0, 100000)
|
322 |
+
adj = sp.csr_matrix(adj)
|
323 |
+
random_walks = _random_walk(adj.indptr,
|
324 |
+
adj.indices,
|
325 |
+
walk_length,
|
326 |
+
walks_per_node,
|
327 |
+
seed).reshape([-1, walk_length])
|
328 |
+
return random_walks
|
329 |
+
|
330 |
+
|
331 |
+
@numba.jit(nopython=True, parallel=True)
|
332 |
+
def _random_walk(indptr, indices, walk_length, walks_per_node, seed):
|
333 |
+
"""Sample r random walks of length l per node in parallel from the graph.
|
334 |
+
Parameters
|
335 |
+
----------
|
336 |
+
indptr : array-like
|
337 |
+
Pointer for the edges of each node
|
338 |
+
indices : array-like
|
339 |
+
Edges for each node
|
340 |
+
walk_length : int
|
341 |
+
Random walk length
|
342 |
+
walks_per_node : int
|
343 |
+
Number of random walks per node
|
344 |
+
seed : int
|
345 |
+
Random seed
|
346 |
+
Returns
|
347 |
+
-------
|
348 |
+
walks : array-like, shape [r*N*l]
|
349 |
+
The sampled random walks
|
350 |
+
"""
|
351 |
+
np.random.seed(seed)
|
352 |
+
N = len(indptr) - 1
|
353 |
+
walks = []
|
354 |
+
|
355 |
+
for ir in range(walks_per_node):
|
356 |
+
for n in range(N):
|
357 |
+
for il in range(walk_length):
|
358 |
+
walks.append(n)
|
359 |
+
n = np.random.choice(indices[indptr[n]:indptr[n + 1]])
|
360 |
+
|
361 |
+
return np.array(walks)
|
362 |
+
|
363 |
+
def sample_n2v_random_walks(adj, walk_length, walks_per_node, p, q, seed=None):
|
364 |
+
"""Sample node2vec random walks of fixed length from each node in the graph in parallel.
|
365 |
+
Parameters
|
366 |
+
----------
|
367 |
+
adj : sp.csr_matrix, shape [n_nodes, n_nodes]
|
368 |
+
Sparse adjacency matrix
|
369 |
+
walk_length : int
|
370 |
+
Random walk length
|
371 |
+
walks_per_node : int
|
372 |
+
Number of random walks per node
|
373 |
+
p: float
|
374 |
+
The probability to go back
|
375 |
+
q: float,
|
376 |
+
The probability to go explore undiscovered parts of the graphs
|
377 |
+
seed : int or None
|
378 |
+
Random seed
|
379 |
+
Returns
|
380 |
+
-------
|
381 |
+
walks : np.ndarray, shape [num_walks * num_nodes, walk_length]
|
382 |
+
The sampled random walks
|
383 |
+
"""
|
384 |
+
if seed is None:
|
385 |
+
seed = np.random.randint(0, 100000)
|
386 |
+
adj = sp.csr_matrix(adj)
|
387 |
+
random_walks = _n2v_random_walk(adj.indptr,
|
388 |
+
adj.indices,
|
389 |
+
walk_length,
|
390 |
+
walks_per_node,
|
391 |
+
p,
|
392 |
+
q,
|
393 |
+
seed)
|
394 |
+
return random_walks
|
395 |
+
|
396 |
+
@numba.jit(nopython=True)
|
397 |
+
def random_choice(arr, p):
|
398 |
+
"""Similar to `numpy.random.choice` and it suppors p=option in numba.
|
399 |
+
refer to <https://github.com/numba/numba/issues/2539#issuecomment-507306369>
|
400 |
+
|
401 |
+
Parameters
|
402 |
+
----------
|
403 |
+
arr : 1-D array-like
|
404 |
+
p : 1-D array-like
|
405 |
+
The probabilities associated with each entry in arr
|
406 |
+
|
407 |
+
Returns
|
408 |
+
-------
|
409 |
+
samples : ndarray
|
410 |
+
The generated random samples
|
411 |
+
"""
|
412 |
+
return arr[np.searchsorted(np.cumsum(p), np.random.random(), side="right")]
|
413 |
+
|
414 |
+
@numba.jit(nopython=True)
|
415 |
+
def _n2v_random_walk(indptr,
|
416 |
+
indices,
|
417 |
+
walk_length,
|
418 |
+
walks_per_node,
|
419 |
+
p,
|
420 |
+
q,
|
421 |
+
seed):
|
422 |
+
"""Sample r random walks of length l per node in parallel from the graph.
|
423 |
+
Parameters
|
424 |
+
----------
|
425 |
+
indptr : array-like
|
426 |
+
Pointer for the edges of each node
|
427 |
+
indices : array-like
|
428 |
+
Edges for each node
|
429 |
+
walk_length : int
|
430 |
+
Random walk length
|
431 |
+
walks_per_node : int
|
432 |
+
Number of random walks per node
|
433 |
+
p: float
|
434 |
+
The probability to go back
|
435 |
+
q: float,
|
436 |
+
The probability to go explore undiscovered parts of the graphs
|
437 |
+
seed : int
|
438 |
+
Random seed
|
439 |
+
Returns
|
440 |
+
-------
|
441 |
+
walks : list generator, shape [r, N*l]
|
442 |
+
The sampled random walks
|
443 |
+
"""
|
444 |
+
np.random.seed(seed)
|
445 |
+
N = len(indptr) - 1
|
446 |
+
for _ in range(walks_per_node):
|
447 |
+
for n in range(N):
|
448 |
+
walk = [n]
|
449 |
+
current_node = n
|
450 |
+
previous_node = N
|
451 |
+
previous_node_neighbors = np.empty(0, dtype=np.int32)
|
452 |
+
for _ in range(walk_length - 1):
|
453 |
+
neighbors = indices[indptr[current_node]:indptr[current_node + 1]]
|
454 |
+
if neighbors.size == 0:
|
455 |
+
break
|
456 |
+
|
457 |
+
probability = np.array([1 / q] * neighbors.size)
|
458 |
+
probability[previous_node == neighbors] = 1 / p
|
459 |
+
|
460 |
+
for i, nbr in enumerate(neighbors):
|
461 |
+
if np.any(nbr == previous_node_neighbors):
|
462 |
+
probability[i] = 1.
|
463 |
+
|
464 |
+
norm_probability = probability / np.sum(probability)
|
465 |
+
current_node = random_choice(neighbors, norm_probability)
|
466 |
+
walk.append(current_node)
|
467 |
+
previous_node_neighbors = neighbors
|
468 |
+
previous_node = current_node
|
469 |
+
yield walk
|
470 |
+
|
471 |
+
def sum_of_powers_of_transition_matrix(adj, pow):
|
472 |
+
"""Computes \sum_{r=1}^{pow) (D^{-1}A)^r.
|
473 |
+
|
474 |
+
Parameters
|
475 |
+
-----
|
476 |
+
adj: sp.csr_matrix, shape [n_nodes, n_nodes]
|
477 |
+
Adjacency matrix of the graph
|
478 |
+
pow: int
|
479 |
+
Power exponent
|
480 |
+
|
481 |
+
Returns
|
482 |
+
----
|
483 |
+
sp.csr_matrix
|
484 |
+
Sum of powers of the transition matrix of a graph.
|
485 |
+
"""
|
486 |
+
deg = adj.sum(1).A1
|
487 |
+
deg[deg == 0] = 1
|
488 |
+
transition_matrix = sp.diags(1 / deg).dot(adj)
|
489 |
+
|
490 |
+
sum_of_powers = transition_matrix
|
491 |
+
last = transition_matrix
|
492 |
+
for i in range(1, pow):
|
493 |
+
last = last.dot(transition_matrix)
|
494 |
+
sum_of_powers += last
|
495 |
+
|
496 |
+
return sum_of_powers
|
497 |
+
|
498 |
+
|
499 |
+
if __name__ == "__main__":
|
500 |
+
from deeprobust.graph.data import Dataset
|
501 |
+
from deeprobust.graph.global_attack import NodeEmbeddingAttack
|
502 |
+
dataset_str = 'cora_ml'
|
503 |
+
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
|
504 |
+
adj, features, labels = data.adj, data.features, data.labels
|
505 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
506 |
+
|
507 |
+
model = NodeEmbeddingAttack()
|
508 |
+
model.attack(adj, attack_type="add_by_remove", n_perturbations=1000, n_candidates=10000)
|
509 |
+
modified_adj = model.modified_adj
|
510 |
+
|
511 |
+
# train defense model
|
512 |
+
print("Test DeepWalk on clean graph")
|
513 |
+
model = DeepWalk()
|
514 |
+
model.fit(adj)
|
515 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
516 |
+
# model.evaluate_node_classification(labels, idx_train, idx_test, lr_params={"max_iter": 10})
|
517 |
+
|
518 |
+
print("Test DeepWalk on attacked graph")
|
519 |
+
model.fit(modified_adj)
|
520 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
521 |
+
print("\t link prediciton...")
|
522 |
+
model.evaluate_link_prediction(modified_adj, np.array(adj.nonzero()).T)
|
523 |
+
|
524 |
+
print("Test DeepWalk SVD")
|
525 |
+
model = DeepWalk(type="svd")
|
526 |
+
model.fit(modified_adj)
|
527 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
528 |
+
|
529 |
+
# train defense model
|
530 |
+
print("Test Node2vec on clean graph")
|
531 |
+
model = Node2Vec()
|
532 |
+
model.fit(adj)
|
533 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
534 |
+
|
535 |
+
print("Test Node2vec on attacked graph")
|
536 |
+
model = Node2Vec()
|
537 |
+
model.fit(modified_adj)
|
538 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
deeprobust/graph/defense/prognn.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
from copy import deepcopy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
from deeprobust.graph.utils import accuracy
|
9 |
+
from deeprobust.graph.defense.pgd import PGD, prox_operators
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
class ProGNN:
|
13 |
+
""" ProGNN (Properties Graph Neural Network). See more details in Graph Structure Learning for Robust Graph Neural Networks, KDD 2020, https://arxiv.org/abs/2005.10203.
|
14 |
+
|
15 |
+
Parameters
|
16 |
+
----------
|
17 |
+
model:
|
18 |
+
model: The backbone GNN model in ProGNN
|
19 |
+
args:
|
20 |
+
model configs
|
21 |
+
device: str
|
22 |
+
'cpu' or 'cuda'.
|
23 |
+
|
24 |
+
Examples
|
25 |
+
--------
|
26 |
+
See details in https://github.com/ChandlerBang/Pro-GNN.
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, model, args, device):
|
31 |
+
self.device = device
|
32 |
+
self.args = args
|
33 |
+
self.best_val_acc = 0
|
34 |
+
self.best_val_loss = 10
|
35 |
+
self.best_graph = None
|
36 |
+
self.weights = None
|
37 |
+
self.estimator = None
|
38 |
+
self.model = model.to(device)
|
39 |
+
|
40 |
+
def fit(self, features, adj, labels, idx_train, idx_val, **kwargs):
|
41 |
+
"""Train Pro-GNN.
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
features :
|
46 |
+
node features
|
47 |
+
adj :
|
48 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
49 |
+
labels :
|
50 |
+
node labels
|
51 |
+
idx_train :
|
52 |
+
node training indices
|
53 |
+
idx_val :
|
54 |
+
node validation indices
|
55 |
+
"""
|
56 |
+
args = self.args
|
57 |
+
|
58 |
+
self.optimizer = optim.Adam(self.model.parameters(),
|
59 |
+
lr=args.lr, weight_decay=args.weight_decay)
|
60 |
+
estimator = EstimateAdj(adj, symmetric=args.symmetric, device=self.device).to(self.device)
|
61 |
+
self.estimator = estimator
|
62 |
+
self.optimizer_adj = optim.SGD(estimator.parameters(),
|
63 |
+
momentum=0.9, lr=args.lr_adj)
|
64 |
+
|
65 |
+
self.optimizer_l1 = PGD(estimator.parameters(),
|
66 |
+
proxs=[prox_operators.prox_l1],
|
67 |
+
lr=args.lr_adj, alphas=[args.alpha])
|
68 |
+
|
69 |
+
# warnings.warn("If you find the nuclear proximal operator runs too slow on Pubmed, you can uncomment line 67-71 and use prox_nuclear_cuda to perform the proximal on gpu.")
|
70 |
+
# if args.dataset == "pubmed":
|
71 |
+
# self.optimizer_nuclear = PGD(estimator.parameters(),
|
72 |
+
# proxs=[prox_operators.prox_nuclear_cuda],
|
73 |
+
# lr=args.lr_adj, alphas=[args.beta])
|
74 |
+
# else:
|
75 |
+
warnings.warn("If you find the nuclear proximal operator runs too slow, you can modify line 77 to use prox_operators.prox_nuclear_cuda instead of prox_operators.prox_nuclear to perform the proximal on GPU. See details in https://github.com/ChandlerBang/Pro-GNN/issues/1")
|
76 |
+
self.optimizer_nuclear = PGD(estimator.parameters(),
|
77 |
+
proxs=[prox_operators.prox_nuclear_cuda],
|
78 |
+
lr=args.lr_adj, alphas=[args.beta])
|
79 |
+
|
80 |
+
# Train model
|
81 |
+
t_total = time.time()
|
82 |
+
for epoch in range(args.epochs):
|
83 |
+
if args.only_gcn:
|
84 |
+
self.train_gcn(epoch, features, estimator.estimated_adj,
|
85 |
+
labels, idx_train, idx_val)
|
86 |
+
else:
|
87 |
+
for i in range(int(args.outer_steps)):
|
88 |
+
self.train_adj(epoch, features, adj, labels,
|
89 |
+
idx_train, idx_val)
|
90 |
+
|
91 |
+
for i in range(int(args.inner_steps)):
|
92 |
+
self.train_gcn(epoch, features, estimator.estimated_adj,
|
93 |
+
labels, idx_train, idx_val)
|
94 |
+
|
95 |
+
print("Optimization Finished!")
|
96 |
+
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
|
97 |
+
print(args)
|
98 |
+
|
99 |
+
# Testing
|
100 |
+
print("picking the best model according to validation performance")
|
101 |
+
self.model.load_state_dict(self.weights)
|
102 |
+
|
103 |
+
def train_gcn(self, epoch, features, adj, labels, idx_train, idx_val):
|
104 |
+
args = self.args
|
105 |
+
estimator = self.estimator
|
106 |
+
adj = estimator.normalize()
|
107 |
+
|
108 |
+
t = time.time()
|
109 |
+
self.model.train()
|
110 |
+
self.optimizer.zero_grad()
|
111 |
+
|
112 |
+
output = self.model(features, adj)
|
113 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
114 |
+
acc_train = accuracy(output[idx_train], labels[idx_train])
|
115 |
+
loss_train.backward()
|
116 |
+
self.optimizer.step()
|
117 |
+
|
118 |
+
# Evaluate validation set performance separately,
|
119 |
+
# deactivates dropout during validation run.
|
120 |
+
self.model.eval()
|
121 |
+
output = self.model(features, adj)
|
122 |
+
|
123 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
124 |
+
acc_val = accuracy(output[idx_val], labels[idx_val])
|
125 |
+
|
126 |
+
if acc_val > self.best_val_acc:
|
127 |
+
self.best_val_acc = acc_val
|
128 |
+
self.best_graph = adj.detach()
|
129 |
+
self.weights = deepcopy(self.model.state_dict())
|
130 |
+
if args.debug:
|
131 |
+
print('\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item())
|
132 |
+
|
133 |
+
if loss_val < self.best_val_loss:
|
134 |
+
self.best_val_loss = loss_val
|
135 |
+
self.best_graph = adj.detach()
|
136 |
+
self.weights = deepcopy(self.model.state_dict())
|
137 |
+
if args.debug:
|
138 |
+
print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item())
|
139 |
+
|
140 |
+
if args.debug:
|
141 |
+
if epoch % 1 == 0:
|
142 |
+
print('Epoch: {:04d}'.format(epoch+1),
|
143 |
+
'loss_train: {:.4f}'.format(loss_train.item()),
|
144 |
+
'acc_train: {:.4f}'.format(acc_train.item()),
|
145 |
+
'loss_val: {:.4f}'.format(loss_val.item()),
|
146 |
+
'acc_val: {:.4f}'.format(acc_val.item()),
|
147 |
+
'time: {:.4f}s'.format(time.time() - t))
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def train_adj(self, epoch, features, adj, labels, idx_train, idx_val):
|
152 |
+
estimator = self.estimator
|
153 |
+
args = self.args
|
154 |
+
if args.debug:
|
155 |
+
print("\n=== train_adj ===")
|
156 |
+
t = time.time()
|
157 |
+
estimator.train()
|
158 |
+
self.optimizer_adj.zero_grad()
|
159 |
+
|
160 |
+
loss_l1 = torch.norm(estimator.estimated_adj, 1)
|
161 |
+
loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro')
|
162 |
+
normalized_adj = estimator.normalize()
|
163 |
+
|
164 |
+
if args.lambda_:
|
165 |
+
loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj, features)
|
166 |
+
else:
|
167 |
+
loss_smooth_feat = 0 * loss_l1
|
168 |
+
|
169 |
+
output = self.model(features, normalized_adj)
|
170 |
+
loss_gcn = F.nll_loss(output[idx_train], labels[idx_train])
|
171 |
+
acc_train = accuracy(output[idx_train], labels[idx_train])
|
172 |
+
|
173 |
+
loss_symmetric = torch.norm(estimator.estimated_adj \
|
174 |
+
- estimator.estimated_adj.t(), p="fro")
|
175 |
+
|
176 |
+
loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat + args.phi * loss_symmetric
|
177 |
+
|
178 |
+
loss_diffiential.backward()
|
179 |
+
|
180 |
+
self.optimizer_adj.step()
|
181 |
+
loss_nuclear = 0 * loss_fro
|
182 |
+
if args.beta != 0:
|
183 |
+
self.optimizer_nuclear.zero_grad()
|
184 |
+
self.optimizer_nuclear.step()
|
185 |
+
loss_nuclear = prox_operators.nuclear_norm
|
186 |
+
|
187 |
+
self.optimizer_l1.zero_grad()
|
188 |
+
self.optimizer_l1.step()
|
189 |
+
|
190 |
+
total_loss = loss_fro \
|
191 |
+
+ args.gamma * loss_gcn \
|
192 |
+
+ args.alpha * loss_l1 \
|
193 |
+
+ args.beta * loss_nuclear \
|
194 |
+
+ args.phi * loss_symmetric
|
195 |
+
|
196 |
+
estimator.estimated_adj.data.copy_(torch.clamp(
|
197 |
+
estimator.estimated_adj.data, min=0, max=1))
|
198 |
+
|
199 |
+
# Evaluate validation set performance separately,
|
200 |
+
# deactivates dropout during validation run.
|
201 |
+
self.model.eval()
|
202 |
+
normalized_adj = estimator.normalize()
|
203 |
+
output = self.model(features, normalized_adj)
|
204 |
+
|
205 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
206 |
+
acc_val = accuracy(output[idx_val], labels[idx_val])
|
207 |
+
print('Epoch: {:04d}'.format(epoch+1),
|
208 |
+
'acc_train: {:.4f}'.format(acc_train.item()),
|
209 |
+
'loss_val: {:.4f}'.format(loss_val.item()),
|
210 |
+
'acc_val: {:.4f}'.format(acc_val.item()),
|
211 |
+
'time: {:.4f}s'.format(time.time() - t))
|
212 |
+
|
213 |
+
if acc_val > self.best_val_acc:
|
214 |
+
self.best_val_acc = acc_val
|
215 |
+
self.best_graph = normalized_adj.detach()
|
216 |
+
self.weights = deepcopy(self.model.state_dict())
|
217 |
+
if args.debug:
|
218 |
+
print(f'\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item())
|
219 |
+
|
220 |
+
if loss_val < self.best_val_loss:
|
221 |
+
self.best_val_loss = loss_val
|
222 |
+
self.best_graph = normalized_adj.detach()
|
223 |
+
self.weights = deepcopy(self.model.state_dict())
|
224 |
+
if args.debug:
|
225 |
+
print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item())
|
226 |
+
|
227 |
+
if args.debug:
|
228 |
+
if epoch % 1 == 0:
|
229 |
+
print('Epoch: {:04d}'.format(epoch+1),
|
230 |
+
'loss_fro: {:.4f}'.format(loss_fro.item()),
|
231 |
+
'loss_gcn: {:.4f}'.format(loss_gcn.item()),
|
232 |
+
'loss_feat: {:.4f}'.format(loss_smooth_feat.item()),
|
233 |
+
'loss_symmetric: {:.4f}'.format(loss_symmetric.item()),
|
234 |
+
'delta_l1_norm: {:.4f}'.format(torch.norm(estimator.estimated_adj-adj, 1).item()),
|
235 |
+
'loss_l1: {:.4f}'.format(loss_l1.item()),
|
236 |
+
'loss_total: {:.4f}'.format(total_loss.item()),
|
237 |
+
'loss_nuclear: {:.4f}'.format(loss_nuclear.item()))
|
238 |
+
|
239 |
+
|
240 |
+
def test(self, features, labels, idx_test):
|
241 |
+
"""Evaluate the performance of ProGNN on test set
|
242 |
+
"""
|
243 |
+
print("\t=== testing ===")
|
244 |
+
self.model.eval()
|
245 |
+
adj = self.best_graph
|
246 |
+
if self.best_graph is None:
|
247 |
+
adj = self.estimator.normalize()
|
248 |
+
output = self.model(features, adj)
|
249 |
+
loss_test = F.nll_loss(output[idx_test], labels[idx_test])
|
250 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
251 |
+
print("\tTest set results:",
|
252 |
+
"loss= {:.4f}".format(loss_test.item()),
|
253 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
254 |
+
return acc_test.item()
|
255 |
+
|
256 |
+
def feature_smoothing(self, adj, X):
|
257 |
+
adj = (adj.t() + adj)/2
|
258 |
+
rowsum = adj.sum(1)
|
259 |
+
r_inv = rowsum.flatten()
|
260 |
+
D = torch.diag(r_inv)
|
261 |
+
L = D - adj
|
262 |
+
|
263 |
+
r_inv = r_inv + 1e-3
|
264 |
+
r_inv = r_inv.pow(-1/2).flatten()
|
265 |
+
r_inv[torch.isinf(r_inv)] = 0.
|
266 |
+
r_mat_inv = torch.diag(r_inv)
|
267 |
+
# L = r_mat_inv @ L
|
268 |
+
L = r_mat_inv @ L @ r_mat_inv
|
269 |
+
|
270 |
+
XLXT = torch.matmul(torch.matmul(X.t(), L), X)
|
271 |
+
loss_smooth_feat = torch.trace(XLXT)
|
272 |
+
return loss_smooth_feat
|
273 |
+
|
274 |
+
|
275 |
+
class EstimateAdj(nn.Module):
|
276 |
+
"""Provide a pytorch parameter matrix for estimated
|
277 |
+
adjacency matrix and corresponding operations.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, adj, symmetric=False, device='cpu'):
|
281 |
+
super(EstimateAdj, self).__init__()
|
282 |
+
n = len(adj)
|
283 |
+
self.estimated_adj = nn.Parameter(torch.FloatTensor(n, n))
|
284 |
+
self._init_estimation(adj)
|
285 |
+
self.symmetric = symmetric
|
286 |
+
self.device = device
|
287 |
+
|
288 |
+
def _init_estimation(self, adj):
|
289 |
+
with torch.no_grad():
|
290 |
+
n = len(adj)
|
291 |
+
self.estimated_adj.data.copy_(adj)
|
292 |
+
|
293 |
+
def forward(self):
|
294 |
+
return self.estimated_adj
|
295 |
+
|
296 |
+
def normalize(self):
|
297 |
+
|
298 |
+
if self.symmetric:
|
299 |
+
adj = (self.estimated_adj + self.estimated_adj.t())/2
|
300 |
+
else:
|
301 |
+
adj = self.estimated_adj
|
302 |
+
|
303 |
+
normalized_adj = self._normalize(adj + torch.eye(adj.shape[0]).to(self.device))
|
304 |
+
return normalized_adj
|
305 |
+
|
306 |
+
def _normalize(self, mx):
|
307 |
+
rowsum = mx.sum(1)
|
308 |
+
r_inv = rowsum.pow(-1/2).flatten()
|
309 |
+
r_inv[torch.isinf(r_inv)] = 0.
|
310 |
+
r_mat_inv = torch.diag(r_inv)
|
311 |
+
mx = r_mat_inv @ mx
|
312 |
+
mx = mx @ r_mat_inv
|
313 |
+
return mx
|
314 |
+
|
deeprobust/graph/defense/r_gcn.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Robust Graph Convolutional Networks Against Adversarial Attacks. KDD 2019.
|
3 |
+
http://pengcui.thumedialab.com/papers/RGCN.pdf
|
4 |
+
Author's Tensorflow implemention:
|
5 |
+
https://github.com/thumanlab/nrlweb/tree/master/static/assets/download
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import math
|
10 |
+
import torch
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
from torch.nn.modules.module import Module
|
13 |
+
from torch.distributions.multivariate_normal import MultivariateNormal
|
14 |
+
from deeprobust.graph import utils
|
15 |
+
import torch.optim as optim
|
16 |
+
from scipy.sparse import issparse
|
17 |
+
from copy import deepcopy
|
18 |
+
|
19 |
+
# TODO sparse implementation
|
20 |
+
|
21 |
+
class GGCL_F(Module):
|
22 |
+
"""Graph Gaussian Convolution Layer (GGCL) when the input is feature"""
|
23 |
+
|
24 |
+
def __init__(self, in_features, out_features, dropout=0.6):
|
25 |
+
super(GGCL_F, self).__init__()
|
26 |
+
self.in_features = in_features
|
27 |
+
self.out_features = out_features
|
28 |
+
self.dropout = dropout
|
29 |
+
self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features))
|
30 |
+
self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features))
|
31 |
+
self.reset_parameters()
|
32 |
+
|
33 |
+
def reset_parameters(self):
|
34 |
+
torch.nn.init.xavier_uniform_(self.weight_miu)
|
35 |
+
torch.nn.init.xavier_uniform_(self.weight_sigma)
|
36 |
+
|
37 |
+
def forward(self, features, adj_norm1, adj_norm2, gamma=1):
|
38 |
+
features = F.dropout(features, self.dropout, training=self.training)
|
39 |
+
self.miu = F.elu(torch.mm(features, self.weight_miu))
|
40 |
+
self.sigma = F.relu(torch.mm(features, self.weight_sigma))
|
41 |
+
# torch.mm(previous_sigma, self.weight_sigma)
|
42 |
+
Att = torch.exp(-gamma * self.sigma)
|
43 |
+
miu_out = adj_norm1 @ (self.miu * Att)
|
44 |
+
sigma_out = adj_norm2 @ (self.sigma * Att * Att)
|
45 |
+
return miu_out, sigma_out
|
46 |
+
|
47 |
+
class GGCL_D(Module):
|
48 |
+
|
49 |
+
"""Graph Gaussian Convolution Layer (GGCL) when the input is distribution"""
|
50 |
+
def __init__(self, in_features, out_features, dropout):
|
51 |
+
super(GGCL_D, self).__init__()
|
52 |
+
self.in_features = in_features
|
53 |
+
self.out_features = out_features
|
54 |
+
self.dropout = dropout
|
55 |
+
self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features))
|
56 |
+
self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features))
|
57 |
+
# self.register_parameter('bias', None)
|
58 |
+
self.reset_parameters()
|
59 |
+
|
60 |
+
def reset_parameters(self):
|
61 |
+
torch.nn.init.xavier_uniform_(self.weight_miu)
|
62 |
+
torch.nn.init.xavier_uniform_(self.weight_sigma)
|
63 |
+
|
64 |
+
def forward(self, miu, sigma, adj_norm1, adj_norm2, gamma=1):
|
65 |
+
miu = F.dropout(miu, self.dropout, training=self.training)
|
66 |
+
sigma = F.dropout(sigma, self.dropout, training=self.training)
|
67 |
+
miu = F.elu(miu @ self.weight_miu)
|
68 |
+
sigma = F.relu(sigma @ self.weight_sigma)
|
69 |
+
|
70 |
+
Att = torch.exp(-gamma * sigma)
|
71 |
+
mean_out = adj_norm1 @ (miu * Att)
|
72 |
+
sigma_out = adj_norm2 @ (sigma * Att * Att)
|
73 |
+
return mean_out, sigma_out
|
74 |
+
|
75 |
+
|
76 |
+
class GaussianConvolution(Module):
|
77 |
+
"""[Deprecated] Alternative gaussion convolution layer.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, in_features, out_features):
|
81 |
+
super(GaussianConvolution, self).__init__()
|
82 |
+
self.in_features = in_features
|
83 |
+
self.out_features = out_features
|
84 |
+
self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features))
|
85 |
+
self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features))
|
86 |
+
# self.sigma = Parameter(torch.FloatTensor(out_features))
|
87 |
+
# self.register_parameter('bias', None)
|
88 |
+
self.reset_parameters()
|
89 |
+
|
90 |
+
def reset_parameters(self):
|
91 |
+
# TODO
|
92 |
+
torch.nn.init.xavier_uniform_(self.weight_miu)
|
93 |
+
torch.nn.init.xavier_uniform_(self.weight_sigma)
|
94 |
+
|
95 |
+
def forward(self, previous_miu, previous_sigma, adj_norm1=None, adj_norm2=None, gamma=1):
|
96 |
+
|
97 |
+
if adj_norm1 is None and adj_norm2 is None:
|
98 |
+
return torch.mm(previous_miu, self.weight_miu), \
|
99 |
+
torch.mm(previous_miu, self.weight_miu)
|
100 |
+
# torch.mm(previous_sigma, self.weight_sigma)
|
101 |
+
|
102 |
+
Att = torch.exp(-gamma * previous_sigma)
|
103 |
+
M = adj_norm1 @ (previous_miu * Att) @ self.weight_miu
|
104 |
+
Sigma = adj_norm2 @ (previous_sigma * Att * Att) @ self.weight_sigma
|
105 |
+
return M, Sigma
|
106 |
+
|
107 |
+
# M = torch.mm(torch.mm(adj, previous_miu * A), self.weight_miu)
|
108 |
+
# Sigma = torch.mm(torch.mm(adj, previous_sigma * A * A), self.weight_sigma)
|
109 |
+
|
110 |
+
# TODO sparse implemention
|
111 |
+
# support = torch.mm(input, self.weight)
|
112 |
+
# output = torch.spmm(adj, support)
|
113 |
+
# return output + self.bias
|
114 |
+
|
115 |
+
def __repr__(self):
|
116 |
+
return self.__class__.__name__ + ' (' \
|
117 |
+
+ str(self.in_features) + ' -> ' \
|
118 |
+
+ str(self.out_features) + ')'
|
119 |
+
|
120 |
+
|
121 |
+
class RGCN(Module):
|
122 |
+
"""Robust Graph Convolutional Networks Against Adversarial Attacks. KDD 2019.
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
nnodes : int
|
127 |
+
number of nodes in the input grpah
|
128 |
+
nfeat : int
|
129 |
+
size of input feature dimension
|
130 |
+
nhid : int
|
131 |
+
number of hidden units
|
132 |
+
nclass : int
|
133 |
+
size of output dimension
|
134 |
+
gamma : float
|
135 |
+
hyper-parameter for RGCN. See more details in the paper.
|
136 |
+
beta1 : float
|
137 |
+
hyper-parameter for RGCN. See more details in the paper.
|
138 |
+
beta2 : float
|
139 |
+
hyper-parameter for RGCN. See more details in the paper.
|
140 |
+
lr : float
|
141 |
+
learning rate for GCN
|
142 |
+
dropout : float
|
143 |
+
dropout rate for GCN
|
144 |
+
device: str
|
145 |
+
'cpu' or 'cuda'.
|
146 |
+
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'):
|
150 |
+
super(RGCN, self).__init__()
|
151 |
+
|
152 |
+
self.device = device
|
153 |
+
# adj_norm = normalize(adj)
|
154 |
+
# first turn original features to distribution
|
155 |
+
self.lr = lr
|
156 |
+
self.gamma = gamma
|
157 |
+
self.beta1 = beta1
|
158 |
+
self.beta2 = beta2
|
159 |
+
self.nclass = nclass
|
160 |
+
self.nhid = nhid // 2
|
161 |
+
# self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout)
|
162 |
+
# self.gc2 = GaussianConvolution(nhid, nclass, dropout)
|
163 |
+
self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout)
|
164 |
+
self.gc2 = GGCL_D(nhid, nclass, dropout=dropout)
|
165 |
+
|
166 |
+
self.dropout = dropout
|
167 |
+
# self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass))
|
168 |
+
self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass),
|
169 |
+
torch.diag_embed(torch.ones(nnodes, self.nclass)))
|
170 |
+
self.adj_norm1, self.adj_norm2 = None, None
|
171 |
+
self.features, self.labels = None, None
|
172 |
+
|
173 |
+
def forward(self):
|
174 |
+
features = self.features
|
175 |
+
miu, sigma = self.gc1(features, self.adj_norm1, self.adj_norm2, self.gamma)
|
176 |
+
miu, sigma = self.gc2(miu, sigma, self.adj_norm1, self.adj_norm2, self.gamma)
|
177 |
+
output = miu + self.gaussian.sample().to(self.device) * torch.sqrt(sigma + 1e-8)
|
178 |
+
return F.log_softmax(output, dim=1)
|
179 |
+
|
180 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, verbose=True, **kwargs):
|
181 |
+
"""Train RGCN.
|
182 |
+
|
183 |
+
Parameters
|
184 |
+
----------
|
185 |
+
features :
|
186 |
+
node features
|
187 |
+
adj :
|
188 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
189 |
+
labels :
|
190 |
+
node labels
|
191 |
+
idx_train :
|
192 |
+
node training indices
|
193 |
+
idx_val :
|
194 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
195 |
+
train_iters : int
|
196 |
+
number of training epochs
|
197 |
+
verbose : bool
|
198 |
+
whether to show verbose logs
|
199 |
+
|
200 |
+
Examples
|
201 |
+
--------
|
202 |
+
We can first load dataset and then train RGCN.
|
203 |
+
|
204 |
+
>>> from deeprobust.graph.data import PrePtbDataset, Dataset
|
205 |
+
>>> from deeprobust.graph.defense import RGCN
|
206 |
+
>>> # load clean graph data
|
207 |
+
>>> data = Dataset(root='/tmp/', name='cora', seed=15)
|
208 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
209 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
210 |
+
>>> # load perturbed graph data
|
211 |
+
>>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora')
|
212 |
+
>>> perturbed_adj = perturbed_data.adj
|
213 |
+
>>> # train defense model
|
214 |
+
>>> model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1],
|
215 |
+
nclass=labels.max()+1, nhid=32, device='cpu')
|
216 |
+
>>> model.fit(features, perturbed_adj, labels, idx_train, idx_val,
|
217 |
+
train_iters=200, verbose=True)
|
218 |
+
>>> model.test(idx_test)
|
219 |
+
|
220 |
+
"""
|
221 |
+
adj = adj.to('cpu')
|
222 |
+
if isinstance(adj, torch.Tensor):
|
223 |
+
if adj.is_sparse:
|
224 |
+
adj_dense = adj.to_dense().float()
|
225 |
+
else:
|
226 |
+
adj_dense = adj.float()
|
227 |
+
else:
|
228 |
+
adj_dense = torch.tensor(adj.toarray(), dtype=torch.float32)
|
229 |
+
|
230 |
+
if issparse(features):
|
231 |
+
features_dense = features.toarray()
|
232 |
+
else:
|
233 |
+
features_dense = features
|
234 |
+
|
235 |
+
adj, features, labels = utils.to_tensor(adj_dense, features_dense, labels, device=self.device)
|
236 |
+
|
237 |
+
self.features, self.labels = features, labels
|
238 |
+
self.adj_norm1 = self._normalize_adj(adj, power=-1/2)
|
239 |
+
self.adj_norm2 = self._normalize_adj(adj, power=-1)
|
240 |
+
print('=== training rgcn model ===')
|
241 |
+
self._initialize()
|
242 |
+
if idx_val is None:
|
243 |
+
self._train_without_val(labels, idx_train, train_iters, verbose)
|
244 |
+
else:
|
245 |
+
self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
|
246 |
+
|
247 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose=True):
|
248 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
|
249 |
+
self.train()
|
250 |
+
for i in range(train_iters):
|
251 |
+
optimizer.zero_grad()
|
252 |
+
output = self.forward()
|
253 |
+
loss_train = self._loss(output[idx_train], labels[idx_train])
|
254 |
+
loss_train.backward()
|
255 |
+
optimizer.step()
|
256 |
+
if verbose and i % 10 == 0:
|
257 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
258 |
+
|
259 |
+
self.eval()
|
260 |
+
output = self.forward()
|
261 |
+
self.output = output
|
262 |
+
|
263 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
264 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
|
265 |
+
|
266 |
+
best_loss_val = 100
|
267 |
+
best_acc_val = 0
|
268 |
+
|
269 |
+
for i in range(train_iters):
|
270 |
+
self.train()
|
271 |
+
optimizer.zero_grad()
|
272 |
+
output = self.forward()
|
273 |
+
loss_train = self._loss(output[idx_train], labels[idx_train])
|
274 |
+
loss_train.backward()
|
275 |
+
optimizer.step()
|
276 |
+
if verbose and i % 10 == 0:
|
277 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
278 |
+
|
279 |
+
self.eval()
|
280 |
+
output = self.forward()
|
281 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
282 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
283 |
+
|
284 |
+
if best_loss_val > loss_val:
|
285 |
+
best_loss_val = loss_val
|
286 |
+
self.output = output
|
287 |
+
|
288 |
+
if acc_val > best_acc_val:
|
289 |
+
best_acc_val = acc_val
|
290 |
+
self.output = output
|
291 |
+
|
292 |
+
print('=== picking the best model according to the performance on validation ===')
|
293 |
+
|
294 |
+
|
295 |
+
def test(self, idx_test):
|
296 |
+
"""Evaluate the peformance on test set
|
297 |
+
"""
|
298 |
+
self.eval()
|
299 |
+
# output = self.forward()
|
300 |
+
output = self.output
|
301 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
302 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
303 |
+
print("Test set results:",
|
304 |
+
"loss= {:.4f}".format(loss_test.item()),
|
305 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
306 |
+
return acc_test.item()
|
307 |
+
|
308 |
+
def predict(self):
|
309 |
+
"""
|
310 |
+
Returns
|
311 |
+
-------
|
312 |
+
torch.FloatTensor
|
313 |
+
output (log probabilities) of RGCN
|
314 |
+
"""
|
315 |
+
|
316 |
+
self.eval()
|
317 |
+
return self.forward()
|
318 |
+
|
319 |
+
def _loss(self, input, labels):
|
320 |
+
loss = F.nll_loss(input, labels)
|
321 |
+
miu1 = self.gc1.miu
|
322 |
+
sigma1 = self.gc1.sigma
|
323 |
+
kl_loss = 0.5 * (miu1.pow(2) + sigma1 - torch.log(1e-8 + sigma1)).mean(1)
|
324 |
+
kl_loss = kl_loss.sum()
|
325 |
+
norm2 = torch.norm(self.gc1.weight_miu, 2).pow(2) + \
|
326 |
+
torch.norm(self.gc1.weight_sigma, 2).pow(2)
|
327 |
+
|
328 |
+
# print(f'gcn_loss: {loss.item()}, kl_loss: {self.beta1 * kl_loss.item()}, norm2: {self.beta2 * norm2.item()}')
|
329 |
+
return loss + self.beta1 * kl_loss + self.beta2 * norm2
|
330 |
+
|
331 |
+
def _initialize(self):
|
332 |
+
self.gc1.reset_parameters()
|
333 |
+
self.gc2.reset_parameters()
|
334 |
+
|
335 |
+
def _normalize_adj(self, adj, power=-1/2):
|
336 |
+
|
337 |
+
"""Row-normalize sparse matrix"""
|
338 |
+
A = adj + torch.eye(len(adj)).to(self.device)
|
339 |
+
D_power = (A.sum(1)).pow(power)
|
340 |
+
D_power[torch.isinf(D_power)] = 0.
|
341 |
+
D_power = torch.diag(D_power)
|
342 |
+
return D_power @ A @ D_power
|
343 |
+
|
344 |
+
if __name__ == "__main__":
|
345 |
+
|
346 |
+
from deeprobust.graph.data import PrePtbDataset, Dataset
|
347 |
+
# load clean graph data
|
348 |
+
dataset_str = 'pubmed'
|
349 |
+
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
|
350 |
+
adj, features, labels = data.adj, data.features, data.labels
|
351 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
352 |
+
# load perturbed graph data
|
353 |
+
perturbed_data = PrePtbDataset(root='/tmp/', name=dataset_str)
|
354 |
+
perturbed_adj = perturbed_data.adj
|
355 |
+
|
356 |
+
# train defense model
|
357 |
+
model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1],
|
358 |
+
nclass=labels.max()+1, nhid=32, device='cuda').to('cuda')
|
359 |
+
model.fit(features, perturbed_adj, labels, idx_train, idx_val,
|
360 |
+
train_iters=200, verbose=True)
|
361 |
+
model.test(idx_test)
|
362 |
+
|
363 |
+
prediction_1 = model.predict()
|
364 |
+
print(prediction_1)
|
365 |
+
# prediction_2 = model.predict(features, perturbed_adj)
|
366 |
+
# assert (prediction_1 != prediction_2).sum() == 0
|
367 |
+
|
deeprobust/graph/defense/r_gcn.py.backup
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Robust Graph Convolutional Networks Against Adversarial Attacks. KDD 2019.
|
3 |
+
http://pengcui.thumedialab.com/papers/RGCN.pdf
|
4 |
+
Author's Tensorflow implemention:
|
5 |
+
https://github.com/thumanlab/nrlweb/tree/master/static/assets/download
|
6 |
+
'''
|
7 |
+
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import math
|
10 |
+
import torch
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
from torch.nn.modules.module import Module
|
13 |
+
from torch.distributions.multivariate_normal import MultivariateNormal
|
14 |
+
from deeprobust.graph import utils
|
15 |
+
import torch.optim as optim
|
16 |
+
from copy import deepcopy
|
17 |
+
|
18 |
+
class GaussianConvolution(Module):
|
19 |
+
|
20 |
+
def __init__(self, in_features, out_features):
|
21 |
+
super(GaussianConvolution, self).__init__()
|
22 |
+
self.in_features = in_features
|
23 |
+
self.out_features = out_features
|
24 |
+
self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features))
|
25 |
+
self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features))
|
26 |
+
# self.sigma = Parameter(torch.FloatTensor(out_features))
|
27 |
+
# self.register_parameter('bias', None)
|
28 |
+
self.reset_parameters()
|
29 |
+
|
30 |
+
def reset_parameters(self):
|
31 |
+
# TODO
|
32 |
+
torch.nn.init.xavier_uniform_(self.weight_miu)
|
33 |
+
torch.nn.init.xavier_uniform_(self.weight_sigma)
|
34 |
+
|
35 |
+
def forward(self, previous_miu, previous_sigma, adj_norm1=None, adj_norm2=None, gamma=1):
|
36 |
+
|
37 |
+
if adj_norm1 is None and adj_norm2 is None:
|
38 |
+
return torch.mm(previous_miu, self.weight_miu), \
|
39 |
+
torch.mm(previous_sigma, self.weight_sigma)
|
40 |
+
|
41 |
+
Att = torch.exp(-gamma * previous_sigma)
|
42 |
+
M = adj_norm1 @ (previous_miu * Att) @ self.weight_miu
|
43 |
+
Sigma = adj_norm2 @ (previous_sigma * Att * Att) @ self.weight_sigma
|
44 |
+
return M, Sigma
|
45 |
+
|
46 |
+
# M = torch.mm(torch.mm(adj, previous_miu * A), self.weight_miu)
|
47 |
+
# Sigma = torch.mm(torch.mm(adj, previous_sigma * A * A), self.weight_sigma)
|
48 |
+
|
49 |
+
# TODO sparse implemention
|
50 |
+
# support = torch.mm(input, self.weight)
|
51 |
+
# output = torch.spmm(adj, support)
|
52 |
+
# return output + self.bias
|
53 |
+
|
54 |
+
def __repr__(self):
|
55 |
+
return self.__class__.__name__ + ' (' \
|
56 |
+
+ str(self.in_features) + ' -> ' \
|
57 |
+
+ str(self.out_features) + ')'
|
58 |
+
|
59 |
+
|
60 |
+
class RGCN(Module):
|
61 |
+
|
62 |
+
def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'):
|
63 |
+
super(RGCN, self).__init__()
|
64 |
+
|
65 |
+
self.device = device
|
66 |
+
# adj_norm = normalize(adj)
|
67 |
+
# first turn original features to distribution
|
68 |
+
self.lr = lr
|
69 |
+
self.gamma = gamma
|
70 |
+
self.beta1 = beta1
|
71 |
+
self.beta2 = beta2
|
72 |
+
self.nclass = nclass
|
73 |
+
self.nhid = nhid
|
74 |
+
self.gc1 = GaussianConvolution(nfeat, nhid)
|
75 |
+
# self.gc2 = GaussianConvolution(nhid, nhid)
|
76 |
+
# self.gc3 = GaussianConvolution(nhid, nclass)
|
77 |
+
self.gc2 = GaussianConvolution(nhid, nclass)
|
78 |
+
|
79 |
+
self.dropout = dropout
|
80 |
+
# self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass))
|
81 |
+
self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass),
|
82 |
+
torch.diag_embed(torch.ones(nnodes, self.nclass)))
|
83 |
+
self.miu1 = None
|
84 |
+
self.sigma1 = None
|
85 |
+
self.adj_norm1, self.adj_norm2 = None, None
|
86 |
+
self.features, self.labels = None, None
|
87 |
+
|
88 |
+
def forward(self):
|
89 |
+
features = self.features
|
90 |
+
miu, sigma = self.gc1(features, features)
|
91 |
+
miu, sigma = F.elu(miu, alpha=1), F.relu(sigma)
|
92 |
+
self.miu1, self.sigma1 = miu, sigma
|
93 |
+
miu = F.dropout(miu, self.dropout, training=self.training)
|
94 |
+
sigma = F.dropout(sigma, self.dropout, training=self.training)
|
95 |
+
|
96 |
+
miu, sigma = self.gc2(miu, sigma, self.adj_norm1, self.adj_norm2, self.gamma)
|
97 |
+
miu, sigma = F.elu(miu, alpha=1), F.relu(sigma)
|
98 |
+
|
99 |
+
# # third layer
|
100 |
+
# miu = F.dropout(miu, self.dropout, training=self.training)
|
101 |
+
# sigma = F.dropout(sigma, self.dropout, training=self.training)
|
102 |
+
# miu, sigma = self.gc3(miu, sigma, self.adj_norm1, self.adj_norm2, self.gamma)
|
103 |
+
# miu, sigma = F.elu(miu), F.relu(sigma)
|
104 |
+
output = miu + self.gaussian.sample().to(self.device) * torch.sqrt(sigma + 1e-8)
|
105 |
+
return F.log_softmax(output, dim=1)
|
106 |
+
|
107 |
+
def fit_(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, verbose=True):
|
108 |
+
|
109 |
+
adj, features, labels = utils.to_tensor(adj.todense(), features, labels, device=self.device)
|
110 |
+
self.features, self.labels = features, labels
|
111 |
+
self.adj_norm1 = self._normalize_adj(adj, power=-1/2)
|
112 |
+
self.adj_norm2 = self._normalize_adj(adj, power=-1)
|
113 |
+
print('=== training rgcn model ===')
|
114 |
+
self._initialize()
|
115 |
+
if idx_val is None:
|
116 |
+
self._train_without_val(labels, idx_train, train_iters, verbose)
|
117 |
+
else:
|
118 |
+
self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
|
119 |
+
|
120 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose=True):
|
121 |
+
print('=== training gcn model ===')
|
122 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
|
123 |
+
self.train()
|
124 |
+
for i in range(train_iters):
|
125 |
+
optimizer.zero_grad()
|
126 |
+
output = self.forward()
|
127 |
+
loss_train = self.loss(output[idx_train], labels[idx_train])
|
128 |
+
loss_train.backward()
|
129 |
+
optimizer.step()
|
130 |
+
if verbose and i % 10 == 0:
|
131 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
132 |
+
|
133 |
+
self.eval()
|
134 |
+
output = self.forward()
|
135 |
+
self.output = output
|
136 |
+
|
137 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
138 |
+
print('=== training gcn model ===')
|
139 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
|
140 |
+
|
141 |
+
best_loss_val = 100
|
142 |
+
best_acc_val = 0
|
143 |
+
|
144 |
+
for i in range(train_iters):
|
145 |
+
self.train()
|
146 |
+
optimizer.zero_grad()
|
147 |
+
output = self.forward()
|
148 |
+
loss_train = self.loss(output[idx_train], labels[idx_train])
|
149 |
+
loss_train.backward()
|
150 |
+
optimizer.step()
|
151 |
+
if verbose and i % 10 == 0:
|
152 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
153 |
+
|
154 |
+
self.eval()
|
155 |
+
output = self.forward()
|
156 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
157 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
158 |
+
|
159 |
+
if best_loss_val > loss_val:
|
160 |
+
best_loss_val = loss_val
|
161 |
+
self.output = output
|
162 |
+
|
163 |
+
if acc_val > best_acc_val:
|
164 |
+
best_acc_val = acc_val
|
165 |
+
self.output = output
|
166 |
+
|
167 |
+
print('=== picking the best model according to the performance on validation ===')
|
168 |
+
|
169 |
+
|
170 |
+
def test(self, idx_test):
|
171 |
+
# output = self.forward()
|
172 |
+
output = self.output
|
173 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
174 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
175 |
+
print("Test set results:",
|
176 |
+
"loss= {:.4f}".format(loss_test.item()),
|
177 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
178 |
+
|
179 |
+
def loss(self, input, labels):
|
180 |
+
loss = F.nll_loss(input, labels)
|
181 |
+
kl_loss = 0.5 * (self.miu1.pow(2) + self.sigma1 - torch.log(1e-8 + self.sigma1)).mean(1)
|
182 |
+
kl_loss = kl_loss.sum()
|
183 |
+
norm2 = torch.norm(self.gc1.weight_miu, 2).pow(2) + \
|
184 |
+
torch.norm(self.gc1.weight_sigma, 2).pow(2)
|
185 |
+
# print(f'gcn_loss: {loss.item()}, kl_loss: {self.beta1 * kl_loss.item()}, norm2: {self.beta2 * norm2.item()}')
|
186 |
+
return loss + self.beta1 * kl_loss + self.beta2 * norm2
|
187 |
+
|
188 |
+
def _initialize(self):
|
189 |
+
self.gc1.reset_parameters()
|
190 |
+
self.gc2.reset_parameters()
|
191 |
+
|
192 |
+
def _normalize_adj(self, adj, power=-1/2):
|
193 |
+
|
194 |
+
"""Row-normalize sparse matrix"""
|
195 |
+
A = adj + torch.eye(len(adj)).to(self.device)
|
196 |
+
D_power = (A.sum(1)).pow(power)
|
197 |
+
D_power[torch.isinf(D_power)] = 0.
|
198 |
+
D_power = torch.diag(D_power)
|
199 |
+
return D_power @ A @ D_power
|
200 |
+
|
deeprobust/graph/defense/sgc.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Extended from https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation
|
3 |
+
"""
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.optim as optim
|
9 |
+
from torch.nn.parameter import Parameter
|
10 |
+
from torch.nn.modules.module import Module
|
11 |
+
from deeprobust.graph import utils
|
12 |
+
from copy import deepcopy
|
13 |
+
from torch_geometric.nn import SGConv
|
14 |
+
|
15 |
+
class SGC(torch.nn.Module):
|
16 |
+
""" SGC based on pytorch geometric. Simplifying Graph Convolutional Networks.
|
17 |
+
|
18 |
+
Parameters
|
19 |
+
----------
|
20 |
+
nfeat : int
|
21 |
+
size of input feature dimension
|
22 |
+
nclass : int
|
23 |
+
size of output dimension
|
24 |
+
K: int
|
25 |
+
number of propagation in SGC
|
26 |
+
cached : bool
|
27 |
+
whether to set the cache flag in SGConv
|
28 |
+
lr : float
|
29 |
+
learning rate for SGC
|
30 |
+
weight_decay : float
|
31 |
+
weight decay coefficient (l2 normalization) for GCN.
|
32 |
+
When `with_relu` is True, `weight_decay` will be set to 0.
|
33 |
+
with_bias: bool
|
34 |
+
whether to include bias term in SGC weights.
|
35 |
+
device: str
|
36 |
+
'cpu' or 'cuda'.
|
37 |
+
|
38 |
+
Examples
|
39 |
+
--------
|
40 |
+
We can first load dataset and then train SGC.
|
41 |
+
|
42 |
+
>>> from deeprobust.graph.data import Dataset
|
43 |
+
>>> from deeprobust.graph.defense import SGC
|
44 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
45 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
46 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
47 |
+
>>> sgc = SGC(nfeat=features.shape[1], K=3, lr=0.1,
|
48 |
+
nclass=labels.max().item() + 1, device='cuda')
|
49 |
+
>>> sgc = sgc.to('cuda')
|
50 |
+
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
|
51 |
+
>>> sgc.fit(pyg_data, train_iters=200, patience=200, verbose=True) # train with earlystopping
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
def __init__(self, nfeat, nclass, K=3, cached=True, lr=0.01,
|
56 |
+
weight_decay=5e-4, with_bias=True, device=None):
|
57 |
+
|
58 |
+
super(SGC, self).__init__()
|
59 |
+
|
60 |
+
assert device is not None, "Please specify 'device'!"
|
61 |
+
self.device = device
|
62 |
+
|
63 |
+
self.conv1 = SGConv(nfeat,
|
64 |
+
nclass, bias=with_bias, K=K, cached=cached)
|
65 |
+
|
66 |
+
self.weight_decay = weight_decay
|
67 |
+
self.lr = lr
|
68 |
+
self.output = None
|
69 |
+
self.best_model = None
|
70 |
+
self.best_output = None
|
71 |
+
|
72 |
+
def forward(self, data):
|
73 |
+
x, edge_index = data.x, data.edge_index
|
74 |
+
x = self.conv1(x, edge_index)
|
75 |
+
return F.log_softmax(x, dim=1)
|
76 |
+
|
77 |
+
def initialize(self):
|
78 |
+
"""Initialize parameters of SGC.
|
79 |
+
"""
|
80 |
+
self.conv1.reset_parameters()
|
81 |
+
|
82 |
+
def fit(self, pyg_data, train_iters=200, initialize=True, verbose=False, patience=500, **kwargs):
|
83 |
+
"""Train the SGC model, when idx_val is not None, pick the best model
|
84 |
+
according to the validation loss.
|
85 |
+
|
86 |
+
Parameters
|
87 |
+
----------
|
88 |
+
pyg_data :
|
89 |
+
pytorch geometric dataset object
|
90 |
+
train_iters : int
|
91 |
+
number of training epochs
|
92 |
+
initialize : bool
|
93 |
+
whether to initialize parameters before training
|
94 |
+
verbose : bool
|
95 |
+
whether to show verbose logs
|
96 |
+
patience : int
|
97 |
+
patience for early stopping, only valid when `idx_val` is given
|
98 |
+
"""
|
99 |
+
|
100 |
+
# self.device = self.conv1.weight.device
|
101 |
+
if initialize:
|
102 |
+
self.initialize()
|
103 |
+
|
104 |
+
self.data = pyg_data[0].to(self.device)
|
105 |
+
# By default, it is trained with early stopping on validation
|
106 |
+
self.train_with_early_stopping(train_iters, patience, verbose)
|
107 |
+
|
108 |
+
def train_with_early_stopping(self, train_iters, patience, verbose):
|
109 |
+
"""early stopping based on the validation loss
|
110 |
+
"""
|
111 |
+
if verbose:
|
112 |
+
print('=== training SGC model ===')
|
113 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
114 |
+
|
115 |
+
labels = self.data.y
|
116 |
+
train_mask, val_mask = self.data.train_mask, self.data.val_mask
|
117 |
+
|
118 |
+
early_stopping = patience
|
119 |
+
best_loss_val = 100
|
120 |
+
|
121 |
+
for i in range(train_iters):
|
122 |
+
self.train()
|
123 |
+
optimizer.zero_grad()
|
124 |
+
output = self.forward(self.data)
|
125 |
+
|
126 |
+
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
|
127 |
+
loss_train.backward()
|
128 |
+
optimizer.step()
|
129 |
+
|
130 |
+
if verbose and i % 10 == 0:
|
131 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
132 |
+
|
133 |
+
self.eval()
|
134 |
+
output = self.forward(self.data)
|
135 |
+
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
|
136 |
+
|
137 |
+
if best_loss_val > loss_val:
|
138 |
+
best_loss_val = loss_val
|
139 |
+
self.output = output
|
140 |
+
weights = deepcopy(self.state_dict())
|
141 |
+
patience = early_stopping
|
142 |
+
else:
|
143 |
+
patience -= 1
|
144 |
+
if i > early_stopping and patience <= 0:
|
145 |
+
break
|
146 |
+
|
147 |
+
if verbose:
|
148 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
|
149 |
+
self.load_state_dict(weights)
|
150 |
+
|
151 |
+
def test(self):
|
152 |
+
"""Evaluate SGC performance on test set.
|
153 |
+
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
idx_test :
|
157 |
+
node testing indices
|
158 |
+
"""
|
159 |
+
self.eval()
|
160 |
+
test_mask = self.data.test_mask
|
161 |
+
labels = self.data.y
|
162 |
+
output = self.forward(self.data)
|
163 |
+
# output = self.output
|
164 |
+
loss_test = F.nll_loss(output[test_mask], labels[test_mask])
|
165 |
+
acc_test = utils.accuracy(output[test_mask], labels[test_mask])
|
166 |
+
print("Test set results:",
|
167 |
+
"loss= {:.4f}".format(loss_test.item()),
|
168 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
169 |
+
return acc_test.item()
|
170 |
+
|
171 |
+
def predict(self):
|
172 |
+
"""
|
173 |
+
Returns
|
174 |
+
-------
|
175 |
+
torch.FloatTensor
|
176 |
+
output (log probabilities) of SGC
|
177 |
+
"""
|
178 |
+
|
179 |
+
self.eval()
|
180 |
+
return self.forward(self.data)
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
185 |
+
# from deeprobust.graph.defense import SGC
|
186 |
+
data = Dataset(root='/tmp/', name='cora')
|
187 |
+
adj, features, labels = data.adj, data.features, data.labels
|
188 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
189 |
+
sgc = SGC(nfeat=features.shape[1],
|
190 |
+
nclass=labels.max().item() + 1, device='cpu')
|
191 |
+
sgc = sgc.to('cpu')
|
192 |
+
pyg_data = Dpr2Pyg(data)
|
193 |
+
sgc.fit(pyg_data, verbose=True) # train with earlystopping
|
194 |
+
sgc.test()
|
195 |
+
print(sgc.predict())
|
196 |
+
|
deeprobust/graph/defense_pyg/airgnn.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.nn import Linear
|
4 |
+
from torch_geometric.nn.conv.gcn_conv import gcn_norm
|
5 |
+
from torch_geometric.nn.conv import MessagePassing
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
from torch_geometric.typing import Adj, OptTensor
|
8 |
+
from torch import Tensor
|
9 |
+
from torch_sparse import SparseTensor, matmul
|
10 |
+
from .base_model import BaseModel
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
class AirGNN(BaseModel):
|
14 |
+
|
15 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, K=2, dropout=0.5, lr=0.01,
|
16 |
+
with_bn=False, weight_decay=5e-4, with_bias=True, device=None, args=None):
|
17 |
+
|
18 |
+
super(AirGNN, self).__init__()
|
19 |
+
assert device is not None, "Please specify 'device'!"
|
20 |
+
self.device = device
|
21 |
+
|
22 |
+
self.lins = nn.ModuleList([])
|
23 |
+
self.lins.append(Linear(nfeat, nhid))
|
24 |
+
if with_bn:
|
25 |
+
self.bns = nn.ModuleList([])
|
26 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
27 |
+
for i in range(nlayers-2):
|
28 |
+
self.lins.append(Linear(nhid, nhid))
|
29 |
+
if with_bn:
|
30 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
31 |
+
self.lins.append(Linear(nhid, nclass))
|
32 |
+
|
33 |
+
self.prop = AdaptiveMessagePassing(K=K, alpha=args.alpha, mode=args.model, args=args)
|
34 |
+
print(self.prop)
|
35 |
+
|
36 |
+
self.dropout = dropout
|
37 |
+
self.weight_decay = weight_decay
|
38 |
+
self.lr = lr
|
39 |
+
self.name = args.model
|
40 |
+
self.with_bn = with_bn
|
41 |
+
|
42 |
+
def initialize(self):
|
43 |
+
self.reset_parameters()
|
44 |
+
|
45 |
+
def reset_parameters(self):
|
46 |
+
for lin in self.lins:
|
47 |
+
lin.reset_parameters()
|
48 |
+
if self.with_bn:
|
49 |
+
for bn in self.bns:
|
50 |
+
bn.reset_parameters()
|
51 |
+
self.prop.reset_parameters()
|
52 |
+
|
53 |
+
def forward(self, x, edge_index, edge_weight=None):
|
54 |
+
x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight)
|
55 |
+
edge_index = SparseTensor.from_edge_index(edge_index, edge_weight,
|
56 |
+
sparse_sizes=2 * x.shape[:1]).t()
|
57 |
+
for ii, lin in enumerate(self.lins[:-1]):
|
58 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
59 |
+
x = lin(x)
|
60 |
+
if self.with_bn:
|
61 |
+
x = self.bns[ii](x)
|
62 |
+
x = F.relu(x)
|
63 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
64 |
+
x = self.lins[-1](x)
|
65 |
+
x = self.prop(x, edge_index)
|
66 |
+
return F.log_softmax(x, dim=1)
|
67 |
+
|
68 |
+
def get_embed(self, x, edge_index, edge_weight=None):
|
69 |
+
x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight)
|
70 |
+
edge_index = SparseTensor.from_edge_index(edge_index, edge_weight,
|
71 |
+
sparse_sizes=2 * x.shape[:1]).t()
|
72 |
+
for ii, lin in enumerate(self.lins[:-1]):
|
73 |
+
x = lin(x)
|
74 |
+
if self.with_bn:
|
75 |
+
x = self.bns[ii](x)
|
76 |
+
x = F.relu(x)
|
77 |
+
x = self.prop(x, edge_index)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class AdaptiveMessagePassing(MessagePassing):
|
82 |
+
_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
|
83 |
+
_cached_adj_t: Optional[SparseTensor]
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
K: int,
|
87 |
+
alpha: float,
|
88 |
+
dropout: float = 0.,
|
89 |
+
cached: bool = False,
|
90 |
+
add_self_loops: bool = True,
|
91 |
+
normalize: bool = True,
|
92 |
+
mode: str = None,
|
93 |
+
node_num: int = None,
|
94 |
+
args=None,
|
95 |
+
**kwargs):
|
96 |
+
|
97 |
+
super(AdaptiveMessagePassing, self).__init__(aggr='add', **kwargs)
|
98 |
+
self.K = K
|
99 |
+
self.alpha = alpha
|
100 |
+
self.mode = mode
|
101 |
+
self.dropout = dropout
|
102 |
+
self.cached = cached
|
103 |
+
self.add_self_loops = add_self_loops
|
104 |
+
self.normalize = normalize
|
105 |
+
self._cached_edge_index = None
|
106 |
+
self.node_num = node_num
|
107 |
+
self.args = args
|
108 |
+
self._cached_adj_t = None
|
109 |
+
|
110 |
+
def reset_parameters(self):
|
111 |
+
self._cached_edge_index = None
|
112 |
+
self._cached_adj_t = None
|
113 |
+
|
114 |
+
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mode=None) -> Tensor:
|
115 |
+
if self.normalize:
|
116 |
+
if isinstance(edge_index, Tensor):
|
117 |
+
raise ValueError('Only support SparseTensor now')
|
118 |
+
|
119 |
+
elif isinstance(edge_index, SparseTensor):
|
120 |
+
cache = self._cached_adj_t
|
121 |
+
if cache is None:
|
122 |
+
edge_index = gcn_norm( # yapf: disable
|
123 |
+
edge_index, edge_weight, x.size(self.node_dim), False,
|
124 |
+
add_self_loops=self.add_self_loops, dtype=x.dtype)
|
125 |
+
if self.cached:
|
126 |
+
self._cached_adj_t = edge_index
|
127 |
+
else:
|
128 |
+
edge_index = cache
|
129 |
+
|
130 |
+
if mode == None: mode = self.mode
|
131 |
+
|
132 |
+
if self.K <= 0:
|
133 |
+
return x
|
134 |
+
hh = x
|
135 |
+
|
136 |
+
if mode == 'MLP':
|
137 |
+
return x
|
138 |
+
|
139 |
+
elif mode == 'APPNP':
|
140 |
+
x = self.appnp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K, alpha=self.alpha)
|
141 |
+
|
142 |
+
elif mode in ['AirGNN']:
|
143 |
+
x = self.amp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K)
|
144 |
+
else:
|
145 |
+
raise ValueError('wrong propagate mode')
|
146 |
+
return x
|
147 |
+
|
148 |
+
def appnp_forward(self, x, hh, edge_index, K, alpha):
|
149 |
+
for k in range(K):
|
150 |
+
x = self.propagate(edge_index, x=x, edge_weight=None, size=None)
|
151 |
+
x = x * (1 - alpha)
|
152 |
+
x += alpha * hh
|
153 |
+
return x
|
154 |
+
|
155 |
+
def amp_forward(self, x, hh, K, edge_index):
|
156 |
+
lambda_amp = self.args.lambda_amp
|
157 |
+
gamma = 1 / (2 * (1 - lambda_amp)) ## or simply gamma = 1
|
158 |
+
|
159 |
+
for k in range(K):
|
160 |
+
y = x - gamma * 2 * (1 - lambda_amp) * self.compute_LX(x=x, edge_index=edge_index) # Equation (9)
|
161 |
+
x = hh + self.proximal_L21(x=y - hh, lambda_=gamma * lambda_amp) # Equation (11) and (12)
|
162 |
+
return x
|
163 |
+
|
164 |
+
def proximal_L21(self, x: Tensor, lambda_):
|
165 |
+
row_norm = torch.norm(x, p=2, dim=1)
|
166 |
+
score = torch.clamp(row_norm - lambda_, min=0)
|
167 |
+
index = torch.where(row_norm > 0) # Deal with the case when the row_norm is 0
|
168 |
+
score[index] = score[index] / row_norm[index] # score is the adaptive score in Equation (14)
|
169 |
+
return score.unsqueeze(1) * x
|
170 |
+
|
171 |
+
def compute_LX(self, x, edge_index, edge_weight=None):
|
172 |
+
x = x - self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
|
173 |
+
return x
|
174 |
+
|
175 |
+
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
|
176 |
+
return edge_weight.view(-1, 1) * x_j
|
177 |
+
|
178 |
+
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
|
179 |
+
return matmul(adj_t, x, reduce=self.aggr)
|
180 |
+
|
181 |
+
def __repr__(self):
|
182 |
+
return '{}(K={}, alpha={}, mode={}, dropout={}, lambda_amp={})'.format(self.__class__.__name__, self.K,
|
183 |
+
self.alpha, self.mode, self.dropout,
|
184 |
+
self.args.lambda_amp)
|
185 |
+
|
186 |
+
|
deeprobust/graph/defense_pyg/sage.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch_sparse import SparseTensor, matmul
|
5 |
+
# from torch_geometric.nn import SAGEConv, GATConv, APPNP, MessagePassing
|
6 |
+
from torch_geometric.nn.conv.gcn_conv import gcn_norm
|
7 |
+
import scipy.sparse
|
8 |
+
import numpy as np
|
9 |
+
from .base_model import BaseModel
|
10 |
+
|
11 |
+
|
12 |
+
class SAGE(BaseModel):
|
13 |
+
|
14 |
+
def __init__(self, nfeat, nhid, nclass, num_layers=2,
|
15 |
+
dropout=0.5, lr=0.01, weight_decay=0, device='cpu', with_bn=False, **kwargs):
|
16 |
+
super(SAGE, self).__init__()
|
17 |
+
|
18 |
+
self.convs = nn.ModuleList()
|
19 |
+
self.convs.append(
|
20 |
+
SAGEConv(nfeat, nhid))
|
21 |
+
|
22 |
+
self.bns = nn.ModuleList()
|
23 |
+
if 'nlayers' in kwargs:
|
24 |
+
num_layers = kwargs['nlayers']
|
25 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
26 |
+
for _ in range(num_layers - 2):
|
27 |
+
self.convs.append(
|
28 |
+
SAGEConv(nhid, nhid))
|
29 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
30 |
+
|
31 |
+
self.convs.append(
|
32 |
+
SAGEConv(nhid, nclass))
|
33 |
+
|
34 |
+
self.weight_decay = weight_decay
|
35 |
+
self.lr = lr
|
36 |
+
self.dropout = dropout
|
37 |
+
self.activation = F.relu
|
38 |
+
self.with_bn = with_bn
|
39 |
+
self.device = device
|
40 |
+
self.name = "SAGE"
|
41 |
+
|
42 |
+
def initialize(self):
|
43 |
+
self.reset_parameters()
|
44 |
+
|
45 |
+
def reset_parameters(self):
|
46 |
+
for conv in self.convs:
|
47 |
+
conv.reset_parameters()
|
48 |
+
for bn in self.bns:
|
49 |
+
bn.reset_parameters()
|
50 |
+
|
51 |
+
|
52 |
+
def forward(self, x, edge_index, edge_weight=None):
|
53 |
+
if edge_weight is not None:
|
54 |
+
adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t()
|
55 |
+
|
56 |
+
for i, conv in enumerate(self.convs[:-1]):
|
57 |
+
if edge_weight is not None:
|
58 |
+
x = conv(x, adj)
|
59 |
+
else:
|
60 |
+
x = conv(x, edge_index, edge_weight)
|
61 |
+
if self.with_bn:
|
62 |
+
x = self.bns[i](x)
|
63 |
+
x = self.activation(x)
|
64 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
65 |
+
if edge_weight is not None:
|
66 |
+
x = self.convs[-1](x, adj)
|
67 |
+
else:
|
68 |
+
x = self.convs[-1](x, edge_index, edge_weight)
|
69 |
+
return F.log_softmax(x, dim=1)
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
from typing import Union, Tuple
|
74 |
+
from torch_geometric.typing import OptPairTensor, Adj, Size
|
75 |
+
|
76 |
+
from torch import Tensor
|
77 |
+
from torch.nn import Linear
|
78 |
+
import torch.nn.functional as F
|
79 |
+
from torch_sparse import SparseTensor, matmul
|
80 |
+
from torch_geometric.nn.conv import MessagePassing
|
81 |
+
|
82 |
+
|
83 |
+
class SAGEConv(MessagePassing):
|
84 |
+
r"""The GraphSAGE operator from the `"Inductive Representation Learning on
|
85 |
+
Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper
|
86 |
+
|
87 |
+
.. math::
|
88 |
+
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot
|
89 |
+
\mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j
|
90 |
+
|
91 |
+
Args:
|
92 |
+
in_channels (int or tuple): Size of each input sample. A tuple
|
93 |
+
corresponds to the sizes of source and target dimensionalities.
|
94 |
+
out_channels (int): Size of each output sample.
|
95 |
+
normalize (bool, optional): If set to :obj:`True`, output features
|
96 |
+
will be :math:`\ell_2`-normalized, *i.e.*,
|
97 |
+
:math:`\frac{\mathbf{x}^{\prime}_i}
|
98 |
+
{\| \mathbf{x}^{\prime}_i \|_2}`.
|
99 |
+
(default: :obj:`False`)
|
100 |
+
bias (bool, optional): If set to :obj:`False`, the layer will not learn
|
101 |
+
an additive bias. (default: :obj:`True`)
|
102 |
+
**kwargs (optional): Additional arguments of
|
103 |
+
:class:`torch_geometric.nn.conv.MessagePassing`.
|
104 |
+
"""
|
105 |
+
def __init__(self, in_channels: Union[int, Tuple[int, int]],
|
106 |
+
out_channels: int, normalize: bool = False,
|
107 |
+
bias: bool = True, **kwargs): # yapf: disable
|
108 |
+
kwargs.setdefault('aggr', 'mean')
|
109 |
+
super(SAGEConv, self).__init__(**kwargs)
|
110 |
+
|
111 |
+
self.in_channels = in_channels
|
112 |
+
self.out_channels = out_channels
|
113 |
+
self.normalize = normalize
|
114 |
+
|
115 |
+
if isinstance(in_channels, int):
|
116 |
+
in_channels = (in_channels, in_channels)
|
117 |
+
|
118 |
+
self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
|
119 |
+
self.lin_r = Linear(in_channels[1], out_channels, bias=False)
|
120 |
+
|
121 |
+
self.reset_parameters()
|
122 |
+
|
123 |
+
def reset_parameters(self):
|
124 |
+
self.lin_l.reset_parameters()
|
125 |
+
self.lin_r.reset_parameters()
|
126 |
+
|
127 |
+
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
|
128 |
+
size: Size = None) -> Tensor:
|
129 |
+
""""""
|
130 |
+
if isinstance(x, Tensor):
|
131 |
+
x: OptPairTensor = (x, x)
|
132 |
+
|
133 |
+
# propagate_type: (x: OptPairTensor)
|
134 |
+
out = self.propagate(edge_index, x=x, size=size)
|
135 |
+
out = self.lin_l(out)
|
136 |
+
|
137 |
+
x_r = x[1]
|
138 |
+
if x_r is not None:
|
139 |
+
out += self.lin_r(x_r)
|
140 |
+
|
141 |
+
if self.normalize:
|
142 |
+
out = F.normalize(out, p=2., dim=-1)
|
143 |
+
|
144 |
+
return out
|
145 |
+
|
146 |
+
def message(self, x_j: Tensor) -> Tensor:
|
147 |
+
return x_j
|
148 |
+
|
149 |
+
def message_and_aggregate(self, adj_t: SparseTensor,
|
150 |
+
x: OptPairTensor) -> Tensor:
|
151 |
+
# Deleted the following line to make propagation differentiable
|
152 |
+
# adj_t = adj_t.set_value(None, layout=None)
|
153 |
+
return matmul(adj_t, x[0], reduce=self.aggr)
|
154 |
+
|
155 |
+
def __repr__(self):
|
156 |
+
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
|
157 |
+
self.out_channels)
|
deeprobust/graph/global_attack/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_attack import BaseAttack
|
2 |
+
from .dice import DICE
|
3 |
+
from .mettack import MetaApprox, Metattack
|
4 |
+
from .random_attack import Random
|
5 |
+
from .topology_attack import MinMax, PGDAttack
|
6 |
+
from .node_embedding_attack import NodeEmbeddingAttack, OtherNodeEmbeddingAttack
|
7 |
+
from .nipa import NIPA
|
8 |
+
|
9 |
+
try:
|
10 |
+
from .prbcd import PRBCD
|
11 |
+
except ImportError as e:
|
12 |
+
print(e)
|
13 |
+
warnings.warn("Please install pytorch geometric if you " +
|
14 |
+
"would like to use the datasets from pytorch " +
|
15 |
+
"geometric. See details in https://pytorch-geom" +
|
16 |
+
"etric.readthedocs.io/en/latest/notes/installation.html")
|
17 |
+
|
18 |
+
__all__ = ['BaseAttack', 'DICE', 'MetaApprox', 'Metattack', 'Random', 'MinMax', 'PGDAttack', 'NIPA', 'NodeEmbeddingAttack', 'OtherNodeEmbeddingAttack', 'PRBCD']
|
deeprobust/graph/global_attack/dice.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import scipy.sparse as sp
|
4 |
+
from deeprobust.graph.global_attack import BaseAttack
|
5 |
+
|
6 |
+
class DICE(BaseAttack):
|
7 |
+
"""As is described in ADVERSARIAL ATTACKS ON GRAPH NEURAL NETWORKS VIA META LEARNING (ICLR'19),
|
8 |
+
'DICE (delete internally, connect externally) is a baseline where, for each perturbation,
|
9 |
+
we randomly choose whether to insert or remove an edge. Edges are only removed between
|
10 |
+
nodes from the same classes, and only inserted between nodes from different classes.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
model :
|
15 |
+
model to attack. Default `None`.
|
16 |
+
nnodes : int
|
17 |
+
number of nodes in the input graph
|
18 |
+
attack_structure : bool
|
19 |
+
whether to attack graph structure
|
20 |
+
attack_features : bool
|
21 |
+
whether to attack node features
|
22 |
+
device: str
|
23 |
+
'cpu' or 'cuda'
|
24 |
+
|
25 |
+
|
26 |
+
Examples
|
27 |
+
--------
|
28 |
+
|
29 |
+
>>> from deeprobust.graph.data import Dataset
|
30 |
+
>>> from deeprobust.graph.global_attack import DICE
|
31 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
32 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
33 |
+
>>> model = DICE()
|
34 |
+
>>> model.attack(adj, labels, n_perturbations=10)
|
35 |
+
>>> modified_adj = model.modified_adj
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, model=None, nnodes=None, attack_structure=True, attack_features=False, device='cpu'):
|
40 |
+
super(DICE, self).__init__(model, nnodes, attack_structure=attack_structure, attack_features=attack_features, device=device)
|
41 |
+
|
42 |
+
assert not self.attack_features, 'DICE does NOT support attacking features'
|
43 |
+
|
44 |
+
def attack(self, ori_adj, labels, n_perturbations, **kwargs):
|
45 |
+
"""Delete internally, connect externally. This baseline has all true class labels
|
46 |
+
(train and test) available.
|
47 |
+
|
48 |
+
Parameters
|
49 |
+
----------
|
50 |
+
ori_adj : scipy.sparse.csr_matrix
|
51 |
+
Original (unperturbed) adjacency matrix.
|
52 |
+
labels:
|
53 |
+
node labels
|
54 |
+
n_perturbations : int
|
55 |
+
Number of edge removals/additions.
|
56 |
+
|
57 |
+
Returns
|
58 |
+
-------
|
59 |
+
None.
|
60 |
+
|
61 |
+
"""
|
62 |
+
|
63 |
+
# ori_adj: sp.csr_matrix
|
64 |
+
|
65 |
+
print('number of pertubations: %s' % n_perturbations)
|
66 |
+
modified_adj = ori_adj.tolil()
|
67 |
+
|
68 |
+
remove_or_insert = np.random.choice(2, n_perturbations)
|
69 |
+
n_remove = sum(remove_or_insert)
|
70 |
+
|
71 |
+
nonzero = set(zip(*ori_adj.nonzero()))
|
72 |
+
indices = sp.triu(modified_adj).nonzero()
|
73 |
+
possible_indices = [x for x in zip(indices[0], indices[1])
|
74 |
+
if labels[x[0]] == labels[x[1]]]
|
75 |
+
|
76 |
+
remove_indices = np.random.permutation(possible_indices)[: n_remove]
|
77 |
+
modified_adj[remove_indices[:, 0], remove_indices[:, 1]] = 0
|
78 |
+
modified_adj[remove_indices[:, 1], remove_indices[:, 0]] = 0
|
79 |
+
|
80 |
+
n_insert = n_perturbations - n_remove
|
81 |
+
|
82 |
+
# sample edges to add
|
83 |
+
added_edges = 0
|
84 |
+
while added_edges < n_insert:
|
85 |
+
n_remaining = n_insert - added_edges
|
86 |
+
|
87 |
+
# sample random pairs
|
88 |
+
candidate_edges = np.array([np.random.choice(ori_adj.shape[0], n_remaining),
|
89 |
+
np.random.choice(ori_adj.shape[0], n_remaining)]).T
|
90 |
+
|
91 |
+
# filter out existing edges, and pairs with the different labels
|
92 |
+
candidate_edges = set([(u, v) for u, v in candidate_edges if labels[u] != labels[v]
|
93 |
+
and modified_adj[u, v] == 0 and modified_adj[v, u] == 0])
|
94 |
+
candidate_edges = np.array(list(candidate_edges))
|
95 |
+
|
96 |
+
# if none is found, try again
|
97 |
+
if len(candidate_edges) == 0:
|
98 |
+
continue
|
99 |
+
|
100 |
+
# add all found edges to your modified adjacency matrix
|
101 |
+
modified_adj[candidate_edges[:, 0], candidate_edges[:, 1]] = 1
|
102 |
+
modified_adj[candidate_edges[:, 1], candidate_edges[:, 0]] = 1
|
103 |
+
added_edges += candidate_edges.shape[0]
|
104 |
+
|
105 |
+
self.check_adj(modified_adj)
|
106 |
+
self.modified_adj = modified_adj
|
107 |
+
|
108 |
+
|
109 |
+
def sample_forever(self, adj, exclude):
|
110 |
+
"""Randomly random sample edges from adjacency matrix, `exclude` is a set
|
111 |
+
which contains the edges we do not want to sample and the ones already sampled
|
112 |
+
"""
|
113 |
+
while True:
|
114 |
+
# t = tuple(np.random.randint(0, adj.shape[0], 2))
|
115 |
+
t = tuple(random.sample(range(0, adj.shape[0]), 2))
|
116 |
+
if t not in exclude:
|
117 |
+
yield t
|
118 |
+
exclude.add(t)
|
119 |
+
exclude.add((t[1], t[0]))
|
120 |
+
|
121 |
+
def random_sample_edges(self, adj, n, exclude):
|
122 |
+
itr = self.sample_forever(adj, exclude=exclude)
|
123 |
+
return [next(itr) for _ in range(n)]
|
deeprobust/graph/global_attack/ig_attack.py.backup
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective
|
3 |
+
https://arxiv.org/pdf/1906.04214.pdf
|
4 |
+
Tensorflow Implementation:
|
5 |
+
https://github.com/KaidiXu/GCN_ADV_Train
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import scipy.sparse as sp
|
10 |
+
import torch
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from torch.nn.parameter import Parameter
|
13 |
+
from tqdm import tqdm
|
14 |
+
import warnings
|
15 |
+
from deeprobust.graph import utils
|
16 |
+
from deeprobust.graph.global_attack import BaseAttack
|
17 |
+
|
18 |
+
|
19 |
+
class IGAttack(BaseAttack):
|
20 |
+
"""[Under Development] Untargeted Attack Version of IGAttack: IG-FGSM. Adversarial Examples on Graph Data: Deep Insights into Attack and Defense, https://arxiv.org/pdf/1903.01610.pdf.
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
model :
|
25 |
+
model to attack
|
26 |
+
nnodes : int
|
27 |
+
number of nodes in the input graph
|
28 |
+
feature_shape : tuple
|
29 |
+
shape of the input node features
|
30 |
+
attack_structure : bool
|
31 |
+
whether to attack graph structure
|
32 |
+
attack_features : bool
|
33 |
+
whether to attack node features
|
34 |
+
device: str
|
35 |
+
'cpu' or 'cuda'
|
36 |
+
|
37 |
+
"""
|
38 |
+
def __init__(self, model=None, nnodes=None, feature_shape=None, attack_structure=True, attack_features=False, device='cpu'):
|
39 |
+
|
40 |
+
super(IGAttack, self).__init__(model, nnodes, attack_structure, attack_features, device)
|
41 |
+
|
42 |
+
assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False'
|
43 |
+
|
44 |
+
self.modified_adj = None
|
45 |
+
self.modified_features = None
|
46 |
+
|
47 |
+
if attack_structure:
|
48 |
+
assert nnodes is not None, 'Please give nnodes='
|
49 |
+
self.adj_changes = Parameter(torch.FloatTensor(int(nnodes*(nnodes-1)/2)))
|
50 |
+
self.adj_changes.data.fill_(0)
|
51 |
+
|
52 |
+
if attack_features:
|
53 |
+
assert feature_shape is not None, 'Please give feature_shape='
|
54 |
+
self.feature_changes = Parameter(torch.FloatTensor(feature_shape))
|
55 |
+
self.feature_changes.data.fill_(0)
|
56 |
+
|
57 |
+
def attack(self, ori_features, ori_adj, labels, idx_train, n_perturbations, **kwargs):
|
58 |
+
"""Generate perturbations on the input graph.
|
59 |
+
|
60 |
+
Parameters
|
61 |
+
----------
|
62 |
+
ori_features :
|
63 |
+
Original (unperturbed) node feature matrix
|
64 |
+
ori_adj :
|
65 |
+
Original (unperturbed) adjacency matrix
|
66 |
+
labels :
|
67 |
+
node labels
|
68 |
+
idx_train :
|
69 |
+
node training indices
|
70 |
+
n_perturbations : int
|
71 |
+
Number of perturbations on the input graph. Perturbations could
|
72 |
+
be edge removals/additions or feature removals/additions.
|
73 |
+
"""
|
74 |
+
|
75 |
+
victim_model = self.surrogate
|
76 |
+
self.sparse_features = sp.issparse(ori_features)
|
77 |
+
ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)
|
78 |
+
|
79 |
+
victim_model.eval()
|
80 |
+
|
81 |
+
warnings.warn('This process is extremely slow!')
|
82 |
+
adj_norm = utils.normalize_adj_tensor(ori_adj)
|
83 |
+
s_e = self.calc_importance_edge(ori_features, adj_norm, labels, idx_train, steps=20)
|
84 |
+
s_f = self.calc_importance_feature(ori_features, adj_norm, labels, idx_train, steps=20)
|
85 |
+
|
86 |
+
import ipdb
|
87 |
+
ipdb.set_trace()
|
88 |
+
|
89 |
+
for t in tqdm(range(n_perturbations)):
|
90 |
+
modified_adj
|
91 |
+
|
92 |
+
self.adj_changes.data.copy_(torch.tensor(best_s))
|
93 |
+
self.modified_adj = self.get_modified_adj(ori_adj).detach()
|
94 |
+
|
95 |
+
def calc_importance_edge(self, features, adj, labels, idx_train, steps):
|
96 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
97 |
+
adj_norm.requires_grad = True
|
98 |
+
integrated_grad_list = []
|
99 |
+
for i in tqdm(range(adj.shape[0])):
|
100 |
+
for j in (range(adj.shape[1])):
|
101 |
+
if adj_norm[i][j]:
|
102 |
+
scaled_inputs = [(float(k)/ steps) * (adj_norm - 0) for k in range(0, steps + 1)]
|
103 |
+
else:
|
104 |
+
scaled_inputs = [-(float(k)/ steps) * (1 - adj_norm) for k in range(0, steps + 1)]
|
105 |
+
_sum = 0
|
106 |
+
|
107 |
+
# num_processes = steps
|
108 |
+
# # NOTE: this is required for the ``fork`` method to work
|
109 |
+
# self.surrogate.share_memory()
|
110 |
+
# processes = []
|
111 |
+
# for rank in range(num_processes):
|
112 |
+
# p = mp.Process(target=self.get_gradient, args=(features, scaled_inputs[rank], adj_norm, labels, idx_train))
|
113 |
+
# p.start()
|
114 |
+
# processes.append(p)
|
115 |
+
# for p in processes:
|
116 |
+
# p.join()
|
117 |
+
|
118 |
+
for new_adj in scaled_inputs:
|
119 |
+
output = self.surrogate(features, new_adj)
|
120 |
+
loss = F.nll_loss(output[idx_train], labels[idx_train])
|
121 |
+
# adj_grad = torch.autograd.grad(loss, adj[i][j], allow_unused=True)[0]
|
122 |
+
adj_grad = torch.autograd.grad(loss, adj_norm)[0]
|
123 |
+
adj_grad = adj_grad[i][j]
|
124 |
+
_sum += adj_grad
|
125 |
+
|
126 |
+
if adj_norm[i][j]:
|
127 |
+
avg_grad = (adj_norm[i][j] - 0) * _sum.mean()
|
128 |
+
else:
|
129 |
+
avg_grad = (1 - adj_norm[i][j]) * _sum.mean()
|
130 |
+
integrated_grad_list.append(avg_grad)
|
131 |
+
|
132 |
+
return integrated_grad_list
|
133 |
+
|
134 |
+
def get_gradient(self, features, new_adj, adj_norm, labels, idx_train):
|
135 |
+
output = self.surrogate(features, new_adj)
|
136 |
+
loss = F.nll_loss(output[idx_train], labels[idx_train])
|
137 |
+
# adj_grad = torch.autograd.grad(loss, adj[i][j], allow_unused=True)[0]
|
138 |
+
adj_grad = torch.autograd.grad(loss, adj_norm)[0]
|
139 |
+
adj_grad = adj_grad[i][j]
|
140 |
+
self._sum += adj_grad
|
141 |
+
|
142 |
+
def calc_importance_feature(self, features, adj_norm, labels, idx_train, steps):
|
143 |
+
features.requires_grad = True
|
144 |
+
integrated_grad_list = []
|
145 |
+
for i in range(features.shape[0]):
|
146 |
+
for j in range(features.shape[1]):
|
147 |
+
if features[i][j]:
|
148 |
+
scaled_inputs = [(float(k)/ steps) * (features - 0) for k in range(0, steps + 1)]
|
149 |
+
else:
|
150 |
+
scaled_inputs = [-(float(k)/ steps) * (1 - features) for k in range(0, steps + 1)]
|
151 |
+
_sum = 0
|
152 |
+
|
153 |
+
for new_features in scaled_inputs:
|
154 |
+
output = self.surrogate(new_features, adj_norm)
|
155 |
+
loss = F.nll_loss(output[idx_train], labels[idx_train])
|
156 |
+
# adj_grad = torch.autograd.grad(loss, adj[i][j], allow_unused=True)[0]
|
157 |
+
feature_grad = torch.autograd.grad(loss, features, allow_unused=True)[0]
|
158 |
+
feature_grad = feature_grad[i][j]
|
159 |
+
_sum += feature_grad
|
160 |
+
|
161 |
+
if adj_norm[i][j]:
|
162 |
+
avg_grad = (features[i][j] - 0) * _sum.mean()
|
163 |
+
else:
|
164 |
+
avg_grad = (1 - features[i][j]) * _sum.mean()
|
165 |
+
integrated_grad_list.append(avg_grad)
|
166 |
+
|
167 |
+
return integrated_grad_list
|
168 |
+
|
169 |
+
def calc_gradient_adj(self, inputs, features):
|
170 |
+
for adj in inputs:
|
171 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
172 |
+
output = self.surrogate(features, adj_norm)
|
173 |
+
loss = F.nll_loss(output[idx_train], labels[idx_train])
|
174 |
+
adj_grad = torch.autograd.grad(loss, inputs)[0]
|
175 |
+
return adj_grad.mean()
|
176 |
+
|
177 |
+
def calc_gradient_feature(self, adj_norm, inputs):
|
178 |
+
for features in inputs:
|
179 |
+
output = self.surrogate(features, adj_norm)
|
180 |
+
loss = F.nll_loss(output[idx_train], labels[idx_train])
|
181 |
+
adj_grad = torch.autograd.grad(loss, inputs)[0]
|
182 |
+
return adj_grad.mean()
|
183 |
+
|
184 |
+
def get_modified_adj(self, ori_adj):
|
185 |
+
adj_changes_square = self.adj_changes - torch.diag(torch.diag(self.adj_changes, 0))
|
186 |
+
ind = np.diag_indices(self.adj_changes.shape[0])
|
187 |
+
adj_changes_symm = torch.clamp(adj_changes_square + torch.transpose(adj_changes_square, 1, 0), -1, 1)
|
188 |
+
modified_adj = adj_changes_symm + ori_adj
|
189 |
+
return modified_adj
|
190 |
+
|
191 |
+
def get_modified_features(self, ori_features):
|
192 |
+
return ori_features + self.feature_changes
|
deeprobust/graph/global_attack/mettack.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Attacks on Graph Neural Networks via Meta Learning. ICLR 2019
|
3 |
+
https://openreview.net/pdf?id=Bylnx209YX
|
4 |
+
Author Tensorflow implementation:
|
5 |
+
https://github.com/danielzuegner/gnn-meta-attack
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import scipy.sparse as sp
|
11 |
+
import torch
|
12 |
+
from torch import optim
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torch.nn.parameter import Parameter
|
15 |
+
from tqdm import tqdm
|
16 |
+
from deeprobust.graph import utils
|
17 |
+
from deeprobust.graph.global_attack import BaseAttack
|
18 |
+
|
19 |
+
|
20 |
+
class BaseMeta(BaseAttack):
|
21 |
+
"""Abstract base class for meta attack. Adversarial Attacks on Graph Neural
|
22 |
+
Networks via Meta Learning, ICLR 2019,
|
23 |
+
https://openreview.net/pdf?id=Bylnx209YX
|
24 |
+
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
model :
|
28 |
+
model to attack. Default `None`.
|
29 |
+
nnodes : int
|
30 |
+
number of nodes in the input graph
|
31 |
+
lambda_ : float
|
32 |
+
lambda_ is used to weight the two objectives in Eq. (10) in the paper.
|
33 |
+
feature_shape : tuple
|
34 |
+
shape of the input node features
|
35 |
+
attack_structure : bool
|
36 |
+
whether to attack graph structure
|
37 |
+
attack_features : bool
|
38 |
+
whether to attack node features
|
39 |
+
undirected : bool
|
40 |
+
whether the graph is undirected
|
41 |
+
device: str
|
42 |
+
'cpu' or 'cuda'
|
43 |
+
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, model=None, nnodes=None, feature_shape=None, lambda_=0.5, attack_structure=True, attack_features=False, undirected=True, device='cpu'):
|
47 |
+
|
48 |
+
super(BaseMeta, self).__init__(model, nnodes, attack_structure, attack_features, device)
|
49 |
+
self.lambda_ = lambda_
|
50 |
+
|
51 |
+
assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False'
|
52 |
+
|
53 |
+
self.modified_adj = None
|
54 |
+
self.modified_features = None
|
55 |
+
|
56 |
+
if attack_structure:
|
57 |
+
self.undirected = undirected
|
58 |
+
assert nnodes is not None, 'Please give nnodes='
|
59 |
+
self.adj_changes = Parameter(torch.FloatTensor(nnodes, nnodes))
|
60 |
+
self.adj_changes.data.fill_(0)
|
61 |
+
|
62 |
+
if attack_features:
|
63 |
+
assert feature_shape is not None, 'Please give feature_shape='
|
64 |
+
self.feature_changes = Parameter(torch.FloatTensor(feature_shape))
|
65 |
+
self.feature_changes.data.fill_(0)
|
66 |
+
|
67 |
+
self.with_relu = model.with_relu
|
68 |
+
|
69 |
+
def attack(self, adj, labels, n_perturbations):
|
70 |
+
pass
|
71 |
+
|
72 |
+
def get_modified_adj(self, ori_adj):
|
73 |
+
adj_changes_square = self.adj_changes - torch.diag(torch.diag(self.adj_changes, 0))
|
74 |
+
# ind = np.diag_indices(self.adj_changes.shape[0]) # this line seems useless
|
75 |
+
if self.undirected:
|
76 |
+
adj_changes_square = adj_changes_square + torch.transpose(adj_changes_square, 1, 0)
|
77 |
+
adj_changes_square = torch.clamp(adj_changes_square, -1, 1)
|
78 |
+
modified_adj = adj_changes_square + ori_adj
|
79 |
+
return modified_adj
|
80 |
+
|
81 |
+
def get_modified_features(self, ori_features):
|
82 |
+
return ori_features + self.feature_changes
|
83 |
+
|
84 |
+
def filter_potential_singletons(self, modified_adj):
|
85 |
+
"""
|
86 |
+
Computes a mask for entries potentially leading to singleton nodes, i.e. one of the two nodes corresponding to
|
87 |
+
the entry have degree 1 and there is an edge between the two nodes.
|
88 |
+
"""
|
89 |
+
|
90 |
+
degrees = modified_adj.sum(0)
|
91 |
+
degree_one = (degrees == 1)
|
92 |
+
resh = degree_one.repeat(modified_adj.shape[0], 1).float()
|
93 |
+
l_and = resh * modified_adj
|
94 |
+
if self.undirected:
|
95 |
+
l_and = l_and + l_and.t()
|
96 |
+
flat_mask = 1 - l_and
|
97 |
+
return flat_mask
|
98 |
+
|
99 |
+
def self_training_label(self, labels, idx_train):
|
100 |
+
# Predict the labels of the unlabeled nodes to use them for self-training.
|
101 |
+
output = self.surrogate.output
|
102 |
+
labels_self_training = output.argmax(1)
|
103 |
+
labels_self_training[idx_train] = labels[idx_train]
|
104 |
+
return labels_self_training
|
105 |
+
|
106 |
+
|
107 |
+
def log_likelihood_constraint(self, modified_adj, ori_adj, ll_cutoff):
|
108 |
+
"""
|
109 |
+
Computes a mask for entries that, if the edge corresponding to the entry is added/removed, would lead to the
|
110 |
+
log likelihood constraint to be violated.
|
111 |
+
|
112 |
+
Note that different data type (float, double) can effect the final results.
|
113 |
+
"""
|
114 |
+
t_d_min = torch.tensor(2.0).to(self.device)
|
115 |
+
if self.undirected:
|
116 |
+
t_possible_edges = np.array(np.triu(np.ones((self.nnodes, self.nnodes)), k=1).nonzero()).T
|
117 |
+
else:
|
118 |
+
t_possible_edges = np.array((np.ones((self.nnodes, self.nnodes)) - np.eye(self.nnodes)).nonzero()).T
|
119 |
+
allowed_mask, current_ratio = utils.likelihood_ratio_filter(t_possible_edges,
|
120 |
+
modified_adj,
|
121 |
+
ori_adj, t_d_min,
|
122 |
+
ll_cutoff, undirected=self.undirected)
|
123 |
+
return allowed_mask, current_ratio
|
124 |
+
|
125 |
+
def get_adj_score(self, adj_grad, modified_adj, ori_adj, ll_constraint, ll_cutoff):
|
126 |
+
adj_meta_grad = adj_grad * (-2 * modified_adj + 1)
|
127 |
+
# Make sure that the minimum entry is 0.
|
128 |
+
adj_meta_grad = adj_meta_grad - adj_meta_grad.min()
|
129 |
+
# Filter self-loops
|
130 |
+
adj_meta_grad = adj_meta_grad - torch.diag(torch.diag(adj_meta_grad, 0))
|
131 |
+
# # Set entries to 0 that could lead to singleton nodes.
|
132 |
+
singleton_mask = self.filter_potential_singletons(modified_adj)
|
133 |
+
adj_meta_grad = adj_meta_grad * singleton_mask
|
134 |
+
|
135 |
+
if ll_constraint:
|
136 |
+
allowed_mask, self.ll_ratio = self.log_likelihood_constraint(modified_adj, ori_adj, ll_cutoff)
|
137 |
+
allowed_mask = allowed_mask.to(self.device)
|
138 |
+
adj_meta_grad = adj_meta_grad * allowed_mask
|
139 |
+
return adj_meta_grad
|
140 |
+
|
141 |
+
def get_feature_score(self, feature_grad, modified_features):
|
142 |
+
feature_meta_grad = feature_grad * (-2 * modified_features + 1)
|
143 |
+
feature_meta_grad -= feature_meta_grad.min()
|
144 |
+
return feature_meta_grad
|
145 |
+
|
146 |
+
|
147 |
+
class Metattack(BaseMeta):
|
148 |
+
"""Meta attack. Adversarial Attacks on Graph Neural Networks
|
149 |
+
via Meta Learning, ICLR 2019.
|
150 |
+
|
151 |
+
Examples
|
152 |
+
--------
|
153 |
+
|
154 |
+
>>> import numpy as np
|
155 |
+
>>> from deeprobust.graph.data import Dataset
|
156 |
+
>>> from deeprobust.graph.defense import GCN
|
157 |
+
>>> from deeprobust.graph.global_attack import Metattack
|
158 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
159 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
160 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
161 |
+
>>> idx_unlabeled = np.union1d(idx_val, idx_test)
|
162 |
+
>>> idx_unlabeled = np.union1d(idx_val, idx_test)
|
163 |
+
>>> # Setup Surrogate model
|
164 |
+
>>> surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
165 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device='cpu').to('cpu')
|
166 |
+
>>> surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
167 |
+
>>> # Setup Attack Model
|
168 |
+
>>> model = Metattack(surrogate, nnodes=adj.shape[0], feature_shape=features.shape,
|
169 |
+
attack_structure=True, attack_features=False, device='cpu', lambda_=0).to('cpu')
|
170 |
+
>>> # Attack
|
171 |
+
>>> model.attack(features, adj, labels, idx_train, idx_unlabeled, n_perturbations=10, ll_constraint=False)
|
172 |
+
>>> modified_adj = model.modified_adj
|
173 |
+
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, model, nnodes, feature_shape=None, attack_structure=True, attack_features=False, undirected=True, device='cpu', with_bias=False, lambda_=0.5, train_iters=100, lr=0.1, momentum=0.9):
|
177 |
+
|
178 |
+
super(Metattack, self).__init__(model, nnodes, feature_shape, lambda_, attack_structure, attack_features, undirected, device)
|
179 |
+
self.momentum = momentum
|
180 |
+
self.lr = lr
|
181 |
+
self.train_iters = train_iters
|
182 |
+
self.with_bias = with_bias
|
183 |
+
|
184 |
+
self.weights = []
|
185 |
+
self.biases = []
|
186 |
+
self.w_velocities = []
|
187 |
+
self.b_velocities = []
|
188 |
+
|
189 |
+
self.hidden_sizes = self.surrogate.hidden_sizes
|
190 |
+
self.nfeat = self.surrogate.nfeat
|
191 |
+
self.nclass = self.surrogate.nclass
|
192 |
+
|
193 |
+
previous_size = self.nfeat
|
194 |
+
for ix, nhid in enumerate(self.hidden_sizes):
|
195 |
+
weight = Parameter(torch.FloatTensor(previous_size, nhid).to(device))
|
196 |
+
w_velocity = torch.zeros(weight.shape).to(device)
|
197 |
+
self.weights.append(weight)
|
198 |
+
self.w_velocities.append(w_velocity)
|
199 |
+
|
200 |
+
if self.with_bias:
|
201 |
+
bias = Parameter(torch.FloatTensor(nhid).to(device))
|
202 |
+
b_velocity = torch.zeros(bias.shape).to(device)
|
203 |
+
self.biases.append(bias)
|
204 |
+
self.b_velocities.append(b_velocity)
|
205 |
+
|
206 |
+
previous_size = nhid
|
207 |
+
|
208 |
+
output_weight = Parameter(torch.FloatTensor(previous_size, self.nclass).to(device))
|
209 |
+
output_w_velocity = torch.zeros(output_weight.shape).to(device)
|
210 |
+
self.weights.append(output_weight)
|
211 |
+
self.w_velocities.append(output_w_velocity)
|
212 |
+
|
213 |
+
if self.with_bias:
|
214 |
+
output_bias = Parameter(torch.FloatTensor(self.nclass).to(device))
|
215 |
+
output_b_velocity = torch.zeros(output_bias.shape).to(device)
|
216 |
+
self.biases.append(output_bias)
|
217 |
+
self.b_velocities.append(output_b_velocity)
|
218 |
+
|
219 |
+
self._initialize()
|
220 |
+
|
221 |
+
def _initialize(self):
|
222 |
+
for w, v in zip(self.weights, self.w_velocities):
|
223 |
+
stdv = 1. / math.sqrt(w.size(1))
|
224 |
+
w.data.uniform_(-stdv, stdv)
|
225 |
+
v.data.fill_(0)
|
226 |
+
|
227 |
+
if self.with_bias:
|
228 |
+
for b, v in zip(self.biases, self.b_velocities):
|
229 |
+
stdv = 1. / math.sqrt(w.size(1))
|
230 |
+
b.data.uniform_(-stdv, stdv)
|
231 |
+
v.data.fill_(0)
|
232 |
+
|
233 |
+
def inner_train(self, features, adj_norm, idx_train, idx_unlabeled, labels):
|
234 |
+
self._initialize()
|
235 |
+
|
236 |
+
for ix in range(len(self.hidden_sizes) + 1):
|
237 |
+
self.weights[ix] = self.weights[ix].detach()
|
238 |
+
self.weights[ix].requires_grad = True
|
239 |
+
self.w_velocities[ix] = self.w_velocities[ix].detach()
|
240 |
+
self.w_velocities[ix].requires_grad = True
|
241 |
+
|
242 |
+
if self.with_bias:
|
243 |
+
self.biases[ix] = self.biases[ix].detach()
|
244 |
+
self.biases[ix].requires_grad = True
|
245 |
+
self.b_velocities[ix] = self.b_velocities[ix].detach()
|
246 |
+
self.b_velocities[ix].requires_grad = True
|
247 |
+
|
248 |
+
for j in range(self.train_iters):
|
249 |
+
hidden = features
|
250 |
+
for ix, w in enumerate(self.weights):
|
251 |
+
b = self.biases[ix] if self.with_bias else 0
|
252 |
+
if self.sparse_features:
|
253 |
+
hidden = adj_norm @ torch.spmm(hidden, w) + b
|
254 |
+
else:
|
255 |
+
hidden = adj_norm @ hidden @ w + b
|
256 |
+
|
257 |
+
if self.with_relu and ix != len(self.weights) - 1:
|
258 |
+
hidden = F.relu(hidden)
|
259 |
+
|
260 |
+
output = F.log_softmax(hidden, dim=1)
|
261 |
+
loss_labeled = F.nll_loss(output[idx_train], labels[idx_train])
|
262 |
+
|
263 |
+
weight_grads = torch.autograd.grad(loss_labeled, self.weights, create_graph=True)
|
264 |
+
self.w_velocities = [self.momentum * v + g for v, g in zip(self.w_velocities, weight_grads)]
|
265 |
+
if self.with_bias:
|
266 |
+
bias_grads = torch.autograd.grad(loss_labeled, self.biases, create_graph=True)
|
267 |
+
self.b_velocities = [self.momentum * v + g for v, g in zip(self.b_velocities, bias_grads)]
|
268 |
+
|
269 |
+
self.weights = [w - self.lr * v for w, v in zip(self.weights, self.w_velocities)]
|
270 |
+
if self.with_bias:
|
271 |
+
self.biases = [b - self.lr * v for b, v in zip(self.biases, self.b_velocities)]
|
272 |
+
|
273 |
+
def get_meta_grad(self, features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training):
|
274 |
+
|
275 |
+
hidden = features
|
276 |
+
for ix, w in enumerate(self.weights):
|
277 |
+
b = self.biases[ix] if self.with_bias else 0
|
278 |
+
if self.sparse_features:
|
279 |
+
hidden = adj_norm @ torch.spmm(hidden, w) + b
|
280 |
+
else:
|
281 |
+
hidden = adj_norm @ hidden @ w + b
|
282 |
+
if self.with_relu and ix != len(self.weights) - 1:
|
283 |
+
hidden = F.relu(hidden)
|
284 |
+
|
285 |
+
output = F.log_softmax(hidden, dim=1)
|
286 |
+
|
287 |
+
loss_labeled = F.nll_loss(output[idx_train], labels[idx_train])
|
288 |
+
loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled])
|
289 |
+
loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled])
|
290 |
+
|
291 |
+
if self.lambda_ == 1:
|
292 |
+
attack_loss = loss_labeled
|
293 |
+
elif self.lambda_ == 0:
|
294 |
+
attack_loss = loss_unlabeled
|
295 |
+
else:
|
296 |
+
attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled
|
297 |
+
|
298 |
+
print('GCN loss on unlabled data: {}'.format(loss_test_val.item()))
|
299 |
+
print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item()))
|
300 |
+
print('attack loss: {}'.format(attack_loss.item()))
|
301 |
+
|
302 |
+
adj_grad, feature_grad = None, None
|
303 |
+
if self.attack_structure:
|
304 |
+
adj_grad = torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0]
|
305 |
+
if self.attack_features:
|
306 |
+
feature_grad = torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0]
|
307 |
+
return adj_grad, feature_grad
|
308 |
+
|
309 |
+
def attack(self, ori_features, ori_adj, labels, idx_train, idx_unlabeled, n_perturbations, ll_constraint=True, ll_cutoff=0.004):
|
310 |
+
"""Generate n_perturbations on the input graph.
|
311 |
+
|
312 |
+
Parameters
|
313 |
+
----------
|
314 |
+
ori_features :
|
315 |
+
Original (unperturbed) node feature matrix
|
316 |
+
ori_adj :
|
317 |
+
Original (unperturbed) adjacency matrix
|
318 |
+
labels :
|
319 |
+
node labels
|
320 |
+
idx_train :
|
321 |
+
node training indices
|
322 |
+
idx_unlabeled:
|
323 |
+
unlabeled nodes indices
|
324 |
+
n_perturbations : int
|
325 |
+
Number of perturbations on the input graph. Perturbations could
|
326 |
+
be edge removals/additions or feature removals/additions.
|
327 |
+
ll_constraint: bool
|
328 |
+
whether to exert the likelihood ratio test constraint
|
329 |
+
ll_cutoff : float
|
330 |
+
The critical value for the likelihood ratio test of the power law distributions.
|
331 |
+
See the Chi square distribution with one degree of freedom. Default value 0.004
|
332 |
+
corresponds to a p-value of roughly 0.95. It would be ignored if `ll_constraint`
|
333 |
+
is False.
|
334 |
+
|
335 |
+
"""
|
336 |
+
|
337 |
+
self.sparse_features = sp.issparse(ori_features)
|
338 |
+
ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)
|
339 |
+
labels_self_training = self.self_training_label(labels, idx_train)
|
340 |
+
modified_adj = ori_adj
|
341 |
+
modified_features = ori_features
|
342 |
+
|
343 |
+
for i in tqdm(range(n_perturbations), desc="Perturbing graph"):
|
344 |
+
if self.attack_structure:
|
345 |
+
modified_adj = self.get_modified_adj(ori_adj)
|
346 |
+
|
347 |
+
if self.attack_features:
|
348 |
+
modified_features = ori_features + self.feature_changes
|
349 |
+
|
350 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
351 |
+
self.inner_train(modified_features, adj_norm, idx_train, idx_unlabeled, labels)
|
352 |
+
|
353 |
+
adj_grad, feature_grad = self.get_meta_grad(modified_features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training)
|
354 |
+
|
355 |
+
adj_meta_score = torch.tensor(0.0).to(self.device)
|
356 |
+
feature_meta_score = torch.tensor(0.0).to(self.device)
|
357 |
+
if self.attack_structure:
|
358 |
+
adj_meta_score = self.get_adj_score(adj_grad, modified_adj, ori_adj, ll_constraint, ll_cutoff)
|
359 |
+
if self.attack_features:
|
360 |
+
feature_meta_score = self.get_feature_score(feature_grad, modified_features)
|
361 |
+
|
362 |
+
if adj_meta_score.max() >= feature_meta_score.max():
|
363 |
+
adj_meta_argmax = torch.argmax(adj_meta_score)
|
364 |
+
row_idx, col_idx = utils.unravel_index(adj_meta_argmax, ori_adj.shape)
|
365 |
+
self.adj_changes.data[row_idx][col_idx] += (-2 * modified_adj[row_idx][col_idx] + 1)
|
366 |
+
if self.undirected:
|
367 |
+
self.adj_changes.data[col_idx][row_idx] += (-2 * modified_adj[row_idx][col_idx] + 1)
|
368 |
+
else:
|
369 |
+
feature_meta_argmax = torch.argmax(feature_meta_score)
|
370 |
+
row_idx, col_idx = utils.unravel_index(feature_meta_argmax, ori_features.shape)
|
371 |
+
self.feature_changes.data[row_idx][col_idx] += (-2 * modified_features[row_idx][col_idx] + 1)
|
372 |
+
|
373 |
+
if self.attack_structure:
|
374 |
+
self.modified_adj = self.get_modified_adj(ori_adj).detach()
|
375 |
+
if self.attack_features:
|
376 |
+
self.modified_features = self.get_modified_features(ori_features).detach()
|
377 |
+
|
378 |
+
|
379 |
+
class MetaApprox(BaseMeta):
|
380 |
+
"""Approximated version of Meta Attack. Adversarial Attacks on
|
381 |
+
Graph Neural Networks via Meta Learning, ICLR 2019.
|
382 |
+
|
383 |
+
Examples
|
384 |
+
--------
|
385 |
+
|
386 |
+
>>> import numpy as np
|
387 |
+
>>> from deeprobust.graph.data import Dataset
|
388 |
+
>>> from deeprobust.graph.defense import GCN
|
389 |
+
>>> from deeprobust.graph.global_attack import MetaApprox
|
390 |
+
>>> from deeprobust.graph.utils import preprocess
|
391 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
392 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
393 |
+
>>> adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False) # conver to tensor
|
394 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
395 |
+
>>> idx_unlabeled = np.union1d(idx_val, idx_test)
|
396 |
+
>>> # Setup Surrogate model
|
397 |
+
>>> surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
398 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device='cpu').to('cpu')
|
399 |
+
>>> surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
400 |
+
>>> # Setup Attack Model
|
401 |
+
>>> model = MetaApprox(surrogate, nnodes=adj.shape[0], feature_shape=features.shape,
|
402 |
+
attack_structure=True, attack_features=False, device='cpu', lambda_=0).to('cpu')
|
403 |
+
>>> # Attack
|
404 |
+
>>> model.attack(features, adj, labels, idx_train, idx_unlabeled, n_perturbations=10, ll_constraint=True)
|
405 |
+
>>> modified_adj = model.modified_adj
|
406 |
+
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(self, model, nnodes, feature_shape=None, attack_structure=True, attack_features=False, undirected=True, device='cpu', with_bias=False, lambda_=0.5, train_iters=100, lr=0.01):
|
410 |
+
|
411 |
+
super(MetaApprox, self).__init__(model, nnodes, feature_shape, lambda_, attack_structure, attack_features, undirected, device)
|
412 |
+
|
413 |
+
self.lr = lr
|
414 |
+
self.train_iters = train_iters
|
415 |
+
self.adj_meta_grad = None
|
416 |
+
self.features_meta_grad = None
|
417 |
+
if self.attack_structure:
|
418 |
+
self.adj_grad_sum = torch.zeros(nnodes, nnodes).to(device)
|
419 |
+
if self.attack_features:
|
420 |
+
self.feature_grad_sum = torch.zeros(feature_shape).to(device)
|
421 |
+
|
422 |
+
self.with_bias = with_bias
|
423 |
+
|
424 |
+
self.weights = []
|
425 |
+
self.biases = []
|
426 |
+
|
427 |
+
previous_size = self.nfeat
|
428 |
+
for ix, nhid in enumerate(self.hidden_sizes):
|
429 |
+
weight = Parameter(torch.FloatTensor(previous_size, nhid).to(device))
|
430 |
+
bias = Parameter(torch.FloatTensor(nhid).to(device))
|
431 |
+
previous_size = nhid
|
432 |
+
|
433 |
+
self.weights.append(weight)
|
434 |
+
self.biases.append(bias)
|
435 |
+
|
436 |
+
output_weight = Parameter(torch.FloatTensor(previous_size, self.nclass).to(device))
|
437 |
+
output_bias = Parameter(torch.FloatTensor(self.nclass).to(device))
|
438 |
+
self.weights.append(output_weight)
|
439 |
+
self.biases.append(output_bias)
|
440 |
+
|
441 |
+
self.optimizer = optim.Adam(self.weights + self.biases, lr=lr) # , weight_decay=5e-4)
|
442 |
+
self._initialize()
|
443 |
+
|
444 |
+
def _initialize(self):
|
445 |
+
for w, b in zip(self.weights, self.biases):
|
446 |
+
# w.data.fill_(1)
|
447 |
+
# b.data.fill_(1)
|
448 |
+
stdv = 1. / math.sqrt(w.size(1))
|
449 |
+
w.data.uniform_(-stdv, stdv)
|
450 |
+
b.data.uniform_(-stdv, stdv)
|
451 |
+
|
452 |
+
self.optimizer = optim.Adam(self.weights + self.biases, lr=self.lr)
|
453 |
+
|
454 |
+
def inner_train(self, features, modified_adj, idx_train, idx_unlabeled, labels, labels_self_training):
|
455 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
456 |
+
for j in range(self.train_iters):
|
457 |
+
# hidden = features
|
458 |
+
# for w, b in zip(self.weights, self.biases):
|
459 |
+
# if self.sparse_features:
|
460 |
+
# hidden = adj_norm @ torch.spmm(hidden, w) + b
|
461 |
+
# else:
|
462 |
+
# hidden = adj_norm @ hidden @ w + b
|
463 |
+
# if self.with_relu:
|
464 |
+
# hidden = F.relu(hidden)
|
465 |
+
|
466 |
+
hidden = features
|
467 |
+
for ix, w in enumerate(self.weights):
|
468 |
+
b = self.biases[ix] if self.with_bias else 0
|
469 |
+
if self.sparse_features:
|
470 |
+
hidden = adj_norm @ torch.spmm(hidden, w) + b
|
471 |
+
else:
|
472 |
+
hidden = adj_norm @ hidden @ w + b
|
473 |
+
if self.with_relu:
|
474 |
+
hidden = F.relu(hidden)
|
475 |
+
|
476 |
+
output = F.log_softmax(hidden, dim=1)
|
477 |
+
loss_labeled = F.nll_loss(output[idx_train], labels[idx_train])
|
478 |
+
loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled])
|
479 |
+
|
480 |
+
if self.lambda_ == 1:
|
481 |
+
attack_loss = loss_labeled
|
482 |
+
elif self.lambda_ == 0:
|
483 |
+
attack_loss = loss_unlabeled
|
484 |
+
else:
|
485 |
+
attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled
|
486 |
+
|
487 |
+
self.optimizer.zero_grad()
|
488 |
+
loss_labeled.backward(retain_graph=True)
|
489 |
+
|
490 |
+
if self.attack_structure:
|
491 |
+
self.adj_changes.grad.zero_()
|
492 |
+
self.adj_grad_sum += torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0]
|
493 |
+
if self.attack_features:
|
494 |
+
self.feature_changes.grad.zero_()
|
495 |
+
self.feature_grad_sum += torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0]
|
496 |
+
|
497 |
+
self.optimizer.step()
|
498 |
+
|
499 |
+
|
500 |
+
loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled])
|
501 |
+
print('GCN loss on unlabled data: {}'.format(loss_test_val.item()))
|
502 |
+
print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item()))
|
503 |
+
|
504 |
+
|
505 |
+
def attack(self, ori_features, ori_adj, labels, idx_train, idx_unlabeled, n_perturbations, ll_constraint=True, ll_cutoff=0.004):
|
506 |
+
"""Generate n_perturbations on the input graph.
|
507 |
+
|
508 |
+
Parameters
|
509 |
+
----------
|
510 |
+
ori_features :
|
511 |
+
Original (unperturbed) node feature matrix
|
512 |
+
ori_adj :
|
513 |
+
Original (unperturbed) adjacency matrix
|
514 |
+
labels :
|
515 |
+
node labels
|
516 |
+
idx_train :
|
517 |
+
node training indices
|
518 |
+
idx_unlabeled:
|
519 |
+
unlabeled nodes indices
|
520 |
+
n_perturbations : int
|
521 |
+
Number of perturbations on the input graph. Perturbations could
|
522 |
+
be edge removals/additions or feature removals/additions.
|
523 |
+
ll_constraint: bool
|
524 |
+
whether to exert the likelihood ratio test constraint
|
525 |
+
ll_cutoff : float
|
526 |
+
The critical value for the likelihood ratio test of the power law distributions.
|
527 |
+
See the Chi square distribution with one degree of freedom. Default value 0.004
|
528 |
+
corresponds to a p-value of roughly 0.95. It would be ignored if `ll_constraint`
|
529 |
+
is False.
|
530 |
+
|
531 |
+
"""
|
532 |
+
ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)
|
533 |
+
labels_self_training = self.self_training_label(labels, idx_train)
|
534 |
+
self.sparse_features = sp.issparse(ori_features)
|
535 |
+
modified_adj = ori_adj
|
536 |
+
modified_features = ori_features
|
537 |
+
|
538 |
+
for i in tqdm(range(n_perturbations), desc="Perturbing graph"):
|
539 |
+
self._initialize()
|
540 |
+
|
541 |
+
if self.attack_structure:
|
542 |
+
modified_adj = self.get_modified_adj(ori_adj)
|
543 |
+
self.adj_grad_sum.data.fill_(0)
|
544 |
+
if self.attack_features:
|
545 |
+
modified_features = ori_features + self.feature_changes
|
546 |
+
self.feature_grad_sum.data.fill_(0)
|
547 |
+
|
548 |
+
self.inner_train(modified_features, modified_adj, idx_train, idx_unlabeled, labels, labels_self_training)
|
549 |
+
|
550 |
+
adj_meta_score = torch.tensor(0.0).to(self.device)
|
551 |
+
feature_meta_score = torch.tensor(0.0).to(self.device)
|
552 |
+
|
553 |
+
if self.attack_structure:
|
554 |
+
adj_meta_score = self.get_adj_score(self.adj_grad_sum, modified_adj, ori_adj, ll_constraint, ll_cutoff)
|
555 |
+
if self.attack_features:
|
556 |
+
feature_meta_score = self.get_feature_score(self.feature_grad_sum, modified_features)
|
557 |
+
|
558 |
+
if adj_meta_score.max() >= feature_meta_score.max():
|
559 |
+
adj_meta_argmax = torch.argmax(adj_meta_score)
|
560 |
+
row_idx, col_idx = utils.unravel_index(adj_meta_argmax, ori_adj.shape)
|
561 |
+
self.adj_changes.data[row_idx][col_idx] += (-2 * modified_adj[row_idx][col_idx] + 1)
|
562 |
+
if self.undirected:
|
563 |
+
self.adj_changes.data[col_idx][row_idx] += (-2 * modified_adj[row_idx][col_idx] + 1)
|
564 |
+
else:
|
565 |
+
feature_meta_argmax = torch.argmax(feature_meta_score)
|
566 |
+
row_idx, col_idx = utils.unravel_index(feature_meta_argmax, ori_features.shape)
|
567 |
+
self.feature_changes.data[row_idx][col_idx] += (-2 * modified_features[row_idx][col_idx] + 1)
|
568 |
+
|
569 |
+
if self.attack_structure:
|
570 |
+
self.modified_adj = self.get_modified_adj(ori_adj).detach()
|
571 |
+
if self.attack_features:
|
572 |
+
self.modified_features = self.get_modified_features(ori_features).detach()
|
deeprobust/graph/global_attack/nipa.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Non-target-specific Node Injection Attacks on Graph Neural Networks: A Hierarchical Reinforcement Learning Approach. WWW 2020.
|
3 |
+
https://faculty.ist.psu.edu/vhonavar/Papers/www20.pdf
|
4 |
+
|
5 |
+
Still on testing stage. Haven't reproduced the performance yet.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import os.path as osp
|
10 |
+
import random
|
11 |
+
from itertools import count
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.optim as optim
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from deeprobust.graph.rl.nipa_q_net_node import (NStepQNetNode, QNetNode,
|
20 |
+
node_greedy_actions)
|
21 |
+
from deeprobust.graph.rl.nstep_replay_mem import NstepReplayMem
|
22 |
+
from deeprobust.graph.utils import loss_acc
|
23 |
+
|
24 |
+
|
25 |
+
class NIPA(object):
|
26 |
+
""" Reinforcement learning agent for NIPA attack.
|
27 |
+
https://faculty.ist.psu.edu/vhonavar/Papers/www20.pdf
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
env :
|
32 |
+
Node attack environment
|
33 |
+
features :
|
34 |
+
node features matrix
|
35 |
+
labels :
|
36 |
+
labels
|
37 |
+
idx_meta :
|
38 |
+
node meta indices
|
39 |
+
idx_test :
|
40 |
+
node test indices
|
41 |
+
list_action_space : list
|
42 |
+
list of action space
|
43 |
+
num_mod :
|
44 |
+
number of modification (perturbation) on the graph
|
45 |
+
reward_type : str
|
46 |
+
type of reward (e.g., 'binary')
|
47 |
+
batch_size :
|
48 |
+
batch size for training DQN
|
49 |
+
save_dir :
|
50 |
+
saving directory for model checkpoints
|
51 |
+
device: str
|
52 |
+
'cpu' or 'cuda'
|
53 |
+
|
54 |
+
Examples
|
55 |
+
--------
|
56 |
+
See more details in https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_nipa.py
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, env, features, labels, idx_train, idx_val, idx_test,
|
60 |
+
list_action_space, ratio, reward_type='binary', batch_size=30,
|
61 |
+
num_wrong=0, bilin_q=1, embed_dim=64, gm='mean_field',
|
62 |
+
mlp_hidden=64, max_lv=1, save_dir='checkpoint_dqn', device=None):
|
63 |
+
|
64 |
+
assert device is not None, "'device' cannot be None, please specify it"
|
65 |
+
|
66 |
+
self.features = features
|
67 |
+
self.labels = labels
|
68 |
+
self.possible_labels = torch.arange(labels.max() + 1).to(labels.device)
|
69 |
+
self.idx_train = idx_train
|
70 |
+
self.idx_val = idx_val
|
71 |
+
self.idx_test = idx_test
|
72 |
+
self.num_wrong = num_wrong
|
73 |
+
self.list_action_space = list_action_space
|
74 |
+
|
75 |
+
degrees = np.array([len(d) for n, d in list_action_space.items()])
|
76 |
+
N = len(degrees[degrees > 0])
|
77 |
+
self.n_injected = len(degrees) - N
|
78 |
+
assert self.n_injected == int(ratio * N)
|
79 |
+
self.injected_nodes = np.arange(N)[-self.n_injected: ]
|
80 |
+
|
81 |
+
self.reward_type = reward_type
|
82 |
+
self.batch_size = batch_size
|
83 |
+
self.save_dir = save_dir
|
84 |
+
if not osp.exists(save_dir):
|
85 |
+
os.system('mkdir -p %s' % save_dir)
|
86 |
+
|
87 |
+
self.gm = gm
|
88 |
+
self.device = device
|
89 |
+
|
90 |
+
self.mem_pool = NstepReplayMem(memory_size=500000, n_steps=3, balance_sample=reward_type == 'binary', model='nipa')
|
91 |
+
self.env = env
|
92 |
+
|
93 |
+
self.net = NStepQNetNode(3, features, labels, list_action_space, self.n_injected,
|
94 |
+
bilin_q=bilin_q, embed_dim=embed_dim, mlp_hidden=mlp_hidden,
|
95 |
+
max_lv=max_lv, gm=gm, device=device)
|
96 |
+
|
97 |
+
self.old_net = NStepQNetNode(3, features, labels, list_action_space, self.n_injected,
|
98 |
+
bilin_q=bilin_q, embed_dim=embed_dim, mlp_hidden=mlp_hidden,
|
99 |
+
max_lv=max_lv, gm=gm, device=device)
|
100 |
+
|
101 |
+
self.net = self.net.to(device)
|
102 |
+
self.old_net = self.old_net.to(device)
|
103 |
+
|
104 |
+
self.eps_start = 1.0
|
105 |
+
self.eps_end = 0.05
|
106 |
+
# self.eps_step = 100000
|
107 |
+
self.eps_step = 30000
|
108 |
+
self.GAMMA = 0.9
|
109 |
+
self.burn_in = 50
|
110 |
+
self.step = 0
|
111 |
+
self.pos = 0
|
112 |
+
self.best_eval = None
|
113 |
+
self.take_snapshot()
|
114 |
+
|
115 |
+
def take_snapshot(self):
|
116 |
+
self.old_net.load_state_dict(self.net.state_dict())
|
117 |
+
|
118 |
+
def make_actions(self, time_t, greedy=False):
|
119 |
+
# TODO
|
120 |
+
self.eps = self.eps_end + max(0., (self.eps_start - self.eps_end)
|
121 |
+
* (self.eps_step - max(0., self.step)) / self.eps_step)
|
122 |
+
|
123 |
+
self.step += 1
|
124 |
+
if random.random() < self.eps and not greedy:
|
125 |
+
actions = self.env.uniformRandActions()
|
126 |
+
else:
|
127 |
+
|
128 |
+
cur_state = self.env.getStateRef()
|
129 |
+
# list_at = self.env.uniformRandActions()
|
130 |
+
list_at = self.env.first_nodes if time_t == 1 else None
|
131 |
+
|
132 |
+
actions = self.possible_actions(cur_state, list_at, time_t)
|
133 |
+
actions, values = self.net(time_t, cur_state, actions, greedy_acts=True, is_inference=True)
|
134 |
+
|
135 |
+
assert len(actions) == len(cur_state)
|
136 |
+
# actions = list(actions.cpu().numpy())
|
137 |
+
return actions
|
138 |
+
|
139 |
+
def run_simulation(self):
|
140 |
+
self.env.setup()
|
141 |
+
t = 0
|
142 |
+
while not self.env.isActionFinished():
|
143 |
+
list_at = self.make_actions(t)
|
144 |
+
list_st = self.env.cloneState()
|
145 |
+
|
146 |
+
self.env.step(list_at)
|
147 |
+
|
148 |
+
assert (self.env.rewards is not None) == self.env.isActionFinished()
|
149 |
+
if self.env.isActionFinished():
|
150 |
+
rewards = self.env.rewards
|
151 |
+
s_prime = self.env.cloneState()
|
152 |
+
else:
|
153 |
+
rewards = np.zeros(len(list_at), dtype=np.float32)
|
154 |
+
s_prime = self.env.cloneState()
|
155 |
+
|
156 |
+
if self.env.isTerminal():
|
157 |
+
rewards = self.env.rewards
|
158 |
+
s_prime = None
|
159 |
+
# self.env.init_overall_steps()
|
160 |
+
|
161 |
+
self.mem_pool.add_list(list_st, list_at, rewards, s_prime,
|
162 |
+
[self.env.isTerminal()] * len(list_at), t)
|
163 |
+
t += 1
|
164 |
+
|
165 |
+
def eval(self, training=True):
|
166 |
+
"""Evaluate RL agent.
|
167 |
+
"""
|
168 |
+
self.env.init_overall_steps()
|
169 |
+
self.env.setup()
|
170 |
+
|
171 |
+
for _ in count():
|
172 |
+
self.env.setup()
|
173 |
+
t = 0
|
174 |
+
while not self.env.isActionFinished():
|
175 |
+
list_at = self.make_actions(t, greedy=True)
|
176 |
+
# print(list_at)
|
177 |
+
self.env.step(list_at, inference=True)
|
178 |
+
t += 1
|
179 |
+
if self.env.isTerminal():
|
180 |
+
break
|
181 |
+
|
182 |
+
device = self.labels.device
|
183 |
+
extra_adj = self.env.modified_list[0].get_extra_adj(device=device)
|
184 |
+
adj = self.env.classifier.norm_tool.norm_extra(extra_adj)
|
185 |
+
labels = torch.cat((self.labels, self.env.modified_label_list[0]))
|
186 |
+
|
187 |
+
self.env.classifier.fit(self.features, adj, labels, self.idx_train, self.idx_val, normalize=False, patience=50)
|
188 |
+
output = self.env.classifier(self.features, adj)
|
189 |
+
loss, acc = loss_acc(output, self.labels, self.idx_test)
|
190 |
+
print('\033[93m average test: acc %.5f\033[0m' % (acc))
|
191 |
+
|
192 |
+
if training == True and self.best_eval is None or acc < self.best_eval:
|
193 |
+
print('----saving to best attacker since this is the best attack rate so far.----')
|
194 |
+
torch.save(self.net.state_dict(), osp.join(self.save_dir, 'epoch-best.model'))
|
195 |
+
with open(osp.join(self.save_dir, 'epoch-best.txt'), 'w') as f:
|
196 |
+
f.write('%.4f\n' % acc)
|
197 |
+
# with open(osp.join(self.save_dir, 'attack_solution.txt'), 'w') as f:
|
198 |
+
# for i in range(len(self.idx_meta)):
|
199 |
+
# f.write('%d: [' % self.idx_meta[i])
|
200 |
+
# for e in self.env.modified_list[i].directed_edges:
|
201 |
+
# f.write('(%d %d)' % e)
|
202 |
+
# f.write('] succ: %d\n' % (self.env.binary_rewards[i]))
|
203 |
+
self.best_eval = acc
|
204 |
+
|
205 |
+
def train(self, num_episodes=10, lr=0.01):
|
206 |
+
"""Train RL agent.
|
207 |
+
"""
|
208 |
+
optimizer = optim.Adam(self.net.parameters(), lr=lr)
|
209 |
+
self.env.init_overall_steps()
|
210 |
+
pbar = tqdm(range(self.burn_in), unit='batch')
|
211 |
+
for p in pbar:
|
212 |
+
self.run_simulation()
|
213 |
+
self.mem_pool.print_count()
|
214 |
+
|
215 |
+
for i_episode in tqdm(range(num_episodes)):
|
216 |
+
self.env.init_overall_steps()
|
217 |
+
|
218 |
+
for t in count():
|
219 |
+
self.run_simulation()
|
220 |
+
|
221 |
+
cur_time, list_st, list_at, list_rt, list_s_primes, list_term = self.mem_pool.sample(batch_size=self.batch_size)
|
222 |
+
list_target = torch.Tensor(list_rt).to(self.device)
|
223 |
+
|
224 |
+
if not list_term[0]:
|
225 |
+
actions = self.possible_actions(list_st, list_at, cur_time+1)
|
226 |
+
_, q_rhs = self.old_net(cur_time + 1, list_s_primes, actions, greedy_acts=True)
|
227 |
+
list_target += self.GAMMA * q_rhs
|
228 |
+
|
229 |
+
# list_target = list_target.view(-1, 1)
|
230 |
+
_, q_sa = self.net(cur_time, list_st, list_at)
|
231 |
+
loss = F.mse_loss(q_sa, list_target)
|
232 |
+
loss = torch.clamp(loss, -1, 1)
|
233 |
+
optimizer.zero_grad()
|
234 |
+
loss.backward()
|
235 |
+
# print([x[0] for x in self.nnamed_parameters() if x[1].grad is None])
|
236 |
+
# for param in self.net.parameters():
|
237 |
+
# if param.grad is None:
|
238 |
+
# continue
|
239 |
+
# param.grad.data.clamp_(-1, 1)
|
240 |
+
optimizer.step()
|
241 |
+
|
242 |
+
# pbar.set_description('eps: %.5f, loss: %0.5f, q_val: %.5f' % (self.eps, loss, torch.mean(q_sa)) )
|
243 |
+
if t % 20 == 0:
|
244 |
+
print('eps: %.5f, loss: %0.5f, q_val: %.5f, list_target: %.5f' % (self.eps, loss, torch.mean(q_sa), torch.mean(list_target)) )
|
245 |
+
|
246 |
+
if self.env.isTerminal():
|
247 |
+
break
|
248 |
+
|
249 |
+
# if (t+1) % 50 == 0:
|
250 |
+
# self.take_snapshot()
|
251 |
+
|
252 |
+
if i_episode % 1 == 0:
|
253 |
+
self.take_snapshot()
|
254 |
+
|
255 |
+
if i_episode % 1 == 0:
|
256 |
+
self.eval()
|
257 |
+
|
258 |
+
def possible_actions(self, list_st, list_at, t):
|
259 |
+
"""
|
260 |
+
Parameters
|
261 |
+
----------
|
262 |
+
list_st:
|
263 |
+
current state
|
264 |
+
list_at:
|
265 |
+
current action
|
266 |
+
|
267 |
+
Returns
|
268 |
+
-------
|
269 |
+
list
|
270 |
+
actions for next state
|
271 |
+
"""
|
272 |
+
|
273 |
+
t = t % 3
|
274 |
+
if t == 0:
|
275 |
+
return np.tile(self.injected_nodes, ((len(list_st), 1)))
|
276 |
+
|
277 |
+
if t == 1:
|
278 |
+
actions = []
|
279 |
+
for i in range(len(list_at)):
|
280 |
+
a_prime = list_st[i][0].get_possible_nodes(list_at[i])
|
281 |
+
actions.append(a_prime)
|
282 |
+
return actions
|
283 |
+
|
284 |
+
if t == 2:
|
285 |
+
return self.possible_labels.repeat((len(list_st), 1))
|
deeprobust/graph/global_attack/random_attack.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from deeprobust.graph.global_attack import BaseAttack
|
3 |
+
import scipy.sparse as sp
|
4 |
+
# import random
|
5 |
+
|
6 |
+
|
7 |
+
class Random(BaseAttack):
|
8 |
+
""" Randomly adding edges to the input graph
|
9 |
+
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
model :
|
13 |
+
model to attack. Default `None`.
|
14 |
+
nnodes : int
|
15 |
+
number of nodes in the input graph
|
16 |
+
attack_structure : bool
|
17 |
+
whether to attack graph structure
|
18 |
+
attack_features : bool
|
19 |
+
whether to attack node features
|
20 |
+
device: str
|
21 |
+
'cpu' or 'cuda'
|
22 |
+
|
23 |
+
Examples
|
24 |
+
--------
|
25 |
+
|
26 |
+
>>> from deeprobust.graph.data import Dataset
|
27 |
+
>>> from deeprobust.graph.global_attack import Random
|
28 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
29 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
30 |
+
>>> model = Random()
|
31 |
+
>>> model.attack(adj, n_perturbations=10)
|
32 |
+
>>> modified_adj = model.modified_adj
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, model=None, nnodes=None, attack_structure=True, attack_features=False, device='cpu'):
|
37 |
+
super(Random, self).__init__(model, nnodes, attack_structure=attack_structure, attack_features=attack_features, device=device)
|
38 |
+
|
39 |
+
assert not self.attack_features, 'RND does NOT support attacking features'
|
40 |
+
|
41 |
+
def attack(self, ori_adj, n_perturbations, type='flip', **kwargs):
|
42 |
+
"""Generate attacks on the input graph.
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
ori_adj : scipy.sparse.csr_matrix
|
47 |
+
Original (unperturbed) adjacency matrix.
|
48 |
+
n_perturbations : int
|
49 |
+
Number of edge removals/additions.
|
50 |
+
type: str
|
51 |
+
perturbation type. Could be 'add', 'remove' or 'flip'.
|
52 |
+
|
53 |
+
Returns
|
54 |
+
-------
|
55 |
+
None.
|
56 |
+
|
57 |
+
"""
|
58 |
+
|
59 |
+
if self.attack_structure:
|
60 |
+
modified_adj = self.perturb_adj(ori_adj, n_perturbations, type)
|
61 |
+
self.modified_adj = modified_adj
|
62 |
+
|
63 |
+
def perturb_adj(self, adj, n_perturbations, type='add'):
|
64 |
+
"""Randomly add, remove or flip edges.
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
adj : scipy.sparse.csr_matrix
|
69 |
+
Original (unperturbed) adjacency matrix.
|
70 |
+
n_perturbations : int
|
71 |
+
Number of edge removals/additions.
|
72 |
+
type: str
|
73 |
+
perturbation type. Could be 'add', 'remove' or 'flip'.
|
74 |
+
|
75 |
+
Returns
|
76 |
+
------
|
77 |
+
scipy.sparse matrix
|
78 |
+
perturbed adjacency matrix
|
79 |
+
"""
|
80 |
+
# adj: sp.csr_matrix
|
81 |
+
modified_adj = adj.tolil()
|
82 |
+
|
83 |
+
type = type.lower()
|
84 |
+
assert type in ['add', 'remove', 'flip']
|
85 |
+
|
86 |
+
if type == 'flip':
|
87 |
+
# sample edges to flip
|
88 |
+
edges = self.random_sample_edges(adj, n_perturbations, exclude=set())
|
89 |
+
for n1, n2 in edges:
|
90 |
+
modified_adj[n1, n2] = 1 - modified_adj[n1, n2]
|
91 |
+
modified_adj[n2, n1] = 1 - modified_adj[n2, n1]
|
92 |
+
|
93 |
+
if type == 'add':
|
94 |
+
# sample edges to add
|
95 |
+
nonzero = set(zip(*adj.nonzero()))
|
96 |
+
edges = self.random_sample_edges(adj, n_perturbations, exclude=nonzero)
|
97 |
+
for n1, n2 in edges:
|
98 |
+
modified_adj[n1, n2] = 1
|
99 |
+
modified_adj[n2, n1] = 1
|
100 |
+
|
101 |
+
if type == 'remove':
|
102 |
+
# sample edges to remove
|
103 |
+
nonzero = np.array(sp.triu(adj, k=1).nonzero()).T
|
104 |
+
indices = np.random.permutation(nonzero)[: n_perturbations].T
|
105 |
+
modified_adj[indices[0], indices[1]] = 0
|
106 |
+
modified_adj[indices[1], indices[0]] = 0
|
107 |
+
|
108 |
+
self.check_adj(modified_adj)
|
109 |
+
return modified_adj
|
110 |
+
|
111 |
+
def perturb_features(self, features, n_perturbations):
|
112 |
+
"""Randomly perturb features.
|
113 |
+
"""
|
114 |
+
raise NotImplementedError
|
115 |
+
print('number of pertubations: %s' % n_perturbations)
|
116 |
+
return modified_features
|
117 |
+
|
118 |
+
def inject_nodes(self, adj, n_add, n_perturbations):
|
119 |
+
"""For each added node, randomly connect with other nodes.
|
120 |
+
"""
|
121 |
+
# adj: sp.csr_matrix
|
122 |
+
# TODO
|
123 |
+
print('number of pertubations: %s' % n_perturbations)
|
124 |
+
raise NotImplementedError
|
125 |
+
|
126 |
+
modified_adj = adj.tolil()
|
127 |
+
return modified_adj
|
128 |
+
|
129 |
+
def random_sample_edges(self, adj, n, exclude):
|
130 |
+
itr = self.sample_forever(adj, exclude=exclude)
|
131 |
+
return [next(itr) for _ in range(n)]
|
132 |
+
|
133 |
+
def sample_forever(self, adj, exclude):
|
134 |
+
"""Randomly random sample edges from adjacency matrix, `exclude` is a set
|
135 |
+
which contains the edges we do not want to sample and the ones already sampled
|
136 |
+
"""
|
137 |
+
while True:
|
138 |
+
# t = tuple(np.random.randint(0, adj.shape[0], 2))
|
139 |
+
# t = tuple(random.sample(range(0, adj.shape[0]), 2))
|
140 |
+
t = tuple(np.random.choice(adj.shape[0], 2, replace=False))
|
141 |
+
if t not in exclude:
|
142 |
+
yield t
|
143 |
+
exclude.add(t)
|
144 |
+
exclude.add((t[1], t[0]))
|
deeprobust/graph/global_attack/topology_attack.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective
|
3 |
+
https://arxiv.org/pdf/1906.04214.pdf
|
4 |
+
Tensorflow Implementation:
|
5 |
+
https://github.com/KaidiXu/GCN_ADV_Train
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import scipy.sparse as sp
|
10 |
+
import torch
|
11 |
+
from torch import optim
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from torch.nn.parameter import Parameter
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from deeprobust.graph import utils
|
17 |
+
from deeprobust.graph.global_attack import BaseAttack
|
18 |
+
|
19 |
+
|
20 |
+
class PGDAttack(BaseAttack):
|
21 |
+
"""PGD attack for graph data.
|
22 |
+
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
model :
|
26 |
+
model to attack. Default `None`.
|
27 |
+
nnodes : int
|
28 |
+
number of nodes in the input graph
|
29 |
+
loss_type: str
|
30 |
+
attack loss type, chosen from ['CE', 'CW']
|
31 |
+
feature_shape : tuple
|
32 |
+
shape of the input node features
|
33 |
+
attack_structure : bool
|
34 |
+
whether to attack graph structure
|
35 |
+
attack_features : bool
|
36 |
+
whether to attack node features
|
37 |
+
device: str
|
38 |
+
'cpu' or 'cuda'
|
39 |
+
|
40 |
+
Examples
|
41 |
+
--------
|
42 |
+
|
43 |
+
>>> from deeprobust.graph.data import Dataset
|
44 |
+
>>> from deeprobust.graph.defense import GCN
|
45 |
+
>>> from deeprobust.graph.global_attack import PGDAttack
|
46 |
+
>>> from deeprobust.graph.utils import preprocess
|
47 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
48 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
49 |
+
>>> adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False) # conver to tensor
|
50 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
51 |
+
>>> # Setup Victim Model
|
52 |
+
>>> victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
53 |
+
nhid=16, dropout=0.5, weight_decay=5e-4, device='cpu').to('cpu')
|
54 |
+
>>> victim_model.fit(features, adj, labels, idx_train)
|
55 |
+
>>> # Setup Attack Model
|
56 |
+
>>> model = PGDAttack(model=victim_model, nnodes=adj.shape[0], loss_type='CE', device='cpu').to('cpu')
|
57 |
+
>>> model.attack(features, adj, labels, idx_train, n_perturbations=10)
|
58 |
+
>>> modified_adj = model.modified_adj
|
59 |
+
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, device='cpu'):
|
63 |
+
|
64 |
+
super(PGDAttack, self).__init__(model, nnodes, attack_structure, attack_features, device)
|
65 |
+
|
66 |
+
assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False'
|
67 |
+
|
68 |
+
self.loss_type = loss_type
|
69 |
+
self.modified_adj = None
|
70 |
+
self.modified_features = None
|
71 |
+
|
72 |
+
if attack_structure:
|
73 |
+
assert nnodes is not None, 'Please give nnodes='
|
74 |
+
self.adj_changes = Parameter(torch.FloatTensor(int(nnodes*(nnodes-1)/2)))
|
75 |
+
self.adj_changes.data.fill_(0)
|
76 |
+
|
77 |
+
if attack_features:
|
78 |
+
assert True, 'Topology Attack does not support attack feature'
|
79 |
+
|
80 |
+
self.complementary = None
|
81 |
+
|
82 |
+
def attack(self, ori_features, ori_adj, labels, idx_train, n_perturbations, epochs=25, **kwargs):
|
83 |
+
"""Generate perturbations on the input graph.
|
84 |
+
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
ori_features :
|
88 |
+
Original (unperturbed) node feature matrix
|
89 |
+
ori_adj :
|
90 |
+
Original (unperturbed) adjacency matrix
|
91 |
+
labels :
|
92 |
+
node labels
|
93 |
+
idx_train :
|
94 |
+
node training indices
|
95 |
+
n_perturbations : int
|
96 |
+
Number of perturbations on the input graph. Perturbations could
|
97 |
+
be edge removals/additions or feature removals/additions.
|
98 |
+
epochs:
|
99 |
+
number of training epochs
|
100 |
+
|
101 |
+
"""
|
102 |
+
|
103 |
+
victim_model = self.surrogate
|
104 |
+
|
105 |
+
self.sparse_features = sp.issparse(ori_features)
|
106 |
+
ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)
|
107 |
+
|
108 |
+
victim_model.eval()
|
109 |
+
for t in tqdm(range(epochs)):
|
110 |
+
modified_adj = self.get_modified_adj(ori_adj)
|
111 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
112 |
+
output = victim_model(ori_features, adj_norm)
|
113 |
+
# loss = F.nll_loss(output[idx_train], labels[idx_train])
|
114 |
+
loss = self._loss(output[idx_train], labels[idx_train])
|
115 |
+
adj_grad = torch.autograd.grad(loss, self.adj_changes)[0]
|
116 |
+
|
117 |
+
if self.loss_type == 'CE':
|
118 |
+
lr = 200 / np.sqrt(t+1)
|
119 |
+
self.adj_changes.data.add_(lr * adj_grad)
|
120 |
+
|
121 |
+
if self.loss_type == 'CW':
|
122 |
+
lr = 0.1 / np.sqrt(t+1)
|
123 |
+
self.adj_changes.data.add_(lr * adj_grad)
|
124 |
+
|
125 |
+
self.projection(n_perturbations)
|
126 |
+
|
127 |
+
self.random_sample(ori_adj, ori_features, labels, idx_train, n_perturbations)
|
128 |
+
self.modified_adj = self.get_modified_adj(ori_adj).detach()
|
129 |
+
self.check_adj_tensor(self.modified_adj)
|
130 |
+
|
131 |
+
|
132 |
+
def random_sample(self, ori_adj, ori_features, labels, idx_train, n_perturbations):
|
133 |
+
K = 20
|
134 |
+
best_loss = -1000
|
135 |
+
victim_model = self.surrogate
|
136 |
+
victim_model.eval()
|
137 |
+
with torch.no_grad():
|
138 |
+
s = self.adj_changes.cpu().detach().numpy()
|
139 |
+
for i in range(K):
|
140 |
+
sampled = np.random.binomial(1, s)
|
141 |
+
|
142 |
+
# print(sampled.sum())
|
143 |
+
if sampled.sum() > n_perturbations:
|
144 |
+
continue
|
145 |
+
self.adj_changes.data.copy_(torch.tensor(sampled))
|
146 |
+
modified_adj = self.get_modified_adj(ori_adj)
|
147 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
148 |
+
output = victim_model(ori_features, adj_norm)
|
149 |
+
loss = self._loss(output[idx_train], labels[idx_train])
|
150 |
+
# loss = F.nll_loss(output[idx_train], labels[idx_train])
|
151 |
+
# print(loss)
|
152 |
+
if best_loss < loss:
|
153 |
+
best_loss = loss
|
154 |
+
best_s = sampled
|
155 |
+
self.adj_changes.data.copy_(torch.tensor(best_s))
|
156 |
+
|
157 |
+
def _loss(self, output, labels):
|
158 |
+
if self.loss_type == "CE":
|
159 |
+
loss = F.nll_loss(output, labels)
|
160 |
+
if self.loss_type == "CW":
|
161 |
+
onehot = utils.tensor2onehot(labels)
|
162 |
+
best_second_class = (output - 1000*onehot).argmax(1)
|
163 |
+
margin = output[np.arange(len(output)), labels] - \
|
164 |
+
output[np.arange(len(output)), best_second_class]
|
165 |
+
k = 0
|
166 |
+
loss = -torch.clamp(margin, min=k).mean()
|
167 |
+
# loss = torch.clamp(margin.sum()+50, min=k)
|
168 |
+
return loss
|
169 |
+
|
170 |
+
def projection(self, n_perturbations):
|
171 |
+
# projected = torch.clamp(self.adj_changes, 0, 1)
|
172 |
+
if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations:
|
173 |
+
left = (self.adj_changes - 1).min()
|
174 |
+
right = self.adj_changes.max()
|
175 |
+
miu = self.bisection(left, right, n_perturbations, epsilon=1e-5)
|
176 |
+
self.adj_changes.data.copy_(torch.clamp(self.adj_changes.data - miu, min=0, max=1))
|
177 |
+
else:
|
178 |
+
self.adj_changes.data.copy_(torch.clamp(self.adj_changes.data, min=0, max=1))
|
179 |
+
|
180 |
+
def get_modified_adj(self, ori_adj):
|
181 |
+
|
182 |
+
if self.complementary is None:
|
183 |
+
self.complementary = (torch.ones_like(ori_adj) - torch.eye(self.nnodes).to(self.device) - ori_adj) - ori_adj
|
184 |
+
|
185 |
+
m = torch.zeros((self.nnodes, self.nnodes)).to(self.device)
|
186 |
+
tril_indices = torch.tril_indices(row=self.nnodes, col=self.nnodes, offset=-1)
|
187 |
+
m[tril_indices[0], tril_indices[1]] = self.adj_changes
|
188 |
+
m = m + m.t()
|
189 |
+
modified_adj = self.complementary * m + ori_adj
|
190 |
+
|
191 |
+
return modified_adj
|
192 |
+
|
193 |
+
def bisection(self, a, b, n_perturbations, epsilon):
|
194 |
+
def func(x):
|
195 |
+
return torch.clamp(self.adj_changes-x, 0, 1).sum() - n_perturbations
|
196 |
+
|
197 |
+
miu = a
|
198 |
+
while ((b-a) >= epsilon):
|
199 |
+
miu = (a+b)/2
|
200 |
+
# Check if middle point is root
|
201 |
+
if (func(miu) == 0.0):
|
202 |
+
break
|
203 |
+
# Decide the side to repeat the steps
|
204 |
+
if (func(miu)*func(a) < 0):
|
205 |
+
b = miu
|
206 |
+
else:
|
207 |
+
a = miu
|
208 |
+
# print("The value of root is : ","%.4f" % miu)
|
209 |
+
return miu
|
210 |
+
|
211 |
+
|
212 |
+
class MinMax(PGDAttack):
|
213 |
+
"""MinMax attack for graph data.
|
214 |
+
|
215 |
+
Parameters
|
216 |
+
----------
|
217 |
+
model :
|
218 |
+
model to attack. Default `None`.
|
219 |
+
nnodes : int
|
220 |
+
number of nodes in the input graph
|
221 |
+
loss_type: str
|
222 |
+
attack loss type, chosen from ['CE', 'CW']
|
223 |
+
feature_shape : tuple
|
224 |
+
shape of the input node features
|
225 |
+
attack_structure : bool
|
226 |
+
whether to attack graph structure
|
227 |
+
attack_features : bool
|
228 |
+
whether to attack node features
|
229 |
+
device: str
|
230 |
+
'cpu' or 'cuda'
|
231 |
+
|
232 |
+
Examples
|
233 |
+
--------
|
234 |
+
|
235 |
+
>>> from deeprobust.graph.data import Dataset
|
236 |
+
>>> from deeprobust.graph.defense import GCN
|
237 |
+
>>> from deeprobust.graph.global_attack import MinMax
|
238 |
+
>>> from deeprobust.graph.utils import preprocess
|
239 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
240 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
241 |
+
>>> adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False) # conver to tensor
|
242 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
243 |
+
>>> # Setup Victim Model
|
244 |
+
>>> victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
245 |
+
nhid=16, dropout=0.5, weight_decay=5e-4, device='cpu').to('cpu')
|
246 |
+
>>> victim_model.fit(features, adj, labels, idx_train)
|
247 |
+
>>> # Setup Attack Model
|
248 |
+
>>> model = MinMax(model=victim_model, nnodes=adj.shape[0], loss_type='CE', device='cpu').to('cpu')
|
249 |
+
>>> model.attack(features, adj, labels, idx_train, n_perturbations=10)
|
250 |
+
>>> modified_adj = model.modified_adj
|
251 |
+
|
252 |
+
"""
|
253 |
+
|
254 |
+
def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, device='cpu'):
|
255 |
+
|
256 |
+
super(MinMax, self).__init__(model, nnodes, loss_type, feature_shape, attack_structure, attack_features, device=device)
|
257 |
+
|
258 |
+
|
259 |
+
def attack(self, ori_features, ori_adj, labels, idx_train, n_perturbations, **kwargs):
|
260 |
+
"""Generate perturbations on the input graph.
|
261 |
+
|
262 |
+
Parameters
|
263 |
+
----------
|
264 |
+
ori_features :
|
265 |
+
Original (unperturbed) node feature matrix
|
266 |
+
ori_adj :
|
267 |
+
Original (unperturbed) adjacency matrix
|
268 |
+
labels :
|
269 |
+
node labels
|
270 |
+
idx_train :
|
271 |
+
node training indices
|
272 |
+
n_perturbations : int
|
273 |
+
Number of perturbations on the input graph. Perturbations could
|
274 |
+
be edge removals/additions or feature removals/additions.
|
275 |
+
epochs:
|
276 |
+
number of training epochs
|
277 |
+
|
278 |
+
"""
|
279 |
+
|
280 |
+
victim_model = self.surrogate
|
281 |
+
|
282 |
+
self.sparse_features = sp.issparse(ori_features)
|
283 |
+
ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)
|
284 |
+
|
285 |
+
# optimizer
|
286 |
+
optimizer = optim.Adam(victim_model.parameters(), lr=0.01)
|
287 |
+
|
288 |
+
epochs = 200
|
289 |
+
victim_model.eval()
|
290 |
+
for t in tqdm(range(epochs)):
|
291 |
+
# update victim model
|
292 |
+
victim_model.train()
|
293 |
+
modified_adj = self.get_modified_adj(ori_adj)
|
294 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
295 |
+
output = victim_model(ori_features, adj_norm)
|
296 |
+
loss = self._loss(output[idx_train], labels[idx_train])
|
297 |
+
|
298 |
+
optimizer.zero_grad()
|
299 |
+
loss.backward()
|
300 |
+
optimizer.step()
|
301 |
+
|
302 |
+
# generate pgd attack
|
303 |
+
victim_model.eval()
|
304 |
+
modified_adj = self.get_modified_adj(ori_adj)
|
305 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
306 |
+
output = victim_model(ori_features, adj_norm)
|
307 |
+
loss = self._loss(output[idx_train], labels[idx_train])
|
308 |
+
adj_grad = torch.autograd.grad(loss, self.adj_changes)[0]
|
309 |
+
# adj_grad = self.adj_changes.grad
|
310 |
+
|
311 |
+
if self.loss_type == 'CE':
|
312 |
+
lr = 200 / np.sqrt(t+1)
|
313 |
+
self.adj_changes.data.add_(lr * adj_grad)
|
314 |
+
|
315 |
+
if self.loss_type == 'CW':
|
316 |
+
lr = 0.1 / np.sqrt(t+1)
|
317 |
+
self.adj_changes.data.add_(lr * adj_grad)
|
318 |
+
|
319 |
+
# self.adj_changes.grad.zero_()
|
320 |
+
self.projection(n_perturbations)
|
321 |
+
|
322 |
+
self.random_sample(ori_adj, ori_features, labels, idx_train, n_perturbations)
|
323 |
+
self.modified_adj = self.get_modified_adj(ori_adj).detach()
|
deeprobust/graph/rl/nipa_config.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le
|
2 |
+
"""
|
3 |
+
import argparse
|
4 |
+
import pickle as cp
|
5 |
+
|
6 |
+
cmd_opt = argparse.ArgumentParser(description='Argparser for molecule vae')
|
7 |
+
|
8 |
+
cmd_opt.add_argument('-ratio', type=float, default=0.01, help='ratio of injected nodes')
|
9 |
+
|
10 |
+
cmd_opt.add_argument('-saved_model', type=str, default=None, help='saved model')
|
11 |
+
cmd_opt.add_argument('-save_dir', type=str, default=None, help='save folder')
|
12 |
+
cmd_opt.add_argument('-ctx', type=str, default='gpu', help='cpu/gpu')
|
13 |
+
|
14 |
+
cmd_opt.add_argument('-phase', type=str, default='train', help='train/test')
|
15 |
+
cmd_opt.add_argument('-batch_size', type=int, default=10, help='minibatch size')
|
16 |
+
cmd_opt.add_argument('-seed', type=int, default=1, help='seed')
|
17 |
+
|
18 |
+
cmd_opt.add_argument('-gm', default='mean_field', help='mean_field/loopy_bp/gcn')
|
19 |
+
cmd_opt.add_argument('-latent_dim', type=int, default=64, help='dimension of latent layers')
|
20 |
+
cmd_opt.add_argument('-hidden', type=int, default=0, help='dimension of classification')
|
21 |
+
cmd_opt.add_argument('-max_lv', type=int, default=1, help='max rounds of message passing')
|
22 |
+
|
23 |
+
# target model
|
24 |
+
cmd_opt.add_argument('-num_epochs', type=int, default=200, help='number of epochs')
|
25 |
+
cmd_opt.add_argument('-learning_rate', type=float, default=0.01, help='init learning_rate')
|
26 |
+
cmd_opt.add_argument('-weight_decay', type=float, default=5e-4, help='weight_decay')
|
27 |
+
cmd_opt.add_argument('-dropout', type=float, default=0.5, help='dropout rate')
|
28 |
+
|
29 |
+
# for node classification
|
30 |
+
cmd_opt.add_argument('-dataset', type=str, default='cora', help='citeseer/cora/pubmed')
|
31 |
+
|
32 |
+
# for attack
|
33 |
+
cmd_opt.add_argument('-num_steps', type=int, default=500000, help='rl training steps')
|
34 |
+
# cmd_opt.add_argument('-frac_meta', type=float, default=0, help='fraction for meta rl learning')
|
35 |
+
|
36 |
+
cmd_opt.add_argument('-meta_test', type=int, default=0, help='for meta rl learning')
|
37 |
+
cmd_opt.add_argument('-reward_type', type=str, default='binary', help='binary/nll')
|
38 |
+
cmd_opt.add_argument('-num_mod', type=int, default=1, help='number of modifications allowed')
|
39 |
+
|
40 |
+
# for node attack
|
41 |
+
cmd_opt.add_argument('-bilin_q', type=int, default=1, help='bilinear q or not')
|
42 |
+
cmd_opt.add_argument('-mlp_hidden', type=int, default=64, help='mlp hidden layer size')
|
43 |
+
# cmd_opt.add_argument('-n_hops', type=int, default=2, help='attack range')
|
44 |
+
|
45 |
+
|
46 |
+
args, _ = cmd_opt.parse_known_args()
|
47 |
+
args.save_dir = './results/rl_s2v/{}-gcn'.format(args.dataset)
|
48 |
+
args.saved_model = 'results/node_classification/{}'.format(args.dataset)
|
49 |
+
print(args)
|
50 |
+
|
51 |
+
def build_kwargs(keys, arg_dict):
|
52 |
+
st = ''
|
53 |
+
for key in keys:
|
54 |
+
st += '%s-%s' % (key, str(arg_dict[key]))
|
55 |
+
return st
|
56 |
+
|
57 |
+
def save_args(fout, args):
|
58 |
+
with open(fout, 'wb') as f:
|
59 |
+
cp.dump(args, f, cp.HIGHEST_PROTOCOL)
|
deeprobust/graph/rl/nipa_nstep_replay_mem.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This part of code is adopted from https://github.com/Hanjun-Dai/graph_adversarial_attack (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le)
|
3 |
+
but modified to be integrated into the repository.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from deeprobust.graph.rl.nstep_replay_mem import *
|
9 |
+
|
10 |
+
|
11 |
+
def nipa_hash_state_action(s_t, a_t):
|
12 |
+
key = s_t[0]
|
13 |
+
base = 179424673
|
14 |
+
for e in s_t[1].directed_edges:
|
15 |
+
key = (key * base + e[0]) % base
|
16 |
+
key = (key * base + e[1]) % base
|
17 |
+
if s_t[2] is not None:
|
18 |
+
key = (key * base + s_t[2]) % base
|
19 |
+
else:
|
20 |
+
key = (key * base) % base
|
21 |
+
|
22 |
+
key = (key * base + a_t) % base
|
23 |
+
return key
|
24 |
+
|
25 |
+
|
26 |
+
class NstepReplayMem(object):
|
27 |
+
def __init__(self, memory_size, n_steps, balance_sample = False):
|
28 |
+
self.mem_cells = []
|
29 |
+
for i in range(n_steps - 1):
|
30 |
+
self.mem_cells.append(NstepReplayMemCell(memory_size, False))
|
31 |
+
self.mem_cells.append(NstepReplayMemCell(memory_size, balance_sample))
|
32 |
+
|
33 |
+
self.n_steps = n_steps
|
34 |
+
self.memory_size = memory_size
|
35 |
+
|
36 |
+
def add(self, s_t, a_t, r_t, s_prime, terminal, t):
|
37 |
+
assert t >= 0 and t < self.n_steps
|
38 |
+
if t == self.n_steps - 1:
|
39 |
+
assert terminal
|
40 |
+
else:
|
41 |
+
assert not terminal
|
42 |
+
self.mem_cells[t].add(s_t, a_t, r_t, s_prime, terminal)
|
43 |
+
|
44 |
+
def add_list(self, list_st, list_at, list_rt, list_sp, list_term, t):
|
45 |
+
for i in range(len(list_st)):
|
46 |
+
if list_sp is None:
|
47 |
+
sp = (None, None, None)
|
48 |
+
else:
|
49 |
+
sp = list_sp[i]
|
50 |
+
self.add(list_st[i], list_at[i], list_rt[i], sp, list_term[i], t)
|
51 |
+
|
52 |
+
def sample(self, batch_size, t = None):
|
53 |
+
if t is None:
|
54 |
+
t = np.random.randint(self.n_steps)
|
55 |
+
list_st, list_at, list_rt, list_s_primes, list_term = self.mem_cells[t].sample(batch_size)
|
56 |
+
return t, list_st, list_at, list_rt, list_s_primes, list_term
|
deeprobust/graph/rl/nipa_q_net_node.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adversarial Attacks on Neural Networks for Graph Data. ICML 2018.
|
3 |
+
https://arxiv.org/abs/1806.02371
|
4 |
+
Author's Implementation
|
5 |
+
https://github.com/Hanjun-Dai/graph_adversarial_attack
|
6 |
+
This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le) but modified
|
7 |
+
to be integrated into the repository.
|
8 |
+
'''
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import networkx as nx
|
15 |
+
import random
|
16 |
+
from torch.nn.parameter import Parameter
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.optim as optim
|
20 |
+
from tqdm import tqdm
|
21 |
+
from deeprobust.graph.rl.env import GraphNormTool
|
22 |
+
|
23 |
+
class QNetNode(nn.Module):
|
24 |
+
|
25 |
+
def __init__(self, node_features, node_labels, list_action_space, n_injected, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'):
|
26 |
+
'''
|
27 |
+
bilin_q: bilinear q or not
|
28 |
+
mlp_hidden: mlp hidden layer size
|
29 |
+
mav_lv: max rounds of message passing
|
30 |
+
'''
|
31 |
+
super(QNetNode, self).__init__()
|
32 |
+
self.node_features = node_features
|
33 |
+
self.identity = torch.eye(node_labels.max() + 1).to(node_labels.device)
|
34 |
+
# self.node_labels = self.to_onehot(node_labels)
|
35 |
+
self.n_injected = n_injected
|
36 |
+
|
37 |
+
self.list_action_space = list_action_space
|
38 |
+
self.total_nodes = len(list_action_space)
|
39 |
+
|
40 |
+
self.bilin_q = bilin_q
|
41 |
+
self.embed_dim = embed_dim
|
42 |
+
self.mlp_hidden = mlp_hidden
|
43 |
+
self.max_lv = max_lv
|
44 |
+
self.gm = gm
|
45 |
+
|
46 |
+
if mlp_hidden:
|
47 |
+
self.linear_1 = nn.Linear(embed_dim * 3, mlp_hidden)
|
48 |
+
self.linear_out = nn.Linear(mlp_hidden, 1)
|
49 |
+
else:
|
50 |
+
self.linear_out = nn.Linear(embed_dim * 3, 1)
|
51 |
+
|
52 |
+
self.w_n2l = Parameter(torch.Tensor(node_features.size()[1], embed_dim))
|
53 |
+
self.bias_n2l = Parameter(torch.Tensor(embed_dim))
|
54 |
+
|
55 |
+
# self.bias_picked = Parameter(torch.Tensor(1, embed_dim))
|
56 |
+
self.conv_params = nn.Linear(embed_dim, embed_dim)
|
57 |
+
self.norm_tool = GraphNormTool(normalize=True, gm=self.gm, device=device)
|
58 |
+
weights_init(self)
|
59 |
+
|
60 |
+
input_dim = (node_labels.max() + 1) * self.n_injected
|
61 |
+
self.label_encoder_1 = nn.Linear(input_dim, mlp_hidden)
|
62 |
+
self.label_encoder_2 = nn.Linear(mlp_hidden, embed_dim)
|
63 |
+
self.device = self.node_features.device
|
64 |
+
|
65 |
+
def to_onehot(self, labels):
|
66 |
+
return self.identity[labels].view(-1, self.identity.shape[1])
|
67 |
+
|
68 |
+
def get_label_embedding(self, labels):
|
69 |
+
# int to one hot
|
70 |
+
onehot = self.to_onehot(labels).view(1, -1)
|
71 |
+
|
72 |
+
x = F.relu(self.label_encoder_1(onehot))
|
73 |
+
x = F.relu(self.label_encoder_2(x))
|
74 |
+
return x
|
75 |
+
|
76 |
+
def get_action_label_encoding(self, label):
|
77 |
+
onehot = self.to_onehot(label)
|
78 |
+
zeros = torch.zeros((onehot.shape[0], self.embed_dim - onehot.shape[1])).to(onehot.device)
|
79 |
+
return torch.cat((onehot, zeros), dim=1)
|
80 |
+
|
81 |
+
def get_graph_embedding(self, adj):
|
82 |
+
if self.node_features.data.is_sparse:
|
83 |
+
node_embed = torch.spmm(self.node_features, self.w_n2l)
|
84 |
+
else:
|
85 |
+
node_embed = torch.mm(self.node_features, self.w_n2l)
|
86 |
+
|
87 |
+
node_embed += self.bias_n2l
|
88 |
+
|
89 |
+
input_message = node_embed
|
90 |
+
node_embed = F.relu(input_message)
|
91 |
+
|
92 |
+
for i in range(self.max_lv):
|
93 |
+
n2npool = torch.spmm(adj, node_embed)
|
94 |
+
node_linear = self.conv_params(n2npool)
|
95 |
+
merged_linear = node_linear + input_message
|
96 |
+
node_embed = F.relu(merged_linear)
|
97 |
+
|
98 |
+
graph_embed = torch.mean(node_embed, dim=0, keepdim=True)
|
99 |
+
return graph_embed, node_embed
|
100 |
+
|
101 |
+
def make_spmat(self, n_rows, n_cols, row_idx, col_idx):
|
102 |
+
idxes = torch.LongTensor([[row_idx], [col_idx]])
|
103 |
+
values = torch.ones(1)
|
104 |
+
|
105 |
+
sp = torch.sparse.FloatTensor(idxes, values, torch.Size([n_rows, n_cols]))
|
106 |
+
if next(self.parameters()).is_cuda:
|
107 |
+
sp = sp.cuda()
|
108 |
+
return sp
|
109 |
+
|
110 |
+
def forward(self, time_t, states, actions, greedy_acts=False, is_inference=False):
|
111 |
+
|
112 |
+
preds = torch.zeros(len(states)).to(self.device)
|
113 |
+
|
114 |
+
batch_graph, modified_labels = zip(*states)
|
115 |
+
greedy_actions = []
|
116 |
+
with torch.set_grad_enabled(mode=not is_inference):
|
117 |
+
|
118 |
+
for i in range(len(batch_graph)):
|
119 |
+
if batch_graph[i] is None:
|
120 |
+
continue
|
121 |
+
adj = self.norm_tool.norm_extra(batch_graph[i].get_extra_adj(self.device))
|
122 |
+
# get graph representation
|
123 |
+
graph_embed, node_embed = self.get_graph_embedding(adj)
|
124 |
+
|
125 |
+
# get label reprensentation
|
126 |
+
label_embed = self.get_label_embedding(modified_labels[i])
|
127 |
+
|
128 |
+
# get action reprensentation
|
129 |
+
if time_t != 2:
|
130 |
+
action_embed = node_embed[actions[i]].view(-1, self.embed_dim)
|
131 |
+
else:
|
132 |
+
action_embed = self.get_action_label_encoding(actions[i])
|
133 |
+
|
134 |
+
# concat them and send it to neural network
|
135 |
+
embed_s = torch.cat((graph_embed, label_embed), dim=1)
|
136 |
+
embed_s = embed_s.repeat(len(action_embed), 1)
|
137 |
+
embed_s_a = torch.cat((embed_s, action_embed), dim=1)
|
138 |
+
|
139 |
+
if self.mlp_hidden:
|
140 |
+
embed_s_a = F.relu( self.linear_1(embed_s_a) )
|
141 |
+
|
142 |
+
raw_pred = self.linear_out(embed_s_a)
|
143 |
+
|
144 |
+
if greedy_acts:
|
145 |
+
action_id = raw_pred.argmax(0)
|
146 |
+
raw_pred = raw_pred.max()
|
147 |
+
greedy_actions.append(actions[i][action_id])
|
148 |
+
else:
|
149 |
+
raw_pred = raw_pred.max()
|
150 |
+
# list_pred.append(raw_pred)
|
151 |
+
preds[i] += raw_pred
|
152 |
+
|
153 |
+
|
154 |
+
return greedy_actions, preds
|
155 |
+
|
156 |
+
class NStepQNetNode(nn.Module):
|
157 |
+
|
158 |
+
def __init__(self, num_steps, node_features, node_labels, list_action_space, n_injected, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'):
|
159 |
+
|
160 |
+
super(NStepQNetNode, self).__init__()
|
161 |
+
self.node_features = node_features
|
162 |
+
self.node_labels = node_labels
|
163 |
+
self.list_action_space = list_action_space
|
164 |
+
self.total_nodes = len(list_action_space)
|
165 |
+
|
166 |
+
list_mod = []
|
167 |
+
for i in range(0, num_steps):
|
168 |
+
# list_mod.append(QNetNode(node_features, node_labels, list_action_space))
|
169 |
+
list_mod.append(QNetNode(node_features, node_labels, list_action_space, n_injected, bilin_q, embed_dim, mlp_hidden, max_lv, gm=gm, device=device))
|
170 |
+
|
171 |
+
self.list_mod = nn.ModuleList(list_mod)
|
172 |
+
self.num_steps = num_steps
|
173 |
+
|
174 |
+
def forward(self, time_t, states, actions, greedy_acts = False, is_inference=False):
|
175 |
+
# print('time_t:', time_t)
|
176 |
+
# print('self.num_step:', self.num_steps)
|
177 |
+
# assert time_t >= 0 and time_t < self.num_steps
|
178 |
+
time_t = time_t % 3
|
179 |
+
return self.list_mod[time_t](time_t, states, actions, greedy_acts, is_inference)
|
180 |
+
|
181 |
+
|
182 |
+
def glorot_uniform(t):
|
183 |
+
if len(t.size()) == 2:
|
184 |
+
fan_in, fan_out = t.size()
|
185 |
+
elif len(t.size()) == 3:
|
186 |
+
# out_ch, in_ch, kernel for Conv 1
|
187 |
+
fan_in = t.size()[1] * t.size()[2]
|
188 |
+
fan_out = t.size()[0] * t.size()[2]
|
189 |
+
else:
|
190 |
+
fan_in = np.prod(t.size())
|
191 |
+
fan_out = np.prod(t.size())
|
192 |
+
|
193 |
+
limit = np.sqrt(6.0 / (fan_in + fan_out))
|
194 |
+
t.uniform_(-limit, limit)
|
195 |
+
|
196 |
+
|
197 |
+
def _param_init(m):
|
198 |
+
if isinstance(m, Parameter):
|
199 |
+
glorot_uniform(m.data)
|
200 |
+
elif isinstance(m, nn.Linear):
|
201 |
+
m.bias.data.zero_()
|
202 |
+
glorot_uniform(m.weight.data)
|
203 |
+
|
204 |
+
def weights_init(m):
|
205 |
+
for p in m.modules():
|
206 |
+
if isinstance(p, nn.ParameterList):
|
207 |
+
for pp in p:
|
208 |
+
_param_init(pp)
|
209 |
+
else:
|
210 |
+
_param_init(p)
|
211 |
+
|
212 |
+
for name, p in m.named_parameters():
|
213 |
+
if not '.' in name: # top-level parameters
|
214 |
+
_param_init(p)
|
215 |
+
|
216 |
+
def node_greedy_actions(target_nodes, picked_nodes, list_q, net):
|
217 |
+
assert len(target_nodes) == len(list_q)
|
218 |
+
|
219 |
+
actions = []
|
220 |
+
values = []
|
221 |
+
for i in range(len(target_nodes)):
|
222 |
+
region = net.list_action_space[target_nodes[i]]
|
223 |
+
if picked_nodes is not None and picked_nodes[i] is not None:
|
224 |
+
region = net.list_action_space[picked_nodes[i]]
|
225 |
+
if region is None:
|
226 |
+
assert list_q[i].size()[0] == net.total_nodes
|
227 |
+
else:
|
228 |
+
assert len(region) == list_q[i].size()[0]
|
229 |
+
|
230 |
+
val, act = torch.max(list_q[i], dim=0)
|
231 |
+
values.append(val)
|
232 |
+
if region is not None:
|
233 |
+
act = region[act.data.cpu().numpy()[0]]
|
234 |
+
# act = Variable(torch.LongTensor([act]))
|
235 |
+
act = torch.LongTensor([act])
|
236 |
+
actions.append(act)
|
237 |
+
else:
|
238 |
+
actions.append(act)
|
239 |
+
|
240 |
+
return torch.cat(actions, dim=0).data, torch.cat(values, dim=0).data
|
241 |
+
|
242 |
+
|
deeprobust/graph/rl/nstep_replay_mem.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This part of code is adopted from https://github.com/Hanjun-Dai/graph_adversarial_attack (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le)
|
3 |
+
but modified to be integrated into the repository.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class NstepReplaySubMemCell(object):
|
10 |
+
def __init__(self, memory_size):
|
11 |
+
self.memory_size = memory_size
|
12 |
+
|
13 |
+
self.actions = [None] * self.memory_size
|
14 |
+
self.rewards = [None] * self.memory_size
|
15 |
+
self.states = [None] * self.memory_size
|
16 |
+
self.s_primes = [None] * self.memory_size
|
17 |
+
self.terminals = [None] * self.memory_size
|
18 |
+
|
19 |
+
self.count = 0
|
20 |
+
self.current = 0
|
21 |
+
|
22 |
+
def add(self, s_t, a_t, r_t, s_prime, terminal):
|
23 |
+
self.actions[self.current] = a_t
|
24 |
+
self.rewards[self.current] = r_t
|
25 |
+
self.states[self.current] = s_t
|
26 |
+
self.s_primes[self.current] = s_prime
|
27 |
+
self.terminals[self.current] = terminal
|
28 |
+
|
29 |
+
self.count = max(self.count, self.current + 1)
|
30 |
+
self.current = (self.current + 1) % self.memory_size
|
31 |
+
|
32 |
+
def add_list(self, list_st, list_at, list_rt, list_sp, list_term):
|
33 |
+
for i in range(len(list_st)):
|
34 |
+
if list_sp is None:
|
35 |
+
sp = (None, None, None)
|
36 |
+
else:
|
37 |
+
sp = list_sp[i]
|
38 |
+
self.add(list_st[i], list_at[i], list_rt[i], sp, list_term[i])
|
39 |
+
|
40 |
+
def sample(self, batch_size):
|
41 |
+
|
42 |
+
assert self.count >= batch_size
|
43 |
+
list_st = []
|
44 |
+
list_at = []
|
45 |
+
list_rt = []
|
46 |
+
list_s_primes = []
|
47 |
+
list_term = []
|
48 |
+
|
49 |
+
for i in range(batch_size):
|
50 |
+
idx = random.randint(0, self.count - 1)
|
51 |
+
list_st.append(self.states[idx])
|
52 |
+
list_at.append(self.actions[idx])
|
53 |
+
list_rt.append(float(self.rewards[idx]))
|
54 |
+
list_s_primes.append(self.s_primes[idx])
|
55 |
+
list_term.append(self.terminals[idx])
|
56 |
+
|
57 |
+
return list_st, list_at, list_rt, list_s_primes, list_term
|
58 |
+
|
59 |
+
def hash_state_action(s_t, a_t):
|
60 |
+
key = s_t[0]
|
61 |
+
base = 179424673
|
62 |
+
for e in s_t[1].directed_edges:
|
63 |
+
key = (key * base + e[0]) % base
|
64 |
+
key = (key * base + e[1]) % base
|
65 |
+
if s_t[2] is not None:
|
66 |
+
key = (key * base + s_t[2]) % base
|
67 |
+
else:
|
68 |
+
key = (key * base) % base
|
69 |
+
|
70 |
+
key = (key * base + a_t) % base
|
71 |
+
return key
|
72 |
+
|
73 |
+
def nipa_hash_state_action(s_t, a_t):
|
74 |
+
key = s_t[0]
|
75 |
+
base = 179424673
|
76 |
+
for e in s_t[1].directed_edges:
|
77 |
+
key = (key * base + e[0]) % base
|
78 |
+
key = (key * base + e[1]) % base
|
79 |
+
if s_t[2] is not None:
|
80 |
+
key = (key * base + s_t[2]) % base
|
81 |
+
else:
|
82 |
+
key = (key * base) % base
|
83 |
+
|
84 |
+
key = (key * base + a_t) % base
|
85 |
+
return key
|
86 |
+
|
87 |
+
class NstepReplayMemCell(object):
|
88 |
+
def __init__(self, memory_size, balance_sample = False):
|
89 |
+
self.sub_list = []
|
90 |
+
self.balance_sample = balance_sample
|
91 |
+
self.sub_list.append(NstepReplaySubMemCell(memory_size))
|
92 |
+
if balance_sample:
|
93 |
+
self.sub_list.append(NstepReplaySubMemCell(memory_size))
|
94 |
+
self.state_set = set()
|
95 |
+
|
96 |
+
def add(self, s_t, a_t, r_t, s_prime, terminal, use_hash=True):
|
97 |
+
if not self.balance_sample or r_t < 0:
|
98 |
+
self.sub_list[0].add(s_t, a_t, r_t, s_prime, terminal)
|
99 |
+
else:
|
100 |
+
assert r_t > 0
|
101 |
+
if use_hash:
|
102 |
+
# TODO add hash?
|
103 |
+
key = hash_state_action(s_t, a_t)
|
104 |
+
if key in self.state_set:
|
105 |
+
return
|
106 |
+
self.state_set.add(key)
|
107 |
+
self.sub_list[1].add(s_t, a_t, r_t, s_prime, terminal)
|
108 |
+
|
109 |
+
def sample(self, batch_size):
|
110 |
+
if not self.balance_sample or self.sub_list[1].count < batch_size:
|
111 |
+
return self.sub_list[0].sample(batch_size)
|
112 |
+
|
113 |
+
list_st, list_at, list_rt, list_s_primes, list_term = self.sub_list[0].sample(batch_size // 2)
|
114 |
+
list_st2, list_at2, list_rt2, list_s_primes2, list_term2 = self.sub_list[1].sample(batch_size - batch_size // 2)
|
115 |
+
|
116 |
+
return list_st + list_st2, list_at + list_at2, list_rt + list_rt2, list_s_primes + list_s_primes2, list_term + list_term2
|
117 |
+
|
118 |
+
class NstepReplayMem(object):
|
119 |
+
def __init__(self, memory_size, n_steps, balance_sample=False, model='rl_s2v'):
|
120 |
+
self.mem_cells = []
|
121 |
+
for i in range(n_steps - 1):
|
122 |
+
self.mem_cells.append(NstepReplayMemCell(memory_size, False))
|
123 |
+
self.mem_cells.append(NstepReplayMemCell(memory_size, balance_sample))
|
124 |
+
|
125 |
+
self.n_steps = n_steps
|
126 |
+
self.memory_size = memory_size
|
127 |
+
self.model = model
|
128 |
+
|
129 |
+
def add(self, s_t, a_t, r_t, s_prime, terminal, t):
|
130 |
+
assert t >= 0 and t < self.n_steps
|
131 |
+
if self.model == 'nipa':
|
132 |
+
self.mem_cells[t].add(s_t, a_t, r_t, s_prime, terminal, use_hash=False)
|
133 |
+
else:
|
134 |
+
if t == self.n_steps - 1:
|
135 |
+
assert terminal
|
136 |
+
else:
|
137 |
+
assert not terminal
|
138 |
+
self.mem_cells[t].add(s_t, a_t, r_t, s_prime, terminal, use_hash=True)
|
139 |
+
|
140 |
+
def add_list(self, list_st, list_at, list_rt, list_sp, list_term, t):
|
141 |
+
for i in range(len(list_st)):
|
142 |
+
if list_sp is None:
|
143 |
+
sp = (None, None, None)
|
144 |
+
else:
|
145 |
+
sp = list_sp[i]
|
146 |
+
self.add(list_st[i], list_at[i], list_rt[i], sp, list_term[i], t)
|
147 |
+
|
148 |
+
def sample(self, batch_size, t = None):
|
149 |
+
if t is None:
|
150 |
+
t = np.random.randint(self.n_steps)
|
151 |
+
list_st, list_at, list_rt, list_s_primes, list_term = self.mem_cells[t].sample(batch_size)
|
152 |
+
return t, list_st, list_at, list_rt, list_s_primes, list_term
|
153 |
+
|
154 |
+
def print_count(self):
|
155 |
+
for i in range(self.n_steps):
|
156 |
+
for j, cell in enumerate(self.mem_cells[i].sub_list):
|
157 |
+
print('Cell {} sub_list {}: {}'.format(i, j, cell.count))
|
deeprobust/graph/rl/rl_s2v_env.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Attacks on Neural Networks for Graph Data. ICML 2018.
|
3 |
+
https://arxiv.org/abs/1806.02371
|
4 |
+
Author's Implementation
|
5 |
+
https://github.com/Hanjun-Dai/graph_adversarial_attack
|
6 |
+
This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le) but modified
|
7 |
+
to be integrated into the repository.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import networkx as nx
|
15 |
+
import random
|
16 |
+
from torch.nn.parameter import Parameter
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.optim as optim
|
20 |
+
from tqdm import tqdm
|
21 |
+
from copy import deepcopy
|
22 |
+
import pickle as cp
|
23 |
+
from deeprobust.graph.utils import *
|
24 |
+
import scipy.sparse as sp
|
25 |
+
from scipy.sparse.linalg.eigen.arpack import eigsh
|
26 |
+
from deeprobust.graph import utils
|
27 |
+
|
28 |
+
class StaticGraph(object):
|
29 |
+
graph = None
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def get_gsize():
|
33 |
+
return torch.Size( (len(StaticGraph.graph), len(StaticGraph.graph)) )
|
34 |
+
|
35 |
+
class GraphNormTool(object):
|
36 |
+
|
37 |
+
def __init__(self, normalize, gm, device):
|
38 |
+
self.adj_norm = normalize
|
39 |
+
self.gm = gm
|
40 |
+
g = StaticGraph.graph
|
41 |
+
edges = np.array(g.edges(), dtype=np.int64)
|
42 |
+
rev_edges = np.array([edges[:, 1], edges[:, 0]], dtype=np.int64)
|
43 |
+
|
44 |
+
# self_edges = np.array([range(len(g)), range(len(g))], dtype=np.int64)
|
45 |
+
# edges = np.hstack((edges.T, rev_edges, self_edges))
|
46 |
+
edges = np.hstack((edges.T, rev_edges))
|
47 |
+
idxes = torch.LongTensor(edges)
|
48 |
+
values = torch.ones(idxes.size()[1])
|
49 |
+
|
50 |
+
self.raw_adj = torch.sparse.FloatTensor(idxes, values, StaticGraph.get_gsize())
|
51 |
+
self.raw_adj = self.raw_adj.to(device)
|
52 |
+
|
53 |
+
self.normed_adj = self.raw_adj.clone()
|
54 |
+
if self.adj_norm:
|
55 |
+
if self.gm == 'gcn':
|
56 |
+
self.normed_adj = utils.normalize_adj_tensor(self.normed_adj, sparse=True)
|
57 |
+
# GraphLaplacianNorm(self.normed_adj)
|
58 |
+
else:
|
59 |
+
|
60 |
+
self.normed_adj = utils.degree_normalize_adj_tensor(self.normed_adj, sparse=True)
|
61 |
+
# GraphDegreeNorm(self.normed_adj)
|
62 |
+
|
63 |
+
def norm_extra(self, added_adj = None):
|
64 |
+
if added_adj is None:
|
65 |
+
return self.normed_adj
|
66 |
+
|
67 |
+
new_adj = self.raw_adj + added_adj
|
68 |
+
if self.adj_norm:
|
69 |
+
if self.gm == 'gcn':
|
70 |
+
new_adj = utils.normalize_adj_tensor(new_adj, sparse=True)
|
71 |
+
else:
|
72 |
+
new_adj = utils.degree_normalize_adj_tensor(new_adj, sparse=True)
|
73 |
+
|
74 |
+
return new_adj
|
75 |
+
|
76 |
+
|
77 |
+
class ModifiedGraph(object):
|
78 |
+
def __init__(self, directed_edges = None, weights = None):
|
79 |
+
self.edge_set = set() #(first, second)
|
80 |
+
self.node_set = set(range(StaticGraph.get_gsize()[0]))
|
81 |
+
self.node_set = np.arange(StaticGraph.get_gsize()[0])
|
82 |
+
if directed_edges is not None:
|
83 |
+
self.directed_edges = deepcopy(directed_edges)
|
84 |
+
self.weights = deepcopy(weights)
|
85 |
+
else:
|
86 |
+
self.directed_edges = []
|
87 |
+
self.weights = []
|
88 |
+
|
89 |
+
def add_edge(self, x, y, z):
|
90 |
+
assert x is not None and y is not None
|
91 |
+
if x == y:
|
92 |
+
return
|
93 |
+
for e in self.directed_edges:
|
94 |
+
if e[0] == x and e[1] == y:
|
95 |
+
return
|
96 |
+
if e[1] == x and e[0] == y:
|
97 |
+
return
|
98 |
+
self.edge_set.add((x, y)) # (first, second)
|
99 |
+
self.edge_set.add((y, x)) # (second, first)
|
100 |
+
self.directed_edges.append((x, y))
|
101 |
+
# assert z < 0
|
102 |
+
self.weights.append(z)
|
103 |
+
|
104 |
+
def get_extra_adj(self, device):
|
105 |
+
if len(self.directed_edges):
|
106 |
+
edges = np.array(self.directed_edges, dtype=np.int64)
|
107 |
+
rev_edges = np.array([edges[:, 1], edges[:, 0]], dtype=np.int64)
|
108 |
+
edges = np.hstack((edges.T, rev_edges))
|
109 |
+
|
110 |
+
idxes = torch.LongTensor(edges)
|
111 |
+
values = torch.Tensor(self.weights + self.weights)
|
112 |
+
|
113 |
+
added_adj = torch.sparse.FloatTensor(idxes, values, StaticGraph.get_gsize())
|
114 |
+
|
115 |
+
added_adj = added_adj.to(device)
|
116 |
+
return added_adj
|
117 |
+
else:
|
118 |
+
return None
|
119 |
+
|
120 |
+
def get_possible_nodes(self, target_node):
|
121 |
+
connected = set()
|
122 |
+
connected = []
|
123 |
+
for n1, n2 in self.edge_set:
|
124 |
+
if n1 == target_node:
|
125 |
+
# connected.add(target_node)
|
126 |
+
connected.append(n1)
|
127 |
+
return np.setdiff1d(self.node_set, np.array(connected))
|
128 |
+
# return self.node_set - connected
|
129 |
+
|
130 |
+
class NodeAttackEnv(object):
|
131 |
+
"""Node attack environment. It executes an action and then change the
|
132 |
+
environment status (modify the graph).
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, features, labels, all_targets, list_action_space, classifier, num_mod=1, reward_type='binary'):
|
136 |
+
|
137 |
+
self.classifier = classifier
|
138 |
+
self.list_action_space = list_action_space
|
139 |
+
self.features = features
|
140 |
+
self.labels = labels
|
141 |
+
self.all_targets = all_targets
|
142 |
+
self.num_mod = num_mod
|
143 |
+
self.reward_type = reward_type
|
144 |
+
|
145 |
+
def setup(self, target_nodes):
|
146 |
+
self.target_nodes = target_nodes
|
147 |
+
self.n_steps = 0
|
148 |
+
self.first_nodes = None
|
149 |
+
self.rewards = None
|
150 |
+
self.binary_rewards = None
|
151 |
+
self.modified_list = []
|
152 |
+
for i in range(len(self.target_nodes)):
|
153 |
+
self.modified_list.append(ModifiedGraph())
|
154 |
+
|
155 |
+
self.list_acc_of_all = []
|
156 |
+
|
157 |
+
def step(self, actions):
|
158 |
+
"""run actions and get rewards
|
159 |
+
"""
|
160 |
+
if self.first_nodes is None: # pick the first node of edge
|
161 |
+
assert self.n_steps % 2 == 0
|
162 |
+
self.first_nodes = actions[:]
|
163 |
+
else:
|
164 |
+
for i in range(len(self.target_nodes)):
|
165 |
+
# assert self.first_nodes[i] != actions[i]
|
166 |
+
# deleta an edge from the graph
|
167 |
+
self.modified_list[i].add_edge(self.first_nodes[i], actions[i], -1.0)
|
168 |
+
self.first_nodes = None
|
169 |
+
self.banned_list = None
|
170 |
+
self.n_steps += 1
|
171 |
+
|
172 |
+
if self.isTerminal():
|
173 |
+
# only calc reward when its terminal
|
174 |
+
acc_list = []
|
175 |
+
loss_list = []
|
176 |
+
# for i in tqdm(range(len(self.target_nodes))):
|
177 |
+
for i in (range(len(self.target_nodes))):
|
178 |
+
device = self.labels.device
|
179 |
+
extra_adj = self.modified_list[i].get_extra_adj(device=device)
|
180 |
+
adj = self.classifier.norm_tool.norm_extra(extra_adj)
|
181 |
+
|
182 |
+
output = self.classifier(self.features, adj)
|
183 |
+
|
184 |
+
loss, acc = loss_acc(output, self.labels, self.all_targets, avg_loss=False)
|
185 |
+
# _, loss, acc = self.classifier(self.features, Variable(adj), self.all_targets, self.labels, avg_loss=False)
|
186 |
+
|
187 |
+
cur_idx = self.all_targets.index(self.target_nodes[i])
|
188 |
+
acc = np.copy(acc.double().cpu().view(-1).numpy())
|
189 |
+
loss = loss.data.cpu().view(-1).numpy()
|
190 |
+
self.list_acc_of_all.append(acc)
|
191 |
+
acc_list.append(acc[cur_idx])
|
192 |
+
loss_list.append(loss[cur_idx])
|
193 |
+
|
194 |
+
self.binary_rewards = (np.array(acc_list) * -2.0 + 1.0).astype(np.float32)
|
195 |
+
if self.reward_type == 'binary':
|
196 |
+
self.rewards = (np.array(acc_list) * -2.0 + 1.0).astype(np.float32)
|
197 |
+
else:
|
198 |
+
assert self.reward_type == 'nll'
|
199 |
+
self.rewards = np.array(loss_list).astype(np.float32)
|
200 |
+
|
201 |
+
def sample_pos_rewards(self, num_samples):
|
202 |
+
assert self.list_acc_of_all is not None
|
203 |
+
cands = []
|
204 |
+
|
205 |
+
for i in range(len(self.list_acc_of_all)):
|
206 |
+
succ = np.where( self.list_acc_of_all[i] < 0.9 )[0]
|
207 |
+
|
208 |
+
for j in range(len(succ)):
|
209 |
+
|
210 |
+
cands.append((i, self.all_targets[succ[j]]))
|
211 |
+
|
212 |
+
if num_samples > len(cands):
|
213 |
+
return cands
|
214 |
+
random.shuffle(cands)
|
215 |
+
return cands[0:num_samples]
|
216 |
+
|
217 |
+
def uniformRandActions(self):
|
218 |
+
# TODO: here only support deleting edges
|
219 |
+
# seems they sample first node from 2-hop neighbours
|
220 |
+
act_list = []
|
221 |
+
offset = 0
|
222 |
+
for i in range(len(self.target_nodes)):
|
223 |
+
cur_node = self.target_nodes[i]
|
224 |
+
region = self.list_action_space[cur_node]
|
225 |
+
|
226 |
+
if self.first_nodes is not None and self.first_nodes[i] is not None:
|
227 |
+
region = self.list_action_space[self.first_nodes[i]]
|
228 |
+
|
229 |
+
if region is None: # singleton node
|
230 |
+
cur_action = np.random.randint(len(self.list_action_space))
|
231 |
+
else: # select from neighbours or 2-hop neighbours
|
232 |
+
cur_action = region[np.random.randint(len(region))]
|
233 |
+
|
234 |
+
act_list.append(cur_action)
|
235 |
+
return act_list
|
236 |
+
|
237 |
+
def isTerminal(self):
|
238 |
+
if self.n_steps == 2 * self.num_mod:
|
239 |
+
return True
|
240 |
+
return False
|
241 |
+
|
242 |
+
def getStateRef(self):
|
243 |
+
cp_first = [None] * len(self.target_nodes)
|
244 |
+
if self.first_nodes is not None:
|
245 |
+
cp_first = self.first_nodes
|
246 |
+
|
247 |
+
return zip(self.target_nodes, self.modified_list, cp_first)
|
248 |
+
|
249 |
+
def cloneState(self):
|
250 |
+
cp_first = [None] * len(self.target_nodes)
|
251 |
+
if self.first_nodes is not None:
|
252 |
+
cp_first = self.first_nodes[:]
|
253 |
+
|
254 |
+
return list(zip(self.target_nodes[:], deepcopy(self.modified_list), cp_first))
|
255 |
+
|
256 |
+
|
deeprobust/graph/targeted_attack/rl_s2v.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Attacks on Neural Networks for Graph Data. ICML 2018.
|
3 |
+
https://arxiv.org/abs/1806.02371
|
4 |
+
Author's Implementation
|
5 |
+
https://github.com/Hanjun-Dai/graph_adversarial_attack
|
6 |
+
This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le)
|
7 |
+
but modified to be integrated into the repository.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import os.path as osp
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import networkx as nx
|
16 |
+
import random
|
17 |
+
from torch.nn.parameter import Parameter
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torch.optim as optim
|
21 |
+
from tqdm import tqdm
|
22 |
+
from copy import deepcopy
|
23 |
+
from deeprobust.graph.rl.q_net_node import QNetNode, NStepQNetNode, node_greedy_actions
|
24 |
+
from deeprobust.graph.rl.env import NodeAttackEnv
|
25 |
+
from deeprobust.graph.rl.nstep_replay_mem import NstepReplayMem
|
26 |
+
|
27 |
+
class RLS2V(object):
|
28 |
+
""" Reinforcement learning agent for RL-S2V attack.
|
29 |
+
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
env :
|
33 |
+
Node attack environment
|
34 |
+
features :
|
35 |
+
node features matrix
|
36 |
+
labels :
|
37 |
+
labels
|
38 |
+
idx_meta :
|
39 |
+
node meta indices
|
40 |
+
idx_test :
|
41 |
+
node test indices
|
42 |
+
list_action_space : list
|
43 |
+
list of action space
|
44 |
+
num_mod :
|
45 |
+
number of modification (perturbation) on the graph
|
46 |
+
reward_type : str
|
47 |
+
type of reward (e.g., 'binary')
|
48 |
+
batch_size :
|
49 |
+
batch size for training DQN
|
50 |
+
save_dir :
|
51 |
+
saving directory for model checkpoints
|
52 |
+
device: str
|
53 |
+
'cpu' or 'cuda'
|
54 |
+
|
55 |
+
Examples
|
56 |
+
--------
|
57 |
+
See details in https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_rl_s2v.py
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, env, features, labels, idx_meta, idx_test,
|
61 |
+
list_action_space, num_mod, reward_type, batch_size=10,
|
62 |
+
num_wrong=0, bilin_q=1, embed_dim=64, gm='mean_field',
|
63 |
+
mlp_hidden=64, max_lv=1, save_dir='checkpoint_dqn', device=None):
|
64 |
+
|
65 |
+
|
66 |
+
assert device is not None, "'device' cannot be None, please specify it"
|
67 |
+
|
68 |
+
self.features = features
|
69 |
+
self.labels = labels
|
70 |
+
self.idx_meta = idx_meta
|
71 |
+
self.idx_test = idx_test
|
72 |
+
self.num_wrong = num_wrong
|
73 |
+
self.list_action_space = list_action_space
|
74 |
+
self.num_mod = num_mod
|
75 |
+
self.reward_type = reward_type
|
76 |
+
self.batch_size = batch_size
|
77 |
+
self.save_dir = save_dir
|
78 |
+
if not osp.exists(save_dir):
|
79 |
+
os.system('mkdir -p {}'.format(save_dir))
|
80 |
+
|
81 |
+
self.gm = gm
|
82 |
+
self.device = device
|
83 |
+
|
84 |
+
self.mem_pool = NstepReplayMem(memory_size=500000, n_steps=2 * num_mod, balance_sample=reward_type == 'binary')
|
85 |
+
self.env = env
|
86 |
+
|
87 |
+
# self.net = QNetNode(features, labels, list_action_space)
|
88 |
+
# self.old_net = QNetNode(features, labels, list_action_space)
|
89 |
+
self.net = NStepQNetNode(2 * num_mod, features, labels, list_action_space,
|
90 |
+
bilin_q=bilin_q, embed_dim=embed_dim, mlp_hidden=mlp_hidden,
|
91 |
+
max_lv=max_lv, gm=gm, device=device)
|
92 |
+
|
93 |
+
self.old_net = NStepQNetNode(2 * num_mod, features, labels, list_action_space,
|
94 |
+
bilin_q=bilin_q, embed_dim=embed_dim, mlp_hidden=mlp_hidden,
|
95 |
+
max_lv=max_lv, gm=gm, device=device)
|
96 |
+
|
97 |
+
self.net = self.net.to(device)
|
98 |
+
self.old_net = self.old_net.to(device)
|
99 |
+
|
100 |
+
self.eps_start = 1.0
|
101 |
+
self.eps_end = 0.05
|
102 |
+
self.eps_step = 100000
|
103 |
+
self.burn_in = 10
|
104 |
+
self.step = 0
|
105 |
+
self.pos = 0
|
106 |
+
self.best_eval = None
|
107 |
+
self.take_snapshot()
|
108 |
+
|
109 |
+
def take_snapshot(self):
|
110 |
+
self.old_net.load_state_dict(self.net.state_dict())
|
111 |
+
|
112 |
+
def make_actions(self, time_t, greedy=False):
|
113 |
+
self.eps = self.eps_end + max(0., (self.eps_start - self.eps_end)
|
114 |
+
* (self.eps_step - max(0., self.step)) / self.eps_step)
|
115 |
+
|
116 |
+
if random.random() < self.eps and not greedy:
|
117 |
+
actions = self.env.uniformRandActions()
|
118 |
+
else:
|
119 |
+
cur_state = self.env.getStateRef()
|
120 |
+
actions, values = self.net(time_t, cur_state, None, greedy_acts=True, is_inference=True)
|
121 |
+
actions = list(actions.cpu().numpy())
|
122 |
+
|
123 |
+
return actions
|
124 |
+
|
125 |
+
def run_simulation(self):
|
126 |
+
|
127 |
+
if (self.pos + 1) * self.batch_size > len(self.idx_test):
|
128 |
+
self.pos = 0
|
129 |
+
random.shuffle(self.idx_test)
|
130 |
+
|
131 |
+
selected_idx = self.idx_test[self.pos * self.batch_size : (self.pos + 1) * self.batch_size]
|
132 |
+
self.pos += 1
|
133 |
+
self.env.setup(selected_idx)
|
134 |
+
|
135 |
+
t = 0
|
136 |
+
list_of_list_st = []
|
137 |
+
list_of_list_at = []
|
138 |
+
|
139 |
+
while not self.env.isTerminal():
|
140 |
+
list_at = self.make_actions(t)
|
141 |
+
list_st = self.env.cloneState()
|
142 |
+
|
143 |
+
self.env.step(list_at)
|
144 |
+
|
145 |
+
# TODO Wei added line #87
|
146 |
+
env = self.env
|
147 |
+
assert (env.rewards is not None) == env.isTerminal()
|
148 |
+
if env.isTerminal():
|
149 |
+
rewards = env.rewards
|
150 |
+
s_prime = None
|
151 |
+
else:
|
152 |
+
rewards = np.zeros(len(list_at), dtype=np.float32)
|
153 |
+
s_prime = self.env.cloneState()
|
154 |
+
|
155 |
+
self.mem_pool.add_list(list_st, list_at, rewards, s_prime, [env.isTerminal()] * len(list_at), t)
|
156 |
+
list_of_list_st.append( deepcopy(list_st) )
|
157 |
+
list_of_list_at.append( deepcopy(list_at) )
|
158 |
+
t += 1
|
159 |
+
|
160 |
+
# if the reward type is nll_loss, directly return
|
161 |
+
if self.reward_type == 'nll':
|
162 |
+
return
|
163 |
+
|
164 |
+
T = t
|
165 |
+
cands = self.env.sample_pos_rewards(len(selected_idx))
|
166 |
+
if len(cands):
|
167 |
+
for c in cands:
|
168 |
+
sample_idx, target = c
|
169 |
+
doable = True
|
170 |
+
for t in range(T):
|
171 |
+
if self.list_action_space[target] is not None and (not list_of_list_at[t][sample_idx] in self.list_action_space[target]):
|
172 |
+
doable = False # TODO WHY False? This is only 1-hop neighbour
|
173 |
+
break
|
174 |
+
if not doable:
|
175 |
+
continue
|
176 |
+
|
177 |
+
for t in range(T):
|
178 |
+
s_t = list_of_list_st[t][sample_idx]
|
179 |
+
a_t = list_of_list_at[t][sample_idx]
|
180 |
+
s_t = [target, deepcopy(s_t[1]), s_t[2]]
|
181 |
+
if t + 1 == T:
|
182 |
+
s_prime = (None, None, None)
|
183 |
+
r = 1.0
|
184 |
+
term = True
|
185 |
+
else:
|
186 |
+
s_prime = list_of_list_st[t + 1][sample_idx]
|
187 |
+
s_prime = [target, deepcopy(s_prime[1]), s_prime[2]]
|
188 |
+
r = 0.0
|
189 |
+
term = False
|
190 |
+
self.mem_pool.mem_cells[t].add(s_t, a_t, r, s_prime, term)
|
191 |
+
|
192 |
+
def eval(self, training=True):
|
193 |
+
"""Evaluate RL agent.
|
194 |
+
"""
|
195 |
+
|
196 |
+
self.env.setup(self.idx_meta)
|
197 |
+
t = 0
|
198 |
+
|
199 |
+
while not self.env.isTerminal():
|
200 |
+
list_at = self.make_actions(t, greedy=True)
|
201 |
+
self.env.step(list_at)
|
202 |
+
t += 1
|
203 |
+
|
204 |
+
acc = 1 - (self.env.binary_rewards + 1.0) / 2.0
|
205 |
+
acc = np.sum(acc) / (len(self.idx_meta) + self.num_wrong)
|
206 |
+
print('\033[93m average test: acc %.5f\033[0m' % (acc))
|
207 |
+
|
208 |
+
if training == True and self.best_eval is None or acc < self.best_eval:
|
209 |
+
print('----saving to best attacker since this is the best attack rate so far.----')
|
210 |
+
torch.save(self.net.state_dict(), osp.join(self.save_dir, 'epoch-best.model'))
|
211 |
+
with open(osp.join(self.save_dir, 'epoch-best.txt'), 'w') as f:
|
212 |
+
f.write('%.4f\n' % acc)
|
213 |
+
with open(osp.join(self.save_dir, 'attack_solution.txt'), 'w') as f:
|
214 |
+
for i in range(len(self.idx_meta)):
|
215 |
+
f.write('%d: [' % self.idx_meta[i])
|
216 |
+
for e in self.env.modified_list[i].directed_edges:
|
217 |
+
f.write('(%d %d)' % e)
|
218 |
+
f.write('] succ: %d\n' % (self.env.binary_rewards[i]))
|
219 |
+
self.best_eval = acc
|
220 |
+
|
221 |
+
def train(self, num_steps=100000, lr=0.001):
|
222 |
+
"""Train RL agent.
|
223 |
+
"""
|
224 |
+
|
225 |
+
pbar = tqdm(range(self.burn_in), unit='batch')
|
226 |
+
|
227 |
+
for p in pbar:
|
228 |
+
self.run_simulation()
|
229 |
+
|
230 |
+
pbar = tqdm(range(num_steps), unit='steps')
|
231 |
+
optimizer = optim.Adam(self.net.parameters(), lr=lr)
|
232 |
+
|
233 |
+
for self.step in pbar:
|
234 |
+
|
235 |
+
self.run_simulation()
|
236 |
+
|
237 |
+
if self.step % 123 == 0:
|
238 |
+
# update the params of old_net
|
239 |
+
self.take_snapshot()
|
240 |
+
if self.step % 500 == 0:
|
241 |
+
self.eval()
|
242 |
+
|
243 |
+
cur_time, list_st, list_at, list_rt, list_s_primes, list_term = self.mem_pool.sample(batch_size=self.batch_size)
|
244 |
+
list_target = torch.Tensor(list_rt).to(self.device)
|
245 |
+
|
246 |
+
if not list_term[0]:
|
247 |
+
target_nodes, _, picked_nodes = zip(*list_s_primes)
|
248 |
+
_, q_t_plus_1 = self.old_net(cur_time + 1, list_s_primes, None)
|
249 |
+
_, q_rhs = node_greedy_actions(target_nodes, picked_nodes, q_t_plus_1, self.old_net)
|
250 |
+
list_target += q_rhs
|
251 |
+
|
252 |
+
# list_target = Variable(list_target.view(-1, 1))
|
253 |
+
list_target = list_target.view(-1, 1)
|
254 |
+
_, q_sa = self.net(cur_time, list_st, list_at)
|
255 |
+
q_sa = torch.cat(q_sa, dim=0)
|
256 |
+
loss = F.mse_loss(q_sa, list_target)
|
257 |
+
optimizer.zero_grad()
|
258 |
+
loss.backward()
|
259 |
+
optimizer.step()
|
260 |
+
pbar.set_description('eps: %.5f, loss: %0.5f, q_val: %.5f' % (self.eps, loss, torch.mean(q_sa)) )
|
261 |
+
# print('eps: %.5f, loss: %0.5f, q_val: %.5f' % (self.eps, loss, torch.mean(q_sa)) )
|
262 |
+
|
deeprobust/graph/visualization.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns; sns.set()
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from tqdm import tqdm
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import scipy.sparse as sp
|
7 |
+
|
8 |
+
def degree_dist(clean_adj, perturbed_adj, savename='degree_dist.pdf'):
|
9 |
+
"""Plot degree distributnio on clean and perturbed graphs.
|
10 |
+
|
11 |
+
Parameters
|
12 |
+
----------
|
13 |
+
clean_adj: sp.csr_matrix
|
14 |
+
adjancecy matrix of the clean graph
|
15 |
+
perturbed_adj: sp.csr_matrix
|
16 |
+
adjancecy matrix of the perturbed graph
|
17 |
+
savename: str
|
18 |
+
filename to be saved
|
19 |
+
|
20 |
+
Returns
|
21 |
+
-------
|
22 |
+
None
|
23 |
+
|
24 |
+
"""
|
25 |
+
clean_degree = clean_adj.sum(1)
|
26 |
+
perturbed_degree = perturbed_adj.sum(1)
|
27 |
+
fig, ax1 = plt.subplots()
|
28 |
+
sns.distplot(clean_degree, label='Clean Graph', norm_hist=False, ax=ax1)
|
29 |
+
sns.distplot(perturbed_degree, label='Perturbed Graph', norm_hist=False, ax=ax1)
|
30 |
+
ax1.grid(False)
|
31 |
+
plt.legend(prop={'size':18})
|
32 |
+
plt.ylabel('Density Distribution', fontsize=18)
|
33 |
+
plt.xlabel('Node degree', fontsize=18)
|
34 |
+
plt.xticks(fontsize=14)
|
35 |
+
plt.yticks(fontsize=14)
|
36 |
+
# plt.title(f'Feature difference of adjacency after {attack}-attack')
|
37 |
+
if not os.path.exists('figures/'):
|
38 |
+
os.mkdir('figures')
|
39 |
+
plt.savefig('figures/%s' % savename, bbox_inches='tight')
|
40 |
+
plt.show()
|
41 |
+
|
42 |
+
def feature_diff(clean_adj, perturbed_adj, features, savename='feature_diff.pdf'):
|
43 |
+
"""Plot feature difference on clean and perturbed graphs.
|
44 |
+
|
45 |
+
Parameters
|
46 |
+
----------
|
47 |
+
clean_adj: sp.csr_matrix
|
48 |
+
adjancecy matrix of the clean graph
|
49 |
+
perturbed_adj: sp.csr_matrix
|
50 |
+
adjancecy matrix of the perturbed graph
|
51 |
+
features: sp.csr_matrix or np.array
|
52 |
+
node features
|
53 |
+
savename: str
|
54 |
+
filename to be saved
|
55 |
+
|
56 |
+
Returns
|
57 |
+
-------
|
58 |
+
None
|
59 |
+
"""
|
60 |
+
|
61 |
+
fig, ax1 = plt.subplots()
|
62 |
+
sns.distplot(_get_diff(clean_adj, features), label='Normal Edges', norm_hist=True, ax=ax1)
|
63 |
+
delta_adj = perturbed_adj - clean_adj
|
64 |
+
delta_adj[delta_adj < 0] = 0
|
65 |
+
sns.distplot(_get_diff(delta_adj, features), label='Adversarial Edges', norm_hist=True, ax=ax1)
|
66 |
+
ax1.grid(False)
|
67 |
+
plt.legend(prop={'size':18})
|
68 |
+
plt.ylabel('Density Distribution', fontsize=18)
|
69 |
+
plt.xlabel('Feature Difference Between Connected Nodes', fontsize=18)
|
70 |
+
plt.xticks(fontsize=14)
|
71 |
+
plt.yticks(fontsize=14)
|
72 |
+
if not os.path.exists('figures/'):
|
73 |
+
os.mkdir('figures')
|
74 |
+
plt.savefig('figures/%s' % savename, bbox_inches='tight')
|
75 |
+
plt.show()
|
76 |
+
|
77 |
+
|
78 |
+
def _get_diff(adj, features):
|
79 |
+
isSparse = sp.issparse(features)
|
80 |
+
edges = np.array(adj.nonzero()).T
|
81 |
+
row_degree = adj.sum(0).tolist()[0]
|
82 |
+
diff = []
|
83 |
+
for edge in tqdm(edges):
|
84 |
+
n1 = edge[0]
|
85 |
+
n2 = edge[1]
|
86 |
+
if n1 > n2:
|
87 |
+
continue
|
88 |
+
d = np.sum((features[n1]/np.sqrt(row_degree[n1]) - features[n2]/np.sqrt(row_degree[n2])).power(2))
|
89 |
+
diff.append(d)
|
90 |
+
return diff
|
91 |
+
|
deeprobust/image/attack/BPDA.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/lordwarlock/Pytorch-BPDA/blob/master/bpda.py
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision.models as models
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
def normalize(image, mean, std):
|
10 |
+
return (image - mean)/std
|
11 |
+
|
12 |
+
def preprocess(image):
|
13 |
+
image = image / 255
|
14 |
+
image = np.transpose(image, (2, 0, 1))
|
15 |
+
mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
|
16 |
+
std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
|
17 |
+
image = normalize(image, mean, std)
|
18 |
+
return image
|
19 |
+
|
20 |
+
def image2tensor(image):
|
21 |
+
img_t = torch.Tensor(image)
|
22 |
+
img_t = img_t.unsqueeze(0)
|
23 |
+
img_t.requires_grad_()
|
24 |
+
return img_t
|
25 |
+
|
26 |
+
def label2tensor(label):
|
27 |
+
target = np.array([label])
|
28 |
+
target = torch.from_numpy(target).long()
|
29 |
+
return target
|
30 |
+
|
31 |
+
def get_img_grad_given_label(image, label, model):
|
32 |
+
logits = model(image)
|
33 |
+
ce = nn.CrossEntropyLoss()
|
34 |
+
loss = ce(logits, target)
|
35 |
+
loss.backward()
|
36 |
+
ret = image.grad.clone()
|
37 |
+
model.zero_grad()
|
38 |
+
image.grad.data.zero_()
|
39 |
+
return ret
|
40 |
+
|
41 |
+
def get_cw_grad(adv, origin, label, model):
|
42 |
+
logits = model(adv)
|
43 |
+
ce = nn.CrossEntropyLoss()
|
44 |
+
l2 = nn.MSELoss()
|
45 |
+
loss = ce(logits, label) + l2(0, origin - adv) / l2(0, origin)
|
46 |
+
loss.backward()
|
47 |
+
ret = adv.grad.clone()
|
48 |
+
model.zero_grad()
|
49 |
+
adv.grad.data.zero_()
|
50 |
+
origin.grad.data.zero_()
|
51 |
+
return ret
|
52 |
+
|
53 |
+
def l2_norm(adv, img):
|
54 |
+
adv = adv.detach().numpy()
|
55 |
+
img = img.detach().numpy()
|
56 |
+
ret = np.sum(np.square(adv - img))/np.sum(np.square(img))
|
57 |
+
return ret
|
58 |
+
|
59 |
+
def clip_bound(adv):
|
60 |
+
mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
|
61 |
+
std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
|
62 |
+
adv = adv * std + mean
|
63 |
+
adv = np.clip(adv, 0., 1.)
|
64 |
+
adv = (adv - mean) / std
|
65 |
+
return adv.astype(np.float32)
|
66 |
+
|
67 |
+
def identity_transform(x):
|
68 |
+
return x.detach().clone()
|
69 |
+
|
70 |
+
def BPDA_attack(image,target, model, step_size = 1., iterations = 10, linf=False, transform_func=identity_transform):
|
71 |
+
target = label2tensor(target)
|
72 |
+
adv = image.detach().numpy()
|
73 |
+
adv = torch.from_numpy(adv)
|
74 |
+
adv.requires_grad_()
|
75 |
+
for _ in range(iterations):
|
76 |
+
adv_def = transform_func(adv)
|
77 |
+
adv_def.requires_grad_()
|
78 |
+
l2 = nn.MSELoss()
|
79 |
+
loss = l2(0, adv_def)
|
80 |
+
loss.backward()
|
81 |
+
g = get_cw_grad(adv_def, image, target, model)
|
82 |
+
if linf:
|
83 |
+
g = torch.sign(g)
|
84 |
+
print(g.numpy().sum())
|
85 |
+
adv = adv.detach().numpy() - step_size * g.numpy()
|
86 |
+
adv = clip_bound(adv)
|
87 |
+
adv = torch.from_numpy(adv)
|
88 |
+
adv.requires_grad_()
|
89 |
+
if linf:
|
90 |
+
print('label', torch.argmax(model(adv)), 'linf', torch.max(torch.abs(adv - image)).detach().numpy())
|
91 |
+
else:
|
92 |
+
print('label', torch.argmax(model(adv)), 'l2', l2_norm(adv, image))
|
93 |
+
return adv.detach().numpy()
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
import matplotlib
|
97 |
+
matplotlib.use('TkAgg')
|
98 |
+
import skimage
|
99 |
+
resnet18 = models.resnet18(pretrained=True).eval() # for CPU, remove cuda()
|
100 |
+
image = preprocess(skimage.io.imread('test.png'))
|
101 |
+
|
102 |
+
img_t = image2tensor(image)
|
103 |
+
BPDA_attack(img_t, 924, resnet18)
|
104 |
+
print('L-inf')
|
105 |
+
BPDA_attack(img_t, 924, resnet18, step_size = 0.003, linf=True)
|
deeprobust/image/attack/Universal.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ferjad/Universal_Adversarial_Perturbation_pytorch
|
3 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
4 |
+
|
5 |
+
"""
|
6 |
+
from deeprobust.image.attack import deepfool
|
7 |
+
import collections
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.optim as optim
|
15 |
+
import torch.utils.data as data_utils
|
16 |
+
import math
|
17 |
+
from PIL import Image
|
18 |
+
import torchvision.models as models
|
19 |
+
import sys
|
20 |
+
import random
|
21 |
+
import time
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
def zero_gradients(x):
|
25 |
+
if isinstance(x, torch.Tensor):
|
26 |
+
if x.grad is not None:
|
27 |
+
x.grad.detach_()
|
28 |
+
x.grad.zero_()
|
29 |
+
elif isinstance(x, collections.abc.Iterable):
|
30 |
+
for elem in x:
|
31 |
+
zero_gradients(elem)
|
32 |
+
|
33 |
+
def get_model(model,device):
|
34 |
+
if model == 'vgg16':
|
35 |
+
net = models.vgg16(pretrained=True)
|
36 |
+
elif model =='resnet18':
|
37 |
+
net = models.resnet18(pretrained=True)
|
38 |
+
|
39 |
+
net.eval()
|
40 |
+
net=net.to(device)
|
41 |
+
return net
|
42 |
+
|
43 |
+
def data_input_init(xi):
|
44 |
+
mean = [ 0.485, 0.456, 0.406 ]
|
45 |
+
std = [ 0.229, 0.224, 0.225 ]
|
46 |
+
transform = transforms.Compose([
|
47 |
+
transforms.Resize(256),
|
48 |
+
transforms.CenterCrop(224),
|
49 |
+
transforms.ToTensor(),
|
50 |
+
transforms.Normalize(mean = mean,
|
51 |
+
std = std)])
|
52 |
+
|
53 |
+
return (mean,std,transform)
|
54 |
+
|
55 |
+
def proj_lp(v, xi, p):
|
56 |
+
# Project on the lp ball centered at 0 and of radius xi
|
57 |
+
if p==np.inf:
|
58 |
+
v=torch.clamp(v,-xi,xi)
|
59 |
+
else:
|
60 |
+
v=v * min(1, xi/(torch.norm(v,p)+0.00001))
|
61 |
+
return v
|
62 |
+
|
63 |
+
def get_fooling_rate(data_list,v,model, device):
|
64 |
+
f = data_input_init(0)[2]
|
65 |
+
num_images = len(data_list)
|
66 |
+
|
67 |
+
fooled=0.0
|
68 |
+
|
69 |
+
for name in tqdm(data_list):
|
70 |
+
image = Image.open(name)
|
71 |
+
image = tf(image)
|
72 |
+
image = image.unsqueeze(0)
|
73 |
+
image = image.to(device)
|
74 |
+
_, pred = torch.max(model(image),1)
|
75 |
+
_, adv_pred = torch.max(model(image+v),1)
|
76 |
+
if(pred!=adv_pred):
|
77 |
+
fooled+=1
|
78 |
+
|
79 |
+
# Compute the fooling rate
|
80 |
+
fooling_rate = fooled/num_images
|
81 |
+
print('Fooling Rate = ', fooling_rate)
|
82 |
+
for param in model.parameters():
|
83 |
+
param.requires_grad = False
|
84 |
+
|
85 |
+
return fooling_rate,model
|
86 |
+
|
87 |
+
def universal_adversarial_perturbation(dataloader, model, device, xi=10, delta=0.2, max_iter_uni = 10, p=np.inf,
|
88 |
+
num_classes=10, overshoot=0.02, max_iter_df=10,t_p = 0.2):
|
89 |
+
"""universal_adversarial_perturbation.
|
90 |
+
|
91 |
+
Parameters
|
92 |
+
----------
|
93 |
+
dataloader :
|
94 |
+
dataloader
|
95 |
+
model :
|
96 |
+
target model
|
97 |
+
device :
|
98 |
+
device
|
99 |
+
xi :
|
100 |
+
controls the l_p magnitude of the perturbation
|
101 |
+
delta :
|
102 |
+
controls the desired fooling rate (default = 80% fooling rate)
|
103 |
+
max_iter_uni :
|
104 |
+
maximum number of iteration (default = 10*num_images)
|
105 |
+
p :
|
106 |
+
norm to be used (default = np.inf)
|
107 |
+
num_classes :
|
108 |
+
num_classes (default = 10)
|
109 |
+
overshoot :
|
110 |
+
to prevent vanishing updates (default = 0.02)
|
111 |
+
max_iter_df :
|
112 |
+
maximum number of iterations for deepfool (default = 10)
|
113 |
+
t_p :
|
114 |
+
truth percentage, for how many flipped labels in a batch. (default = 0.2)
|
115 |
+
|
116 |
+
Returns
|
117 |
+
-------
|
118 |
+
the universal perturbation matrix.
|
119 |
+
"""
|
120 |
+
time_start = time.time()
|
121 |
+
mean, std,tf = data_input_init(xi)
|
122 |
+
v = torch.zeros(1,3,224,224).to(device)
|
123 |
+
v.requires_grad_()
|
124 |
+
|
125 |
+
fooling_rate = 0.0
|
126 |
+
num_images = len(dataloader)
|
127 |
+
itr = 0
|
128 |
+
|
129 |
+
while fooling_rate < 1-delta and itr < max_iter_uni:
|
130 |
+
|
131 |
+
# Iterate over the dataset and compute the purturbation incrementally
|
132 |
+
|
133 |
+
for i,(img, label) in enumerate(dataloader):
|
134 |
+
_, pred = torch.max(model(img),1)
|
135 |
+
_, adv_pred = torch.max(model(img+v),1)
|
136 |
+
|
137 |
+
if(pred == adv_pred):
|
138 |
+
perturb = deepfool(model, device)
|
139 |
+
_ = perturb.generate(img+v, num_classed = num_classed, overshoot = overshoot, max_iter = max_iter_df)
|
140 |
+
dr, iter = perturb.getpurb()
|
141 |
+
if(iter<max_iter_df-1):
|
142 |
+
v = v + torch.from_numpy(dr).to(device)
|
143 |
+
v = proj_lp(v,xi,p)
|
144 |
+
|
145 |
+
if(k%10==0):
|
146 |
+
print('Norm of v: '+str(torch.norm(v).detach().cpu().numpy()))
|
147 |
+
|
148 |
+
fooling_rate,model = get_fooling_rate(data_list,v,model, device)
|
149 |
+
itr = itr + 1
|
150 |
+
|
151 |
+
return v
|
deeprobust/image/attack/YOPOpgd.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import torch.optim as optim
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from deeprobust.image.attack.base_attack import BaseAttack
|
9 |
+
|
10 |
+
class FASTPGD(BaseAttack):
|
11 |
+
'''
|
12 |
+
This module is the adversarial example gererated algorithm in YOPO.
|
13 |
+
|
14 |
+
References
|
15 |
+
----------
|
16 |
+
Original code: https://github.com/a1600012888/YOPO-You-Only-Propagate-Once
|
17 |
+
'''
|
18 |
+
# ImageNet pre-trained mean and std
|
19 |
+
# _mean = torch.tensor(np.array([0.485, 0.456, 0.406]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis])
|
20 |
+
# _std = torch.tensor(np.array([0.229, 0.224, 0.225]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis])
|
21 |
+
|
22 |
+
# _mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis])
|
23 |
+
# _std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis])
|
24 |
+
def __init__(self, eps = 6 / 255.0, sigma = 3 / 255.0, nb_iter = 20,
|
25 |
+
norm = np.inf, DEVICE = torch.device('cpu'),
|
26 |
+
mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]),
|
27 |
+
std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), random_start = True):
|
28 |
+
'''
|
29 |
+
:param eps: maximum distortion of adversarial examples
|
30 |
+
:param sigma: single step size
|
31 |
+
:param nb_iter: number of attack iterations
|
32 |
+
:param norm: which norm to bound the perturbations
|
33 |
+
'''
|
34 |
+
self.eps = eps
|
35 |
+
self.sigma = sigma
|
36 |
+
self.nb_iter = nb_iter
|
37 |
+
self.norm = norm
|
38 |
+
self.criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
|
39 |
+
self.DEVICE = DEVICE
|
40 |
+
self._mean = mean.to(DEVICE)
|
41 |
+
self._std = std.to(DEVICE)
|
42 |
+
self.random_start = random_start
|
43 |
+
|
44 |
+
def single_attack(self, net, inp, label, eta, target = None):
|
45 |
+
'''
|
46 |
+
Given the original image and the perturbation computed so far, computes
|
47 |
+
a new perturbation.
|
48 |
+
:param net:
|
49 |
+
:param inp: original image
|
50 |
+
:param label:
|
51 |
+
:param eta: perturbation computed so far
|
52 |
+
:return: a new perturbation
|
53 |
+
'''
|
54 |
+
|
55 |
+
adv_inp = inp + eta
|
56 |
+
|
57 |
+
#net.zero_grad()
|
58 |
+
|
59 |
+
pred = net(adv_inp)
|
60 |
+
if target is not None:
|
61 |
+
targets = torch.sum(pred[:, target])
|
62 |
+
grad_sign = torch.autograd.grad(targets, adv_in, only_inputs=True, retain_graph = False)[0].sign()
|
63 |
+
|
64 |
+
else:
|
65 |
+
loss = self.criterion(pred, label)
|
66 |
+
grad_sign = torch.autograd.grad(loss, adv_inp,
|
67 |
+
only_inputs=True, retain_graph = False)[0].sign()
|
68 |
+
|
69 |
+
adv_inp = adv_inp + grad_sign * (self.sigma / self._std)
|
70 |
+
tmp_adv_inp = adv_inp * self._std + self._mean
|
71 |
+
|
72 |
+
tmp_inp = inp * self._std + self._mean
|
73 |
+
tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1) ## clip into 0-1
|
74 |
+
#tmp_adv_inp = (tmp_adv_inp - self._mean) / self._std
|
75 |
+
tmp_eta = tmp_adv_inp - tmp_inp
|
76 |
+
|
77 |
+
#tmp_eta = clip_eta(tmp_eta, norm=self.norm, eps=self.eps, DEVICE=self.DEVICE)
|
78 |
+
if self.norm == np.inf:
|
79 |
+
tmp_eta = torch.clamp(tmp_eta, -self.eps, self.eps)
|
80 |
+
|
81 |
+
eta = tmp_eta/ self._std
|
82 |
+
return eta
|
83 |
+
|
84 |
+
def attack(self, net, inp, label, target = None):
|
85 |
+
|
86 |
+
if self.random_start:
|
87 |
+
eta = torch.FloatTensor(*inp.shape).uniform_(-self.eps, self.eps)
|
88 |
+
else:
|
89 |
+
eta = torch.zeros_like(inp)
|
90 |
+
eta = eta.to(self.DEVICE)
|
91 |
+
eta = (eta - self._mean) / self._std
|
92 |
+
net.eval()
|
93 |
+
|
94 |
+
inp.requires_grad = True
|
95 |
+
eta.requires_grad = True
|
96 |
+
for i in range(self.nb_iter):
|
97 |
+
eta = self.single_attack(net, inp, label, eta, target)
|
98 |
+
#print(i)
|
99 |
+
|
100 |
+
#print(eta.max())
|
101 |
+
adv_inp = inp + eta
|
102 |
+
tmp_adv_inp = adv_inp * self._std + self._mean
|
103 |
+
tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1)
|
104 |
+
adv_inp = (tmp_adv_inp - self._mean) / self._std
|
105 |
+
|
106 |
+
return adv_inp
|
107 |
+
|
108 |
+
def to(self, device):
|
109 |
+
self.DEVICE = device
|
110 |
+
self._mean = self._mean.to(device)
|
111 |
+
self._std = self._std.to(device)
|
112 |
+
self.criterion = self.criterion.to(device)
|
113 |
+
|
deeprobust/image/attack/base_attack.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class BaseAttack(object):
|
5 |
+
"""
|
6 |
+
Attack base class.
|
7 |
+
"""
|
8 |
+
|
9 |
+
__metaclass__ = ABCMeta
|
10 |
+
|
11 |
+
def __init__(self, model, device = 'cuda'):
|
12 |
+
self.model = model
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
def generate(self, image, label, **kwargs):
|
16 |
+
"""
|
17 |
+
Overide this function for the main body of attack algorithm.
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
image :
|
22 |
+
original image
|
23 |
+
label :
|
24 |
+
original label
|
25 |
+
kwargs :
|
26 |
+
user defined parameters
|
27 |
+
"""
|
28 |
+
return input
|
29 |
+
|
30 |
+
def parse_params(self, **kwargs):
|
31 |
+
"""
|
32 |
+
Parse user defined parameters.
|
33 |
+
"""
|
34 |
+
return True
|
35 |
+
|
36 |
+
def check_type_device(self, image, label):
|
37 |
+
"""
|
38 |
+
Check device, match variable type to device type.
|
39 |
+
|
40 |
+
Parameters
|
41 |
+
----------
|
42 |
+
image :
|
43 |
+
image
|
44 |
+
label :
|
45 |
+
label
|
46 |
+
"""
|
47 |
+
|
48 |
+
################## devices
|
49 |
+
if self.device == 'cuda':
|
50 |
+
image = image.cuda()
|
51 |
+
label = label.cuda()
|
52 |
+
self.model = self.model.cuda()
|
53 |
+
elif self.device == 'cpu':
|
54 |
+
image = image.cpu()
|
55 |
+
label = label.cpu()
|
56 |
+
self.model = self.model.cpu()
|
57 |
+
else:
|
58 |
+
raise ValueError('Please input cpu or cuda')
|
59 |
+
|
60 |
+
################## data type
|
61 |
+
if type(image).__name__ == 'Tensor':
|
62 |
+
image = image.float()
|
63 |
+
image = image.float().clone().detach().requires_grad_(True)
|
64 |
+
elif type(image).__name__ == 'ndarray':
|
65 |
+
image = image.astype('float')
|
66 |
+
image = torch.tensor(image, requires_grad=True)
|
67 |
+
else:
|
68 |
+
raise ValueError('Input values only take numpy arrays or torch tensors')
|
69 |
+
|
70 |
+
if type(label).__name__ == 'Tensor':
|
71 |
+
label = label.long()
|
72 |
+
elif type(label).__name__ == 'ndarray':
|
73 |
+
label = label.astype('long')
|
74 |
+
label = torch.tensor(y)
|
75 |
+
else:
|
76 |
+
raise ValueError('Input labels only take numpy arrays or torch tensors')
|
77 |
+
|
78 |
+
|
79 |
+
#################### set init attributes
|
80 |
+
self.image = image
|
81 |
+
self.label = label
|
82 |
+
|
83 |
+
return True
|
84 |
+
|
85 |
+
def get_or_predict_lable(self, image):
|
86 |
+
output = self.model(image)
|
87 |
+
pred = output.argmax(dim=1, keepdim=True)
|
88 |
+
return(pred)
|
deeprobust/image/attack/l2_attack.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class CarliniL2:
|
8 |
+
def __init__(self, model, device):
|
9 |
+
self.model = model
|
10 |
+
self.device = device
|
11 |
+
|
12 |
+
def parse_params(self, gan, confidence=0, targeted=False, learning_rate=1e-1,
|
13 |
+
binary_search_steps=5, max_iterations=10000, abort_early=False, initial_const=1,
|
14 |
+
clip_min=0, clip_max=1):
|
15 |
+
|
16 |
+
self.TARGETED = targeted
|
17 |
+
self.LEARNING_RATE = learning_rate
|
18 |
+
self.MAX_ITERATIONS = max_iterations
|
19 |
+
self.BINARY_SEARCH_STEPS = binary_search_steps
|
20 |
+
self.ABORT_EARLY = abort_early
|
21 |
+
self.CONFIDENCE = confidence
|
22 |
+
self.initial_const = initial_const
|
23 |
+
self.clip_min = clip_min
|
24 |
+
self.clip_max = clip_max
|
25 |
+
self.gan = gan
|
26 |
+
self.learning_rate = learning_rate
|
27 |
+
self.repeat = binary_search_steps >= 10
|
28 |
+
|
29 |
+
def get_or_guess_labels(self, x, y=None):
|
30 |
+
"""
|
31 |
+
Get the label to use in generating an adversarial example for x.
|
32 |
+
The kwargs are fed directly from the kwargs of the attack.
|
33 |
+
If 'y' is in kwargs, use that as the label.
|
34 |
+
Otherwise, use the model's prediction as the label.
|
35 |
+
"""
|
36 |
+
if y is not None:
|
37 |
+
labels = y
|
38 |
+
else:
|
39 |
+
preds = F.softmax(self.model(x))
|
40 |
+
preds_max = torch.max(preds, 1, keepdim=True)[0]
|
41 |
+
original_predictions = (preds == preds_max)
|
42 |
+
labels = original_predictions
|
43 |
+
del preds
|
44 |
+
return labels.float()
|
45 |
+
|
46 |
+
def atanh(self, x):
|
47 |
+
return 0.5 * torch.log((1 + x) / (1 - x))
|
48 |
+
|
49 |
+
def to_one_hot(self, x):
|
50 |
+
one_hot = torch.FloatTensor(x.shape[0], 10).to(x.get_device())
|
51 |
+
one_hot.zero_()
|
52 |
+
x = x.unsqueeze(1)
|
53 |
+
one_hot = one_hot.scatter_(1, x, 1)
|
54 |
+
return one_hot
|
55 |
+
|
56 |
+
def generate(self, imgs, y, start):
|
57 |
+
|
58 |
+
batch_size = imgs.shape[0]
|
59 |
+
labs = self.get_or_guess_labels(imgs, y)
|
60 |
+
|
61 |
+
def compare(x, y):
|
62 |
+
if self.TARGETED is None: return True
|
63 |
+
|
64 |
+
if sum(x.shape) != 0:
|
65 |
+
x = x.clone()
|
66 |
+
if self.TARGETED:
|
67 |
+
x[y] -= self.CONFIDENCE
|
68 |
+
else:
|
69 |
+
x[y] += self.CONFIDENCE
|
70 |
+
x = torch.argmax(x)
|
71 |
+
if self.TARGETED:
|
72 |
+
return x == y
|
73 |
+
else:
|
74 |
+
return x != y
|
75 |
+
|
76 |
+
# set the lower and upper bounds accordingly
|
77 |
+
lower_bound = torch.zeros(batch_size).to(self.device)
|
78 |
+
CONST = torch.ones(batch_size).to(self.device) * self.initial_const
|
79 |
+
upper_bound = (torch.ones(batch_size) * 1e10).to(self.device)
|
80 |
+
|
81 |
+
# the best l2, score, and image attack
|
82 |
+
o_bestl2 = [1e10] * batch_size
|
83 |
+
o_bestscore = [-1] * batch_size
|
84 |
+
o_bestattack = self.gan(start)
|
85 |
+
|
86 |
+
# check if the input label is one-hot, if not, then change it into one-hot vector
|
87 |
+
if len(labs.shape) == 1:
|
88 |
+
tlabs = self.to_one_hot(labs.long())
|
89 |
+
else:
|
90 |
+
tlabs = labs
|
91 |
+
|
92 |
+
for outer_step in range(self.BINARY_SEARCH_STEPS):
|
93 |
+
# completely reset adam's internal state.
|
94 |
+
modifier = nn.Parameter(start)
|
95 |
+
optimizer = torch.optim.Adam([modifier, ], lr=self.learning_rate)
|
96 |
+
|
97 |
+
bestl2 = [1e10] * batch_size
|
98 |
+
bestscore = -1 * torch.ones(batch_size, dtype=torch.float32).to(self.device)
|
99 |
+
|
100 |
+
# The last iteration (if we run many steps) repeat the search once.
|
101 |
+
if self.repeat and outer_step == self.BINARY_SEARCH_STEPS - 1:
|
102 |
+
CONST = upper_bound
|
103 |
+
prev = 1e6
|
104 |
+
|
105 |
+
for i in range(self.MAX_ITERATIONS):
|
106 |
+
optimizer.zero_grad()
|
107 |
+
nimgs = self.gan(modifier.to(self.device))
|
108 |
+
|
109 |
+
# distance to the input data
|
110 |
+
l2dist = torch.sum(torch.sum(torch.sum((nimgs - imgs) ** 2, 1), 1), 1)
|
111 |
+
loss2 = torch.sum(l2dist)
|
112 |
+
|
113 |
+
# prediction BEFORE-SOFTMAX of the model
|
114 |
+
scores = self.model(nimgs)
|
115 |
+
|
116 |
+
# compute the probability of the label class versus the maximum other
|
117 |
+
other = torch.max(((1 - tlabs) * scores - tlabs * 10000), 1)[0]
|
118 |
+
real = torch.sum(tlabs * scores, 1)
|
119 |
+
|
120 |
+
if self.TARGETED:
|
121 |
+
# if targeted, optimize for making the other class most likely
|
122 |
+
loss1 = torch.max(torch.zeros_like(other), other - real + self.CONFIDENCE)
|
123 |
+
else:
|
124 |
+
# if untargeted, optimize for making this class least likely.
|
125 |
+
loss1 = torch.max(torch.zeros_like(other), real - other + self.CONFIDENCE)
|
126 |
+
|
127 |
+
# sum up the losses
|
128 |
+
loss1 = torch.sum(CONST * loss1)
|
129 |
+
loss = loss1 + loss2
|
130 |
+
|
131 |
+
# update the modifier
|
132 |
+
loss.backward()
|
133 |
+
optimizer.step()
|
134 |
+
|
135 |
+
# check if we should abort search if we're getting nowhere.
|
136 |
+
if self.ABORT_EARLY and i % ((self.MAX_ITERATIONS // 10) or 1) == 0:
|
137 |
+
if loss > prev * .9999:
|
138 |
+
# print('Stop early')
|
139 |
+
break
|
140 |
+
prev = loss
|
141 |
+
|
142 |
+
# adjust the best result found so far
|
143 |
+
for e, (l2, sc, ii) in enumerate(zip(l2dist, scores, nimgs)):
|
144 |
+
lab = torch.argmax(tlabs[e])
|
145 |
+
|
146 |
+
if l2 < bestl2[e] and compare(sc, lab):
|
147 |
+
bestl2[e] = l2
|
148 |
+
bestscore[e] = torch.argmax(sc)
|
149 |
+
|
150 |
+
if l2 < o_bestl2[e] and compare(sc, lab):
|
151 |
+
o_bestl2[e] = l2
|
152 |
+
o_bestscore[e] = torch.argmax(sc)
|
153 |
+
o_bestattack[e] = ii
|
154 |
+
|
155 |
+
# adjust the constant as needed
|
156 |
+
for e in range(batch_size):
|
157 |
+
if compare(bestscore[e], torch.argmax(tlabs[e]).float()) and \
|
158 |
+
bestscore[e] != -1:
|
159 |
+
# success, divide CONST by two
|
160 |
+
upper_bound[e] = min(upper_bound[e], CONST[e])
|
161 |
+
if upper_bound[e] < 1e9:
|
162 |
+
CONST[e] = (lower_bound[e] + upper_bound[e]) / 2
|
163 |
+
else:
|
164 |
+
# failure, either multiply by 10 if no solution found yet
|
165 |
+
# or do binary search with the known upper bound
|
166 |
+
lower_bound[e] = max(lower_bound[e], CONST[e])
|
167 |
+
if upper_bound[e] < 1e9:
|
168 |
+
CONST[e] = (lower_bound[e] + upper_bound[e]) / 2
|
169 |
+
else:
|
170 |
+
CONST[e] *= 10
|
171 |
+
|
172 |
+
# return the best solution found
|
173 |
+
o_bestl2 = np.array(o_bestl2)
|
174 |
+
return o_bestattack
|
deeprobust/image/attack/lbfgs.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import scipy.optimize as so
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F #233
|
6 |
+
|
7 |
+
from deeprobust.image.attack.base_attack import BaseAttack
|
8 |
+
|
9 |
+
class LBFGS(BaseAttack):
|
10 |
+
"""
|
11 |
+
LBFGS is the first adversarial generating algorithm.
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
def __init__(self, model, device = 'cuda' ):
|
16 |
+
super(LBFGS, self).__init__(model, device)
|
17 |
+
|
18 |
+
def generate(self, image, label, target_label, **kwargs):
|
19 |
+
"""
|
20 |
+
Call this function to generate adversarial examples.
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
image :
|
25 |
+
original image
|
26 |
+
label :
|
27 |
+
target label
|
28 |
+
kwargs :
|
29 |
+
user defined paremeters
|
30 |
+
"""
|
31 |
+
assert self.check_type_device(image, label)
|
32 |
+
assert self.parse_params(**kwargs)
|
33 |
+
self.target_label = target_label
|
34 |
+
adv_img= optimize(self.model,
|
35 |
+
self.image,
|
36 |
+
self.label,
|
37 |
+
self.target_label,
|
38 |
+
self.bounds,
|
39 |
+
self.epsilon,
|
40 |
+
self.maxiter,
|
41 |
+
self.class_num,
|
42 |
+
self.device)
|
43 |
+
return adv_img
|
44 |
+
|
45 |
+
def distance(self):
|
46 |
+
return self.dist
|
47 |
+
|
48 |
+
def loss(self):
|
49 |
+
return self.loss
|
50 |
+
|
51 |
+
def parse_params(self,
|
52 |
+
clip_max = 1,
|
53 |
+
clip_min = 0,
|
54 |
+
class_num = 10,
|
55 |
+
epsilon = 1e-5, #step of finding initial c
|
56 |
+
maxiter = 20, #maximum of iteration in lbfgs optimization
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Parse the user defined parameters.
|
60 |
+
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
clip_max :
|
64 |
+
maximum pixel value
|
65 |
+
clip_min :
|
66 |
+
minimum pixel value
|
67 |
+
class_num :
|
68 |
+
total number of class
|
69 |
+
epsilon :
|
70 |
+
step length for binary seach
|
71 |
+
maxiter :
|
72 |
+
maximum number of iterations
|
73 |
+
"""
|
74 |
+
self.epsilon = epsilon
|
75 |
+
self.maxiter = maxiter
|
76 |
+
self.class_num = class_num
|
77 |
+
self.bounds = (clip_min, clip_max)
|
78 |
+
return True
|
79 |
+
|
80 |
+
def optimize(model, image, label, target_label, bounds, epsilon, maxiter, class_num, device):
|
81 |
+
x_t = image
|
82 |
+
x0 = image[0].to('cpu').detach().numpy()
|
83 |
+
min_, max_ = bounds
|
84 |
+
|
85 |
+
target_dist = torch.tensor(target_label)
|
86 |
+
target_dist = target_dist.unsqueeze_(0).long().to(device)
|
87 |
+
|
88 |
+
# store the shape for later and operate on the flattened input
|
89 |
+
|
90 |
+
shape = x0.shape
|
91 |
+
dtype = x0.dtype
|
92 |
+
x0 = x0.flatten().astype(np.float64)
|
93 |
+
|
94 |
+
n = len(x0)
|
95 |
+
bounds = [(min_, max_)] * n
|
96 |
+
|
97 |
+
def loss(x, c):
|
98 |
+
#calculate the target function
|
99 |
+
v1 = (torch.norm(torch.from_numpy(x0) - x)) **2
|
100 |
+
|
101 |
+
x = torch.tensor(x.astype(dtype).reshape(shape))
|
102 |
+
x = x.unsqueeze_(0).float().to(device)
|
103 |
+
|
104 |
+
predict = model(x)
|
105 |
+
v2 = F.nll_loss(predict, target_dist)
|
106 |
+
|
107 |
+
v = c * v1 + v2
|
108 |
+
#print(v)
|
109 |
+
return np.float64(v)
|
110 |
+
|
111 |
+
def lbfgs_b(c):
|
112 |
+
|
113 |
+
#initial the variables
|
114 |
+
approx_grad_eps = (max_ - min_) / 100
|
115 |
+
print('in lbfgs_b:', 'c =', c)
|
116 |
+
|
117 |
+
#start optimization
|
118 |
+
optimize_output, f, d = so.fmin_l_bfgs_b(
|
119 |
+
loss,
|
120 |
+
x0,
|
121 |
+
args=(c,),
|
122 |
+
approx_grad = True,
|
123 |
+
bounds = bounds,
|
124 |
+
m = 15,
|
125 |
+
maxiter = maxiter,
|
126 |
+
factr = 1e10, #optimization accuracy
|
127 |
+
maxls = 5,
|
128 |
+
epsilon = approx_grad_eps, iprint = 11)
|
129 |
+
print('finish optimization')
|
130 |
+
|
131 |
+
# LBFGS-B does not always exactly respect the boundaries
|
132 |
+
if np.amax(optimize_output) > max_ or np.amin(optimize_output) < min_: # pragma: no coverage
|
133 |
+
logging.info('Input out of bounds (min, max = {}, {}). Performing manual clip.'.format(
|
134 |
+
np.amin(optimize_output), np.amax(optimize_output)))
|
135 |
+
|
136 |
+
optimize_output = np.clip(optimize_output, min_, max_)
|
137 |
+
|
138 |
+
#is_adversarial = pending_attack(target_model = model, adv_exp = optimize_output, target_label = target_label)
|
139 |
+
# pending if the attack success
|
140 |
+
optimize_output = optimize_output.reshape(shape).astype(dtype)
|
141 |
+
optimize_output = torch.from_numpy(optimize_output)
|
142 |
+
optimize_output = optimize_output.unsqueeze_(0).float().to(device)
|
143 |
+
|
144 |
+
predict1 = model(optimize_output)
|
145 |
+
label = predict1.argmax(dim=1, keepdim=True)
|
146 |
+
if label == target_label:
|
147 |
+
is_adversarial = True
|
148 |
+
print('can find adversarial example with current c.')
|
149 |
+
else:
|
150 |
+
is_adversarial = False
|
151 |
+
print('could not find adversarial example with current c.')
|
152 |
+
|
153 |
+
return optimize_output, is_adversarial
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
# finding initial c
|
158 |
+
c = epsilon
|
159 |
+
print('finding initial c:')
|
160 |
+
|
161 |
+
for i in range(30):
|
162 |
+
c = 2 * c
|
163 |
+
x_new, is_adversarial = lbfgs_b(c)
|
164 |
+
if is_adversarial == False:
|
165 |
+
break
|
166 |
+
print('initial c:', c)
|
167 |
+
print('start binary search:')
|
168 |
+
|
169 |
+
x_new, is_adversarial = lbfgs_b(0)
|
170 |
+
if is_adversarial == False: # pragma: no cover
|
171 |
+
print('Could not find an adversarial;')
|
172 |
+
return
|
173 |
+
|
174 |
+
print('c_high:',c)
|
175 |
+
# binary search
|
176 |
+
c_low = 0
|
177 |
+
c_high = c
|
178 |
+
while c_high - c_low >= epsilon:
|
179 |
+
print(c_high,' ',c_low)
|
180 |
+
c_half = (c_low + c_high) / 2
|
181 |
+
x_new, is_adversarial = lbfgs_b(c_half)
|
182 |
+
|
183 |
+
if is_adversarial:
|
184 |
+
c_low = c_half
|
185 |
+
else:
|
186 |
+
c_high = c_half
|
187 |
+
|
188 |
+
x_new, is_adversarial = lbfgs_b(c_low)
|
189 |
+
|
190 |
+
dis = ( torch.norm(x_new.reshape(shape) - x0.reshape(shape)) ) **2
|
191 |
+
|
192 |
+
x_new = x_new.flatten().numpy()
|
193 |
+
mintargetfunc = loss(x_new.astype(np.float64), c_low)
|
194 |
+
|
195 |
+
x_new = x_new.astype(dtype)
|
196 |
+
x_new = x_new.reshape(shape)
|
197 |
+
|
198 |
+
x_new = torch.from_numpy(x_new).unsqueeze_(0).float().to(device)
|
199 |
+
|
200 |
+
return x_new
|
201 |
+
|
202 |
+
|
203 |
+
|
deeprobust/image/attack/pgd.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import torch.optim as optim
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from deeprobust.image.attack.base_attack import BaseAttack
|
9 |
+
|
10 |
+
class PGD(BaseAttack):
|
11 |
+
"""
|
12 |
+
This is the multi-step version of FGSM attack.
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
def __init__(self, model, device = 'cuda'):
|
17 |
+
|
18 |
+
super(PGD, self).__init__(model, device)
|
19 |
+
|
20 |
+
def generate(self, image, label, **kwargs):
|
21 |
+
"""
|
22 |
+
Call this function to generate PGD adversarial examples.
|
23 |
+
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
image :
|
27 |
+
original image
|
28 |
+
label :
|
29 |
+
target label
|
30 |
+
kwargs :
|
31 |
+
user defined paremeters
|
32 |
+
"""
|
33 |
+
|
34 |
+
## check and parse parameters for attack
|
35 |
+
label = label.type(torch.FloatTensor)
|
36 |
+
|
37 |
+
assert self.check_type_device(image, label)
|
38 |
+
assert self.parse_params(**kwargs)
|
39 |
+
|
40 |
+
return pgd_attack(self.model,
|
41 |
+
self.image,
|
42 |
+
self.label,
|
43 |
+
self.epsilon,
|
44 |
+
self.clip_max,
|
45 |
+
self.clip_min,
|
46 |
+
self.num_steps,
|
47 |
+
self.step_size,
|
48 |
+
self.print_process,
|
49 |
+
self.bound)
|
50 |
+
##default parameter for mnist data set.
|
51 |
+
|
52 |
+
def parse_params(self,
|
53 |
+
epsilon = 0.03,
|
54 |
+
num_steps = 40,
|
55 |
+
step_size = 0.01,
|
56 |
+
clip_max = 1.0,
|
57 |
+
clip_min = 0.0,
|
58 |
+
print_process = False,
|
59 |
+
bound = 'linf'
|
60 |
+
):
|
61 |
+
"""parse_params.
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
epsilon :
|
66 |
+
perturbation constraint
|
67 |
+
num_steps :
|
68 |
+
iteration step
|
69 |
+
step_size :
|
70 |
+
step size
|
71 |
+
clip_max :
|
72 |
+
maximum pixel value
|
73 |
+
clip_min :
|
74 |
+
minimum pixel value
|
75 |
+
print_process :
|
76 |
+
whether to print out the log during optimization process, True or False print out the log during optimization process, True or False.
|
77 |
+
"""
|
78 |
+
self.epsilon = epsilon
|
79 |
+
self.num_steps = num_steps
|
80 |
+
self.step_size = step_size
|
81 |
+
self.clip_max = clip_max
|
82 |
+
self.clip_min = clip_min
|
83 |
+
self.print_process = print_process
|
84 |
+
self.bound = bound
|
85 |
+
return True
|
86 |
+
|
87 |
+
def pgd_attack(model,
|
88 |
+
X,
|
89 |
+
y,
|
90 |
+
epsilon,
|
91 |
+
clip_max,
|
92 |
+
clip_min,
|
93 |
+
num_steps,
|
94 |
+
step_size,
|
95 |
+
print_process,
|
96 |
+
bound = 'linf'):
|
97 |
+
|
98 |
+
out = model(X)
|
99 |
+
err = (out.data.max(1)[1] != y.data).float().sum()
|
100 |
+
#TODO: find a other way
|
101 |
+
device = X.device
|
102 |
+
imageArray = X.detach().cpu().numpy()
|
103 |
+
X_random = np.random.uniform(-epsilon, epsilon, X.shape)
|
104 |
+
imageArray = np.clip(imageArray + X_random, 0, 1.0)
|
105 |
+
|
106 |
+
X_pgd = torch.tensor(imageArray).to(device).float()
|
107 |
+
X_pgd.requires_grad = True
|
108 |
+
eta = torch.zeros_like(X)
|
109 |
+
eta.requires_grad = True
|
110 |
+
for i in range(num_steps):
|
111 |
+
|
112 |
+
pred = model(X_pgd)
|
113 |
+
loss = nn.CrossEntropyLoss()(pred, y)
|
114 |
+
|
115 |
+
if print_process:
|
116 |
+
print("iteration {:.0f}, loss:{:.4f}".format(i,loss))
|
117 |
+
|
118 |
+
loss.backward()
|
119 |
+
|
120 |
+
if bound == 'linf':
|
121 |
+
eta = step_size * X_pgd.grad.data.sign()
|
122 |
+
X_pgd = X_pgd + eta
|
123 |
+
eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon)
|
124 |
+
|
125 |
+
X_pgd = X.data + eta
|
126 |
+
|
127 |
+
X_pgd = torch.clamp(X_pgd, clip_min, clip_max)
|
128 |
+
#for ind in range(X_pgd.shape[1]):
|
129 |
+
# X_pgd[:,ind,:,:] = (torch.clamp(X_pgd[:,ind,:,:] * std[ind] + mean[ind], clip_min, clip_max) - mean[ind]) / std[ind]
|
130 |
+
|
131 |
+
X_pgd = X_pgd.detach()
|
132 |
+
X_pgd.requires_grad_()
|
133 |
+
X_pgd.retain_grad()
|
134 |
+
|
135 |
+
if bound == 'l2':
|
136 |
+
output = model(X + eta)
|
137 |
+
incorrect = output.max(1)[1] != y
|
138 |
+
correct = (~incorrect).unsqueeze(1).unsqueeze(1).unsqueeze(1).float()
|
139 |
+
#Finding the correct examples so as to attack only them
|
140 |
+
loss = nn.CrossEntropyLoss()(model(X + eta), y)
|
141 |
+
loss.backward()
|
142 |
+
|
143 |
+
eta.data += correct * step_size * eta.grad.detach() / torch.norm(eta.grad.detach())
|
144 |
+
eta.data *= epsilon / torch.norm(eta.detach()).clamp(min=epsilon)
|
145 |
+
eta.data = torch.min(torch.max(eta.detach(), -X), 1-X) # clip X+delta to [0,1]
|
146 |
+
eta.grad.zero_()
|
147 |
+
X_pgd = X + eta
|
148 |
+
|
149 |
+
return X_pgd
|
150 |
+
|
deeprobust/image/config.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
# ---------------------attack config------------------------#
|
3 |
+
attack_params = {
|
4 |
+
"FGSM_MNIST": {
|
5 |
+
'epsilon': 0.2,
|
6 |
+
'order': np.inf,
|
7 |
+
'clip_max': None,
|
8 |
+
'clip_min': None
|
9 |
+
},
|
10 |
+
|
11 |
+
"PGD_CIFAR10": {
|
12 |
+
'epsilon': 0.1,
|
13 |
+
'clip_max': 1.0,
|
14 |
+
'clip_min': 0.0,
|
15 |
+
'print_process': True
|
16 |
+
},
|
17 |
+
|
18 |
+
"LBFGS_MNIST": {
|
19 |
+
'epsilon': 1e-4,
|
20 |
+
'maxiter': 20,
|
21 |
+
'clip_max': 1,
|
22 |
+
'clip_min': 0,
|
23 |
+
'class_num': 10
|
24 |
+
},
|
25 |
+
|
26 |
+
"CW_MNIST": {
|
27 |
+
'confidence': 1e-4,
|
28 |
+
'clip_max': 1,
|
29 |
+
'clip_min': 0,
|
30 |
+
'max_iterations': 1000,
|
31 |
+
'initial_const': 1e-2,
|
32 |
+
'binary_search_steps': 5,
|
33 |
+
'learning_rate': 5e-3,
|
34 |
+
'abort_early': True,
|
35 |
+
}
|
36 |
+
|
37 |
+
}
|
38 |
+
|
39 |
+
#-----------defense(Adversarial training) config------------#
|
40 |
+
|
41 |
+
defense_params = {
|
42 |
+
"PGDtraining_MNIST":{
|
43 |
+
'save_dir': "./defense_model",
|
44 |
+
'save_model': True,
|
45 |
+
'save_name' : "mnist_pgdtraining_0.3.pt",
|
46 |
+
'epsilon' : 0.3,
|
47 |
+
'epoch_num' : 80,
|
48 |
+
'lr' : 0.01
|
49 |
+
},
|
50 |
+
|
51 |
+
"FGSMtraining_MNIST":{
|
52 |
+
'save_dir': "./defense_model",
|
53 |
+
'save_model': True,
|
54 |
+
'save_name' : "mnist_fgsmtraining_0.2.pt",
|
55 |
+
'epsilon' : 0.2,
|
56 |
+
'epoch_num' : 50,
|
57 |
+
'lr_train' : 0.001
|
58 |
+
},
|
59 |
+
|
60 |
+
"FAST_MNIST":{
|
61 |
+
'save_dir': "./defense_model",
|
62 |
+
'save_model': True,
|
63 |
+
'save_name' : "fast_mnist_0.3.pt",
|
64 |
+
'epsilon' : 0.3,
|
65 |
+
'epoch_num' : 50,
|
66 |
+
'lr_train' : 0.001
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
deeprobust/image/defense/LIDclassifier.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of LID detector.
|
3 |
+
Currently this implementation is under testing.
|
4 |
+
|
5 |
+
References
|
6 |
+
----------
|
7 |
+
.. [1] Ma, Xingjun, Bo Li, Yisen Wang, Sarah M. Erfani, Sudanthi Wijewickrema, Grant Schoenebeck, Dawn Song, Michael E. Houle, and James Bailey. "Characterizing adversarial subspaces using local intrinsic dimensionality." arXiv preprint arXiv:1801.02613 (2018).
|
8 |
+
.. [2] Original code:t https://github.com/xingjunm/lid_adversarial_subspace_detection
|
9 |
+
Copyright (c) 2018 Xingjun Ma
|
10 |
+
"""
|
11 |
+
|
12 |
+
from deeprobust.image.netmodels.CNN_multilayer import Net
|
13 |
+
|
14 |
+
def train(self, device, train_loader, optimizer, epoch):
|
15 |
+
"""train process.
|
16 |
+
|
17 |
+
Parameters
|
18 |
+
----------
|
19 |
+
device :
|
20 |
+
device(option:'cpu', 'cuda')
|
21 |
+
train_loader :
|
22 |
+
train data loader
|
23 |
+
optimizer :
|
24 |
+
optimizer
|
25 |
+
epoch :
|
26 |
+
epoch
|
27 |
+
"""
|
28 |
+
self.model.train()
|
29 |
+
correct = 0
|
30 |
+
bs = train_loader.batch_size
|
31 |
+
|
32 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
33 |
+
|
34 |
+
optimizer.zero_grad()
|
35 |
+
|
36 |
+
data, target = data.to(device), target.to(device)
|
37 |
+
|
38 |
+
data_adv, output = self.adv_data(data, target, ep = self.epsilon, num_steps = self.num_steps)
|
39 |
+
|
40 |
+
loss = self.calculate_loss(output, target)
|
41 |
+
|
42 |
+
loss.backward()
|
43 |
+
optimizer.step()
|
44 |
+
|
45 |
+
pred = output.argmax(dim = 1, keepdim = True)
|
46 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
47 |
+
|
48 |
+
#print every 10
|
49 |
+
if batch_idx % 10 == 0:
|
50 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy:{:.2f}%'.format(
|
51 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
52 |
+
100. * batch_idx / len(train_loader), loss.item(), 100 * correct/(10*bs)))
|
53 |
+
correct = 0
|
54 |
+
|
55 |
+
def get_lid(model, X_test, X_test_noisy, X_test_adv, k, batch_size):
|
56 |
+
"""get_lid.
|
57 |
+
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
model :
|
61 |
+
model
|
62 |
+
X_test :
|
63 |
+
clean data
|
64 |
+
X_test_noisy :
|
65 |
+
noisy data
|
66 |
+
X_test_adv :
|
67 |
+
adversarial data
|
68 |
+
k :
|
69 |
+
k
|
70 |
+
batch_size :
|
71 |
+
batch_size
|
72 |
+
"""
|
73 |
+
funcs = [K.function([model.layers[0].input, K.learning_phase()], [out])
|
74 |
+
for out in get_layer_wise_activations(model, dataset)]
|
75 |
+
|
76 |
+
lid_dim = len(funcs)
|
77 |
+
print("Number of layers to estimate: ", lid_dim)
|
78 |
+
|
79 |
+
def estimate(i_batch):
|
80 |
+
|
81 |
+
start = i_batch * batch_size
|
82 |
+
end = np.minimum(len(X), (i_batch + 1) * batch_size)
|
83 |
+
n_feed = end - start
|
84 |
+
lid_batch = np.zeros(shape=(n_feed, lid_dim))
|
85 |
+
lid_batch_adv = np.zeros(shape=(n_feed, lid_dim))
|
86 |
+
lid_batch_noisy = np.zeros(shape=(n_feed, lid_dim))
|
87 |
+
|
88 |
+
for i, func in enumerate(funcs):
|
89 |
+
X_act = func([X[start:end], 0])[0]
|
90 |
+
X_act = np.asarray(X_act, dtype=np.float32).reshape((n_feed, -1))
|
91 |
+
# print("X_act: ", X_act.shape)
|
92 |
+
|
93 |
+
X_adv_act = func([X_adv[start:end], 0])[0]
|
94 |
+
X_adv_act = np.asarray(X_adv_act, dtype=np.float32).reshape((n_feed, -1))
|
95 |
+
# print("X_adv_act: ", X_adv_act.shape)
|
96 |
+
|
97 |
+
X_noisy_act = func([X_noisy[start:end], 0])[0]
|
98 |
+
X_noisy_act = np.asarray(X_noisy_act, dtype=np.float32).reshape((n_feed, -1))
|
99 |
+
# print("X_noisy_act: ", X_noisy_act.shape)
|
100 |
+
|
101 |
+
# random clean samples
|
102 |
+
# Maximum likelihood estimation of local intrinsic dimensionality (LID)
|
103 |
+
lid_batch[:, i] = mle_batch(X_act, X_act, k=k)
|
104 |
+
# print("lid_batch: ", lid_batch.shape)
|
105 |
+
lid_batch_adv[:, i] = mle_batch(X_act, X_adv_act, k=k)
|
106 |
+
# print("lid_batch_adv: ", lid_batch_adv.shape)
|
107 |
+
lid_batch_noisy[:, i] = mle_batch(X_act, X_noisy_act, k=k)
|
108 |
+
# print("lid_batch_noisy: ", lid_batch_noisy.shape)
|
109 |
+
|
110 |
+
return lid_batch, lid_batch_noisy, lid_batch_adv
|
111 |
+
|
112 |
+
lids = []
|
113 |
+
lids_adv = []
|
114 |
+
lids_noisy = []
|
115 |
+
n_batches = int(np.ceil(X.shape[0] / float(batch_size)))
|
116 |
+
|
117 |
+
for i_batch in tqdm(range(n_batches)):
|
118 |
+
|
119 |
+
lid_batch, lid_batch_noisy, lid_batch_adv = estimate(i_batch)
|
120 |
+
lids.extend(lid_batch)
|
121 |
+
lids_adv.extend(lid_batch_adv)
|
122 |
+
lids_noisy.extend(lid_batch_noisy)
|
123 |
+
# print("lids: ", lids.shape)
|
124 |
+
# print("lids_adv: ", lids_noisy.shape)
|
125 |
+
# print("lids_noisy: ", lids_noisy.shape)
|
126 |
+
|
127 |
+
lids_normal = np.asarray(lids, dtype=np.float32)
|
128 |
+
lids_noisy = np.asarray(lids_noisy, dtype=np.float32)
|
129 |
+
lids_adv = np.asarray(lids_adv, dtype=np.float32)
|
130 |
+
|
131 |
+
lids_pos = lids_adv
|
132 |
+
lids_neg = np.concatenate((lids_normal, lids_noisy))
|
133 |
+
artifacts, labels = merge_and_generate_labels(lids_pos, lids_neg)
|
134 |
+
|
135 |
+
return artifacts, labels
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
|
139 |
+
batch_size = 100
|
140 |
+
k_nearest = 20
|
141 |
+
|
142 |
+
#get LID characters
|
143 |
+
characters, labels = get_lid(model, X_test, X_test_noisy, X_test_adv, k_nearest, batch_size)
|
144 |
+
data = np.concatenate((characters, labels), axis = 1)
|
145 |
+
|
docs/graph/attack.rst
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Introduction to Graph Attack with Examples
|
2 |
+
=======================
|
3 |
+
In this section, we introduce the graph attack algorithms provided
|
4 |
+
in DeepRobust. Speficailly, they can be divied into two types:
|
5 |
+
(1) targeted attack :class:`deeprobust.graph.targeted_attack` and
|
6 |
+
(2) global attack :class:`deeprobust.graph.global_attack`.
|
7 |
+
|
8 |
+
.. contents::
|
9 |
+
:local:
|
10 |
+
|
11 |
+
|
12 |
+
Global (Untargeted) Attack for Node Classification
|
13 |
+
-----------------------
|
14 |
+
Global (untargeted) attack aims to fool GNNs into giving wrong predictions on all
|
15 |
+
given nodes. Specifically, DeepRobust provides the following targeted
|
16 |
+
attack algorithms:
|
17 |
+
|
18 |
+
- :class:`deeprobust.graph.global_attack.Metattack`
|
19 |
+
- :class:`deeprobust.graph.global_attack.MetaApprox`
|
20 |
+
- :class:`deeprobust.graph.global_attack.DICE`
|
21 |
+
- :class:`deeprobust.graph.global_attack.MinMax`
|
22 |
+
- :class:`deeprobust.graph.global_attack.PGDAttack`
|
23 |
+
- :class:`deeprobust.graph.global_attack.NIPA`
|
24 |
+
- :class:`deeprobust.graph.global_attack.Random`
|
25 |
+
- :class:`deeprobust.graph.global_attack.NodeEmbeddingAttack`
|
26 |
+
- :class:`deeprobust.graph.global_attack.OtherNodeEmbeddingAttack`
|
27 |
+
|
28 |
+
All the above attacks except `NodeEmbeddingAttack` and `OtherNodeEmbeddingAttack` (see details
|
29 |
+
`here <https://deeprobust.readthedocs.io/en/latest/graph/node_embedding.html>`_ )
|
30 |
+
take the adjacency matrix, node feature matrix and labels as input. Usually, the adjacency
|
31 |
+
matrix is in the format of :obj:`scipy.sparse.csr_matrix` and feature matrix can either be
|
32 |
+
:obj:`scipy.sparse.csr_matrix` or :obj:`numpy.array`. The attack algorithm
|
33 |
+
will then transfer them into :obj:`torch.tensor` inside the class. It is also fine if you
|
34 |
+
provide :obj:`torch.tensor` as input, since the algorithm can automatically deal with it.
|
35 |
+
Now let's take a look at an example:
|
36 |
+
|
37 |
+
.. code-block:: python
|
38 |
+
|
39 |
+
import numpy as np
|
40 |
+
from deeprobust.graph.data import Dataset
|
41 |
+
from deeprobust.graph.defense import GCN
|
42 |
+
from deeprobust.graph.global_attack import Metattack
|
43 |
+
data = Dataset(root='/tmp/', name='cora')
|
44 |
+
adj, features, labels = data.adj, data.features, data.labels
|
45 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
46 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
47 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
48 |
+
# Setup Surrogate model
|
49 |
+
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
50 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device='cpu').to('cpu')
|
51 |
+
surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
52 |
+
# Setup Attack Model
|
53 |
+
model = Metattack(surrogate, nnodes=adj.shape[0], feature_shape=features.shape,
|
54 |
+
attack_structure=True, attack_features=False, device='cpu', lambda_=0).to('cpu')
|
55 |
+
# Attack
|
56 |
+
model.attack(features, adj, labels, idx_train, idx_unlabeled, n_perturbations=10, ll_constraint=False)
|
57 |
+
modified_adj = model.modified_adj # modified_adj is a torch.tensor
|
58 |
+
|
59 |
+
|
60 |
+
Targeted Attack for Node Classification
|
61 |
+
-----------------------
|
62 |
+
Targeted attack aims to fool GNNs into give wrong predictions on a
|
63 |
+
subset of nodes. Specifically, DeepRobust provides the following targeted
|
64 |
+
attack algorithms:
|
65 |
+
|
66 |
+
- :class:`deeprobust.graph.targeted_attack.Nettack`
|
67 |
+
- :class:`deeprobust.graph.targeted_attack.RLS2V`
|
68 |
+
- :class:`deeprobust.graph.targeted_attack.FGA`
|
69 |
+
- :class:`deeprobust.graph.targeted_attack.RND`
|
70 |
+
- :class:`deeprobust.graph.targeted_attack.IGAttack`
|
71 |
+
|
72 |
+
All the above attacks take the adjacency matrix, node feature matrix and labels as input.
|
73 |
+
Usually, the adjacency matrix is in the format of :obj:`scipy.sparse.csr_matrix` and feature
|
74 |
+
matrix can either be :obj:`scipy.sparse.csr_matrix` or :obj:`numpy.array`. Now let's take a look at an example:
|
75 |
+
|
76 |
+
.. code-block:: python
|
77 |
+
|
78 |
+
from deeprobust.graph.data import Dataset
|
79 |
+
from deeprobust.graph.defense import GCN
|
80 |
+
from deeprobust.graph.targeted_attack import Nettack
|
81 |
+
data = Dataset(root='/tmp/', name='cora')
|
82 |
+
adj, features, labels = data.adj, data.features, data.labels
|
83 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
84 |
+
# Setup Surrogate model
|
85 |
+
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
86 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device='cpu').to('cpu')
|
87 |
+
surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
88 |
+
# Setup Attack Model
|
89 |
+
target_node = 0
|
90 |
+
model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device='cpu').to('cpu')
|
91 |
+
# Attack
|
92 |
+
model.attack(features, adj, labels, target_node, n_perturbations=5)
|
93 |
+
modified_adj = model.modified_adj # scipy sparse matrix
|
94 |
+
modified_features = model.modified_features # scipy sparse matrix
|
95 |
+
|
96 |
+
Note that we also provide scripts in :download:`test_nettack.py <https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_nettack.py>`
|
97 |
+
for selecting nodes as reported in the
|
98 |
+
`nettack <https://arxiv.org/abs/1805.07984>`_ paper: (1) the 10 nodes
|
99 |
+
with highest margin of classification, i.e. they are clearly correctly classified,
|
100 |
+
(2) the 10 nodes with lowest margin (but still correctly classified) and
|
101 |
+
(3) 20 more nodes randomly.
|
102 |
+
|
103 |
+
|
104 |
+
More Examples
|
105 |
+
-----------------------
|
106 |
+
More examples can be found in :class:`deeprobust.graph.targeted_attack` and
|
107 |
+
:class:`deeprobust.graph.global_attack`. You can also find examples in
|
108 |
+
`github code examples <https://github.com/DSE-MSU/DeepRobust/tree/master/examples/graph>`_
|
109 |
+
and more details in `attacks table <https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph#attack-methods>`_.
|
docs/graph/data.rst
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Graph Dataset
|
2 |
+
=======================
|
3 |
+
|
4 |
+
We briefly introduce the dataset format of DeepRobust through self-contained examples.
|
5 |
+
In essence, DeepRobust-Graph provides the following main features:
|
6 |
+
|
7 |
+
.. contents::
|
8 |
+
:local:
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
Clean (Unattacked) Graphs for Node Classification
|
13 |
+
-----------------------
|
14 |
+
Graphs are ubiquitous data structures describing pairwise relations between entities.
|
15 |
+
A single clean graph in DeepRobust is described by an instance of :class:`deeprobust.graph.data.Dataset`, which holds the following attributes by default:
|
16 |
+
|
17 |
+
- :obj:`data.adj`: Graph adjacency matrix in scipy.sparse.csr_matrix format with shape :obj:`[num_nodes, num_nodes]`
|
18 |
+
- :obj:`data.features`: Node feature matrix with shape :obj:`[num_nodes, num_node_features]`
|
19 |
+
- :obj:`data.labels`: Target to train against (may have arbitrary shape), *e.g.*, node-level targets of shape :obj:`[num_nodes, *]`
|
20 |
+
- :obj:`data.train_idx`: Array of training node indices
|
21 |
+
- :obj:`data.val_idx`: Array of validation node indices
|
22 |
+
- :obj:`data.test_idx`: Array of test node indices
|
23 |
+
|
24 |
+
By default, the loaded :obj:`deeprobust.graph.data.Dataset` will select the largest connect
|
25 |
+
component of the graph, but users specify different settings by giving different parameters.
|
26 |
+
|
27 |
+
Currently DeepRobust supports the following datasets:
|
28 |
+
:obj:`Cora`,
|
29 |
+
:obj:`Cora-ML`,
|
30 |
+
:obj:`Citeseer`,
|
31 |
+
:obj:`Pubmed`,
|
32 |
+
:obj:`Polblogs`,
|
33 |
+
:obj:`ACM`,
|
34 |
+
:obj:`BlogCatalog`,
|
35 |
+
:obj:`Flickr`,
|
36 |
+
:obj:`UAI`.
|
37 |
+
More details about the datasets can be found `here <https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph#supported-datasets>`_.
|
38 |
+
|
39 |
+
|
40 |
+
By default, the data splits are generated by :obj:`deeprobust.graph.utils.get_train_val_test`,
|
41 |
+
which randomly split the data into 10%/10%/80% for training/validaiton/test. You can also generate
|
42 |
+
splits by yourself by using :obj:`deeprobust.graph.utils.get_train_val_test` or :obj:`deeprobust.graph.utils.get_train_val_test_gcn`.
|
43 |
+
It is worth noting that there is parameter :obj:`setting` that can be passed into this class. It can be chosen from `["nettack", "gcn", "prognn"]`:
|
44 |
+
|
45 |
+
- :obj:`setting="nettack"`: the data splits are 10%/10%/80% and using the largest connected component of the graph;
|
46 |
+
- :obj:`setting="gcn"`: use the full graph and the data splits will be: 20 nodes per class for training, 500 nodes for validation and 1000 nodes for testing (randomly choosen);
|
47 |
+
- :obj:`setting="prognn"`: use the largest connected component and the data splits are provided by `ProGNN <https://github.com/ChandlerBang/Pro-GNN>`_ (10%/10%/80%);
|
48 |
+
|
49 |
+
|
50 |
+
.. note::
|
51 |
+
The 'netack' and 'gcn' setting do not provide fixed split, i.e.,
|
52 |
+
different random seed would return different data splits.
|
53 |
+
|
54 |
+
.. note::
|
55 |
+
If you hope to use the full graph, please use the 'gcn' setting.
|
56 |
+
|
57 |
+
The following example shows how to load DeepRobust datasets
|
58 |
+
|
59 |
+
.. code-block:: python
|
60 |
+
|
61 |
+
from deeprobust.graph.data import Dataset
|
62 |
+
# loading cora dataset
|
63 |
+
data = Dataset(root='/tmp/', name='cora', seed=15)
|
64 |
+
adj, features, labels = data.adj, data.features, data.labels
|
65 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
66 |
+
# you can also split the data by yourself
|
67 |
+
idx_train, idx_val, idx_test = get_train_val_test(adj.shape[0], val_size=0.1, test_size=0.8)
|
68 |
+
|
69 |
+
# loading acm dataset
|
70 |
+
data = Dataset(root='/tmp/', name='acm', seed=15)
|
71 |
+
|
72 |
+
|
73 |
+
DeepRobust also provides access to Amazon and Coauthor datasets loaded from Pytorch Geometric:
|
74 |
+
:obj:`Amazon-Computers`,
|
75 |
+
:obj:`Amazon-Photo`,
|
76 |
+
:obj:`Coauthor-CS`,
|
77 |
+
:obj:`Coauthor-Physics`.
|
78 |
+
|
79 |
+
Users can also easily create their own datasets by creating a class with the following attributes: :obj:`data.adj`, :obj:`data.features`, :obj:`data.labels`, :obj:`data.train_idx`, :obj:`data.val_idx`, :obj:`data.test_idx`.
|
80 |
+
|
81 |
+
Attacked Graphs for Node Classification
|
82 |
+
-----------------------
|
83 |
+
DeepRobust provides the attacked graphs perturbed by `metattack <https://openreview.net/pdf?id=Bylnx209YX>`_ and `nettack <https://arxiv.org/abs/1805.07984>`_. The graphs are attacked using authors' Tensorflow implementation, on random split using seed 15. The download link can be found in `ProGNN code <https://github.com/ChandlerBang/Pro-GNN/tree/master/splits>`_ and the performance of various GNNs can be found in `ProGNN paper <https://arxiv.org/abs/2005.10203>`_. They are instances of :class:`deeprobust.graph.data.PrePtbDataset` with only one attribute :obj:`adj`. Hence, :class:`deeprobust.graph.data.PrePtbDataset` is often used together with :class:`deeprobust.graph.data.Dataset` to obtain node features and labels.
|
84 |
+
|
85 |
+
For metattack, DeepRobust provides attacked graphs for Cora, Citeseer, Polblogs and Pubmed,
|
86 |
+
and the perturbation rate can be chosen from [0.05, 0.1, 0.15, 0.2, 0.25].
|
87 |
+
|
88 |
+
.. code-block:: python
|
89 |
+
|
90 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
91 |
+
# You can either use setting='prognn' or seed=15 to get the prognn splits
|
92 |
+
# data = Dataset(root='/tmp/', name='cora', seed=15) # since the attacked graph are generated under seed 15
|
93 |
+
data = Dataset(root='/tmp/', name='cora', setting='prognn')
|
94 |
+
adj, features, labels = data.adj, data.features, data.labels
|
95 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
96 |
+
# Load meta attacked data
|
97 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
98 |
+
name='cora',
|
99 |
+
attack_method='meta',
|
100 |
+
ptb_rate=0.05)
|
101 |
+
perturbed_adj = perturbed_data.adj
|
102 |
+
|
103 |
+
For nettack, DeepRobust provides attacked graphs for Cora, Citeseer, Polblogs and Pubmed,
|
104 |
+
and ptb_rate indicates the number of perturbations made on each node.
|
105 |
+
It can be chosen from [1.0, 2.0, 3.0, 4.0, 5.0].
|
106 |
+
|
107 |
+
.. code-block:: python
|
108 |
+
|
109 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
110 |
+
# data = Dataset(root='/tmp/', name='cora', seed=15)
|
111 |
+
data = Dataset(root='/tmp/', name='cora', setting='prognn')
|
112 |
+
adj, features, labels = data.adj, data.features, data.labels
|
113 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
114 |
+
# Load nettack attacked data
|
115 |
+
perturbed_data = PrePtbDataset(root='/tmp/', name='cora',
|
116 |
+
attack_method='nettack',
|
117 |
+
ptb_rate=3.0) # here ptb_rate means number of perturbation per nodes
|
118 |
+
perturbed_adj = perturbed_data.adj
|
119 |
+
idx_test = perturbed_data.target_nodes
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
Converting Graph Data between DeepRobust and PyTorch Geometric
|
124 |
+
-----------------------
|
125 |
+
Given the popularity of PyTorch Geometric in the graph representation learning community,
|
126 |
+
we also provide tools for converting data between DeepRobust and PyTorch Geometric. We can
|
127 |
+
use :class:`deeprobust.graph.data.Dpr2Pyg` to convert DeepRobust data to PyTorch Geometric
|
128 |
+
and use :class:`deeprobust.graph.data.Pyg2Dpr` to convert Pytorch Geometric data to DeepRobust.
|
129 |
+
For example, we can first create an instance of the Dataset class and convert it to pytorch geometric data format.
|
130 |
+
|
131 |
+
.. code-block:: python
|
132 |
+
|
133 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg, Pyg2Dpr
|
134 |
+
data = Dataset(root='/tmp/', name='cora') # load clean graph
|
135 |
+
pyg_data = Dpr2Pyg(data) # convert dpr to pyg
|
136 |
+
print(pyg_data)
|
137 |
+
print(pyg_data[0])
|
138 |
+
dpr_data = Pyg2Dpr(pyg_data) # convert pyg to dpr
|
139 |
+
print(dpr_data.adj)
|
140 |
+
|
141 |
+
|
142 |
+
Load OGB Datasets
|
143 |
+
-----------------------
|
144 |
+
`Open Graph Benchmark (OGB) <https://ogb.stanford.edu/>`_ has provided various benchmark
|
145 |
+
datasets. DeepRobsut now provides interface to convert OGB dataset format (Pyg data format)
|
146 |
+
to DeepRobust format.
|
147 |
+
|
148 |
+
.. code-block:: python
|
149 |
+
|
150 |
+
from ogb.nodeproppred import PygNodePropPredDataset
|
151 |
+
from deeprobust.graph.data import Pyg2Dpr
|
152 |
+
pyg_data = PygNodePropPredDataset(name = 'ogbn-arxiv')
|
153 |
+
dpr_data = Pyg2Dpr(pyg_data) # convert pyg to dpr
|
154 |
+
|
155 |
+
|
156 |
+
Load Pytorch Geometric Amazon and Coauthor Datasets
|
157 |
+
-----------------------
|
158 |
+
DeepRobust also provides access to the Amazon datasets and Coauthor datasets, i.e.,
|
159 |
+
`Amazon-Computers`, `Amazon-Photo`, `Coauthor-CS`, `Coauthor-Physics`, from Pytorch
|
160 |
+
Geometric. Specifically, users can access them through
|
161 |
+
:class:`deeprobust.graph.data.AmazonPyg` and :class:`deeprobust.graph.data.CoauthorPyg`.
|
162 |
+
For example, we can directly load Amazon dataset from deeprobust in the format of pyg
|
163 |
+
as follows,
|
164 |
+
|
165 |
+
.. code-block:: python
|
166 |
+
|
167 |
+
from deeprobust.graph.data import AmazonPyg
|
168 |
+
computers = AmazonPyg(root='/tmp', name='computers')
|
169 |
+
print(computers)
|
170 |
+
print(computers[0])
|
171 |
+
photo = AmazonPyg(root='/tmp', name='photo')
|
172 |
+
print(photo)
|
173 |
+
print(photo[0])
|
174 |
+
|
175 |
+
|
176 |
+
Similarly, we can also load Coauthor dataset,
|
177 |
+
|
178 |
+
.. code-block:: python
|
179 |
+
|
180 |
+
from deeprobust.graph.data import CoauthorPyg
|
181 |
+
cs = CoauthorPyg(root='/tmp', name='cs')
|
182 |
+
print(cs)
|
183 |
+
print(cs[0])
|
184 |
+
physics = CoauthorPyg(root='/tmp', name='physics')
|
185 |
+
print(physics)
|
186 |
+
print(physics[0])
|
187 |
+
|
188 |
+
|
docs/graph/pyg.rst
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Using PyTorch Geometric in DeepRobust
|
2 |
+
========
|
3 |
+
DeepRobust now provides interface to convert the data between
|
4 |
+
PyTorch Geometric and DeepRobust.
|
5 |
+
|
6 |
+
.. note::
|
7 |
+
Before we start, make sure you have successfully installed `torch_geometric
|
8 |
+
<https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html>`_.
|
9 |
+
After you install torch_geometric, please reinstall DeepRobust to activate
|
10 |
+
the following functions.
|
11 |
+
|
12 |
+
.. contents::
|
13 |
+
:local:
|
14 |
+
|
15 |
+
Converting Graph Data between DeepRobust and PyTorch Geometric
|
16 |
+
-----------------------
|
17 |
+
Given the popularity of PyTorch Geometric in the graph representation learning community,
|
18 |
+
we also provide tools for converting data between DeepRobust and PyTorch Geometric. We can
|
19 |
+
use :class:`deeprobust.graph.data.Dpr2Pyg` to convert DeepRobust data to PyTorch Geometric
|
20 |
+
and use :class:`deeprobust.graph.data.Pyg2Dpr` to convert Pytorch Geometric data to DeepRobust.
|
21 |
+
For example, we can first create an instance of the Dataset class and convert it to pytorch geometric data format.
|
22 |
+
|
23 |
+
.. code-block:: python
|
24 |
+
|
25 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg, Pyg2Dpr
|
26 |
+
data = Dataset(root='/tmp/', name='cora') # load clean graph
|
27 |
+
pyg_data = Dpr2Pyg(data) # convert dpr to pyg
|
28 |
+
print(pyg_data)
|
29 |
+
print(pyg_data[0])
|
30 |
+
dpr_data = Pyg2Dpr(pyg_data) # convert pyg to dpr
|
31 |
+
print(dpr_data.adj)
|
32 |
+
|
33 |
+
For the attacked graph :class:`deeprobust.graph.PrePtbDataset`, it only has the attribute :obj:`adj`.
|
34 |
+
To convert it to PyTorch Geometric data format, we can first convert the clean graph to Pyg and
|
35 |
+
then update its :obj:`edge_index`:
|
36 |
+
|
37 |
+
.. code-block:: python
|
38 |
+
|
39 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset, Dpr2Pyg
|
40 |
+
data = Dataset(root='/tmp/', name='cora') # load clean graph
|
41 |
+
pyg_data = Dpr2Pyg(data) # convert dpr to pyg
|
42 |
+
# load perturbed graph
|
43 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
44 |
+
name='cora',
|
45 |
+
attack_method='meta',
|
46 |
+
ptb_rate=0.05)
|
47 |
+
perturbed_adj = perturbed_data.adj
|
48 |
+
pyg_data.update_edge_index(perturbed_adj) # inplace operation
|
49 |
+
|
50 |
+
Now :obj:`pyg_data` becomes the perturbed data in the format of PyTorch Geometric.
|
51 |
+
We can then use it as the input for various Pytorch Geometric models!
|
52 |
+
|
53 |
+
Load OGB Datasets
|
54 |
+
-----------------------
|
55 |
+
`Open Graph Benchmark (OGB) <https://ogb.stanford.edu/>`_ has provided various benchmark
|
56 |
+
datasets. DeepRobsut now provides interface to convert OGB dataset format (Pyg data format)
|
57 |
+
to DeepRobust format.
|
58 |
+
|
59 |
+
.. code-block:: python
|
60 |
+
|
61 |
+
from ogb.nodeproppred import PygNodePropPredDataset
|
62 |
+
from deeprobust.graph.data import Pyg2Dpr
|
63 |
+
pyg_data = PygNodePropPredDataset(name = 'ogbn-arxiv')
|
64 |
+
dpr_data = Pyg2Dpr(pyg_data) # convert pyg to dpr
|
65 |
+
|
66 |
+
|
67 |
+
Load Pytorch Geometric Amazon and Coauthor Datasets
|
68 |
+
-----------------------
|
69 |
+
DeepRobust also provides access to the Amazon datasets and Coauthor datasets, i.e.,
|
70 |
+
`Amazon-Computers`, `Amazon-Photo`, `Coauthor-CS`, `Coauthor-Physics`, from Pytorch
|
71 |
+
Geometric. Specifically, users can access them through
|
72 |
+
:class:`deeprobust.graph.data.AmazonPyg` and :class:`deeprobust.graph.data.CoauthorPyg`.
|
73 |
+
For example, we can directly load Amazon dataset from deeprobust in the format of pyg
|
74 |
+
as follows,
|
75 |
+
|
76 |
+
.. code-block:: python
|
77 |
+
|
78 |
+
from deeprobust.graph.data import AmazonPyg
|
79 |
+
computers = AmazonPyg(root='/tmp', name='computers')
|
80 |
+
print(computers)
|
81 |
+
print(computers[0])
|
82 |
+
photo = AmazonPyg(root='/tmp', name='photo')
|
83 |
+
print(photo)
|
84 |
+
print(photo[0])
|
85 |
+
|
86 |
+
|
87 |
+
Similarly, we can also load Coauthor dataset,
|
88 |
+
|
89 |
+
.. code-block:: python
|
90 |
+
|
91 |
+
from deeprobust.graph.data import CoauthorPyg
|
92 |
+
cs = CoauthorPyg(root='/tmp', name='cs')
|
93 |
+
print(cs)
|
94 |
+
print(cs[0])
|
95 |
+
physics = CoauthorPyg(root='/tmp', name='physics')
|
96 |
+
print(physics)
|
97 |
+
print(physics[0])
|
98 |
+
|
99 |
+
|
100 |
+
Working on PyTorch Geometric Models
|
101 |
+
-----------
|
102 |
+
In this subsection, we provide examples for using GNNs based on
|
103 |
+
PyTorch Geometric. Spefically, we use GAT :class:`deeprobust.graph.defense.GAT` and
|
104 |
+
ChebNet :class:`deeprobust.graph.defense.ChebNet` to further illustrate (while :class:`deeprobust.graph.defense.SGC` is also available in this library).
|
105 |
+
Basically, we can first convert the DeepRobust data to PyTorch Geometric
|
106 |
+
data and then train Pyg models.
|
107 |
+
|
108 |
+
.. code-block:: python
|
109 |
+
|
110 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg, PrePtbDataset
|
111 |
+
from deeprobust.graph.defense import GAT
|
112 |
+
data = Dataset(root='/tmp/', name='cora', seed=15)
|
113 |
+
adj, features, labels = data.adj, data.features, data.labels
|
114 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
115 |
+
gat = GAT(nfeat=features.shape[1],
|
116 |
+
nhid=8, heads=8,
|
117 |
+
nclass=labels.max().item() + 1,
|
118 |
+
dropout=0.5, device='cpu')
|
119 |
+
gat = gat.to('cpu')
|
120 |
+
pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
|
121 |
+
gat.fit(pyg_data, patience=100, verbose=True) # train with earlystopping
|
122 |
+
gat.test() # test performance on clean graph
|
123 |
+
|
124 |
+
# load perturbed graph
|
125 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
126 |
+
name='cora',
|
127 |
+
attack_method='meta',
|
128 |
+
ptb_rate=0.05)
|
129 |
+
perturbed_adj = perturbed_data.adj
|
130 |
+
pyg_data.update_edge_index(perturbed_adj) # inplace operation
|
131 |
+
gat.fit(pyg_data, patience=100, verbose=True) # train with earlystopping
|
132 |
+
gat.test() # test performance on perturbed graph
|
133 |
+
|
134 |
+
|
135 |
+
.. code-block:: python
|
136 |
+
|
137 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
138 |
+
from deeprobust.graph.defense import ChebNet
|
139 |
+
data = Dataset(root='/tmp/', name='cora')
|
140 |
+
adj, features, labels = data.adj, data.features, data.labels
|
141 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
142 |
+
cheby = ChebNet(nfeat=features.shape[1],
|
143 |
+
nhid=16, num_hops=3,
|
144 |
+
nclass=labels.max().item() + 1,
|
145 |
+
dropout=0.5, device='cpu')
|
146 |
+
cheby = cheby.to('cpu')
|
147 |
+
pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
|
148 |
+
cheby.fit(pyg_data, patience=10, verbose=True) # train with earlystopping
|
149 |
+
cheby.test()
|
150 |
+
|
151 |
+
|
152 |
+
More Details
|
153 |
+
-----------------------
|
154 |
+
More details can be found in
|
155 |
+
`test_gat.py <https://github.com/DSE-MSU/DeepRobust/tree/master/examples/graph/test_gat.py>`_, `test_chebnet.py <https://github.com/DSE-MSU/DeepRobust/tree/master/examples/graph/test_chebnet.py>`_ and `test_sgc.py <https://github.com/DSE-MSU/DeepRobust/tree/master/examples/graph/test_sgc.py>`_.
|
docs/image/example.rst
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Image Attack and Defense
|
3 |
+
============
|
4 |
+
We introduce the usage of attacks and defense API in image package.
|
5 |
+
|
6 |
+
.. contents::
|
7 |
+
:local:
|
8 |
+
|
9 |
+
|
10 |
+
Attack Example
|
11 |
+
------------
|
12 |
+
|
13 |
+
.. code-block:: python
|
14 |
+
|
15 |
+
from deeprobust.image.attack.pgd import PGD
|
16 |
+
from deeprobust.image.config import attack_params
|
17 |
+
from deeprobust.image.utils import download_model
|
18 |
+
import torch
|
19 |
+
import deeprobust.image.netmodels.resnet as resnet
|
20 |
+
|
21 |
+
URL = "https://github.com/I-am-Bot/deeprobust_model/raw/master/CIFAR10_ResNet18_epoch_50.pt"
|
22 |
+
download_model(URL, "$MODEL_PATH$")
|
23 |
+
|
24 |
+
model = resnet.ResNet18().to('cuda')
|
25 |
+
model.load_state_dict(torch.load("$MODEL_PATH$"))
|
26 |
+
model.eval()
|
27 |
+
|
28 |
+
transform_val = transforms.Compose([transforms.ToTensor()])
|
29 |
+
test_loader = torch.utils.data.DataLoader(
|
30 |
+
datasets.CIFAR10('deeprobust/image/data', train = False, download=True,
|
31 |
+
transform = transform_val),
|
32 |
+
batch_size = 10, shuffle=True)
|
33 |
+
|
34 |
+
x, y = next(iter(test_loader))
|
35 |
+
x = x.to('cuda').float()
|
36 |
+
|
37 |
+
adversary = PGD(model, device)
|
38 |
+
Adv_img = adversary.generate(x, y, **attack_params['PGD_CIFAR10'])
|
39 |
+
|
40 |
+
Defense Example
|
41 |
+
------------
|
42 |
+
|
43 |
+
.. code-block:: python
|
44 |
+
|
45 |
+
model = Net()
|
46 |
+
train_loader = torch.utils.data.DataLoader(
|
47 |
+
datasets.MNIST('deeprobust/image/defense/data', train=True, download=True,
|
48 |
+
transform=transforms.Compose([transforms.ToTensor()])),
|
49 |
+
batch_size=100, shuffle=True)
|
50 |
+
test_loader = torch.utils.data.DataLoader(
|
51 |
+
datasets.MNIST('deeprobust/image/defense/data', train=False,
|
52 |
+
transform=transforms.Compose([transforms.ToTensor()])),
|
53 |
+
batch_size=1000,shuffle=True)
|
54 |
+
|
55 |
+
defense = PGDtraining(model, 'cuda')
|
56 |
+
defense.generate(train_loader, test_loader, **defense_params["PGDtraining_MNIST"])
|
57 |
+
|
58 |
+
|
docs/notes/installation.rst
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Installation
|
2 |
+
============
|
3 |
+
#. Activate your virtual environment
|
4 |
+
|
5 |
+
#. Install package
|
6 |
+
|
7 |
+
Install the newest deeprobust:
|
8 |
+
|
9 |
+
.. code-block:: none
|
10 |
+
|
11 |
+
git clone https://github.com/DSE-MSU/DeepRobust.git
|
12 |
+
cd DeepRobust
|
13 |
+
python setup.py install
|
14 |
+
|
15 |
+
Or install via pip (may not contain all the new features)
|
16 |
+
|
17 |
+
.. code-block:: none
|
18 |
+
|
19 |
+
pip install deeprobust
|
20 |
+
|
21 |
+
.. note::
|
22 |
+
If you meet any installation problem, feel free to open an issue
|
23 |
+
in the our github `page <https://github.com/DSE-MSU/DeepRobust/issues>`_
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib==3.6.0
|
2 |
+
numpy==1.23.5
|
3 |
+
ipdb==0.13.13
|
4 |
+
torch==2.0.1
|
5 |
+
scipy==1.11.3
|
6 |
+
torchvision==0.15.2
|
7 |
+
texttable==1.6.7
|
8 |
+
networkx==3.0
|
9 |
+
numba==0.57.1
|
10 |
+
Pillow==9.4.0
|
11 |
+
scikit-learn==1.2.0
|
12 |
+
tensorboardX==2.6
|
13 |
+
tqdm==4.64.1
|
14 |
+
gensim==4.3.0
|