Add files using upload-large-folder tool
Browse files- LICENSE +21 -0
- adversary_examples/cifar_advexample_orig.png +0 -0
- deeprobust/graph/README.md +76 -0
- deeprobust/graph/__init__.py +1 -0
- deeprobust/graph/data/__init__.py +16 -0
- deeprobust/graph/data/dataset.py +333 -0
- deeprobust/graph/data/pyg_dataset.py +308 -0
- deeprobust/graph/data/utils.py +10 -0
- deeprobust/graph/defense/__init__.py +23 -0
- deeprobust/graph/defense/pgd.py +207 -0
- deeprobust/graph/defense/simpgcn.py +474 -0
- deeprobust/graph/defense_pyg/gat.py +100 -0
- deeprobust/graph/defense_pyg/gcn.py +110 -0
- deeprobust/graph/global_attack/base_attack.py +130 -0
- deeprobust/graph/global_attack/node_embedding_attack.py +522 -0
- deeprobust/graph/global_attack/prbcd.py +440 -0
- deeprobust/graph/rl/nipa_env.py +169 -0
- deeprobust/graph/rl/rl_s2v_config.py +57 -0
- deeprobust/graph/targeted_attack/__init__.py +9 -0
- deeprobust/graph/targeted_attack/base_attack.py +126 -0
- deeprobust/graph/targeted_attack/fga.py +124 -0
- deeprobust/graph/targeted_attack/ig_attack.py +224 -0
- deeprobust/graph/targeted_attack/nettack.py +624 -0
- deeprobust/graph/targeted_attack/rnd.py +139 -0
- deeprobust/graph/targeted_attack/sga.py +323 -0
- deeprobust/graph/targeted_attack/ugba.py +913 -0
- deeprobust/graph/utils.py +778 -0
- deeprobust/image/README.md +45 -0
- deeprobust/image/__init__.py +11 -0
- deeprobust/image/attack/Nattack.py +181 -0
- deeprobust/image/attack/fgsm.py +121 -0
- deeprobust/image/attack/onepixel.py +186 -0
- deeprobust/image/defense/AWP.py +301 -0
- deeprobust/image/defense/TherEncoding.py +203 -0
- deeprobust/image/defense/YOPO.py +410 -0
- deeprobust/image/defense/__init__.py +6 -0
- deeprobust/image/defense/base_defense.py +100 -0
- deeprobust/image/defense/fast.py +169 -0
- deeprobust/image/defense/fgsmtraining.py +227 -0
- deeprobust/image/defense/pgdtraining.py +229 -0
- deeprobust/image/defense/trades.py +241 -0
- deeprobust/image/optimizer.py +914 -0
- deeprobust/image/preprocessing/APE-GAN.py +127 -0
- deeprobust/image/preprocessing/prepare_advdata.py +62 -0
- deeprobust/image/utils.py +211 -0
- docs/graph/defense.rst +109 -0
- docs/graph/node_embedding.rst +110 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Yaxin Li, Wei Jin, Han Xu and Jiliang Tang.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
adversary_examples/cifar_advexample_orig.png
ADDED
![]() |
deeprobust/graph/README.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Setup
|
2 |
+
```
|
3 |
+
git clone https://github.com/DSE-MSU/DeepRobust.git
|
4 |
+
cd DeepRobust
|
5 |
+
python setup.py install
|
6 |
+
```
|
7 |
+
# Test Examples
|
8 |
+
Test GCN on perturbed graph (5% metattack)
|
9 |
+
```
|
10 |
+
python examples/graph/test_gcn.py --dataset cora
|
11 |
+
```
|
12 |
+
Test GCN-Jaccard on perturbed graph (5% metattack)
|
13 |
+
```
|
14 |
+
python examples/graph/test_gcn_jaccard.py --dataset cora
|
15 |
+
```
|
16 |
+
Generate attack by yourself
|
17 |
+
```
|
18 |
+
python examples/graph/test_mettack.py --dataset cora --ptb_rate 0.05
|
19 |
+
```
|
20 |
+
For a practice of deeprobust graph package, you can also refer to https://github.com/ChandlerBang/Pro-GNN.
|
21 |
+
|
22 |
+
|
23 |
+
# Full README
|
24 |
+
[click here](https://github.com/DSE-MSU/DeepRobust)
|
25 |
+
|
26 |
+
# Supported Datasets
|
27 |
+
* Cora
|
28 |
+
* Cora-ML
|
29 |
+
* Citeseer
|
30 |
+
* Pubmed
|
31 |
+
* Polblogs
|
32 |
+
* ACM: [link1](https://github.com/zhumeiqiBUPT/AM-GCN) [link2](https://github.com/Jhy1993/HAN)
|
33 |
+
* BlogCatalog: [link](https://github.com/mengzaiqiao/CAN)
|
34 |
+
* Flickr: [link](https://github.com/mengzaiqiao/CAN)
|
35 |
+
* UAI: A Unifed Weakly Supervised Framework for Community Detection and Semantic Matching.
|
36 |
+
* PyTorch Geometric Datasets: Amazon-Computers, Amazon-Photo, CoauthorCS CoauthorPhysics...
|
37 |
+
|
38 |
+
For more details, please take a look at [dataset.py](https://github.com/DSE-MSU/DeepRobust/blob/master/deeprobust/graph/data/dataset.py)
|
39 |
+
|
40 |
+
# Attack Methods
|
41 |
+
| Attack Methods | Type<img width=200> | Perturbation <img width=80> | Evasion/<br>Poisoning | Apply Domain | Paper | Code |
|
42 |
+
|--------------------|------|--------------------|-------------|-------|----|----|
|
43 |
+
| Nettack | Targeted Attack | Structure<br>Features | Both | Node Classification | [Adversarial Attacks on Neural Networks for Graph Data](https://arxiv.org/pdf/1805.07984.pdf)| [test_nettack.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_nettack.py) |
|
44 |
+
| FGA | Targeted Attack | Structure | Both | Node Classification | [Fast Gradient Attack on Network Embedding](https://arxiv.org/pdf/1809.02797.pdf)| [test_fga.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_fga.py) |
|
45 |
+
| Metattack | Global Attack | Structure<br>Features | Poisoning | Node Classification | [Adversarial Attacks on Graph Neural Networks via Meta Learning](https://openreview.net/pdf?id=Bylnx209YX) | [test_mettack.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_mettack.py) |
|
46 |
+
| RL-S2V | Targeted Attack | Structure | Evasion | Node Classification | [Adversarial Attack on Graph Structured Data](https://arxiv.org/pdf/1806.02371.pdf) |[test_rl_s2v.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_rl_s2v.py) |
|
47 |
+
| Node Embedding Attack | Global Attack | Structure | Poisoning | Node Embedding | [Adversarial Attacks on Node Embeddings via Graph Poisoning](https://arxiv.org/abs/1809.01093) | [test_node_embedding_attack.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_node_embedding_attack.py) |
|
48 |
+
| Baselines for Node Embedding Attack <br> Degree, eigencentrality and random | Global Attack | Structure | Poisoning | Node Embedding | [Adversarial Attacks on Node Embeddings via Graph Poisoning](https://arxiv.org/abs/1809.01093) | [test_node_embedding_attack.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_node_embedding_attack.py) |
|
49 |
+
| PGD, Min-max | Global Attack | Structure | Both | Node Classification | [Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective](https://arxiv.org/pdf/1906.04214.pdf)|[test_pgd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_pgd.py) [test_min_max.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_min_max.py) |
|
50 |
+
| DICE | Global Attack | Structure | Both | Node Classification | [Hiding individuals and communities in a social network](https://arxiv.org/abs/1608.00375)|[test_dice.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_dice.py) |
|
51 |
+
| IG-Attack | Targeted Attack | Structure<br>Features| Both | Node Classification | [Adversarial Examples on Graph Data: Deep Insights into Attack and Defense](https://arxiv.org/pdf/1903.01610.pdf)|[test_ig.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_ig.py) |
|
52 |
+
| NIPA | Global Attack | Structure | Poisoning | Node Classification | [Non-target-specific Node Injection Attacks on Graph Neural Networks: A Hierarchical Reinforcement Learning Approach](https://faculty.ist.psu.edu/vhonavar/Papers/www20.pdf) | [test_nipa.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_nipa.py) |
|
53 |
+
| RND | Targeted Attack<br>Global Attack | Structure<br>Features<br>Adding Nodes | Both | Node Classification | |[test_rnd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_rnd.py) |
|
54 |
+
| SGAttack | Targeted Attack | Structure | Poisoning | Node Classification | [Adversarial Attack on Large Scale Graph](https://arxiv.org/abs/2009.03488)| [test_sga.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_sga.py) |
|
55 |
+
|
56 |
+
# Defense Methods
|
57 |
+
| Defense Methods | Defense Type | Apply Domain | Paper | Code |
|
58 |
+
|---------------------|--------------|--------------|------| ------|
|
59 |
+
| GCN | Victim Model | Node Classification | [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) | [test_gcn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_gcn.py) |
|
60 |
+
| ChebNet | Victim Model | Node Classification | [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375) | [test_chebnet.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_chebnet.py) |
|
61 |
+
| SGC | Victim Model | Node Classification | [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153) | [test_sgc.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_sgc.py) |
|
62 |
+
| GAT | Adaptive Aggregation | Node Classification | [Graph Attention Networks](https://arxiv.org/abs/1710.10903) | [test_gat.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_gat.py) |
|
63 |
+
| DeepWalk | Victim Model | Node Embedding | [DeepWalk: Online Learning of Social Representations](https://arxiv.org/abs/1403.6652) | [test_deepwalk.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_deepwalk.py) |
|
64 |
+
| Node2Vec | Victim Model | Node Embedding | [node2vec: Scalable Feature Learning for Networks](https://arxiv.org/abs/1607.00653) | [test_deepwalk.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_deepwalk.py) |
|
65 |
+
| RGCN | Adaptive Aggregation | Node Classification | [Robust Graph Convolutional Networks Against Adversarial Attacks](http://pengcui.thumedialab.com/papers/RGCN.pdf) | [test_rgcn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_rgcn.py) |
|
66 |
+
| GCN-Jaccard | Graph Purifying | Node Classification | [Adversarial Examples on Graph Data: Deep Insights into Attack and Defense](https://arxiv.org/pdf/1903.01610.pdf)| [test_gcn_jaccard.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_gcn_jaccard.py) |
|
67 |
+
| GCN-SVD | Graph Purifying | Node Classification | [All You Need is Low (Rank): Defending Against Adversarial Attacks on Graphs](https://dl.acm.org/doi/pdf/10.1145/3336191.3371789?download=true) | [test_gcn_svd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_gcn_svd.py) |
|
68 |
+
| Adv-training | Adversarial Training | Node Classification | |[test_adv_train_poisoning.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_adv_train_poisoning.py) |
|
69 |
+
| Pro-GNN | Graph Purifying | Node Classification | [Graph Structure Learning for Robust Graph Neural Network](https://arxiv.org/abs/2005.10203)|[test_prognn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_prognn.py) |
|
70 |
+
| SimP-GCN | Adaptive Aggregation | Node Classification | [Node Similarity Preserving Graph Convolutional Networks](https://arxiv.org/abs/2011.09643)|[test_simpgcn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_simpgcn.py) |
|
71 |
+
| MedianGCN | Adaptive Aggregation | Node Classification | [Understanding Structural Vulnerability in Graph Convolutional Networks](https://arxiv.org/abs/2108.06280)|[test_median_gcn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_median_gcn.py) |
|
72 |
+
<!--| Adv-training | Adversarial Training | Node Classification | [Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective](https://arxiv.org/pdf/1906.04214.pdf)|
|
73 |
+
-->
|
74 |
+
<!--| Hidden-Adv-training | Adversarial Training | Node Classification<br>Graph Classification |[To be added]|
|
75 |
+
-->
|
76 |
+
|
deeprobust/graph/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
deeprobust/graph/data/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset import Dataset
|
2 |
+
from .attacked_data import PtbDataset
|
3 |
+
from .attacked_data import PrePtbDataset
|
4 |
+
import warnings
|
5 |
+
try:
|
6 |
+
from .pyg_dataset import Pyg2Dpr, Dpr2Pyg, AmazonPyg, CoauthorPyg
|
7 |
+
except ImportError as e:
|
8 |
+
print(e)
|
9 |
+
warnings.warn("Please install pytorch geometric if you " +
|
10 |
+
"would like to use the datasets from pytorch " +
|
11 |
+
"geometric. See details in https://pytorch-geom" +
|
12 |
+
"etric.readthedocs.io/en/latest/notes/installation.html")
|
13 |
+
|
14 |
+
|
15 |
+
__all__ = ['Dataset', 'PtbDataset', 'PrePtbDataset',
|
16 |
+
'Pyg2Dpr', 'Dpr2Pyg', 'AmazonPyg', 'CoauthorPyg']
|
deeprobust/graph/data/dataset.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.sparse as sp
|
3 |
+
import os.path as osp
|
4 |
+
import os
|
5 |
+
import urllib.request
|
6 |
+
import sys
|
7 |
+
import pickle as pkl
|
8 |
+
import networkx as nx
|
9 |
+
from deeprobust.graph.utils import get_train_val_test, get_train_val_test_gcn
|
10 |
+
import zipfile
|
11 |
+
import json
|
12 |
+
import platform
|
13 |
+
|
14 |
+
class Dataset():
|
15 |
+
"""Dataset class contains four citation network datasets "cora", "cora-ml", "citeseer" and "pubmed",
|
16 |
+
and one blog dataset "Polblogs". Datasets "ACM", "BlogCatalog", "Flickr", "UAI",
|
17 |
+
"Flickr" are also available. See more details in https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph#supported-datasets.
|
18 |
+
The 'cora', 'cora-ml', 'polblogs' and 'citeseer' are downloaded from https://github.com/danielzuegner/gnn-meta-attack/tree/master/data, and 'pubmed' is from https://github.com/tkipf/gcn/tree/master/gcn/data.
|
19 |
+
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
root : string
|
23 |
+
root directory where the dataset should be saved.
|
24 |
+
name : string
|
25 |
+
dataset name, it can be chosen from ['cora', 'citeseer', 'cora_ml', 'polblogs',
|
26 |
+
'pubmed', 'acm', 'blogcatalog', 'uai', 'flickr']
|
27 |
+
setting : string
|
28 |
+
there are two data splits settings. It can be chosen from ['nettack', 'gcn', 'prognn']
|
29 |
+
The 'nettack' setting follows nettack paper where they select the largest connected
|
30 |
+
components of the graph and use 10%/10%/80% nodes for training/validation/test .
|
31 |
+
The 'gcn' setting follows gcn paper where they use the full graph and 20 samples
|
32 |
+
in each class for traing, 500 nodes for validation, and 1000
|
33 |
+
nodes for test. (Note here 'netack' and 'gcn' setting do not provide fixed split, i.e.,
|
34 |
+
different random seed would return different data splits)
|
35 |
+
seed : int
|
36 |
+
random seed for splitting training/validation/test.
|
37 |
+
require_mask : bool
|
38 |
+
setting require_mask True to get training, validation and test mask
|
39 |
+
(self.train_mask, self.val_mask, self.test_mask)
|
40 |
+
|
41 |
+
Examples
|
42 |
+
--------
|
43 |
+
We can first create an instance of the Dataset class and then take out its attributes.
|
44 |
+
|
45 |
+
>>> from deeprobust.graph.data import Dataset
|
46 |
+
>>> data = Dataset(root='/tmp/', name='cora', seed=15)
|
47 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
48 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, root, name, setting='nettack', seed=None, require_mask=False):
|
52 |
+
self.name = name.lower()
|
53 |
+
self.setting = setting.lower()
|
54 |
+
|
55 |
+
assert self.name in ['cora', 'citeseer', 'cora_ml', 'polblogs',
|
56 |
+
'pubmed', 'acm', 'blogcatalog', 'uai', 'flickr'], \
|
57 |
+
'Currently only support cora, citeseer, cora_ml, ' + \
|
58 |
+
'polblogs, pubmed, acm, blogcatalog, flickr'
|
59 |
+
assert self.setting in ['gcn', 'nettack', 'prognn'], "Settings should be" + \
|
60 |
+
" choosen from ['gcn', 'nettack', 'prognn']"
|
61 |
+
|
62 |
+
self.seed = seed
|
63 |
+
# self.url = 'https://raw.githubusercontent.com/danielzuegner/nettack/master/data/%s.npz' % self.name
|
64 |
+
self.url = 'https://raw.githubusercontent.com/danielzuegner/gnn-meta-attack/master/data/%s.npz' % self.name
|
65 |
+
|
66 |
+
if platform.system() == 'Windows':
|
67 |
+
self.root = root
|
68 |
+
else:
|
69 |
+
self.root = osp.expanduser(osp.normpath(root))
|
70 |
+
|
71 |
+
self.data_folder = osp.join(root, self.name)
|
72 |
+
self.data_filename = self.data_folder + '.npz'
|
73 |
+
self.require_mask = require_mask
|
74 |
+
|
75 |
+
self.require_lcc = False if setting == 'gcn' else True
|
76 |
+
self.adj, self.features, self.labels = self.load_data()
|
77 |
+
|
78 |
+
if setting == 'prognn':
|
79 |
+
assert name in ['cora', 'citeseer', 'pubmed', 'cora_ml', 'polblogs', 'Flickr'], "ProGNN splits only " + \
|
80 |
+
"cora, citeseer, pubmed, cora_ml, polblogs, Flickr"
|
81 |
+
self.idx_train, self.idx_val, self.idx_test = self.get_prognn_splits()
|
82 |
+
else:
|
83 |
+
self.idx_train, self.idx_val, self.idx_test = self.get_train_val_test()
|
84 |
+
if self.require_mask:
|
85 |
+
self.get_mask()
|
86 |
+
|
87 |
+
def get_train_val_test(self):
|
88 |
+
"""Get training, validation, test splits according to self.setting (either 'nettack' or 'gcn').
|
89 |
+
"""
|
90 |
+
if self.setting == 'nettack':
|
91 |
+
return get_train_val_test(nnodes=self.adj.shape[0], val_size=0.1, test_size=0.8, stratify=self.labels, seed=self.seed)
|
92 |
+
if self.setting == 'gcn':
|
93 |
+
return get_train_val_test_gcn(self.labels, seed=self.seed)
|
94 |
+
|
95 |
+
def get_prognn_splits(self):
|
96 |
+
"""Get target nodes incides, which is the nodes with degree > 10 in the test set."""
|
97 |
+
url = 'https://raw.githubusercontent.com/ChandlerBang/Pro-GNN/' + \
|
98 |
+
'master/splits/{}_prognn_splits.json'.format(self.name)
|
99 |
+
json_file = osp.join(self.root,
|
100 |
+
'{}_prognn_splits.json'.format(self.name))
|
101 |
+
|
102 |
+
if not osp.exists(json_file):
|
103 |
+
self.download_file(url, json_file)
|
104 |
+
# with open(f'/mnt/home/jinwei2/Projects/nettack/{dataset}_nettacked_nodes.json', 'r') as f:
|
105 |
+
with open(json_file, 'r') as f:
|
106 |
+
idx = json.loads(f.read())
|
107 |
+
return np.array(idx['idx_train']), \
|
108 |
+
np.array(idx['idx_val']), np.array(idx['idx_test'])
|
109 |
+
|
110 |
+
def load_data(self):
|
111 |
+
print('Loading {} dataset...'.format(self.name))
|
112 |
+
if self.name == 'pubmed':
|
113 |
+
return self.load_pubmed()
|
114 |
+
|
115 |
+
if self.name in ['acm', 'blogcatalog', 'uai', 'flickr']:
|
116 |
+
return self.load_zip()
|
117 |
+
|
118 |
+
if not osp.exists(self.data_filename):
|
119 |
+
self.download_npz()
|
120 |
+
|
121 |
+
adj, features, labels = self.get_adj()
|
122 |
+
return adj, features, labels
|
123 |
+
|
124 |
+
def download_file(self, url, file):
|
125 |
+
print('Dowloading from {} to {}'.format(url, file))
|
126 |
+
try:
|
127 |
+
urllib.request.urlretrieve(url, file)
|
128 |
+
except:
|
129 |
+
raise Exception("Download failed! Make sure you have \
|
130 |
+
stable Internet connection and enter the right name")
|
131 |
+
|
132 |
+
def download_npz(self):
|
133 |
+
"""Download adjacen matrix npz file from self.url.
|
134 |
+
"""
|
135 |
+
print('Downloading from {} to {}'.format(self.url, self.data_filename))
|
136 |
+
try:
|
137 |
+
urllib.request.urlretrieve(self.url, self.data_filename)
|
138 |
+
print('Done!')
|
139 |
+
except:
|
140 |
+
raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''')
|
141 |
+
|
142 |
+
def download_pubmed(self, name):
|
143 |
+
url = 'https://raw.githubusercontent.com/tkipf/gcn/master/gcn/data/'
|
144 |
+
try:
|
145 |
+
print('Downloading', url)
|
146 |
+
urllib.request.urlretrieve(url + name, osp.join(self.root, name))
|
147 |
+
print('Done!')
|
148 |
+
except:
|
149 |
+
raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''')
|
150 |
+
|
151 |
+
def download_zip(self, name):
|
152 |
+
url = 'https://raw.githubusercontent.com/ChandlerBang/Pro-GNN/master/other_datasets/{}.zip'.\
|
153 |
+
format(name)
|
154 |
+
try:
|
155 |
+
print('Downlading', url)
|
156 |
+
urllib.request.urlretrieve(url, osp.join(self.root, name+'.zip'))
|
157 |
+
print('Done!')
|
158 |
+
except:
|
159 |
+
raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''')
|
160 |
+
|
161 |
+
def load_zip(self):
|
162 |
+
data_filename = self.data_folder + '.zip'
|
163 |
+
name = self.name
|
164 |
+
if not osp.exists(data_filename):
|
165 |
+
self.download_zip(name)
|
166 |
+
with zipfile.ZipFile(data_filename, 'r') as zip_ref:
|
167 |
+
zip_ref.extractall(self.root)
|
168 |
+
|
169 |
+
feature_path = osp.join(self.data_folder, '{0}.feature'.format(name))
|
170 |
+
label_path = osp.join(self.data_folder, '{0}.label'.format(name))
|
171 |
+
graph_path = osp.join(self.data_folder, '{0}.edge'.format(name))
|
172 |
+
|
173 |
+
f = np.loadtxt(feature_path, dtype = float)
|
174 |
+
l = np.loadtxt(label_path, dtype = int)
|
175 |
+
features = sp.csr_matrix(f, dtype=np.float32)
|
176 |
+
# features = torch.FloatTensor(np.array(features.todense()))
|
177 |
+
struct_edges = np.genfromtxt(graph_path, dtype=np.int32)
|
178 |
+
sedges = np.array(list(struct_edges), dtype=np.int32).reshape(struct_edges.shape)
|
179 |
+
n = features.shape[0]
|
180 |
+
sadj = sp.coo_matrix((np.ones(sedges.shape[0]), (sedges[:, 0], sedges[:, 1])), shape=(n, n), dtype=np.float32)
|
181 |
+
sadj = sadj + sadj.T.multiply(sadj.T > sadj) - sadj.multiply(sadj.T > sadj)
|
182 |
+
label = np.array(l)
|
183 |
+
|
184 |
+
return sadj, features, label
|
185 |
+
|
186 |
+
def load_pubmed(self):
|
187 |
+
dataset = 'pubmed'
|
188 |
+
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
|
189 |
+
objects = []
|
190 |
+
for i in range(len(names)):
|
191 |
+
name = "ind.{}.{}".format(dataset, names[i])
|
192 |
+
data_filename = osp.join(self.root, name)
|
193 |
+
|
194 |
+
if not osp.exists(data_filename):
|
195 |
+
self.download_pubmed(name)
|
196 |
+
|
197 |
+
with open(data_filename, 'rb') as f:
|
198 |
+
if sys.version_info > (3, 0):
|
199 |
+
objects.append(pkl.load(f, encoding='latin1'))
|
200 |
+
else:
|
201 |
+
objects.append(pkl.load(f))
|
202 |
+
|
203 |
+
x, y, tx, ty, allx, ally, graph = tuple(objects)
|
204 |
+
|
205 |
+
|
206 |
+
test_idx_file = "ind.{}.test.index".format(dataset)
|
207 |
+
if not osp.exists(osp.join(self.root, test_idx_file)):
|
208 |
+
self.download_pubmed(test_idx_file)
|
209 |
+
|
210 |
+
test_idx_reorder = parse_index_file(osp.join(self.root, test_idx_file))
|
211 |
+
test_idx_range = np.sort(test_idx_reorder)
|
212 |
+
|
213 |
+
features = sp.vstack((allx, tx)).tolil()
|
214 |
+
features[test_idx_reorder, :] = features[test_idx_range, :]
|
215 |
+
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
|
216 |
+
labels = np.vstack((ally, ty))
|
217 |
+
labels[test_idx_reorder, :] = labels[test_idx_range, :]
|
218 |
+
labels = np.where(labels)[1]
|
219 |
+
return adj, features, labels
|
220 |
+
|
221 |
+
def get_adj(self):
|
222 |
+
adj, features, labels = self.load_npz(self.data_filename)
|
223 |
+
adj = adj + adj.T
|
224 |
+
adj = adj.tolil()
|
225 |
+
adj[adj > 1] = 1
|
226 |
+
|
227 |
+
if self.require_lcc:
|
228 |
+
lcc = self.largest_connected_components(adj)
|
229 |
+
adj = adj[lcc][:, lcc]
|
230 |
+
features = features[lcc]
|
231 |
+
labels = labels[lcc]
|
232 |
+
assert adj.sum(0).A1.min() > 0, "Graph contains singleton nodes"
|
233 |
+
|
234 |
+
# whether to set diag=0?
|
235 |
+
adj.setdiag(0)
|
236 |
+
adj = adj.astype("float32").tocsr()
|
237 |
+
adj.eliminate_zeros()
|
238 |
+
|
239 |
+
assert np.abs(adj - adj.T).sum() == 0, "Input graph is not symmetric"
|
240 |
+
assert adj.max() == 1 and len(np.unique(adj[adj.nonzero()].A1)) == 1, "Graph must be unweighted"
|
241 |
+
|
242 |
+
return adj, features, labels
|
243 |
+
|
244 |
+
def load_npz(self, file_name, is_sparse=True):
|
245 |
+
with np.load(file_name) as loader:
|
246 |
+
# loader = dict(loader)
|
247 |
+
if is_sparse:
|
248 |
+
adj = sp.csr_matrix((loader['adj_data'], loader['adj_indices'],
|
249 |
+
loader['adj_indptr']), shape=loader['adj_shape'])
|
250 |
+
if 'attr_data' in loader:
|
251 |
+
features = sp.csr_matrix((loader['attr_data'], loader['attr_indices'],
|
252 |
+
loader['attr_indptr']), shape=loader['attr_shape'])
|
253 |
+
else:
|
254 |
+
features = None
|
255 |
+
labels = loader.get('labels')
|
256 |
+
else:
|
257 |
+
adj = loader['adj_data']
|
258 |
+
if 'attr_data' in loader:
|
259 |
+
features = loader['attr_data']
|
260 |
+
else:
|
261 |
+
features = None
|
262 |
+
labels = loader.get('labels')
|
263 |
+
if features is None:
|
264 |
+
features = np.eye(adj.shape[0])
|
265 |
+
features = sp.csr_matrix(features, dtype=np.float32)
|
266 |
+
return adj, features, labels
|
267 |
+
|
268 |
+
def largest_connected_components(self, adj, n_components=1):
|
269 |
+
"""Select k largest connected components.
|
270 |
+
|
271 |
+
Parameters
|
272 |
+
----------
|
273 |
+
adj : scipy.sparse.csr_matrix
|
274 |
+
input adjacency matrix
|
275 |
+
n_components : int
|
276 |
+
n largest connected components we want to select
|
277 |
+
"""
|
278 |
+
|
279 |
+
_, component_indices = sp.csgraph.connected_components(adj)
|
280 |
+
component_sizes = np.bincount(component_indices)
|
281 |
+
components_to_keep = np.argsort(component_sizes)[::-1][:n_components] # reverse order to sort descending
|
282 |
+
nodes_to_keep = [
|
283 |
+
idx for (idx, component) in enumerate(component_indices) if component in components_to_keep]
|
284 |
+
print("Selecting {0} largest connected components".format(n_components))
|
285 |
+
return nodes_to_keep
|
286 |
+
|
287 |
+
def __repr__(self):
|
288 |
+
return '{0}(adj_shape={1}, feature_shape={2})'.format(self.name, self.adj.shape, self.features.shape)
|
289 |
+
|
290 |
+
def get_mask(self):
|
291 |
+
idx_train, idx_val, idx_test = self.idx_train, self.idx_val, self.idx_test
|
292 |
+
labels = self.onehot(self.labels)
|
293 |
+
|
294 |
+
def get_mask(idx):
|
295 |
+
mask = np.zeros(labels.shape[0], dtype=np.bool)
|
296 |
+
mask[idx] = 1
|
297 |
+
return mask
|
298 |
+
|
299 |
+
def get_y(idx):
|
300 |
+
mx = np.zeros(labels.shape)
|
301 |
+
mx[idx] = labels[idx]
|
302 |
+
return mx
|
303 |
+
|
304 |
+
self.train_mask = get_mask(self.idx_train)
|
305 |
+
self.val_mask = get_mask(self.idx_val)
|
306 |
+
self.test_mask = get_mask(self.idx_test)
|
307 |
+
self.y_train, self.y_val, self.y_test = get_y(idx_train), get_y(idx_val), get_y(idx_test)
|
308 |
+
|
309 |
+
def onehot(self, labels):
|
310 |
+
eye = np.identity(labels.max() + 1)
|
311 |
+
onehot_mx = eye[labels]
|
312 |
+
return onehot_mx
|
313 |
+
|
314 |
+
def parse_index_file(filename):
|
315 |
+
index = []
|
316 |
+
for line in open(filename):
|
317 |
+
index.append(int(line.strip()))
|
318 |
+
return index
|
319 |
+
|
320 |
+
|
321 |
+
if __name__ == '__main__':
|
322 |
+
from deeprobust.graph.data import Dataset
|
323 |
+
for name in ['cora', 'citeseer', 'pubmed', 'cora_ml']:
|
324 |
+
data = Dataset(root='/tmp/', name=name, setting="prognn")
|
325 |
+
idx_train = data.idx_train
|
326 |
+
data2 = Dataset(root='/tmp/', name=name, setting="nettack", seed=15)
|
327 |
+
idx_train2 = data2.idx_train
|
328 |
+
assert (idx_train != idx_train2).sum() == 0
|
329 |
+
|
330 |
+
data = Dataset(root='/tmp/', name='flickr')
|
331 |
+
adj, features, labels = data.adj, data.features, data.labels
|
332 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
333 |
+
|
deeprobust/graph/data/pyg_dataset.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from .dataset import Dataset
|
4 |
+
import scipy.sparse as sp
|
5 |
+
from itertools import repeat
|
6 |
+
import os.path as osp
|
7 |
+
import warnings
|
8 |
+
import sys
|
9 |
+
from torch_geometric.data import InMemoryDataset, Data
|
10 |
+
from torch_geometric.datasets import Coauthor, Amazon
|
11 |
+
|
12 |
+
|
13 |
+
class Dpr2Pyg(InMemoryDataset):
|
14 |
+
"""Convert deeprobust data (sparse matrix) to pytorch geometric data (tensor, edge_index)
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
dpr_data :
|
19 |
+
data instance of class from deeprobust.graph.data, e.g., deeprobust.graph.data.Dataset,
|
20 |
+
deeprobust.graph.data.PtbDataset, deeprobust.graph.data.PrePtbDataset
|
21 |
+
transform :
|
22 |
+
A function/transform that takes in an object and returns a transformed version.
|
23 |
+
The data object will be transformed before every access. For example, you can
|
24 |
+
use torch_geometric.transforms.NormalizeFeatures()
|
25 |
+
|
26 |
+
Examples
|
27 |
+
--------
|
28 |
+
We can first create an instance of the Dataset class and convert it to
|
29 |
+
pytorch geometric data format.
|
30 |
+
|
31 |
+
>>> from deeprobust.graph.data import Dataset, Dpr2Pyg
|
32 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
33 |
+
>>> pyg_data = Dpr2Pyg(data)
|
34 |
+
>>> print(pyg_data)
|
35 |
+
>>> print(pyg_data[0])
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, dpr_data, transform=None, **kwargs):
|
39 |
+
root = 'data/' # dummy root; does not mean anything
|
40 |
+
self.dpr_data = dpr_data
|
41 |
+
super(Dpr2Pyg, self).__init__(root, transform)
|
42 |
+
pyg_data = self.process()
|
43 |
+
self.data, self.slices = self.collate([pyg_data])
|
44 |
+
self.transform = transform
|
45 |
+
|
46 |
+
def process(self):
|
47 |
+
dpr_data = self.dpr_data
|
48 |
+
edge_index = torch.LongTensor(dpr_data.adj.nonzero())
|
49 |
+
# by default, the features in pyg data is dense
|
50 |
+
if sp.issparse(dpr_data.features):
|
51 |
+
x = torch.FloatTensor(dpr_data.features.todense()).float()
|
52 |
+
else:
|
53 |
+
x = torch.FloatTensor(dpr_data.features).float()
|
54 |
+
y = torch.LongTensor(dpr_data.labels)
|
55 |
+
idx_train, idx_val, idx_test = dpr_data.idx_train, dpr_data.idx_val, dpr_data.idx_test
|
56 |
+
data = Data(x=x, edge_index=edge_index, y=y)
|
57 |
+
train_mask = index_to_mask(idx_train, size=y.size(0))
|
58 |
+
val_mask = index_to_mask(idx_val, size=y.size(0))
|
59 |
+
test_mask = index_to_mask(idx_test, size=y.size(0))
|
60 |
+
data.train_mask = train_mask
|
61 |
+
data.val_mask = val_mask
|
62 |
+
data.test_mask = test_mask
|
63 |
+
return data
|
64 |
+
|
65 |
+
def update_edge_index(self, adj):
|
66 |
+
""" This is an inplace operation to substitute the original edge_index
|
67 |
+
with adj.nonzero()
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
adj: sp.csr_matrix
|
72 |
+
update the original adjacency into adj (by change edge_index)
|
73 |
+
"""
|
74 |
+
self.data.edge_index = torch.LongTensor(adj.nonzero())
|
75 |
+
self.data, self.slices = self.collate([self.data])
|
76 |
+
|
77 |
+
def get(self, idx):
|
78 |
+
if self.slices is None:
|
79 |
+
return self.data
|
80 |
+
data = self.data.__class__()
|
81 |
+
|
82 |
+
if hasattr(self.data, '__num_nodes__'):
|
83 |
+
data.num_nodes = self.data.__num_nodes__[idx]
|
84 |
+
|
85 |
+
for key in self.data.keys:
|
86 |
+
item, slices = self.data[key], self.slices[key]
|
87 |
+
s = list(repeat(slice(None), item.dim()))
|
88 |
+
s[self.data.__cat_dim__(key, item)] = slice(slices[idx],
|
89 |
+
slices[idx + 1])
|
90 |
+
data[key] = item[s]
|
91 |
+
return data
|
92 |
+
|
93 |
+
@property
|
94 |
+
def raw_file_names(self):
|
95 |
+
return ['some_file_1', 'some_file_2', ...]
|
96 |
+
|
97 |
+
@property
|
98 |
+
def processed_file_names(self):
|
99 |
+
return ['data.pt']
|
100 |
+
|
101 |
+
def _download(self):
|
102 |
+
pass
|
103 |
+
|
104 |
+
|
105 |
+
class Pyg2Dpr(Dataset):
|
106 |
+
"""Convert pytorch geometric data (tensor, edge_index) to deeprobust
|
107 |
+
data (sparse matrix)
|
108 |
+
|
109 |
+
Parameters
|
110 |
+
----------
|
111 |
+
pyg_data :
|
112 |
+
data instance of class from pytorch geometric dataset
|
113 |
+
|
114 |
+
Examples
|
115 |
+
--------
|
116 |
+
We can first create an instance of the Dataset class and convert it to
|
117 |
+
pytorch geometric data format and then convert it back to Dataset class.
|
118 |
+
|
119 |
+
>>> from deeprobust.graph.data import Dataset, Dpr2Pyg, Pyg2Dpr
|
120 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
121 |
+
>>> pyg_data = Dpr2Pyg(data)
|
122 |
+
>>> print(pyg_data)
|
123 |
+
>>> print(pyg_data[0])
|
124 |
+
>>> dpr_data = Pyg2Dpr(pyg_data)
|
125 |
+
>>> print(dpr_data.adj)
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, pyg_data, **kwargs):
|
129 |
+
is_ogb = hasattr(pyg_data, 'get_idx_split')
|
130 |
+
if is_ogb: # get splits for ogb datasets
|
131 |
+
splits = pyg_data.get_idx_split()
|
132 |
+
pyg_data = pyg_data[0]
|
133 |
+
n = pyg_data.num_nodes
|
134 |
+
self.adj = sp.csr_matrix((np.ones(pyg_data.edge_index.shape[1]),
|
135 |
+
(pyg_data.edge_index[0], pyg_data.edge_index[1])), shape=(n, n))
|
136 |
+
self.features = pyg_data.x.numpy()
|
137 |
+
self.labels = pyg_data.y.numpy()
|
138 |
+
if len(self.labels.shape) == 2 and self.labels.shape[1] == 1:
|
139 |
+
self.labels = self.labels.reshape(-1) # ogb-arxiv needs to reshape
|
140 |
+
if is_ogb: # set splits for ogb datasets
|
141 |
+
self.idx_train = splits['train'].numpy()
|
142 |
+
self.idx_val = splits['valid'].numpy()
|
143 |
+
self.idx_test = splits['test'].numpy()
|
144 |
+
else:
|
145 |
+
try:
|
146 |
+
self.idx_train = mask_to_index(pyg_data.train_mask, n)
|
147 |
+
self.idx_val = mask_to_index(pyg_data.val_mask, n)
|
148 |
+
self.idx_test = mask_to_index(pyg_data.test_mask, n)
|
149 |
+
except AttributeError:
|
150 |
+
print(
|
151 |
+
'Warning: This pyg dataset is not associated with any data splits...')
|
152 |
+
self.name = 'Pyg2Dpr'
|
153 |
+
|
154 |
+
|
155 |
+
class AmazonPyg(Amazon):
|
156 |
+
"""Amazon-Computers and Amazon-Photo datasets loaded from pytorch geomtric;
|
157 |
+
the way we split the dataset follows Towards Deeper Graph Neural Networks
|
158 |
+
(https://github.com/mengliu1998/DeeperGNN/blob/master/DeeperGNN/train_eval.py).
|
159 |
+
Specifically, 20 * num_classes labels for training, 30 * num_classes labels
|
160 |
+
for validation, rest labels for testing.
|
161 |
+
|
162 |
+
Parameters
|
163 |
+
----------
|
164 |
+
root : string
|
165 |
+
root directory where the dataset should be saved.
|
166 |
+
name : string
|
167 |
+
dataset name, it can be choosen from ['computers', 'photo']
|
168 |
+
transform :
|
169 |
+
A function/transform that takes in an torch_geometric.data.Data object
|
170 |
+
and returns a transformed version. The data object will be transformed
|
171 |
+
before every access. (default: None)
|
172 |
+
pre_transform :
|
173 |
+
A function/transform that takes in an torch_geometric.data.Data object
|
174 |
+
and returns a transformed version. The data object will be transformed
|
175 |
+
before being saved to disk.
|
176 |
+
|
177 |
+
Examples
|
178 |
+
--------
|
179 |
+
We can directly load Amazon dataset from deeprobust in the format of pyg.
|
180 |
+
|
181 |
+
>>> from deeprobust.graph.data import AmazonPyg
|
182 |
+
>>> computers = AmazonPyg(root='/tmp', name='computers')
|
183 |
+
>>> print(computers)
|
184 |
+
>>> print(computers[0])
|
185 |
+
>>> photo = AmazonPyg(root='/tmp', name='photo')
|
186 |
+
>>> print(photo)
|
187 |
+
>>> print(photo[0])
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(self, root, name, transform=None, pre_transform=None, **kwargs):
|
191 |
+
path = osp.join(root, 'pygdata', name)
|
192 |
+
super(AmazonPyg, self).__init__(path, name, transform, pre_transform)
|
193 |
+
|
194 |
+
random_coauthor_amazon_splits(self, self.num_classes, lcc_mask=None)
|
195 |
+
self.data, self.slices = self.collate([self.data])
|
196 |
+
|
197 |
+
|
198 |
+
class CoauthorPyg(Coauthor):
|
199 |
+
"""Coauthor-CS and Coauthor-Physics datasets loaded from pytorch geomtric;
|
200 |
+
the way we split the dataset follows Towards Deeper Graph Neural Networks
|
201 |
+
(https://github.com/mengliu1998/DeeperGNN/blob/master/DeeperGNN/train_eval.py).
|
202 |
+
Specifically, 20 * num_classes labels for training, 30 * num_classes labels
|
203 |
+
for validation, rest labels for testing.
|
204 |
+
|
205 |
+
Parameters
|
206 |
+
----------
|
207 |
+
root : string
|
208 |
+
root directory where the dataset should be saved.
|
209 |
+
name : string
|
210 |
+
dataset name, it can be choosen from ['cs', 'physics']
|
211 |
+
transform :
|
212 |
+
A function/transform that takes in an torch_geometric.data.Data object
|
213 |
+
and returns a transformed version. The data object will be transformed
|
214 |
+
before every access. (default: None)
|
215 |
+
pre_transform :
|
216 |
+
A function/transform that takes in an torch_geometric.data.Data object
|
217 |
+
and returns a transformed version. The data object will be transformed
|
218 |
+
before being saved to disk.
|
219 |
+
|
220 |
+
Examples
|
221 |
+
--------
|
222 |
+
We can directly load Coauthor dataset from deeprobust in the format of pyg.
|
223 |
+
|
224 |
+
>>> from deeprobust.graph.data import CoauthorPyg
|
225 |
+
>>> cs = CoauthorPyg(root='/tmp', name='cs')
|
226 |
+
>>> print(cs)
|
227 |
+
>>> print(cs[0])
|
228 |
+
>>> physics = CoauthorPyg(root='/tmp', name='physics')
|
229 |
+
>>> print(physics)
|
230 |
+
>>> print(physics[0])
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(self, root, name, transform=None, pre_transform=None, **kwargs):
|
234 |
+
path = osp.join(root, 'pygdata', name)
|
235 |
+
super(CoauthorPyg, self).__init__(path, name, transform, pre_transform)
|
236 |
+
random_coauthor_amazon_splits(self, self.num_classes, lcc_mask=None)
|
237 |
+
self.data, self.slices = self.collate([self.data])
|
238 |
+
|
239 |
+
|
240 |
+
def random_coauthor_amazon_splits(dataset, num_classes, lcc_mask):
|
241 |
+
"""https://github.com/mengliu1998/DeeperGNN/blob/master/DeeperGNN/train_eval.py
|
242 |
+
Set random coauthor/co-purchase splits:
|
243 |
+
* 20 * num_classes labels for training
|
244 |
+
* 30 * num_classes labels for validation
|
245 |
+
rest labels for testing
|
246 |
+
"""
|
247 |
+
data = dataset.data
|
248 |
+
indices = []
|
249 |
+
if lcc_mask is not None:
|
250 |
+
for i in range(num_classes):
|
251 |
+
index = (data.y[lcc_mask] == i).nonzero().view(-1)
|
252 |
+
index = index[torch.randperm(index.size(0))]
|
253 |
+
indices.append(index)
|
254 |
+
else:
|
255 |
+
for i in range(num_classes):
|
256 |
+
index = (data.y == i).nonzero().view(-1)
|
257 |
+
index = index[torch.randperm(index.size(0))]
|
258 |
+
indices.append(index)
|
259 |
+
|
260 |
+
train_index = torch.cat([i[:20] for i in indices], dim=0)
|
261 |
+
val_index = torch.cat([i[20:50] for i in indices], dim=0)
|
262 |
+
|
263 |
+
rest_index = torch.cat([i[50:] for i in indices], dim=0)
|
264 |
+
rest_index = rest_index[torch.randperm(rest_index.size(0))]
|
265 |
+
|
266 |
+
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
|
267 |
+
data.val_mask = index_to_mask(val_index, size=data.num_nodes)
|
268 |
+
data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
|
269 |
+
|
270 |
+
|
271 |
+
def mask_to_index(index, size):
|
272 |
+
all_idx = np.arange(size)
|
273 |
+
return all_idx[index]
|
274 |
+
|
275 |
+
|
276 |
+
def index_to_mask(index, size):
|
277 |
+
mask = torch.zeros((size, ), dtype=torch.bool)
|
278 |
+
mask[index] = 1
|
279 |
+
return mask
|
280 |
+
|
281 |
+
|
282 |
+
if __name__ == "__main__":
|
283 |
+
from deeprobust.graph.data import PrePtbDataset, Dataset
|
284 |
+
# load clean graph data
|
285 |
+
dataset_str = 'cora'
|
286 |
+
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
|
287 |
+
pyg_data = Dpr2Pyg(data)
|
288 |
+
print(pyg_data)
|
289 |
+
print(pyg_data[0])
|
290 |
+
dpr_data = Pyg2Dpr(pyg_data)
|
291 |
+
print(dpr_data)
|
292 |
+
|
293 |
+
computers = AmazonPyg(root='/tmp', name='computers')
|
294 |
+
print(computers)
|
295 |
+
print(computers[0])
|
296 |
+
photo = AmazonPyg(root='/tmp', name='photo')
|
297 |
+
print(photo)
|
298 |
+
print(photo[0])
|
299 |
+
cs = CoauthorPyg(root='/tmp', name='cs')
|
300 |
+
print(cs)
|
301 |
+
print(cs[0])
|
302 |
+
physics = CoauthorPyg(root='/tmp', name='physics')
|
303 |
+
print(physics)
|
304 |
+
print(physics[0])
|
305 |
+
|
306 |
+
# from ogb.nodeproppred import PygNodePropPredDataset
|
307 |
+
# dataset = PygNodePropPredDataset(name = 'ogbn-arxiv')
|
308 |
+
# ogb_data = Pyg2Dpr(dataset)
|
deeprobust/graph/data/utils.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file provides functions for converting deeprobust data
|
3 |
+
to pytorch geometric data.
|
4 |
+
"""
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
deeprobust/graph/defense/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .gcn import GCN, GraphConvolution
|
2 |
+
from .gcn_preprocess import GCNSVD, GCNJaccard
|
3 |
+
from .gcn_cgscore import GCNScore
|
4 |
+
from .r_gcn import RGCN, GGCL_F, GGCL_D
|
5 |
+
from .prognn import ProGNN
|
6 |
+
from .simpgcn import SimPGCN
|
7 |
+
from .node_embedding import Node2Vec, DeepWalk
|
8 |
+
import warnings
|
9 |
+
try:
|
10 |
+
from .gat import GAT
|
11 |
+
from .chebnet import ChebNet
|
12 |
+
from .sgc import SGC
|
13 |
+
from .median_gcn import MedianGCN
|
14 |
+
except ImportError as e:
|
15 |
+
print(e)
|
16 |
+
warnings.warn("Please install pytorch geometric if you " +
|
17 |
+
"would like to use the datasets from pytorch " +
|
18 |
+
"geometric. See details in https://pytorch-geom" +
|
19 |
+
"etric.readthedocs.io/en/latest/notes/installation.html")
|
20 |
+
|
21 |
+
__all__ = ['GCN', 'GCNSVD', 'GCNJaccard', 'RGCN', 'ProGNN',
|
22 |
+
'GraphConvolution', 'GGCL_F', 'GGCL_D', 'GAT', 'MedianGCN',
|
23 |
+
'ChebNet', 'SGC', 'SimPGCN', 'Node2Vec', 'DeepWalk']
|
deeprobust/graph/defense/pgd.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim.sgd import SGD
|
2 |
+
from torch.optim.optimizer import required
|
3 |
+
from torch.optim import Optimizer
|
4 |
+
import torch
|
5 |
+
import sklearn
|
6 |
+
import numpy as np
|
7 |
+
import scipy.sparse as sp
|
8 |
+
|
9 |
+
class PGD(Optimizer):
|
10 |
+
"""Proximal gradient descent.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
params : iterable
|
15 |
+
iterable of parameters to optimize or dicts defining parameter groups
|
16 |
+
proxs : iterable
|
17 |
+
iterable of proximal operators
|
18 |
+
alpha : iterable
|
19 |
+
iterable of coefficients for proximal gradient descent
|
20 |
+
lr : float
|
21 |
+
learning rate
|
22 |
+
momentum : float
|
23 |
+
momentum factor (default: 0)
|
24 |
+
weight_decay : float
|
25 |
+
weight decay (L2 penalty) (default: 0)
|
26 |
+
dampening : float
|
27 |
+
dampening for momentum (default: 0)
|
28 |
+
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, params, proxs, alphas, lr=required, momentum=0, dampening=0, weight_decay=0):
|
32 |
+
defaults = dict(lr=lr, momentum=0, dampening=0,
|
33 |
+
weight_decay=0, nesterov=False)
|
34 |
+
|
35 |
+
|
36 |
+
super(PGD, self).__init__(params, defaults)
|
37 |
+
|
38 |
+
for group in self.param_groups:
|
39 |
+
group.setdefault('proxs', proxs)
|
40 |
+
group.setdefault('alphas', alphas)
|
41 |
+
|
42 |
+
def __setstate__(self, state):
|
43 |
+
super(PGD, self).__setstate__(state)
|
44 |
+
for group in self.param_groups:
|
45 |
+
group.setdefault('nesterov', False)
|
46 |
+
group.setdefault('proxs', proxs)
|
47 |
+
group.setdefault('alphas', alphas)
|
48 |
+
|
49 |
+
def step(self, delta=0, closure=None):
|
50 |
+
for group in self.param_groups:
|
51 |
+
lr = group['lr']
|
52 |
+
weight_decay = group['weight_decay']
|
53 |
+
momentum = group['momentum']
|
54 |
+
dampening = group['dampening']
|
55 |
+
nesterov = group['nesterov']
|
56 |
+
proxs = group['proxs']
|
57 |
+
alphas = group['alphas']
|
58 |
+
|
59 |
+
# apply the proximal operator to each parameter in a group
|
60 |
+
for param in group['params']:
|
61 |
+
for prox_operator, alpha in zip(proxs, alphas):
|
62 |
+
# param.data.add_(lr, -param.grad.data)
|
63 |
+
# param.data.add_(delta)
|
64 |
+
param.data = prox_operator(param.data, alpha=alpha*lr)
|
65 |
+
|
66 |
+
|
67 |
+
class ProxOperators():
|
68 |
+
"""Proximal Operators.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self):
|
72 |
+
self.nuclear_norm = None
|
73 |
+
|
74 |
+
def prox_l1(self, data, alpha):
|
75 |
+
"""Proximal operator for l1 norm.
|
76 |
+
"""
|
77 |
+
data = torch.mul(torch.sign(data), torch.clamp(torch.abs(data)-alpha, min=0))
|
78 |
+
return data
|
79 |
+
|
80 |
+
def prox_nuclear(self, data, alpha):
|
81 |
+
"""Proximal operator for nuclear norm (trace norm).
|
82 |
+
"""
|
83 |
+
device = data.device
|
84 |
+
U, S, V = np.linalg.svd(data.cpu())
|
85 |
+
U, S, V = torch.FloatTensor(U).to(device), torch.FloatTensor(S).to(device), torch.FloatTensor(V).to(device)
|
86 |
+
self.nuclear_norm = S.sum()
|
87 |
+
# print("nuclear norm: %.4f" % self.nuclear_norm)
|
88 |
+
|
89 |
+
diag_S = torch.diag(torch.clamp(S-alpha, min=0))
|
90 |
+
return torch.matmul(torch.matmul(U, diag_S), V)
|
91 |
+
|
92 |
+
def prox_nuclear_truncated_2(self, data, alpha, k=50):
|
93 |
+
device = data.device
|
94 |
+
import tensorly as tl
|
95 |
+
tl.set_backend('pytorch')
|
96 |
+
U, S, V = tl.truncated_svd(data.cpu(), n_eigenvecs=k)
|
97 |
+
U, S, V = torch.FloatTensor(U).to(device), torch.FloatTensor(S).to(device), torch.FloatTensor(V).to(device)
|
98 |
+
self.nuclear_norm = S.sum()
|
99 |
+
# print("nuclear norm: %.4f" % self.nuclear_norm)
|
100 |
+
|
101 |
+
S = torch.clamp(S-alpha, min=0)
|
102 |
+
|
103 |
+
# diag_S = torch.diag(torch.clamp(S-alpha, min=0))
|
104 |
+
# U = torch.spmm(U, diag_S)
|
105 |
+
# V = torch.matmul(U, V)
|
106 |
+
|
107 |
+
# make diag_S sparse matrix
|
108 |
+
indices = torch.tensor((range(0, len(S)), range(0, len(S)))).to(device)
|
109 |
+
values = S
|
110 |
+
diag_S = torch.sparse.FloatTensor(indices, values, torch.Size((len(S), len(S))))
|
111 |
+
V = torch.spmm(diag_S, V)
|
112 |
+
V = torch.matmul(U, V)
|
113 |
+
return V
|
114 |
+
|
115 |
+
def prox_nuclear_truncated(self, data, alpha, k=50):
|
116 |
+
device = data.device
|
117 |
+
indices = torch.nonzero(data).t()
|
118 |
+
values = data[indices[0], indices[1]] # modify this based on dimensionality
|
119 |
+
data_sparse = sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()))
|
120 |
+
U, S, V = sp.linalg.svds(data_sparse, k=k)
|
121 |
+
U, S, V = torch.FloatTensor(U).to(device), torch.FloatTensor(S).to(device), torch.FloatTensor(V).to(device)
|
122 |
+
self.nuclear_norm = S.sum()
|
123 |
+
diag_S = torch.diag(torch.clamp(S-alpha, min=0))
|
124 |
+
return torch.matmul(torch.matmul(U, diag_S), V)
|
125 |
+
|
126 |
+
def prox_nuclear_cuda(self, data, alpha):
|
127 |
+
|
128 |
+
device = data.device
|
129 |
+
U, S, V = torch.svd(data)
|
130 |
+
# self.nuclear_norm = S.sum()
|
131 |
+
# print(f"rank = {len(S.nonzero())}")
|
132 |
+
self.nuclear_norm = S.sum()
|
133 |
+
S = torch.clamp(S-alpha, min=0)
|
134 |
+
indices = torch.tensor([range(0, U.shape[0]),range(0, U.shape[0])]).to(device)
|
135 |
+
values = S
|
136 |
+
diag_S = torch.sparse.FloatTensor(indices, values, torch.Size(U.shape))
|
137 |
+
# diag_S = torch.diag(torch.clamp(S-alpha, min=0))
|
138 |
+
# print(f"rank_after = {len(diag_S.nonzero())}")
|
139 |
+
V = torch.spmm(diag_S, V.t_())
|
140 |
+
V = torch.matmul(U, V)
|
141 |
+
return V
|
142 |
+
|
143 |
+
|
144 |
+
class SGD(Optimizer):
|
145 |
+
|
146 |
+
|
147 |
+
def __init__(self, params, lr=required, momentum=0, dampening=0,
|
148 |
+
weight_decay=0, nesterov=False):
|
149 |
+
if lr is not required and lr < 0.0:
|
150 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
151 |
+
if momentum < 0.0:
|
152 |
+
raise ValueError("Invalid momentum value: {}".format(momentum))
|
153 |
+
if weight_decay < 0.0:
|
154 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
155 |
+
|
156 |
+
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
|
157 |
+
weight_decay=weight_decay, nesterov=nesterov)
|
158 |
+
if nesterov and (momentum <= 0 or dampening != 0):
|
159 |
+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
160 |
+
super(SGD, self).__init__(params, defaults)
|
161 |
+
|
162 |
+
def __setstate__(self, state):
|
163 |
+
super(SGD, self).__setstate__(state)
|
164 |
+
for group in self.param_groups:
|
165 |
+
group.setdefault('nesterov', False)
|
166 |
+
|
167 |
+
def step(self, closure=None):
|
168 |
+
"""Performs a single optimization step.
|
169 |
+
|
170 |
+
Arguments:
|
171 |
+
closure (callable, optional): A closure that reevaluates the model
|
172 |
+
and returns the loss.
|
173 |
+
"""
|
174 |
+
loss = None
|
175 |
+
if closure is not None:
|
176 |
+
loss = closure()
|
177 |
+
|
178 |
+
for group in self.param_groups:
|
179 |
+
weight_decay = group['weight_decay']
|
180 |
+
momentum = group['momentum']
|
181 |
+
dampening = group['dampening']
|
182 |
+
nesterov = group['nesterov']
|
183 |
+
|
184 |
+
for p in group['params']:
|
185 |
+
if p.grad is None:
|
186 |
+
continue
|
187 |
+
d_p = p.grad.data
|
188 |
+
if weight_decay != 0:
|
189 |
+
d_p.add_(weight_decay, p.data)
|
190 |
+
if momentum != 0:
|
191 |
+
param_state = self.state[p]
|
192 |
+
if 'momentum_buffer' not in param_state:
|
193 |
+
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
|
194 |
+
else:
|
195 |
+
buf = param_state['momentum_buffer']
|
196 |
+
buf.mul_(momentum).add_(1 - dampening, d_p)
|
197 |
+
if nesterov:
|
198 |
+
d_p = d_p.add(momentum, buf)
|
199 |
+
else:
|
200 |
+
d_p = buf
|
201 |
+
|
202 |
+
p.data.add_(-group['lr'], d_p)
|
203 |
+
|
204 |
+
return loss
|
205 |
+
|
206 |
+
prox_operators = ProxOperators()
|
207 |
+
|
deeprobust/graph/defense/simpgcn.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn.parameter import Parameter
|
8 |
+
import scipy.sparse as sp
|
9 |
+
from deeprobust.graph.defense import GraphConvolution
|
10 |
+
import deeprobust.graph.utils as utils
|
11 |
+
import torch.optim as optim
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
from copy import deepcopy
|
14 |
+
from itertools import product
|
15 |
+
|
16 |
+
|
17 |
+
class SimPGCN(nn.Module):
|
18 |
+
"""SimP-GCN: Node similarity preserving graph convolutional networks.
|
19 |
+
https://arxiv.org/abs/2011.09643
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
nnodes : int
|
24 |
+
number of nodes in the input grpah
|
25 |
+
nfeat : int
|
26 |
+
size of input feature dimension
|
27 |
+
nhid : int
|
28 |
+
number of hidden units
|
29 |
+
nclass : int
|
30 |
+
size of output dimension
|
31 |
+
lambda_ : float
|
32 |
+
coefficients for SSL loss in SimP-GCN
|
33 |
+
gamma : float
|
34 |
+
coefficients for adaptive learnable self-loops
|
35 |
+
bias_init : float
|
36 |
+
bias init for the score
|
37 |
+
dropout : float
|
38 |
+
dropout rate for GCN
|
39 |
+
lr : float
|
40 |
+
learning rate for GCN
|
41 |
+
weight_decay : float
|
42 |
+
weight decay coefficient (l2 normalization) for GCN. When `with_relu` is True, `weight_decay` will be set to 0.
|
43 |
+
with_bias: bool
|
44 |
+
whether to include bias term in GCN weights.
|
45 |
+
device: str
|
46 |
+
'cpu' or 'cuda'.
|
47 |
+
|
48 |
+
Examples
|
49 |
+
--------
|
50 |
+
We can first load dataset and then train SimPGCN.
|
51 |
+
See the detailed hyper-parameter setting in https://github.com/ChandlerBang/SimP-GCN.
|
52 |
+
|
53 |
+
>>> from deeprobust.graph.data import PrePtbDataset, Dataset
|
54 |
+
>>> from deeprobust.graph.defense import SimPGCN
|
55 |
+
>>> # load clean graph data
|
56 |
+
>>> data = Dataset(root='/tmp/', name='cora', seed=15)
|
57 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
58 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
59 |
+
>>> # load perturbed graph data
|
60 |
+
>>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora')
|
61 |
+
>>> perturbed_adj = perturbed_data.adj
|
62 |
+
>>> model = SimPGCN(nnodes=features.shape[0], nfeat=features.shape[1],
|
63 |
+
nhid=16, nclass=labels.max()+1, device='cuda')
|
64 |
+
>>> model = model.to('cuda')
|
65 |
+
>>> model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
|
66 |
+
>>> model.test(idx_test)
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, nnodes, nfeat, nhid, nclass, dropout=0.5, lr=0.01,
|
70 |
+
weight_decay=5e-4, lambda_=5, gamma=0.1, bias_init=0,
|
71 |
+
with_bias=True, device=None):
|
72 |
+
super(SimPGCN, self).__init__()
|
73 |
+
|
74 |
+
assert device is not None, "Please specify 'device'!"
|
75 |
+
|
76 |
+
self.device = device
|
77 |
+
self.nfeat = nfeat
|
78 |
+
self.hidden_sizes = [nhid]
|
79 |
+
self.nclass = nclass
|
80 |
+
self.dropout = dropout
|
81 |
+
self.lr = lr
|
82 |
+
self.weight_decay = weight_decay
|
83 |
+
self.bias_init = bias_init
|
84 |
+
self.gamma = gamma
|
85 |
+
self.lambda_ = lambda_
|
86 |
+
self.output = None
|
87 |
+
self.best_model = None
|
88 |
+
self.best_output = None
|
89 |
+
self.adj_norm = None
|
90 |
+
self.features = None
|
91 |
+
|
92 |
+
self.gc1 = GraphConvolution(nfeat, nhid, with_bias=with_bias)
|
93 |
+
self.gc2 = GraphConvolution(nhid, nclass, with_bias=with_bias)
|
94 |
+
|
95 |
+
# self.reset_parameters()
|
96 |
+
self.scores = nn.ParameterList()
|
97 |
+
self.scores.append(Parameter(torch.FloatTensor(nfeat, 1)))
|
98 |
+
for i in range(1):
|
99 |
+
self.scores.append(Parameter(torch.FloatTensor(nhid, 1)))
|
100 |
+
|
101 |
+
self.bias = nn.ParameterList()
|
102 |
+
self.bias.append(Parameter(torch.FloatTensor(1)))
|
103 |
+
for i in range(1):
|
104 |
+
self.bias.append(Parameter(torch.FloatTensor(1)))
|
105 |
+
|
106 |
+
self.D_k = nn.ParameterList()
|
107 |
+
self.D_k.append(Parameter(torch.FloatTensor(nfeat, 1)))
|
108 |
+
for i in range(1):
|
109 |
+
self.D_k.append(Parameter(torch.FloatTensor(nhid, 1)))
|
110 |
+
|
111 |
+
self.identity = utils.sparse_mx_to_torch_sparse_tensor(
|
112 |
+
sp.eye(nnodes)).to(device)
|
113 |
+
|
114 |
+
self.D_bias = nn.ParameterList()
|
115 |
+
self.D_bias.append(Parameter(torch.FloatTensor(1)))
|
116 |
+
for i in range(1):
|
117 |
+
self.D_bias.append(Parameter(torch.FloatTensor(1)))
|
118 |
+
|
119 |
+
# discriminator for ssl
|
120 |
+
self.linear = nn.Linear(nhid, 1).to(device)
|
121 |
+
|
122 |
+
self.adj_knn = None
|
123 |
+
self.pseudo_labels = None
|
124 |
+
|
125 |
+
def get_knn_graph(self, features, k=20):
|
126 |
+
if not os.path.exists('saved_knn/'):
|
127 |
+
os.mkdir('saved_knn')
|
128 |
+
if not os.path.exists('saved_knn/knn_graph_{}.npz'.format(features.shape)):
|
129 |
+
features[features!=0] = 1
|
130 |
+
sims = cosine_similarity(features)
|
131 |
+
np.save('saved_knn/cosine_sims_{}.npy'.format(features.shape), sims)
|
132 |
+
|
133 |
+
sims[(np.arange(len(sims)), np.arange(len(sims)))] = 0
|
134 |
+
for i in range(len(sims)):
|
135 |
+
indices_argsort = np.argsort(sims[i])
|
136 |
+
sims[i, indices_argsort[: -k]] = 0
|
137 |
+
|
138 |
+
adj_knn = sp.csr_matrix(sims)
|
139 |
+
sp.save_npz('saved_knn/knn_graph_{}.npz'.format(features.shape), adj_knn)
|
140 |
+
else:
|
141 |
+
print('loading saved_knn/knn_graph_{}.npz...'.format(features.shape))
|
142 |
+
adj_knn = sp.load_npz('saved_knn/knn_graph_{}.npz'.format(features.shape))
|
143 |
+
return preprocess_adj_noloop(adj_knn, self.device)
|
144 |
+
|
145 |
+
def initialize(self):
|
146 |
+
"""Initialize parameters of SimPGCN.
|
147 |
+
"""
|
148 |
+
self.gc1.reset_parameters()
|
149 |
+
self.gc2.reset_parameters()
|
150 |
+
|
151 |
+
for s in self.scores:
|
152 |
+
stdv = 1. / math.sqrt(s.size(1))
|
153 |
+
s.data.uniform_(-stdv, stdv)
|
154 |
+
for b in self.bias:
|
155 |
+
# fill in b with postive value to make
|
156 |
+
# score s closer to 1 at the beginning
|
157 |
+
b.data.fill_(self.bias_init)
|
158 |
+
|
159 |
+
for Dk in self.D_k:
|
160 |
+
stdv = 1. / math.sqrt(Dk.size(1))
|
161 |
+
Dk.data.uniform_(-stdv, stdv)
|
162 |
+
|
163 |
+
for b in self.D_bias:
|
164 |
+
b.data.fill_(0)
|
165 |
+
|
166 |
+
|
167 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, initialize=True, verbose=False, normalize=True, patience=500, **kwargs):
|
168 |
+
if initialize:
|
169 |
+
self.initialize()
|
170 |
+
|
171 |
+
if type(adj) is not torch.Tensor:
|
172 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
173 |
+
else:
|
174 |
+
features = features.to(self.device)
|
175 |
+
adj = adj.to(self.device)
|
176 |
+
labels = labels.to(self.device)
|
177 |
+
|
178 |
+
if normalize:
|
179 |
+
if utils.is_sparse_tensor(adj):
|
180 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
181 |
+
else:
|
182 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
183 |
+
else:
|
184 |
+
adj_norm = adj
|
185 |
+
|
186 |
+
self.adj_norm = adj_norm
|
187 |
+
self.features = features
|
188 |
+
self.labels = labels
|
189 |
+
|
190 |
+
if idx_val is None:
|
191 |
+
self._train_without_val(labels, idx_train, train_iters, verbose)
|
192 |
+
else:
|
193 |
+
if patience < train_iters:
|
194 |
+
self._train_with_early_stopping(labels, idx_train, idx_val, train_iters, patience, verbose)
|
195 |
+
else:
|
196 |
+
self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
|
197 |
+
|
198 |
+
|
199 |
+
def forward(self, fea, adj):
|
200 |
+
x, _ = self.myforward(fea, adj)
|
201 |
+
return x
|
202 |
+
|
203 |
+
def myforward(self, fea, adj):
|
204 |
+
'''output embedding and log_softmax'''
|
205 |
+
if self.adj_knn is None:
|
206 |
+
self.adj_knn = self.get_knn_graph(fea.to_dense().cpu().numpy())
|
207 |
+
|
208 |
+
adj_knn = self.adj_knn
|
209 |
+
gamma = self.gamma
|
210 |
+
|
211 |
+
s_i = torch.sigmoid(fea @ self.scores[0] + self.bias[0])
|
212 |
+
|
213 |
+
Dk_i = (fea @ self.D_k[0] + self.D_bias[0])
|
214 |
+
x = (s_i * self.gc1(fea, adj) + (1-s_i) * self.gc1(fea, adj_knn)) + (gamma) * Dk_i * self.gc1(fea, self.identity)
|
215 |
+
|
216 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
217 |
+
embedding = x.clone()
|
218 |
+
|
219 |
+
# output, no relu and dropput here.
|
220 |
+
s_o = torch.sigmoid(x @ self.scores[-1] + self.bias[-1])
|
221 |
+
Dk_o = (x @ self.D_k[-1] + self.D_bias[-1])
|
222 |
+
x = (s_o * self.gc2(x, adj) + (1-s_o) * self.gc2(x, adj_knn)) + (gamma) * Dk_o * self.gc2(x, self.identity)
|
223 |
+
|
224 |
+
x = F.log_softmax(x, dim=1)
|
225 |
+
|
226 |
+
self.ss = torch.cat((s_i.view(1,-1), s_o.view(1,-1), gamma*Dk_i.view(1,-1), gamma*Dk_o.view(1,-1)), dim=0)
|
227 |
+
return x, embedding
|
228 |
+
|
229 |
+
def regression_loss(self, embeddings):
|
230 |
+
if self.pseudo_labels is None:
|
231 |
+
agent = AttrSim(self.features.to_dense())
|
232 |
+
self.pseudo_labels = agent.get_label().to(self.device)
|
233 |
+
node_pairs = agent.node_pairs
|
234 |
+
self.node_pairs = node_pairs
|
235 |
+
|
236 |
+
k = 10000
|
237 |
+
node_pairs = self.node_pairs
|
238 |
+
if len(self.node_pairs[0]) > k:
|
239 |
+
sampled = np.random.choice(len(self.node_pairs[0]), k, replace=False)
|
240 |
+
|
241 |
+
embeddings0 = embeddings[node_pairs[0][sampled]]
|
242 |
+
embeddings1 = embeddings[node_pairs[1][sampled]]
|
243 |
+
embeddings = self.linear(torch.abs(embeddings0 - embeddings1))
|
244 |
+
loss = F.mse_loss(embeddings, self.pseudo_labels[sampled], reduction='mean')
|
245 |
+
else:
|
246 |
+
embeddings0 = embeddings[node_pairs[0]]
|
247 |
+
embeddings1 = embeddings[node_pairs[1]]
|
248 |
+
embeddings = self.linear(torch.abs(embeddings0 - embeddings1))
|
249 |
+
loss = F.mse_loss(embeddings, self.pseudo_labels, reduction='mean')
|
250 |
+
# print(loss)
|
251 |
+
return loss
|
252 |
+
|
253 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose):
|
254 |
+
self.train()
|
255 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
256 |
+
for i in range(train_iters):
|
257 |
+
self.train()
|
258 |
+
optimizer.zero_grad()
|
259 |
+
output, embeddings = self.myforward(self.features, self.adj_norm)
|
260 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
261 |
+
loss_ssl = self.lambda_ * self.regression_loss(embeddings)
|
262 |
+
loss_total = loss_train + loss_ssl
|
263 |
+
loss_total.backward()
|
264 |
+
optimizer.step()
|
265 |
+
if verbose and i % 10 == 0:
|
266 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
267 |
+
|
268 |
+
self.eval()
|
269 |
+
output = self.forward(self.features, self.adj_norm)
|
270 |
+
self.output = output
|
271 |
+
|
272 |
+
|
273 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
274 |
+
if verbose:
|
275 |
+
print('=== training gcn model ===')
|
276 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
277 |
+
|
278 |
+
best_loss_val = 100
|
279 |
+
best_acc_val = 0
|
280 |
+
|
281 |
+
for i in range(train_iters):
|
282 |
+
|
283 |
+
self.train()
|
284 |
+
optimizer.zero_grad()
|
285 |
+
output, embeddings = self.myforward(self.features, self.adj_norm)
|
286 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
287 |
+
# acc_train = accuracy(output[idx_train], labels[idx_train])
|
288 |
+
loss_ssl = self.lambda_ * self.regression_loss(embeddings)
|
289 |
+
loss_total = loss_train + loss_ssl
|
290 |
+
loss_total.backward()
|
291 |
+
optimizer.step()
|
292 |
+
|
293 |
+
if verbose and i % 10 == 0:
|
294 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
295 |
+
|
296 |
+
self.eval()
|
297 |
+
output = self.forward(self.features, self.adj_norm)
|
298 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
299 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
300 |
+
|
301 |
+
if best_loss_val > loss_val:
|
302 |
+
best_loss_val = loss_val
|
303 |
+
self.output = output
|
304 |
+
weights = deepcopy(self.state_dict())
|
305 |
+
|
306 |
+
if acc_val > best_acc_val:
|
307 |
+
best_acc_val = acc_val
|
308 |
+
self.output = output
|
309 |
+
weights = deepcopy(self.state_dict())
|
310 |
+
|
311 |
+
if verbose:
|
312 |
+
print('=== picking the best model according to the performance on validation ===')
|
313 |
+
self.load_state_dict(weights)
|
314 |
+
|
315 |
+
def _train_with_early_stopping(self, labels, idx_train, idx_val, train_iters, patience, verbose):
|
316 |
+
if verbose:
|
317 |
+
print('=== training gcn model ===')
|
318 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
319 |
+
|
320 |
+
early_stopping = patience
|
321 |
+
best_loss_val = 100
|
322 |
+
|
323 |
+
for i in range(train_iters):
|
324 |
+
self.train()
|
325 |
+
optimizer.zero_grad()
|
326 |
+
output, embeddings = self.myforward(self.features, self.adj_norm)
|
327 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
328 |
+
loss_ssl = self.lambda_ * self.regression_loss(embeddings)
|
329 |
+
loss_total = loss_train + loss_ssl
|
330 |
+
loss_total.backward()
|
331 |
+
optimizer.step()
|
332 |
+
|
333 |
+
if verbose and i % 10 == 0:
|
334 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
335 |
+
|
336 |
+
self.eval()
|
337 |
+
output = self.forward(self.features, self.adj_norm)
|
338 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
339 |
+
|
340 |
+
if best_loss_val > loss_val:
|
341 |
+
best_loss_val = loss_val
|
342 |
+
self.output = output
|
343 |
+
weights = deepcopy(self.state_dict())
|
344 |
+
patience = early_stopping
|
345 |
+
else:
|
346 |
+
patience -= 1
|
347 |
+
if i > early_stopping and patience <= 0:
|
348 |
+
break
|
349 |
+
|
350 |
+
if verbose:
|
351 |
+
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
|
352 |
+
self.load_state_dict(weights)
|
353 |
+
|
354 |
+
def test(self, idx_test):
|
355 |
+
"""Evaluate GCN performance on test set.
|
356 |
+
|
357 |
+
Parameters
|
358 |
+
----------
|
359 |
+
idx_test :
|
360 |
+
node testing indices
|
361 |
+
"""
|
362 |
+
self.eval()
|
363 |
+
output = self.predict()
|
364 |
+
# output = self.output
|
365 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
366 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
367 |
+
print("Test set results:",
|
368 |
+
"loss= {:.4f}".format(loss_test.item()),
|
369 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
370 |
+
return acc_test.item()
|
371 |
+
|
372 |
+
|
373 |
+
def predict(self, features=None, adj=None):
|
374 |
+
"""By default, the inputs should be unnormalized data
|
375 |
+
|
376 |
+
Parameters
|
377 |
+
----------
|
378 |
+
features :
|
379 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
380 |
+
adj :
|
381 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
382 |
+
|
383 |
+
|
384 |
+
Returns
|
385 |
+
-------
|
386 |
+
torch.FloatTensor
|
387 |
+
output (log probabilities) of GCN
|
388 |
+
"""
|
389 |
+
|
390 |
+
self.eval()
|
391 |
+
if features is None and adj is None:
|
392 |
+
return self.forward(self.features, self.adj_norm)
|
393 |
+
else:
|
394 |
+
if type(adj) is not torch.Tensor:
|
395 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
396 |
+
|
397 |
+
self.features = features
|
398 |
+
if utils.is_sparse_tensor(adj):
|
399 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
400 |
+
else:
|
401 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
402 |
+
return self.forward(self.features, self.adj_norm)
|
403 |
+
|
404 |
+
|
405 |
+
class AttrSim:
|
406 |
+
|
407 |
+
def __init__(self, features):
|
408 |
+
self.features = features.cpu().numpy()
|
409 |
+
self.features[self.features!=0] = 1
|
410 |
+
|
411 |
+
def get_label(self, k=5):
|
412 |
+
features = self.features
|
413 |
+
if not os.path.exists('saved_knn/cosine_sims_{}.npy'.format(features.shape)):
|
414 |
+
sims = cosine_similarity(features)
|
415 |
+
np.save('saved_knn/cosine_sims_{}.npy'.format(features.shape), sims)
|
416 |
+
else:
|
417 |
+
print('loading saved_knn/cosine_sims_{}.npy'.format(features.shape))
|
418 |
+
sims = np.load('saved_knn/cosine_sims_{}.npy'.format(features.shape))
|
419 |
+
|
420 |
+
if not os.path.exists('saved_knn/attrsim_sampled_idx_{}.npy'.format(features.shape)):
|
421 |
+
try:
|
422 |
+
indices_sorted = sims.argsort(1)
|
423 |
+
idx = np.arange(k, sims.shape[0]-k)
|
424 |
+
selected = np.hstack((indices_sorted[:, :k],
|
425 |
+
indices_sorted[:, -k-1:]))
|
426 |
+
|
427 |
+
selected_set = set()
|
428 |
+
for i in range(len(sims)):
|
429 |
+
for pair in product([i], selected[i]):
|
430 |
+
if pair[0] > pair[1]:
|
431 |
+
pair = (pair[1], pair[0])
|
432 |
+
if pair[0] == pair[1]:
|
433 |
+
continue
|
434 |
+
selected_set.add(pair)
|
435 |
+
|
436 |
+
except MemoryError:
|
437 |
+
selected_set = set()
|
438 |
+
for ii, row in tqdm(enumerate(sims)):
|
439 |
+
row = row.argsort()
|
440 |
+
idx = np.arange(k, sims.shape[0]-k)
|
441 |
+
sampled = np.random.choice(idx, k, replace=False)
|
442 |
+
for node in np.hstack((row[:k], row[-k-1:], row[sampled])):
|
443 |
+
if ii > node:
|
444 |
+
pair = (node, ii)
|
445 |
+
else:
|
446 |
+
pair = (ii, node)
|
447 |
+
selected_set.add(pair)
|
448 |
+
|
449 |
+
sampled = np.array(list(selected_set)).transpose()
|
450 |
+
np.save('saved_knn/attrsim_sampled_idx_{}.npy'.format(features.shape), sampled)
|
451 |
+
else:
|
452 |
+
print('loading saved_knn/attrsim_sampled_idx_{}.npy'.format(features.shape))
|
453 |
+
sampled = np.load('saved_knn/attrsim_sampled_idx_{}.npy'.format(features.shape))
|
454 |
+
print('number of sampled:', len(sampled[0]))
|
455 |
+
self.node_pairs = (sampled[0], sampled[1])
|
456 |
+
self.sims = sims
|
457 |
+
return torch.FloatTensor(sims[self.node_pairs]).reshape(-1,1)
|
458 |
+
|
459 |
+
|
460 |
+
def preprocess_adj_noloop(adj, device):
|
461 |
+
adj_normalizer = noaug_normalized_adjacency
|
462 |
+
r_adj = adj_normalizer(adj)
|
463 |
+
r_adj = utils.sparse_mx_to_torch_sparse_tensor(r_adj).float()
|
464 |
+
r_adj = r_adj.to(device)
|
465 |
+
return r_adj
|
466 |
+
|
467 |
+
def noaug_normalized_adjacency(adj):
|
468 |
+
adj = sp.coo_matrix(adj)
|
469 |
+
row_sum = np.array(adj.sum(1))
|
470 |
+
d_inv_sqrt = np.power(row_sum, -0.5).flatten()
|
471 |
+
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
|
472 |
+
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
473 |
+
return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()
|
474 |
+
|
deeprobust/graph/defense_pyg/gat.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch.nn.parameter import Parameter
|
6 |
+
from torch.nn.modules.module import Module
|
7 |
+
# from torch_geometric.nn import GATConv
|
8 |
+
from .mygat_conv import GATConv
|
9 |
+
from .base_model import BaseModel
|
10 |
+
|
11 |
+
|
12 |
+
class GAT(BaseModel):
|
13 |
+
|
14 |
+
def __init__(self, nfeat, nhid, nclass, heads=8, output_heads=1, dropout=0.5, lr=0.01,
|
15 |
+
nlayers=2, with_bn=False, weight_decay=5e-4, with_bias=True, device=None):
|
16 |
+
|
17 |
+
super(GAT, self).__init__()
|
18 |
+
|
19 |
+
assert device is not None, "Please specify 'device'!"
|
20 |
+
self.device = device
|
21 |
+
|
22 |
+
self.convs = nn.ModuleList([])
|
23 |
+
if with_bn:
|
24 |
+
self.bns = nn.ModuleList([])
|
25 |
+
self.bns.append(nn.BatchNorm1d(nhid*heads))
|
26 |
+
|
27 |
+
self.convs.append(GATConv(
|
28 |
+
nfeat,
|
29 |
+
nhid,
|
30 |
+
heads=heads,
|
31 |
+
dropout=dropout,
|
32 |
+
bias=with_bias))
|
33 |
+
|
34 |
+
for i in range(nlayers-2):
|
35 |
+
self.convs.append(GATConv(nhid*heads,
|
36 |
+
nhid, heads=heads, dropout=dropout, bias=with_bias))
|
37 |
+
if with_bn:
|
38 |
+
self.bns.append(nn.BatchNorm1d(nhid*heads))
|
39 |
+
|
40 |
+
self.convs.append(GATConv(
|
41 |
+
nhid * heads,
|
42 |
+
nclass,
|
43 |
+
heads=output_heads,
|
44 |
+
concat=False,
|
45 |
+
dropout=dropout,
|
46 |
+
bias=with_bias))
|
47 |
+
|
48 |
+
self.dropout = dropout
|
49 |
+
self.weight_decay = weight_decay
|
50 |
+
self.lr = lr
|
51 |
+
self.output = None
|
52 |
+
self.best_model = None
|
53 |
+
self.best_output = None
|
54 |
+
self.name = 'GAT'
|
55 |
+
self.with_bn = with_bn
|
56 |
+
|
57 |
+
def forward(self, x, edge_index, edge_weight=None):
|
58 |
+
for ii, conv in enumerate(self.convs[:-1]):
|
59 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
60 |
+
x = conv(x, edge_index, edge_weight)
|
61 |
+
if self.with_bn:
|
62 |
+
x = self.bns[ii](x)
|
63 |
+
x = F.elu(x)
|
64 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
65 |
+
x = self.convs[-1](x, edge_index, edge_weight)
|
66 |
+
return F.log_softmax(x, dim=1)
|
67 |
+
|
68 |
+
def get_embed(self, x, edge_index, edge_weight=None):
|
69 |
+
for ii, conv in enumerate(self.convs[:-1]):
|
70 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
71 |
+
x = conv(x, edge_index, edge_weight)
|
72 |
+
if self.with_bn:
|
73 |
+
x = self.bns[ii](x)
|
74 |
+
x = F.elu(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
def initialize(self):
|
78 |
+
for conv in self.convs:
|
79 |
+
conv.reset_parameters()
|
80 |
+
if self.with_bn:
|
81 |
+
for bn in self.bns:
|
82 |
+
bn.reset_parameters()
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
88 |
+
# from deeprobust.graph.defense import GAT
|
89 |
+
data = Dataset(root='/tmp/', name='cora')
|
90 |
+
adj, features, labels = data.adj, data.features, data.labels
|
91 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
92 |
+
gat = GAT(nfeat=features.shape[1],
|
93 |
+
nhid=8, heads=8,
|
94 |
+
nclass=labels.max().item() + 1,
|
95 |
+
dropout=0.5, device='cpu')
|
96 |
+
gat = gat.to('cpu')
|
97 |
+
pyg_data = Dpr2Pyg(data)
|
98 |
+
gat.fit(pyg_data, verbose=True) # train with earlystopping
|
99 |
+
gat.test()
|
100 |
+
print(gat.predict())
|
deeprobust/graph/defense_pyg/gcn.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch.nn.parameter import Parameter
|
6 |
+
from torch.nn.modules.module import Module
|
7 |
+
from torch_geometric.nn import GCNConv
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from torch_sparse import coalesce, SparseTensor, matmul
|
10 |
+
|
11 |
+
|
12 |
+
class GCN(BaseModel):
|
13 |
+
|
14 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01,
|
15 |
+
with_bn=False, weight_decay=5e-4, with_bias=True, device=None):
|
16 |
+
|
17 |
+
super(GCN, self).__init__()
|
18 |
+
|
19 |
+
assert device is not None, "Please specify 'device'!"
|
20 |
+
self.device = device
|
21 |
+
|
22 |
+
self.layers = nn.ModuleList([])
|
23 |
+
if with_bn:
|
24 |
+
self.bns = nn.ModuleList()
|
25 |
+
|
26 |
+
if nlayers == 1:
|
27 |
+
self.layers.append(GCNConv(nfeat, nclass, bias=with_bias))
|
28 |
+
else:
|
29 |
+
self.layers.append(GCNConv(nfeat, nhid, bias=with_bias))
|
30 |
+
if with_bn:
|
31 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
32 |
+
for i in range(nlayers-2):
|
33 |
+
self.layers.append(GCNConv(nhid, nhid, bias=with_bias))
|
34 |
+
if with_bn:
|
35 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
36 |
+
self.layers.append(GCNConv(nhid, nclass, bias=with_bias))
|
37 |
+
|
38 |
+
self.dropout = dropout
|
39 |
+
self.weight_decay = weight_decay
|
40 |
+
self.lr = lr
|
41 |
+
self.output = None
|
42 |
+
self.best_model = None
|
43 |
+
self.best_output = None
|
44 |
+
self.with_bn = with_bn
|
45 |
+
self.name = 'GCN'
|
46 |
+
|
47 |
+
def forward(self, x, edge_index, edge_weight=None):
|
48 |
+
x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight)
|
49 |
+
for ii, layer in enumerate(self.layers):
|
50 |
+
if edge_weight is not None:
|
51 |
+
adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t()
|
52 |
+
x = layer(x, adj)
|
53 |
+
else:
|
54 |
+
x = layer(x, edge_index)
|
55 |
+
if ii != len(self.layers) - 1:
|
56 |
+
if self.with_bn:
|
57 |
+
x = self.bns[ii](x)
|
58 |
+
x = F.relu(x)
|
59 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
60 |
+
return F.log_softmax(x, dim=1)
|
61 |
+
|
62 |
+
def get_embed(self, x, edge_index, edge_weight=None):
|
63 |
+
x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight)
|
64 |
+
for ii, layer in enumerate(self.layers):
|
65 |
+
if ii == len(self.layers) - 1:
|
66 |
+
return x
|
67 |
+
if edge_weight is not None:
|
68 |
+
adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t()
|
69 |
+
x = layer(x, adj)
|
70 |
+
else:
|
71 |
+
x = layer(x, edge_index)
|
72 |
+
if ii != len(self.layers) - 1:
|
73 |
+
if self.with_bn:
|
74 |
+
x = self.bns[ii](x)
|
75 |
+
x = F.relu(x)
|
76 |
+
return x
|
77 |
+
|
78 |
+
def initialize(self):
|
79 |
+
for m in self.layers:
|
80 |
+
m.reset_parameters()
|
81 |
+
if self.with_bn:
|
82 |
+
for bn in self.bns:
|
83 |
+
bn.reset_parameters()
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
88 |
+
# from deeprobust.graph.defense import GCN
|
89 |
+
data = Dataset(root='/tmp/', name='citeseer', setting='prognn')
|
90 |
+
adj, features, labels = data.adj, data.features, data.labels
|
91 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
92 |
+
model = GCN(nfeat=features.shape[1],
|
93 |
+
nhid=16,
|
94 |
+
nclass=labels.max().item() + 1,
|
95 |
+
dropout=0.5, device='cuda')
|
96 |
+
model = model.to('cuda')
|
97 |
+
pyg_data = Dpr2Pyg(data)[0]
|
98 |
+
|
99 |
+
# model.fit(features, adj, labels, idx_train, train_iters=200, verbose=True)
|
100 |
+
# model.test(idx_test)
|
101 |
+
|
102 |
+
from utils import get_dataset
|
103 |
+
pyg_data = get_dataset('citeseer', True, if_dpr=False)[0]
|
104 |
+
|
105 |
+
import ipdb
|
106 |
+
ipdb.set_trace()
|
107 |
+
|
108 |
+
model.fit(pyg_data, verbose=True) # train with earlystopping
|
109 |
+
model.test()
|
110 |
+
print(model.predict())
|
deeprobust/graph/global_attack/base_attack.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import scipy.sparse as sp
|
5 |
+
import torch
|
6 |
+
from torch.nn.modules.module import Module
|
7 |
+
|
8 |
+
from deeprobust.graph import utils
|
9 |
+
|
10 |
+
|
11 |
+
class BaseAttack(Module):
|
12 |
+
"""Abstract base class for target attack classes.
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
model :
|
17 |
+
model to attack
|
18 |
+
nnodes : int
|
19 |
+
number of nodes in the input graph
|
20 |
+
attack_structure : bool
|
21 |
+
whether to attack graph structure
|
22 |
+
attack_features : bool
|
23 |
+
whether to attack node features
|
24 |
+
device: str
|
25 |
+
'cpu' or 'cuda'
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, model, nnodes, attack_structure=True, attack_features=False, device='cpu'):
|
30 |
+
super(BaseAttack, self).__init__()
|
31 |
+
|
32 |
+
self.surrogate = model
|
33 |
+
self.nnodes = nnodes
|
34 |
+
self.attack_structure = attack_structure
|
35 |
+
self.attack_features = attack_features
|
36 |
+
self.device = device
|
37 |
+
self.modified_adj = None
|
38 |
+
self.modified_features = None
|
39 |
+
if model is not None:
|
40 |
+
self.nclass = model.nclass
|
41 |
+
self.nfeat = model.nfeat
|
42 |
+
self.hidden_sizes = model.hidden_sizes
|
43 |
+
|
44 |
+
def attack(self, ori_adj, n_perturbations, **kwargs):
|
45 |
+
"""Generate attacks on the input graph.
|
46 |
+
|
47 |
+
Parameters
|
48 |
+
----------
|
49 |
+
ori_adj : scipy.sparse.csr_matrix
|
50 |
+
Original (unperturbed) adjacency matrix.
|
51 |
+
n_perturbations : int
|
52 |
+
Number of edge removals/additions.
|
53 |
+
|
54 |
+
Returns
|
55 |
+
-------
|
56 |
+
None.
|
57 |
+
|
58 |
+
"""
|
59 |
+
pass
|
60 |
+
|
61 |
+
def check_adj(self, adj):
|
62 |
+
"""Check if the modified adjacency is symmetric and unweighted.
|
63 |
+
"""
|
64 |
+
assert np.abs(adj - adj.T).sum() == 0, "Input graph is not symmetric"
|
65 |
+
assert adj.tocsr().max() == 1, "Max value should be 1!"
|
66 |
+
assert adj.tocsr().min() == 0, "Min value should be 0!"
|
67 |
+
|
68 |
+
def check_adj_tensor(self, adj):
|
69 |
+
"""Check if the modified adjacency is symmetric, unweighted, all-zero diagonal.
|
70 |
+
"""
|
71 |
+
assert torch.abs(adj - adj.t()).sum() == 0, "Input graph is not symmetric"
|
72 |
+
assert adj.max() == 1, "Max value should be 1!"
|
73 |
+
assert adj.min() == 0, "Min value should be 0!"
|
74 |
+
diag = adj.diag()
|
75 |
+
assert diag.max() == 0, "Diagonal should be 0!"
|
76 |
+
assert diag.min() == 0, "Diagonal should be 0!"
|
77 |
+
|
78 |
+
|
79 |
+
def save_adj(self, root=r'/tmp/', name='mod_adj'):
|
80 |
+
"""Save attacked adjacency matrix.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
root :
|
85 |
+
root directory where the variable should be saved
|
86 |
+
name : str
|
87 |
+
saved file name
|
88 |
+
|
89 |
+
Returns
|
90 |
+
-------
|
91 |
+
None.
|
92 |
+
|
93 |
+
"""
|
94 |
+
assert self.modified_adj is not None, \
|
95 |
+
'modified_adj is None! Please perturb the graph first.'
|
96 |
+
name = name + '.npz'
|
97 |
+
modified_adj = self.modified_adj
|
98 |
+
|
99 |
+
if type(modified_adj) is torch.Tensor:
|
100 |
+
sparse_adj = utils.to_scipy(modified_adj)
|
101 |
+
sp.save_npz(osp.join(root, name), sparse_adj)
|
102 |
+
else:
|
103 |
+
sp.save_npz(osp.join(root, name), modified_adj)
|
104 |
+
|
105 |
+
def save_features(self, root=r'/tmp/', name='mod_features'):
|
106 |
+
"""Save attacked node feature matrix.
|
107 |
+
|
108 |
+
Parameters
|
109 |
+
----------
|
110 |
+
root :
|
111 |
+
root directory where the variable should be saved
|
112 |
+
name : str
|
113 |
+
saved file name
|
114 |
+
|
115 |
+
Returns
|
116 |
+
-------
|
117 |
+
None.
|
118 |
+
|
119 |
+
"""
|
120 |
+
|
121 |
+
assert self.modified_features is not None, \
|
122 |
+
'modified_features is None! Please perturb the graph first.'
|
123 |
+
name = name + '.npz'
|
124 |
+
modified_features = self.modified_features
|
125 |
+
|
126 |
+
if type(modified_features) is torch.Tensor:
|
127 |
+
sparse_features = utils.to_scipy(modified_features)
|
128 |
+
sp.save_npz(osp.join(root, name), sparse_features)
|
129 |
+
else:
|
130 |
+
sp.save_npz(osp.join(root, name), modified_features)
|
deeprobust/graph/global_attack/node_embedding_attack.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code in this file is modified from https://github.com/abojchevski/node_embedding_attack
|
3 |
+
|
4 |
+
'Adversarial Attacks on Node Embeddings via Graph Poisoning'
|
5 |
+
Aleksandar Bojchevski and Stephan Günnemann, ICML 2019
|
6 |
+
http://proceedings.mlr.press/v97/bojchevski19a.html
|
7 |
+
Copyright (C) owned by the authors, 2019
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numba
|
11 |
+
import numpy as np
|
12 |
+
import scipy.sparse as sp
|
13 |
+
import scipy.linalg as spl
|
14 |
+
import torch
|
15 |
+
import networkx as nx
|
16 |
+
from deeprobust.graph.global_attack import BaseAttack
|
17 |
+
|
18 |
+
|
19 |
+
class NodeEmbeddingAttack(BaseAttack):
|
20 |
+
"""Node embedding attack. Adversarial Attacks on Node Embeddings via Graph
|
21 |
+
Poisoning. Aleksandar Bojchevski and Stephan Günnemann, ICML 2019
|
22 |
+
http://proceedings.mlr.press/v97/bojchevski19a.html
|
23 |
+
|
24 |
+
Examples
|
25 |
+
-----
|
26 |
+
>>> from deeprobust.graph.data import Dataset
|
27 |
+
>>> from deeprobust.graph.global_attack import NodeEmbeddingAttack
|
28 |
+
>>> data = Dataset(root='/tmp/', name='cora_ml', seed=15)
|
29 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
30 |
+
>>> model = NodeEmbeddingAttack()
|
31 |
+
>>> model.attack(adj, attack_type="remove")
|
32 |
+
>>> modified_adj = model.modified_adj
|
33 |
+
>>> model.attack(adj, attack_type="remove", min_span_tree=True)
|
34 |
+
>>> modified_adj = model.modified_adj
|
35 |
+
>>> model.attack(adj, attack_type="add", n_candidates=10000)
|
36 |
+
>>> modified_adj = model.modified_adj
|
37 |
+
>>> model.attack(adj, attack_type="add_by_remove", n_candidates=10000)
|
38 |
+
>>> modified_adj = model.modified_adj
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self):
|
42 |
+
pass
|
43 |
+
|
44 |
+
def attack(self, adj, n_perturbations=1000, dim=32, window_size=5,
|
45 |
+
attack_type="remove", min_span_tree=False, n_candidates=None, seed=None, **kwargs):
|
46 |
+
"""Selects the top (n_perturbations) number of flips using our perturbation attack.
|
47 |
+
|
48 |
+
:param adj: sp.spmatrix
|
49 |
+
The graph represented as a sparse scipy matrix
|
50 |
+
:param n_perturbations: int
|
51 |
+
Number of flips to select
|
52 |
+
:param dim: int
|
53 |
+
Dimensionality of the embeddings.
|
54 |
+
:param window_size: int
|
55 |
+
Co-occurence window size.
|
56 |
+
:param attack_type: str
|
57 |
+
can be chosed from ["remove", "add", "add_by_remove"]
|
58 |
+
:param min_span_tree: bool
|
59 |
+
Whether to disallow edges that lie on the minimum spanning tree;
|
60 |
+
only valid when `attack_type` is "remove"
|
61 |
+
:param n_candidates: int
|
62 |
+
Number of candiates for addition; only valid when `attack_type` is "add" or "add_by_remove";
|
63 |
+
:param seed: int
|
64 |
+
Random seed
|
65 |
+
"""
|
66 |
+
assert attack_type in ["remove", "add", "add_by_remove"], \
|
67 |
+
"attack_type can only be `remove` or `add`"
|
68 |
+
|
69 |
+
if attack_type == "remove":
|
70 |
+
if min_span_tree:
|
71 |
+
candidates = self.generate_candidates_removal_minimum_spanning_tree(adj)
|
72 |
+
else:
|
73 |
+
candidates = self.generate_candidates_removal(adj, seed)
|
74 |
+
|
75 |
+
elif attack_type == "add" or attack_type == "add_by_remove":
|
76 |
+
|
77 |
+
assert n_candidates, "please specify the value of `n_candidates`, " \
|
78 |
+
+ "i.e. how many candiate you want to genereate for addition"
|
79 |
+
candidates = self.generate_candidates_addition(adj, n_candidates, seed)
|
80 |
+
|
81 |
+
|
82 |
+
n_nodes = adj.shape[0]
|
83 |
+
|
84 |
+
if attack_type == "add_by_remove":
|
85 |
+
candidates_add = candidates
|
86 |
+
adj_add = self.flip_candidates(adj, candidates_add)
|
87 |
+
vals_org_add, vecs_org_add = spl.eigh(adj_add.toarray(), np.diag(adj_add.sum(1).A1))
|
88 |
+
flip_indicator = 1 - 2 * adj_add[candidates[:, 0], candidates[:, 1]].A1
|
89 |
+
|
90 |
+
loss_est = estimate_loss_with_delta_eigenvals(candidates_add, flip_indicator,
|
91 |
+
vals_org_add, vecs_org_add, n_nodes, dim, window_size)
|
92 |
+
|
93 |
+
loss_argsort = loss_est.argsort()
|
94 |
+
top_flips = candidates_add[loss_argsort[:n_perturbations]]
|
95 |
+
|
96 |
+
else:
|
97 |
+
# vector indicating whether we are adding an edge (+1) or removing an edge (-1)
|
98 |
+
delta_w = 1 - 2 * adj[candidates[:, 0], candidates[:, 1]].A1
|
99 |
+
|
100 |
+
# generalized eigenvalues/eigenvectors
|
101 |
+
deg_matrix = np.diag(adj.sum(1).A1)
|
102 |
+
vals_org, vecs_org = spl.eigh(adj.toarray(), deg_matrix)
|
103 |
+
|
104 |
+
loss_for_candidates = estimate_loss_with_delta_eigenvals(candidates, delta_w, vals_org, vecs_org, n_nodes, dim, window_size)
|
105 |
+
top_flips = candidates[loss_for_candidates.argsort()[-n_perturbations:]]
|
106 |
+
|
107 |
+
assert len(top_flips) == n_perturbations
|
108 |
+
|
109 |
+
modified_adj = self.flip_candidates(adj, top_flips)
|
110 |
+
self.check_adj(modified_adj)
|
111 |
+
self.modified_adj = modified_adj
|
112 |
+
|
113 |
+
def generate_candidates_removal(self, adj, seed=None):
|
114 |
+
"""Generates candidate edge flips for removal (edge -> non-edge),
|
115 |
+
disallowing one random edge per node to prevent singleton nodes.
|
116 |
+
|
117 |
+
:param adj: sp.csr_matrix, shape [n_nodes, n_nodes]
|
118 |
+
Adjacency matrix of the graph
|
119 |
+
:param seed: int
|
120 |
+
Random seed
|
121 |
+
:return: np.ndarray, shape [?, 2]
|
122 |
+
Candidate set of edge flips
|
123 |
+
"""
|
124 |
+
n_nodes = adj.shape[0]
|
125 |
+
if seed is not None:
|
126 |
+
np.random.seed(seed)
|
127 |
+
deg = np.where(adj.sum(1).A1 == 1)[0]
|
128 |
+
hiddeen = np.column_stack(
|
129 |
+
(np.arange(n_nodes), np.fromiter(map(np.random.choice, adj.tolil().rows), dtype=np.int32)))
|
130 |
+
|
131 |
+
adj_hidden = edges_to_sparse(hiddeen, adj.shape[0])
|
132 |
+
adj_hidden = adj_hidden.maximum(adj_hidden.T)
|
133 |
+
|
134 |
+
adj_keep = adj - adj_hidden
|
135 |
+
|
136 |
+
candidates = np.column_stack((sp.triu(adj_keep).nonzero()))
|
137 |
+
|
138 |
+
candidates = candidates[np.logical_not(np.in1d(candidates[:, 0], deg) | np.in1d(candidates[:, 1], deg))]
|
139 |
+
|
140 |
+
return candidates
|
141 |
+
|
142 |
+
def generate_candidates_removal_minimum_spanning_tree(self, adj):
|
143 |
+
"""Generates candidate edge flips for removal (edge -> non-edge),
|
144 |
+
disallowing edges that lie on the minimum spanning tree.
|
145 |
+
adj: sp.csr_matrix, shape [n_nodes, n_nodes]
|
146 |
+
Adjacency matrix of the graph
|
147 |
+
:return: np.ndarray, shape [?, 2]
|
148 |
+
Candidate set of edge flips
|
149 |
+
"""
|
150 |
+
mst = sp.csgraph.minimum_spanning_tree(adj)
|
151 |
+
mst = mst.maximum(mst.T)
|
152 |
+
adj_sample = adj - mst
|
153 |
+
candidates = np.column_stack(sp.triu(adj_sample, 1).nonzero())
|
154 |
+
|
155 |
+
return candidates
|
156 |
+
|
157 |
+
def generate_candidates_addition(self, adj, n_candidates, seed=None):
|
158 |
+
"""Generates candidate edge flips for addition (non-edge -> edge).
|
159 |
+
|
160 |
+
:param adj: sp.csr_matrix, shape [n_nodes, n_nodes]
|
161 |
+
Adjacency matrix of the graph
|
162 |
+
:param n_candidates: int
|
163 |
+
Number of candidates to generate.
|
164 |
+
:param seed: int
|
165 |
+
Random seed
|
166 |
+
:return: np.ndarray, shape [?, 2]
|
167 |
+
Candidate set of edge flips
|
168 |
+
"""
|
169 |
+
if seed is not None:
|
170 |
+
np.random.seed(seed)
|
171 |
+
|
172 |
+
num_nodes = adj.shape[0]
|
173 |
+
|
174 |
+
candidates = np.random.randint(0, num_nodes, [n_candidates * 5, 2])
|
175 |
+
candidates = candidates[candidates[:, 0] < candidates[:, 1]]
|
176 |
+
candidates = candidates[adj[candidates[:, 0], candidates[:, 1]].A1 == 0]
|
177 |
+
candidates = np.array(list(set(map(tuple, candidates))))
|
178 |
+
candidates = candidates[:n_candidates]
|
179 |
+
|
180 |
+
assert len(candidates) == n_candidates
|
181 |
+
|
182 |
+
return candidates
|
183 |
+
|
184 |
+
def flip_candidates(self, adj, candidates):
|
185 |
+
"""Flip the edges in the candidate set to non-edges and vise-versa.
|
186 |
+
|
187 |
+
:param adj: sp.csr_matrix, shape [n_nodes, n_nodes]
|
188 |
+
Adjacency matrix of the graph
|
189 |
+
:param candidates: np.ndarray, shape [?, 2]
|
190 |
+
Candidate set of edge flips
|
191 |
+
:return: sp.csr_matrix, shape [n_nodes, n_nodes]
|
192 |
+
Adjacency matrix of the graph with the flipped edges/non-edges.
|
193 |
+
"""
|
194 |
+
adj_flipped = adj.copy().tolil()
|
195 |
+
adj_flipped[candidates[:, 0], candidates[:, 1]] = 1 - adj[candidates[:, 0], candidates[:, 1]]
|
196 |
+
adj_flipped[candidates[:, 1], candidates[:, 0]] = 1 - adj[candidates[:, 1], candidates[:, 0]]
|
197 |
+
adj_flipped = adj_flipped.tocsr()
|
198 |
+
adj_flipped.eliminate_zeros()
|
199 |
+
|
200 |
+
return adj_flipped
|
201 |
+
|
202 |
+
|
203 |
+
@numba.jit(nopython=True)
|
204 |
+
def estimate_loss_with_delta_eigenvals(candidates, flip_indicator, vals_org, vecs_org, n_nodes, dim, window_size):
|
205 |
+
"""Computes the estimated loss using the change in the eigenvalues for every candidate edge flip.
|
206 |
+
|
207 |
+
:param candidates: np.ndarray, shape [?, 2]
|
208 |
+
Candidate set of edge flips,
|
209 |
+
:param flip_indicator: np.ndarray, shape [?]
|
210 |
+
Vector indicating whether we are adding an edge (+1) or removing an edge (-1)
|
211 |
+
:param vals_org: np.ndarray, shape [n]
|
212 |
+
The generalized eigenvalues of the clean graph
|
213 |
+
:param vecs_org: np.ndarray, shape [n, n]
|
214 |
+
The generalized eigenvectors of the clean graph
|
215 |
+
:param n_nodes: int
|
216 |
+
Number of nodes
|
217 |
+
:param dim: int
|
218 |
+
Embedding dimension
|
219 |
+
:param window_size: int
|
220 |
+
Size of the window
|
221 |
+
:return: np.ndarray, shape [?]
|
222 |
+
Estimated loss for each candidate flip
|
223 |
+
"""
|
224 |
+
|
225 |
+
loss_est = np.zeros(len(candidates))
|
226 |
+
for x in range(len(candidates)):
|
227 |
+
i, j = candidates[x]
|
228 |
+
vals_est = vals_org + flip_indicator[x] * (
|
229 |
+
2 * vecs_org[i] * vecs_org[j] - vals_org * (vecs_org[i] ** 2 + vecs_org[j] ** 2))
|
230 |
+
|
231 |
+
vals_sum_powers = sum_of_powers(vals_est, window_size)
|
232 |
+
|
233 |
+
loss_ij = np.sqrt(np.sum(np.sort(vals_sum_powers ** 2)[:n_nodes - dim]))
|
234 |
+
loss_est[x] = loss_ij
|
235 |
+
|
236 |
+
return loss_est
|
237 |
+
|
238 |
+
|
239 |
+
@numba.jit(nopython=True)
|
240 |
+
def estimate_delta_eigenvecs(candidates, flip_indicator, degrees, vals_org, vecs_org, delta_eigvals, pinvs):
|
241 |
+
"""Computes the estimated change in the eigenvectors for every candidate edge flip.
|
242 |
+
|
243 |
+
:param candidates: np.ndarray, shape [?, 2]
|
244 |
+
Candidate set of edge flips,
|
245 |
+
:param flip_indicator: np.ndarray, shape [?]
|
246 |
+
Vector indicating whether we are adding an edge (+1) or removing an edge (-1)
|
247 |
+
:param degrees: np.ndarray, shape [n]
|
248 |
+
Vector of node degrees.
|
249 |
+
:param vals_org: np.ndarray, shape [n]
|
250 |
+
The generalized eigenvalues of the clean graph
|
251 |
+
:param vecs_org: np.ndarray, shape [n, n]
|
252 |
+
The generalized eigenvectors of the clean graph
|
253 |
+
:param delta_eigvals: np.ndarray, shape [?, n]
|
254 |
+
Estimated change in the eigenvalues for all candidate edge flips
|
255 |
+
:param pinvs: np.ndarray, shape [k, n, n]
|
256 |
+
Precomputed pseudo-inverse matrices for every dimension
|
257 |
+
:return: np.ndarray, shape [?, n, k]
|
258 |
+
Estimated change in the eigenvectors for all candidate edge flips
|
259 |
+
"""
|
260 |
+
n_nodes, dim = vecs_org.shape
|
261 |
+
n_candidates = len(candidates)
|
262 |
+
delta_eigvecs = np.zeros((n_candidates, dim, n_nodes))
|
263 |
+
|
264 |
+
for k in range(dim):
|
265 |
+
cur_eigvecs = vecs_org[:, k]
|
266 |
+
cur_eigvals = vals_org[k]
|
267 |
+
for c in range(n_candidates):
|
268 |
+
degree_eigvec = (-delta_eigvals[c, k] * degrees) * cur_eigvecs
|
269 |
+
i, j = candidates[c]
|
270 |
+
|
271 |
+
degree_eigvec[i] += cur_eigvecs[j] - cur_eigvals * cur_eigvecs[i]
|
272 |
+
degree_eigvec[j] += cur_eigvecs[i] - cur_eigvals * cur_eigvecs[j]
|
273 |
+
|
274 |
+
delta_eigvecs[c, k] = np.dot(pinvs[k], flip_indicator[c] * degree_eigvec)
|
275 |
+
|
276 |
+
return delta_eigvecs
|
277 |
+
|
278 |
+
|
279 |
+
def estimate_delta_eigvals(candidates, adj, vals_org, vecs_org):
|
280 |
+
"""Computes the estimated change in the eigenvalues for every candidate edge flip.
|
281 |
+
|
282 |
+
:param candidates: np.ndarray, shape [?, 2]
|
283 |
+
Candidate set of edge flips
|
284 |
+
:param adj: sp.spmatrix
|
285 |
+
The graph represented as a sparse scipy matrix
|
286 |
+
:param vals_org: np.ndarray, shape [n]
|
287 |
+
The generalized eigenvalues of the clean graph
|
288 |
+
:param vecs_org: np.ndarray, shape [n, n]
|
289 |
+
The generalized eigenvectors of the clean graph
|
290 |
+
:return: np.ndarray, shape [?, n]
|
291 |
+
Estimated change in the eigenvalues for all candidate edge flips
|
292 |
+
"""
|
293 |
+
# vector indicating whether we are adding an edge (+1) or removing an edge (-1)
|
294 |
+
delta_w = 1 - 2 * adj[candidates[:, 0], candidates[:, 1]].A1
|
295 |
+
|
296 |
+
delta_eigvals = delta_w[:, None] * (2 * vecs_org[candidates[:, 0]] * vecs_org[candidates[:, 1]]
|
297 |
+
- vals_org * (
|
298 |
+
vecs_org[candidates[:, 0]] ** 2 + vecs_org[candidates[:, 1]] ** 2))
|
299 |
+
|
300 |
+
return delta_eigvals
|
301 |
+
|
302 |
+
|
303 |
+
class OtherNodeEmbeddingAttack(NodeEmbeddingAttack):
|
304 |
+
""" Baseline methods from the paper Adversarial Attacks on Node Embeddings
|
305 |
+
via Graph Poisoning. Aleksandar Bojchevski and Stephan Günnemann, ICML 2019.
|
306 |
+
http://proceedings.mlr.press/v97/bojchevski19a.html
|
307 |
+
|
308 |
+
Examples
|
309 |
+
-----
|
310 |
+
>>> from deeprobust.graph.data import Dataset
|
311 |
+
>>> from deeprobust.graph.global_attack import OtherNodeEmbeddingAttack
|
312 |
+
>>> data = Dataset(root='/tmp/', name='cora_ml', seed=15)
|
313 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
314 |
+
>>> model = OtherNodeEmbeddingAttack(type='degree')
|
315 |
+
>>> model.attack(adj, attack_type="remove")
|
316 |
+
>>> modified_adj = model.modified_adj
|
317 |
+
>>> #
|
318 |
+
>>> model = OtherNodeEmbeddingAttack(type='eigencentrality')
|
319 |
+
>>> model.attack(adj, attack_type="remove")
|
320 |
+
>>> modified_adj = model.modified_adj
|
321 |
+
>>> #
|
322 |
+
>>> model = OtherNodeEmbeddingAttack(type='random')
|
323 |
+
>>> model.attack(adj, attack_type="add", n_candidates=10000)
|
324 |
+
>>> modified_adj = model.modified_adj
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, type):
|
328 |
+
assert type in ["degree", "eigencentrality", "random"]
|
329 |
+
self.type = type
|
330 |
+
|
331 |
+
def attack(self, adj, n_perturbations=1000, attack_type="remove",
|
332 |
+
min_span_tree=False, n_candidates=None, seed=None, **kwargs):
|
333 |
+
"""Selects the top (n_perturbations) number of flips using our perturbation attack.
|
334 |
+
|
335 |
+
:param adj: sp.spmatrix
|
336 |
+
The graph represented as a sparse scipy matrix
|
337 |
+
:param n_perturbations: int
|
338 |
+
Number of flips to select
|
339 |
+
:param dim: int
|
340 |
+
Dimensionality of the embeddings.
|
341 |
+
:param attack_type: str
|
342 |
+
can be chosed from ["remove", "add"]
|
343 |
+
:param min_span_tree: bool
|
344 |
+
Whether to disallow edges that lie on the minimum spanning tree;
|
345 |
+
only valid when `attack_type` is "remove"
|
346 |
+
:param n_candidates: int
|
347 |
+
Number of candiates for addition; only valid when `attack_type` is "add";
|
348 |
+
:param seed: int
|
349 |
+
Random seed;
|
350 |
+
:return: np.ndarray, shape [?, 2]
|
351 |
+
The top edge flips from the candidate set
|
352 |
+
"""
|
353 |
+
assert attack_type in ["remove", "add"], \
|
354 |
+
"attack_type can only be `remove` or `add`"
|
355 |
+
|
356 |
+
if attack_type == "remove":
|
357 |
+
if min_span_tree:
|
358 |
+
candidates = self.generate_candidates_removal_minimum_spanning_tree(adj)
|
359 |
+
else:
|
360 |
+
candidates = self.generate_candidates_removal(adj, seed)
|
361 |
+
elif attack_type == "add":
|
362 |
+
assert n_candidates, "please specify the value of `n_candidates`, " \
|
363 |
+
+ "i.e. how many candiate you want to genereate for addition"
|
364 |
+
candidates = self.generate_candidates_addition(adj, n_candidates, seed)
|
365 |
+
else:
|
366 |
+
raise NotImplementedError
|
367 |
+
|
368 |
+
if self.type == "random":
|
369 |
+
top_flips = self.random_top_flips(candidates, n_perturbations, seed)
|
370 |
+
elif self.type == "eigencentrality":
|
371 |
+
top_flips = self.eigencentrality_top_flips(adj, candidates, n_perturbations)
|
372 |
+
elif self.type == "degree":
|
373 |
+
top_flips = self.degree_top_flips(adj, candidates, n_perturbations, complement=False)
|
374 |
+
else:
|
375 |
+
raise NotImplementedError
|
376 |
+
|
377 |
+
assert len(top_flips) == n_perturbations
|
378 |
+
modified_adj = self.flip_candidates(adj, top_flips)
|
379 |
+
self.check_adj(modified_adj)
|
380 |
+
self.modified_adj = modified_adj
|
381 |
+
|
382 |
+
def random_top_flips(self, candidates, n_perturbations, seed=None):
|
383 |
+
"""Selects (n_perturbations) number of flips at random.
|
384 |
+
|
385 |
+
:param candidates: np.ndarray, shape [?, 2]
|
386 |
+
Candidate set of edge flips
|
387 |
+
:param n_perturbations: int
|
388 |
+
Number of flips to select
|
389 |
+
:param seed: int
|
390 |
+
Random seed
|
391 |
+
:return: np.ndarray, shape [?, 2]
|
392 |
+
The top edge flips from the candidate set
|
393 |
+
"""
|
394 |
+
if seed is not None:
|
395 |
+
np.random.seed(seed)
|
396 |
+
return candidates[np.random.permutation(len(candidates))[:n_perturbations]]
|
397 |
+
|
398 |
+
|
399 |
+
def eigencentrality_top_flips(self, adj, candidates, n_perturbations):
|
400 |
+
"""Selects the top (n_perturbations) number of flips using eigencentrality score of the edges.
|
401 |
+
Applicable only when removing edges.
|
402 |
+
|
403 |
+
:param adj: sp.spmatrix
|
404 |
+
The graph represented as a sparse scipy matrix
|
405 |
+
:param candidates: np.ndarray, shape [?, 2]
|
406 |
+
Candidate set of edge flips
|
407 |
+
:param n_perturbations: int
|
408 |
+
Number of flips to select
|
409 |
+
:return: np.ndarray, shape [?, 2]
|
410 |
+
The top edge flips from the candidate set
|
411 |
+
"""
|
412 |
+
edges = np.column_stack(sp.triu(adj, 1).nonzero())
|
413 |
+
line_graph = construct_line_graph(adj)
|
414 |
+
eigcentrality_scores = nx.eigenvector_centrality_numpy(nx.Graph(line_graph))
|
415 |
+
eigcentrality_scores = {tuple(edges[k]): eigcentrality_scores[k] for k, v in eigcentrality_scores.items()}
|
416 |
+
eigcentrality_scores = np.array([eigcentrality_scores[tuple(cnd)] for cnd in candidates])
|
417 |
+
scores_argsrt = eigcentrality_scores.argsort()
|
418 |
+
return candidates[scores_argsrt[-n_perturbations:]]
|
419 |
+
|
420 |
+
|
421 |
+
def degree_top_flips(self, adj, candidates, n_perturbations, complement):
|
422 |
+
"""Selects the top (n_perturbations) number of flips using degree centrality score of the edges.
|
423 |
+
|
424 |
+
:param adj: sp.spmatrix
|
425 |
+
The graph represented as a sparse scipy matrix
|
426 |
+
:param candidates: np.ndarray, shape [?, 2]
|
427 |
+
Candidate set of edge flips
|
428 |
+
:param n_perturbations: int
|
429 |
+
Number of flips to select
|
430 |
+
:param complement: bool
|
431 |
+
Whether to look at the complement graph
|
432 |
+
:return: np.ndarray, shape [?, 2]
|
433 |
+
The top edge flips from the candidate set
|
434 |
+
"""
|
435 |
+
if complement:
|
436 |
+
adj = sp.csr_matrix(1-adj.toarray())
|
437 |
+
deg = adj.sum(1).A1
|
438 |
+
deg_argsort = (deg[candidates[:, 0]] + deg[candidates[:, 1]]).argsort()
|
439 |
+
|
440 |
+
return candidates[deg_argsort[-n_perturbations:]]
|
441 |
+
|
442 |
+
|
443 |
+
@numba.jit(nopython=True)
|
444 |
+
def sum_of_powers(x, power):
|
445 |
+
"""For each x_i, computes \sum_{r=1}^{pow) x_i^r (elementwise sum of powers).
|
446 |
+
|
447 |
+
:param x: shape [?]
|
448 |
+
Any vector
|
449 |
+
:param pow: int
|
450 |
+
The largest power to consider
|
451 |
+
:return: shape [?]
|
452 |
+
Vector where each element is the sum of powers from 1 to pow.
|
453 |
+
"""
|
454 |
+
n = x.shape[0]
|
455 |
+
sum_powers = np.zeros((power, n))
|
456 |
+
|
457 |
+
for i, i_power in enumerate(range(1, power + 1)):
|
458 |
+
sum_powers[i] = np.power(x, i_power)
|
459 |
+
|
460 |
+
return sum_powers.sum(0)
|
461 |
+
|
462 |
+
|
463 |
+
def edges_to_sparse(edges, num_nodes, weights=None):
|
464 |
+
if weights is None:
|
465 |
+
weights = np.ones(edges.shape[0])
|
466 |
+
|
467 |
+
return sp.coo_matrix((weights, (edges[:, 0], edges[:, 1])), shape=(num_nodes, num_nodes)).tocsr()
|
468 |
+
|
469 |
+
def construct_line_graph(adj):
|
470 |
+
"""Construct a line graph from an undirected original graph.
|
471 |
+
|
472 |
+
Parameters
|
473 |
+
----------
|
474 |
+
adj : sp.spmatrix [n_samples ,n_samples]
|
475 |
+
Symmetric binary adjacency matrix.
|
476 |
+
Returns
|
477 |
+
-------
|
478 |
+
L : sp.spmatrix, shape [A.nnz/2, A.nnz/2]
|
479 |
+
Symmetric binary adjacency matrix of the line graph.
|
480 |
+
"""
|
481 |
+
N = adj.shape[0]
|
482 |
+
edges = np.column_stack(sp.triu(adj, 1).nonzero())
|
483 |
+
e1, e2 = edges[:, 0], edges[:, 1]
|
484 |
+
|
485 |
+
I = sp.eye(N).tocsr()
|
486 |
+
E1 = I[e1]
|
487 |
+
E2 = I[e2]
|
488 |
+
|
489 |
+
L = E1.dot(E1.T) + E1.dot(E2.T) + E2.dot(E1.T) + E2.dot(E2.T)
|
490 |
+
|
491 |
+
return L - 2 * sp.eye(L.shape[0])
|
492 |
+
|
493 |
+
|
494 |
+
if __name__ == "__main__":
|
495 |
+
from deeprobust.graph.data import Dataset
|
496 |
+
from deeprobust.graph.defense import DeepWalk
|
497 |
+
import itertools
|
498 |
+
# load clean graph data
|
499 |
+
dataset_str = 'cora_ml'
|
500 |
+
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
|
501 |
+
adj, features, labels = data.adj, data.features, data.labels
|
502 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
503 |
+
|
504 |
+
comb = itertools.product(["random", "degree", "eigencentrality"], ["remove", "add"])
|
505 |
+
for type, attack_type in comb:
|
506 |
+
model = OtherNodeEmbeddingAttack(type=type)
|
507 |
+
print(model.type, attack_type)
|
508 |
+
try:
|
509 |
+
model.attack(adj, attack_type=attack_type, n_candidates=10000)
|
510 |
+
defender = DeepWalk()
|
511 |
+
defender.fit(adj)
|
512 |
+
defender.evaluate_node_classification(labels, idx_train, idx_test)
|
513 |
+
except KeyError:
|
514 |
+
print('eigencentrality only supports removing edges')
|
515 |
+
|
516 |
+
model = NodeEmbeddingAttack()
|
517 |
+
model.attack(adj, attack_type="remove")
|
518 |
+
model.attack(adj, attack_type="remove", min_span_tree=True)
|
519 |
+
modified_adj = model.modified_adj
|
520 |
+
model.attack(adj, attack_type="add", n_candidates=10000)
|
521 |
+
model.attack(adj, attack_type="add_by_remove", n_candidates=10000)
|
522 |
+
# model.attack(adj, attack_type="add")
|
deeprobust/graph/global_attack/prbcd.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Robustness of Graph Neural Networks at Scale. NeurIPS 2021.
|
3 |
+
|
4 |
+
Modified from https://github.com/sigeisler/robustness_of_gnns_at_scale/blob/main/rgnn_at_scale/attacks/prbcd.py
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
from deeprobust.graph.defense_pyg import GCN
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch
|
10 |
+
import deeprobust.graph.utils as utils
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
from tqdm import tqdm
|
13 |
+
import torch_sparse
|
14 |
+
from torch_sparse import coalesce
|
15 |
+
import math
|
16 |
+
from torch_geometric.utils import to_scipy_sparse_matrix, from_scipy_sparse_matrix
|
17 |
+
|
18 |
+
|
19 |
+
class PRBCD:
|
20 |
+
|
21 |
+
def __init__(self, data, model=None,
|
22 |
+
make_undirected=True,
|
23 |
+
eps=1e-7, search_space_size=10_000_000,
|
24 |
+
max_final_samples=20,
|
25 |
+
fine_tune_epochs=100,
|
26 |
+
epochs=400, lr_adj=0.1,
|
27 |
+
with_early_stopping=True,
|
28 |
+
do_synchronize=True,
|
29 |
+
device='cuda',
|
30 |
+
**kwargs
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
data : pyg format data
|
36 |
+
model : the model to be attacked, should be models in deeprobust.graph.defense_pyg
|
37 |
+
"""
|
38 |
+
self.device = device
|
39 |
+
self.data = data
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
model = self.pretrain_model()
|
43 |
+
|
44 |
+
self.model = model
|
45 |
+
nnodes = data.x.shape[0]
|
46 |
+
d = data.x.shape[1]
|
47 |
+
|
48 |
+
self.n, self.d = nnodes, nnodes
|
49 |
+
self.make_undirected = make_undirected
|
50 |
+
self.max_final_samples = max_final_samples
|
51 |
+
self.search_space_size = search_space_size
|
52 |
+
self.eps = eps
|
53 |
+
self.lr_adj = lr_adj
|
54 |
+
|
55 |
+
self.modified_edge_index: torch.Tensor = None
|
56 |
+
self.perturbed_edge_weight: torch.Tensor = None
|
57 |
+
if self.make_undirected:
|
58 |
+
self.n_possible_edges = self.n * (self.n - 1) // 2
|
59 |
+
else:
|
60 |
+
self.n_possible_edges = self.n ** 2 # We filter self-loops later
|
61 |
+
|
62 |
+
# lr_factor = 0.1
|
63 |
+
# self.lr_factor = lr_factor * max(math.log2(self.n_possible_edges / self.search_space_size), 1.)
|
64 |
+
self.epochs = epochs
|
65 |
+
self.epochs_resampling = epochs - fine_tune_epochs # TODO
|
66 |
+
|
67 |
+
self.with_early_stopping = with_early_stopping
|
68 |
+
self.do_synchronize = do_synchronize
|
69 |
+
|
70 |
+
def pretrain_model(self, model=None):
|
71 |
+
data = self.data
|
72 |
+
device = self.device
|
73 |
+
feat, labels = data.x, data.y
|
74 |
+
nclass = max(labels).item()+1
|
75 |
+
|
76 |
+
if model is None:
|
77 |
+
model = GCN(nfeat=feat.shape[1], nhid=256, dropout=0,
|
78 |
+
nlayers=3, with_bn=True, weight_decay=5e-4, nclass=nclass,
|
79 |
+
device=device).to(device)
|
80 |
+
print(model)
|
81 |
+
|
82 |
+
model.fit(data, train_iters=1000, patience=200, verbose=True)
|
83 |
+
model.eval()
|
84 |
+
model.data = data.to(self.device)
|
85 |
+
output = model.predict()
|
86 |
+
labels = labels.to(device)
|
87 |
+
print(f"{model.name} Test set results:", self.get_perf(output, labels, data.test_mask, verbose=0)[1])
|
88 |
+
self.clean_node_mask = (output.argmax(1) == labels)
|
89 |
+
return model
|
90 |
+
|
91 |
+
|
92 |
+
def sample_random_block(self, n_perturbations):
|
93 |
+
for _ in range(self.max_final_samples):
|
94 |
+
self.current_search_space = torch.randint(
|
95 |
+
self.n_possible_edges, (self.search_space_size,), device=self.device)
|
96 |
+
self.current_search_space = torch.unique(self.current_search_space, sorted=True)
|
97 |
+
if self.make_undirected:
|
98 |
+
self.modified_edge_index = linear_to_triu_idx(self.n, self.current_search_space)
|
99 |
+
else:
|
100 |
+
self.modified_edge_index = linear_to_full_idx(self.n, self.current_search_space)
|
101 |
+
is_not_self_loop = self.modified_edge_index[0] != self.modified_edge_index[1]
|
102 |
+
self.current_search_space = self.current_search_space[is_not_self_loop]
|
103 |
+
self.modified_edge_index = self.modified_edge_index[:, is_not_self_loop]
|
104 |
+
|
105 |
+
self.perturbed_edge_weight = torch.full_like(
|
106 |
+
self.current_search_space, self.eps, dtype=torch.float32, requires_grad=True
|
107 |
+
)
|
108 |
+
if self.current_search_space.size(0) >= n_perturbations:
|
109 |
+
return
|
110 |
+
raise RuntimeError('Sampling random block was not successfull. Please decrease `n_perturbations`.')
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def sample_final_edges(self, n_perturbations):
|
114 |
+
best_loss = -float('Inf')
|
115 |
+
perturbed_edge_weight = self.perturbed_edge_weight.detach()
|
116 |
+
perturbed_edge_weight[perturbed_edge_weight <= self.eps] = 0
|
117 |
+
|
118 |
+
_, feat, labels = self.edge_index, self.data.x, self.data.y
|
119 |
+
for i in range(self.max_final_samples):
|
120 |
+
if best_loss == float('Inf') or best_loss == -float('Inf'):
|
121 |
+
# In first iteration employ top k heuristic instead of sampling
|
122 |
+
sampled_edges = torch.zeros_like(perturbed_edge_weight)
|
123 |
+
sampled_edges[torch.topk(perturbed_edge_weight, n_perturbations).indices] = 1
|
124 |
+
else:
|
125 |
+
sampled_edges = torch.bernoulli(perturbed_edge_weight).float()
|
126 |
+
|
127 |
+
if sampled_edges.sum() > n_perturbations:
|
128 |
+
n_samples = sampled_edges.sum()
|
129 |
+
print(f'{i}-th sampling: too many samples {n_samples}')
|
130 |
+
continue
|
131 |
+
self.perturbed_edge_weight = sampled_edges
|
132 |
+
|
133 |
+
edge_index, edge_weight = self.get_modified_adj()
|
134 |
+
with torch.no_grad():
|
135 |
+
output = self.model.forward(feat, edge_index, edge_weight)
|
136 |
+
loss = F.nll_loss(output[self.data.val_mask], labels[self.data.val_mask]).item()
|
137 |
+
|
138 |
+
if best_loss < loss:
|
139 |
+
best_loss = loss
|
140 |
+
print('best_loss:', best_loss)
|
141 |
+
best_edges = self.perturbed_edge_weight.clone().cpu()
|
142 |
+
|
143 |
+
# Recover best sample
|
144 |
+
self.perturbed_edge_weight.data.copy_(best_edges.to(self.device))
|
145 |
+
|
146 |
+
edge_index, edge_weight = self.get_modified_adj()
|
147 |
+
edge_mask = edge_weight == 1
|
148 |
+
|
149 |
+
allowed_perturbations = 2 * n_perturbations if self.make_undirected else n_perturbations
|
150 |
+
edges_after_attack = edge_mask.sum()
|
151 |
+
clean_edges = self.edge_index.shape[1]
|
152 |
+
assert (edges_after_attack >= clean_edges - allowed_perturbations
|
153 |
+
and edges_after_attack <= clean_edges + allowed_perturbations), \
|
154 |
+
f'{edges_after_attack} out of range with {clean_edges} clean edges and {n_perturbations} pertutbations'
|
155 |
+
return edge_index[:, edge_mask], edge_weight[edge_mask]
|
156 |
+
|
157 |
+
def resample_random_block(self, n_perturbations: int):
|
158 |
+
self.keep_heuristic = 'WeightOnly'
|
159 |
+
if self.keep_heuristic == 'WeightOnly':
|
160 |
+
sorted_idx = torch.argsort(self.perturbed_edge_weight)
|
161 |
+
idx_keep = (self.perturbed_edge_weight <= self.eps).sum().long()
|
162 |
+
# Keep at most half of the block (i.e. resample low weights)
|
163 |
+
if idx_keep < sorted_idx.size(0) // 2:
|
164 |
+
idx_keep = sorted_idx.size(0) // 2
|
165 |
+
else:
|
166 |
+
raise NotImplementedError('Only keep_heuristic=`WeightOnly` supported')
|
167 |
+
|
168 |
+
sorted_idx = sorted_idx[idx_keep:]
|
169 |
+
self.current_search_space = self.current_search_space[sorted_idx]
|
170 |
+
self.modified_edge_index = self.modified_edge_index[:, sorted_idx]
|
171 |
+
self.perturbed_edge_weight = self.perturbed_edge_weight[sorted_idx]
|
172 |
+
|
173 |
+
# Sample until enough edges were drawn
|
174 |
+
for i in range(self.max_final_samples):
|
175 |
+
n_edges_resample = self.search_space_size - self.current_search_space.size(0)
|
176 |
+
lin_index = torch.randint(self.n_possible_edges, (n_edges_resample,), device=self.device)
|
177 |
+
|
178 |
+
self.current_search_space, unique_idx = torch.unique(
|
179 |
+
torch.cat((self.current_search_space, lin_index)),
|
180 |
+
sorted=True,
|
181 |
+
return_inverse=True
|
182 |
+
)
|
183 |
+
|
184 |
+
if self.make_undirected:
|
185 |
+
self.modified_edge_index = linear_to_triu_idx(self.n, self.current_search_space)
|
186 |
+
else:
|
187 |
+
self.modified_edge_index = linear_to_full_idx(self.n, self.current_search_space)
|
188 |
+
|
189 |
+
# Merge existing weights with new edge weights
|
190 |
+
perturbed_edge_weight_old = self.perturbed_edge_weight.clone()
|
191 |
+
self.perturbed_edge_weight = torch.full_like(self.current_search_space, self.eps, dtype=torch.float32)
|
192 |
+
self.perturbed_edge_weight[
|
193 |
+
unique_idx[:perturbed_edge_weight_old.size(0)]
|
194 |
+
] = perturbed_edge_weight_old # unique_idx: the indices for the old edges
|
195 |
+
|
196 |
+
if not self.make_undirected:
|
197 |
+
is_not_self_loop = self.modified_edge_index[0] != self.modified_edge_index[1]
|
198 |
+
self.current_search_space = self.current_search_space[is_not_self_loop]
|
199 |
+
self.modified_edge_index = self.modified_edge_index[:, is_not_self_loop]
|
200 |
+
self.perturbed_edge_weight = self.perturbed_edge_weight[is_not_self_loop]
|
201 |
+
|
202 |
+
if self.current_search_space.size(0) > n_perturbations:
|
203 |
+
return
|
204 |
+
raise RuntimeError('Sampling random block was not successfull. Please decrease `n_perturbations`.')
|
205 |
+
|
206 |
+
|
207 |
+
def project(self, n_perturbations, values, eps, inplace=False):
|
208 |
+
if not inplace:
|
209 |
+
values = values.clone()
|
210 |
+
|
211 |
+
if torch.clamp(values, 0, 1).sum() > n_perturbations:
|
212 |
+
left = (values - 1).min()
|
213 |
+
right = values.max()
|
214 |
+
miu = bisection(values, left, right, n_perturbations)
|
215 |
+
values.data.copy_(torch.clamp(
|
216 |
+
values - miu, min=eps, max=1 - eps
|
217 |
+
))
|
218 |
+
else:
|
219 |
+
values.data.copy_(torch.clamp(
|
220 |
+
values, min=eps, max=1 - eps
|
221 |
+
))
|
222 |
+
return values
|
223 |
+
|
224 |
+
def get_modified_adj(self):
|
225 |
+
if self.make_undirected:
|
226 |
+
modified_edge_index, modified_edge_weight = to_symmetric(
|
227 |
+
self.modified_edge_index, self.perturbed_edge_weight, self.n
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
modified_edge_index, modified_edge_weight = self.modified_edge_index, self.perturbed_edge_weight
|
231 |
+
edge_index = torch.cat((self.edge_index.to(self.device), modified_edge_index), dim=-1)
|
232 |
+
edge_weight = torch.cat((self.edge_weight.to(self.device), modified_edge_weight))
|
233 |
+
|
234 |
+
edge_index, edge_weight = torch_sparse.coalesce(edge_index, edge_weight, m=self.n, n=self.n, op='sum')
|
235 |
+
|
236 |
+
# Allow removal of edges
|
237 |
+
edge_weight[edge_weight > 1] = 2 - edge_weight[edge_weight > 1]
|
238 |
+
return edge_index, edge_weight
|
239 |
+
|
240 |
+
def update_edge_weights(self, n_perturbations, epoch, gradient):
|
241 |
+
self.optimizer_adj.zero_grad()
|
242 |
+
self.perturbed_edge_weight.grad = -gradient
|
243 |
+
self.optimizer_adj.step()
|
244 |
+
self.perturbed_edge_weight.data[self.perturbed_edge_weight < self.eps] = self.eps
|
245 |
+
|
246 |
+
def _update_edge_weights(self, n_perturbations, epoch, gradient):
|
247 |
+
lr_factor = n_perturbations / self.n / 2 * self.lr_factor
|
248 |
+
lr = lr_factor / np.sqrt(max(0, epoch - self.epochs_resampling) + 1)
|
249 |
+
self.perturbed_edge_weight.data.add_(lr * gradient)
|
250 |
+
self.perturbed_edge_weight.data[self.perturbed_edge_weight < self.eps] = self.eps
|
251 |
+
return None
|
252 |
+
|
253 |
+
def attack(self, edge_index=None, edge_weight=None, ptb_rate=0.1):
|
254 |
+
data = self.data
|
255 |
+
epochs, lr_adj = self.epochs, self.lr_adj
|
256 |
+
model = self.model
|
257 |
+
model.eval() # should set to eval
|
258 |
+
|
259 |
+
self.edge_index, feat, labels = data.edge_index, data.x, data.y
|
260 |
+
with torch.no_grad():
|
261 |
+
output = model.forward(feat, self.edge_index)
|
262 |
+
pred = output.argmax(1)
|
263 |
+
gt_labels = labels
|
264 |
+
labels = labels.clone() # to avoid shallow copy
|
265 |
+
labels[~data.train_mask] = pred[~data.train_mask]
|
266 |
+
|
267 |
+
if edge_index is not None:
|
268 |
+
self.edge_index = edge_index
|
269 |
+
|
270 |
+
self.edge_weight = torch.ones(self.edge_index.shape[1]).to(self.device)
|
271 |
+
|
272 |
+
n_perturbations = int(ptb_rate * self.edge_index.shape[1] //2)
|
273 |
+
print('n_perturbations:', n_perturbations)
|
274 |
+
self.sample_random_block(n_perturbations)
|
275 |
+
|
276 |
+
self.perturbed_edge_weight.requires_grad = True
|
277 |
+
self.optimizer_adj = torch.optim.Adam([self.perturbed_edge_weight], lr=lr_adj)
|
278 |
+
best_loss_val = -float('Inf')
|
279 |
+
for it in tqdm(range(epochs)):
|
280 |
+
self.perturbed_edge_weight.requires_grad = True
|
281 |
+
edge_index, edge_weight = self.get_modified_adj()
|
282 |
+
if torch.cuda.is_available() and self.do_synchronize:
|
283 |
+
torch.cuda.empty_cache()
|
284 |
+
torch.cuda.synchronize()
|
285 |
+
output = model.forward(feat, edge_index, edge_weight)
|
286 |
+
loss = self.loss_attack(output, labels, type='tanhMargin')
|
287 |
+
gradient = grad_with_checkpoint(loss, self.perturbed_edge_weight)[0]
|
288 |
+
|
289 |
+
if torch.cuda.is_available() and self.do_synchronize:
|
290 |
+
torch.cuda.empty_cache()
|
291 |
+
torch.cuda.synchronize()
|
292 |
+
if it % 10 == 0:
|
293 |
+
print(f'Epoch {it}: {loss}')
|
294 |
+
|
295 |
+
with torch.no_grad():
|
296 |
+
self.update_edge_weights(n_perturbations, it, gradient)
|
297 |
+
self.perturbed_edge_weight = self.project(
|
298 |
+
n_perturbations, self.perturbed_edge_weight, self.eps)
|
299 |
+
|
300 |
+
del edge_index, edge_weight #, logits
|
301 |
+
|
302 |
+
if it < self.epochs_resampling - 1:
|
303 |
+
self.resample_random_block(n_perturbations)
|
304 |
+
|
305 |
+
edge_index, edge_weight = self.get_modified_adj()
|
306 |
+
output = model.predict(feat, edge_index, edge_weight)
|
307 |
+
loss_val = F.nll_loss(output[data.val_mask], labels[data.val_mask])
|
308 |
+
|
309 |
+
self.perturbed_edge_weight.requires_grad = True
|
310 |
+
self.optimizer_adj = torch.optim.Adam([self.perturbed_edge_weight], lr=lr_adj)
|
311 |
+
|
312 |
+
# Sample final discrete graph
|
313 |
+
edge_index, edge_weight = self.sample_final_edges(n_perturbations)
|
314 |
+
output = model.predict(feat, edge_index, edge_weight)
|
315 |
+
print('Test:')
|
316 |
+
self.get_perf(output, gt_labels, data.test_mask)
|
317 |
+
print('Validatoin:')
|
318 |
+
self.get_perf(output, gt_labels, data.val_mask)
|
319 |
+
return edge_index, edge_weight
|
320 |
+
|
321 |
+
def loss_attack(self, logits, labels, type='CE'):
|
322 |
+
self.loss_type = type
|
323 |
+
if self.loss_type == 'tanhMargin':
|
324 |
+
sorted = logits.argsort(-1)
|
325 |
+
best_non_target_class = sorted[sorted != labels[:, None]].reshape(logits.size(0), -1)[:, -1]
|
326 |
+
margin = (
|
327 |
+
logits[np.arange(logits.size(0)), labels]
|
328 |
+
- logits[np.arange(logits.size(0)), best_non_target_class]
|
329 |
+
)
|
330 |
+
loss = torch.tanh(-margin).mean()
|
331 |
+
elif self.loss_type == 'MCE':
|
332 |
+
not_flipped = logits.argmax(-1) == labels
|
333 |
+
loss = F.cross_entropy(logits[not_flipped], labels[not_flipped])
|
334 |
+
elif self.loss_type == 'NCE':
|
335 |
+
sorted = logits.argsort(-1)
|
336 |
+
best_non_target_class = sorted[sorted != labels[:, None]].reshape(logits.size(0), -1)[:, -1]
|
337 |
+
loss = -F.cross_entropy(logits, best_non_target_class)
|
338 |
+
else:
|
339 |
+
loss = F.cross_entropy(logits, labels)
|
340 |
+
return loss
|
341 |
+
|
342 |
+
def get_perf(self, output, labels, mask, verbose=True):
|
343 |
+
loss = F.nll_loss(output[mask], labels[mask])
|
344 |
+
acc = utils.accuracy(output[mask], labels[mask])
|
345 |
+
if verbose:
|
346 |
+
print("loss= {:.4f}".format(loss.item()),
|
347 |
+
"accuracy= {:.4f}".format(acc.item()))
|
348 |
+
return loss.item(), acc.item()
|
349 |
+
|
350 |
+
@torch.jit.script
|
351 |
+
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
|
352 |
+
"""Entropy of softmax distribution from **logits**."""
|
353 |
+
return -(x.softmax(1) * x.log_softmax(1)).sum(1)
|
354 |
+
|
355 |
+
@torch.jit.script
|
356 |
+
def entropy(x: torch.Tensor) -> torch.Tensor:
|
357 |
+
"""Entropy of softmax distribution from **log_softmax**."""
|
358 |
+
return -(torch.exp(x) * x).sum(1)
|
359 |
+
|
360 |
+
def to_symmetric(edge_index, edge_weight, n, op='mean'):
|
361 |
+
symmetric_edge_index = torch.cat(
|
362 |
+
(edge_index, edge_index.flip(0)), dim=-1
|
363 |
+
)
|
364 |
+
|
365 |
+
symmetric_edge_weight = edge_weight.repeat(2)
|
366 |
+
|
367 |
+
symmetric_edge_index, symmetric_edge_weight = coalesce(
|
368 |
+
symmetric_edge_index,
|
369 |
+
symmetric_edge_weight,
|
370 |
+
m=n,
|
371 |
+
n=n,
|
372 |
+
op=op
|
373 |
+
)
|
374 |
+
return symmetric_edge_index, symmetric_edge_weight
|
375 |
+
|
376 |
+
def linear_to_full_idx(n: int, lin_idx: torch.Tensor) -> torch.Tensor:
|
377 |
+
row_idx = lin_idx // n
|
378 |
+
col_idx = lin_idx % n
|
379 |
+
return torch.stack((row_idx, col_idx))
|
380 |
+
|
381 |
+
def linear_to_triu_idx(n: int, lin_idx: torch.Tensor) -> torch.Tensor:
|
382 |
+
row_idx = (
|
383 |
+
n
|
384 |
+
- 2
|
385 |
+
- torch.floor(torch.sqrt(-8 * lin_idx.double() + 4 * n * (n - 1) - 7) / 2.0 - 0.5)
|
386 |
+
).long()
|
387 |
+
col_idx = (
|
388 |
+
lin_idx
|
389 |
+
+ row_idx
|
390 |
+
+ 1 - n * (n - 1) // 2
|
391 |
+
+ (n - row_idx) * ((n - row_idx) - 1) // 2
|
392 |
+
)
|
393 |
+
return torch.stack((row_idx, col_idx))
|
394 |
+
|
395 |
+
def grad_with_checkpoint(outputs, inputs):
|
396 |
+
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
|
397 |
+
for input in inputs:
|
398 |
+
if not input.is_leaf:
|
399 |
+
input.retain_grad()
|
400 |
+
torch.autograd.backward(outputs)
|
401 |
+
|
402 |
+
grad_outputs = []
|
403 |
+
for input in inputs:
|
404 |
+
grad_outputs.append(input.grad.clone())
|
405 |
+
input.grad.zero_()
|
406 |
+
return grad_outputs
|
407 |
+
|
408 |
+
def bisection(edge_weights, a, b, n_perturbations, epsilon=1e-5, iter_max=1e5):
|
409 |
+
def func(x):
|
410 |
+
return torch.clamp(edge_weights - x, 0, 1).sum() - n_perturbations
|
411 |
+
|
412 |
+
miu = a
|
413 |
+
for i in range(int(iter_max)):
|
414 |
+
miu = (a + b) / 2
|
415 |
+
# Check if middle point is root
|
416 |
+
if (func(miu) == 0.0):
|
417 |
+
break
|
418 |
+
# Decide the side to repeat the steps
|
419 |
+
if (func(miu) * func(a) < 0):
|
420 |
+
b = miu
|
421 |
+
else:
|
422 |
+
a = miu
|
423 |
+
if ((b - a) <= epsilon):
|
424 |
+
break
|
425 |
+
return miu
|
426 |
+
|
427 |
+
|
428 |
+
if __name__ == "__main__":
|
429 |
+
from ogb.nodeproppred import PygNodePropPredDataset
|
430 |
+
from torch_geometric.utils import to_undirected
|
431 |
+
import torch_geometric.transforms as T
|
432 |
+
dataset = PygNodePropPredDataset(name='ogbn-arxiv')
|
433 |
+
dataset.transform = T.NormalizeFeatures()
|
434 |
+
data = dataset[0]
|
435 |
+
if not hasattr(data, 'train_mask'):
|
436 |
+
utils.add_mask(data, dataset)
|
437 |
+
data.edge_index = to_undirected(data.edge_index, data.num_nodes)
|
438 |
+
agent = PRBCD(data)
|
439 |
+
edge_index, edge_weight = agent.attack()
|
440 |
+
|
deeprobust/graph/rl/nipa_env.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This part of code is adopted from https://github.com/Hanjun-Dai/graph_adversarial_attack (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le)
|
3 |
+
but modified to be integrated into the repository.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import networkx as nx
|
11 |
+
import random
|
12 |
+
from torch.nn.parameter import Parameter
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.optim as optim
|
16 |
+
from tqdm import tqdm
|
17 |
+
from copy import deepcopy
|
18 |
+
import pickle as cp
|
19 |
+
from deeprobust.graph.utils import *
|
20 |
+
import scipy.sparse as sp
|
21 |
+
from scipy.sparse.linalg.eigen.arpack import eigsh
|
22 |
+
from deeprobust.graph import utils
|
23 |
+
from deeprobust.graph.rl.env import *
|
24 |
+
|
25 |
+
class NodeInjectionEnv(NodeAttackEnv):
|
26 |
+
"""Node attack environment. It executes an action and then change the
|
27 |
+
environment status (modify the graph).
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, features, labels, idx_train, idx_val, dict_of_lists, classifier, ratio=0.01, parallel_size=1, reward_type='binary'):
|
31 |
+
"""number of injected nodes: ratio*|V|
|
32 |
+
number of modifications: ratio*|V|*|D_avg|
|
33 |
+
"""
|
34 |
+
# super(NodeInjectionEnv, self).__init__(features, labels, all_targets, list_action_space, classifier, num_mod, reward_type)
|
35 |
+
super(NodeInjectionEnv, self).__init__(features, labels, idx_val, dict_of_lists, classifier)
|
36 |
+
self.parallel_size = parallel_size
|
37 |
+
|
38 |
+
degrees = np.array([len(d) for n, d in dict_of_lists.items()])
|
39 |
+
N = len(degrees[degrees > 0])
|
40 |
+
avg_degree = degrees.sum() / N
|
41 |
+
self.n_injected = len(degrees) - N
|
42 |
+
assert self.n_injected == int(ratio * N)
|
43 |
+
|
44 |
+
self.ori_adj_size = N
|
45 |
+
self.n_perturbations = int(self.n_injected * avg_degree)
|
46 |
+
print("number of perturbations: {}".format(self.n_perturbations))
|
47 |
+
self.all_nodes = np.arange(N)
|
48 |
+
self.injected_nodes = self.all_nodes[-self.n_injected: ]
|
49 |
+
self.previous_acc = [1] * parallel_size
|
50 |
+
|
51 |
+
self.idx_train = np.hstack((idx_train, self.injected_nodes))
|
52 |
+
self.idx_val = idx_val
|
53 |
+
|
54 |
+
self.modified_label_list = []
|
55 |
+
for i in range(self.parallel_size):
|
56 |
+
self.modified_label_list.append(labels[-self.n_injected: ].clone())
|
57 |
+
|
58 |
+
|
59 |
+
def init_overall_steps(self):
|
60 |
+
self.overall_steps = 0
|
61 |
+
self.modified_list = []
|
62 |
+
for i in range(self.parallel_size):
|
63 |
+
self.modified_list.append(ModifiedGraph())
|
64 |
+
|
65 |
+
def setup(self):
|
66 |
+
self.n_steps = 0
|
67 |
+
self.first_nodes = None
|
68 |
+
self.second_nodes = None
|
69 |
+
self.rewards = None
|
70 |
+
self.binary_rewards = None
|
71 |
+
self.list_acc_of_all = []
|
72 |
+
|
73 |
+
def step(self, actions, inference=False):
|
74 |
+
'''
|
75 |
+
run actions and get reward
|
76 |
+
'''
|
77 |
+
if self.first_nodes is None: # pick the first node of edge
|
78 |
+
assert (self.n_steps + 1) % 3 == 1
|
79 |
+
self.first_nodes = actions[:]
|
80 |
+
|
81 |
+
if (self.n_steps + 1) % 3 == 2:
|
82 |
+
self.second_nodes = actions[:]
|
83 |
+
for i in range(self.parallel_size):
|
84 |
+
# add an edge from the graph
|
85 |
+
self.modified_list[i].add_edge(self.first_nodes[i], actions[i], 1.0)
|
86 |
+
|
87 |
+
if (self.n_steps + 1) % 3 == 0:
|
88 |
+
for i in range(self.parallel_size):
|
89 |
+
# change label
|
90 |
+
self.modified_label_list[i][self.first_nodes[i] - self.ori_adj_size] = actions[i]
|
91 |
+
|
92 |
+
self.first_nodes = None
|
93 |
+
self.second_nodes = None
|
94 |
+
|
95 |
+
self.n_steps += 1
|
96 |
+
self.overall_steps += 1
|
97 |
+
|
98 |
+
if not inference:
|
99 |
+
if self.isActionFinished() :
|
100 |
+
rewards = []
|
101 |
+
for i in (range(self.parallel_size)):
|
102 |
+
device = self.labels.device
|
103 |
+
extra_adj = self.modified_list[i].get_extra_adj(device=device)
|
104 |
+
adj = self.classifier.norm_tool.norm_extra(extra_adj)
|
105 |
+
labels = torch.cat((self.labels, self.modified_label_list[i]))
|
106 |
+
# self.classifier.fit(self.features, adj, labels, self.idx_train, self.idx_val, normalize=False)
|
107 |
+
self.classifier.fit(self.features, adj, labels, self.idx_train, self.idx_val, normalize=False, patience=30)
|
108 |
+
output = self.classifier(self.features, adj)
|
109 |
+
loss, correct = loss_acc(output, self.labels, self.idx_val, avg_loss=False)
|
110 |
+
acc = correct.sum()
|
111 |
+
# r = 1 if self.previous_acc[i] - acc > 0.01 else -1
|
112 |
+
r = 1 if self.previous_acc[i] - acc > 0 else -1
|
113 |
+
self.previous_acc[i] = acc
|
114 |
+
rewards.append(r)
|
115 |
+
self.rewards = np.array(rewards).astype(np.float32)
|
116 |
+
|
117 |
+
|
118 |
+
def sample_pos_rewards(self, num_samples):
|
119 |
+
assert self.list_acc_of_all is not None
|
120 |
+
cands = []
|
121 |
+
|
122 |
+
for i in range(len(self.list_acc_of_all)):
|
123 |
+
succ = np.where( self.list_acc_of_all[i] < 0.9 )[0]
|
124 |
+
|
125 |
+
for j in range(len(succ)):
|
126 |
+
|
127 |
+
cands.append((i, self.all_targets[succ[j]]))
|
128 |
+
|
129 |
+
if num_samples > len(cands):
|
130 |
+
return cands
|
131 |
+
random.shuffle(cands)
|
132 |
+
return cands[0:num_samples]
|
133 |
+
|
134 |
+
def uniformRandActions(self):
|
135 |
+
act_list = []
|
136 |
+
for i in range(self.parallel_size):
|
137 |
+
if self.first_nodes is None:
|
138 |
+
# a1: choose a node from injected nodes
|
139 |
+
cur_action = np.random.choice(self.injected_nodes)
|
140 |
+
|
141 |
+
if self.first_nodes is not None and self.second_nodes is None:
|
142 |
+
# a2: choose a node from all nodes
|
143 |
+
cur_action = np.random.randint(len(self.list_action_space))
|
144 |
+
while (self.first_nodes[i], cur_action) in self.modified_list[i].edge_set:
|
145 |
+
cur_action = np.random.randint(len(self.list_action_space))
|
146 |
+
|
147 |
+
if self.first_nodes is not None and self.second_nodes is not None:
|
148 |
+
# a3: choose label
|
149 |
+
cur_action = np.random.randint(self.labels.cpu().max() + 1)
|
150 |
+
|
151 |
+
act_list.append(cur_action)
|
152 |
+
return act_list
|
153 |
+
|
154 |
+
def isActionFinished(self):
|
155 |
+
if (self.n_steps) % 3 == 0 and self.n_steps != 0:
|
156 |
+
return True
|
157 |
+
return False
|
158 |
+
|
159 |
+
def isTerminal(self):
|
160 |
+
if self.overall_steps == 3 * self.n_perturbations:
|
161 |
+
return True
|
162 |
+
return False
|
163 |
+
|
164 |
+
def getStateRef(self):
|
165 |
+
return list(zip(self.modified_list, self.modified_label_list))
|
166 |
+
|
167 |
+
def cloneState(self):
|
168 |
+
return list(zip(deepcopy(self.modified_list), deepcopy(self.modified_label_list)))
|
169 |
+
|
deeprobust/graph/rl/rl_s2v_config.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le
|
2 |
+
"""
|
3 |
+
import argparse
|
4 |
+
import pickle as cp
|
5 |
+
|
6 |
+
cmd_opt = argparse.ArgumentParser(description='Argparser for molecule vae')
|
7 |
+
|
8 |
+
cmd_opt.add_argument('-saved_model', type=str, default=None, help='saved model')
|
9 |
+
cmd_opt.add_argument('-save_dir', type=str, default=None, help='save folder')
|
10 |
+
cmd_opt.add_argument('-ctx', type=str, default='gpu', help='cpu/gpu')
|
11 |
+
|
12 |
+
cmd_opt.add_argument('-phase', type=str, default='train', help='train/test')
|
13 |
+
cmd_opt.add_argument('-batch_size', type=int, default=10, help='minibatch size')
|
14 |
+
cmd_opt.add_argument('-seed', type=int, default=1, help='seed')
|
15 |
+
|
16 |
+
cmd_opt.add_argument('-gm', default='mean_field', help='mean_field/loopy_bp/gcn')
|
17 |
+
cmd_opt.add_argument('-latent_dim', type=int, default=64, help='dimension of latent layers')
|
18 |
+
cmd_opt.add_argument('-hidden', type=int, default=0, help='dimension of classification')
|
19 |
+
cmd_opt.add_argument('-max_lv', type=int, default=1, help='max rounds of message passing')
|
20 |
+
|
21 |
+
# target model
|
22 |
+
cmd_opt.add_argument('-num_epochs', type=int, default=200, help='number of epochs')
|
23 |
+
cmd_opt.add_argument('-learning_rate', type=float, default=0.01, help='init learning_rate')
|
24 |
+
cmd_opt.add_argument('-weight_decay', type=float, default=5e-4, help='weight_decay')
|
25 |
+
cmd_opt.add_argument('-dropout', type=float, default=0.5, help='dropout rate')
|
26 |
+
|
27 |
+
# for node classification
|
28 |
+
cmd_opt.add_argument('-dataset', type=str, default='cora', help='citeseer/cora/pubmed')
|
29 |
+
|
30 |
+
# for attack
|
31 |
+
cmd_opt.add_argument('-num_steps', type=int, default=500000, help='rl training steps')
|
32 |
+
# cmd_opt.add_argument('-frac_meta', type=float, default=0, help='fraction for meta rl learning')
|
33 |
+
|
34 |
+
cmd_opt.add_argument('-meta_test', type=int, default=0, help='for meta rl learning')
|
35 |
+
cmd_opt.add_argument('-reward_type', type=str, default='binary', help='binary/nll')
|
36 |
+
cmd_opt.add_argument('-num_mod', type=int, default=1, help='number of modifications allowed')
|
37 |
+
|
38 |
+
# for node attack
|
39 |
+
cmd_opt.add_argument('-bilin_q', type=int, default=1, help='bilinear q or not')
|
40 |
+
cmd_opt.add_argument('-mlp_hidden', type=int, default=64, help='mlp hidden layer size')
|
41 |
+
# cmd_opt.add_argument('-n_hops', type=int, default=2, help='attack range')
|
42 |
+
|
43 |
+
|
44 |
+
args, _ = cmd_opt.parse_known_args()
|
45 |
+
args.save_dir = './results/rl_s2v/{}-gcn'.format(args.dataset)
|
46 |
+
args.saved_model = 'results/node_classification/{}'.format(args.dataset)
|
47 |
+
print(args)
|
48 |
+
|
49 |
+
def build_kwargs(keys, arg_dict):
|
50 |
+
st = ''
|
51 |
+
for key in keys:
|
52 |
+
st += '%s-%s' % (key, str(arg_dict[key]))
|
53 |
+
return st
|
54 |
+
|
55 |
+
def save_args(fout, args):
|
56 |
+
with open(fout, 'wb') as f:
|
57 |
+
cp.dump(args, f, cp.HIGHEST_PROTOCOL)
|
deeprobust/graph/targeted_attack/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_attack import BaseAttack
|
2 |
+
from .fga import FGA
|
3 |
+
from .rnd import RND
|
4 |
+
from .nettack import Nettack
|
5 |
+
from .ig_attack import IGAttack
|
6 |
+
from .rl_s2v import RLS2V
|
7 |
+
from .sga import SGAttack
|
8 |
+
|
9 |
+
__all__ = ['BaseAttack', 'FGA', 'RND', 'Nettack', 'IGAttack', 'RLS2V', 'SGAttack']
|
deeprobust/graph/targeted_attack/base_attack.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn.modules.module import Module
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import scipy.sparse as sp
|
5 |
+
import os.path as osp
|
6 |
+
|
7 |
+
class BaseAttack(Module):
|
8 |
+
"""Abstract base class for target attack classes.
|
9 |
+
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
model :
|
13 |
+
model to attack
|
14 |
+
nnodes : int
|
15 |
+
number of nodes in the input graph
|
16 |
+
attack_structure : bool
|
17 |
+
whether to attack graph structure
|
18 |
+
attack_features : bool
|
19 |
+
whether to attack node features
|
20 |
+
device: str
|
21 |
+
'cpu' or 'cuda'
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, model, nnodes, attack_structure=True, attack_features=False, device='cpu'):
|
26 |
+
super(BaseAttack, self).__init__()
|
27 |
+
|
28 |
+
self.surrogate = model
|
29 |
+
self.nnodes = nnodes
|
30 |
+
self.attack_structure = attack_structure
|
31 |
+
self.attack_features = attack_features
|
32 |
+
self.device = device
|
33 |
+
|
34 |
+
if model is not None:
|
35 |
+
self.nclass = model.nclass
|
36 |
+
self.nfeat = model.nfeat
|
37 |
+
self.hidden_sizes = model.hidden_sizes
|
38 |
+
|
39 |
+
self.modified_adj = None
|
40 |
+
self.modified_features = None
|
41 |
+
|
42 |
+
def attack(self, ori_adj, n_perturbations, **kwargs):
|
43 |
+
"""Generate perturbations on the input graph.
|
44 |
+
|
45 |
+
Parameters
|
46 |
+
----------
|
47 |
+
ori_adj : scipy.sparse.csr_matrix
|
48 |
+
Original (unperturbed) adjacency matrix.
|
49 |
+
n_perturbations : int
|
50 |
+
Number of perturbations on the input graph. Perturbations could
|
51 |
+
be edge removals/additions or feature removals/additions.
|
52 |
+
|
53 |
+
Returns
|
54 |
+
-------
|
55 |
+
None.
|
56 |
+
|
57 |
+
"""
|
58 |
+
pass
|
59 |
+
|
60 |
+
def check_adj(self, adj):
|
61 |
+
"""Check if the modified adjacency is symmetric and unweighted.
|
62 |
+
"""
|
63 |
+
|
64 |
+
if type(adj) is torch.Tensor:
|
65 |
+
adj = adj.cpu().numpy()
|
66 |
+
assert np.abs(adj - adj.T).sum() == 0, "Input graph is not symmetric"
|
67 |
+
if sp.issparse(adj):
|
68 |
+
assert adj.tocsr().max() == 1, "Max value should be 1!"
|
69 |
+
assert adj.tocsr().min() == 0, "Min value should be 0!"
|
70 |
+
else:
|
71 |
+
assert adj.max() == 1, "Max value should be 1!"
|
72 |
+
assert adj.min() == 0, "Min value should be 0!"
|
73 |
+
|
74 |
+
def save_adj(self, root=r'/tmp/', name='mod_adj'):
|
75 |
+
"""Save attacked adjacency matrix.
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
root :
|
80 |
+
root directory where the variable should be saved
|
81 |
+
name : str
|
82 |
+
saved file name
|
83 |
+
|
84 |
+
Returns
|
85 |
+
-------
|
86 |
+
None.
|
87 |
+
|
88 |
+
"""
|
89 |
+
assert self.modified_adj is not None, \
|
90 |
+
'modified_adj is None! Please perturb the graph first.'
|
91 |
+
name = name + '.npz'
|
92 |
+
modified_adj = self.modified_adj
|
93 |
+
|
94 |
+
if type(modified_adj) is torch.Tensor:
|
95 |
+
modified_adj = utils.to_scipy(modified_adj)
|
96 |
+
if sp.issparse(modified_adj):
|
97 |
+
modified_adj = modified_adj.tocsr()
|
98 |
+
sp.save_npz(osp.join(root, name), modified_adj)
|
99 |
+
|
100 |
+
def save_features(self, root=r'/tmp/', name='mod_features'):
|
101 |
+
"""Save attacked node feature matrix.
|
102 |
+
|
103 |
+
Parameters
|
104 |
+
----------
|
105 |
+
root :
|
106 |
+
root directory where the variable should be saved
|
107 |
+
name : str
|
108 |
+
saved file name
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
None.
|
113 |
+
|
114 |
+
"""
|
115 |
+
|
116 |
+
assert self.modified_features is not None, \
|
117 |
+
'modified_features is None! Please perturb the graph first.'
|
118 |
+
name = name + '.npz'
|
119 |
+
modified_features = self.modified_features
|
120 |
+
|
121 |
+
if type(modified_features) is torch.Tensor:
|
122 |
+
modified_features = utils.to_scipy(modified_features)
|
123 |
+
if sp.issparse(modified_features):
|
124 |
+
modified_features = modified_features.tocsr()
|
125 |
+
sp.save_npz(osp.join(root, name), modified_features)
|
126 |
+
|
deeprobust/graph/targeted_attack/fga.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FGA: Fast Gradient Attack on Network Embedding (https://arxiv.org/pdf/1809.02797.pdf)
|
3 |
+
Another very similar algorithm to mention here is FGSM (for graph data).
|
4 |
+
It is mentioned in Zugner's paper,
|
5 |
+
Adversarial Attacks on Neural Networks for Graph Data, KDD'19
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from deeprobust.graph.targeted_attack import BaseAttack
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
from copy import deepcopy
|
12 |
+
from deeprobust.graph import utils
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import scipy.sparse as sp
|
15 |
+
|
16 |
+
class FGA(BaseAttack):
|
17 |
+
"""FGA/FGSM.
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
model :
|
22 |
+
model to attack
|
23 |
+
nnodes : int
|
24 |
+
number of nodes in the input graph
|
25 |
+
feature_shape : tuple
|
26 |
+
shape of the input node features
|
27 |
+
attack_structure : bool
|
28 |
+
whether to attack graph structure
|
29 |
+
attack_features : bool
|
30 |
+
whether to attack node features
|
31 |
+
device: str
|
32 |
+
'cpu' or 'cuda'
|
33 |
+
|
34 |
+
Examples
|
35 |
+
--------
|
36 |
+
|
37 |
+
>>> from deeprobust.graph.data import Dataset
|
38 |
+
>>> from deeprobust.graph.defense import GCN
|
39 |
+
>>> from deeprobust.graph.targeted_attack import FGA
|
40 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
41 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
42 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
43 |
+
>>> # Setup Surrogate model
|
44 |
+
>>> surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
45 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device='cpu').to('cpu')
|
46 |
+
>>> surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
47 |
+
>>> # Setup Attack Model
|
48 |
+
>>> target_node = 0
|
49 |
+
>>> model = FGA(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=False, device='cpu').to('cpu')
|
50 |
+
>>> # Attack
|
51 |
+
>>> model.attack(features, adj, labels, idx_train, target_node, n_perturbations=5)
|
52 |
+
>>> modified_adj = model.modified_adj
|
53 |
+
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, model, nnodes, feature_shape=None, attack_structure=True, attack_features=False, device='cpu'):
|
57 |
+
|
58 |
+
super(FGA, self).__init__(model, nnodes, attack_structure=attack_structure, attack_features=attack_features, device=device)
|
59 |
+
|
60 |
+
|
61 |
+
assert not self.attack_features, "not support attacking features"
|
62 |
+
|
63 |
+
if self.attack_features:
|
64 |
+
self.feature_changes = Parameter(torch.FloatTensor(feature_shape))
|
65 |
+
self.feature_changes.data.fill_(0)
|
66 |
+
|
67 |
+
def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_perturbations, verbose=False, **kwargs):
|
68 |
+
"""Generate perturbations on the input graph.
|
69 |
+
|
70 |
+
Parameters
|
71 |
+
----------
|
72 |
+
ori_features : scipy.sparse.csr_matrix
|
73 |
+
Original (unperturbed) adjacency matrix
|
74 |
+
ori_adj : scipy.sparse.csr_matrix
|
75 |
+
Original (unperturbed) node feature matrix
|
76 |
+
labels :
|
77 |
+
node labels
|
78 |
+
idx_train:
|
79 |
+
training node indices
|
80 |
+
target_node : int
|
81 |
+
target node index to be attacked
|
82 |
+
n_perturbations : int
|
83 |
+
Number of perturbations on the input graph. Perturbations could
|
84 |
+
be edge removals/additions or feature removals/additions.
|
85 |
+
"""
|
86 |
+
|
87 |
+
modified_adj = ori_adj.todense()
|
88 |
+
modified_features = ori_features.todense()
|
89 |
+
modified_adj, modified_features, labels = utils.to_tensor(modified_adj, modified_features, labels, device=self.device)
|
90 |
+
|
91 |
+
self.surrogate.eval()
|
92 |
+
if verbose == True:
|
93 |
+
print('number of pertubations: %s' % n_perturbations)
|
94 |
+
|
95 |
+
pseudo_labels = self.surrogate.predict().detach().argmax(1)
|
96 |
+
pseudo_labels[idx_train] = labels[idx_train]
|
97 |
+
|
98 |
+
modified_adj.requires_grad = True
|
99 |
+
for i in range(n_perturbations):
|
100 |
+
adj_norm = utils.normalize_adj_tensor(modified_adj)
|
101 |
+
|
102 |
+
if self.attack_structure:
|
103 |
+
output = self.surrogate(modified_features, adj_norm)
|
104 |
+
loss = F.nll_loss(output[[target_node]], pseudo_labels[[target_node]])
|
105 |
+
grad = torch.autograd.grad(loss, modified_adj)[0]
|
106 |
+
# bidirection
|
107 |
+
grad = (grad[target_node] + grad[:, target_node]) * (-2*modified_adj[target_node] + 1)
|
108 |
+
grad[target_node] = -10
|
109 |
+
grad_argmax = torch.argmax(grad)
|
110 |
+
|
111 |
+
value = -2*modified_adj[target_node][grad_argmax] + 1
|
112 |
+
modified_adj.data[target_node][grad_argmax] += value
|
113 |
+
modified_adj.data[grad_argmax][target_node] += value
|
114 |
+
|
115 |
+
if self.attack_features:
|
116 |
+
pass
|
117 |
+
|
118 |
+
modified_adj = modified_adj.detach().cpu().numpy()
|
119 |
+
modified_adj = sp.csr_matrix(modified_adj)
|
120 |
+
self.check_adj(modified_adj)
|
121 |
+
self.modified_adj = modified_adj
|
122 |
+
# self.modified_features = modified_features
|
123 |
+
|
124 |
+
|
deeprobust/graph/targeted_attack/ig_attack.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Examples on Graph Data: Deep Insights into Attack and Defense
|
3 |
+
https://arxiv.org/pdf/1903.01610.pdf
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.multiprocessing as mp
|
8 |
+
from deeprobust.graph.targeted_attack import BaseAttack
|
9 |
+
from torch.nn.parameter import Parameter
|
10 |
+
from deeprobust.graph import utils
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
import scipy.sparse as sp
|
14 |
+
|
15 |
+
from torch import optim
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from torch.nn.modules.module import Module
|
18 |
+
import numpy as np
|
19 |
+
from tqdm import tqdm
|
20 |
+
import math
|
21 |
+
import scipy.sparse as sp
|
22 |
+
|
23 |
+
class IGAttack(BaseAttack):
|
24 |
+
"""IGAttack: IG-FGSM. Adversarial Examples on Graph Data: Deep Insights into Attack and Defense, https://arxiv.org/pdf/1903.01610.pdf.
|
25 |
+
|
26 |
+
Parameters
|
27 |
+
----------
|
28 |
+
model :
|
29 |
+
model to attack
|
30 |
+
nnodes : int
|
31 |
+
number of nodes in the input graph
|
32 |
+
feature_shape : tuple
|
33 |
+
shape of the input node features
|
34 |
+
attack_structure : bool
|
35 |
+
whether to attack graph structure
|
36 |
+
attack_features : bool
|
37 |
+
whether to attack node features
|
38 |
+
device: str
|
39 |
+
'cpu' or 'cuda'
|
40 |
+
|
41 |
+
Examples
|
42 |
+
--------
|
43 |
+
|
44 |
+
>>> from deeprobust.graph.data import Dataset
|
45 |
+
>>> from deeprobust.graph.defense import GCN
|
46 |
+
>>> from deeprobust.graph.targeted_attack import IGAttack
|
47 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
48 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
49 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
50 |
+
>>> # Setup Surrogate model
|
51 |
+
>>> surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
52 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device='cpu').to('cpu')
|
53 |
+
>>> surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
54 |
+
>>> # Setup Attack Model
|
55 |
+
>>> target_node = 0
|
56 |
+
>>> model = IGAttack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device='cpu').to('cpu')
|
57 |
+
>>> # Attack
|
58 |
+
>>> model.attack(features, adj, labels, idx_train, target_node, n_perturbations=5, steps=10)
|
59 |
+
>>> modified_adj = model.modified_adj
|
60 |
+
>>> modified_features = model.modified_features
|
61 |
+
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, model, nnodes=None, feature_shape=None, attack_structure=True, attack_features=True, device='cpu'):
|
65 |
+
|
66 |
+
super(IGAttack, self).__init__(model, nnodes, attack_structure, attack_features, device)
|
67 |
+
|
68 |
+
assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False'
|
69 |
+
|
70 |
+
self.modified_adj = None
|
71 |
+
self.modified_features = None
|
72 |
+
self.target_node = None
|
73 |
+
|
74 |
+
def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_perturbations, steps=10, **kwargs):
|
75 |
+
"""Generate perturbations on the input graph.
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
ori_features :
|
80 |
+
Original (unperturbed) node feature matrix
|
81 |
+
ori_adj :
|
82 |
+
Original (unperturbed) adjacency matrix
|
83 |
+
labels :
|
84 |
+
node labels
|
85 |
+
idx_train:
|
86 |
+
training nodes indices
|
87 |
+
target_node : int
|
88 |
+
target node index to be attacked
|
89 |
+
n_perturbations : int
|
90 |
+
Number of perturbations on the input graph. Perturbations could
|
91 |
+
be edge removals/additions or feature removals/additions.
|
92 |
+
steps : int
|
93 |
+
steps for computing integrated gradients
|
94 |
+
"""
|
95 |
+
|
96 |
+
self.surrogate.eval()
|
97 |
+
self.target_node = target_node
|
98 |
+
|
99 |
+
|
100 |
+
modified_adj = ori_adj.todense()
|
101 |
+
modified_features = ori_features.todense()
|
102 |
+
adj, features, labels = utils.to_tensor(modified_adj, modified_features, labels, device=self.device)
|
103 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
104 |
+
|
105 |
+
pseudo_labels = self.surrogate.predict().detach().argmax(1)
|
106 |
+
pseudo_labels[idx_train] = labels[idx_train]
|
107 |
+
self.pseudo_labels = pseudo_labels
|
108 |
+
|
109 |
+
s_e = np.zeros(adj.shape[1])
|
110 |
+
s_f = np.zeros(features.shape[1])
|
111 |
+
if self.attack_structure:
|
112 |
+
s_e = self.calc_importance_edge(features, adj_norm, labels, steps)
|
113 |
+
if self.attack_features:
|
114 |
+
s_f = self.calc_importance_feature(features, adj_norm, labels, steps)
|
115 |
+
|
116 |
+
for t in (range(n_perturbations)):
|
117 |
+
s_e_max = np.argmax(s_e)
|
118 |
+
s_f_max = np.argmax(s_f)
|
119 |
+
|
120 |
+
if s_e[s_e_max] >= s_f[s_f_max]:
|
121 |
+
# edge perturbation score is larger
|
122 |
+
if self.attack_structure:
|
123 |
+
value = np.abs(1 - modified_adj[target_node, s_e_max])
|
124 |
+
modified_adj[target_node, s_e_max] = value
|
125 |
+
modified_adj[s_e_max, target_node] = value
|
126 |
+
s_e[s_e_max] = 0
|
127 |
+
else:
|
128 |
+
raise Exception("""No posisble perturbation on the structure can be made!
|
129 |
+
See https://github.com/DSE-MSU/DeepRobust/issues/42 for more details.""")
|
130 |
+
else:
|
131 |
+
# feature perturbation score is larger
|
132 |
+
if self.attack_features:
|
133 |
+
modified_features[target_node, s_f_max] = np.abs(1 - modified_features[target_node, s_f_max])
|
134 |
+
s_f[s_f_max] = 0
|
135 |
+
else:
|
136 |
+
raise Exception("""No posisble perturbation on the features can be made!
|
137 |
+
See https://github.com/DSE-MSU/DeepRobust/issues/42 for more details.""")
|
138 |
+
|
139 |
+
|
140 |
+
self.modified_adj = sp.csr_matrix(modified_adj)
|
141 |
+
self.modified_features = sp.csr_matrix(modified_features)
|
142 |
+
self.check_adj(modified_adj)
|
143 |
+
|
144 |
+
def calc_importance_edge(self, features, adj_norm, labels, steps):
|
145 |
+
"""Calculate integrated gradient for edges. Although I think the the gradient should be
|
146 |
+
with respect to adj instead of adj_norm, but the calculation is too time-consuming. So I
|
147 |
+
finally decided to calculate the gradient of loss with respect to adj_norm
|
148 |
+
"""
|
149 |
+
baseline_add = adj_norm.clone()
|
150 |
+
baseline_remove = adj_norm.clone()
|
151 |
+
baseline_add.data[self.target_node] = 1
|
152 |
+
baseline_remove.data[self.target_node] = 0
|
153 |
+
adj_norm.requires_grad = True
|
154 |
+
integrated_grad_list = []
|
155 |
+
|
156 |
+
i = self.target_node
|
157 |
+
for j in tqdm(range(adj_norm.shape[1])):
|
158 |
+
if adj_norm[i][j]:
|
159 |
+
scaled_inputs = [baseline_remove + (float(k)/ steps) * (adj_norm - baseline_remove) for k in range(0, steps + 1)]
|
160 |
+
else:
|
161 |
+
scaled_inputs = [baseline_add - (float(k)/ steps) * (baseline_add - adj_norm) for k in range(0, steps + 1)]
|
162 |
+
_sum = 0
|
163 |
+
|
164 |
+
for new_adj in scaled_inputs:
|
165 |
+
output = self.surrogate(features, new_adj)
|
166 |
+
loss = F.nll_loss(output[[self.target_node]],
|
167 |
+
self.pseudo_labels[[self.target_node]])
|
168 |
+
adj_grad = torch.autograd.grad(loss, adj_norm)[0]
|
169 |
+
adj_grad = adj_grad[i][j]
|
170 |
+
_sum += adj_grad
|
171 |
+
|
172 |
+
if adj_norm[i][j]:
|
173 |
+
avg_grad = (adj_norm[i][j] - 0) * _sum.mean()
|
174 |
+
else:
|
175 |
+
avg_grad = (1 - adj_norm[i][j]) * _sum.mean()
|
176 |
+
|
177 |
+
integrated_grad_list.append(avg_grad.detach().item())
|
178 |
+
|
179 |
+
integrated_grad_list[i] = 0
|
180 |
+
# make impossible perturbation to be negative
|
181 |
+
integrated_grad_list = np.array(integrated_grad_list)
|
182 |
+
adj = (adj_norm > 0).cpu().numpy()
|
183 |
+
integrated_grad_list = (-2 * adj[self.target_node] + 1) * integrated_grad_list
|
184 |
+
integrated_grad_list[self.target_node] = -10
|
185 |
+
return integrated_grad_list
|
186 |
+
|
187 |
+
def calc_importance_feature(self, features, adj_norm, labels, steps):
|
188 |
+
"""Calculate integrated gradient for features
|
189 |
+
"""
|
190 |
+
baseline_add = features.clone()
|
191 |
+
baseline_remove = features.clone()
|
192 |
+
baseline_add.data[self.target_node] = 1
|
193 |
+
baseline_remove.data[self.target_node] = 0
|
194 |
+
|
195 |
+
features.requires_grad = True
|
196 |
+
integrated_grad_list = []
|
197 |
+
i = self.target_node
|
198 |
+
for j in tqdm(range(features.shape[1])):
|
199 |
+
if features[i][j]:
|
200 |
+
scaled_inputs = [baseline_add + (float(k)/ steps) * (features - baseline_add) for k in range(0, steps + 1)]
|
201 |
+
else:
|
202 |
+
scaled_inputs = [baseline_remove - (float(k)/ steps) * (baseline_remove - features) for k in range(0, steps + 1)]
|
203 |
+
_sum = 0
|
204 |
+
|
205 |
+
for new_features in scaled_inputs:
|
206 |
+
output = self.surrogate(new_features, adj_norm)
|
207 |
+
loss = F.nll_loss(output[[self.target_node]],
|
208 |
+
self.pseudo_labels[[self.target_node]])
|
209 |
+
|
210 |
+
feature_grad = torch.autograd.grad(loss, features)[0]
|
211 |
+
feature_grad = feature_grad[i][j]
|
212 |
+
_sum += feature_grad
|
213 |
+
|
214 |
+
if features[i][j]:
|
215 |
+
avg_grad = (features[i][j] - 0) * _sum.mean()
|
216 |
+
else:
|
217 |
+
avg_grad = (1 - features[i][j]) * _sum.mean()
|
218 |
+
integrated_grad_list.append(avg_grad.detach().item())
|
219 |
+
# make impossible perturbation to be negative
|
220 |
+
features = (features > 0).cpu().numpy()
|
221 |
+
integrated_grad_list = np.array(integrated_grad_list)
|
222 |
+
integrated_grad_list = (-2 * features[self.target_node] + 1) * integrated_grad_list
|
223 |
+
return integrated_grad_list
|
224 |
+
|
deeprobust/graph/targeted_attack/nettack.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Attacks on Neural Networks for Graph Data. KDD 2018.
|
3 |
+
https://arxiv.org/pdf/1805.07984.pdf
|
4 |
+
Author's Implementation
|
5 |
+
https://github.com/danielzuegner/nettack
|
6 |
+
|
7 |
+
Since pytorch does not have good enough support to the operations
|
8 |
+
on sparse tensor, this part of code is heavily based on the author's implementation.
|
9 |
+
"""
|
10 |
+
"""
|
11 |
+
Implementation of the method proposed in the paper:
|
12 |
+
'Adversarial Attacks on Neural Networks for Graph Data'
|
13 |
+
by Daniel Zügner, Amir Akbarnejad and Stephan Günnemann,
|
14 |
+
published at SIGKDD'18, August 2018, London, UK
|
15 |
+
Copyright (C) 2018
|
16 |
+
Daniel Zügner
|
17 |
+
Technical University of Munich
|
18 |
+
"""
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from deeprobust.graph.targeted_attack import BaseAttack
|
22 |
+
from torch.nn.parameter import Parameter
|
23 |
+
from deeprobust.graph import utils
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from torch import optim
|
26 |
+
from torch.nn import functional as F
|
27 |
+
from torch.nn.modules.module import Module
|
28 |
+
from torch.nn.parameter import Parameter
|
29 |
+
import numpy as np
|
30 |
+
import scipy.sparse as sp
|
31 |
+
from copy import deepcopy
|
32 |
+
from numba import jit
|
33 |
+
from torch import spmm
|
34 |
+
|
35 |
+
class Nettack(BaseAttack):
|
36 |
+
"""Nettack.
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
model :
|
41 |
+
model to attack
|
42 |
+
nnodes : int
|
43 |
+
number of nodes in the input graph
|
44 |
+
attack_structure : bool
|
45 |
+
whether to attack graph structure
|
46 |
+
attack_features : bool
|
47 |
+
whether to attack node features
|
48 |
+
device: str
|
49 |
+
'cpu' or 'cuda'
|
50 |
+
|
51 |
+
Examples
|
52 |
+
--------
|
53 |
+
|
54 |
+
>>> from deeprobust.graph.data import Dataset
|
55 |
+
>>> from deeprobust.graph.defense import GCN
|
56 |
+
>>> from deeprobust.graph.targeted_attack import Nettack
|
57 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
58 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
59 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
60 |
+
>>> # Setup Surrogate model
|
61 |
+
>>> surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
62 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device='cpu').to('cpu')
|
63 |
+
>>> surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
64 |
+
>>> # Setup Attack Model
|
65 |
+
>>> target_node = 0
|
66 |
+
>>> model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device='cpu').to('cpu')
|
67 |
+
>>> # Attack
|
68 |
+
>>> model.attack(features, adj, labels, target_node, n_perturbations=5)
|
69 |
+
>>> modified_adj = model.modified_adj
|
70 |
+
>>> modified_features = model.modified_features
|
71 |
+
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, model, nnodes=None, attack_structure=True, attack_features=False, device='cpu'):
|
75 |
+
|
76 |
+
super(Nettack, self).__init__(model, nnodes, attack_structure=attack_structure, attack_features=attack_features, device=device)
|
77 |
+
|
78 |
+
self.structure_perturbations = []
|
79 |
+
self.feature_perturbations = []
|
80 |
+
self.influencer_nodes = []
|
81 |
+
self.potential_edges = []
|
82 |
+
|
83 |
+
self.cooc_constraint = None
|
84 |
+
|
85 |
+
def filter_potential_singletons(self, modified_adj):
|
86 |
+
"""Computes a mask for entries potentially leading to singleton nodes, i.e.
|
87 |
+
one of the two nodes corresponding to the entry have degree 1 and there
|
88 |
+
is an edge between the two nodes.
|
89 |
+
"""
|
90 |
+
|
91 |
+
degrees = modified_adj.sum(0)
|
92 |
+
degree_one = (degrees == 1)
|
93 |
+
resh = degree_one.repeat(self.nnodes, 1).float()
|
94 |
+
l_and = resh * modified_adj
|
95 |
+
logical_and_symmetric = l_and + l_and.t()
|
96 |
+
flat_mask = 1 - logical_and_symmetric
|
97 |
+
return flat_mask
|
98 |
+
|
99 |
+
def get_linearized_weight(self):
|
100 |
+
surrogate = self.surrogate
|
101 |
+
W = surrogate.gc1.weight @ surrogate.gc2.weight
|
102 |
+
return W.detach().cpu().numpy()
|
103 |
+
|
104 |
+
def attack(self, features, adj, labels, target_node, n_perturbations, direct=True, n_influencers= 0, ll_cutoff=0.004, verbose=True, **kwargs):
|
105 |
+
"""Generate perturbations on the input graph.
|
106 |
+
|
107 |
+
Parameters
|
108 |
+
----------
|
109 |
+
ori_features : torch.Tensor or scipy.sparse.csr_matrix
|
110 |
+
Origina (unperturbed) node feature matrix. Note that
|
111 |
+
torch.Tensor will be automatically transformed into
|
112 |
+
scipy.sparse.csr_matrix
|
113 |
+
ori_adj : torch.Tensor or scipy.sparse.csr_matrix
|
114 |
+
Original (unperturbed) adjacency matrix. Note that
|
115 |
+
torch.Tensor will be automatically transformed into
|
116 |
+
scipy.sparse.csr_matrix
|
117 |
+
labels :
|
118 |
+
node labels
|
119 |
+
target_node : int
|
120 |
+
target node index to be attacked
|
121 |
+
n_perturbations : int
|
122 |
+
Number of perturbations on the input graph. Perturbations could
|
123 |
+
be edge removals/additions or feature removals/additions.
|
124 |
+
direct: bool
|
125 |
+
whether to conduct direct attack
|
126 |
+
n_influencers:
|
127 |
+
number of influencer nodes when performing indirect attack.
|
128 |
+
(setting `direct` to False). When `direct` is True, it would be ignored.
|
129 |
+
ll_cutoff : float
|
130 |
+
The critical value for the likelihood ratio test of the power law distributions.
|
131 |
+
See the Chi square distribution with one degree of freedom. Default value 0.004
|
132 |
+
corresponds to a p-value of roughly 0.95.
|
133 |
+
verbose : bool
|
134 |
+
whether to show verbose logs
|
135 |
+
"""
|
136 |
+
|
137 |
+
if self.nnodes is None:
|
138 |
+
self.nnodes = adj.shape[0]
|
139 |
+
|
140 |
+
self.target_node = target_node
|
141 |
+
|
142 |
+
if type(adj) is torch.Tensor:
|
143 |
+
self.ori_adj = utils.to_scipy(adj).tolil()
|
144 |
+
self.modified_adj = utils.to_scipy(adj).tolil()
|
145 |
+
self.ori_features = utils.to_scipy(features).tolil()
|
146 |
+
self.modified_features = utils.to_scipy(features).tolil()
|
147 |
+
else:
|
148 |
+
self.ori_adj = adj.tolil()
|
149 |
+
self.modified_adj = adj.tolil()
|
150 |
+
self.ori_features = features.tolil()
|
151 |
+
self.modified_features = features.tolil()
|
152 |
+
|
153 |
+
self.cooc_matrix = self.modified_features.T.dot(self.modified_features).tolil()
|
154 |
+
|
155 |
+
attack_features = self.attack_features
|
156 |
+
attack_structure = self.attack_structure
|
157 |
+
assert not (direct==False and n_influencers==0), "indirect mode requires at least one influencer node"
|
158 |
+
assert n_perturbations > 0, "need at least one perturbation"
|
159 |
+
assert attack_features or attack_structure, "either attack_features or attack_structure must be true"
|
160 |
+
|
161 |
+
# adj_norm = utils.normalize_adj_tensor(modified_adj, sparse=True)
|
162 |
+
self.adj_norm = utils.normalize_adj(self.modified_adj)
|
163 |
+
self.W = self.get_linearized_weight()
|
164 |
+
|
165 |
+
logits = (self.adj_norm @ self.adj_norm @ self.modified_features @ self.W )[target_node]
|
166 |
+
|
167 |
+
self.label_u = labels[target_node]
|
168 |
+
label_target_onehot = np.eye(int(self.nclass))[labels[target_node]]
|
169 |
+
best_wrong_class = (logits - 1000*label_target_onehot).argmax()
|
170 |
+
surrogate_losses = [logits[labels[target_node]] - logits[best_wrong_class]]
|
171 |
+
|
172 |
+
if verbose:
|
173 |
+
print("##### Starting attack #####")
|
174 |
+
if attack_structure and attack_features:
|
175 |
+
print("##### Attack node with ID {} using structure and feature perturbations #####".format(target_node))
|
176 |
+
elif attack_features:
|
177 |
+
print("##### Attack only using feature perturbations #####")
|
178 |
+
elif attack_structure:
|
179 |
+
print("##### Attack only using structure perturbations #####")
|
180 |
+
if direct:
|
181 |
+
print("##### Attacking the node directly #####")
|
182 |
+
else:
|
183 |
+
print("##### Attacking the node indirectly via {} influencer nodes #####".format(n_influencers))
|
184 |
+
print("##### Performing {} perturbations #####".format(n_perturbations))
|
185 |
+
|
186 |
+
if attack_structure:
|
187 |
+
# Setup starting values of the likelihood ratio test.
|
188 |
+
degree_sequence_start = self.ori_adj.sum(0).A1
|
189 |
+
current_degree_sequence = self.modified_adj.sum(0).A1
|
190 |
+
d_min = 2
|
191 |
+
|
192 |
+
S_d_start = np.sum(np.log(degree_sequence_start[degree_sequence_start >= d_min]))
|
193 |
+
current_S_d = np.sum(np.log(current_degree_sequence[current_degree_sequence >= d_min]))
|
194 |
+
n_start = np.sum(degree_sequence_start >= d_min)
|
195 |
+
current_n = np.sum(current_degree_sequence >= d_min)
|
196 |
+
alpha_start = compute_alpha(n_start, S_d_start, d_min)
|
197 |
+
|
198 |
+
log_likelihood_orig = compute_log_likelihood(n_start, alpha_start, S_d_start, d_min)
|
199 |
+
|
200 |
+
if len(self.influencer_nodes) == 0:
|
201 |
+
if not direct:
|
202 |
+
# Choose influencer nodes
|
203 |
+
infls, add_infls = self.get_attacker_nodes(n_influencers, add_additional_nodes=True)
|
204 |
+
self.influencer_nodes = np.concatenate((infls, add_infls)).astype("int")
|
205 |
+
# Potential edges are all edges from any attacker to any other node, except the respective
|
206 |
+
# attacker itself or the node being attacked.
|
207 |
+
self.potential_edges = np.row_stack([np.column_stack((np.tile(infl, self.nnodes - 2),
|
208 |
+
np.setdiff1d(np.arange(self.nnodes),
|
209 |
+
np.array([target_node,infl])))) for infl in
|
210 |
+
self.influencer_nodes])
|
211 |
+
if verbose:
|
212 |
+
print("Influencer nodes: {}".format(self.influencer_nodes))
|
213 |
+
else:
|
214 |
+
# direct attack
|
215 |
+
influencers = [target_node]
|
216 |
+
self.potential_edges = np.column_stack((np.tile(target_node, self.nnodes-1), np.setdiff1d(np.arange(self.nnodes), target_node)))
|
217 |
+
self.influencer_nodes = np.array(influencers)
|
218 |
+
|
219 |
+
self.potential_edges = self.potential_edges.astype("int32")
|
220 |
+
|
221 |
+
for _ in range(n_perturbations):
|
222 |
+
if verbose:
|
223 |
+
print("##### ...{}/{} perturbations ... #####".format(_+1, n_perturbations))
|
224 |
+
if attack_structure:
|
225 |
+
|
226 |
+
# Do not consider edges that, if removed, result in singleton edges in the graph.
|
227 |
+
singleton_filter = filter_singletons(self.potential_edges, self.modified_adj)
|
228 |
+
filtered_edges = self.potential_edges[singleton_filter]
|
229 |
+
|
230 |
+
# Update the values for the power law likelihood ratio test.
|
231 |
+
|
232 |
+
deltas = 2 * (1 - self.modified_adj[tuple(filtered_edges.T)].toarray()[0] )- 1
|
233 |
+
d_edges_old = current_degree_sequence[filtered_edges]
|
234 |
+
d_edges_new = current_degree_sequence[filtered_edges] + deltas[:, None]
|
235 |
+
new_S_d, new_n = update_Sx(current_S_d, current_n, d_edges_old, d_edges_new, d_min)
|
236 |
+
new_alphas = compute_alpha(new_n, new_S_d, d_min)
|
237 |
+
new_ll = compute_log_likelihood(new_n, new_alphas, new_S_d, d_min)
|
238 |
+
alphas_combined = compute_alpha(new_n + n_start, new_S_d + S_d_start, d_min)
|
239 |
+
new_ll_combined = compute_log_likelihood(new_n + n_start, alphas_combined, new_S_d + S_d_start, d_min)
|
240 |
+
new_ratios = -2 * new_ll_combined + 2 * (new_ll + log_likelihood_orig)
|
241 |
+
|
242 |
+
# Do not consider edges that, if added/removed, would lead to a violation of the
|
243 |
+
# likelihood ration Chi_square cutoff value.
|
244 |
+
powerlaw_filter = filter_chisquare(new_ratios, ll_cutoff)
|
245 |
+
filtered_edges_final = filtered_edges[powerlaw_filter]
|
246 |
+
|
247 |
+
# Compute new entries in A_hat_square_uv
|
248 |
+
a_hat_uv_new = self.compute_new_a_hat_uv(filtered_edges_final, target_node)
|
249 |
+
# Compute the struct scores for each potential edge
|
250 |
+
struct_scores = self.struct_score(a_hat_uv_new, self.modified_features @ self.W)
|
251 |
+
best_edge_ix = struct_scores.argmin()
|
252 |
+
best_edge_score = struct_scores.min()
|
253 |
+
best_edge = filtered_edges_final[best_edge_ix]
|
254 |
+
|
255 |
+
if attack_features:
|
256 |
+
# Compute the feature scores for each potential feature perturbation
|
257 |
+
feature_ixs, feature_scores = self.feature_scores()
|
258 |
+
best_feature_ix = feature_ixs[0]
|
259 |
+
best_feature_score = feature_scores[0]
|
260 |
+
|
261 |
+
if attack_structure and attack_features:
|
262 |
+
# decide whether to choose an edge or feature to change
|
263 |
+
if best_edge_score < best_feature_score:
|
264 |
+
if verbose:
|
265 |
+
print("Edge perturbation: {}".format(best_edge))
|
266 |
+
change_structure = True
|
267 |
+
else:
|
268 |
+
if verbose:
|
269 |
+
print("Feature perturbation: {}".format(best_feature_ix))
|
270 |
+
change_structure=False
|
271 |
+
|
272 |
+
elif attack_structure:
|
273 |
+
change_structure = True
|
274 |
+
elif attack_features:
|
275 |
+
change_structure = False
|
276 |
+
|
277 |
+
if change_structure:
|
278 |
+
# perform edge perturbation
|
279 |
+
self.modified_adj[tuple(best_edge)] = self.modified_adj[tuple(best_edge[::-1])] = 1 - self.modified_adj[tuple(best_edge)]
|
280 |
+
self.adj_norm = utils.normalize_adj(self.modified_adj)
|
281 |
+
|
282 |
+
self.structure_perturbations.append(tuple(best_edge))
|
283 |
+
self.feature_perturbations.append(())
|
284 |
+
surrogate_losses.append(best_edge_score)
|
285 |
+
|
286 |
+
# Update likelihood ratio test values
|
287 |
+
current_S_d = new_S_d[powerlaw_filter][best_edge_ix]
|
288 |
+
current_n = new_n[powerlaw_filter][best_edge_ix]
|
289 |
+
current_degree_sequence[best_edge] += deltas[powerlaw_filter][best_edge_ix]
|
290 |
+
|
291 |
+
else:
|
292 |
+
self.modified_features[tuple(best_feature_ix)] = 1 - self.modified_features[tuple(best_feature_ix)]
|
293 |
+
self.feature_perturbations.append(tuple(best_feature_ix))
|
294 |
+
self.structure_perturbations.append(())
|
295 |
+
surrogate_losses.append(best_feature_score)
|
296 |
+
|
297 |
+
# return self.modified_adj, self.modified_features
|
298 |
+
|
299 |
+
def get_attacker_nodes(self, n=5, add_additional_nodes = False):
|
300 |
+
"""Determine the influencer nodes to attack node i based on
|
301 |
+
the weights W and the attributes X.
|
302 |
+
"""
|
303 |
+
assert n < self.nnodes-1, "number of influencers cannot be >= number of nodes in the graph!"
|
304 |
+
neighbors = self.ori_adj[self.target_node].nonzero()[1]
|
305 |
+
assert self.target_node not in neighbors
|
306 |
+
|
307 |
+
potential_edges = np.column_stack((np.tile(self.target_node, len(neighbors)),neighbors)).astype("int32")
|
308 |
+
|
309 |
+
# The new A_hat_square_uv values that we would get if we removed the edge from u to each of the neighbors, respectively
|
310 |
+
a_hat_uv = self.compute_new_a_hat_uv(potential_edges, self.target_node)
|
311 |
+
|
312 |
+
# XW = self.compute_XW()
|
313 |
+
XW = self.modified_features @ self.W
|
314 |
+
|
315 |
+
# compute the struct scores for all neighbors
|
316 |
+
struct_scores = self.struct_score(a_hat_uv, XW)
|
317 |
+
if len(neighbors) >= n: # do we have enough neighbors for the number of desired influencers?
|
318 |
+
influencer_nodes = neighbors[np.argsort(struct_scores)[:n]]
|
319 |
+
if add_additional_nodes:
|
320 |
+
return influencer_nodes, np.array([])
|
321 |
+
return influencer_nodes
|
322 |
+
else:
|
323 |
+
|
324 |
+
influencer_nodes = neighbors
|
325 |
+
if add_additional_nodes: # Add additional influencers by connecting them to u first.
|
326 |
+
# Compute the set of possible additional influencers, i.e. all nodes except the ones
|
327 |
+
# that are already connected to u.
|
328 |
+
poss_add_infl = np.setdiff1d(np.setdiff1d(np.arange(self.nnodes),neighbors), self.target_node)
|
329 |
+
n_possible_additional = len(poss_add_infl)
|
330 |
+
n_additional_attackers = n-len(neighbors)
|
331 |
+
possible_edges = np.column_stack((np.tile(self.target_node, n_possible_additional), poss_add_infl))
|
332 |
+
|
333 |
+
# Compute the struct_scores for all possible additional influencers, and choose the one
|
334 |
+
# with the best struct score.
|
335 |
+
a_hat_uv_additional = self.compute_new_a_hat_uv(possible_edges, self.target_node)
|
336 |
+
additional_struct_scores = self.struct_score(a_hat_uv_additional, XW)
|
337 |
+
additional_influencers = poss_add_infl[np.argsort(additional_struct_scores)[-n_additional_attackers::]]
|
338 |
+
|
339 |
+
return influencer_nodes, additional_influencers
|
340 |
+
else:
|
341 |
+
return influencer_nodes
|
342 |
+
|
343 |
+
def compute_logits(self):
|
344 |
+
return (self.adj_norm @ self.adj_norm @ self.modified_features @ self.W)[self.target_node]
|
345 |
+
|
346 |
+
def strongest_wrong_class(self, logits):
|
347 |
+
label_u_onehot = np.eye(self.nclass)[self.label_u]
|
348 |
+
return (logits - 1000*label_u_onehot).argmax()
|
349 |
+
|
350 |
+
def feature_scores(self):
|
351 |
+
"""Compute feature scores for all possible feature changes.
|
352 |
+
"""
|
353 |
+
|
354 |
+
if self.cooc_constraint is None:
|
355 |
+
self.compute_cooccurrence_constraint(self.influencer_nodes)
|
356 |
+
logits = self.compute_logits()
|
357 |
+
best_wrong_class = self.strongest_wrong_class(logits)
|
358 |
+
surrogate_loss = logits[self.label_u] - logits[best_wrong_class]
|
359 |
+
|
360 |
+
gradient = self.gradient_wrt_x(self.label_u) - self.gradient_wrt_x(best_wrong_class)
|
361 |
+
# gradients_flipped = (gradient * -1).tolil()
|
362 |
+
gradients_flipped = sp.lil_matrix(gradient * -1)
|
363 |
+
gradients_flipped[self.modified_features.nonzero()] *= -1
|
364 |
+
|
365 |
+
X_influencers = sp.lil_matrix(self.modified_features.shape)
|
366 |
+
X_influencers[self.influencer_nodes] = self.modified_features[self.influencer_nodes]
|
367 |
+
gradients_flipped = gradients_flipped.multiply((self.cooc_constraint + X_influencers) > 0)
|
368 |
+
nnz_ixs = np.array(gradients_flipped.nonzero()).T
|
369 |
+
|
370 |
+
sorting = np.argsort(gradients_flipped[tuple(nnz_ixs.T)]).A1
|
371 |
+
sorted_ixs = nnz_ixs[sorting]
|
372 |
+
grads = gradients_flipped[tuple(nnz_ixs[sorting].T)]
|
373 |
+
|
374 |
+
scores = surrogate_loss - grads
|
375 |
+
return sorted_ixs[::-1], scores.A1[::-1]
|
376 |
+
|
377 |
+
def compute_cooccurrence_constraint(self, nodes):
|
378 |
+
"""
|
379 |
+
Co-occurrence constraint as described in the paper.
|
380 |
+
|
381 |
+
Parameters
|
382 |
+
----------
|
383 |
+
nodes: np.array
|
384 |
+
Nodes whose features are considered for change
|
385 |
+
|
386 |
+
Returns
|
387 |
+
-------
|
388 |
+
np.array [len(nodes), D], dtype bool
|
389 |
+
Binary matrix of dimension len(nodes) x D. A 1 in entry n,d indicates that
|
390 |
+
we are allowed to add feature d to the features of node n.
|
391 |
+
|
392 |
+
"""
|
393 |
+
|
394 |
+
words_graph = self.cooc_matrix.copy()
|
395 |
+
D = self.modified_features.shape[1]
|
396 |
+
words_graph.setdiag(0)
|
397 |
+
words_graph = (words_graph > 0)
|
398 |
+
word_degrees = np.sum(words_graph, axis=0).A1
|
399 |
+
|
400 |
+
inv_word_degrees = np.reciprocal(word_degrees.astype(float) + 1e-8)
|
401 |
+
|
402 |
+
sd = np.zeros([self.nnodes])
|
403 |
+
for n in range(self.nnodes):
|
404 |
+
n_idx = self.modified_features[n, :].nonzero()[1]
|
405 |
+
sd[n] = np.sum(inv_word_degrees[n_idx.tolist()])
|
406 |
+
|
407 |
+
scores_matrix = sp.lil_matrix((self.nnodes, D))
|
408 |
+
|
409 |
+
for n in nodes:
|
410 |
+
common_words = words_graph.multiply(self.modified_features[n])
|
411 |
+
idegs = inv_word_degrees[common_words.nonzero()[1]]
|
412 |
+
nnz = common_words.nonzero()[0]
|
413 |
+
scores = np.array([idegs[nnz == ix].sum() for ix in range(D)])
|
414 |
+
scores_matrix[n] = scores
|
415 |
+
self.cooc_constraint = sp.csr_matrix(scores_matrix - 0.5 * sd[:, None] > 0)
|
416 |
+
|
417 |
+
def gradient_wrt_x(self, label):
|
418 |
+
# return self.adj_norm.dot(self.adj_norm)[self.target_node].T.dot(self.W[:, label].T)
|
419 |
+
return self.adj_norm.dot(self.adj_norm)[self.target_node].T.dot(self.W[:, label].reshape(1, -1))
|
420 |
+
|
421 |
+
def reset(self):
|
422 |
+
"""Reset Nettack
|
423 |
+
"""
|
424 |
+
self.modified_adj = self.ori_adj.copy()
|
425 |
+
self.modified_features = self.ori_features.copy()
|
426 |
+
self.structure_perturbations = []
|
427 |
+
self.feature_perturbations = []
|
428 |
+
self.influencer_nodes = []
|
429 |
+
self.potential_edges = []
|
430 |
+
self.cooc_constraint = None
|
431 |
+
|
432 |
+
|
433 |
+
def struct_score(self, a_hat_uv, XW):
|
434 |
+
"""
|
435 |
+
Compute structure scores, cf. Eq. 15 in the paper
|
436 |
+
|
437 |
+
Parameters
|
438 |
+
----------
|
439 |
+
a_hat_uv: sp.sparse_matrix, shape [P,2]
|
440 |
+
Entries of matrix A_hat^2_u for each potential edge (see paper for explanation)
|
441 |
+
|
442 |
+
XW: sp.sparse_matrix, shape [N, K], dtype float
|
443 |
+
The class logits for each node.
|
444 |
+
|
445 |
+
Returns
|
446 |
+
-------
|
447 |
+
np.array [P,]
|
448 |
+
The struct score for every row in a_hat_uv
|
449 |
+
"""
|
450 |
+
|
451 |
+
logits = a_hat_uv.dot(XW)
|
452 |
+
label_onehot = np.eye(XW.shape[1])[self.label_u]
|
453 |
+
best_wrong_class_logits = (logits - 1000 * label_onehot).max(1)
|
454 |
+
logits_for_correct_class = logits[:,self.label_u]
|
455 |
+
struct_scores = logits_for_correct_class - best_wrong_class_logits
|
456 |
+
|
457 |
+
return struct_scores
|
458 |
+
|
459 |
+
def compute_new_a_hat_uv(self, potential_edges, target_node):
|
460 |
+
"""
|
461 |
+
Compute the updated A_hat_square_uv entries that would result from inserting/deleting the input edges,
|
462 |
+
for every edge.
|
463 |
+
|
464 |
+
Parameters
|
465 |
+
----------
|
466 |
+
potential_edges: np.array, shape [P,2], dtype int
|
467 |
+
The edges to check.
|
468 |
+
|
469 |
+
Returns
|
470 |
+
-------
|
471 |
+
sp.sparse_matrix: updated A_hat_square_u entries, a sparse PxN matrix, where P is len(possible_edges).
|
472 |
+
"""
|
473 |
+
|
474 |
+
edges = np.array(self.modified_adj.nonzero()).T
|
475 |
+
edges_set = {tuple(x) for x in edges}
|
476 |
+
A_hat_sq = self.adj_norm @ self.adj_norm
|
477 |
+
values_before = A_hat_sq[target_node].toarray()[0]
|
478 |
+
node_ixs = np.unique(edges[:, 0], return_index=True)[1]
|
479 |
+
twohop_ixs = np.array(A_hat_sq.nonzero()).T
|
480 |
+
degrees = self.modified_adj.sum(0).A1 + 1
|
481 |
+
|
482 |
+
ixs, vals = compute_new_a_hat_uv(edges, node_ixs, edges_set, twohop_ixs, values_before, degrees,
|
483 |
+
potential_edges.astype(np.int32), target_node)
|
484 |
+
ixs_arr = np.array(ixs)
|
485 |
+
a_hat_uv = sp.coo_matrix((vals, (ixs_arr[:, 0], ixs_arr[:, 1])), shape=[len(potential_edges), self.nnodes])
|
486 |
+
|
487 |
+
return a_hat_uv
|
488 |
+
|
489 |
+
@jit(nopython=True)
|
490 |
+
def connected_after(u, v, connected_before, delta):
|
491 |
+
if u == v:
|
492 |
+
if delta == -1:
|
493 |
+
return False
|
494 |
+
else:
|
495 |
+
return True
|
496 |
+
else:
|
497 |
+
return connected_before
|
498 |
+
|
499 |
+
|
500 |
+
@jit(nopython=True)
|
501 |
+
def compute_new_a_hat_uv(edge_ixs, node_nb_ixs, edges_set, twohop_ixs, values_before, degs, potential_edges, u):
|
502 |
+
"""
|
503 |
+
Compute the new values [A_hat_square]_u for every potential edge, where u is the target node. C.f. Theorem 5.1
|
504 |
+
equation 17.
|
505 |
+
|
506 |
+
"""
|
507 |
+
N = degs.shape[0]
|
508 |
+
|
509 |
+
twohop_u = twohop_ixs[twohop_ixs[:, 0] == u, 1]
|
510 |
+
nbs_u = edge_ixs[edge_ixs[:, 0] == u, 1]
|
511 |
+
nbs_u_set = set(nbs_u)
|
512 |
+
|
513 |
+
return_ixs = []
|
514 |
+
return_values = []
|
515 |
+
|
516 |
+
for ix in range(len(potential_edges)):
|
517 |
+
edge = potential_edges[ix]
|
518 |
+
edge_set = set(edge)
|
519 |
+
degs_new = degs.copy()
|
520 |
+
delta = -2 * ((edge[0], edge[1]) in edges_set) + 1
|
521 |
+
degs_new[edge] += delta
|
522 |
+
|
523 |
+
nbs_edge0 = edge_ixs[edge_ixs[:, 0] == edge[0], 1]
|
524 |
+
nbs_edge1 = edge_ixs[edge_ixs[:, 0] == edge[1], 1]
|
525 |
+
|
526 |
+
affected_nodes = set(np.concatenate((twohop_u, nbs_edge0, nbs_edge1)))
|
527 |
+
affected_nodes = affected_nodes.union(edge_set)
|
528 |
+
a_um = edge[0] in nbs_u_set
|
529 |
+
a_un = edge[1] in nbs_u_set
|
530 |
+
|
531 |
+
a_un_after = connected_after(u, edge[0], a_un, delta)
|
532 |
+
a_um_after = connected_after(u, edge[1], a_um, delta)
|
533 |
+
|
534 |
+
for v in affected_nodes:
|
535 |
+
a_uv_before = v in nbs_u_set
|
536 |
+
a_uv_before_sl = a_uv_before or v == u
|
537 |
+
|
538 |
+
if v in edge_set and u in edge_set and u != v:
|
539 |
+
if delta == -1:
|
540 |
+
a_uv_after = False
|
541 |
+
else:
|
542 |
+
a_uv_after = True
|
543 |
+
else:
|
544 |
+
a_uv_after = a_uv_before
|
545 |
+
a_uv_after_sl = a_uv_after or v == u
|
546 |
+
|
547 |
+
from_ix = node_nb_ixs[v]
|
548 |
+
to_ix = node_nb_ixs[v + 1] if v < N - 1 else len(edge_ixs)
|
549 |
+
node_nbs = edge_ixs[from_ix:to_ix, 1]
|
550 |
+
node_nbs_set = set(node_nbs)
|
551 |
+
a_vm_before = edge[0] in node_nbs_set
|
552 |
+
|
553 |
+
a_vn_before = edge[1] in node_nbs_set
|
554 |
+
a_vn_after = connected_after(v, edge[0], a_vn_before, delta)
|
555 |
+
a_vm_after = connected_after(v, edge[1], a_vm_before, delta)
|
556 |
+
|
557 |
+
mult_term = 1 / np.sqrt(degs_new[u] * degs_new[v])
|
558 |
+
|
559 |
+
sum_term1 = np.sqrt(degs[u] * degs[v]) * values_before[v] - a_uv_before_sl / degs[u] - a_uv_before / \
|
560 |
+
degs[v]
|
561 |
+
sum_term2 = a_uv_after / degs_new[v] + a_uv_after_sl / degs_new[u]
|
562 |
+
sum_term3 = -((a_um and a_vm_before) / degs[edge[0]]) + (a_um_after and a_vm_after) / degs_new[edge[0]]
|
563 |
+
sum_term4 = -((a_un and a_vn_before) / degs[edge[1]]) + (a_un_after and a_vn_after) / degs_new[edge[1]]
|
564 |
+
new_val = mult_term * (sum_term1 + sum_term2 + sum_term3 + sum_term4)
|
565 |
+
|
566 |
+
return_ixs.append((ix, v))
|
567 |
+
return_values.append(new_val)
|
568 |
+
|
569 |
+
return return_ixs, return_values
|
570 |
+
|
571 |
+
def filter_singletons(edges, adj):
|
572 |
+
"""
|
573 |
+
Filter edges that, if removed, would turn one or more nodes into singleton nodes.
|
574 |
+
"""
|
575 |
+
|
576 |
+
|
577 |
+
degs = np.squeeze(np.array(np.sum(adj,0)))
|
578 |
+
existing_edges = np.squeeze(np.array(adj.tocsr()[tuple(edges.T)]))
|
579 |
+
if existing_edges.size > 0:
|
580 |
+
edge_degrees = degs[np.array(edges)] + 2*(1-existing_edges[:,None]) - 1
|
581 |
+
else:
|
582 |
+
edge_degrees = degs[np.array(edges)] + 1
|
583 |
+
|
584 |
+
zeros = edge_degrees == 0
|
585 |
+
zeros_sum = zeros.sum(1)
|
586 |
+
return zeros_sum == 0
|
587 |
+
|
588 |
+
def compute_alpha(n, S_d, d_min):
|
589 |
+
"""
|
590 |
+
Approximate the alpha of a power law distribution.
|
591 |
+
|
592 |
+
"""
|
593 |
+
|
594 |
+
return n / (S_d - n * np.log(d_min - 0.5)) + 1
|
595 |
+
|
596 |
+
|
597 |
+
def update_Sx(S_old, n_old, d_old, d_new, d_min):
|
598 |
+
"""
|
599 |
+
Update on the sum of log degrees S_d and n based on degree distribution resulting from inserting or deleting
|
600 |
+
a single edge.
|
601 |
+
"""
|
602 |
+
|
603 |
+
old_in_range = d_old >= d_min
|
604 |
+
new_in_range = d_new >= d_min
|
605 |
+
|
606 |
+
d_old_in_range = np.multiply(d_old, old_in_range)
|
607 |
+
d_new_in_range = np.multiply(d_new, new_in_range)
|
608 |
+
|
609 |
+
new_S_d = S_old - np.log(np.maximum(d_old_in_range, 1)).sum(1) + np.log(np.maximum(d_new_in_range, 1)).sum(1)
|
610 |
+
new_n = n_old - np.sum(old_in_range, 1) + np.sum(new_in_range, 1)
|
611 |
+
|
612 |
+
return new_S_d, new_n
|
613 |
+
|
614 |
+
|
615 |
+
def compute_log_likelihood(n, alpha, S_d, d_min):
|
616 |
+
"""
|
617 |
+
Compute log likelihood of the powerlaw fit.
|
618 |
+
|
619 |
+
"""
|
620 |
+
|
621 |
+
return n * np.log(alpha) + n * alpha * np.log(d_min) - (alpha + 1) * S_d
|
622 |
+
|
623 |
+
def filter_chisquare(ll_ratios, cutoff):
|
624 |
+
return ll_ratios < cutoff
|
deeprobust/graph/targeted_attack/rnd.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from deeprobust.graph.targeted_attack import BaseAttack
|
3 |
+
from torch.nn.parameter import Parameter
|
4 |
+
from copy import deepcopy
|
5 |
+
from deeprobust.graph import utils
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
from copy import deepcopy
|
9 |
+
import scipy.sparse as sp
|
10 |
+
|
11 |
+
class RND(BaseAttack):
|
12 |
+
"""As is described in Adversarial Attacks on Neural Networks for Graph Data (KDD'19),
|
13 |
+
'Rnd is an attack in which we modify the structure of the graph. Given our target node v,
|
14 |
+
in each step we randomly sample nodes u whose lable is different from v and
|
15 |
+
add the edge u,v to the graph structure
|
16 |
+
|
17 |
+
Parameters
|
18 |
+
----------
|
19 |
+
model :
|
20 |
+
model to attack
|
21 |
+
nnodes : int
|
22 |
+
number of nodes in the input graph
|
23 |
+
attack_structure : bool
|
24 |
+
whether to attack graph structure
|
25 |
+
attack_features : bool
|
26 |
+
whether to attack node features
|
27 |
+
device: str
|
28 |
+
'cpu' or 'cuda'
|
29 |
+
|
30 |
+
Examples
|
31 |
+
--------
|
32 |
+
|
33 |
+
>>> from deeprobust.graph.data import Dataset
|
34 |
+
>>> from deeprobust.graph.targeted_attack import RND
|
35 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
36 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
37 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
38 |
+
>>> # Setup Attack Model
|
39 |
+
>>> target_node = 0
|
40 |
+
>>> model = RND()
|
41 |
+
>>> # Attack
|
42 |
+
>>> model.attack(adj, labels, idx_train, target_node, n_perturbations=5)
|
43 |
+
>>> modified_adj = model.modified_adj
|
44 |
+
>>> # # You can also inject nodes
|
45 |
+
>>> # model.add_nodes(features, adj, labels, idx_train, target_node, n_added=10, n_perturbations=100)
|
46 |
+
>>> # modified_adj = model.modified_adj
|
47 |
+
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, model=None, nnodes=None, attack_structure=True, attack_features=False, device='cpu'):
|
51 |
+
super(RND, self).__init__(model, nnodes, attack_structure=attack_structure, attack_features=attack_features, device=device)
|
52 |
+
|
53 |
+
assert not self.attack_features, 'RND does NOT support attacking features except adding nodes'
|
54 |
+
|
55 |
+
def attack(self, ori_adj, labels, idx_train, target_node, n_perturbations, **kwargs):
|
56 |
+
"""
|
57 |
+
Randomly sample nodes u whose lable is different from v and
|
58 |
+
add the edge u,v to the graph structure. This baseline only
|
59 |
+
has access to true class labels in training set
|
60 |
+
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
ori_adj : scipy.sparse.csr_matrix
|
64 |
+
Original (unperturbed) adjacency matrix
|
65 |
+
labels :
|
66 |
+
node labels
|
67 |
+
idx_train :
|
68 |
+
node training indices
|
69 |
+
target_node : int
|
70 |
+
target node index to be attacked
|
71 |
+
n_perturbations : int
|
72 |
+
Number of perturbations on the input graph. Perturbations could
|
73 |
+
be edge removals/additions or feature removals/additions.
|
74 |
+
"""
|
75 |
+
# ori_adj: sp.csr_matrix
|
76 |
+
|
77 |
+
print('number of pertubations: %s' % n_perturbations)
|
78 |
+
modified_adj = ori_adj.tolil()
|
79 |
+
|
80 |
+
row = ori_adj[target_node].todense().A1
|
81 |
+
diff_label_nodes = [x for x in idx_train if labels[x] != labels[target_node] \
|
82 |
+
and row[x] == 0]
|
83 |
+
diff_label_nodes = np.random.permutation(diff_label_nodes)
|
84 |
+
|
85 |
+
if len(diff_label_nodes) >= n_perturbations:
|
86 |
+
changed_nodes = diff_label_nodes[: n_perturbations]
|
87 |
+
modified_adj[target_node, changed_nodes] = 1
|
88 |
+
modified_adj[changed_nodes, target_node] = 1
|
89 |
+
else:
|
90 |
+
changed_nodes = diff_label_nodes
|
91 |
+
unlabeled_nodes = [x for x in range(ori_adj.shape[0]) if x not in idx_train and row[x] == 0]
|
92 |
+
unlabeled_nodes = np.random.permutation(unlabeled_nodes)
|
93 |
+
changed_nodes = np.concatenate([changed_nodes,
|
94 |
+
unlabeled_nodes[: n_perturbations-len(diff_label_nodes)]])
|
95 |
+
modified_adj[target_node, changed_nodes] = 1
|
96 |
+
modified_adj[changed_nodes, target_node] = 1
|
97 |
+
|
98 |
+
self.check_adj(modified_adj)
|
99 |
+
self.modified_adj = modified_adj
|
100 |
+
# self.modified_features = modified_features
|
101 |
+
|
102 |
+
def add_nodes(self, features, ori_adj, labels, idx_train, target_node, n_added=1, n_perturbations=10, **kwargs):
|
103 |
+
"""
|
104 |
+
For each added node, first connect the target node with added fake nodes.
|
105 |
+
Then randomly connect the fake nodes with other nodes whose label is
|
106 |
+
different from target node. As for the node feature, simply copy arbitary node
|
107 |
+
"""
|
108 |
+
# ori_adj: sp.csr_matrix
|
109 |
+
print('number of pertubations: %s' % n_perturbations)
|
110 |
+
N = ori_adj.shape[0]
|
111 |
+
D = features.shape[1]
|
112 |
+
modified_adj = self.reshape_mx(ori_adj, shape=(N+n_added, N+n_added))
|
113 |
+
modified_features = self.reshape_mx(features, shape=(N+n_added, D))
|
114 |
+
|
115 |
+
diff_labels = [l for l in range(labels.max()+1) if l != labels[target_node]]
|
116 |
+
diff_labels = np.random.permutation(diff_labels)
|
117 |
+
possible_nodes = [x for x in idx_train if labels[x] == diff_labels[0]]
|
118 |
+
|
119 |
+
for fake_node in range(N, N+n_added):
|
120 |
+
sampled_nodes = np.random.permutation(possible_nodes)[: n_perturbations]
|
121 |
+
# connect the fake node with target node
|
122 |
+
modified_adj[fake_node, target_node] = 1
|
123 |
+
modified_adj[target_node, fake_node] = 1
|
124 |
+
# connect the fake node with other nodes
|
125 |
+
for node in sampled_nodes:
|
126 |
+
modified_adj[fake_node, node] = 1
|
127 |
+
modified_adj[node, fake_node] = 1
|
128 |
+
modified_features[fake_node] = features[node]
|
129 |
+
|
130 |
+
self.check_adj(modified_adj)
|
131 |
+
|
132 |
+
self.modified_adj = modified_adj
|
133 |
+
self.modified_features = modified_features
|
134 |
+
# return modified_adj, modified_features
|
135 |
+
|
136 |
+
def reshape_mx(self, mx, shape):
|
137 |
+
indices = mx.nonzero()
|
138 |
+
return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape).tolil()
|
139 |
+
|
deeprobust/graph/targeted_attack/sga.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
import scipy.sparse as sp
|
5 |
+
from collections import namedtuple
|
6 |
+
from functools import lru_cache
|
7 |
+
|
8 |
+
from torch_scatter import scatter_add
|
9 |
+
from torch_geometric.utils import k_hop_subgraph
|
10 |
+
from deeprobust.graph.targeted_attack import BaseAttack
|
11 |
+
from deeprobust.graph import utils
|
12 |
+
|
13 |
+
SubGraph = namedtuple('SubGraph', ['edge_index', 'non_edge_index',
|
14 |
+
'self_loop', 'self_loop_weight',
|
15 |
+
'edge_weight', 'non_edge_weight',
|
16 |
+
'edges_all'])
|
17 |
+
|
18 |
+
|
19 |
+
class SGAttack(BaseAttack):
|
20 |
+
"""SGAttack proposed in `Adversarial Attack on Large Scale Graph` TKDE 2021
|
21 |
+
<https://arxiv.org/abs/2009.03488>
|
22 |
+
|
23 |
+
SGAttack follows these steps::
|
24 |
+
+ training a surrogate SGC model with hop K
|
25 |
+
+ extrack a K-hop subgraph centered at target node
|
26 |
+
+ choose top-N attacker nodes that belong to the best wrong classes of the target node
|
27 |
+
+ compute gradients w.r.t to the subgraph to add or remove edges iteratively
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
model :
|
32 |
+
model to attack
|
33 |
+
nnodes : int
|
34 |
+
number of nodes in the input graph
|
35 |
+
attack_structure : bool
|
36 |
+
whether to attack graph structure
|
37 |
+
attack_features : bool
|
38 |
+
whether to attack node features
|
39 |
+
device: str
|
40 |
+
'cpu' or 'cuda'
|
41 |
+
|
42 |
+
Examples
|
43 |
+
--------
|
44 |
+
|
45 |
+
>>> from deeprobust.graph.data import Dataset
|
46 |
+
>>> from deeprobust.graph.defense import SGC
|
47 |
+
>>> data = Dataset(root='/tmp/', name='cora')
|
48 |
+
>>> adj, features, labels = data.adj, data.features, data.labels
|
49 |
+
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
50 |
+
>>> surrogate = SGC(nfeat=features.shape[1], K=3, lr=0.1,
|
51 |
+
nclass=labels.max().item() + 1, device='cuda')
|
52 |
+
>>> surrogate = surrogate.to('cuda')
|
53 |
+
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
|
54 |
+
>>> surrogate.fit(pyg_data, train_iters=200, patience=200, verbose=True) # train with earlystopping
|
55 |
+
>>> from deeprobust.graph.targeted_attack import SGAttack
|
56 |
+
>>> # Setup Attack Model
|
57 |
+
>>> target_node = 0
|
58 |
+
>>> model = SGAttack(surrogate, attack_structure=True, device=device)
|
59 |
+
>>> # Attack
|
60 |
+
>>> model.attack(features, adj, labels, target_node, n_perturbations=5)
|
61 |
+
>>> modified_adj = model.modified_adj
|
62 |
+
>>> modified_features = model.modified_features
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, model, nnodes=None, attack_structure=True, attack_features=False, device='cpu'):
|
66 |
+
|
67 |
+
super(SGAttack, self).__init__(model=None, nnodes=nnodes,
|
68 |
+
attack_structure=attack_structure, attack_features=attack_features, device=device)
|
69 |
+
|
70 |
+
self.target_node = None
|
71 |
+
self.logits = model.predict()
|
72 |
+
self.K = model.conv1.K
|
73 |
+
W = model.conv1.lin.weight.to(device)
|
74 |
+
b = model.conv1.lin.bias
|
75 |
+
if b is not None:
|
76 |
+
b = b.to(device)
|
77 |
+
|
78 |
+
self.weight, self.bias = W, b
|
79 |
+
|
80 |
+
@lru_cache(maxsize=1)
|
81 |
+
def compute_XW(self):
|
82 |
+
return F.linear(self.modified_features, self.weight)
|
83 |
+
|
84 |
+
def attack(self, features, adj, labels, target_node, n_perturbations, direct=True, n_influencers=3, **kwargs):
|
85 |
+
"""Generate perturbations on the input graph.
|
86 |
+
|
87 |
+
Parameters
|
88 |
+
----------
|
89 |
+
features :
|
90 |
+
Original (unperturbed) node feature matrix
|
91 |
+
adj :
|
92 |
+
Original (unperturbed) adjacency matrix
|
93 |
+
labels :
|
94 |
+
node labels
|
95 |
+
target_node : int
|
96 |
+
target_node node index to be attacked
|
97 |
+
n_perturbations : int
|
98 |
+
Number of perturbations on the input graph. Perturbations could
|
99 |
+
be edge removals/additions or feature removals/additions.
|
100 |
+
direct: bool
|
101 |
+
whether to conduct direct attack
|
102 |
+
n_influencers : int
|
103 |
+
number of the top influencers to choose. For direct attack, it will set as `n_perturbations`.
|
104 |
+
"""
|
105 |
+
if sp.issparse(features):
|
106 |
+
# to dense numpy matrix
|
107 |
+
features = features.A
|
108 |
+
|
109 |
+
if not torch.is_tensor(features):
|
110 |
+
features = torch.tensor(features, device=self.device)
|
111 |
+
|
112 |
+
if torch.is_tensor(adj):
|
113 |
+
adj = utils.to_scipy(adj).csr()
|
114 |
+
|
115 |
+
self.modified_features = features.requires_grad_(bool(self.attack_features))
|
116 |
+
|
117 |
+
target_label = torch.LongTensor([labels[target_node]])
|
118 |
+
best_wrong_label = torch.LongTensor([(self.logits[target_node].cpu() - 1000 * torch.eye(self.logits.size(1))[target_label]).argmax()])
|
119 |
+
|
120 |
+
self.selfloop_degree = torch.tensor(adj.sum(1).A1 + 1, device=self.device)
|
121 |
+
self.target_label = target_label.to(self.device)
|
122 |
+
self.best_wrong_label = best_wrong_label.to(self.device)
|
123 |
+
self.n_perturbations = n_perturbations
|
124 |
+
self.ori_adj = adj
|
125 |
+
self.target_node = target_node
|
126 |
+
self.direct = direct
|
127 |
+
|
128 |
+
attacker_nodes = torch.where(torch.as_tensor(labels) == best_wrong_label)[0]
|
129 |
+
subgraph = self.get_subgraph(attacker_nodes, n_influencers)
|
130 |
+
|
131 |
+
if not direct:
|
132 |
+
# for indirect attack, the edges adjacent to targeted node should not be considered
|
133 |
+
mask = torch.logical_or(subgraph.edge_index[0] == target_node, subgraph.edge_index[1] == target_node).to(self.device)
|
134 |
+
|
135 |
+
structure_perturbations = []
|
136 |
+
feature_perturbations = []
|
137 |
+
num_features = features.shape[-1]
|
138 |
+
for _ in range(n_perturbations):
|
139 |
+
edge_grad, non_edge_grad, features_grad = self.compute_gradient(subgraph)
|
140 |
+
max_structure_score = max_feature_score = 0.
|
141 |
+
|
142 |
+
if self.attack_structure:
|
143 |
+
edge_grad *= (-2 * subgraph.edge_weight + 1)
|
144 |
+
non_edge_grad *= -2 * subgraph.non_edge_weight + 1
|
145 |
+
min_grad = min(edge_grad.min().item(), non_edge_grad.min().item())
|
146 |
+
edge_grad -= min_grad
|
147 |
+
non_edge_grad -= min_grad
|
148 |
+
if not direct:
|
149 |
+
edge_grad[mask] = 0.
|
150 |
+
max_edge_grad, max_edge_idx = torch.max(edge_grad, dim=0)
|
151 |
+
max_non_edge_grad, max_non_edge_idx = torch.max(non_edge_grad, dim=0)
|
152 |
+
max_structure_score = max(max_edge_grad.item(), max_non_edge_grad.item())
|
153 |
+
|
154 |
+
if self.attack_features:
|
155 |
+
features_grad *= -2 * self.modified_features + 1
|
156 |
+
features_grad -= features_grad.min()
|
157 |
+
if not direct:
|
158 |
+
features_grad[target_node] = 0.
|
159 |
+
max_feature_grad, max_feature_idx = torch.max(features_grad.view(-1), dim=0)
|
160 |
+
max_feature_score = max_feature_grad.item()
|
161 |
+
|
162 |
+
if max_structure_score >= max_feature_score:
|
163 |
+
if max_edge_grad > max_non_edge_grad:
|
164 |
+
# remove one edge
|
165 |
+
best_edge = subgraph.edge_index[:, max_edge_idx]
|
166 |
+
subgraph.edge_weight.data[max_edge_idx] = 0.0
|
167 |
+
self.selfloop_degree[best_edge] -= 1.0
|
168 |
+
else:
|
169 |
+
# add one edge
|
170 |
+
best_edge = subgraph.non_edge_index[:, max_non_edge_idx]
|
171 |
+
subgraph.non_edge_weight.data[max_non_edge_idx] = 1.0
|
172 |
+
self.selfloop_degree[best_edge] += 1.0
|
173 |
+
|
174 |
+
u, v = best_edge.tolist()
|
175 |
+
structure_perturbations.append((u, v))
|
176 |
+
else:
|
177 |
+
u, v = divmod(max_feature_idx.item(), num_features)
|
178 |
+
feature_perturbations.append((u, v))
|
179 |
+
self.modified_features[u, v].data.fill_(1. - self.modified_features[u, v].data)
|
180 |
+
|
181 |
+
if structure_perturbations:
|
182 |
+
modified_adj = adj.tolil(copy=True)
|
183 |
+
row, col = list(zip(*structure_perturbations))
|
184 |
+
modified_adj[row, col] = modified_adj[col, row] = 1 - modified_adj[row, col].A
|
185 |
+
modified_adj = modified_adj.tocsr(copy=False)
|
186 |
+
modified_adj.eliminate_zeros()
|
187 |
+
else:
|
188 |
+
modified_adj = adj.copy()
|
189 |
+
|
190 |
+
self.modified_adj = modified_adj
|
191 |
+
self.modified_features = self.modified_features.detach().cpu().numpy()
|
192 |
+
self.structure_perturbations = structure_perturbations
|
193 |
+
self.feature_perturbations = feature_perturbations
|
194 |
+
|
195 |
+
def get_subgraph(self, attacker_nodes, n_influencers=None):
|
196 |
+
target_node = self.target_node
|
197 |
+
neighbors = self.ori_adj[target_node].indices
|
198 |
+
sub_nodes, sub_edges = self.ego_subgraph()
|
199 |
+
|
200 |
+
if self.direct or n_influencers is not None:
|
201 |
+
influencers = [target_node]
|
202 |
+
attacker_nodes = np.setdiff1d(attacker_nodes, neighbors)
|
203 |
+
else:
|
204 |
+
influencers = neighbors
|
205 |
+
|
206 |
+
subgraph = self.subgraph_processing(influencers, attacker_nodes, sub_nodes, sub_edges)
|
207 |
+
|
208 |
+
if n_influencers is not None and self.attack_structure:
|
209 |
+
if self.direct:
|
210 |
+
influencers = [target_node]
|
211 |
+
attacker_nodes = self.get_topk_influencers(subgraph, k=self.n_perturbations + 1)
|
212 |
+
|
213 |
+
else:
|
214 |
+
influencers = neighbors
|
215 |
+
attacker_nodes = self.get_topk_influencers(subgraph, k=n_influencers)
|
216 |
+
|
217 |
+
subgraph = self.subgraph_processing(influencers, attacker_nodes, sub_nodes, sub_edges)
|
218 |
+
return subgraph
|
219 |
+
|
220 |
+
def get_topk_influencers(self, subgraph, k):
|
221 |
+
_, non_edge_grad, _ = self.compute_gradient(subgraph)
|
222 |
+
_, topk_nodes = torch.topk(non_edge_grad, k=k, sorted=False)
|
223 |
+
|
224 |
+
influencers = subgraph.non_edge_index[1][topk_nodes.cpu()]
|
225 |
+
return influencers.cpu().numpy()
|
226 |
+
|
227 |
+
def subgraph_processing(self, influencers, attacker_nodes, sub_nodes, sub_edges):
|
228 |
+
if not self.attack_structure:
|
229 |
+
self_loop = sub_nodes.repeat((2, 1))
|
230 |
+
edges_all = torch.cat([sub_edges, sub_edges[[1, 0]], self_loop], dim=1)
|
231 |
+
edge_weight = torch.ones(edges_all.size(1), device=self.device)
|
232 |
+
|
233 |
+
return SubGraph(edge_index=sub_edges, non_edge_index=None,
|
234 |
+
self_loop=None, edges_all=edges_all,
|
235 |
+
edge_weight=edge_weight, non_edge_weight=None,
|
236 |
+
self_loop_weight=None)
|
237 |
+
|
238 |
+
row = np.repeat(influencers, len(attacker_nodes))
|
239 |
+
col = np.tile(attacker_nodes, len(influencers))
|
240 |
+
non_edges = np.row_stack([row, col])
|
241 |
+
|
242 |
+
if len(influencers) > 1:
|
243 |
+
mask = self.ori_adj[non_edges[0],
|
244 |
+
non_edges[1]].A1 == 0
|
245 |
+
non_edges = non_edges[:, mask]
|
246 |
+
|
247 |
+
non_edges = torch.as_tensor(non_edges, device=self.device)
|
248 |
+
unique_nodes = np.union1d(sub_nodes.tolist(), attacker_nodes)
|
249 |
+
unique_nodes = torch.as_tensor(unique_nodes, device=self.device)
|
250 |
+
self_loop = unique_nodes.repeat((2, 1))
|
251 |
+
edges_all = torch.cat([sub_edges, sub_edges[[1, 0]],
|
252 |
+
non_edges, non_edges[[1, 0]], self_loop], dim=1)
|
253 |
+
|
254 |
+
edge_weight = torch.ones(sub_edges.size(1), device=self.device).requires_grad_(bool(self.attack_structure))
|
255 |
+
non_edge_weight = torch.zeros(non_edges.size(1), device=self.device).requires_grad_(bool(self.attack_structure))
|
256 |
+
self_loop_weight = torch.ones(self_loop.size(1), device=self.device)
|
257 |
+
|
258 |
+
edge_index = sub_edges
|
259 |
+
non_edge_index = non_edges
|
260 |
+
self_loop = self_loop
|
261 |
+
|
262 |
+
subgraph = SubGraph(edge_index=edge_index, non_edge_index=non_edge_index,
|
263 |
+
self_loop=self_loop, edges_all=edges_all,
|
264 |
+
edge_weight=edge_weight, non_edge_weight=non_edge_weight,
|
265 |
+
self_loop_weight=self_loop_weight)
|
266 |
+
return subgraph
|
267 |
+
|
268 |
+
def SGCCov(self, x, edge_index, edge_weight):
|
269 |
+
row, col = edge_index
|
270 |
+
for _ in range(self.K):
|
271 |
+
src = x[row] * edge_weight.view(-1, 1)
|
272 |
+
x = scatter_add(src, col, dim=-2, dim_size=x.size(0))
|
273 |
+
return x
|
274 |
+
|
275 |
+
def compute_gradient(self, subgraph, eps=5.0):
|
276 |
+
if self.attack_structure:
|
277 |
+
edge_weight = subgraph.edge_weight
|
278 |
+
non_edge_weight = subgraph.non_edge_weight
|
279 |
+
self_loop_weight = subgraph.self_loop_weight
|
280 |
+
weights = torch.cat([edge_weight, edge_weight,
|
281 |
+
non_edge_weight, non_edge_weight,
|
282 |
+
self_loop_weight], dim=0)
|
283 |
+
else:
|
284 |
+
weights = subgraph.edge_weight
|
285 |
+
|
286 |
+
weights = self.gcn_norm(subgraph.edges_all, weights, self.selfloop_degree)
|
287 |
+
logit = self.SGCCov(self.compute_XW(), subgraph.edges_all, weights)
|
288 |
+
logit = logit[self.target_node]
|
289 |
+
if self.bias is not None:
|
290 |
+
logit += self.bias
|
291 |
+
|
292 |
+
# model calibration
|
293 |
+
logit = F.log_softmax(logit.view(1, -1) / eps, dim=1)
|
294 |
+
loss = F.nll_loss(logit, self.target_label) - F.nll_loss(logit, self.best_wrong_label)
|
295 |
+
|
296 |
+
edge_grad = non_edge_grad = features_grad = None
|
297 |
+
|
298 |
+
if self.attack_structure and self.attack_features:
|
299 |
+
edge_grad, non_edge_grad, features_grad = torch.autograd.grad(loss, [edge_weight, non_edge_weight, self.modified_features], create_graph=False)
|
300 |
+
|
301 |
+
elif self.attack_structure:
|
302 |
+
edge_grad, non_edge_grad = torch.autograd.grad(loss, [edge_weight, non_edge_weight], create_graph=False)
|
303 |
+
else:
|
304 |
+
features_grad = torch.autograd.grad(loss, self.modified_features, create_graph=False)[0]
|
305 |
+
|
306 |
+
if self.attack_features:
|
307 |
+
self.compute_XW.cache_clear()
|
308 |
+
return edge_grad, non_edge_grad, features_grad
|
309 |
+
|
310 |
+
def ego_subgraph(self):
|
311 |
+
edge_index = np.asarray(self.ori_adj.nonzero())
|
312 |
+
edge_index = torch.as_tensor(edge_index, dtype=torch.long, device=self.device)
|
313 |
+
sub_nodes, sub_edges, *_ = k_hop_subgraph(int(self.target_node), self.K, edge_index)
|
314 |
+
sub_edges = sub_edges[:, sub_edges[0] < sub_edges[1]]
|
315 |
+
|
316 |
+
return sub_nodes, sub_edges
|
317 |
+
|
318 |
+
@ staticmethod
|
319 |
+
def gcn_norm(edge_index, weights, degree):
|
320 |
+
row, col = edge_index
|
321 |
+
inv_degree = torch.pow(degree, -0.5)
|
322 |
+
normed_weights = weights * inv_degree[row] * inv_degree[col]
|
323 |
+
return normed_weights
|
deeprobust/graph/targeted_attack/ugba.py
ADDED
@@ -0,0 +1,913 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.sparse as sp
|
3 |
+
import time
|
4 |
+
import copy
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.nn.parameter import Parameter
|
9 |
+
from torch_geometric.utils import degree
|
10 |
+
from sklearn.cluster import KMeans
|
11 |
+
from copy import deepcopy
|
12 |
+
# from deeprobust.graph.defense_pyg import GCN, SAGE, GAT
|
13 |
+
from deeprobust.graph.targeted_attack import BaseAttack
|
14 |
+
from deeprobust.graph import utils
|
15 |
+
|
16 |
+
class UGBA(BaseAttack):
|
17 |
+
"""
|
18 |
+
Modified from Unnoticeable Backdoor Attacks on Graph Neural Networks (WWW 2023).
|
19 |
+
|
20 |
+
see example in examples/graph/test_ugba.py
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
vs_number: int
|
25 |
+
number of selected poisoned for training backdoor model
|
26 |
+
|
27 |
+
device: str
|
28 |
+
'cpu' or 'cuda'
|
29 |
+
|
30 |
+
target_class: int
|
31 |
+
the class that the attacker aim to misclassify into
|
32 |
+
|
33 |
+
trigger_size: int
|
34 |
+
the number of nodes in a trigger
|
35 |
+
|
36 |
+
target_loss_weight: float
|
37 |
+
|
38 |
+
homo_loss_weight: float
|
39 |
+
the weight of homophily loss
|
40 |
+
|
41 |
+
homo_boost_thrd: float
|
42 |
+
the upper bound of similarity
|
43 |
+
|
44 |
+
train_epochs: int
|
45 |
+
the number of epochs when training GCN encoder
|
46 |
+
|
47 |
+
trojan_epochs: int
|
48 |
+
the number of epochs when training trigger generator
|
49 |
+
|
50 |
+
|
51 |
+
"""
|
52 |
+
def __init__(self, data, vs_number,
|
53 |
+
target_class = 0, trigger_size = 3, target_loss_weight = 1,
|
54 |
+
homo_loss_weight = 100, homo_boost_thrd = 0.8, train_epochs = 200, trojan_epochs = 800, dis_weight = 1,
|
55 |
+
inner = 1, thrd=0.5, lr = 0.01, hidden = 32, weight_decay = 5e-4,
|
56 |
+
seed = 10, debug = True, device='cpu'):
|
57 |
+
self.device = device
|
58 |
+
self.data = data
|
59 |
+
self.size = vs_number
|
60 |
+
# self.test_model = model
|
61 |
+
self.target_class = target_class
|
62 |
+
self.trigger_size = trigger_size
|
63 |
+
self.target_loss_weight = target_loss_weight
|
64 |
+
self.homo_loss_weight = homo_loss_weight
|
65 |
+
self.homo_boost_thrd = homo_boost_thrd
|
66 |
+
self.train_epochs = train_epochs
|
67 |
+
self.trojan_epochs = trojan_epochs
|
68 |
+
self.dis_weight = dis_weight
|
69 |
+
self.inner = inner
|
70 |
+
self.thrd = thrd
|
71 |
+
self.lr = lr
|
72 |
+
self.hidden = hidden
|
73 |
+
self.weight_decay = weight_decay
|
74 |
+
self.seed = seed
|
75 |
+
self.debug = debug
|
76 |
+
|
77 |
+
# filter out the unlabeled nodes except from training nodes and testing nodes, nonzero() is to get index, flatten is to get 1-d tensor
|
78 |
+
self.unlabeled_idx = (torch.bitwise_not(data.test_mask)&torch.bitwise_not(data.train_mask)).nonzero().flatten()
|
79 |
+
self.idx_val = utils.index_to_mask(data.val_mask, size=data.x.shape[0])
|
80 |
+
def attack(self, target_node, x, y, edge_index, edge_weights = None):
|
81 |
+
'''
|
82 |
+
inject the generated trigger to the target node (a single node)
|
83 |
+
|
84 |
+
Parameters
|
85 |
+
----------
|
86 |
+
target_node: int
|
87 |
+
the index of target node
|
88 |
+
x: tensor:
|
89 |
+
features of nodes
|
90 |
+
y: tensor:
|
91 |
+
node labels
|
92 |
+
edge_index: tensor:
|
93 |
+
edge index of the graph
|
94 |
+
edge_weights: tensor:
|
95 |
+
the weights of edges
|
96 |
+
'''
|
97 |
+
idx_target = torch.tensor([target_node])
|
98 |
+
print(idx_target)
|
99 |
+
if(edge_weights == None):
|
100 |
+
edge_weights = torch.ones([edge_index.shape[1]]).to(self.device)
|
101 |
+
x, edge_index, edge_weights, y = self.inject_trigger(idx_target, x, y, edge_index, edge_weights)
|
102 |
+
return x, edge_index, edge_weights, y
|
103 |
+
|
104 |
+
def get_poisoned_graph(self):
|
105 |
+
'''
|
106 |
+
Obtain the poisoned training graph for training backdoor GNN
|
107 |
+
'''
|
108 |
+
assert self.trigger_generator, "please first use train_trigger_generator() to train trigger generator and get poisoned nodes"
|
109 |
+
poison_x, poison_edge_index, poison_edge_weights, poison_labels = self.trigger_generator.get_poisoned()
|
110 |
+
# add poisoned nodes into training nodes
|
111 |
+
idx_bkd_tn = torch.cat([self.idx_train,self.idx_attach]).to(self.device)
|
112 |
+
|
113 |
+
poison_data = copy.deepcopy(self.data)
|
114 |
+
idx_val = poison_data.val_mask.nonzero().flatten()
|
115 |
+
idx_test = poison_data.test_mask.nonzero().flatten()
|
116 |
+
|
117 |
+
poison_data.x, poison_data.edge_index, poison_data.edge_weights, poison_data.y = poison_x, poison_edge_index, poison_edge_weights, poison_labels
|
118 |
+
poison_data.train_mask = utils.index_to_mask(idx_bkd_tn, poison_data.x.shape[0])
|
119 |
+
poison_data.val_mask = utils.index_to_mask(idx_val, poison_data.x.shape[0])
|
120 |
+
poison_data.test_mask = utils.index_to_mask(idx_test, poison_data.x.shape[0])
|
121 |
+
return poison_data
|
122 |
+
|
123 |
+
def train_trigger_generator(self, idx_train, edge_index, edge_weights = None, selection_method = 'cluster', **kwargs):
|
124 |
+
"""
|
125 |
+
Train the adpative trigger generator
|
126 |
+
|
127 |
+
Parameters
|
128 |
+
----------
|
129 |
+
idx_train: tensor:
|
130 |
+
indexs of training nodes
|
131 |
+
edge_index: tensor:
|
132 |
+
edge index of the graph
|
133 |
+
edge_weights: tensor:
|
134 |
+
the weights of edges
|
135 |
+
selection method : ['none', 'cluster']
|
136 |
+
the method to select poisoned nodes
|
137 |
+
"""
|
138 |
+
self.idx_train = idx_train
|
139 |
+
# self.data = data
|
140 |
+
|
141 |
+
idx_attach = self.select_idx_attach(selection_method, edge_index, edge_weights).to(self.device)
|
142 |
+
self.idx_attach = idx_attach
|
143 |
+
print("idx_attach: {}".format(idx_attach))
|
144 |
+
# train trigger generator
|
145 |
+
trigger_generator = Backdoor(self.target_class, self.trigger_size, self.target_loss_weight,
|
146 |
+
self.homo_loss_weight, self.homo_boost_thrd, self.trojan_epochs,
|
147 |
+
self.inner, self.thrd, self.lr, self.hidden, self.weight_decay,
|
148 |
+
self.seed, self.debug, self.device)
|
149 |
+
self.trigger_generator = trigger_generator
|
150 |
+
|
151 |
+
self.trigger_generator.fit(self.data.x, edge_index, edge_weights, self.data.y, idx_train,idx_attach, self.unlabeled_idx)
|
152 |
+
return self.trigger_generator, idx_attach
|
153 |
+
|
154 |
+
def inject_trigger(self, idx_attach, x, y, edge_index, edge_weights):
|
155 |
+
"""
|
156 |
+
Attach the generated triggers with the attachde nodes
|
157 |
+
|
158 |
+
Parameters
|
159 |
+
----------
|
160 |
+
idx_attach: tensor:
|
161 |
+
indexs of to-be attached nodes
|
162 |
+
x: tensor:
|
163 |
+
features of nodes
|
164 |
+
y: tensor:
|
165 |
+
node labels
|
166 |
+
edge_index: tensor:
|
167 |
+
edge index of the graph
|
168 |
+
edge_weights: tensor:
|
169 |
+
the weights of edges
|
170 |
+
"""
|
171 |
+
assert self.trigger_generator, "please first use train_trigger_generator() to train trigger generator"
|
172 |
+
|
173 |
+
update_x, update_edge_index,update_edge_weights, update_y = self.trigger_generator.inject_trigger(idx_attach,x,edge_index,edge_weights,y,self.device)
|
174 |
+
return update_x, update_edge_index,update_edge_weights, update_y
|
175 |
+
|
176 |
+
def select_idx_attach(self, selection_method, edge_index, edge_weights = None):
|
177 |
+
if(selection_method == 'none'):
|
178 |
+
idx_attach = self.obtain_attach_nodes(self.unlabeled_idx,self.size)
|
179 |
+
elif(selection_method == 'cluster'):
|
180 |
+
idx_attach = self.cluster_selection(self.data,self.idx_train,self.idx_val,self.unlabeled_idx,self.size,edge_index,edge_weights)
|
181 |
+
idx_attach = torch.LongTensor(idx_attach).to(self.device)
|
182 |
+
return idx_attach
|
183 |
+
|
184 |
+
def obtain_attach_nodes(self,node_idxs, size):
|
185 |
+
### current random to implement
|
186 |
+
size = min(len(node_idxs),size)
|
187 |
+
rs = np.random.RandomState(self.seed)
|
188 |
+
choice = np.arange(len(node_idxs))
|
189 |
+
rs.shuffle(choice)
|
190 |
+
return node_idxs[choice[:size]]
|
191 |
+
|
192 |
+
def cluster_selection(self,data,idx_train,idx_val,unlabeled_idx,size,edge_index,edge_weights = None):
|
193 |
+
gcn_encoder = GCN_Encoder(nfeat=data.x.shape[1],
|
194 |
+
nhid=32,
|
195 |
+
nclass= int(data.y.max()+1),
|
196 |
+
dropout=0.5,
|
197 |
+
lr=0.01,
|
198 |
+
weight_decay=5e-4,
|
199 |
+
device=self.device,
|
200 |
+
use_ln=False,
|
201 |
+
layer_norm_first=False).to(self.device)
|
202 |
+
t_total = time.time()
|
203 |
+
# edge_weights = torch.ones([data.edge_index.shape[1]],device=device,dtype=torch.float)
|
204 |
+
print("Length of training set: {}".format(len(idx_train)))
|
205 |
+
gcn_encoder.fit(data.x, edge_index, edge_weights, data.y, idx_train, idx_val= idx_val,train_iters=self.train_epochs,verbose=True)
|
206 |
+
print("Training encoder Finished!")
|
207 |
+
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
|
208 |
+
|
209 |
+
seen_node_idx = torch.concat([idx_train,unlabeled_idx])
|
210 |
+
nclass = np.unique(data.y.cpu().numpy()).shape[0]
|
211 |
+
encoder_x = gcn_encoder.get_h(data.x, edge_index,edge_weights).clone().detach()
|
212 |
+
|
213 |
+
kmeans = KMeans(n_clusters=nclass,random_state=1)
|
214 |
+
kmeans.fit(encoder_x[seen_node_idx].detach().cpu().numpy())
|
215 |
+
cluster_centers = kmeans.cluster_centers_
|
216 |
+
y_pred = kmeans.predict(encoder_x.cpu().numpy())
|
217 |
+
# encoder_output = gcn_encoder(data.x,train_edge_index,None)
|
218 |
+
idx_attach = self.obtain_attach_nodes_by_cluster_degree_all(edge_index,y_pred,cluster_centers,unlabeled_idx.cpu().tolist(),encoder_x,size).astype(int)
|
219 |
+
idx_attach = idx_attach[:size]
|
220 |
+
return idx_attach
|
221 |
+
|
222 |
+
def obtain_attach_nodes_by_cluster_degree_all(self,edge_index,y_pred,cluster_centers,node_idxs,x,size):
|
223 |
+
dis_weight = self.dis_weight
|
224 |
+
degrees = (degree(edge_index[0]) + degree(edge_index[1])).cpu().numpy()
|
225 |
+
distances = []
|
226 |
+
for id in range(x.shape[0]):
|
227 |
+
tmp_center_label = y_pred[id]
|
228 |
+
tmp_center_x = cluster_centers[tmp_center_label]
|
229 |
+
|
230 |
+
dis = np.linalg.norm(tmp_center_x - x[id].detach().cpu().numpy())
|
231 |
+
distances.append(dis)
|
232 |
+
|
233 |
+
distances = np.array(distances)
|
234 |
+
print(y_pred)
|
235 |
+
|
236 |
+
nontarget_nodes = np.where(y_pred!=self.target_class)[0]
|
237 |
+
|
238 |
+
non_target_node_idxs = np.array(list(set(nontarget_nodes) & set(node_idxs)))
|
239 |
+
node_idxs = np.array(non_target_node_idxs)
|
240 |
+
candiadate_distances = distances[node_idxs]
|
241 |
+
candiadate_degrees = degrees[node_idxs]
|
242 |
+
candiadate_distances = self.max_norm(candiadate_distances)
|
243 |
+
candiadate_degrees = self.max_norm(candiadate_degrees)
|
244 |
+
|
245 |
+
dis_score = candiadate_distances + dis_weight * candiadate_degrees
|
246 |
+
candidate_nid_index = np.argsort(dis_score)
|
247 |
+
sorted_node_idex = np.array(node_idxs[candidate_nid_index])
|
248 |
+
selected_nodes = sorted_node_idex
|
249 |
+
return selected_nodes
|
250 |
+
|
251 |
+
def max_norm(self,data):
|
252 |
+
_range = np.max(data) - np.min(data)
|
253 |
+
return (data - np.min(data)) / _range
|
254 |
+
|
255 |
+
|
256 |
+
from copy import deepcopy
|
257 |
+
import torch
|
258 |
+
import torch.nn as nn
|
259 |
+
import torch.nn.functional as F
|
260 |
+
import torch.optim as optim
|
261 |
+
|
262 |
+
def accuracy(output, labels):
|
263 |
+
"""Return accuracy of output compared to labels.
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
output : torch.Tensor
|
267 |
+
output from model
|
268 |
+
labels : torch.Tensor or numpy.array
|
269 |
+
node labels
|
270 |
+
Returns
|
271 |
+
-------
|
272 |
+
float
|
273 |
+
accuracy
|
274 |
+
"""
|
275 |
+
if not hasattr(labels, '__len__'):
|
276 |
+
labels = [labels]
|
277 |
+
if type(labels) is not torch.Tensor:
|
278 |
+
labels = torch.LongTensor(labels)
|
279 |
+
preds = output.max(1)[1].type_as(labels)
|
280 |
+
correct = preds.eq(labels).double()
|
281 |
+
correct = correct.sum()
|
282 |
+
return correct / len(labels)
|
283 |
+
#%%
|
284 |
+
class GradWhere(torch.autograd.Function):
|
285 |
+
"""
|
286 |
+
We can implement our own custom autograd Functions by subclassing
|
287 |
+
torch.autograd.Function and implementing the forward and backward passes
|
288 |
+
which operate on Tensors.
|
289 |
+
"""
|
290 |
+
|
291 |
+
@staticmethod
|
292 |
+
def forward(ctx, input, thrd, device):
|
293 |
+
"""
|
294 |
+
In the forward pass we receive a Tensor containing the input and return
|
295 |
+
a Tensor containing the output. ctx is a context object that can be used
|
296 |
+
to stash information for backward computation. You can cache arbitrary
|
297 |
+
objects for use in the backward pass using the ctx.save_for_backward method.
|
298 |
+
"""
|
299 |
+
ctx.save_for_backward(input)
|
300 |
+
rst = torch.where(input>thrd, torch.tensor(1.0, device=device, requires_grad=True),
|
301 |
+
torch.tensor(0.0, device=device, requires_grad=True))
|
302 |
+
return rst
|
303 |
+
|
304 |
+
@staticmethod
|
305 |
+
def backward(ctx, grad_output):
|
306 |
+
"""
|
307 |
+
In the backward pass we receive a Tensor containing the gradient of the loss
|
308 |
+
with respect to the output, and we need to compute the gradient of the loss
|
309 |
+
with respect to the input.
|
310 |
+
"""
|
311 |
+
input, = ctx.saved_tensors
|
312 |
+
grad_input = grad_output.clone()
|
313 |
+
|
314 |
+
"""
|
315 |
+
Return results number should corresponding with .forward inputs (besides ctx),
|
316 |
+
for each input, return a corresponding backward grad
|
317 |
+
"""
|
318 |
+
return grad_input, None, None
|
319 |
+
|
320 |
+
class GraphTrojanNet(nn.Module):
|
321 |
+
# In the furture, we may use a GNN model to generate backdoor
|
322 |
+
def __init__(self, device, nfeat, nout, layernum=1, dropout=0.00):
|
323 |
+
super(GraphTrojanNet, self).__init__()
|
324 |
+
|
325 |
+
layers = []
|
326 |
+
if dropout > 0:
|
327 |
+
layers.append(nn.Dropout(p=dropout))
|
328 |
+
for l in range(layernum-1):
|
329 |
+
layers.append(nn.Linear(nfeat, nfeat))
|
330 |
+
layers.append(nn.ReLU(inplace=True))
|
331 |
+
if dropout > 0:
|
332 |
+
layers.append(nn.Dropout(p=dropout))
|
333 |
+
|
334 |
+
self.layers = nn.Sequential(*layers).to(device)
|
335 |
+
|
336 |
+
self.feat = nn.Linear(nfeat,nout*nfeat)
|
337 |
+
self.edge = nn.Linear(nfeat, int(nout*(nout-1)/2))
|
338 |
+
self.device = device
|
339 |
+
|
340 |
+
def forward(self, input, thrd):
|
341 |
+
|
342 |
+
"""
|
343 |
+
"input", "mask" and "thrd", should already in cuda before sent to this function.
|
344 |
+
If using sparse format, corresponding tensor should already in sparse format before
|
345 |
+
sent into this function
|
346 |
+
"""
|
347 |
+
|
348 |
+
GW = GradWhere.apply
|
349 |
+
self.layers = self.layers
|
350 |
+
h = self.layers(input)
|
351 |
+
|
352 |
+
feat = self.feat(h)
|
353 |
+
edge_weight = self.edge(h)
|
354 |
+
# feat = GW(feat, thrd, self.device)
|
355 |
+
edge_weight = GW(edge_weight, thrd, self.device)
|
356 |
+
|
357 |
+
return feat, edge_weight
|
358 |
+
|
359 |
+
class HomoLoss(nn.Module):
|
360 |
+
def __init__(self,device):
|
361 |
+
super(HomoLoss, self).__init__()
|
362 |
+
self.device = device
|
363 |
+
|
364 |
+
def forward(self,trigger_edge_index,trigger_edge_weights,x,thrd):
|
365 |
+
|
366 |
+
trigger_edge_index = trigger_edge_index[:,trigger_edge_weights>0.0]
|
367 |
+
edge_sims = F.cosine_similarity(x[trigger_edge_index[0]],x[trigger_edge_index[1]])
|
368 |
+
|
369 |
+
loss = torch.relu(thrd - edge_sims).mean()
|
370 |
+
# print(edge_sims.min())
|
371 |
+
return loss
|
372 |
+
|
373 |
+
#%%
|
374 |
+
import numpy as np
|
375 |
+
class Backdoor:
|
376 |
+
def __init__(self, target_class, trigger_size, target_loss_weight, homo_loss_weight, homo_boost_thrd, trojan_epochs, inner, thrd, lr, hidden, weight_decay, seed, debug, device):
|
377 |
+
self.device = device
|
378 |
+
self.weights = None
|
379 |
+
self.trigger_size = trigger_size
|
380 |
+
self.thrd = thrd
|
381 |
+
self.trigger_index = self.get_trigger_index(self.trigger_size)
|
382 |
+
self.hidden = hidden
|
383 |
+
self.target_class =target_class
|
384 |
+
self.lr = lr
|
385 |
+
self.weight_decay = weight_decay
|
386 |
+
self.trojan_epochs = trojan_epochs
|
387 |
+
self.inner = inner
|
388 |
+
self.seed = seed
|
389 |
+
self.target_loss_weight = target_loss_weight
|
390 |
+
self.homo_boost_thrd = homo_boost_thrd
|
391 |
+
self.homo_loss_weight = homo_loss_weight
|
392 |
+
self.debug = debug
|
393 |
+
def get_trigger_index(self,trigger_size):
|
394 |
+
edge_list = []
|
395 |
+
edge_list.append([0,0])
|
396 |
+
for j in range(trigger_size):
|
397 |
+
for k in range(j):
|
398 |
+
edge_list.append([j,k])
|
399 |
+
edge_index = torch.tensor(edge_list,device=self.device).long().T
|
400 |
+
return edge_index
|
401 |
+
|
402 |
+
def get_trojan_edge(self,start, idx_attach, trigger_size):
|
403 |
+
edge_list = []
|
404 |
+
for idx in idx_attach:
|
405 |
+
edges = self.trigger_index.clone()
|
406 |
+
edges[0,0] = idx
|
407 |
+
edges[1,0] = start
|
408 |
+
edges[:,1:] = edges[:,1:] + start
|
409 |
+
|
410 |
+
edge_list.append(edges)
|
411 |
+
start += trigger_size
|
412 |
+
edge_index = torch.cat(edge_list,dim=1)
|
413 |
+
# to undirected
|
414 |
+
# row, col = edge_index
|
415 |
+
row = torch.cat([edge_index[0], edge_index[1]])
|
416 |
+
col = torch.cat([edge_index[1],edge_index[0]])
|
417 |
+
edge_index = torch.stack([row,col])
|
418 |
+
|
419 |
+
return edge_index
|
420 |
+
|
421 |
+
def inject_trigger(self, idx_attach, features,edge_index,edge_weight,y,device):
|
422 |
+
self.trojan = self.trojan.to(device)
|
423 |
+
idx_attach = idx_attach.to(device)
|
424 |
+
features = features.to(device)
|
425 |
+
edge_index = edge_index.to(device)
|
426 |
+
edge_weight = edge_weight.to(device)
|
427 |
+
self.trojan.eval()
|
428 |
+
|
429 |
+
trojan_feat, trojan_weights = self.trojan(features[idx_attach],self.thrd) # may revise the process of generate
|
430 |
+
trojan_weights = torch.cat([torch.ones([len(idx_attach),1],dtype=torch.float,device=device),trojan_weights],dim=1)
|
431 |
+
trojan_weights = trojan_weights.flatten()
|
432 |
+
|
433 |
+
trojan_feat = trojan_feat.view([-1,features.shape[1]])
|
434 |
+
|
435 |
+
trojan_edge = self.get_trojan_edge(len(features),idx_attach,self.trigger_size).to(device)
|
436 |
+
|
437 |
+
update_edge_weights = torch.cat([edge_weight,trojan_weights,trojan_weights])
|
438 |
+
update_feat = torch.cat([features,trojan_feat])
|
439 |
+
update_edge_index = torch.cat([edge_index,trojan_edge],dim=1)
|
440 |
+
|
441 |
+
# update label set
|
442 |
+
update_y = torch.cat([y,-1*torch.ones([len(idx_attach)*self.trigger_size],dtype=torch.long,device=device)])
|
443 |
+
|
444 |
+
self.trojan = self.trojan.cpu()
|
445 |
+
idx_attach = idx_attach.cpu()
|
446 |
+
features = features.cpu()
|
447 |
+
edge_index = edge_index.cpu()
|
448 |
+
edge_weight = edge_weight.cpu()
|
449 |
+
return update_feat, update_edge_index, update_edge_weights, update_y
|
450 |
+
|
451 |
+
|
452 |
+
def fit(self, features, edge_index, edge_weight, labels, idx_train, idx_attach,idx_unlabeled):
|
453 |
+
|
454 |
+
if edge_weight is None:
|
455 |
+
edge_weight = torch.ones([edge_index.shape[1]],device=self.device,dtype=torch.float)
|
456 |
+
self.idx_attach = idx_attach
|
457 |
+
self.features = features
|
458 |
+
self.edge_index = edge_index
|
459 |
+
self.edge_weights = edge_weight
|
460 |
+
|
461 |
+
# initial a shadow model
|
462 |
+
self.shadow_model = GCN(nfeat=features.shape[1],
|
463 |
+
nhid=self.hidden,
|
464 |
+
nclass=labels.max().item() + 1,
|
465 |
+
dropout=0.0, device=self.device).to(self.device)
|
466 |
+
# initalize a trojanNet to generate trigger
|
467 |
+
self.trojan = GraphTrojanNet(self.device, features.shape[1], self.trigger_size, layernum=2).to(self.device)
|
468 |
+
self.homo_loss = HomoLoss(self.device)
|
469 |
+
|
470 |
+
optimizer_shadow = optim.Adam(self.shadow_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
471 |
+
optimizer_trigger = optim.Adam(self.trojan.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
472 |
+
|
473 |
+
|
474 |
+
# change the labels of the poisoned node to the target class
|
475 |
+
self.labels = labels.clone()
|
476 |
+
self.labels[idx_attach] = self.target_class
|
477 |
+
|
478 |
+
# get the trojan edges, which include the target-trigger edge and the edges among trigger
|
479 |
+
trojan_edge = self.get_trojan_edge(len(features),idx_attach,self.trigger_size).to(self.device)
|
480 |
+
|
481 |
+
# update the poisoned graph's edge index
|
482 |
+
poison_edge_index = torch.cat([edge_index,trojan_edge],dim=1)
|
483 |
+
|
484 |
+
|
485 |
+
# furture change it to bilevel optimization
|
486 |
+
|
487 |
+
loss_best = 1e8
|
488 |
+
for i in range(self.trojan_epochs):
|
489 |
+
self.trojan.train()
|
490 |
+
for j in range(self.inner):
|
491 |
+
|
492 |
+
optimizer_shadow.zero_grad()
|
493 |
+
trojan_feat, trojan_weights = self.trojan(features[idx_attach],self.thrd) # may revise the process of generate
|
494 |
+
trojan_weights = torch.cat([torch.ones([len(trojan_feat),1],dtype=torch.float,device=self.device),trojan_weights],dim=1)
|
495 |
+
trojan_weights = trojan_weights.flatten()
|
496 |
+
trojan_feat = trojan_feat.view([-1,features.shape[1]])
|
497 |
+
poison_edge_weights = torch.cat([edge_weight,trojan_weights,trojan_weights]).detach() # repeat trojan weights beacuse of undirected edge
|
498 |
+
poison_x = torch.cat([features,trojan_feat]).detach()
|
499 |
+
|
500 |
+
output = self.shadow_model(poison_x, poison_edge_index, poison_edge_weights)
|
501 |
+
|
502 |
+
loss_inner = F.nll_loss(output[torch.cat([idx_train,idx_attach])], self.labels[torch.cat([idx_train,idx_attach])]) # add our adaptive loss
|
503 |
+
|
504 |
+
loss_inner.backward()
|
505 |
+
optimizer_shadow.step()
|
506 |
+
|
507 |
+
|
508 |
+
acc_train_clean = accuracy(output[idx_train], self.labels[idx_train])
|
509 |
+
acc_train_attach = accuracy(output[idx_attach], self.labels[idx_attach])
|
510 |
+
|
511 |
+
# involve unlabeled nodes in outter optimization
|
512 |
+
self.trojan.eval()
|
513 |
+
optimizer_trigger.zero_grad()
|
514 |
+
|
515 |
+
rs = np.random.RandomState(self.seed)
|
516 |
+
idx_outter = torch.cat([idx_attach,idx_unlabeled[rs.choice(len(idx_unlabeled),size=512,replace=False)]])
|
517 |
+
|
518 |
+
trojan_feat, trojan_weights = self.trojan(features[idx_outter],self.thrd) # may revise the process of generate
|
519 |
+
|
520 |
+
trojan_weights = torch.cat([torch.ones([len(idx_outter),1],dtype=torch.float,device=self.device),trojan_weights],dim=1)
|
521 |
+
trojan_weights = trojan_weights.flatten()
|
522 |
+
|
523 |
+
trojan_feat = trojan_feat.view([-1,features.shape[1]])
|
524 |
+
|
525 |
+
trojan_edge = self.get_trojan_edge(len(features),idx_outter,self.trigger_size).to(self.device)
|
526 |
+
|
527 |
+
update_edge_weights = torch.cat([edge_weight,trojan_weights,trojan_weights])
|
528 |
+
update_feat = torch.cat([features,trojan_feat])
|
529 |
+
update_edge_index = torch.cat([edge_index,trojan_edge],dim=1)
|
530 |
+
|
531 |
+
output = self.shadow_model(update_feat, update_edge_index, update_edge_weights)
|
532 |
+
|
533 |
+
labels_outter = labels.clone()
|
534 |
+
labels_outter[idx_outter] = self.target_class
|
535 |
+
loss_target = self.target_loss_weight *F.nll_loss(output[torch.cat([idx_train,idx_outter])],
|
536 |
+
labels_outter[torch.cat([idx_train,idx_outter])])
|
537 |
+
loss_homo = 0.0
|
538 |
+
|
539 |
+
if(self.homo_loss_weight > 0):
|
540 |
+
loss_homo = self.homo_loss(trojan_edge[:,:int(trojan_edge.shape[1]/2)],\
|
541 |
+
trojan_weights,\
|
542 |
+
update_feat,\
|
543 |
+
self.homo_boost_thrd)
|
544 |
+
|
545 |
+
loss_outter = loss_target + self.homo_loss_weight * loss_homo
|
546 |
+
|
547 |
+
loss_outter.backward()
|
548 |
+
optimizer_trigger.step()
|
549 |
+
acc_train_outter =(output[idx_outter].argmax(dim=1)==self.target_class).float().mean()
|
550 |
+
|
551 |
+
if loss_outter<loss_best:
|
552 |
+
self.weights = deepcopy(self.trojan.state_dict())
|
553 |
+
loss_best = float(loss_outter)
|
554 |
+
|
555 |
+
if self.debug and i % 10 == 0:
|
556 |
+
print('Epoch {}, loss_inner: {:.5f}, loss_target: {:.5f}, homo loss: {:.5f} '\
|
557 |
+
.format(i, loss_inner, loss_target, loss_homo))
|
558 |
+
print("acc_train_clean: {:.4f}, ASR_train_attach: {:.4f}, ASR_train_outter: {:.4f}"\
|
559 |
+
.format(acc_train_clean,acc_train_attach,acc_train_outter))
|
560 |
+
if self.debug:
|
561 |
+
print("load best weight based on the loss outter")
|
562 |
+
self.trojan.load_state_dict(self.weights)
|
563 |
+
self.trojan.eval()
|
564 |
+
|
565 |
+
# torch.cuda.empty_cache()
|
566 |
+
def get_poisoned(self):
|
567 |
+
with torch.no_grad():
|
568 |
+
poison_x, poison_edge_index, poison_edge_weights, poison_labels = self.inject_trigger(self.idx_attach,self.features,self.edge_index,self.edge_weights, self.labels, self.device)
|
569 |
+
# poison_labels = self.labels
|
570 |
+
poison_edge_index = poison_edge_index[:,poison_edge_weights>0.0]
|
571 |
+
poison_edge_weights = poison_edge_weights[poison_edge_weights>0.0]
|
572 |
+
return poison_x, poison_edge_index, poison_edge_weights, poison_labels
|
573 |
+
|
574 |
+
import torch
|
575 |
+
import torch.nn as nn
|
576 |
+
import torch.nn.functional as F
|
577 |
+
import torch.optim as optim
|
578 |
+
from copy import deepcopy
|
579 |
+
from torch_geometric.nn import GCNConv
|
580 |
+
import numpy as np
|
581 |
+
import scipy.sparse as sp
|
582 |
+
|
583 |
+
class GCN_Encoder(nn.Module):
|
584 |
+
|
585 |
+
def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4, layer=2,device=None,use_ln=False,layer_norm_first=False):
|
586 |
+
|
587 |
+
super(GCN_Encoder, self).__init__()
|
588 |
+
|
589 |
+
assert device is not None, "Please specify 'device'!"
|
590 |
+
self.device = device
|
591 |
+
self.nfeat = nfeat
|
592 |
+
self.hidden_sizes = [nhid]
|
593 |
+
self.nclass = nclass
|
594 |
+
self.use_ln = use_ln
|
595 |
+
self.layer_norm_first = layer_norm_first
|
596 |
+
# self.convs = nn.ModuleList()
|
597 |
+
# self.convs.append(GCNConv(nfeat, nhid))
|
598 |
+
# for _ in range(layer-2):
|
599 |
+
# self.convs.append(GCNConv(nhid,nhid))
|
600 |
+
# self.gc2 = GCNConv(nhid, nclass)
|
601 |
+
self.body = GCN_body(nfeat, nhid, dropout, layer,device=None,use_ln=use_ln,layer_norm_first=layer_norm_first)
|
602 |
+
self.fc = nn.Linear(nhid,nclass)
|
603 |
+
|
604 |
+
self.dropout = dropout
|
605 |
+
self.lr = lr
|
606 |
+
self.output = None
|
607 |
+
self.edge_index = None
|
608 |
+
self.edge_weight = None
|
609 |
+
self.features = None
|
610 |
+
self.weight_decay = weight_decay
|
611 |
+
|
612 |
+
def forward(self, x, edge_index, edge_weight=None):
|
613 |
+
x = self.body(x, edge_index,edge_weight)
|
614 |
+
x = self.fc(x)
|
615 |
+
return F.log_softmax(x,dim=1)
|
616 |
+
def get_h(self, x, edge_index,edge_weight):
|
617 |
+
self.eval()
|
618 |
+
x = self.body(x, edge_index,edge_weight)
|
619 |
+
return x
|
620 |
+
|
621 |
+
def fit(self, features, edge_index, edge_weight, labels, idx_train, idx_val=None, train_iters=200, verbose=False):
|
622 |
+
"""Train the gcn model, when idx_val is not None, pick the best model according to the validation loss.
|
623 |
+
Parameters
|
624 |
+
----------
|
625 |
+
features :
|
626 |
+
node features
|
627 |
+
adj :
|
628 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
629 |
+
labels :
|
630 |
+
node labels
|
631 |
+
idx_train :
|
632 |
+
node training indices
|
633 |
+
idx_val :
|
634 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
635 |
+
train_iters : int
|
636 |
+
number of training epochs
|
637 |
+
initialize : bool
|
638 |
+
whether to initialize parameters before training
|
639 |
+
verbose : bool
|
640 |
+
whether to show verbose logs
|
641 |
+
"""
|
642 |
+
|
643 |
+
self.edge_index, self.edge_weight = edge_index, edge_weight
|
644 |
+
self.features = features.to(self.device)
|
645 |
+
self.labels = labels.to(self.device)
|
646 |
+
|
647 |
+
if idx_val is None:
|
648 |
+
self._train_without_val(self.labels, idx_train, train_iters, verbose)
|
649 |
+
else:
|
650 |
+
self._train_with_val(self.labels, idx_train, idx_val, train_iters, verbose)
|
651 |
+
|
652 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose):
|
653 |
+
self.train()
|
654 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
655 |
+
for i in range(train_iters):
|
656 |
+
optimizer.zero_grad()
|
657 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
658 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
659 |
+
loss_train.backward()
|
660 |
+
optimizer.step()
|
661 |
+
if verbose and i % 10 == 0:
|
662 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
663 |
+
|
664 |
+
self.eval()
|
665 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
666 |
+
self.output = output
|
667 |
+
|
668 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
669 |
+
if verbose:
|
670 |
+
print('=== training gcn model ===')
|
671 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
672 |
+
|
673 |
+
best_loss_val = 100
|
674 |
+
best_acc_val = 0
|
675 |
+
|
676 |
+
for i in range(train_iters):
|
677 |
+
self.train()
|
678 |
+
optimizer.zero_grad()
|
679 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
680 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
681 |
+
loss_train.backward()
|
682 |
+
optimizer.step()
|
683 |
+
|
684 |
+
|
685 |
+
|
686 |
+
self.eval()
|
687 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
688 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
689 |
+
acc_val = accuracy(output[idx_val], labels[idx_val])
|
690 |
+
|
691 |
+
if verbose and i % 10 == 0:
|
692 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
693 |
+
print("acc_val: {:.4f}".format(acc_val))
|
694 |
+
if acc_val > best_acc_val:
|
695 |
+
best_acc_val = acc_val
|
696 |
+
self.output = output
|
697 |
+
weights = deepcopy(self.state_dict())
|
698 |
+
|
699 |
+
if verbose:
|
700 |
+
print('=== picking the best model according to the performance on validation ===')
|
701 |
+
self.load_state_dict(weights)
|
702 |
+
|
703 |
+
|
704 |
+
def test(self, features, edge_index, edge_weight, labels,idx_test):
|
705 |
+
"""Evaluate GCN performance on test set.
|
706 |
+
Parameters
|
707 |
+
----------
|
708 |
+
idx_test :
|
709 |
+
node testing indices
|
710 |
+
"""
|
711 |
+
self.eval()
|
712 |
+
with torch.no_grad():
|
713 |
+
output = self.forward(features, edge_index, edge_weight)
|
714 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
715 |
+
return float(acc_test)
|
716 |
+
|
717 |
+
def test_with_correct_nodes(self, features, edge_index, edge_weight, labels,idx_test):
|
718 |
+
self.eval()
|
719 |
+
output = self.forward(features, edge_index, edge_weight)
|
720 |
+
correct_nids = (output.argmax(dim=1)[idx_test]==labels[idx_test]).nonzero().flatten() # return a tensor
|
721 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
722 |
+
return acc_test,correct_nids
|
723 |
+
|
724 |
+
class GCN_body(nn.Module):
|
725 |
+
def __init__(self,nfeat, nhid, dropout=0.5, layer=2,device=None,layer_norm_first=False,use_ln=False):
|
726 |
+
super(GCN_body, self).__init__()
|
727 |
+
self.device = device
|
728 |
+
self.nfeat = nfeat
|
729 |
+
self.hidden_sizes = [nhid]
|
730 |
+
self.dropout = dropout
|
731 |
+
|
732 |
+
self.convs = nn.ModuleList()
|
733 |
+
self.convs.append(GCNConv(nfeat, nhid))
|
734 |
+
self.lns = nn.ModuleList()
|
735 |
+
self.lns.append(torch.nn.LayerNorm(nfeat))
|
736 |
+
for _ in range(layer-1):
|
737 |
+
self.convs.append(GCNConv(nhid,nhid))
|
738 |
+
self.lns.append(nn.LayerNorm(nhid))
|
739 |
+
self.lns.append(torch.nn.LayerNorm(nhid))
|
740 |
+
self.layer_norm_first = layer_norm_first
|
741 |
+
self.use_ln = use_ln
|
742 |
+
def forward(self,x, edge_index,edge_weight=None):
|
743 |
+
if(self.layer_norm_first):
|
744 |
+
x = self.lns[0](x)
|
745 |
+
i=0
|
746 |
+
for conv in self.convs:
|
747 |
+
x = F.relu(conv(x, edge_index,edge_weight))
|
748 |
+
if self.use_ln:
|
749 |
+
x = self.lns[i+1](x)
|
750 |
+
i+=1
|
751 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
752 |
+
return x
|
753 |
+
|
754 |
+
class GCN(nn.Module):
|
755 |
+
|
756 |
+
def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4, layer=2,device=None,layer_norm_first=False,use_ln=False):
|
757 |
+
|
758 |
+
super(GCN, self).__init__()
|
759 |
+
|
760 |
+
assert device is not None, "Please specify 'device'!"
|
761 |
+
self.device = device
|
762 |
+
self.nfeat = nfeat
|
763 |
+
self.hidden_sizes = [nhid]
|
764 |
+
self.nclass = nclass
|
765 |
+
self.convs = nn.ModuleList()
|
766 |
+
self.convs.append(GCNConv(nfeat, nhid))
|
767 |
+
self.lns = nn.ModuleList()
|
768 |
+
self.lns.append(torch.nn.LayerNorm(nfeat))
|
769 |
+
for _ in range(layer-2):
|
770 |
+
self.convs.append(GCNConv(nhid,nhid))
|
771 |
+
self.lns.append(nn.LayerNorm(nhid))
|
772 |
+
self.lns.append(nn.LayerNorm(nhid))
|
773 |
+
self.gc2 = GCNConv(nhid, nclass)
|
774 |
+
self.dropout = dropout
|
775 |
+
self.lr = lr
|
776 |
+
self.output = None
|
777 |
+
self.edge_index = None
|
778 |
+
self.edge_weight = None
|
779 |
+
self.features = None
|
780 |
+
self.weight_decay = weight_decay
|
781 |
+
|
782 |
+
self.layer_norm_first = layer_norm_first
|
783 |
+
self.use_ln = use_ln
|
784 |
+
|
785 |
+
def forward(self, x, edge_index, edge_weight=None):
|
786 |
+
if(self.layer_norm_first):
|
787 |
+
x = self.lns[0](x)
|
788 |
+
i=0
|
789 |
+
for conv in self.convs:
|
790 |
+
x = F.relu(conv(x, edge_index,edge_weight))
|
791 |
+
if self.use_ln:
|
792 |
+
x = self.lns[i+1](x)
|
793 |
+
i+=1
|
794 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
795 |
+
x = self.gc2(x, edge_index,edge_weight)
|
796 |
+
return F.log_softmax(x,dim=1)
|
797 |
+
def get_h(self, x, edge_index):
|
798 |
+
|
799 |
+
for conv in self.convs:
|
800 |
+
x = F.relu(conv(x, edge_index))
|
801 |
+
|
802 |
+
return x
|
803 |
+
|
804 |
+
def fit(self, features, edge_index, edge_weight, labels, idx_train, idx_val=None, train_iters=200, verbose=False):
|
805 |
+
"""Train the gcn model, when idx_val is not None, pick the best model according to the validation loss.
|
806 |
+
Parameters
|
807 |
+
----------
|
808 |
+
features :
|
809 |
+
node features
|
810 |
+
adj :
|
811 |
+
the adjacency matrix. The format could be torch.tensor or scipy matrix
|
812 |
+
labels :
|
813 |
+
node labels
|
814 |
+
idx_train :
|
815 |
+
node training indices
|
816 |
+
idx_val :
|
817 |
+
node validation indices. If not given (None), GCN training process will not adpot early stopping
|
818 |
+
train_iters : int
|
819 |
+
number of training epochs
|
820 |
+
initialize : bool
|
821 |
+
whether to initialize parameters before training
|
822 |
+
verbose : bool
|
823 |
+
whether to show verbose logs
|
824 |
+
"""
|
825 |
+
|
826 |
+
self.edge_index, self.edge_weight = edge_index, edge_weight
|
827 |
+
self.features = features.to(self.device)
|
828 |
+
self.labels = labels.to(self.device)
|
829 |
+
|
830 |
+
if idx_val is None:
|
831 |
+
self._train_without_val(self.labels, idx_train, train_iters, verbose)
|
832 |
+
else:
|
833 |
+
self._train_with_val(self.labels, idx_train, idx_val, train_iters, verbose)
|
834 |
+
# torch.cuda.empty_cache()
|
835 |
+
|
836 |
+
def _train_without_val(self, labels, idx_train, train_iters, verbose):
|
837 |
+
self.train()
|
838 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
839 |
+
for i in range(train_iters):
|
840 |
+
optimizer.zero_grad()
|
841 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
842 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
843 |
+
loss_train.backward()
|
844 |
+
optimizer.step()
|
845 |
+
if verbose and i % 10 == 0:
|
846 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
847 |
+
|
848 |
+
self.eval()
|
849 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
850 |
+
self.output = output
|
851 |
+
# torch.cuda.empty_cache()
|
852 |
+
|
853 |
+
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
|
854 |
+
if verbose:
|
855 |
+
print('=== training gcn model ===')
|
856 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
857 |
+
|
858 |
+
best_loss_val = 100
|
859 |
+
best_acc_val = 0
|
860 |
+
|
861 |
+
for i in range(train_iters):
|
862 |
+
self.train()
|
863 |
+
optimizer.zero_grad()
|
864 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
865 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
866 |
+
loss_train.backward()
|
867 |
+
optimizer.step()
|
868 |
+
|
869 |
+
|
870 |
+
|
871 |
+
self.eval()
|
872 |
+
output = self.forward(self.features, self.edge_index, self.edge_weight)
|
873 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
874 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
875 |
+
|
876 |
+
if verbose and i % 10 == 0:
|
877 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
878 |
+
print("acc_val: {:.4f}".format(acc_val))
|
879 |
+
if acc_val > best_acc_val:
|
880 |
+
best_acc_val = acc_val
|
881 |
+
self.output = output
|
882 |
+
weights = deepcopy(self.state_dict())
|
883 |
+
|
884 |
+
if verbose:
|
885 |
+
print('=== picking the best model according to the performance on validation ===')
|
886 |
+
self.load_state_dict(weights)
|
887 |
+
# torch.cuda.empty_cache()
|
888 |
+
|
889 |
+
|
890 |
+
def test(self, features, edge_index, edge_weight, labels,idx_test):
|
891 |
+
"""Evaluate GCN performance on test set.
|
892 |
+
Parameters
|
893 |
+
----------
|
894 |
+
idx_test :
|
895 |
+
node testing indices
|
896 |
+
"""
|
897 |
+
self.eval()
|
898 |
+
with torch.no_grad():
|
899 |
+
output = self.forward(features, edge_index, edge_weight)
|
900 |
+
acc_test = utils.accuracy(output[idx_test], labels[idx_test])
|
901 |
+
# torch.cuda.empty_cache()
|
902 |
+
# print("Test set results:",
|
903 |
+
# "loss= {:.4f}".format(loss_test.item()),
|
904 |
+
# "accuracy= {:.4f}".format(acc_test.item()))
|
905 |
+
return float(acc_test)
|
906 |
+
|
907 |
+
def test_with_correct_nodes(self, features, edge_index, edge_weight, labels,idx_test):
|
908 |
+
self.eval()
|
909 |
+
output = self.forward(features, edge_index, edge_weight)
|
910 |
+
correct_nids = (output.argmax(dim=1)[idx_test]==labels[idx_test]).nonzero().flatten() # return a tensor
|
911 |
+
acc_test = utils.accuracy(output[idx_test], labels[idx_test])
|
912 |
+
# torch.cuda.empty_cache()
|
913 |
+
return acc_test,correct_nids
|
deeprobust/graph/utils.py
ADDED
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.sparse as sp
|
3 |
+
import torch
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
import torch.sparse as ts
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
def encode_onehot(labels):
|
10 |
+
"""Convert label to onehot format.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
labels : numpy.array
|
15 |
+
node labels
|
16 |
+
|
17 |
+
Returns
|
18 |
+
-------
|
19 |
+
numpy.array
|
20 |
+
onehot labels
|
21 |
+
"""
|
22 |
+
eye = np.eye(labels.max() + 1)
|
23 |
+
onehot_mx = eye[labels]
|
24 |
+
return onehot_mx
|
25 |
+
|
26 |
+
def tensor2onehot(labels):
|
27 |
+
"""Convert label tensor to label onehot tensor.
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
labels : torch.LongTensor
|
32 |
+
node labels
|
33 |
+
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
torch.LongTensor
|
37 |
+
onehot labels tensor
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
eye = torch.eye(labels.max() + 1)
|
42 |
+
onehot_mx = eye[labels]
|
43 |
+
return onehot_mx.to(labels.device)
|
44 |
+
|
45 |
+
def preprocess(adj, features, labels, preprocess_adj=False, preprocess_feature=False, sparse=False, device='cpu'):
|
46 |
+
"""Convert adj, features, labels from array or sparse matrix to
|
47 |
+
torch Tensor, and normalize the input data.
|
48 |
+
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
adj : scipy.sparse.csr_matrix
|
52 |
+
the adjacency matrix.
|
53 |
+
features : scipy.sparse.csr_matrix
|
54 |
+
node features
|
55 |
+
labels : numpy.array
|
56 |
+
node labels
|
57 |
+
preprocess_adj : bool
|
58 |
+
whether to normalize the adjacency matrix
|
59 |
+
preprocess_feature : bool
|
60 |
+
whether to normalize the feature matrix
|
61 |
+
sparse : bool
|
62 |
+
whether to return sparse tensor
|
63 |
+
device : str
|
64 |
+
'cpu' or 'cuda'
|
65 |
+
"""
|
66 |
+
|
67 |
+
if preprocess_adj:
|
68 |
+
adj = normalize_adj(adj)
|
69 |
+
|
70 |
+
if preprocess_feature:
|
71 |
+
features = normalize_feature(features)
|
72 |
+
|
73 |
+
labels = torch.LongTensor(labels)
|
74 |
+
if sparse:
|
75 |
+
adj = sparse_mx_to_torch_sparse_tensor(adj)
|
76 |
+
features = sparse_mx_to_torch_sparse_tensor(features)
|
77 |
+
else:
|
78 |
+
if sp.issparse(features):
|
79 |
+
features = torch.FloatTensor(np.array(features.todense()))
|
80 |
+
else:
|
81 |
+
features = torch.FloatTensor(features)
|
82 |
+
adj = torch.FloatTensor(adj.todense())
|
83 |
+
return adj.to(device), features.to(device), labels.to(device)
|
84 |
+
|
85 |
+
def to_tensor(adj, features, labels=None, device='cpu'):
|
86 |
+
"""Convert adj, features, labels from array or sparse matrix to
|
87 |
+
torch Tensor.
|
88 |
+
|
89 |
+
Parameters
|
90 |
+
----------
|
91 |
+
adj : scipy.sparse.csr_matrix
|
92 |
+
the adjacency matrix.
|
93 |
+
features : scipy.sparse.csr_matrix
|
94 |
+
node features
|
95 |
+
labels : numpy.array
|
96 |
+
node labels
|
97 |
+
device : str
|
98 |
+
'cpu' or 'cuda'
|
99 |
+
"""
|
100 |
+
if sp.issparse(adj):
|
101 |
+
adj = sparse_mx_to_torch_sparse_tensor(adj)
|
102 |
+
else:
|
103 |
+
adj = torch.FloatTensor(adj)
|
104 |
+
if sp.issparse(features):
|
105 |
+
features = sparse_mx_to_torch_sparse_tensor(features)
|
106 |
+
else:
|
107 |
+
features = torch.FloatTensor(np.array(features))
|
108 |
+
|
109 |
+
if labels is None:
|
110 |
+
return adj.to(device), features.to(device)
|
111 |
+
else:
|
112 |
+
labels = torch.LongTensor(labels)
|
113 |
+
return adj.to(device), features.to(device), labels.to(device)
|
114 |
+
|
115 |
+
def normalize_feature(mx):
|
116 |
+
"""Row-normalize sparse matrix or dense matrix
|
117 |
+
|
118 |
+
Parameters
|
119 |
+
----------
|
120 |
+
mx : scipy.sparse.csr_matrix or numpy.array
|
121 |
+
matrix to be normalized
|
122 |
+
|
123 |
+
Returns
|
124 |
+
-------
|
125 |
+
scipy.sprase.lil_matrix
|
126 |
+
normalized matrix
|
127 |
+
"""
|
128 |
+
if type(mx) is not sp.lil.lil_matrix:
|
129 |
+
try:
|
130 |
+
mx = mx.tolil()
|
131 |
+
except AttributeError:
|
132 |
+
pass
|
133 |
+
rowsum = np.array(mx.sum(1))
|
134 |
+
r_inv = np.power(rowsum, -1).flatten()
|
135 |
+
r_inv[np.isinf(r_inv)] = 0.
|
136 |
+
r_mat_inv = sp.diags(r_inv)
|
137 |
+
mx = r_mat_inv.dot(mx)
|
138 |
+
return mx
|
139 |
+
|
140 |
+
def normalize_adj(mx):
|
141 |
+
"""Normalize sparse adjacency matrix,
|
142 |
+
A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2
|
143 |
+
Row-normalize sparse matrix
|
144 |
+
|
145 |
+
Parameters
|
146 |
+
----------
|
147 |
+
mx : scipy.sparse.csr_matrix
|
148 |
+
matrix to be normalized
|
149 |
+
|
150 |
+
Returns
|
151 |
+
-------
|
152 |
+
scipy.sprase.lil_matrix
|
153 |
+
normalized matrix
|
154 |
+
"""
|
155 |
+
|
156 |
+
# TODO: maybe using coo format would be better?
|
157 |
+
if type(mx) is not sp.lil.lil_matrix:
|
158 |
+
mx = mx.tolil()
|
159 |
+
if mx[0, 0] == 0 :
|
160 |
+
mx = mx + sp.eye(mx.shape[0])
|
161 |
+
rowsum = np.array(mx.sum(1))
|
162 |
+
r_inv = np.power(rowsum, -1/2).flatten()
|
163 |
+
r_inv[np.isinf(r_inv)] = 0.
|
164 |
+
r_mat_inv = sp.diags(r_inv)
|
165 |
+
mx = r_mat_inv.dot(mx)
|
166 |
+
mx = mx.dot(r_mat_inv)
|
167 |
+
return mx
|
168 |
+
|
169 |
+
def normalize_sparse_tensor(adj, fill_value=1):
|
170 |
+
"""Normalize sparse tensor. Need to import torch_scatter
|
171 |
+
"""
|
172 |
+
edge_index = adj._indices()
|
173 |
+
edge_weight = adj._values()
|
174 |
+
num_nodes= adj.size(0)
|
175 |
+
edge_index, edge_weight = add_self_loops(
|
176 |
+
edge_index, edge_weight, fill_value, num_nodes)
|
177 |
+
|
178 |
+
row, col = edge_index
|
179 |
+
from torch_scatter import scatter_add
|
180 |
+
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
|
181 |
+
deg_inv_sqrt = deg.pow(-0.5)
|
182 |
+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
|
183 |
+
|
184 |
+
values = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
|
185 |
+
|
186 |
+
shape = adj.shape
|
187 |
+
return torch.sparse.FloatTensor(edge_index, values, shape)
|
188 |
+
|
189 |
+
def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None):
|
190 |
+
# num_nodes = maybe_num_nodes(edge_index, num_nodes)
|
191 |
+
|
192 |
+
loop_index = torch.arange(0, num_nodes, dtype=torch.long,
|
193 |
+
device=edge_index.device)
|
194 |
+
loop_index = loop_index.unsqueeze(0).repeat(2, 1)
|
195 |
+
|
196 |
+
if edge_weight is not None:
|
197 |
+
assert edge_weight.numel() == edge_index.size(1)
|
198 |
+
loop_weight = edge_weight.new_full((num_nodes, ), fill_value)
|
199 |
+
edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
|
200 |
+
|
201 |
+
edge_index = torch.cat([edge_index, loop_index], dim=1)
|
202 |
+
|
203 |
+
return edge_index, edge_weight
|
204 |
+
|
205 |
+
def normalize_adj_tensor(adj, sparse=False):
|
206 |
+
"""Normalize adjacency tensor matrix.
|
207 |
+
"""
|
208 |
+
device = adj.device
|
209 |
+
if sparse:
|
210 |
+
# warnings.warn('If you find the training process is too slow, you can uncomment line 207 in deeprobust/graph/utils.py. Note that you need to install torch_sparse')
|
211 |
+
# TODO if this is too slow, uncomment the following code,
|
212 |
+
# but you need to install torch_scatter
|
213 |
+
# return normalize_sparse_tensor(adj)
|
214 |
+
adj = to_scipy(adj)
|
215 |
+
mx = normalize_adj(adj)
|
216 |
+
return sparse_mx_to_torch_sparse_tensor(mx).to(device)
|
217 |
+
else:
|
218 |
+
mx = adj + torch.eye(adj.shape[0]).to(device)
|
219 |
+
rowsum = mx.sum(1)
|
220 |
+
r_inv = rowsum.pow(-1/2).flatten()
|
221 |
+
r_inv[torch.isinf(r_inv)] = 0.
|
222 |
+
r_mat_inv = torch.diag(r_inv)
|
223 |
+
mx = r_mat_inv @ mx
|
224 |
+
mx = mx @ r_mat_inv
|
225 |
+
return mx
|
226 |
+
|
227 |
+
def degree_normalize_adj(mx):
|
228 |
+
"""Row-normalize sparse matrix"""
|
229 |
+
mx = mx.tolil()
|
230 |
+
if mx[0, 0] == 0 :
|
231 |
+
mx = mx + sp.eye(mx.shape[0])
|
232 |
+
rowsum = np.array(mx.sum(1))
|
233 |
+
r_inv = np.power(rowsum, -1).flatten()
|
234 |
+
r_inv[np.isinf(r_inv)] = 0.
|
235 |
+
r_mat_inv = sp.diags(r_inv)
|
236 |
+
# mx = mx.dot(r_mat_inv)
|
237 |
+
mx = r_mat_inv.dot(mx)
|
238 |
+
return mx
|
239 |
+
|
240 |
+
def degree_normalize_sparse_tensor(adj, fill_value=1):
|
241 |
+
"""degree_normalize_sparse_tensor.
|
242 |
+
"""
|
243 |
+
edge_index = adj._indices()
|
244 |
+
edge_weight = adj._values()
|
245 |
+
num_nodes= adj.size(0)
|
246 |
+
|
247 |
+
edge_index, edge_weight = add_self_loops(
|
248 |
+
edge_index, edge_weight, fill_value, num_nodes)
|
249 |
+
|
250 |
+
row, col = edge_index
|
251 |
+
from torch_scatter import scatter_add
|
252 |
+
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
|
253 |
+
deg_inv_sqrt = deg.pow(-1)
|
254 |
+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
|
255 |
+
|
256 |
+
values = deg_inv_sqrt[row] * edge_weight
|
257 |
+
shape = adj.shape
|
258 |
+
return torch.sparse.FloatTensor(edge_index, values, shape)
|
259 |
+
|
260 |
+
def degree_normalize_adj_tensor(adj, sparse=True):
|
261 |
+
"""degree_normalize_adj_tensor.
|
262 |
+
"""
|
263 |
+
|
264 |
+
device = adj.device
|
265 |
+
if sparse:
|
266 |
+
# return degree_normalize_sparse_tensor(adj)
|
267 |
+
adj = to_scipy(adj)
|
268 |
+
mx = degree_normalize_adj(adj)
|
269 |
+
return sparse_mx_to_torch_sparse_tensor(mx).to(device)
|
270 |
+
else:
|
271 |
+
mx = adj + torch.eye(adj.shape[0]).to(device)
|
272 |
+
rowsum = mx.sum(1)
|
273 |
+
r_inv = rowsum.pow(-1).flatten()
|
274 |
+
r_inv[torch.isinf(r_inv)] = 0.
|
275 |
+
r_mat_inv = torch.diag(r_inv)
|
276 |
+
mx = r_mat_inv @ mx
|
277 |
+
return mx
|
278 |
+
|
279 |
+
def accuracy(output, labels):
|
280 |
+
"""Return accuracy of output compared to labels.
|
281 |
+
|
282 |
+
Parameters
|
283 |
+
----------
|
284 |
+
output : torch.Tensor
|
285 |
+
output from model
|
286 |
+
labels : torch.Tensor or numpy.array
|
287 |
+
node labels
|
288 |
+
|
289 |
+
Returns
|
290 |
+
-------
|
291 |
+
float
|
292 |
+
accuracy
|
293 |
+
"""
|
294 |
+
if not hasattr(labels, '__len__'):
|
295 |
+
labels = [labels]
|
296 |
+
if type(labels) is not torch.Tensor:
|
297 |
+
labels = torch.LongTensor(labels)
|
298 |
+
preds = output.max(1)[1].type_as(labels)
|
299 |
+
correct = preds.eq(labels).double()
|
300 |
+
correct = correct.sum()
|
301 |
+
return correct / len(labels)
|
302 |
+
|
303 |
+
def loss_acc(output, labels, targets, avg_loss=True):
|
304 |
+
if type(labels) is not torch.Tensor:
|
305 |
+
labels = torch.LongTensor(labels)
|
306 |
+
preds = output.max(1)[1].type_as(labels)
|
307 |
+
correct = preds.eq(labels).double()[targets]
|
308 |
+
loss = F.nll_loss(output[targets], labels[targets], reduction='mean' if avg_loss else 'none')
|
309 |
+
|
310 |
+
if avg_loss:
|
311 |
+
return loss, correct.sum() / len(targets)
|
312 |
+
return loss, correct
|
313 |
+
# correct = correct.sum()
|
314 |
+
# return loss, correct / len(labels)
|
315 |
+
|
316 |
+
def get_perf(output, labels, mask, verbose=True):
|
317 |
+
"""evalute performance for test masked data"""
|
318 |
+
loss = F.nll_loss(output[mask], labels[mask])
|
319 |
+
acc = accuracy(output[mask], labels[mask])
|
320 |
+
if verbose:
|
321 |
+
print("loss= {:.4f}".format(loss.item()),
|
322 |
+
"accuracy= {:.4f}".format(acc.item()))
|
323 |
+
return loss.item(), acc.item()
|
324 |
+
|
325 |
+
|
326 |
+
def classification_margin(output, true_label):
|
327 |
+
"""Calculate classification margin for outputs.
|
328 |
+
`probs_true_label - probs_best_second_class`
|
329 |
+
|
330 |
+
Parameters
|
331 |
+
----------
|
332 |
+
output: torch.Tensor
|
333 |
+
output vector (1 dimension)
|
334 |
+
true_label: int
|
335 |
+
true label for this node
|
336 |
+
|
337 |
+
Returns
|
338 |
+
-------
|
339 |
+
list
|
340 |
+
classification margin for this node
|
341 |
+
"""
|
342 |
+
|
343 |
+
probs = torch.exp(output)
|
344 |
+
probs_true_label = probs[true_label].clone()
|
345 |
+
probs[true_label] = 0
|
346 |
+
probs_best_second_class = probs[probs.argmax()]
|
347 |
+
return (probs_true_label - probs_best_second_class).item()
|
348 |
+
|
349 |
+
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
|
350 |
+
"""Convert a scipy sparse matrix to a torch sparse tensor."""
|
351 |
+
sparse_mx = sparse_mx.tocoo().astype(np.float32)
|
352 |
+
sparserow=torch.LongTensor(sparse_mx.row).unsqueeze(1)
|
353 |
+
sparsecol=torch.LongTensor(sparse_mx.col).unsqueeze(1)
|
354 |
+
sparseconcat=torch.cat((sparserow, sparsecol),1)
|
355 |
+
sparsedata=torch.FloatTensor(sparse_mx.data)
|
356 |
+
return torch.sparse.FloatTensor(sparseconcat.t(),sparsedata,torch.Size(sparse_mx.shape))
|
357 |
+
|
358 |
+
# slower version....
|
359 |
+
# sparse_mx = sparse_mx.tocoo().astype(np.float32)
|
360 |
+
# indices = torch.from_numpy(
|
361 |
+
# np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
|
362 |
+
# values = torch.from_numpy(sparse_mx.data)
|
363 |
+
# shape = torch.Size(sparse_mx.shape)
|
364 |
+
# return torch.sparse.FloatTensor(indices, values, shape)
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
def to_scipy(tensor):
|
369 |
+
"""Convert a dense/sparse tensor to scipy matrix"""
|
370 |
+
if is_sparse_tensor(tensor):
|
371 |
+
values = tensor._values()
|
372 |
+
indices = tensor._indices()
|
373 |
+
return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape)
|
374 |
+
else:
|
375 |
+
indices = tensor.nonzero().t()
|
376 |
+
values = tensor[indices[0], indices[1]]
|
377 |
+
return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape)
|
378 |
+
|
379 |
+
def is_sparse_tensor(tensor):
|
380 |
+
"""Check if a tensor is sparse tensor.
|
381 |
+
|
382 |
+
Parameters
|
383 |
+
----------
|
384 |
+
tensor : torch.Tensor
|
385 |
+
given tensor
|
386 |
+
|
387 |
+
Returns
|
388 |
+
-------
|
389 |
+
bool
|
390 |
+
whether a tensor is sparse tensor
|
391 |
+
"""
|
392 |
+
# if hasattr(tensor, 'nnz'):
|
393 |
+
if tensor.layout == torch.sparse_coo:
|
394 |
+
return True
|
395 |
+
else:
|
396 |
+
return False
|
397 |
+
|
398 |
+
def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None):
|
399 |
+
"""This setting follows nettack/mettack, where we split the nodes
|
400 |
+
into 10% training, 10% validation and 80% testing data
|
401 |
+
|
402 |
+
Parameters
|
403 |
+
----------
|
404 |
+
nnodes : int
|
405 |
+
number of nodes in total
|
406 |
+
val_size : float
|
407 |
+
size of validation set
|
408 |
+
test_size : float
|
409 |
+
size of test set
|
410 |
+
stratify :
|
411 |
+
data is expected to split in a stratified fashion. So stratify should be labels.
|
412 |
+
seed : int or None
|
413 |
+
random seed
|
414 |
+
|
415 |
+
Returns
|
416 |
+
-------
|
417 |
+
idx_train :
|
418 |
+
node training indices
|
419 |
+
idx_val :
|
420 |
+
node validation indices
|
421 |
+
idx_test :
|
422 |
+
node test indices
|
423 |
+
"""
|
424 |
+
|
425 |
+
assert stratify is not None, 'stratify cannot be None!'
|
426 |
+
|
427 |
+
if seed is not None:
|
428 |
+
np.random.seed(seed)
|
429 |
+
|
430 |
+
idx = np.arange(nnodes)
|
431 |
+
train_size = 1 - val_size - test_size
|
432 |
+
idx_train_and_val, idx_test = train_test_split(idx,
|
433 |
+
random_state=None,
|
434 |
+
train_size=train_size + val_size,
|
435 |
+
test_size=test_size,
|
436 |
+
stratify=stratify)
|
437 |
+
|
438 |
+
if stratify is not None:
|
439 |
+
stratify = stratify[idx_train_and_val]
|
440 |
+
|
441 |
+
idx_train, idx_val = train_test_split(idx_train_and_val,
|
442 |
+
random_state=None,
|
443 |
+
train_size=(train_size / (train_size + val_size)),
|
444 |
+
test_size=(val_size / (train_size + val_size)),
|
445 |
+
stratify=stratify)
|
446 |
+
|
447 |
+
return idx_train, idx_val, idx_test
|
448 |
+
|
449 |
+
def get_train_test(nnodes, test_size=0.8, stratify=None, seed=None):
|
450 |
+
"""This function returns training and test set without validation.
|
451 |
+
It can be used for settings of different label rates.
|
452 |
+
|
453 |
+
Parameters
|
454 |
+
----------
|
455 |
+
nnodes : int
|
456 |
+
number of nodes in total
|
457 |
+
test_size : float
|
458 |
+
size of test set
|
459 |
+
stratify :
|
460 |
+
data is expected to split in a stratified fashion. So stratify should be labels.
|
461 |
+
seed : int or None
|
462 |
+
random seed
|
463 |
+
|
464 |
+
Returns
|
465 |
+
-------
|
466 |
+
idx_train :
|
467 |
+
node training indices
|
468 |
+
idx_test :
|
469 |
+
node test indices
|
470 |
+
"""
|
471 |
+
assert stratify is not None, 'stratify cannot be None!'
|
472 |
+
|
473 |
+
if seed is not None:
|
474 |
+
np.random.seed(seed)
|
475 |
+
|
476 |
+
idx = np.arange(nnodes)
|
477 |
+
train_size = 1 - test_size
|
478 |
+
idx_train, idx_test = train_test_split(idx, random_state=None,
|
479 |
+
train_size=train_size,
|
480 |
+
test_size=test_size,
|
481 |
+
stratify=stratify)
|
482 |
+
|
483 |
+
return idx_train, idx_test
|
484 |
+
|
485 |
+
def get_train_val_test_gcn(labels, seed=None):
|
486 |
+
"""This setting follows gcn, where we randomly sample 20 instances for each class
|
487 |
+
as training data, 500 instances as validation data, 1000 instances as test data.
|
488 |
+
Note here we are not using fixed splits. When random seed changes, the splits
|
489 |
+
will also change.
|
490 |
+
|
491 |
+
Parameters
|
492 |
+
----------
|
493 |
+
labels : numpy.array
|
494 |
+
node labels
|
495 |
+
seed : int or None
|
496 |
+
random seed
|
497 |
+
|
498 |
+
Returns
|
499 |
+
-------
|
500 |
+
idx_train :
|
501 |
+
node training indices
|
502 |
+
idx_val :
|
503 |
+
node validation indices
|
504 |
+
idx_test :
|
505 |
+
node test indices
|
506 |
+
"""
|
507 |
+
if seed is not None:
|
508 |
+
np.random.seed(seed)
|
509 |
+
|
510 |
+
idx = np.arange(len(labels))
|
511 |
+
nclass = labels.max() + 1
|
512 |
+
idx_train = []
|
513 |
+
idx_unlabeled = []
|
514 |
+
for i in range(nclass):
|
515 |
+
labels_i = idx[labels==i]
|
516 |
+
labels_i = np.random.permutation(labels_i)
|
517 |
+
idx_train = np.hstack((idx_train, labels_i[: 20])).astype(np.int)
|
518 |
+
idx_unlabeled = np.hstack((idx_unlabeled, labels_i[20: ])).astype(np.int)
|
519 |
+
|
520 |
+
idx_unlabeled = np.random.permutation(idx_unlabeled)
|
521 |
+
idx_val = idx_unlabeled[: 500]
|
522 |
+
idx_test = idx_unlabeled[500: 1500]
|
523 |
+
return idx_train, idx_val, idx_test
|
524 |
+
|
525 |
+
def get_train_test_labelrate(labels, label_rate):
|
526 |
+
"""Get train test according to given label rate.
|
527 |
+
"""
|
528 |
+
nclass = labels.max() + 1
|
529 |
+
train_size = int(round(len(labels) * label_rate / nclass))
|
530 |
+
print("=== train_size = %s ===" % train_size)
|
531 |
+
idx_train, idx_val, idx_test = get_splits_each_class(labels, train_size=train_size)
|
532 |
+
return idx_train, idx_test
|
533 |
+
|
534 |
+
def get_splits_each_class(labels, train_size):
|
535 |
+
"""We randomly sample n instances for class, where n = train_size.
|
536 |
+
"""
|
537 |
+
idx = np.arange(len(labels))
|
538 |
+
nclass = labels.max() + 1
|
539 |
+
idx_train = []
|
540 |
+
idx_val = []
|
541 |
+
idx_test = []
|
542 |
+
for i in range(nclass):
|
543 |
+
labels_i = idx[labels==i]
|
544 |
+
labels_i = np.random.permutation(labels_i)
|
545 |
+
idx_train = np.hstack((idx_train, labels_i[: train_size])).astype(np.int)
|
546 |
+
idx_val = np.hstack((idx_val, labels_i[train_size: 2*train_size])).astype(np.int)
|
547 |
+
idx_test = np.hstack((idx_test, labels_i[2*train_size: ])).astype(np.int)
|
548 |
+
|
549 |
+
return np.random.permutation(idx_train), np.random.permutation(idx_val), \
|
550 |
+
np.random.permutation(idx_test)
|
551 |
+
|
552 |
+
|
553 |
+
def unravel_index(index, array_shape):
|
554 |
+
rows = torch.div(index, array_shape[1], rounding_mode='trunc')
|
555 |
+
cols = index % array_shape[1]
|
556 |
+
return rows, cols
|
557 |
+
|
558 |
+
|
559 |
+
def get_degree_squence(adj):
|
560 |
+
try:
|
561 |
+
return adj.sum(0)
|
562 |
+
except:
|
563 |
+
return ts.sum(adj, dim=1).to_dense()
|
564 |
+
|
565 |
+
def likelihood_ratio_filter(node_pairs, modified_adjacency, original_adjacency, d_min, threshold=0.004, undirected=True):
|
566 |
+
"""
|
567 |
+
Filter the input node pairs based on the likelihood ratio test proposed by Zügner et al. 2018, see
|
568 |
+
https://dl.acm.org/citation.cfm?id=3220078. In essence, for each node pair return 1 if adding/removing the edge
|
569 |
+
between the two nodes does not violate the unnoticeability constraint, and return 0 otherwise. Assumes unweighted
|
570 |
+
and undirected graphs.
|
571 |
+
"""
|
572 |
+
|
573 |
+
N = int(modified_adjacency.shape[0])
|
574 |
+
# original_degree_sequence = get_degree_squence(original_adjacency)
|
575 |
+
# current_degree_sequence = get_degree_squence(modified_adjacency)
|
576 |
+
original_degree_sequence = original_adjacency.sum(0)
|
577 |
+
current_degree_sequence = modified_adjacency.sum(0)
|
578 |
+
|
579 |
+
concat_degree_sequence = torch.cat((current_degree_sequence, original_degree_sequence))
|
580 |
+
|
581 |
+
# Compute the log likelihood values of the original, modified, and combined degree sequences.
|
582 |
+
ll_orig, alpha_orig, n_orig, sum_log_degrees_original = degree_sequence_log_likelihood(original_degree_sequence, d_min)
|
583 |
+
ll_current, alpha_current, n_current, sum_log_degrees_current = degree_sequence_log_likelihood(current_degree_sequence, d_min)
|
584 |
+
|
585 |
+
ll_comb, alpha_comb, n_comb, sum_log_degrees_combined = degree_sequence_log_likelihood(concat_degree_sequence, d_min)
|
586 |
+
|
587 |
+
# Compute the log likelihood ratio
|
588 |
+
current_ratio = -2 * ll_comb + 2 * (ll_orig + ll_current)
|
589 |
+
|
590 |
+
# Compute new log likelihood values that would arise if we add/remove the edges corresponding to each node pair.
|
591 |
+
new_lls, new_alphas, new_ns, new_sum_log_degrees = updated_log_likelihood_for_edge_changes(node_pairs,
|
592 |
+
modified_adjacency, d_min)
|
593 |
+
|
594 |
+
# Combination of the original degree distribution with the distributions corresponding to each node pair.
|
595 |
+
n_combined = n_orig + new_ns
|
596 |
+
new_sum_log_degrees_combined = sum_log_degrees_original + new_sum_log_degrees
|
597 |
+
alpha_combined = compute_alpha(n_combined, new_sum_log_degrees_combined, d_min)
|
598 |
+
new_ll_combined = compute_log_likelihood(n_combined, alpha_combined, new_sum_log_degrees_combined, d_min)
|
599 |
+
new_ratios = -2 * new_ll_combined + 2 * (new_lls + ll_orig)
|
600 |
+
|
601 |
+
# Allowed edges are only those for which the resulting likelihood ratio measure is < than the threshold
|
602 |
+
allowed_edges = new_ratios < threshold
|
603 |
+
|
604 |
+
if allowed_edges.is_cuda:
|
605 |
+
filtered_edges = node_pairs[allowed_edges.cpu().numpy().astype(np.bool)]
|
606 |
+
else:
|
607 |
+
filtered_edges = node_pairs[allowed_edges.numpy().astype(np.bool)]
|
608 |
+
|
609 |
+
allowed_mask = torch.zeros(modified_adjacency.shape)
|
610 |
+
allowed_mask[filtered_edges.T] = 1
|
611 |
+
if undirected:
|
612 |
+
allowed_mask += allowed_mask.t()
|
613 |
+
return allowed_mask, current_ratio
|
614 |
+
|
615 |
+
|
616 |
+
def degree_sequence_log_likelihood(degree_sequence, d_min):
|
617 |
+
"""
|
618 |
+
Compute the (maximum) log likelihood of the Powerlaw distribution fit on a degree distribution.
|
619 |
+
"""
|
620 |
+
|
621 |
+
# Determine which degrees are to be considered, i.e. >= d_min.
|
622 |
+
D_G = degree_sequence[(degree_sequence >= d_min.item())]
|
623 |
+
try:
|
624 |
+
sum_log_degrees = torch.log(D_G).sum()
|
625 |
+
except:
|
626 |
+
sum_log_degrees = np.log(D_G).sum()
|
627 |
+
n = len(D_G)
|
628 |
+
|
629 |
+
alpha = compute_alpha(n, sum_log_degrees, d_min)
|
630 |
+
ll = compute_log_likelihood(n, alpha, sum_log_degrees, d_min)
|
631 |
+
return ll, alpha, n, sum_log_degrees
|
632 |
+
|
633 |
+
def updated_log_likelihood_for_edge_changes(node_pairs, adjacency_matrix, d_min):
|
634 |
+
""" Adopted from https://github.com/danielzuegner/nettack
|
635 |
+
"""
|
636 |
+
# For each node pair find out whether there is an edge or not in the input adjacency matrix.
|
637 |
+
|
638 |
+
edge_entries_before = adjacency_matrix[node_pairs.T]
|
639 |
+
degree_sequence = adjacency_matrix.sum(1)
|
640 |
+
D_G = degree_sequence[degree_sequence >= d_min.item()]
|
641 |
+
sum_log_degrees = torch.log(D_G).sum()
|
642 |
+
n = len(D_G)
|
643 |
+
deltas = -2 * edge_entries_before + 1
|
644 |
+
d_edges_before = degree_sequence[node_pairs]
|
645 |
+
|
646 |
+
d_edges_after = degree_sequence[node_pairs] + deltas[:, None]
|
647 |
+
|
648 |
+
# Sum the log of the degrees after the potential changes which are >= d_min
|
649 |
+
sum_log_degrees_after, new_n = update_sum_log_degrees(sum_log_degrees, n, d_edges_before, d_edges_after, d_min)
|
650 |
+
# Updated estimates of the Powerlaw exponents
|
651 |
+
new_alpha = compute_alpha(new_n, sum_log_degrees_after, d_min)
|
652 |
+
# Updated log likelihood values for the Powerlaw distributions
|
653 |
+
new_ll = compute_log_likelihood(new_n, new_alpha, sum_log_degrees_after, d_min)
|
654 |
+
return new_ll, new_alpha, new_n, sum_log_degrees_after
|
655 |
+
|
656 |
+
|
657 |
+
def update_sum_log_degrees(sum_log_degrees_before, n_old, d_old, d_new, d_min):
|
658 |
+
# Find out whether the degrees before and after the change are above the threshold d_min.
|
659 |
+
old_in_range = d_old >= d_min
|
660 |
+
new_in_range = d_new >= d_min
|
661 |
+
d_old_in_range = d_old * old_in_range.float()
|
662 |
+
d_new_in_range = d_new * new_in_range.float()
|
663 |
+
|
664 |
+
# Update the sum by subtracting the old values and then adding the updated logs of the degrees.
|
665 |
+
sum_log_degrees_after = sum_log_degrees_before - (torch.log(torch.clamp(d_old_in_range, min=1))).sum(1) \
|
666 |
+
+ (torch.log(torch.clamp(d_new_in_range, min=1))).sum(1)
|
667 |
+
|
668 |
+
# Update the number of degrees >= d_min
|
669 |
+
|
670 |
+
new_n = n_old - (old_in_range!=0).sum(1) + (new_in_range!=0).sum(1)
|
671 |
+
new_n = new_n.float()
|
672 |
+
return sum_log_degrees_after, new_n
|
673 |
+
|
674 |
+
def compute_alpha(n, sum_log_degrees, d_min):
|
675 |
+
try:
|
676 |
+
alpha = 1 + n / (sum_log_degrees - n * torch.log(d_min - 0.5))
|
677 |
+
except:
|
678 |
+
alpha = 1 + n / (sum_log_degrees - n * np.log(d_min - 0.5))
|
679 |
+
return alpha
|
680 |
+
|
681 |
+
def compute_log_likelihood(n, alpha, sum_log_degrees, d_min):
|
682 |
+
# Log likelihood under alpha
|
683 |
+
try:
|
684 |
+
ll = n * torch.log(alpha) + n * alpha * torch.log(d_min) + (alpha + 1) * sum_log_degrees
|
685 |
+
except:
|
686 |
+
ll = n * np.log(alpha) + n * alpha * np.log(d_min) + (alpha + 1) * sum_log_degrees
|
687 |
+
|
688 |
+
return ll
|
689 |
+
|
690 |
+
def ravel_multiple_indices(ixs, shape, reverse=False):
|
691 |
+
"""
|
692 |
+
"Flattens" multiple 2D input indices into indices on the flattened matrix, similar to np.ravel_multi_index.
|
693 |
+
Does the same as ravel_index but for multiple indices at once.
|
694 |
+
Parameters
|
695 |
+
----------
|
696 |
+
ixs: array of ints shape (n, 2)
|
697 |
+
The array of n indices that will be flattened.
|
698 |
+
|
699 |
+
shape: list or tuple of ints of length 2
|
700 |
+
The shape of the corresponding matrix.
|
701 |
+
|
702 |
+
Returns
|
703 |
+
-------
|
704 |
+
array of n ints between 0 and shape[0]*shape[1]-1
|
705 |
+
The indices on the flattened matrix corresponding to the 2D input indices.
|
706 |
+
|
707 |
+
"""
|
708 |
+
if reverse:
|
709 |
+
return ixs[:, 1] * shape[1] + ixs[:, 0]
|
710 |
+
|
711 |
+
return ixs[:, 0] * shape[1] + ixs[:, 1]
|
712 |
+
|
713 |
+
def visualize(your_var):
|
714 |
+
"""visualize computation graph"""
|
715 |
+
from graphviz import Digraph
|
716 |
+
import torch
|
717 |
+
from torch.autograd import Variable
|
718 |
+
from torchviz import make_dot
|
719 |
+
make_dot(your_var).view()
|
720 |
+
|
721 |
+
def reshape_mx(mx, shape):
|
722 |
+
indices = mx.nonzero()
|
723 |
+
return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape)
|
724 |
+
|
725 |
+
def add_mask(data, dataset):
|
726 |
+
"""data: ogb-arxiv pyg data format"""
|
727 |
+
# for arxiv
|
728 |
+
split_idx = dataset.get_idx_split()
|
729 |
+
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
|
730 |
+
n = data.x.shape[0]
|
731 |
+
data.train_mask = index_to_mask(train_idx, n)
|
732 |
+
data.val_mask = index_to_mask(valid_idx, n)
|
733 |
+
data.test_mask = index_to_mask(test_idx, n)
|
734 |
+
data.y = data.y.squeeze()
|
735 |
+
# data.edge_index = to_undirected(data.edge_index, data.num_nodes)
|
736 |
+
|
737 |
+
def index_to_mask(index, size):
|
738 |
+
mask = torch.zeros((size, ), dtype=torch.bool)
|
739 |
+
mask[index] = 1
|
740 |
+
return mask
|
741 |
+
|
742 |
+
def add_feature_noise(data, noise_ratio, seed):
|
743 |
+
np.random.seed(seed)
|
744 |
+
n, d = data.x.shape
|
745 |
+
# noise = torch.normal(mean=torch.zeros(int(noise_ratio*n), d), std=1)
|
746 |
+
noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d))).to(data.x.device)
|
747 |
+
indices = np.arange(n)
|
748 |
+
indices = np.random.permutation(indices)[: int(noise_ratio*n)]
|
749 |
+
delta_feat = torch.zeros_like(data.x)
|
750 |
+
delta_feat[indices] = noise - data.x[indices]
|
751 |
+
data.x[indices] = noise
|
752 |
+
mask = np.zeros(n)
|
753 |
+
mask[indices] = 1
|
754 |
+
mask = torch.tensor(mask).bool().to(data.x.device)
|
755 |
+
return delta_feat, mask
|
756 |
+
|
757 |
+
def add_feature_noise_test(data, noise_ratio, seed):
|
758 |
+
np.random.seed(seed)
|
759 |
+
n, d = data.x.shape
|
760 |
+
indices = np.arange(n)
|
761 |
+
test_nodes = indices[data.test_mask.cpu()]
|
762 |
+
selected = np.random.permutation(test_nodes)[: int(noise_ratio*len(test_nodes))]
|
763 |
+
noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d)))
|
764 |
+
noise = noise.to(data.x.device)
|
765 |
+
|
766 |
+
delta_feat = torch.zeros_like(data.x)
|
767 |
+
delta_feat[selected] = noise - data.x[selected]
|
768 |
+
data.x[selected] = noise
|
769 |
+
# mask = np.zeros(len(test_nodes))
|
770 |
+
mask = np.zeros(n)
|
771 |
+
mask[selected] = 1
|
772 |
+
mask = torch.tensor(mask).bool().to(data.x.device)
|
773 |
+
return delta_feat, mask
|
774 |
+
|
775 |
+
# def check_path(file_path):
|
776 |
+
# if not osp.exists(file_path):
|
777 |
+
# os.system(f'mkdir -p {file_path}')
|
778 |
+
|
deeprobust/image/README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Setup
|
2 |
+
```
|
3 |
+
git clone https://github.com/DSE-MSU/DeepRobust.git
|
4 |
+
cd DeepRobust
|
5 |
+
python setup.py install
|
6 |
+
```
|
7 |
+
|
8 |
+
# Full README
|
9 |
+
[click here](https://github.com/DSE-MSU/DeepRobust)
|
10 |
+
|
11 |
+
# Attack Methods
|
12 |
+
| Attack Methods | Attack Type | Apply Domain | Links |
|
13 |
+
|--------------------|-------------|--------------|------|
|
14 |
+
| LBFGS attack | White-Box | Image Classification | [Intriguing Properties of Neural Networks](https://arxiv.org/pdf/1312.6199.pdf?not-changed)|
|
15 |
+
| FGSM attack | White-Box | Image Classification | [Explaining and Harnessing Adversarial Examples](https://arxiv.org/pdf/1412.6572.pdf) |
|
16 |
+
| PGD attack | White-Box | Image Classification | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/pdf/1706.06083.pdf) |
|
17 |
+
| DeepFool attack | White-Box | Image Classification | [DeepFool: a simple and accurate method to fool deep neural network](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Moosavi-Dezfooli_DeepFool_A_Simple_CVPR_2016_paper.pdf) |
|
18 |
+
| CW attack | White-Box | Image Classification | [Towards Evaluating the Robustness of Neural Networks](https://arxiv.org/pdf/1608.04644.pdf) |
|
19 |
+
| One pixel attack | White-Box | Image Classification | [One pixel attack for fooling deep neural networks](https://arxiv.org/pdf/1710.08864.pdf) |
|
20 |
+
| BPDA attack | White-Box | Image Classification | [Obfuscated Gradients Give a False Sense of Security: Circumventing Defenses to Adversarial Examples](https://arxiv.org/pdf/1802.00420.pdf) |
|
21 |
+
| Universal attack | White-Box | Image Classification | [Universal adversarial perturbations](https://arxiv.org/pdf/1610.08401.pdf) |
|
22 |
+
| Nattack | Black-Box | Image Classification | [NATTACK: Learning the Distributions of Adversarial Examples for an Improved Black-Box Attack on Deep Neural Networks](https://arxiv.org/pdf/1905.00441.pdf) |
|
23 |
+
|
24 |
+
# Defense Methods
|
25 |
+
| Defense Methods | Defense Type | Apply Domain | Links |
|
26 |
+
|---------------------|--------------|--------------|------|
|
27 |
+
| FGSM training | Adverserial Training | Image Classification | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/pdf/1706.06083.pdf) |
|
28 |
+
| Fast(an improved version of FGSM training) | Adverserial Training | Image Classification | [Fast is better than free: Revisiting adversarial training](https://openreview.net/attachment?id=BJx040EFvH&name=original_pdf) |
|
29 |
+
| PGD training | Adverserial Training | Image Classification | [Intriguing Properties of Neural Networks](https://arxiv.org/pdf/1312.6199.pdf?not-changed)|
|
30 |
+
| YOPO(an improved version of PGD training) | Adverserial Training | Image Classification | [You Only Propagate Once: Accelerating Adversarial Training via Maximal Principle](https://arxiv.org/pdf/1905.00877.pdf) |
|
31 |
+
| TRADES | Adverserial Training | Image Classification | [Theoretically Principled Trade-off between Robustness and Accuracy](https://arxiv.org/pdf/1901.08573.pdf) |
|
32 |
+
| Thermometer Encoding | Gradient Masking | Image Classification | [Thermometer Encoding:One Hot Way To Resist Adversarial Examples](https://openreview.net/pdf?id=S18Su--CW) |
|
33 |
+
| LID-based adversarial classifier | Detection | Image Classification | [Characterizing Adversarial Subspaces Using Local Intrinsic Dimensionality](https://arxiv.org/pdf/1801.02613.pdf) |
|
34 |
+
|
35 |
+
# Support Datasets
|
36 |
+
- MNIST
|
37 |
+
- CIFAR-10
|
38 |
+
- ImageNet
|
39 |
+
|
40 |
+
# Support Networks
|
41 |
+
- CNN
|
42 |
+
- ResNet(ResNet18, ResNet34, ResNet50)
|
43 |
+
- VGG
|
44 |
+
- DenseNet
|
45 |
+
|
deeprobust/image/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from deeprobust.image import attack
|
4 |
+
from deeprobust.image import defense
|
5 |
+
from deeprobust.image import netmodels
|
6 |
+
|
7 |
+
__all__ = ['attack', 'defense', 'netmodels']
|
8 |
+
|
9 |
+
logging.info("import attack from image")
|
10 |
+
logging.info("import defense from defense")
|
11 |
+
logging.info("import netmodels from netmodels")
|
deeprobust/image/attack/Nattack.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import optim
|
3 |
+
import numpy as np
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from deeprobust.image.attack.base_attack import BaseAttack
|
7 |
+
from deeprobust.image.utils import onehot_like, arctanh
|
8 |
+
|
9 |
+
|
10 |
+
class NATTACK(BaseAttack):
|
11 |
+
"""
|
12 |
+
Nattack is a black box attack algorithm.
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
def __init__(self, model, device = 'cuda'):
|
17 |
+
super(NATTACK, self).__init__(model, device)
|
18 |
+
self.model = model
|
19 |
+
self.device = device
|
20 |
+
|
21 |
+
def generate(self, **kwargs):
|
22 |
+
"""
|
23 |
+
Call this function to generate adversarial examples.
|
24 |
+
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
kwargs :
|
28 |
+
user defined paremeters
|
29 |
+
"""
|
30 |
+
|
31 |
+
assert self.parse_params(**kwargs)
|
32 |
+
return attack(self.model, self.dataloader, self.classnum,
|
33 |
+
self.clip_max, self.clip_min, self.epsilon,
|
34 |
+
self.population, self.max_iterations,
|
35 |
+
self.learning_rate, self.sigma, self.target_or_not)
|
36 |
+
assert self.check_type_device(self.dataloader)
|
37 |
+
|
38 |
+
def parse_params(self,
|
39 |
+
dataloader,
|
40 |
+
classnum,
|
41 |
+
target_or_not = False,
|
42 |
+
clip_max = 1,
|
43 |
+
clip_min = 0,
|
44 |
+
epsilon = 0.2,
|
45 |
+
population = 300,
|
46 |
+
max_iterations = 400,
|
47 |
+
learning_rate = 2,
|
48 |
+
sigma = 0.1
|
49 |
+
):
|
50 |
+
"""parse_params.
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
dataloader :
|
55 |
+
dataloader
|
56 |
+
classnum :
|
57 |
+
classnum
|
58 |
+
target_or_not :
|
59 |
+
target_or_not
|
60 |
+
clip_max :
|
61 |
+
maximum pixel value
|
62 |
+
clip_min :
|
63 |
+
minimum pixel value
|
64 |
+
epsilon :
|
65 |
+
perturb constraint
|
66 |
+
population :
|
67 |
+
population
|
68 |
+
max_iterations :
|
69 |
+
maximum number of iterations
|
70 |
+
learning_rate :
|
71 |
+
learning rate
|
72 |
+
sigma :
|
73 |
+
sigma
|
74 |
+
"""
|
75 |
+
|
76 |
+
self.dataloader = dataloader
|
77 |
+
self.classnum = classnum
|
78 |
+
self.target_or_not = target_or_not
|
79 |
+
self.clip_max = clip_max
|
80 |
+
self.clip_min = clip_min
|
81 |
+
self.epsilon = epsilon
|
82 |
+
self.population = population
|
83 |
+
self.max_iterations = max_iterations
|
84 |
+
self.learning_rate = learning_rate
|
85 |
+
self.sigma = sigma
|
86 |
+
return True
|
87 |
+
|
88 |
+
def attack(model, loader, classnum, clip_max, clip_min, epsilon, population, max_iterations, learning_rate, sigma, target_or_not):
|
89 |
+
|
90 |
+
#initialization
|
91 |
+
totalImages = 0
|
92 |
+
succImages = 0
|
93 |
+
faillist = []
|
94 |
+
successlist = []
|
95 |
+
printlist = []
|
96 |
+
|
97 |
+
for i, (inputs, targets) in enumerate(loader):
|
98 |
+
|
99 |
+
success = False
|
100 |
+
print('attack picture No. ' + str(i))
|
101 |
+
|
102 |
+
c = inputs.size(1) # chanel
|
103 |
+
l = inputs.size(2) # length
|
104 |
+
w = inputs.size(3) # width
|
105 |
+
|
106 |
+
mu = arctanh((inputs * 2) - 1)
|
107 |
+
#mu = torch.from_numpy(np.random.randn(1, c, l, w) * 0.001).float() # random initialize mean
|
108 |
+
predict = model.forward(inputs)
|
109 |
+
|
110 |
+
## skip wrongly classified samples
|
111 |
+
if predict.argmax(dim = 1, keepdim = True) != targets:
|
112 |
+
print('skip the wrong example ', i)
|
113 |
+
continue
|
114 |
+
totalImages += 1
|
115 |
+
|
116 |
+
## finding most possible mean
|
117 |
+
for runstep in range(max_iterations):
|
118 |
+
|
119 |
+
# sample points from normal distribution
|
120 |
+
eps = torch.from_numpy(np.random.randn(population, c, l, w)).float()
|
121 |
+
z = mu.repeat(population, 1, 1, 1) + sigma * eps
|
122 |
+
|
123 |
+
# calculate g_z
|
124 |
+
g_z = np.tanh(z) * 1 / 2 + 1 / 2
|
125 |
+
|
126 |
+
# testing whether exists successful attack every 10 iterations.
|
127 |
+
if runstep % 10 == 0:
|
128 |
+
|
129 |
+
realdist = g_z - inputs
|
130 |
+
|
131 |
+
realclipdist = np.clip(realdist, -epsilon, epsilon).float()
|
132 |
+
realclipinput = realclipdist + inputs
|
133 |
+
|
134 |
+
info = 'inputs.shape__' + str(inputs.shape)
|
135 |
+
logging.debug(info)
|
136 |
+
|
137 |
+
predict = model.forward(realclipinput)
|
138 |
+
|
139 |
+
#pending attack
|
140 |
+
if (target_or_not == False):
|
141 |
+
|
142 |
+
if sum(predict.argmax(dim = 1, keepdim = True)[0] != targets) > 0 and (np.abs(realclipdist).max() <= epsilon):
|
143 |
+
succImages += 1
|
144 |
+
success = True
|
145 |
+
print('succeed attack Images: '+str(succImages)+' totalImages: '+str(totalImages))
|
146 |
+
print('steps: '+ str(runstep))
|
147 |
+
successlist.append(i)
|
148 |
+
printlist.append(runstep)
|
149 |
+
break
|
150 |
+
|
151 |
+
# calculate distance
|
152 |
+
dist = g_z - inputs
|
153 |
+
clipdist = np.clip(dist, -epsilon, epsilon)
|
154 |
+
proj_g_z = inputs + clipdist
|
155 |
+
proj_g_z = proj_g_z.float()
|
156 |
+
outputs = model.forward(proj_g_z)
|
157 |
+
|
158 |
+
# get cw loss on sampled images
|
159 |
+
target_onehot = np.zeros((1,classnum))
|
160 |
+
target_onehot[0][targets]=1.
|
161 |
+
real = (target_onehot * outputs.detach().numpy()).sum(1)
|
162 |
+
other = ((1. - target_onehot) * outputs.detach().numpy() - target_onehot * 10000.).max(1)
|
163 |
+
loss1 = np.clip(real - other, a_min= 0, a_max= 1e10)
|
164 |
+
Reward = 0.5 * loss1
|
165 |
+
|
166 |
+
# update mean by nes
|
167 |
+
A = ((Reward - np.mean(Reward)) / (np.std(Reward)+1e-7))
|
168 |
+
A = np.array(A, dtype= np.float32)
|
169 |
+
|
170 |
+
mu = mu - torch.from_numpy((learning_rate/(population*sigma)) *
|
171 |
+
((np.dot(eps.reshape(population,-1).T, A)).reshape(1, 1, 28, 28)))
|
172 |
+
|
173 |
+
if not success:
|
174 |
+
faillist.append(i)
|
175 |
+
print('failed:',faillist.__len__())
|
176 |
+
print('....................................')
|
177 |
+
else:
|
178 |
+
#print('succeed:',successlist.__len__())
|
179 |
+
print('....................................')
|
180 |
+
|
181 |
+
|
deeprobust/image/attack/fgsm.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.optim as optim
|
6 |
+
import numpy as np
|
7 |
+
from numpy import linalg as LA
|
8 |
+
|
9 |
+
from deeprobust.image.attack.base_attack import BaseAttack
|
10 |
+
|
11 |
+
class FGSM(BaseAttack):
|
12 |
+
"""
|
13 |
+
FGSM attack is an one step gradient descent method.
|
14 |
+
|
15 |
+
"""
|
16 |
+
def __init__(self, model, device = 'cuda'):
|
17 |
+
|
18 |
+
super(FGSM, self).__init__(model, device)
|
19 |
+
|
20 |
+
def generate(self, image, label, **kwargs):
|
21 |
+
""""
|
22 |
+
Call this function to generate FGSM adversarial examples.
|
23 |
+
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
image :
|
27 |
+
original image
|
28 |
+
label :
|
29 |
+
target label
|
30 |
+
kwargs :
|
31 |
+
user defined paremeters
|
32 |
+
"""
|
33 |
+
|
34 |
+
label = label.type(torch.FloatTensor)
|
35 |
+
|
36 |
+
## check and parse parameters for attack
|
37 |
+
assert self.check_type_device(image, label)
|
38 |
+
assert self.parse_params(**kwargs)
|
39 |
+
|
40 |
+
return fgm(self.model,
|
41 |
+
self.image,
|
42 |
+
self.label,
|
43 |
+
self.epsilon,
|
44 |
+
self.order,
|
45 |
+
self.clip_min,
|
46 |
+
self.clip_max,
|
47 |
+
self.device)
|
48 |
+
|
49 |
+
def parse_params(self,
|
50 |
+
epsilon = 0.2,
|
51 |
+
order = np.inf,
|
52 |
+
clip_max = None,
|
53 |
+
clip_min = None):
|
54 |
+
"""
|
55 |
+
Parse the user defined parameters.
|
56 |
+
:param model: victim model
|
57 |
+
:param image: original attack images
|
58 |
+
:param label: target labels
|
59 |
+
:param epsilon: perturbation constraint
|
60 |
+
:param order: constraint type
|
61 |
+
:param clip_min: minimum pixel value
|
62 |
+
:param clip_max: maximum pixel value
|
63 |
+
:param device: device type, cpu or gpu
|
64 |
+
|
65 |
+
:type image: [N*C*H*W],floatTensor
|
66 |
+
:type label: int
|
67 |
+
:type epsilon: float
|
68 |
+
:type order: int
|
69 |
+
:type clip_min: float
|
70 |
+
:type clip_max: float
|
71 |
+
:type device: string('cpu' or 'cuda')
|
72 |
+
|
73 |
+
:return: perturbed images
|
74 |
+
:rtype: [N*C*H*W], floatTensor
|
75 |
+
|
76 |
+
"""
|
77 |
+
self.epsilon = epsilon
|
78 |
+
self.order = order
|
79 |
+
self.clip_max = clip_max
|
80 |
+
self.clip_min = clip_min
|
81 |
+
return True
|
82 |
+
|
83 |
+
|
84 |
+
def fgm(model, image, label, epsilon, order, clip_min, clip_max, device):
|
85 |
+
imageArray = image.cpu().detach().numpy()
|
86 |
+
X_fgsm = torch.tensor(imageArray).to(device)
|
87 |
+
|
88 |
+
#print(image.data)
|
89 |
+
|
90 |
+
X_fgsm.requires_grad = True
|
91 |
+
|
92 |
+
opt = optim.SGD([X_fgsm], lr=1e-3)
|
93 |
+
opt.zero_grad()
|
94 |
+
|
95 |
+
loss = nn.CrossEntropyLoss()(model(X_fgsm), label)
|
96 |
+
|
97 |
+
loss.backward()
|
98 |
+
#print(X_fgsm)
|
99 |
+
#print(X_fgsm.grad)
|
100 |
+
if order == np.inf:
|
101 |
+
d = epsilon * X_fgsm.grad.data.sign()
|
102 |
+
elif order == 2:
|
103 |
+
gradient = X_fgsm.grad
|
104 |
+
d = torch.zeros(gradient.shape, device = device)
|
105 |
+
for i in range(gradient.shape[0]):
|
106 |
+
norm_grad = gradient[i].data/LA.norm(gradient[i].data.cpu().numpy())
|
107 |
+
d[i] = norm_grad * epsilon
|
108 |
+
else:
|
109 |
+
raise ValueError('Other p norms may need other algorithms')
|
110 |
+
|
111 |
+
x_adv = X_fgsm + d
|
112 |
+
|
113 |
+
if clip_max == None and clip_min == None:
|
114 |
+
clip_max = np.inf
|
115 |
+
clip_min = -np.inf
|
116 |
+
|
117 |
+
x_adv = torch.clamp(x_adv,clip_min, clip_max)
|
118 |
+
|
119 |
+
return x_adv
|
120 |
+
|
121 |
+
|
deeprobust/image/attack/onepixel.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.optim as optim
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.backends.cudnn as cudnn
|
10 |
+
|
11 |
+
import torchvision
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from deeprobust.image.optimizer import differential_evolution
|
16 |
+
from deeprobust.image.attack.base_attack import BaseAttack
|
17 |
+
from deeprobust.image.utils import progress_bar
|
18 |
+
|
19 |
+
class Onepixel(BaseAttack):
|
20 |
+
"""
|
21 |
+
Onepixel attack is an algorithm that allow attacker to only manipulate one (or a few) pixel to mislead classifier.
|
22 |
+
This is a re-implementation of One pixel attack.
|
23 |
+
Copyright (c) 2018 Debang Li
|
24 |
+
|
25 |
+
References
|
26 |
+
----------
|
27 |
+
Akhtar, N., & Mian, A. (2018).Threat of Adversarial Attacks on Deep Learning in Computer Vision: A Survey: A Survey. IEEE Access, 6, 14410-14430.
|
28 |
+
|
29 |
+
Reference code: https://github.com/DebangLi/one-pixel-attack-pytorch
|
30 |
+
"""
|
31 |
+
|
32 |
+
|
33 |
+
def __init__(self, model, device = 'cuda'):
|
34 |
+
|
35 |
+
super(Onepixel, self).__init__(model, device)
|
36 |
+
|
37 |
+
def generate(self, image, label, **kwargs):
|
38 |
+
"""
|
39 |
+
Call this function to generate Onepixel adversarial examples.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
image :1*3*W*H
|
44 |
+
original image
|
45 |
+
label :
|
46 |
+
target label
|
47 |
+
kwargs :
|
48 |
+
user defined paremeters
|
49 |
+
"""
|
50 |
+
|
51 |
+
label = label.type(torch.FloatTensor)
|
52 |
+
|
53 |
+
## check and parse parameters for attack
|
54 |
+
assert self.check_type_device(image, label)
|
55 |
+
assert self.parse_params(**kwargs)
|
56 |
+
|
57 |
+
return self.one_pixel(self.image,
|
58 |
+
self.label,
|
59 |
+
self.targeted_attack,
|
60 |
+
self.pixels,
|
61 |
+
self.maxiter,
|
62 |
+
self.popsize,
|
63 |
+
self.print_log)
|
64 |
+
|
65 |
+
def get_pred():
|
66 |
+
return self.adv_pred
|
67 |
+
|
68 |
+
def parse_params(self,
|
69 |
+
pixels = 1,
|
70 |
+
maxiter = 100,
|
71 |
+
popsize = 400,
|
72 |
+
samples = 100,
|
73 |
+
targeted_attack = False,
|
74 |
+
print_log = True,
|
75 |
+
target = 0):
|
76 |
+
|
77 |
+
"""
|
78 |
+
Parse the user-defined params.
|
79 |
+
|
80 |
+
Parameters
|
81 |
+
----------
|
82 |
+
pixels :
|
83 |
+
maximum number of manipulated pixels
|
84 |
+
maxiter :
|
85 |
+
maximum number of iteration
|
86 |
+
popsize :
|
87 |
+
population size
|
88 |
+
samples :
|
89 |
+
samples
|
90 |
+
targeted_attack :
|
91 |
+
targeted attack or not
|
92 |
+
print_log :
|
93 |
+
Set print_log = True to print out details in the searching algorithm
|
94 |
+
target :
|
95 |
+
target label (if targeted attack is set to be True)
|
96 |
+
"""
|
97 |
+
|
98 |
+
self.pixels = pixels
|
99 |
+
self.maxiter = maxiter
|
100 |
+
self.popsize = popsize
|
101 |
+
self.samples = samples
|
102 |
+
self.targeted_attack = targeted_attack
|
103 |
+
self.print_log = print_log
|
104 |
+
self.target = target
|
105 |
+
return True
|
106 |
+
|
107 |
+
|
108 |
+
def one_pixel(self, img, label, targeted_attack = False, target = 0, pixels = 1, maxiter = 75, popsize = 400, print_log = False):
|
109 |
+
# label: a number
|
110 |
+
|
111 |
+
target_calss = target if targeted_attack else label
|
112 |
+
|
113 |
+
bounds = [(0,32), (0,32), (0,255), (0,255), (0,255)] * pixels
|
114 |
+
|
115 |
+
popmul = max(1, popsize/len(bounds))
|
116 |
+
|
117 |
+
predict_fn = lambda xs: predict_classes(
|
118 |
+
xs, img, target_calss, self.model, targeted_attack, self.device)
|
119 |
+
callback_fn = lambda x, convergence: attack_success(
|
120 |
+
x, img, target_calss, self.model, targeted_attack, print_log, self.device)
|
121 |
+
|
122 |
+
inits = np.zeros([popmul*len(bounds), len(bounds)])
|
123 |
+
for init in inits:
|
124 |
+
for i in range(pixels):
|
125 |
+
init[i*5+0] = np.random.random()*32
|
126 |
+
init[i*5+1] = np.random.random()*32
|
127 |
+
init[i*5+2] = np.random.normal(128,127)
|
128 |
+
init[i*5+3] = np.random.normal(128,127)
|
129 |
+
init[i*5+4] = np.random.normal(128,127)
|
130 |
+
|
131 |
+
attack_result = differential_evolution(predict_fn, bounds, maxiter = maxiter, popsize = popmul,
|
132 |
+
recombination = 1, atol = -1, callback = callback_fn, polish = False, init = inits)
|
133 |
+
|
134 |
+
attack_image = perturb_image(attack_result.x, img)
|
135 |
+
attack_var = Variable(attack_image, volatile=True).cuda()
|
136 |
+
predicted_probs = F.softmax(self.model(attack_var)).data.cpu().numpy()[0]
|
137 |
+
|
138 |
+
predicted_class = np.argmax(predicted_probs)
|
139 |
+
|
140 |
+
if (not targeted_attack and predicted_class != label) or (targeted_attack and predicted_class == target_calss):
|
141 |
+
self.adv_pred = predicted_class
|
142 |
+
return attack_image
|
143 |
+
return [None]
|
144 |
+
|
145 |
+
def perturb_image(xs, img):
|
146 |
+
|
147 |
+
if xs.ndim < 2:
|
148 |
+
xs = np.array([xs])
|
149 |
+
batch = len(xs)
|
150 |
+
imgs = img.repeat(batch, 1, 1, 1)
|
151 |
+
xs = xs.astype(int)
|
152 |
+
|
153 |
+
count = 0
|
154 |
+
|
155 |
+
for x in xs:
|
156 |
+
pixels = np.split(x, len(x)/5)
|
157 |
+
for pixel in pixels:
|
158 |
+
x_pos, y_pos, r, g, b = pixel
|
159 |
+
imgs[count, 0, x_pos, y_pos] = (r/255.0-0.4914)/0.2023
|
160 |
+
imgs[count, 1, x_pos, y_pos] = (g/255.0-0.4822)/0.1994
|
161 |
+
imgs[count, 2, x_pos, y_pos] = (b/255.0-0.4465)/0.2010
|
162 |
+
count += 1
|
163 |
+
|
164 |
+
return imgs
|
165 |
+
|
166 |
+
def predict_classes(xs, img, target_calss, net, minimize=True, device = 'cuda'):
|
167 |
+
imgs_perturbed = perturb_image(xs, img.clone()).to(device)
|
168 |
+
predictions = F.softmax(net(imgs_perturbed)).data.cpu().numpy()[:, target_calss]
|
169 |
+
|
170 |
+
return predictions if minimize else 1 - predictions
|
171 |
+
|
172 |
+
def attack_success(x, img, target_calss, net, targeted_attack = False, print_log=False, device = 'cuda'):
|
173 |
+
|
174 |
+
attack_image = perturb_image(x, img.clone()).to(device)
|
175 |
+
confidence = F.softmax(net(attack_image)).data.cpu().numpy()[0]
|
176 |
+
pred = np.argmax(confidence)
|
177 |
+
|
178 |
+
if (print_log):
|
179 |
+
print("Confidence: %.4f"%confidence[target_calss])
|
180 |
+
if (targeted_attack and pred == target_calss) or (not targeted_attack and pred != target_calss):
|
181 |
+
return True
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
deeprobust/image/defense/AWP.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of pgd adversarial training.
|
3 |
+
References
|
4 |
+
----------
|
5 |
+
..[1]Mądry, A., Makelov, A., Schmidt, L., Tsipras, D., & Vladu, A. (2017).
|
6 |
+
Towards Deep Learning Models Resistant to Adversarial Attacks. stat, 1050, 9.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.optim as optim
|
13 |
+
from torchvision import datasets, transforms
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
from PIL import Image
|
18 |
+
from deeprobust.image.attack.pgd import PGD
|
19 |
+
from deeprobust.image.netmodels.CNN import Net
|
20 |
+
from deeprobust.image.defense.base_defense import BaseDefense
|
21 |
+
|
22 |
+
EPS = 1E-20
|
23 |
+
|
24 |
+
def diff_in_weights(model, proxy):
|
25 |
+
diff_dict = OrderedDict()
|
26 |
+
model_state_dict = model.state_dict()
|
27 |
+
proxy_state_dict = proxy.state_dict()
|
28 |
+
for (old_k, old_w), (new_k, new_w) in zip(model_state_dict.items(), proxy_state_dict.items()):
|
29 |
+
if len(old_w.size()) <= 1:
|
30 |
+
continue
|
31 |
+
if 'weight' in old_k:
|
32 |
+
diff_w = new_w - old_w
|
33 |
+
diff_dict[old_k] = old_w.norm() / (diff_w.norm() + EPS) * diff_w
|
34 |
+
return diff_dict
|
35 |
+
|
36 |
+
|
37 |
+
def add_into_weights(model, diff, coeff=1.0):
|
38 |
+
names_in_diff = diff.keys()
|
39 |
+
with torch.no_grad():
|
40 |
+
for name, param in model.named_parameters():
|
41 |
+
if name in names_in_diff:
|
42 |
+
param.add_(coeff * diff[name])
|
43 |
+
|
44 |
+
class pgd_AWP(object):
|
45 |
+
def __init__(self, model, proxy, proxy_optim, gamma):
|
46 |
+
super(pgd_AWP, self).__init__()
|
47 |
+
self.model = model
|
48 |
+
self.proxy = proxy
|
49 |
+
self.proxy_optim = proxy_optim
|
50 |
+
self.gamma = gamma
|
51 |
+
|
52 |
+
def calc_awp(self, adv_samples, clean_samples, labels, weight, weight1, temp, adv_connect, adv_upweight):
|
53 |
+
self.proxy.load_state_dict(self.model.state_dict())
|
54 |
+
self.proxy.train()
|
55 |
+
|
56 |
+
# compute adv loss
|
57 |
+
logits_clean, features_clean = self.proxy(clean_samples, feat = True)
|
58 |
+
#loss_clean = F.cross_entropy(logits_clean, labels)
|
59 |
+
|
60 |
+
# compute adv loss
|
61 |
+
logits_adv, features_adv = self.proxy(adv_samples, feat = True)
|
62 |
+
#loss_adv = F.cross_entropy(logits_adv, labels)
|
63 |
+
|
64 |
+
loss = F.cross_entropy(logits_adv, labels)
|
65 |
+
|
66 |
+
# final loss
|
67 |
+
loss = - 1 * loss
|
68 |
+
|
69 |
+
self.proxy_optim.zero_grad()
|
70 |
+
loss.backward()
|
71 |
+
self.proxy_optim.step()
|
72 |
+
|
73 |
+
# the adversary weight perturb
|
74 |
+
diff = diff_in_weights(self.model, self.proxy)
|
75 |
+
return diff
|
76 |
+
|
77 |
+
def perturb(self, diff):
|
78 |
+
add_into_weights(self.model, diff, coeff=1.0 * self.gamma)
|
79 |
+
|
80 |
+
def restore(self, diff):
|
81 |
+
add_into_weights(self.model, diff, coeff=-1.0 * self.gamma)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
class AWP_AT(BaseDefense):
|
86 |
+
"""
|
87 |
+
PGD adversarial training with adversarial weight perturbation.
|
88 |
+
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, model, device):
|
92 |
+
if not torch.cuda.is_available():
|
93 |
+
print('CUDA not availiable, using cpu...')
|
94 |
+
self.device = 'cpu'
|
95 |
+
else:
|
96 |
+
self.device = device
|
97 |
+
|
98 |
+
self.model = model
|
99 |
+
|
100 |
+
def generate(self, train_loader, test_loader, **kwargs):
|
101 |
+
"""Call this function to generate robust model.
|
102 |
+
|
103 |
+
Parameters
|
104 |
+
----------
|
105 |
+
train_loader :
|
106 |
+
training data loader
|
107 |
+
test_loader :
|
108 |
+
testing data loader
|
109 |
+
kwargs :
|
110 |
+
kwargs
|
111 |
+
"""
|
112 |
+
self.parse_params(**kwargs)
|
113 |
+
|
114 |
+
torch.manual_seed(100)
|
115 |
+
device = torch.device(self.device)
|
116 |
+
|
117 |
+
optimizer = optim.Adam(self.model.parameters(), self.lr)
|
118 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 100], gamma = 0.1)
|
119 |
+
save_model = True
|
120 |
+
for epoch in range(1, self.epoch + 1):
|
121 |
+
print('Training epoch: ', epoch, flush = True)
|
122 |
+
|
123 |
+
self.train(self.device, train_loader, optimizer, epoch)
|
124 |
+
self.test(self.model, self.device, test_loader)
|
125 |
+
|
126 |
+
if (self.save_model and epoch % self.save_per_epoch == 0):
|
127 |
+
if os.path.isdir(str(self.save_dir)):
|
128 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name + '_epoch' + str(epoch) + '.pth'))
|
129 |
+
print("model saved in " + str(self.save_dir))
|
130 |
+
else:
|
131 |
+
print("make new directory and save model in " + str(self.save_dir))
|
132 |
+
os.mkdir('./' + str(self.save_dir))
|
133 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name + '_epoch' + str(epoch) + '.pth'))
|
134 |
+
|
135 |
+
scheduler.step()
|
136 |
+
|
137 |
+
return self.model
|
138 |
+
|
139 |
+
def parse_params(self,
|
140 |
+
epoch_num = 100,
|
141 |
+
save_dir = "./defense_models",
|
142 |
+
save_name = "AWP_pgdtraining_0.3",
|
143 |
+
save_model = True,
|
144 |
+
epsilon = 8.0 / 255.0,
|
145 |
+
num_steps = 10,
|
146 |
+
perturb_step_size = 0.01,
|
147 |
+
lr = 0.1,
|
148 |
+
momentum = 0.1,
|
149 |
+
save_per_epoch = 10):
|
150 |
+
"""Parameter parser.
|
151 |
+
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
epoch_num : int
|
155 |
+
epoch
|
156 |
+
save_dir : str
|
157 |
+
model dir
|
158 |
+
save_name : str
|
159 |
+
model name
|
160 |
+
save_model : bool
|
161 |
+
Whether to save model
|
162 |
+
epsilon : float
|
163 |
+
attack constraint
|
164 |
+
num_steps : int
|
165 |
+
PGD attack iteration time
|
166 |
+
perturb_step_size : float
|
167 |
+
perturb step size
|
168 |
+
lr : float
|
169 |
+
learning rate for adversary training process
|
170 |
+
momentum : float
|
171 |
+
momentum for optimizor
|
172 |
+
"""
|
173 |
+
self.epoch = epoch_num
|
174 |
+
self.save_model = True
|
175 |
+
self.save_dir = save_dir
|
176 |
+
self.save_name = save_name
|
177 |
+
self.epsilon = epsilon
|
178 |
+
self.num_steps = num_steps
|
179 |
+
self.perturb_step_size = perturb_step_size
|
180 |
+
self.lr = lr
|
181 |
+
self.momentum = momentum
|
182 |
+
self.save_per_epoch = save_per_epoch
|
183 |
+
|
184 |
+
def train(self, device, train_loader, optimizer, epoch):
|
185 |
+
"""
|
186 |
+
training process.
|
187 |
+
|
188 |
+
Parameters
|
189 |
+
----------
|
190 |
+
device :
|
191 |
+
device
|
192 |
+
train_loader :
|
193 |
+
training data loader
|
194 |
+
optimizer :
|
195 |
+
optimizer
|
196 |
+
epoch :
|
197 |
+
training epoch
|
198 |
+
"""
|
199 |
+
|
200 |
+
self.model.train()
|
201 |
+
correct = 0
|
202 |
+
bs = train_loader.batch_size
|
203 |
+
#scheduler = StepLR(optimizer, step_size = 10, gamma = 0.5)
|
204 |
+
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [70], gamma = 0.1)
|
205 |
+
awp_adversary = pgd_AWP(model = self.model, proxy = proxy, proxy_optim = proxy_optim, gamma=opt.awp_gamma)
|
206 |
+
|
207 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
208 |
+
|
209 |
+
optimizer.zero_grad()
|
210 |
+
|
211 |
+
data, target = data.to(device), target.to(device)
|
212 |
+
|
213 |
+
data_adv, output = self.adv_data(data, target, ep = self.epsilon, num_steps = self.num_steps, perturb_step_size = self.perturb_step_size)
|
214 |
+
|
215 |
+
awp = awp_adversary.calc_awp(adv_samples = adv_samples, clean_samples= clean_samples, labels = labels, weight = opt.weight, weight1 = opt.weight1, temp = opt.temp, adv_connect = opt.adv_connect, adv_upweight = opt.adv_upweight)
|
216 |
+
awp_adversary.perturb(awp)
|
217 |
+
|
218 |
+
|
219 |
+
loss = self.calculate_loss(output, target)
|
220 |
+
|
221 |
+
loss.backward()
|
222 |
+
optimizer.step()
|
223 |
+
|
224 |
+
pred = output.argmax(dim = 1, keepdim = True)
|
225 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
226 |
+
|
227 |
+
#print every 10
|
228 |
+
if batch_idx % 20 == 0:
|
229 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy:{:.2f}%'.format(
|
230 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
231 |
+
100. * batch_idx / len(train_loader), loss.item(), 100 * correct/(bs)))
|
232 |
+
correct = 0
|
233 |
+
|
234 |
+
scheduler.step()
|
235 |
+
|
236 |
+
|
237 |
+
def test(self, model, device, test_loader):
|
238 |
+
"""
|
239 |
+
testing process.
|
240 |
+
|
241 |
+
Parameters
|
242 |
+
----------
|
243 |
+
model :
|
244 |
+
model
|
245 |
+
device :
|
246 |
+
device
|
247 |
+
test_loader :
|
248 |
+
testing dataloder
|
249 |
+
"""
|
250 |
+
model.eval()
|
251 |
+
|
252 |
+
test_loss = 0
|
253 |
+
correct = 0
|
254 |
+
test_loss_adv = 0
|
255 |
+
correct_adv = 0
|
256 |
+
for data, target in test_loader:
|
257 |
+
data, target = data.to(device), target.to(device)
|
258 |
+
|
259 |
+
# print clean accuracy
|
260 |
+
output = model(data)
|
261 |
+
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
|
262 |
+
pred = output.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
263 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
264 |
+
|
265 |
+
# print adversarial accuracy
|
266 |
+
data_adv, output_adv = self.adv_data(data, target, ep = self.epsilon, num_steps = self.num_steps)
|
267 |
+
|
268 |
+
test_loss_adv += self.calculate_loss(output_adv, target, redmode = 'sum').item() # sum up batch loss
|
269 |
+
pred_adv = output_adv.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
270 |
+
correct_adv += pred_adv.eq(target.view_as(pred_adv)).sum().item()
|
271 |
+
|
272 |
+
test_loss /= len(test_loader.dataset)
|
273 |
+
test_loss_adv /= len(test_loader.dataset)
|
274 |
+
|
275 |
+
print('\nTest set: Clean loss: {:.3f}, Clean Accuracy: {}/{} ({:.0f}%)\n'.format(
|
276 |
+
test_loss, correct, len(test_loader.dataset),
|
277 |
+
100. * correct / len(test_loader.dataset)))
|
278 |
+
|
279 |
+
print('\nTest set: Adv loss: {:.3f}, Adv Accuracy: {}/{} ({:.0f}%)\n'.format(
|
280 |
+
test_loss_adv, correct_adv, len(test_loader.dataset),
|
281 |
+
100. * correct_adv / len(test_loader.dataset)))
|
282 |
+
|
283 |
+
def adv_data(self, data, output, ep = 0.3, num_steps = 10, perturb_step_size = 0.01):
|
284 |
+
"""
|
285 |
+
Generate input(adversarial) data for training.
|
286 |
+
"""
|
287 |
+
|
288 |
+
adversary = PGD(self.model)
|
289 |
+
data_adv = adversary.generate(data, output.flatten(), epsilon = ep, num_steps = num_steps, step_size = perturb_step_size)
|
290 |
+
output = self.model(data_adv)
|
291 |
+
|
292 |
+
return data_adv, output
|
293 |
+
|
294 |
+
def calculate_loss(self, output, target, redmode = 'mean'):
|
295 |
+
"""
|
296 |
+
Calculate loss for training.
|
297 |
+
"""
|
298 |
+
|
299 |
+
loss = F.cross_entropy(output, target, reduction = redmode)
|
300 |
+
return loss
|
301 |
+
|
deeprobust/image/defense/TherEncoding.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of Thermometer Encoding.
|
3 |
+
|
4 |
+
References
|
5 |
+
----------
|
6 |
+
.. [1] Buckman, Jacob, Aurko Roy, Colin Raffel, and Ian Goodfellow. "Thermometer encoding: One hot way to resist adversarial examples." In International Conference on Learning Representations. 2018.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from torchvision import datasets, transforms
|
16 |
+
from deeprobust.image.netmodels.CNN import Net
|
17 |
+
|
18 |
+
import logging
|
19 |
+
|
20 |
+
## TODO
|
21 |
+
# class ther_attack(pgd_attack):
|
22 |
+
# """
|
23 |
+
# PGD attacks in response to thermometer encoding models
|
24 |
+
# """
|
25 |
+
## TODO
|
26 |
+
# def adv_train():
|
27 |
+
# """
|
28 |
+
# adversarial training for thermomoter encoding
|
29 |
+
# """
|
30 |
+
|
31 |
+
def train(model, device, train_loader, optimizer, epoch):
|
32 |
+
"""training process.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
model :
|
37 |
+
model
|
38 |
+
device :
|
39 |
+
device
|
40 |
+
train_loader :
|
41 |
+
training data loader
|
42 |
+
optimizer :
|
43 |
+
optimizer
|
44 |
+
epoch :
|
45 |
+
epoch
|
46 |
+
"""
|
47 |
+
logger.info('trainging')
|
48 |
+
model.train()
|
49 |
+
correct = 0
|
50 |
+
bs = train_loader.batch_size
|
51 |
+
|
52 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
53 |
+
|
54 |
+
optimizer.zero_grad()
|
55 |
+
data, target = data.to(device), target.to(device)
|
56 |
+
|
57 |
+
encoding = Thermometer(data, LEVELS)
|
58 |
+
encoding = encoding.permute(0, 2, 3, 1, 4)
|
59 |
+
encoding = torch.flatten(encoding, start_dim = 3)
|
60 |
+
encoding = encoding.permute(0, 3, 1, 2)
|
61 |
+
|
62 |
+
#print(encoding.size())
|
63 |
+
|
64 |
+
#ipdb.set_trace()
|
65 |
+
output = model(encoding)
|
66 |
+
|
67 |
+
loss = F.nll_loss(output, target)
|
68 |
+
loss.backward()
|
69 |
+
|
70 |
+
optimizer.step()
|
71 |
+
|
72 |
+
pred = output.argmax(dim = 1, keepdim = True)
|
73 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
74 |
+
|
75 |
+
#print(pred,target)
|
76 |
+
#print every 10
|
77 |
+
if batch_idx % 10 == 0:
|
78 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy:{:.2f}%'.format(
|
79 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
80 |
+
100. * batch_idx / len(train_loader), loss.item(), 100 * correct/(10*bs)))
|
81 |
+
correct = 0
|
82 |
+
a = input()
|
83 |
+
|
84 |
+
|
85 |
+
def test(model, device, test_loader):
|
86 |
+
model.eval()
|
87 |
+
|
88 |
+
test_loss = 0
|
89 |
+
correct = 0
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
for data, target in test_loader:
|
93 |
+
data, target = data.to(device), target.to(device)
|
94 |
+
|
95 |
+
encoding = Thermometer(data, LEVELS)
|
96 |
+
encoding = encoding.permute(0, 2, 3, 1, 4)
|
97 |
+
encoding = torch.flatten(encoding, start_dim=3)
|
98 |
+
encoding = encoding.permute(0, 3, 1, 2)
|
99 |
+
|
100 |
+
# print clean accuracy
|
101 |
+
output = model(encoding)
|
102 |
+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
103 |
+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
104 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
105 |
+
|
106 |
+
test_loss /= len(test_loader.dataset)
|
107 |
+
|
108 |
+
print('\nTest set: Clean loss: {:.3f}, Clean Accuracy: {}/{} ({:.0f}%)\n'.format(
|
109 |
+
test_loss, correct, len(test_loader.dataset),
|
110 |
+
100. * correct / len(test_loader.dataset)))
|
111 |
+
|
112 |
+
def Thermometer(x, levels, flattened = False):
|
113 |
+
"""
|
114 |
+
Output
|
115 |
+
------
|
116 |
+
Thermometer Encoding of the input.
|
117 |
+
"""
|
118 |
+
|
119 |
+
onehot = one_hot(x, levels)
|
120 |
+
|
121 |
+
thermometer = one_hot_to_thermometer(onehot, levels)
|
122 |
+
|
123 |
+
return thermometer
|
124 |
+
|
125 |
+
def one_hot(x, levels):
|
126 |
+
"""
|
127 |
+
Output
|
128 |
+
------
|
129 |
+
One hot Encoding of the input.
|
130 |
+
"""
|
131 |
+
|
132 |
+
batch_size, channel, H, W = x.size()
|
133 |
+
x = x.unsqueeze_(4)
|
134 |
+
x = torch.ceil(x * (LEVELS-1)).long()
|
135 |
+
onehot = torch.zeros(batch_size, channel, H, W, levels).float().to('cuda').scatter_(4, x, 1)
|
136 |
+
#print(onehot)
|
137 |
+
|
138 |
+
return onehot
|
139 |
+
|
140 |
+
def one_hot_to_thermometer(x, levels, flattened = False):
|
141 |
+
"""
|
142 |
+
Convert One hot Encoding to Thermometer Encoding.
|
143 |
+
"""
|
144 |
+
|
145 |
+
if flattened:
|
146 |
+
pass
|
147 |
+
#TODO: check how to flatten
|
148 |
+
|
149 |
+
thermometer = torch.cumsum(x , dim = 4)
|
150 |
+
|
151 |
+
if flattened:
|
152 |
+
pass
|
153 |
+
return thermometer
|
154 |
+
|
155 |
+
if __name__ =='__main__':
|
156 |
+
|
157 |
+
logger = logging.getLogger('Thermometer Encoding')
|
158 |
+
|
159 |
+
handler = logging.StreamHandler() # Handler for the logger
|
160 |
+
handler.setFormatter(logging.Formatter('%(asctime)s'))
|
161 |
+
logger.addHandler(handler)
|
162 |
+
logger.setLevel(logging.DEBUG)
|
163 |
+
|
164 |
+
logger.info('Start attack.')
|
165 |
+
|
166 |
+
torch.manual_seed(100)
|
167 |
+
device = torch.device("cuda")
|
168 |
+
|
169 |
+
#ipdb.set_trace()
|
170 |
+
|
171 |
+
logger.info('Load trainset.')
|
172 |
+
train_loader = torch.utils.data.DataLoader(
|
173 |
+
datasets.MNIST('deeprobust/image/data', train=True, download=True,
|
174 |
+
transform=transforms.Compose([transforms.ToTensor()])),
|
175 |
+
batch_size=100,
|
176 |
+
shuffle=True)
|
177 |
+
|
178 |
+
test_loader = torch.utils.data.DataLoader(
|
179 |
+
datasets.MNIST('deeprobust/image/data', train=False,
|
180 |
+
transform=transforms.Compose([transforms.ToTensor()])),
|
181 |
+
batch_size=1000,
|
182 |
+
shuffle=True)
|
183 |
+
|
184 |
+
#ipdb.set_trace()
|
185 |
+
|
186 |
+
#TODO: change the channel according to the dataset.
|
187 |
+
LEVELS = 10
|
188 |
+
channel = 1
|
189 |
+
model = Net(in_channel1 = channel * LEVELS, out_channel1= 32 * LEVELS, out_channel2= 64 * LEVELS).to(device)
|
190 |
+
optimizer = optim.SGD(model.parameters(), lr = 0.0001, momentum = 0.2)
|
191 |
+
logger.info('Load model.')
|
192 |
+
|
193 |
+
save_model = True
|
194 |
+
for epoch in range(1, 50 + 1): ## 5 batches
|
195 |
+
print('Running epoch ', epoch)
|
196 |
+
|
197 |
+
train(model, device, train_loader, optimizer, epoch)
|
198 |
+
test(model, device, test_loader)
|
199 |
+
|
200 |
+
if (save_model):
|
201 |
+
torch.save(model.state_dict(), "deeprobust/image/save_models/thermometer_encoding.pt")
|
202 |
+
|
203 |
+
|
deeprobust/image/defense/YOPO.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of adversarial training variant: YOPO.
|
3 |
+
References
|
4 |
+
----------
|
5 |
+
.. [1] Zhang, D., Zhang, T., Lu, Y., Zhu, Z., & Dong, B. (2019).
|
6 |
+
You only propagate once: Painless adversarial training using maximal principle.
|
7 |
+
arXiv preprint arXiv:1905.00877.
|
8 |
+
|
9 |
+
.. [2] Original code: https://github.com/a1600012888/YOPO-You-Only-Propagate-Once
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn.modules.loss import _Loss
|
15 |
+
from torch import optim
|
16 |
+
|
17 |
+
from collections import OrderedDict
|
18 |
+
import torch
|
19 |
+
from tqdm import tqdm
|
20 |
+
from typing import Tuple, List, Dict
|
21 |
+
import numpy as np
|
22 |
+
import argparse
|
23 |
+
import json
|
24 |
+
import math
|
25 |
+
|
26 |
+
import os
|
27 |
+
|
28 |
+
from deeprobust.image.netmodels import YOPOCNN
|
29 |
+
from deeprobust.image import utils
|
30 |
+
from deeprobust.image.attack import YOPOpgd
|
31 |
+
from deeprobust.image.defense.base_defense import BaseDefense
|
32 |
+
|
33 |
+
import time
|
34 |
+
from tensorboardX import SummaryWriter
|
35 |
+
|
36 |
+
class PieceWiseConstantLrSchedulerMaker(object):
|
37 |
+
|
38 |
+
def __init__(self, milestones:List[int], gamma:float = 0.1):
|
39 |
+
self.milestones = milestones
|
40 |
+
self.gamma = gamma
|
41 |
+
|
42 |
+
def __call__(self, optimizer):
|
43 |
+
return torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.milestones, gamma=self.gamma)
|
44 |
+
|
45 |
+
|
46 |
+
class IPGDAttackMethodMaker(object):
|
47 |
+
|
48 |
+
def __init__(self, eps, sigma, nb_iters, norm, mean, std):
|
49 |
+
self.eps = eps
|
50 |
+
self.sigma = sigma
|
51 |
+
self.nb_iters = nb_iters
|
52 |
+
self.norm = norm
|
53 |
+
self.mean = mean
|
54 |
+
self.std = std
|
55 |
+
|
56 |
+
def __call__(self, DEVICE):
|
57 |
+
return YOPOpgd.FASTPGD(self.eps, self.sigma, self.nb_iters, self.norm, DEVICE, self.mean, self.std)
|
58 |
+
|
59 |
+
def torch_accuracy(output, target, topk=(1,)) -> List[torch.Tensor]:
|
60 |
+
'''
|
61 |
+
param output, target: should be torch Variable
|
62 |
+
'''
|
63 |
+
# assert isinstance(output, torch.cuda.Tensor), 'expecting Torch Tensor'
|
64 |
+
# assert isinstance(target, torch.Tensor), 'expecting Torch Tensor'
|
65 |
+
# print(type(output))
|
66 |
+
|
67 |
+
topn = max(topk)
|
68 |
+
batch_size = output.size(0)
|
69 |
+
|
70 |
+
_, pred = output.topk(topn, 1, True, True)
|
71 |
+
pred = pred.t()
|
72 |
+
|
73 |
+
is_correct = pred.eq(target.view(1, -1).expand_as(pred))
|
74 |
+
|
75 |
+
ans = []
|
76 |
+
for i in topk:
|
77 |
+
is_correct_i = is_correct[:i].view(-1).float().sum(0, keepdim=True)
|
78 |
+
ans.append(is_correct_i.mul_(100.0 / batch_size))
|
79 |
+
|
80 |
+
return ans
|
81 |
+
|
82 |
+
class AvgMeter(object):
|
83 |
+
name = 'No name'
|
84 |
+
|
85 |
+
def __init__(self, name='No name'):
|
86 |
+
self.name = name
|
87 |
+
self.reset()
|
88 |
+
|
89 |
+
def reset(self):
|
90 |
+
self.sum = 0
|
91 |
+
self.mean = 0
|
92 |
+
self.num = 0
|
93 |
+
self.now = 0
|
94 |
+
|
95 |
+
def update(self, mean_var, count=1):
|
96 |
+
if math.isnan(mean_var):
|
97 |
+
mean_var = 1e6
|
98 |
+
print('Avgmeter getting Nan!')
|
99 |
+
self.now = mean_var
|
100 |
+
self.num += count
|
101 |
+
|
102 |
+
self.sum += mean_var * count
|
103 |
+
self.mean = float(self.sum) / self.num
|
104 |
+
|
105 |
+
def load_checkpoint(file_name, net = None, optimizer = None, lr_scheduler = None):
|
106 |
+
if os.path.isfile(file_name):
|
107 |
+
print("=> loading checkpoint '{}'".format(file_name))
|
108 |
+
check_point = torch.load(file_name)
|
109 |
+
if net is not None:
|
110 |
+
print('Loading network state dict')
|
111 |
+
net.load_state_dict(check_point['state_dict'])
|
112 |
+
if optimizer is not None:
|
113 |
+
print('Loading optimizer state dict')
|
114 |
+
optimizer.load_state_dict(check_point['optimizer_state_dict'])
|
115 |
+
if lr_scheduler is not None:
|
116 |
+
print('Loading lr_scheduler state dict')
|
117 |
+
lr_scheduler.load_state_dict(check_point['lr_scheduler_state_dict'])
|
118 |
+
|
119 |
+
return check_point['epoch']
|
120 |
+
else:
|
121 |
+
print("=> no checkpoint found at '{}'".format(file_name))
|
122 |
+
|
123 |
+
|
124 |
+
def make_symlink(source, link_name):
|
125 |
+
if os.path.exists(link_name):
|
126 |
+
#print("Link name already exist! Removing '{}' and overwriting".format(link_name))
|
127 |
+
os.remove(link_name)
|
128 |
+
if os.path.exists(source):
|
129 |
+
os.symlink(source, link_name)
|
130 |
+
return
|
131 |
+
else:
|
132 |
+
print('Source path not exists')
|
133 |
+
#print('SymLink Wrong!')
|
134 |
+
|
135 |
+
def add_path(path):
|
136 |
+
if path not in sys.path:
|
137 |
+
print('Adding {}'.format(path))
|
138 |
+
sys.path.append(path)
|
139 |
+
|
140 |
+
class Hamiltonian(_Loss):
|
141 |
+
|
142 |
+
def __init__(self, layer, reg_cof = 1e-4):
|
143 |
+
super(Hamiltonian, self).__init__()
|
144 |
+
self.layer = layer
|
145 |
+
self.reg_cof = 0
|
146 |
+
|
147 |
+
def forward(self, x, p):
|
148 |
+
y = self.layer(x)
|
149 |
+
H = torch.sum(y * p)
|
150 |
+
return H
|
151 |
+
|
152 |
+
class CrossEntropyWithWeightPenlty(_Loss):
|
153 |
+
def __init__(self, module, DEVICE, reg_cof = 1e-4):
|
154 |
+
super(CrossEntropyWithWeightPenlty, self).__init__()
|
155 |
+
|
156 |
+
self.reg_cof = reg_cof
|
157 |
+
self.criterion = nn.CrossEntropyLoss().to(DEVICE)
|
158 |
+
self.module = module
|
159 |
+
|
160 |
+
def __call__(self, pred, label):
|
161 |
+
cross_loss = self.criterion(pred, label)
|
162 |
+
weight_loss = cal_l2_norm(self.module)
|
163 |
+
|
164 |
+
loss = cross_loss + self.reg_cof * weight_loss
|
165 |
+
return loss
|
166 |
+
|
167 |
+
def cal_l2_norm(layer: torch.nn.Module):
|
168 |
+
loss = 0.
|
169 |
+
for name, param in layer.named_parameters():
|
170 |
+
if name == 'weight':
|
171 |
+
loss = loss + 0.5 * torch.norm(param,) ** 2
|
172 |
+
|
173 |
+
return loss
|
174 |
+
|
175 |
+
class FastGradientLayerOneTrainer(object):
|
176 |
+
|
177 |
+
def __init__(self, Hamiltonian_func, param_optimizer,
|
178 |
+
inner_steps=2, sigma = 0.008, eps = 0.03):
|
179 |
+
self.inner_steps = inner_steps
|
180 |
+
self.sigma = sigma
|
181 |
+
self.eps = eps
|
182 |
+
self.Hamiltonian_func = Hamiltonian_func
|
183 |
+
self.param_optimizer = param_optimizer
|
184 |
+
|
185 |
+
def step(self, inp, p, eta):
|
186 |
+
p = p.detach()
|
187 |
+
|
188 |
+
for i in range(self.inner_steps):
|
189 |
+
tmp_inp = inp + eta
|
190 |
+
tmp_inp = torch.clamp(tmp_inp, 0, 1)
|
191 |
+
H = self.Hamiltonian_func(tmp_inp, p)
|
192 |
+
|
193 |
+
eta_grad_sign = torch.autograd.grad(H, eta, only_inputs=True, retain_graph=False)[0].sign()
|
194 |
+
|
195 |
+
eta = eta - eta_grad_sign * self.sigma
|
196 |
+
|
197 |
+
eta = torch.clamp(eta, -1.0 * self.eps, self.eps)
|
198 |
+
eta = torch.clamp(inp + eta, 0.0, 1.0) - inp
|
199 |
+
eta = eta.detach()
|
200 |
+
eta.requires_grad_()
|
201 |
+
eta.retain_grad()
|
202 |
+
|
203 |
+
#self.param_optimizer.zero_grad()
|
204 |
+
|
205 |
+
yofo_inp = eta + inp
|
206 |
+
yofo_inp = torch.clamp(yofo_inp, 0, 1)
|
207 |
+
|
208 |
+
loss = -1.0 * self.Hamiltonian_func(yofo_inp, p)
|
209 |
+
|
210 |
+
loss.backward()
|
211 |
+
#self.param_optimizer.step()
|
212 |
+
#self.param_optimizer.zero_grad()
|
213 |
+
|
214 |
+
return yofo_inp, eta
|
215 |
+
|
216 |
+
def eval_one_epoch(net, batch_generator, DEVICE=torch.device('cuda:0'), AttackMethod = None):
|
217 |
+
net.eval()
|
218 |
+
pbar = tqdm(batch_generator)
|
219 |
+
clean_accuracy = AvgMeter()
|
220 |
+
adv_accuracy = AvgMeter()
|
221 |
+
|
222 |
+
pbar.set_description('Evaluating')
|
223 |
+
for (data, label) in pbar:
|
224 |
+
data = data.to(DEVICE)
|
225 |
+
label = label.to(DEVICE)
|
226 |
+
|
227 |
+
with torch.no_grad():
|
228 |
+
pred = net(data)
|
229 |
+
acc = torch_accuracy(pred, label, (1,))
|
230 |
+
clean_accuracy.update(acc[0].item())
|
231 |
+
|
232 |
+
if AttackMethod is not None:
|
233 |
+
adv_inp = AttackMethod.attack(net, data, label)
|
234 |
+
|
235 |
+
with torch.no_grad():
|
236 |
+
pred = net(adv_inp)
|
237 |
+
acc = torch_accuracy(pred, label, (1,))
|
238 |
+
adv_accuracy.update(acc[0].item())
|
239 |
+
|
240 |
+
pbar_dic = OrderedDict()
|
241 |
+
pbar_dic['CleanAcc'] = '{:.2f}'.format(clean_accuracy.mean)
|
242 |
+
pbar_dic['AdvAcc'] = '{:.2f}'.format(adv_accuracy.mean)
|
243 |
+
|
244 |
+
pbar.set_postfix(pbar_dic)
|
245 |
+
|
246 |
+
adv_acc = adv_accuracy.mean if AttackMethod is not None else 0
|
247 |
+
return clean_accuracy.mean, adv_acc
|
248 |
+
|
249 |
+
|
250 |
+
class SGDOptimizerMaker(object):
|
251 |
+
|
252 |
+
def __init__(self, lr = 0.1, momentum = 0.9, weight_decay = 1e-4):
|
253 |
+
self.lr = lr
|
254 |
+
self.momentum = momentum
|
255 |
+
self.weight_decay = weight_decay
|
256 |
+
|
257 |
+
def __call__(self, params):
|
258 |
+
return torch.optim.SGD(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
|
259 |
+
|
260 |
+
def main():
|
261 |
+
num_epochs = 40
|
262 |
+
val_interval = 1
|
263 |
+
weight_decay = 5e-4
|
264 |
+
|
265 |
+
inner_iters = 10
|
266 |
+
K = 5
|
267 |
+
sigma = 0.01
|
268 |
+
eps = 0.3
|
269 |
+
lr = 1e-2
|
270 |
+
momentum = 0.9
|
271 |
+
create_optimizer = SGDOptimizerMaker(lr =1e-2 / K, momentum = 0.9, weight_decay = weight_decay)
|
272 |
+
|
273 |
+
create_lr_scheduler = PieceWiseConstantLrSchedulerMaker(milestones = [30, 35, 39], gamma = 0.1)
|
274 |
+
|
275 |
+
create_loss_function = None
|
276 |
+
|
277 |
+
create_attack_method = None
|
278 |
+
|
279 |
+
create_evaluation_attack_method = IPGDAttackMethodMaker(eps = 0.3, sigma = 0.01, nb_iters = 40, norm = np.inf,
|
280 |
+
mean=torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]),
|
281 |
+
std=torch.tensor(np.array([1]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]))
|
282 |
+
|
283 |
+
parser = argparse.ArgumentParser()
|
284 |
+
|
285 |
+
parser.add_argument('--model_dir',default = "./trained_models")
|
286 |
+
parser.add_argument('--resume', default=None, type=str, metavar='PATH',
|
287 |
+
help='path to latest checkpoint (default: none)')
|
288 |
+
parser.add_argument('-b', '--batch_size', default=256, type=int,
|
289 |
+
metavar='N', help='mini-batch size')
|
290 |
+
parser.add_argument('-d', type=int, default=0, help='Which gpu to use')
|
291 |
+
parser.add_argument('-adv_coef', default=1.0, type = float,
|
292 |
+
help = 'Specify the weight for adversarial loss')
|
293 |
+
parser.add_argument('--auto-continue', default=False, action = 'store_true',
|
294 |
+
help = 'Continue from the latest checkpoint')
|
295 |
+
args = parser.parse_args()
|
296 |
+
|
297 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
298 |
+
|
299 |
+
net = YOPOCNN.Net()
|
300 |
+
net.to(DEVICE)
|
301 |
+
criterion = CrossEntropyWithWeightPenlty(net.other_layers, DEVICE, weight_decay)#.to(DEVICE)
|
302 |
+
optimizer = create_optimizer(net.other_layers.parameters())
|
303 |
+
lr_scheduler = create_lr_scheduler(optimizer)
|
304 |
+
|
305 |
+
Hamiltonian_func = Hamiltonian(net.layer_one, weight_decay)
|
306 |
+
layer_one_optimizer = optim.SGD(net.layer_one.parameters(), lr = lr_scheduler.get_lr()[0], momentum=0.9, weight_decay=5e-4)
|
307 |
+
lyaer_one_optimizer_lr_scheduler = optim.lr_scheduler.MultiStepLR(layer_one_optimizer,
|
308 |
+
milestones = [15, 19], gamma = 0.1)
|
309 |
+
LayerOneTrainer = FastGradientLayerOneTrainer(Hamiltonian_func, layer_one_optimizer,
|
310 |
+
inner_iters, sigma, eps)
|
311 |
+
|
312 |
+
ds_train = utils.create_train_dataset(args.batch_size)
|
313 |
+
ds_val = utils.create_test_dataset(args.batch_size)
|
314 |
+
|
315 |
+
EvalAttack = create_evaluation_attack_method(DEVICE)
|
316 |
+
|
317 |
+
now_epoch = 0
|
318 |
+
|
319 |
+
if args.auto_continue:
|
320 |
+
args.resume = os.path.join(args.model_dir, 'last.checkpoint')
|
321 |
+
if args.resume is not None and os.path.isfile(args.resume):
|
322 |
+
now_epoch = load_checkpoint(args.resume, net, optimizer,lr_scheduler)
|
323 |
+
|
324 |
+
now_train_time = 0
|
325 |
+
while True:
|
326 |
+
if now_epoch > num_epochs:
|
327 |
+
break
|
328 |
+
now_epoch = now_epoch + 1
|
329 |
+
|
330 |
+
descrip_str = 'Training epoch:{}/{} -- lr:{}'.format(now_epoch, num_epochs,
|
331 |
+
lr_scheduler.get_lr()[0])
|
332 |
+
s_time = time.time()
|
333 |
+
|
334 |
+
#train
|
335 |
+
acc, yopoacc = train_one_epoch(net, ds_train, optimizer, eps, criterion, LayerOneTrainer, K,
|
336 |
+
DEVICE, descrip_str)
|
337 |
+
|
338 |
+
now_train_time = now_train_time + time.time() - s_time
|
339 |
+
tb_train_dic = {'Acc':acc, 'YoPoAcc':yopoacc}
|
340 |
+
print(tb_train_dic)
|
341 |
+
|
342 |
+
lr_scheduler.step()
|
343 |
+
lyaer_one_optimizer_lr_scheduler.step()
|
344 |
+
utils.save_checkpoint(now_epoch, net, optimizer, lr_scheduler,
|
345 |
+
file_name = os.path.join(args.model_dir, 'epoch-{}.checkpoint'.format(now_epoch)))
|
346 |
+
|
347 |
+
def train_one_epoch(net, batch_generator, optimizer, eps,
|
348 |
+
criterion, LayerOneTrainner, K,
|
349 |
+
DEVICE=torch.device('cuda:0'),descrip_str='Training'):
|
350 |
+
'''
|
351 |
+
:param attack_freq: Frequencies of training with adversarial examples. -1 indicates natural training
|
352 |
+
:param AttackMethod: the attack method, None represents natural training
|
353 |
+
:return: None #(clean_acc, adv_acc)
|
354 |
+
'''
|
355 |
+
net.train()
|
356 |
+
pbar = tqdm(batch_generator)
|
357 |
+
yofoacc = -1
|
358 |
+
cleanacc = -1
|
359 |
+
cleanloss = -1
|
360 |
+
pbar.set_description(descrip_str)
|
361 |
+
for i, (data, label) in enumerate(pbar):
|
362 |
+
data = data.to(DEVICE)
|
363 |
+
label = label.to(DEVICE)
|
364 |
+
|
365 |
+
eta = torch.FloatTensor(*data.shape).uniform_(-eps, eps)
|
366 |
+
eta = eta.to(label.device)
|
367 |
+
eta.requires_grad_()
|
368 |
+
|
369 |
+
optimizer.zero_grad()
|
370 |
+
LayerOneTrainner.param_optimizer.zero_grad()
|
371 |
+
|
372 |
+
for j in range(K):
|
373 |
+
pbar_dic = OrderedDict()
|
374 |
+
TotalLoss = 0
|
375 |
+
|
376 |
+
pred = net(data + eta.detach())
|
377 |
+
|
378 |
+
loss = criterion(pred, label)
|
379 |
+
TotalLoss = TotalLoss + loss
|
380 |
+
wgrad = net.conv1.weight.grad
|
381 |
+
TotalLoss.backward()
|
382 |
+
net.conv1.weight.grad = wgrad
|
383 |
+
|
384 |
+
|
385 |
+
p = -1.0 * net.layer_one_out.grad
|
386 |
+
yofo_inp, eta = LayerOneTrainner.step(data, p, eta)
|
387 |
+
|
388 |
+
with torch.no_grad():
|
389 |
+
if j == 0:
|
390 |
+
acc = torch_accuracy(pred, label, (1,))
|
391 |
+
cleanacc = acc[0].item()
|
392 |
+
cleanloss = loss.item()
|
393 |
+
|
394 |
+
if j == K - 1:
|
395 |
+
yofo_pred = net(yofo_inp)
|
396 |
+
yofoacc = torch_accuracy(yofo_pred, label, (1,))[0].item()
|
397 |
+
|
398 |
+
optimizer.step()
|
399 |
+
LayerOneTrainner.param_optimizer.step()
|
400 |
+
optimizer.zero_grad()
|
401 |
+
LayerOneTrainner.param_optimizer.zero_grad()
|
402 |
+
pbar_dic['Acc'] = '{:.2f}'.format(cleanacc)
|
403 |
+
pbar_dic['loss'] = '{:.2f}'.format(cleanloss)
|
404 |
+
pbar_dic['YoPoAcc'] = '{:.2f}'.format(yofoacc)
|
405 |
+
pbar.set_postfix(pbar_dic)
|
406 |
+
|
407 |
+
return cleanacc, yofoacc
|
408 |
+
|
409 |
+
if __name__ == "__main__":
|
410 |
+
main()
|
deeprobust/image/defense/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deeprobust.image.defense import base_defense
|
2 |
+
from deeprobust.image.defense import pgdtraining
|
3 |
+
from deeprobust.image.defense import fgsmtraining
|
4 |
+
from deeprobust.image.defense import TherEncoding
|
5 |
+
from deeprobust.image.defense import trades
|
6 |
+
from deeprobust.image.defense import YOPO
|
deeprobust/image/defense/base_defense.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class BaseDefense(object):
|
5 |
+
"""
|
6 |
+
Defense base class.
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
__metaclass__ = ABCMeta
|
11 |
+
|
12 |
+
def __init__(self, model, device):
|
13 |
+
self.model = model
|
14 |
+
self.device = device
|
15 |
+
|
16 |
+
def parse_params(self, **kwargs):
|
17 |
+
"""
|
18 |
+
Parse user defined parameters
|
19 |
+
"""
|
20 |
+
return True
|
21 |
+
|
22 |
+
def generate(self, train_loader, test_loader, **kwargs):
|
23 |
+
"""generate.
|
24 |
+
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
train_loader :
|
28 |
+
training data
|
29 |
+
test_loader :
|
30 |
+
testing data
|
31 |
+
kwargs :
|
32 |
+
user defined parameters
|
33 |
+
"""
|
34 |
+
self.train_loader = train_loader
|
35 |
+
self.test_loader = test_loader
|
36 |
+
return
|
37 |
+
|
38 |
+
def train(self, train_loader, optimizer, epoch):
|
39 |
+
"""train.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
train_loader :
|
44 |
+
training data
|
45 |
+
optimizer :
|
46 |
+
training optimizer
|
47 |
+
epoch :
|
48 |
+
training epoch
|
49 |
+
"""
|
50 |
+
return True
|
51 |
+
|
52 |
+
def test(self, test_loader):
|
53 |
+
"""test.
|
54 |
+
|
55 |
+
Parameters
|
56 |
+
----------
|
57 |
+
test_loader :
|
58 |
+
testing data
|
59 |
+
"""
|
60 |
+
return True
|
61 |
+
def adv_data(self, model, data, target, **kwargs):
|
62 |
+
"""
|
63 |
+
Generate adversarial examples for adversarial training.
|
64 |
+
Overide this function to generate customize adv examples.
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
model :
|
69 |
+
victim model
|
70 |
+
data :
|
71 |
+
original data
|
72 |
+
target :
|
73 |
+
target labels
|
74 |
+
kwargs :
|
75 |
+
parameters
|
76 |
+
"""
|
77 |
+
return True
|
78 |
+
|
79 |
+
def loss(self, output, target):
|
80 |
+
"""
|
81 |
+
Calculate training loss.
|
82 |
+
Overide this function to customize loss.
|
83 |
+
|
84 |
+
Parameters
|
85 |
+
----------
|
86 |
+
output :
|
87 |
+
model outputs
|
88 |
+
target :
|
89 |
+
true labels
|
90 |
+
"""
|
91 |
+
return True
|
92 |
+
|
93 |
+
def generate(self):
|
94 |
+
return True
|
95 |
+
|
96 |
+
def save_model(self):
|
97 |
+
"""
|
98 |
+
Save model.
|
99 |
+
"""
|
100 |
+
return True
|
deeprobust/image/defense/fast.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of adversarial training variant: fast
|
3 |
+
|
4 |
+
References
|
5 |
+
----------
|
6 |
+
.. [1] Wong, Eric, Leslie Rice, and J. Zico Kolter. "Fast is better than free: Revisiting adversarial training." arXiv preprint arXiv:2001.03994 (2020).
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torchvision
|
16 |
+
from torchvision import datasets, transforms
|
17 |
+
from torch.utils.data import DataLoader, Dataset
|
18 |
+
from torch import optim
|
19 |
+
|
20 |
+
from deeprobust.image.defense.base_defense import BaseDefense
|
21 |
+
from deeprobust.image.attack.fgsm import FGSM
|
22 |
+
|
23 |
+
class Fast(BaseDefense):
|
24 |
+
def __init__(self, model, device):
|
25 |
+
if not torch.cuda.is_available():
|
26 |
+
print('CUDA not availiable, using cpu...')
|
27 |
+
self.device = 'cpu'
|
28 |
+
else:
|
29 |
+
self.device = device
|
30 |
+
|
31 |
+
self.model = model
|
32 |
+
|
33 |
+
def generate(self, train_loader, test_loader, **kwargs):
|
34 |
+
"""
|
35 |
+
FGSM defense process:
|
36 |
+
"""
|
37 |
+
self.parse_params(**kwargs)
|
38 |
+
torch.manual_seed(100)
|
39 |
+
device = torch.device(self.device)
|
40 |
+
optimizer = optim.Adam(self.model.parameters(), self.lr_train)
|
41 |
+
|
42 |
+
for epoch in range(1, self.epoch_num + 1):
|
43 |
+
|
44 |
+
print(epoch, flush = True)
|
45 |
+
self.train(self.device, train_loader, optimizer, epoch)
|
46 |
+
self.test(self.model, self.device, test_loader)
|
47 |
+
|
48 |
+
if (self.save_model):
|
49 |
+
if os.path.isdir('./' + self.save_dir):
|
50 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name))
|
51 |
+
print("model saved in " + './' + self.save_dir)
|
52 |
+
else:
|
53 |
+
print("make new directory and save model in " + './' + self.save_dir)
|
54 |
+
os.mkdir('./' + self.save_dir)
|
55 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name))
|
56 |
+
|
57 |
+
return self.model
|
58 |
+
|
59 |
+
def parse_params(self,
|
60 |
+
save_dir = "defense_models",
|
61 |
+
save_model = True,
|
62 |
+
save_name = "fast_mnist_fgsmtraining_0.2.pt",
|
63 |
+
epsilon = 0.2,
|
64 |
+
epoch_num = 30,
|
65 |
+
lr_train = 0.005,
|
66 |
+
momentum = 0.1):
|
67 |
+
# """
|
68 |
+
# Set parameters for fast training.
|
69 |
+
# """
|
70 |
+
self.save_model = True
|
71 |
+
self.save_dir = save_dir
|
72 |
+
self.save_name = save_name
|
73 |
+
self.epsilon = epsilon
|
74 |
+
self.epoch_num = epoch_num
|
75 |
+
self.lr_train = lr_train
|
76 |
+
self.momentum = momentum
|
77 |
+
|
78 |
+
def train(self, device, train_loader, optimizer, epoch):
|
79 |
+
"""
|
80 |
+
Training process.
|
81 |
+
"""
|
82 |
+
self.model.train()
|
83 |
+
correct = 0
|
84 |
+
bs = train_loader.batch_size
|
85 |
+
|
86 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
87 |
+
|
88 |
+
optimizer.zero_grad()
|
89 |
+
|
90 |
+
data, target = data.to(device), target.to(device)
|
91 |
+
|
92 |
+
data_adv, output = self.adv_data(data, target, ep = self.epsilon)
|
93 |
+
|
94 |
+
loss = self.calculate_loss(output, target)
|
95 |
+
|
96 |
+
loss.backward()
|
97 |
+
optimizer.step()
|
98 |
+
|
99 |
+
pred = output.argmax(dim = 1, keepdim = True)
|
100 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
101 |
+
|
102 |
+
#print every 10
|
103 |
+
if batch_idx % 10 == 0:
|
104 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy:{:.2f}%'.format(
|
105 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
106 |
+
100. * batch_idx / len(train_loader), loss.item(), 100 * correct/(10*bs)))
|
107 |
+
correct = 0
|
108 |
+
|
109 |
+
|
110 |
+
def test(self, model, device, test_loader):
|
111 |
+
"""
|
112 |
+
Testing process.
|
113 |
+
|
114 |
+
"""
|
115 |
+
model.eval()
|
116 |
+
|
117 |
+
test_loss = 0
|
118 |
+
correct = 0
|
119 |
+
test_loss_adv = 0
|
120 |
+
correct_adv = 0
|
121 |
+
for data, target in test_loader:
|
122 |
+
data, target = data.to(device), target.to(device)
|
123 |
+
|
124 |
+
# print clean accuracy
|
125 |
+
output = model(data)
|
126 |
+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
127 |
+
pred = output.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
128 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
129 |
+
|
130 |
+
# print adversarial accuracy
|
131 |
+
data_adv, output_adv = self.adv_data(data, target, ep = self.epsilon)
|
132 |
+
|
133 |
+
test_loss_adv += self.calculate_loss(output_adv, target, redmode = 'sum').item() # sum up batch loss
|
134 |
+
pred_adv = output_adv.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
135 |
+
correct_adv += pred_adv.eq(target.view_as(pred_adv)).sum().item()
|
136 |
+
|
137 |
+
test_loss /= len(test_loader.dataset)
|
138 |
+
test_loss_adv /= len(test_loader.dataset)
|
139 |
+
|
140 |
+
print('\nTest set: Clean loss: {:.3f}, Clean Accuracy: {}/{} ({:.0f}%)\n'.format(
|
141 |
+
test_loss, correct, len(test_loader.dataset),
|
142 |
+
100. * correct / len(test_loader.dataset)))
|
143 |
+
|
144 |
+
print('\nTest set: Adv loss: {:.3f}, Adv Accuracy: {}/{} ({:.0f}%)\n'.format(
|
145 |
+
test_loss_adv, correct_adv, len(test_loader.dataset),
|
146 |
+
100. * correct_adv / len(test_loader.dataset)))
|
147 |
+
|
148 |
+
def adv_data(self, data, output, ep = 0.3, num_steps = 40):
|
149 |
+
# """
|
150 |
+
# Generate input(adversarial) data for training.
|
151 |
+
|
152 |
+
# """
|
153 |
+
delta = torch.zeros_like(data).uniform_(-ep, ep).to(self.device)
|
154 |
+
data = delta + data
|
155 |
+
|
156 |
+
adversary = FGSM(self.model)
|
157 |
+
data_adv = adversary.generate(data, output.flatten(), epsilon = ep)
|
158 |
+
output = self.model(data_adv)
|
159 |
+
|
160 |
+
return data_adv, output
|
161 |
+
|
162 |
+
def calculate_loss(self, output, target, redmode = 'mean'):
|
163 |
+
"""
|
164 |
+
Calculate loss for training.
|
165 |
+
"""
|
166 |
+
|
167 |
+
loss = F.nll_loss(output, target, reduction = redmode)
|
168 |
+
return loss
|
169 |
+
|
deeprobust/image/defense/fgsmtraining.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is the implementation of fgsm training.
|
3 |
+
|
4 |
+
References
|
5 |
+
----------
|
6 |
+
..[1]Szegedy, C., Zaremba, W., Sutskever, I., Estrach, J. B., Erhan, D., Goodfellow, I., & Fergus, R. (2014, January).
|
7 |
+
Intriguing properties of neural networks.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.optim as optim
|
13 |
+
from torchvision import datasets, transforms
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
from PIL import Image
|
18 |
+
import os
|
19 |
+
|
20 |
+
from deeprobust.image.netmodels import CNN
|
21 |
+
from deeprobust.image.attack.fgsm import FGSM
|
22 |
+
from deeprobust.image.defense.base_defense import BaseDefense
|
23 |
+
|
24 |
+
class FGSMtraining(BaseDefense):
|
25 |
+
"""
|
26 |
+
FGSM adversarial training.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, model, device):
|
30 |
+
if not torch.cuda.is_available():
|
31 |
+
print('CUDA not availiable, using cpu...')
|
32 |
+
self.device = 'cpu'
|
33 |
+
else:
|
34 |
+
self.device = device
|
35 |
+
|
36 |
+
self.model = model
|
37 |
+
|
38 |
+
def generate(self, train_loader, test_loader, **kwargs):
|
39 |
+
"""FGSM adversarial training process.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
train_loader :
|
44 |
+
training data loader
|
45 |
+
test_loader :
|
46 |
+
testing data loader
|
47 |
+
kwargs :
|
48 |
+
kwargs
|
49 |
+
"""
|
50 |
+
self.parse_params(**kwargs)
|
51 |
+
torch.manual_seed(100)
|
52 |
+
device = torch.device(self.device)
|
53 |
+
optimizer = optim.Adam(self.model.parameters(), self.lr_train)
|
54 |
+
|
55 |
+
for epoch in range(1, self.epoch_num + 1):
|
56 |
+
|
57 |
+
print(epoch, flush = True)
|
58 |
+
self.train(self.device, train_loader, optimizer, epoch)
|
59 |
+
self.test(self.model, self.device, test_loader)
|
60 |
+
|
61 |
+
if (self.save_model):
|
62 |
+
if os.path.isdir('./' + self.save_dir):
|
63 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name))
|
64 |
+
print("model saved in " + './' + self.save_dir)
|
65 |
+
else:
|
66 |
+
print("make new directory and save model in " + './' + self.save_dir)
|
67 |
+
os.mkdir('./' + self.save_dir)
|
68 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name))
|
69 |
+
|
70 |
+
return self.model
|
71 |
+
|
72 |
+
def parse_params(self,
|
73 |
+
save_dir = "defense_models",
|
74 |
+
save_model = True,
|
75 |
+
save_name = "mnist_fgsmtraining_0.2.pt",
|
76 |
+
epsilon = 0.2,
|
77 |
+
epoch_num = 50,
|
78 |
+
lr_train = 0.005,
|
79 |
+
momentum = 0.1):
|
80 |
+
"""parse_params.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
save_dir :
|
85 |
+
dir
|
86 |
+
save_model :
|
87 |
+
Whether to save model
|
88 |
+
save_name :
|
89 |
+
model name
|
90 |
+
epsilon :
|
91 |
+
attack perturbation constraint
|
92 |
+
epoch_num :
|
93 |
+
number of training epoch
|
94 |
+
lr_train :
|
95 |
+
training learning rate
|
96 |
+
momentum :
|
97 |
+
momentum for optimizor
|
98 |
+
"""
|
99 |
+
self.save_model = True
|
100 |
+
self.save_dir = save_dir
|
101 |
+
self.save_name = save_name
|
102 |
+
self.epsilon = epsilon
|
103 |
+
self.epoch_num = epoch_num
|
104 |
+
self.lr_train = lr_train
|
105 |
+
self.momentum = momentum
|
106 |
+
|
107 |
+
def train(self, device, train_loader, optimizer, epoch):
|
108 |
+
"""
|
109 |
+
training process.
|
110 |
+
|
111 |
+
Parameters
|
112 |
+
----------
|
113 |
+
device :
|
114 |
+
device
|
115 |
+
train_loader :
|
116 |
+
training data loader
|
117 |
+
optimizer :
|
118 |
+
optimizer
|
119 |
+
epoch :
|
120 |
+
training epoch
|
121 |
+
"""
|
122 |
+
self.model.train()
|
123 |
+
correct = 0
|
124 |
+
bs = train_loader.batch_size
|
125 |
+
|
126 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
127 |
+
|
128 |
+
optimizer.zero_grad()
|
129 |
+
|
130 |
+
data, target = data.to(device), target.to(device)
|
131 |
+
|
132 |
+
data_adv, output = self.adv_data(data, target, ep = self.epsilon)
|
133 |
+
|
134 |
+
loss = self.calculate_loss(output, target)
|
135 |
+
|
136 |
+
loss.backward()
|
137 |
+
optimizer.step()
|
138 |
+
|
139 |
+
pred = output.argmax(dim = 1, keepdim = True)
|
140 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
141 |
+
|
142 |
+
#print every 10
|
143 |
+
if batch_idx % 10 == 0:
|
144 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy:{:.2f}%'.format(
|
145 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
146 |
+
100. * batch_idx / len(train_loader), loss.item(), 100 * correct/(10*bs)))
|
147 |
+
correct = 0
|
148 |
+
|
149 |
+
|
150 |
+
def test(self, model, device, test_loader):
|
151 |
+
"""
|
152 |
+
testing process.
|
153 |
+
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
model :
|
157 |
+
model
|
158 |
+
device :
|
159 |
+
device
|
160 |
+
test_loader :
|
161 |
+
testing dataloder
|
162 |
+
"""
|
163 |
+
model.eval()
|
164 |
+
|
165 |
+
test_loss = 0
|
166 |
+
correct = 0
|
167 |
+
test_loss_adv = 0
|
168 |
+
correct_adv = 0
|
169 |
+
for data, target in test_loader:
|
170 |
+
data, target = data.to(device), target.to(device)
|
171 |
+
|
172 |
+
# print clean accuracy
|
173 |
+
output = model(data)
|
174 |
+
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
|
175 |
+
pred = output.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
176 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
177 |
+
|
178 |
+
# print adversarial accuracy
|
179 |
+
data_adv, output_adv = self.adv_data(data, target, ep = self.epsilon)
|
180 |
+
|
181 |
+
test_loss_adv += self.calculate_loss(output_adv, target, redmode = 'sum').item() # sum up batch loss
|
182 |
+
pred_adv = output_adv.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
183 |
+
correct_adv += pred_adv.eq(target.view_as(pred_adv)).sum().item()
|
184 |
+
|
185 |
+
test_loss /= len(test_loader.dataset)
|
186 |
+
test_loss_adv /= len(test_loader.dataset)
|
187 |
+
|
188 |
+
print('\nTest set: Clean loss: {:.3f}, Clean Accuracy: {}/{} ({:.0f}%)\n'.format(
|
189 |
+
test_loss, correct, len(test_loader.dataset),
|
190 |
+
100. * correct / len(test_loader.dataset)))
|
191 |
+
|
192 |
+
print('\nTest set: Adv loss: {:.3f}, Adv Accuracy: {}/{} ({:.0f}%)\n'.format(
|
193 |
+
test_loss_adv, correct_adv, len(test_loader.dataset),
|
194 |
+
100. * correct_adv / len(test_loader.dataset)))
|
195 |
+
|
196 |
+
def adv_data(self, data, output, ep = 0.3, num_steps = 40):
|
197 |
+
"""Generate adversarial data for training.
|
198 |
+
|
199 |
+
Parameters
|
200 |
+
----------
|
201 |
+
data :
|
202 |
+
data
|
203 |
+
output :
|
204 |
+
output
|
205 |
+
ep :
|
206 |
+
epsilon, perturbation budget.
|
207 |
+
num_steps :
|
208 |
+
iteration steps
|
209 |
+
"""
|
210 |
+
# """
|
211 |
+
# Generate input(adversarial) data for training.
|
212 |
+
|
213 |
+
# """
|
214 |
+
adversary = FGSM(self.model)
|
215 |
+
data_adv = adversary.generate(data, output.flatten(), epsilon = ep)
|
216 |
+
output = self.model(data_adv)
|
217 |
+
|
218 |
+
return data_adv, output
|
219 |
+
|
220 |
+
def calculate_loss(self, output, target, redmode = 'mean'):
|
221 |
+
"""
|
222 |
+
Calculate loss for training.
|
223 |
+
"""
|
224 |
+
|
225 |
+
loss = F.cross_entropy(output, target, reduction = redmode)
|
226 |
+
return loss
|
227 |
+
|
deeprobust/image/defense/pgdtraining.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of pgd adversarial training.
|
3 |
+
References
|
4 |
+
----------
|
5 |
+
..[1]Mądry, A., Makelov, A., Schmidt, L., Tsipras, D., & Vladu, A. (2017).
|
6 |
+
Towards Deep Learning Models Resistant to Adversarial Attacks. stat, 1050, 9.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.optim as optim
|
13 |
+
from torchvision import datasets, transforms
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
from PIL import Image
|
18 |
+
from deeprobust.image.attack.pgd import PGD
|
19 |
+
from deeprobust.image.netmodels.CNN import Net
|
20 |
+
from deeprobust.image.defense.base_defense import BaseDefense
|
21 |
+
|
22 |
+
|
23 |
+
class PGDtraining(BaseDefense):
|
24 |
+
"""
|
25 |
+
PGD adversarial training.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
|
30 |
+
def __init__(self, model, device):
|
31 |
+
if not torch.cuda.is_available():
|
32 |
+
print('CUDA not availiable, using cpu...')
|
33 |
+
self.device = 'cpu'
|
34 |
+
else:
|
35 |
+
self.device = device
|
36 |
+
|
37 |
+
self.model = model
|
38 |
+
|
39 |
+
def generate(self, train_loader, test_loader, **kwargs):
|
40 |
+
"""Call this function to generate robust model.
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
train_loader :
|
45 |
+
training data loader
|
46 |
+
test_loader :
|
47 |
+
testing data loader
|
48 |
+
kwargs :
|
49 |
+
kwargs
|
50 |
+
"""
|
51 |
+
self.parse_params(**kwargs)
|
52 |
+
|
53 |
+
torch.manual_seed(100)
|
54 |
+
device = torch.device(self.device)
|
55 |
+
|
56 |
+
optimizer = optim.Adam(self.model.parameters(), self.lr)
|
57 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 100], gamma = 0.1)
|
58 |
+
save_model = True
|
59 |
+
for epoch in range(1, self.epoch + 1):
|
60 |
+
print('Training epoch: ', epoch, flush = True)
|
61 |
+
self.train(self.device, train_loader, optimizer, epoch)
|
62 |
+
self.test(self.model, self.device, test_loader)
|
63 |
+
|
64 |
+
if (self.save_model and epoch % self.save_per_epoch == 0):
|
65 |
+
if os.path.isdir(str(self.save_dir)):
|
66 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name + '_epoch' + str(epoch) + '.pth'))
|
67 |
+
print("model saved in " + str(self.save_dir))
|
68 |
+
else:
|
69 |
+
print("make new directory and save model in " + str(self.save_dir))
|
70 |
+
os.mkdir('./' + str(self.save_dir))
|
71 |
+
torch.save(self.model.state_dict(), os.path.join(self.save_dir, self.save_name + '_epoch' + str(epoch) + '.pth'))
|
72 |
+
|
73 |
+
scheduler.step()
|
74 |
+
|
75 |
+
return self.model
|
76 |
+
|
77 |
+
def parse_params(self,
|
78 |
+
epoch_num = 100,
|
79 |
+
save_dir = "./defense_models",
|
80 |
+
save_name = "mnist_pgdtraining_0.3",
|
81 |
+
save_model = True,
|
82 |
+
epsilon = 8.0 / 255.0,
|
83 |
+
num_steps = 10,
|
84 |
+
perturb_step_size = 0.01,
|
85 |
+
lr = 0.1,
|
86 |
+
momentum = 0.1,
|
87 |
+
save_per_epoch = 10):
|
88 |
+
"""Parameter parser.
|
89 |
+
|
90 |
+
Parameters
|
91 |
+
----------
|
92 |
+
epoch_num : int
|
93 |
+
epoch
|
94 |
+
save_dir : str
|
95 |
+
model dir
|
96 |
+
save_name : str
|
97 |
+
model name
|
98 |
+
save_model : bool
|
99 |
+
Whether to save model
|
100 |
+
epsilon : float
|
101 |
+
attack constraint
|
102 |
+
num_steps : int
|
103 |
+
PGD attack iteration time
|
104 |
+
perturb_step_size : float
|
105 |
+
perturb step size
|
106 |
+
lr : float
|
107 |
+
learning rate for adversary training process
|
108 |
+
momentum : float
|
109 |
+
momentum for optimizor
|
110 |
+
"""
|
111 |
+
self.epoch = epoch_num
|
112 |
+
self.save_model = True
|
113 |
+
self.save_dir = save_dir
|
114 |
+
self.save_name = save_name
|
115 |
+
self.epsilon = epsilon
|
116 |
+
self.num_steps = num_steps
|
117 |
+
self.perturb_step_size = perturb_step_size
|
118 |
+
self.lr = lr
|
119 |
+
self.momentum = momentum
|
120 |
+
self.save_per_epoch = save_per_epoch
|
121 |
+
|
122 |
+
def train(self, device, train_loader, optimizer, epoch):
|
123 |
+
"""
|
124 |
+
training process.
|
125 |
+
|
126 |
+
Parameters
|
127 |
+
----------
|
128 |
+
device :
|
129 |
+
device
|
130 |
+
train_loader :
|
131 |
+
training data loader
|
132 |
+
optimizer :
|
133 |
+
optimizer
|
134 |
+
epoch :
|
135 |
+
training epoch
|
136 |
+
"""
|
137 |
+
|
138 |
+
self.model.train()
|
139 |
+
correct = 0
|
140 |
+
bs = train_loader.batch_size
|
141 |
+
#scheduler = StepLR(optimizer, step_size = 10, gamma = 0.5)
|
142 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
143 |
+
|
144 |
+
optimizer.zero_grad()
|
145 |
+
|
146 |
+
data, target = data.to(device), target.to(device)
|
147 |
+
|
148 |
+
data_adv, output = self.adv_data(data, target, ep = self.epsilon, num_steps = self.num_steps, perturb_step_size = self.perturb_step_size)
|
149 |
+
loss = self.calculate_loss(output, target)
|
150 |
+
|
151 |
+
loss.backward()
|
152 |
+
optimizer.step()
|
153 |
+
|
154 |
+
pred = output.argmax(dim = 1, keepdim = True)
|
155 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
156 |
+
|
157 |
+
#print every 10
|
158 |
+
if batch_idx % 20 == 0:
|
159 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy:{:.2f}%'.format(
|
160 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
161 |
+
100. * batch_idx / len(train_loader), loss.item(), 100 * correct/(bs)))
|
162 |
+
correct = 0
|
163 |
+
|
164 |
+
|
165 |
+
def test(self, model, device, test_loader):
|
166 |
+
"""
|
167 |
+
testing process.
|
168 |
+
|
169 |
+
Parameters
|
170 |
+
----------
|
171 |
+
model :
|
172 |
+
model
|
173 |
+
device :
|
174 |
+
device
|
175 |
+
test_loader :
|
176 |
+
testing dataloder
|
177 |
+
"""
|
178 |
+
model.eval()
|
179 |
+
|
180 |
+
test_loss = 0
|
181 |
+
correct = 0
|
182 |
+
test_loss_adv = 0
|
183 |
+
correct_adv = 0
|
184 |
+
for data, target in test_loader:
|
185 |
+
data, target = data.to(device), target.to(device)
|
186 |
+
|
187 |
+
# print clean accuracy
|
188 |
+
output = model(data)
|
189 |
+
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
|
190 |
+
pred = output.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
191 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
192 |
+
|
193 |
+
# print adversarial accuracy
|
194 |
+
data_adv, output_adv = self.adv_data(data, target, ep = self.epsilon, num_steps = self.num_steps)
|
195 |
+
|
196 |
+
test_loss_adv += self.calculate_loss(output_adv, target, redmode = 'sum').item() # sum up batch loss
|
197 |
+
pred_adv = output_adv.argmax(dim = 1, keepdim = True) # get the index of the max log-probability
|
198 |
+
correct_adv += pred_adv.eq(target.view_as(pred_adv)).sum().item()
|
199 |
+
|
200 |
+
test_loss /= len(test_loader.dataset)
|
201 |
+
test_loss_adv /= len(test_loader.dataset)
|
202 |
+
|
203 |
+
print('\nTest set: Clean loss: {:.3f}, Clean Accuracy: {}/{} ({:.0f}%)\n'.format(
|
204 |
+
test_loss, correct, len(test_loader.dataset),
|
205 |
+
100. * correct / len(test_loader.dataset)))
|
206 |
+
|
207 |
+
print('\nTest set: Adv loss: {:.3f}, Adv Accuracy: {}/{} ({:.0f}%)\n'.format(
|
208 |
+
test_loss_adv, correct_adv, len(test_loader.dataset),
|
209 |
+
100. * correct_adv / len(test_loader.dataset)))
|
210 |
+
|
211 |
+
def adv_data(self, data, output, ep = 0.3, num_steps = 10, perturb_step_size = 0.01):
|
212 |
+
"""
|
213 |
+
Generate input(adversarial) data for training.
|
214 |
+
"""
|
215 |
+
|
216 |
+
adversary = PGD(self.model)
|
217 |
+
data_adv = adversary.generate(data, output.flatten(), epsilon = ep, num_steps = num_steps, step_size = perturb_step_size)
|
218 |
+
output = self.model(data_adv)
|
219 |
+
|
220 |
+
return data_adv, output
|
221 |
+
|
222 |
+
def calculate_loss(self, output, target, redmode = 'mean'):
|
223 |
+
"""
|
224 |
+
Calculate loss for training.
|
225 |
+
"""
|
226 |
+
|
227 |
+
loss = F.cross_entropy(output, target, reduction = redmode)
|
228 |
+
return loss
|
229 |
+
|
deeprobust/image/defense/trades.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of [1]
|
3 |
+
References
|
4 |
+
---------
|
5 |
+
.. [1] Zhang, H., Yu, Y., Jiao, J., Xing, E., El Ghaoui, L., & Jordan, M. (2019, May).
|
6 |
+
Theoretically Principled Trade-off between Robustness and Accuracy.
|
7 |
+
In International Conference on Machine Learning (pp. 7472-7482).
|
8 |
+
|
9 |
+
This implementation is based on their code: https://github.com/yaodongyu/TRADES
|
10 |
+
Copyright (c) 2019 Hongyang Zhang, Yaodong Yu
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.autograd import Variable
|
19 |
+
import torch.optim as optim
|
20 |
+
from torchvision import datasets, transforms
|
21 |
+
|
22 |
+
from deeprobust.image.defense.base_defense import BaseDefense
|
23 |
+
from deeprobust.image.netmodels.CNN import Net
|
24 |
+
from deeprobust.image.utils import adjust_learning_rate
|
25 |
+
|
26 |
+
class TRADES(BaseDefense):
|
27 |
+
"""TRADES.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, model, device = 'cuda'):
|
31 |
+
if not torch.cuda.is_available():
|
32 |
+
print('CUDA not available, using cpu...')
|
33 |
+
self.device = 'cpu'
|
34 |
+
else:
|
35 |
+
self.device = device
|
36 |
+
|
37 |
+
self.model = model.to(self.device)
|
38 |
+
|
39 |
+
def generate(self, train_loader, test_loader, **kwargs):
|
40 |
+
"""generate robust model.
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
train_loader :
|
45 |
+
train_loader
|
46 |
+
test_loader :
|
47 |
+
test_loader
|
48 |
+
kwargs :
|
49 |
+
kwargs
|
50 |
+
"""
|
51 |
+
|
52 |
+
self.parse_params(**kwargs)
|
53 |
+
|
54 |
+
torch.manual_seed(self.seed)
|
55 |
+
|
56 |
+
loader_kwargs = {'num_workers': 1, 'pin_memory': True} if (self.device == 'cuda') else {}
|
57 |
+
|
58 |
+
# init model, Net() can be also used here for training
|
59 |
+
optimizer = optim.SGD(self.model.parameters(), lr = self.lr, momentum = self.momentum)
|
60 |
+
|
61 |
+
for epoch in range(1, self.epochs + 1):
|
62 |
+
# adjust learning rate for SGD
|
63 |
+
optimizer = adjust_learning_rate(optimizer, epoch, self.lr)
|
64 |
+
|
65 |
+
# adversarial training
|
66 |
+
self.train(self.device, train_loader, optimizer, epoch)
|
67 |
+
|
68 |
+
# evaluation on natural examples
|
69 |
+
self.test(self.model, self.device, test_loader)
|
70 |
+
|
71 |
+
# save checkpoint
|
72 |
+
if not os.path.exists(self.save_dir):
|
73 |
+
os.makedirs(self.save_dir)
|
74 |
+
if epoch % self.save_freq == 0:
|
75 |
+
torch.save(self.model.state_dict(),
|
76 |
+
os.path.join(self.save_dir, 'trade_model-nn-epoch{}.pt'.format(epoch)))
|
77 |
+
torch.save(optimizer.state_dict(),
|
78 |
+
os.path.join(self.save_dir, 'opt-nn-checkpoint_epoch{}.tar'.format(epoch)))
|
79 |
+
|
80 |
+
def parse_params(self,
|
81 |
+
epochs = 100,
|
82 |
+
lr = 0.01,
|
83 |
+
momentum = 0.9,
|
84 |
+
epsilon = 0.3,
|
85 |
+
num_steps = 40,
|
86 |
+
step_size = 0.01,
|
87 |
+
beta = 1.0,
|
88 |
+
seed = 1,
|
89 |
+
log_interval = 100,
|
90 |
+
save_dir = "./defense_model",
|
91 |
+
save_freq = 10
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
:param epoch : int
|
95 |
+
- pgd training epoch
|
96 |
+
:param save_dir : str
|
97 |
+
- directory path to save model
|
98 |
+
:param epsilon : float
|
99 |
+
- perturb constraint of pgd adversary example used to train defense model
|
100 |
+
:param num_steps : int
|
101 |
+
- the perturb
|
102 |
+
:param perturb_step_size : float
|
103 |
+
- step_size
|
104 |
+
:param lr : float
|
105 |
+
- learning rate for adversary training process
|
106 |
+
:param momentum : float
|
107 |
+
- parameter for optimizer in training process
|
108 |
+
"""
|
109 |
+
self.epochs = epochs
|
110 |
+
self.lr = lr
|
111 |
+
self.momentum = momentum
|
112 |
+
self.epsilon = epsilon
|
113 |
+
self.num_steps = num_steps
|
114 |
+
self.step_size = step_size
|
115 |
+
self.beta = beta
|
116 |
+
self.seed = seed
|
117 |
+
self.log_interval = log_interval
|
118 |
+
self.save_dir = save_dir
|
119 |
+
self.save_freq = save_freq
|
120 |
+
|
121 |
+
def test(self, model, device, test_loader):
|
122 |
+
model.eval()
|
123 |
+
test_loss = 0
|
124 |
+
correct = 0
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
for data, target in test_loader:
|
128 |
+
data, target = data.to(device), target.to(device)
|
129 |
+
output = model(data)
|
130 |
+
test_loss += F.cross_entropy(output, target, size_average=False).item()
|
131 |
+
pred = output.max(1, keepdim=True)[1]
|
132 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
133 |
+
test_loss /= len(test_loader.dataset)
|
134 |
+
print('Test: Clean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
|
135 |
+
test_loss, correct, len(test_loader.dataset),
|
136 |
+
100. * correct / len(test_loader.dataset)))
|
137 |
+
test_accuracy = correct / len(test_loader.dataset)
|
138 |
+
|
139 |
+
return test_loss, test_accuracy
|
140 |
+
|
141 |
+
def train(self, device, train_loader, optimizer, epoch):
|
142 |
+
self.model.train()
|
143 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
144 |
+
|
145 |
+
optimizer.zero_grad()
|
146 |
+
|
147 |
+
data, target = data.to(self.device), target.to(self.device)
|
148 |
+
|
149 |
+
# calculate robust loss
|
150 |
+
loss = self.trades_loss(model = self.model,
|
151 |
+
x_natural = data,
|
152 |
+
y = target,
|
153 |
+
optimizer = optimizer,
|
154 |
+
step_size = self.step_size,
|
155 |
+
epsilon = self.epsilon,
|
156 |
+
perturb_steps = self.num_steps,
|
157 |
+
beta = self.beta)
|
158 |
+
|
159 |
+
loss.backward()
|
160 |
+
optimizer.step()
|
161 |
+
|
162 |
+
# print progress
|
163 |
+
if batch_idx % self.log_interval == 0:
|
164 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
165 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
166 |
+
100. * batch_idx / len(train_loader), loss.item()))
|
167 |
+
|
168 |
+
def trades_loss(self,
|
169 |
+
model,
|
170 |
+
x_natural,
|
171 |
+
y,
|
172 |
+
optimizer,
|
173 |
+
step_size = 0.003,
|
174 |
+
epsilon = 0.031,
|
175 |
+
perturb_steps = 10,
|
176 |
+
beta = 1.0,
|
177 |
+
distance = 'l_inf'):
|
178 |
+
|
179 |
+
# define KL-loss
|
180 |
+
criterion_kl = nn.KLDivLoss(size_average=False)
|
181 |
+
model.eval()
|
182 |
+
batch_size = len(x_natural)
|
183 |
+
|
184 |
+
# generate adversarial example
|
185 |
+
x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
|
186 |
+
|
187 |
+
if distance == 'l_inf':
|
188 |
+
for _ in range(perturb_steps):
|
189 |
+
x_adv.requires_grad_()
|
190 |
+
with torch.enable_grad():
|
191 |
+
loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
|
192 |
+
F.softmax(model(x_natural), dim=1))
|
193 |
+
grad = torch.autograd.grad(loss_kl, [x_adv])[0]
|
194 |
+
x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
|
195 |
+
x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
|
196 |
+
x_adv = torch.clamp(x_adv, 0.0, 1.0)
|
197 |
+
|
198 |
+
elif distance == 'l_2':
|
199 |
+
|
200 |
+
delta = 0.001 * torch.randn(x_natural.shape).cuda().detach()
|
201 |
+
delta = Variable(delta.data, requires_grad=True)
|
202 |
+
|
203 |
+
# Setup optimizers
|
204 |
+
optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)
|
205 |
+
|
206 |
+
for _ in range(perturb_steps):
|
207 |
+
adv = x_natural + delta
|
208 |
+
|
209 |
+
# optimize
|
210 |
+
optimizer_delta.zero_grad()
|
211 |
+
with torch.enable_grad():
|
212 |
+
loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1),
|
213 |
+
F.softmax(model(x_natural), dim=1))
|
214 |
+
loss.backward()
|
215 |
+
# renorming gradient
|
216 |
+
grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
|
217 |
+
delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
|
218 |
+
# avoid nan or inf if gradient is 0
|
219 |
+
if (grad_norms == 0).any():
|
220 |
+
delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0])
|
221 |
+
optimizer_delta.step()
|
222 |
+
|
223 |
+
# projection
|
224 |
+
delta.data.add_(x_natural)
|
225 |
+
delta.data.clamp_(0, 1).sub_(x_natural)
|
226 |
+
delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
|
227 |
+
x_adv = Variable(x_natural + delta, requires_grad=False)
|
228 |
+
else:
|
229 |
+
x_adv = torch.clamp(x_adv, 0.0, 1.0)
|
230 |
+
model.train()
|
231 |
+
|
232 |
+
x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
|
233 |
+
# zero gradient
|
234 |
+
optimizer.zero_grad()
|
235 |
+
# calculate robust loss
|
236 |
+
logits = model(x_natural)
|
237 |
+
loss_natural = F.cross_entropy(logits, y)
|
238 |
+
loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
|
239 |
+
F.softmax(model(x_natural), dim=1))
|
240 |
+
loss = loss_natural + beta * loss_robust
|
241 |
+
return loss
|
deeprobust/image/optimizer.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module include the following optimizer:
|
3 |
+
1. differential_evolution:
|
4 |
+
The differential evolution global optimization algorithm
|
5 |
+
https://github.com/scipy/scipy/blob/70e61dee181de23fdd8d893eaa9491100e2218d7/scipy/optimize/_differentialevolution.py
|
6 |
+
|
7 |
+
modified by:
|
8 |
+
https://github.com/DebangLi/one-pixel-attack-pytorch/blob/master/differential_evolution.py
|
9 |
+
|
10 |
+
2. Basic Adam Optimizer
|
11 |
+
|
12 |
+
"""
|
13 |
+
|
14 |
+
from __future__ import division, print_function, absolute_import
|
15 |
+
import numpy as np
|
16 |
+
from scipy.optimize import OptimizeResult, minimize
|
17 |
+
from scipy.optimize.optimize import _status_message
|
18 |
+
from scipy._lib._util import check_random_state
|
19 |
+
import warnings
|
20 |
+
|
21 |
+
|
22 |
+
__all__ = ['differential_evolution', 'AdamOptimizer']
|
23 |
+
|
24 |
+
_MACHEPS = np.finfo(np.float64).eps
|
25 |
+
|
26 |
+
|
27 |
+
def differential_evolution(func, bounds, args=(), strategy='best1bin',
|
28 |
+
maxiter=1000, popsize=15, tol=0.01,
|
29 |
+
mutation=(0.5, 1), recombination=0.7, seed=None,
|
30 |
+
callback=None, disp=False, polish=True,
|
31 |
+
init='latinhypercube', atol=0):
|
32 |
+
"""Finds the global minimum of a multivariate function.
|
33 |
+
Differential Evolution is stochastic in nature (does not use gradient
|
34 |
+
methods) to find the minimium, and can search large areas of candidate
|
35 |
+
space, but often requires larger numbers of function evaluations than
|
36 |
+
conventional gradient based techniques.
|
37 |
+
The algorithm is due to Storn and Price [1]_.
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
func : callable
|
41 |
+
The objective function to be minimized. Must be in the form
|
42 |
+
``f(x, *args)``, where ``x`` is the argument in the form of a 1-D array
|
43 |
+
and ``args`` is a tuple of any additional fixed parameters needed to
|
44 |
+
completely specify the function.
|
45 |
+
bounds : sequence
|
46 |
+
Bounds for variables. ``(min, max)`` pairs for each element in ``x``,
|
47 |
+
defining the lower and upper bounds for the optimizing argument of
|
48 |
+
`func`. It is required to have ``len(bounds) == len(x)``.
|
49 |
+
``len(bounds)`` is used to determine the number of parameters in ``x``.
|
50 |
+
args : tuple, optional
|
51 |
+
Any additional fixed parameters needed to
|
52 |
+
completely specify the objective function.
|
53 |
+
strategy : str, optional
|
54 |
+
The differential evolution strategy to use. Should be one of:
|
55 |
+
- 'best1bin'
|
56 |
+
- 'best1exp'
|
57 |
+
- 'rand1exp'
|
58 |
+
- 'randtobest1exp'
|
59 |
+
- 'currenttobest1exp'
|
60 |
+
- 'best2exp'
|
61 |
+
- 'rand2exp'
|
62 |
+
- 'randtobest1bin'
|
63 |
+
- 'currenttobest1bin'
|
64 |
+
- 'best2bin'
|
65 |
+
- 'rand2bin'
|
66 |
+
- 'rand1bin'
|
67 |
+
The default is 'best1bin'.
|
68 |
+
maxiter : int, optional
|
69 |
+
The maximum number of generations over which the entire population is
|
70 |
+
evolved. The maximum number of function evaluations (with no polishing)
|
71 |
+
is: ``(maxiter + 1) * popsize * len(x)``
|
72 |
+
popsize : int, optional
|
73 |
+
A multiplier for setting the total population size. The population has
|
74 |
+
``popsize * len(x)`` individuals (unless the initial population is
|
75 |
+
supplied via the `init` keyword).
|
76 |
+
tol : float, optional
|
77 |
+
Relative tolerance for convergence, the solving stops when
|
78 |
+
``np.std(pop) <= atol + tol * np.abs(np.mean(population_energies))``,
|
79 |
+
where and `atol` and `tol` are the absolute and relative tolerance
|
80 |
+
respectively.
|
81 |
+
mutation : float or tuple(float, float), optional
|
82 |
+
The mutation constant. In the literature this is also known as
|
83 |
+
differential weight, being denoted by F.
|
84 |
+
If specified as a float it should be in the range [0, 2].
|
85 |
+
If specified as a tuple ``(min, max)`` dithering is employed. Dithering
|
86 |
+
randomly changes the mutation constant on a generation by generation
|
87 |
+
basis. The mutation constant for that generation is taken from
|
88 |
+
``U[min, max)``. Dithering can help speed convergence significantly.
|
89 |
+
Increasing the mutation constant increases the search radius, but will
|
90 |
+
slow down convergence.
|
91 |
+
recombination : float, optional
|
92 |
+
The recombination constant, should be in the range [0, 1]. In the
|
93 |
+
literature this is also known as the crossover probability, being
|
94 |
+
denoted by CR. Increasing this value allows a larger number of mutants
|
95 |
+
to progress into the next generation, but at the risk of population
|
96 |
+
stability.
|
97 |
+
seed : int or `np.random.RandomState`, optional
|
98 |
+
If `seed` is not specified the `np.RandomState` singleton is used.
|
99 |
+
If `seed` is an int, a new `np.random.RandomState` instance is used,
|
100 |
+
seeded with seed.
|
101 |
+
If `seed` is already a `np.random.RandomState instance`, then that
|
102 |
+
`np.random.RandomState` instance is used.
|
103 |
+
Specify `seed` for repeatable minimizations.
|
104 |
+
disp : bool, optional
|
105 |
+
Display status messages
|
106 |
+
callback : callable, `callback(xk, convergence=val)`, optional
|
107 |
+
A function to follow the progress of the minimization. ``xk`` is
|
108 |
+
the current value of ``x0``. ``val`` represents the fractional
|
109 |
+
value of the population convergence. When ``val`` is greater than one
|
110 |
+
the function halts. If callback returns `True`, then the minimization
|
111 |
+
is halted (any polishing is still carried out).
|
112 |
+
polish : bool, optional
|
113 |
+
If True (default), then `scipy.optimize.minimize` with the `L-BFGS-B`
|
114 |
+
method is used to polish the best population member at the end, which
|
115 |
+
can improve the minimization slightly.
|
116 |
+
init : str or array-like, optional
|
117 |
+
Specify which type of population initialization is performed. Should be
|
118 |
+
one of:
|
119 |
+
- 'latinhypercube'
|
120 |
+
- 'random'
|
121 |
+
- array specifying the initial population. The array should have
|
122 |
+
shape ``(M, len(x))``, where len(x) is the number of parameters.
|
123 |
+
`init` is clipped to `bounds` before use.
|
124 |
+
The default is 'latinhypercube'. Latin Hypercube sampling tries to
|
125 |
+
maximize coverage of the available parameter space. 'random'
|
126 |
+
initializes the population randomly - this has the drawback that
|
127 |
+
clustering can occur, preventing the whole of parameter space being
|
128 |
+
covered. Use of an array to specify a population subset could be used,
|
129 |
+
for example, to create a tight bunch of initial guesses in an location
|
130 |
+
where the solution is known to exist, thereby reducing time for
|
131 |
+
convergence.
|
132 |
+
atol : float, optional
|
133 |
+
Absolute tolerance for convergence, the solving stops when
|
134 |
+
``np.std(pop) <= atol + tol * np.abs(np.mean(population_energies))``,
|
135 |
+
where and `atol` and `tol` are the absolute and relative tolerance
|
136 |
+
respectively.
|
137 |
+
Returns
|
138 |
+
-------
|
139 |
+
res : OptimizeResult
|
140 |
+
The optimization result represented as a `OptimizeResult` object.
|
141 |
+
Important attributes are: ``x`` the solution array, ``success`` a
|
142 |
+
Boolean flag indicating if the optimizer exited successfully and
|
143 |
+
``message`` which describes the cause of the termination. See
|
144 |
+
`OptimizeResult` for a description of other attributes. If `polish`
|
145 |
+
was employed, and a lower minimum was obtained by the polishing, then
|
146 |
+
OptimizeResult also contains the ``jac`` attribute.
|
147 |
+
Notes
|
148 |
+
-----
|
149 |
+
Differential evolution is a stochastic population based method that is
|
150 |
+
useful for global optimization problems. At each pass through the population
|
151 |
+
the algorithm mutates each candidate solution by mixing with other candidate
|
152 |
+
solutions to create a trial candidate. There are several strategies [2]_ for
|
153 |
+
creating trial candidates, which suit some problems more than others. The
|
154 |
+
'best1bin' strategy is a good starting point for many systems. In this
|
155 |
+
strategy two members of the population are randomly chosen. Their difference
|
156 |
+
is used to mutate the best member (the `best` in `best1bin`), :math:`b_0`,
|
157 |
+
so far:
|
158 |
+
.. math::
|
159 |
+
b' = b_0 + mutation * (population[rand0] - population[rand1])
|
160 |
+
A trial vector is then constructed. Starting with a randomly chosen 'i'th
|
161 |
+
parameter the trial is sequentially filled (in modulo) with parameters from
|
162 |
+
`b'` or the original candidate. The choice of whether to use `b'` or the
|
163 |
+
original candidate is made with a binomial distribution (the 'bin' in
|
164 |
+
'best1bin') - a random number in [0, 1) is generated. If this number is
|
165 |
+
less than the `recombination` constant then the parameter is loaded from
|
166 |
+
`b'`, otherwise it is loaded from the original candidate. The final
|
167 |
+
parameter is always loaded from `b'`. Once the trial candidate is built
|
168 |
+
its fitness is assessed. If the trial is better than the original candidate
|
169 |
+
then it takes its place. If it is also better than the best overall
|
170 |
+
candidate it also replaces that.
|
171 |
+
To improve your chances of finding a global minimum use higher `popsize`
|
172 |
+
values, with higher `mutation` and (dithering), but lower `recombination`
|
173 |
+
values. This has the effect of widening the search radius, but slowing
|
174 |
+
convergence.
|
175 |
+
.. versionadded:: 0.15.0
|
176 |
+
|
177 |
+
References
|
178 |
+
----------
|
179 |
+
.. [1] Storn, R and Price, K, Differential Evolution - a Simple and
|
180 |
+
Efficient Heuristic for Global Optimization over Continuous Spaces,
|
181 |
+
Journal of Global Optimization, 1997, 11, 341 - 359.
|
182 |
+
.. [2] http://www1.icsi.berkeley.edu/~storn/code.html
|
183 |
+
.. [3] http://en.wikipedia.org/wiki/Differential_evolution
|
184 |
+
"""
|
185 |
+
|
186 |
+
solver = DifferentialEvolutionSolver(func, bounds, args=args,
|
187 |
+
strategy=strategy, maxiter=maxiter,
|
188 |
+
popsize=popsize, tol=tol,
|
189 |
+
mutation=mutation,
|
190 |
+
recombination=recombination,
|
191 |
+
seed=seed, polish=polish,
|
192 |
+
callback=callback,
|
193 |
+
disp=disp, init=init, atol=atol)
|
194 |
+
return solver.solve()
|
195 |
+
|
196 |
+
|
197 |
+
class DifferentialEvolutionSolver(object):
|
198 |
+
|
199 |
+
"""This class implements the differential evolution solver
|
200 |
+
Parameters
|
201 |
+
----------
|
202 |
+
func : callable
|
203 |
+
The objective function to be minimized. Must be in the form
|
204 |
+
``f(x, *args)``, where ``x`` is the argument in the form of a 1-D array
|
205 |
+
and ``args`` is a tuple of any additional fixed parameters needed to
|
206 |
+
completely specify the function.
|
207 |
+
bounds : sequence
|
208 |
+
Bounds for variables. ``(min, max)`` pairs for each element in ``x``,
|
209 |
+
defining the lower and upper bounds for the optimizing argument of
|
210 |
+
`func`. It is required to have ``len(bounds) == len(x)``.
|
211 |
+
``len(bounds)`` is used to determine the number of parameters in ``x``.
|
212 |
+
args : tuple, optional
|
213 |
+
Any additional fixed parameters needed to
|
214 |
+
completely specify the objective function.
|
215 |
+
strategy : str, optional
|
216 |
+
The differential evolution strategy to use. Should be one of:
|
217 |
+
- 'best1bin'
|
218 |
+
- 'best1exp'
|
219 |
+
- 'rand1exp'
|
220 |
+
- 'randtobest1exp'
|
221 |
+
- 'currenttobest1exp'
|
222 |
+
- 'best2exp'
|
223 |
+
- 'rand2exp'
|
224 |
+
- 'randtobest1bin'
|
225 |
+
- 'currenttobest1bin'
|
226 |
+
- 'best2bin'
|
227 |
+
- 'rand2bin'
|
228 |
+
- 'rand1bin'
|
229 |
+
The default is 'best1bin'
|
230 |
+
maxiter : int, optional
|
231 |
+
The maximum number of generations over which the entire population is
|
232 |
+
evolved. The maximum number of function evaluations (with no polishing)
|
233 |
+
is: ``(maxiter + 1) * popsize * len(x)``
|
234 |
+
popsize : int, optional
|
235 |
+
A multiplier for setting the total population size. The population has
|
236 |
+
``popsize * len(x)`` individuals (unless the initial population is
|
237 |
+
supplied via the `init` keyword).
|
238 |
+
tol : float, optional
|
239 |
+
Relative tolerance for convergence, the solving stops when
|
240 |
+
``np.std(pop) <= atol + tol * np.abs(np.mean(population_energies))``,
|
241 |
+
where and `atol` and `tol` are the absolute and relative tolerance
|
242 |
+
respectively.
|
243 |
+
mutation : float or tuple(float, float), optional
|
244 |
+
The mutation constant. In the literature this is also known as
|
245 |
+
differential weight, being denoted by F.
|
246 |
+
If specified as a float it should be in the range [0, 2].
|
247 |
+
If specified as a tuple ``(min, max)`` dithering is employed. Dithering
|
248 |
+
randomly changes the mutation constant on a generation by generation
|
249 |
+
basis. The mutation constant for that generation is taken from
|
250 |
+
U[min, max). Dithering can help speed convergence significantly.
|
251 |
+
Increasing the mutation constant increases the search radius, but will
|
252 |
+
slow down convergence.
|
253 |
+
recombination : float, optional
|
254 |
+
The recombination constant, should be in the range [0, 1]. In the
|
255 |
+
literature this is also known as the crossover probability, being
|
256 |
+
denoted by CR. Increasing this value allows a larger number of mutants
|
257 |
+
to progress into the next generation, but at the risk of population
|
258 |
+
stability.
|
259 |
+
seed : int or `np.random.RandomState`, optional
|
260 |
+
If `seed` is not specified the `np.random.RandomState` singleton is
|
261 |
+
used.
|
262 |
+
If `seed` is an int, a new `np.random.RandomState` instance is used,
|
263 |
+
seeded with `seed`.
|
264 |
+
If `seed` is already a `np.random.RandomState` instance, then that
|
265 |
+
`np.random.RandomState` instance is used.
|
266 |
+
Specify `seed` for repeatable minimizations.
|
267 |
+
disp : bool, optional
|
268 |
+
Display status messages
|
269 |
+
callback : callable, `callback(xk, convergence=val)`, optional
|
270 |
+
A function to follow the progress of the minimization. ``xk`` is
|
271 |
+
the current value of ``x0``. ``val`` represents the fractional
|
272 |
+
value of the population convergence. When ``val`` is greater than one
|
273 |
+
the function halts. If callback returns `True`, then the minimization
|
274 |
+
is halted (any polishing is still carried out).
|
275 |
+
polish : bool, optional
|
276 |
+
If True, then `scipy.optimize.minimize` with the `L-BFGS-B` method
|
277 |
+
is used to polish the best population member at the end. This requires
|
278 |
+
a few more function evaluations.
|
279 |
+
maxfun : int, optional
|
280 |
+
Set the maximum number of function evaluations. However, it probably
|
281 |
+
makes more sense to set `maxiter` instead.
|
282 |
+
init : str or array-like, optional
|
283 |
+
Specify which type of population initialization is performed. Should be
|
284 |
+
one of:
|
285 |
+
- 'latinhypercube'
|
286 |
+
- 'random'
|
287 |
+
- array specifying the initial population. The array should have
|
288 |
+
shape ``(M, len(x))``, where len(x) is the number of parameters.
|
289 |
+
`init` is clipped to `bounds` before use.
|
290 |
+
The default is 'latinhypercube'. Latin Hypercube sampling tries to
|
291 |
+
maximize coverage of the available parameter space. 'random'
|
292 |
+
initializes the population randomly - this has the drawback that
|
293 |
+
clustering can occur, preventing the whole of parameter space being
|
294 |
+
covered. Use of an array to specify a population could be used, for
|
295 |
+
example, to create a tight bunch of initial guesses in an location
|
296 |
+
where the solution is known to exist, thereby reducing time for
|
297 |
+
convergence.
|
298 |
+
atol : float, optional
|
299 |
+
Absolute tolerance for convergence, the solving stops when
|
300 |
+
``np.std(pop) <= atol + tol * np.abs(np.mean(population_energies))``,
|
301 |
+
where and `atol` and `tol` are the absolute and relative tolerance
|
302 |
+
respectively.
|
303 |
+
"""
|
304 |
+
|
305 |
+
# Dispatch of mutation strategy method (binomial or exponential).
|
306 |
+
_binomial = {'best1bin': '_best1',
|
307 |
+
'randtobest1bin': '_randtobest1',
|
308 |
+
'currenttobest1bin': '_currenttobest1',
|
309 |
+
'best2bin': '_best2',
|
310 |
+
'rand2bin': '_rand2',
|
311 |
+
'rand1bin': '_rand1'}
|
312 |
+
_exponential = {'best1exp': '_best1',
|
313 |
+
'rand1exp': '_rand1',
|
314 |
+
'randtobest1exp': '_randtobest1',
|
315 |
+
'currenttobest1exp': '_currenttobest1',
|
316 |
+
'best2exp': '_best2',
|
317 |
+
'rand2exp': '_rand2'}
|
318 |
+
|
319 |
+
__init_error_msg = ("The population initialization method must be one of "
|
320 |
+
"'latinhypercube' or 'random', or an array of shape "
|
321 |
+
"(M, N) where N is the number of parameters and M>5")
|
322 |
+
|
323 |
+
def __init__(self, func, bounds, args=(),
|
324 |
+
strategy='best1bin', maxiter=1000, popsize=15,
|
325 |
+
tol=0.01, mutation=(0.5, 1), recombination=0.7, seed=None,
|
326 |
+
maxfun=np.inf, callback=None, disp=False, polish=True,
|
327 |
+
init='latinhypercube', atol=0):
|
328 |
+
|
329 |
+
if strategy in self._binomial:
|
330 |
+
self.mutation_func = getattr(self, self._binomial[strategy])
|
331 |
+
elif strategy in self._exponential:
|
332 |
+
self.mutation_func = getattr(self, self._exponential[strategy])
|
333 |
+
else:
|
334 |
+
raise ValueError("Please select a valid mutation strategy")
|
335 |
+
self.strategy = strategy
|
336 |
+
|
337 |
+
self.callback = callback
|
338 |
+
self.polish = polish
|
339 |
+
|
340 |
+
# relative and absolute tolerances for convergence
|
341 |
+
self.tol, self.atol = tol, atol
|
342 |
+
|
343 |
+
# Mutation constant should be in [0, 2). If specified as a sequence
|
344 |
+
# then dithering is performed.
|
345 |
+
self.scale = mutation
|
346 |
+
if (not np.all(np.isfinite(mutation)) or
|
347 |
+
np.any(np.array(mutation) >= 2) or
|
348 |
+
np.any(np.array(mutation) < 0)):
|
349 |
+
raise ValueError('The mutation constant must be a float in '
|
350 |
+
'U[0, 2), or specified as a tuple(min, max)'
|
351 |
+
' where min < max and min, max are in U[0, 2).')
|
352 |
+
|
353 |
+
self.dither = None
|
354 |
+
if hasattr(mutation, '__iter__') and len(mutation) > 1:
|
355 |
+
self.dither = [mutation[0], mutation[1]]
|
356 |
+
self.dither.sort()
|
357 |
+
|
358 |
+
self.cross_over_probability = recombination
|
359 |
+
|
360 |
+
self.func = func
|
361 |
+
self.args = args
|
362 |
+
|
363 |
+
# convert tuple of lower and upper bounds to limits
|
364 |
+
# [(low_0, high_0), ..., (low_n, high_n]
|
365 |
+
# -> [[low_0, ..., low_n], [high_0, ..., high_n]]
|
366 |
+
self.limits = np.array(bounds, dtype='float').T
|
367 |
+
if (np.size(self.limits, 0) != 2 or not
|
368 |
+
np.all(np.isfinite(self.limits))):
|
369 |
+
raise ValueError('bounds should be a sequence containing '
|
370 |
+
'real valued (min, max) pairs for each value'
|
371 |
+
' in x')
|
372 |
+
|
373 |
+
if maxiter is None: # the default used to be None
|
374 |
+
maxiter = 1000
|
375 |
+
self.maxiter = maxiter
|
376 |
+
if maxfun is None: # the default used to be None
|
377 |
+
maxfun = np.inf
|
378 |
+
self.maxfun = maxfun
|
379 |
+
|
380 |
+
# population is scaled to between [0, 1].
|
381 |
+
# We have to scale between parameter <-> population
|
382 |
+
# save these arguments for _scale_parameter and
|
383 |
+
# _unscale_parameter. This is an optimization
|
384 |
+
self.__scale_arg1 = 0.5 * (self.limits[0] + self.limits[1])
|
385 |
+
self.__scale_arg2 = np.fabs(self.limits[0] - self.limits[1])
|
386 |
+
|
387 |
+
self.parameter_count = np.size(self.limits, 1)
|
388 |
+
|
389 |
+
self.random_number_generator = check_random_state(seed)
|
390 |
+
|
391 |
+
# default population initialization is a latin hypercube design, but
|
392 |
+
# there are other population initializations possible.
|
393 |
+
# the minimum is 5 because 'best2bin' requires a population that's at
|
394 |
+
# least 5 long
|
395 |
+
self.num_population_members = max(5, popsize * self.parameter_count)
|
396 |
+
|
397 |
+
self.population_shape = (self.num_population_members,
|
398 |
+
self.parameter_count)
|
399 |
+
|
400 |
+
self._nfev = 0
|
401 |
+
if isinstance(init, str):
|
402 |
+
if init == 'latinhypercube':
|
403 |
+
self.init_population_lhs()
|
404 |
+
elif init == 'random':
|
405 |
+
self.init_population_random()
|
406 |
+
else:
|
407 |
+
raise ValueError(self.__init_error_msg)
|
408 |
+
else:
|
409 |
+
self.init_population_array(init)
|
410 |
+
|
411 |
+
self.disp = disp
|
412 |
+
|
413 |
+
def init_population_lhs(self):
|
414 |
+
"""
|
415 |
+
Initializes the population with Latin Hypercube Sampling.
|
416 |
+
Latin Hypercube Sampling ensures that each parameter is uniformly
|
417 |
+
sampled over its range.
|
418 |
+
"""
|
419 |
+
rng = self.random_number_generator
|
420 |
+
|
421 |
+
# Each parameter range needs to be sampled uniformly. The scaled
|
422 |
+
# parameter range ([0, 1)) needs to be split into
|
423 |
+
# `self.num_population_members` segments, each of which has the following
|
424 |
+
# size:
|
425 |
+
segsize = 1.0 / self.num_population_members
|
426 |
+
|
427 |
+
# Within each segment we sample from a uniform random distribution.
|
428 |
+
# We need to do this sampling for each parameter.
|
429 |
+
samples = (segsize * rng.random_sample(self.population_shape)
|
430 |
+
|
431 |
+
# Offset each segment to cover the entire parameter range [0, 1)
|
432 |
+
+ np.linspace(0., 1., self.num_population_members,
|
433 |
+
endpoint=False)[:, np.newaxis])
|
434 |
+
|
435 |
+
# Create an array for population of candidate solutions.
|
436 |
+
self.population = np.zeros_like(samples)
|
437 |
+
|
438 |
+
# Initialize population of candidate solutions by permutation of the
|
439 |
+
# random samples.
|
440 |
+
for j in range(self.parameter_count):
|
441 |
+
order = rng.permutation(range(self.num_population_members))
|
442 |
+
self.population[:, j] = samples[order, j]
|
443 |
+
|
444 |
+
# reset population energies
|
445 |
+
self.population_energies = (np.ones(self.num_population_members) *
|
446 |
+
np.inf)
|
447 |
+
|
448 |
+
# reset number of function evaluations counter
|
449 |
+
self._nfev = 0
|
450 |
+
|
451 |
+
def init_population_random(self):
|
452 |
+
"""
|
453 |
+
Initialises the population at random. This type of initialization
|
454 |
+
can possess clustering, Latin Hypercube sampling is generally better.
|
455 |
+
"""
|
456 |
+
rng = self.random_number_generator
|
457 |
+
self.population = rng.random_sample(self.population_shape)
|
458 |
+
|
459 |
+
# reset population energies
|
460 |
+
self.population_energies = (np.ones(self.num_population_members) *
|
461 |
+
np.inf)
|
462 |
+
|
463 |
+
# reset number of function evaluations counter
|
464 |
+
self._nfev = 0
|
465 |
+
|
466 |
+
def init_population_array(self, init):
|
467 |
+
"""
|
468 |
+
Initialises the population with a user specified population.
|
469 |
+
Parameters
|
470 |
+
----------
|
471 |
+
init : np.ndarray
|
472 |
+
Array specifying subset of the initial population. The array should
|
473 |
+
have shape (M, len(x)), where len(x) is the number of parameters.
|
474 |
+
The population is clipped to the lower and upper `bounds`.
|
475 |
+
"""
|
476 |
+
# make sure you're using a float array
|
477 |
+
popn = np.asfarray(init)
|
478 |
+
|
479 |
+
if (np.size(popn, 0) < 5 or
|
480 |
+
popn.shape[1] != self.parameter_count or
|
481 |
+
len(popn.shape) != 2):
|
482 |
+
raise ValueError("The population supplied needs to have shape"
|
483 |
+
" (M, len(x)), where M > 4.")
|
484 |
+
|
485 |
+
# scale values and clip to bounds, assigning to population
|
486 |
+
self.population = np.clip(self._unscale_parameters(popn), 0, 1)
|
487 |
+
|
488 |
+
self.num_population_members = np.size(self.population, 0)
|
489 |
+
|
490 |
+
self.population_shape = (self.num_population_members,
|
491 |
+
self.parameter_count)
|
492 |
+
|
493 |
+
# reset population energies
|
494 |
+
self.population_energies = (np.ones(self.num_population_members) *
|
495 |
+
np.inf)
|
496 |
+
|
497 |
+
# reset number of function evaluations counter
|
498 |
+
self._nfev = 0
|
499 |
+
|
500 |
+
@property
|
501 |
+
def x(self):
|
502 |
+
"""
|
503 |
+
The best solution from the solver
|
504 |
+
Returns
|
505 |
+
-------
|
506 |
+
x : ndarray
|
507 |
+
The best solution from the solver.
|
508 |
+
"""
|
509 |
+
return self._scale_parameters(self.population[0])
|
510 |
+
|
511 |
+
@property
|
512 |
+
def convergence(self):
|
513 |
+
"""
|
514 |
+
The standard deviation of the population energies divided by their
|
515 |
+
mean.
|
516 |
+
"""
|
517 |
+
return (np.std(self.population_energies) /
|
518 |
+
np.abs(np.mean(self.population_energies) + _MACHEPS))
|
519 |
+
|
520 |
+
def solve(self):
|
521 |
+
"""
|
522 |
+
Runs the DifferentialEvolutionSolver.
|
523 |
+
Returns
|
524 |
+
-------
|
525 |
+
res : OptimizeResult
|
526 |
+
The optimization result represented as a ``OptimizeResult`` object.
|
527 |
+
Important attributes are: ``x`` the solution array, ``success`` a
|
528 |
+
Boolean flag indicating if the optimizer exited successfully and
|
529 |
+
``message`` which describes the cause of the termination. See
|
530 |
+
`OptimizeResult` for a description of other attributes. If `polish`
|
531 |
+
was employed, and a lower minimum was obtained by the polishing,
|
532 |
+
then OptimizeResult also contains the ``jac`` attribute.
|
533 |
+
"""
|
534 |
+
nit, warning_flag = 0, False
|
535 |
+
status_message = _status_message['success']
|
536 |
+
|
537 |
+
# The population may have just been initialized (all entries are
|
538 |
+
# np.inf). If it has you have to calculate the initial energies.
|
539 |
+
# Although this is also done in the evolve generator it's possible
|
540 |
+
# that someone can set maxiter=0, at which point we still want the
|
541 |
+
# initial energies to be calculated (the following loop isn't run).
|
542 |
+
if np.all(np.isinf(self.population_energies)):
|
543 |
+
self._calculate_population_energies()
|
544 |
+
|
545 |
+
# do the optimisation.
|
546 |
+
for nit in range(1, self.maxiter + 1):
|
547 |
+
# evolve the population by a generation
|
548 |
+
try:
|
549 |
+
next(self)
|
550 |
+
except StopIteration:
|
551 |
+
warning_flag = True
|
552 |
+
status_message = _status_message['maxfev']
|
553 |
+
break
|
554 |
+
|
555 |
+
if self.disp:
|
556 |
+
print("differential_evolution step %d: f(x)= %g"
|
557 |
+
% (nit,
|
558 |
+
self.population_energies[0]))
|
559 |
+
|
560 |
+
# should the solver terminate?
|
561 |
+
convergence = self.convergence
|
562 |
+
|
563 |
+
if (self.callback and
|
564 |
+
self.callback(self._scale_parameters(self.population[0]),
|
565 |
+
convergence=self.tol / convergence) is True):
|
566 |
+
|
567 |
+
warning_flag = True
|
568 |
+
status_message = ('callback function requested stop early '
|
569 |
+
'by returning True')
|
570 |
+
break
|
571 |
+
|
572 |
+
intol = (np.std(self.population_energies) <=
|
573 |
+
self.atol +
|
574 |
+
self.tol * np.abs(np.mean(self.population_energies)))
|
575 |
+
if warning_flag or intol:
|
576 |
+
break
|
577 |
+
|
578 |
+
else:
|
579 |
+
status_message = _status_message['maxiter']
|
580 |
+
warning_flag = True
|
581 |
+
|
582 |
+
DE_result = OptimizeResult(
|
583 |
+
x=self.x,
|
584 |
+
fun=self.population_energies[0],
|
585 |
+
nfev=self._nfev,
|
586 |
+
nit=nit,
|
587 |
+
message=status_message,
|
588 |
+
success=(warning_flag is not True))
|
589 |
+
|
590 |
+
if self.polish:
|
591 |
+
result = minimize(self.func,
|
592 |
+
np.copy(DE_result.x),
|
593 |
+
method='L-BFGS-B',
|
594 |
+
bounds=self.limits.T,
|
595 |
+
args=self.args)
|
596 |
+
|
597 |
+
self._nfev += result.nfev
|
598 |
+
DE_result.nfev = self._nfev
|
599 |
+
|
600 |
+
if result.fun < DE_result.fun:
|
601 |
+
DE_result.fun = result.fun
|
602 |
+
DE_result.x = result.x
|
603 |
+
DE_result.jac = result.jac
|
604 |
+
# to keep internal state consistent
|
605 |
+
self.population_energies[0] = result.fun
|
606 |
+
self.population[0] = self._unscale_parameters(result.x)
|
607 |
+
|
608 |
+
return DE_result
|
609 |
+
|
610 |
+
def _calculate_population_energies(self):
|
611 |
+
"""
|
612 |
+
Calculate the energies of all the population members at the same time.
|
613 |
+
Puts the best member in first place. Useful if the population has just
|
614 |
+
been initialised.
|
615 |
+
"""
|
616 |
+
|
617 |
+
##############
|
618 |
+
## CHANGES: self.func operates on the entire parameters array
|
619 |
+
##############
|
620 |
+
itersize = max(0, min(len(self.population), self.maxfun - self._nfev + 1))
|
621 |
+
candidates = self.population[:itersize]
|
622 |
+
parameters = np.array([self._scale_parameters(c) for c in candidates]) # TODO: vectorize
|
623 |
+
energies = self.func(parameters, *self.args)
|
624 |
+
self.population_energies = energies
|
625 |
+
self._nfev += itersize
|
626 |
+
|
627 |
+
# for index, candidate in enumerate(self.population):
|
628 |
+
# if self._nfev > self.maxfun:
|
629 |
+
# break
|
630 |
+
|
631 |
+
# parameters = self._scale_parameters(candidate)
|
632 |
+
# self.population_energies[index] = self.func(parameters,
|
633 |
+
# *self.args)
|
634 |
+
# self._nfev += 1
|
635 |
+
|
636 |
+
##############
|
637 |
+
##############
|
638 |
+
|
639 |
+
|
640 |
+
|
641 |
+
minval = np.argmin(self.population_energies)
|
642 |
+
|
643 |
+
# put the lowest energy into the best solution position.
|
644 |
+
lowest_energy = self.population_energies[minval]
|
645 |
+
self.population_energies[minval] = self.population_energies[0]
|
646 |
+
self.population_energies[0] = lowest_energy
|
647 |
+
|
648 |
+
self.population[[0, minval], :] = self.population[[minval, 0], :]
|
649 |
+
|
650 |
+
def __iter__(self):
|
651 |
+
return self
|
652 |
+
|
653 |
+
def __next__(self):
|
654 |
+
"""
|
655 |
+
Evolve the population by a single generation
|
656 |
+
Returns
|
657 |
+
-------
|
658 |
+
x : ndarray
|
659 |
+
The best solution from the solver.
|
660 |
+
fun : float
|
661 |
+
Value of objective function obtained from the best solution.
|
662 |
+
"""
|
663 |
+
# the population may have just been initialized (all entries are
|
664 |
+
# np.inf). If it has you have to calculate the initial energies
|
665 |
+
if np.all(np.isinf(self.population_energies)):
|
666 |
+
self._calculate_population_energies()
|
667 |
+
|
668 |
+
if self.dither is not None:
|
669 |
+
self.scale = (self.random_number_generator.rand()
|
670 |
+
* (self.dither[1] - self.dither[0]) + self.dither[0])
|
671 |
+
|
672 |
+
##############
|
673 |
+
## CHANGES: self.func operates on the entire parameters array
|
674 |
+
##############
|
675 |
+
|
676 |
+
itersize = max(0, min(self.num_population_members, self.maxfun - self._nfev + 1))
|
677 |
+
trials = np.array([self._mutate(c) for c in range(itersize)]) # TODO: vectorize
|
678 |
+
for trial in trials: self._ensure_constraint(trial)
|
679 |
+
parameters = np.array([self._scale_parameters(trial) for trial in trials])
|
680 |
+
energies = self.func(parameters, *self.args)
|
681 |
+
self._nfev += itersize
|
682 |
+
|
683 |
+
for candidate,(energy,trial) in enumerate(zip(energies, trials)):
|
684 |
+
# if the energy of the trial candidate is lower than the
|
685 |
+
# original population member then replace it
|
686 |
+
if energy < self.population_energies[candidate]:
|
687 |
+
self.population[candidate] = trial
|
688 |
+
self.population_energies[candidate] = energy
|
689 |
+
|
690 |
+
# if the trial candidate also has a lower energy than the
|
691 |
+
# best solution then replace that as well
|
692 |
+
if energy < self.population_energies[0]:
|
693 |
+
self.population_energies[0] = energy
|
694 |
+
self.population[0] = trial
|
695 |
+
|
696 |
+
# for candidate in range(self.num_population_members):
|
697 |
+
# if self._nfev > self.maxfun:
|
698 |
+
# raise StopIteration
|
699 |
+
|
700 |
+
# # create a trial solution
|
701 |
+
# trial = self._mutate(candidate)
|
702 |
+
|
703 |
+
# # ensuring that it's in the range [0, 1)
|
704 |
+
# self._ensure_constraint(trial)
|
705 |
+
|
706 |
+
# # scale from [0, 1) to the actual parameter value
|
707 |
+
# parameters = self._scale_parameters(trial)
|
708 |
+
|
709 |
+
# # determine the energy of the objective function
|
710 |
+
# energy = self.func(parameters, *self.args)
|
711 |
+
# self._nfev += 1
|
712 |
+
|
713 |
+
# # if the energy of the trial candidate is lower than the
|
714 |
+
# # original population member then replace it
|
715 |
+
# if energy < self.population_energies[candidate]:
|
716 |
+
# self.population[candidate] = trial
|
717 |
+
# self.population_energies[candidate] = energy
|
718 |
+
|
719 |
+
# # if the trial candidate also has a lower energy than the
|
720 |
+
# # best solution then replace that as well
|
721 |
+
# if energy < self.population_energies[0]:
|
722 |
+
# self.population_energies[0] = energy
|
723 |
+
# self.population[0] = trial
|
724 |
+
|
725 |
+
##############
|
726 |
+
##############
|
727 |
+
|
728 |
+
return self.x, self.population_energies[0]
|
729 |
+
|
730 |
+
def next(self):
|
731 |
+
"""
|
732 |
+
Evolve the population by a single generation
|
733 |
+
Returns
|
734 |
+
-------
|
735 |
+
x : ndarray
|
736 |
+
The best solution from the solver.
|
737 |
+
fun : float
|
738 |
+
Value of objective function obtained from the best solution.
|
739 |
+
"""
|
740 |
+
# next() is required for compatibility with Python2.7.
|
741 |
+
return self.__next__()
|
742 |
+
|
743 |
+
def _scale_parameters(self, trial):
|
744 |
+
"""
|
745 |
+
scale from a number between 0 and 1 to parameters.
|
746 |
+
"""
|
747 |
+
return self.__scale_arg1 + (trial - 0.5) * self.__scale_arg2
|
748 |
+
|
749 |
+
def _unscale_parameters(self, parameters):
|
750 |
+
"""
|
751 |
+
scale from parameters to a number between 0 and 1.
|
752 |
+
"""
|
753 |
+
return (parameters - self.__scale_arg1) / self.__scale_arg2 + 0.5
|
754 |
+
|
755 |
+
def _ensure_constraint(self, trial):
|
756 |
+
"""
|
757 |
+
make sure the parameters lie between the limits
|
758 |
+
"""
|
759 |
+
for index in np.where((trial < 0) | (trial > 1))[0]:
|
760 |
+
trial[index] = self.random_number_generator.rand()
|
761 |
+
|
762 |
+
def _mutate(self, candidate):
|
763 |
+
"""
|
764 |
+
create a trial vector based on a mutation strategy
|
765 |
+
"""
|
766 |
+
trial = np.copy(self.population[candidate])
|
767 |
+
|
768 |
+
rng = self.random_number_generator
|
769 |
+
|
770 |
+
fill_point = rng.randint(0, self.parameter_count)
|
771 |
+
|
772 |
+
if self.strategy in ['currenttobest1exp', 'currenttobest1bin']:
|
773 |
+
bprime = self.mutation_func(candidate,
|
774 |
+
self._select_samples(candidate, 5))
|
775 |
+
else:
|
776 |
+
bprime = self.mutation_func(self._select_samples(candidate, 5))
|
777 |
+
|
778 |
+
if self.strategy in self._binomial:
|
779 |
+
crossovers = rng.rand(self.parameter_count)
|
780 |
+
crossovers = crossovers < self.cross_over_probability
|
781 |
+
# the last one is always from the bprime vector for binomial
|
782 |
+
# If you fill in modulo with a loop you have to set the last one to
|
783 |
+
# true. If you don't use a loop then you can have any random entry
|
784 |
+
# be True.
|
785 |
+
crossovers[fill_point] = True
|
786 |
+
trial = np.where(crossovers, bprime, trial)
|
787 |
+
return trial
|
788 |
+
|
789 |
+
elif self.strategy in self._exponential:
|
790 |
+
i = 0
|
791 |
+
while (i < self.parameter_count and
|
792 |
+
rng.rand() < self.cross_over_probability):
|
793 |
+
|
794 |
+
trial[fill_point] = bprime[fill_point]
|
795 |
+
fill_point = (fill_point + 1) % self.parameter_count
|
796 |
+
i += 1
|
797 |
+
|
798 |
+
return trial
|
799 |
+
|
800 |
+
def _best1(self, samples):
|
801 |
+
"""
|
802 |
+
best1bin, best1exp
|
803 |
+
"""
|
804 |
+
r0, r1 = samples[:2]
|
805 |
+
return (self.population[0] + self.scale *
|
806 |
+
(self.population[r0] - self.population[r1]))
|
807 |
+
|
808 |
+
def _rand1(self, samples):
|
809 |
+
"""
|
810 |
+
rand1bin, rand1exp
|
811 |
+
"""
|
812 |
+
r0, r1, r2 = samples[:3]
|
813 |
+
return (self.population[r0] + self.scale *
|
814 |
+
(self.population[r1] - self.population[r2]))
|
815 |
+
|
816 |
+
def _randtobest1(self, samples):
|
817 |
+
"""
|
818 |
+
randtobest1bin, randtobest1exp
|
819 |
+
"""
|
820 |
+
r0, r1, r2 = samples[:3]
|
821 |
+
bprime = np.copy(self.population[r0])
|
822 |
+
bprime += self.scale * (self.population[0] - bprime)
|
823 |
+
bprime += self.scale * (self.population[r1] -
|
824 |
+
self.population[r2])
|
825 |
+
return bprime
|
826 |
+
|
827 |
+
def _currenttobest1(self, candidate, samples):
|
828 |
+
"""
|
829 |
+
currenttobest1bin, currenttobest1exp
|
830 |
+
"""
|
831 |
+
r0, r1 = samples[:2]
|
832 |
+
bprime = (self.population[candidate] + self.scale *
|
833 |
+
(self.population[0] - self.population[candidate] +
|
834 |
+
self.population[r0] - self.population[r1]))
|
835 |
+
return bprime
|
836 |
+
|
837 |
+
def _best2(self, samples):
|
838 |
+
"""
|
839 |
+
best2bin, best2exp
|
840 |
+
"""
|
841 |
+
r0, r1, r2, r3 = samples[:4]
|
842 |
+
bprime = (self.population[0] + self.scale *
|
843 |
+
(self.population[r0] + self.population[r1] -
|
844 |
+
self.population[r2] - self.population[r3]))
|
845 |
+
|
846 |
+
return bprime
|
847 |
+
|
848 |
+
def _rand2(self, samples):
|
849 |
+
"""
|
850 |
+
rand2bin, rand2exp
|
851 |
+
"""
|
852 |
+
r0, r1, r2, r3, r4 = samples
|
853 |
+
bprime = (self.population[r0] + self.scale *
|
854 |
+
(self.population[r1] + self.population[r2] -
|
855 |
+
self.population[r3] - self.population[r4]))
|
856 |
+
|
857 |
+
return bprime
|
858 |
+
|
859 |
+
def _select_samples(self, candidate, number_samples):
|
860 |
+
"""
|
861 |
+
obtain random integers from range(self.num_population_members),
|
862 |
+
without replacement. You can't have the original candidate either.
|
863 |
+
"""
|
864 |
+
idxs = list(range(self.num_population_members))
|
865 |
+
idxs.remove(candidate)
|
866 |
+
self.random_number_generator.shuffle(idxs)
|
867 |
+
idxs = idxs[:number_samples]
|
868 |
+
return idxs
|
869 |
+
|
870 |
+
class AdamOptimizer:
|
871 |
+
"""Basic Adam optimizer implementation that can minimize w.r.t.
|
872 |
+
a single variable.
|
873 |
+
Parameters
|
874 |
+
----------
|
875 |
+
shape : tuple
|
876 |
+
shape of the variable w.r.t. which the loss should be minimized
|
877 |
+
"""
|
878 |
+
#TODO Add reference or rewrite the function.
|
879 |
+
def __init__(self, shape):
|
880 |
+
self.m = np.zeros(shape)
|
881 |
+
self.v = np.zeros(shape)
|
882 |
+
self.t = 0
|
883 |
+
|
884 |
+
def __call__(self, gradient, learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8):
|
885 |
+
"""Updates internal parameters of the optimizer and returns
|
886 |
+
the change that should be applied to the variable.
|
887 |
+
Parameters
|
888 |
+
----------
|
889 |
+
gradient : `np.ndarray`
|
890 |
+
the gradient of the loss w.r.t. to the variable
|
891 |
+
learning_rate: float
|
892 |
+
the learning rate in the current iteration
|
893 |
+
beta1: float
|
894 |
+
decay rate for calculating the exponentially
|
895 |
+
decaying average of past gradients
|
896 |
+
beta2: float
|
897 |
+
decay rate for calculating the exponentially
|
898 |
+
decaying average of past squared gradients
|
899 |
+
epsilon: float
|
900 |
+
small value to avoid division by zero
|
901 |
+
"""
|
902 |
+
|
903 |
+
self.t += 1
|
904 |
+
|
905 |
+
self.m = beta1 * self.m + (1 - beta1) * gradient
|
906 |
+
self.v = beta2 * self.v + (1 - beta2) * gradient ** 2
|
907 |
+
|
908 |
+
bias_correction_1 = 1 - beta1 ** self.t
|
909 |
+
bias_correction_2 = 1 - beta2 ** self.t
|
910 |
+
|
911 |
+
m_hat = self.m / bias_correction_1
|
912 |
+
v_hat = self.v / bias_correction_2
|
913 |
+
|
914 |
+
return -learning_rate * m_hat / (np.sqrt(v_hat) + epsilon)
|
deeprobust/image/preprocessing/APE-GAN.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn
|
6 |
+
from torch.utils.data import TensorDataset
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
|
9 |
+
class Generator(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, in_ch):
|
12 |
+
super(Generator, self).__init__()
|
13 |
+
self.conv1 = nn.Conv2d(in_ch, 64, 4, stride=2, padding=1)
|
14 |
+
self.bn1 = nn.BatchNorm2d(64)
|
15 |
+
self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
|
16 |
+
self.bn2 = nn.BatchNorm2d(128)
|
17 |
+
self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
|
18 |
+
self.bn3 = nn.BatchNorm2d(64)
|
19 |
+
self.deconv4 = nn.ConvTranspose2d(64, in_ch, 4, stride=2, padding=1)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
h = F.leaky_relu(self.bn1(self.conv1(x)))
|
23 |
+
h = F.leaky_relu(self.bn2(self.conv2(h)))
|
24 |
+
h = F.leaky_relu(self.bn3(self.deconv3(h)))
|
25 |
+
h = F.tanh(self.deconv4(h))
|
26 |
+
return h
|
27 |
+
|
28 |
+
class Discriminator(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self, in_ch):
|
31 |
+
super(Discriminator, self).__init__()
|
32 |
+
self.conv1 = nn.Conv2d(in_ch, 64, 3, stride=2)
|
33 |
+
self.conv2 = nn.Conv2d(64, 128, 3, stride=2)
|
34 |
+
self.bn2 = nn.BatchNorm2d(128)
|
35 |
+
self.conv3 = nn.Conv2d(128, 256, 3, stride=2)
|
36 |
+
self.bn3 = nn.BatchNorm2d(256)
|
37 |
+
if in_ch == 1:
|
38 |
+
self.fc4 = nn.Linear(1024, 1)
|
39 |
+
else:
|
40 |
+
self.fc4 = nn.Linear(2304, 1)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
h = F.leaky_relu(self.conv1(x))
|
44 |
+
h = F.leaky_relu(self.bn2(self.conv2(h)))
|
45 |
+
h = F.leaky_relu(self.bn3(self.conv3(h)))
|
46 |
+
h = F.sigmoid(self.fc4(h.view(h.size(0), -1)))
|
47 |
+
return h
|
48 |
+
|
49 |
+
|
50 |
+
def main(args):
|
51 |
+
|
52 |
+
#Initialize GAN model
|
53 |
+
G = Generator(in_ch = C).cuda()
|
54 |
+
D = Discriminator(in_ch = C).cuda()
|
55 |
+
|
56 |
+
#Initialize Generator
|
57 |
+
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
|
58 |
+
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
|
59 |
+
loss_bce = nn.BCELoss()
|
60 |
+
loss_mse = nn.MSELoss()
|
61 |
+
cudnn.benchmark = True
|
62 |
+
|
63 |
+
#Initialize DataLoader
|
64 |
+
train_data = torch.load("./adv_data.tar")
|
65 |
+
train_data = TensorDataset(train_data["normal"], train_data["adv"])
|
66 |
+
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
|
67 |
+
|
68 |
+
#Start Training
|
69 |
+
for i in range(args.epochs):
|
70 |
+
G.eval()
|
71 |
+
x_fake = G(x_adv_temp).data
|
72 |
+
G.train()
|
73 |
+
gen_loss, dis_loss, n = 0, 0, 0
|
74 |
+
for x, x_adv in train_loader:
|
75 |
+
current_size = x.size(0)
|
76 |
+
x, x_adv = x.cuda(), x_adv.cuda()
|
77 |
+
|
78 |
+
#Train Discriminator
|
79 |
+
t_real = torch.ones(current_size).cuda()
|
80 |
+
t_fake = torch.zeros(current_size).cuda()
|
81 |
+
y_real = D(x).squeeze()
|
82 |
+
x_fake = G(x_adv)
|
83 |
+
y_fake = D(x_fake).squeeze()
|
84 |
+
|
85 |
+
loss_D = loss_bce(y_real, t_real) + loss_bce(y_fake, t_fake)
|
86 |
+
opt_D.zero_grad()
|
87 |
+
loss_D.backward()
|
88 |
+
opt_D.step()
|
89 |
+
|
90 |
+
# Train G
|
91 |
+
for _ in range(2):
|
92 |
+
x_fake = G(x_adv)
|
93 |
+
y_fake = D(x_fake).squeeze()
|
94 |
+
|
95 |
+
loss_G = args.alpha * loss_mse(x_fake, x) + args.beta * loss_bce(y_fake, t_real)
|
96 |
+
opt_G.zero_grad()
|
97 |
+
loss_G.backward()
|
98 |
+
opt_G.step()
|
99 |
+
|
100 |
+
gen_loss += loss_D.data[0] * x.size(0)
|
101 |
+
dis_loss += loss_G.data[0] * x.size(0)
|
102 |
+
n += x.size(0)
|
103 |
+
|
104 |
+
print("epoch:{}, LossG:{:.3f}, LossD:{:.3f}".format(i, gen_loss / n, dis_loss / n))
|
105 |
+
torch.save({"generator": G.state_dict(), "discriminator": D.state_dict()},
|
106 |
+
os.path.join(args.checkpoint, "{}.tar".format(i + 1)))
|
107 |
+
|
108 |
+
G.eval()
|
109 |
+
|
110 |
+
def get_args():
|
111 |
+
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
|
114 |
+
parser.add_argument("--data", type=str, default="mnist")
|
115 |
+
parser.add_argument("--lr", type=float, default=0.0002)
|
116 |
+
parser.add_argument("--epochs", type=int, default=2)
|
117 |
+
parser.add_argument("--alpha", type=float, default=0.7)
|
118 |
+
parser.add_argument("--beta", type=float, default=0.3)
|
119 |
+
parser.add_argument("--checkpoint", type=str, default="./checkpoint/test")
|
120 |
+
args = parser.parse_args()
|
121 |
+
|
122 |
+
return args
|
123 |
+
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
get_args()
|
127 |
+
main(args)
|
deeprobust/image/preprocessing/prepare_advdata.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This implementation is used to create adversarial dataset.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.optim as optim
|
9 |
+
from torchvision import datasets,models,transforms
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from deeprobust.image.attack.pgd import PGD
|
13 |
+
import deeprobust.image.netmodels.resnet as resnet
|
14 |
+
import deeprobust.image.netmodels.CNN as CNN
|
15 |
+
from deeprobust.image.config import attack_params
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from deeprobust.image.config import attack_params
|
18 |
+
|
19 |
+
def main(args):
|
20 |
+
#Load Model.
|
21 |
+
model = resnet.ResNet18().to('cuda')
|
22 |
+
print("Load network")
|
23 |
+
|
24 |
+
model.load_state_dict(torch.load("~/Documents/deeprobust_model/cifar_res18_120.pt"))
|
25 |
+
model.eval()
|
26 |
+
|
27 |
+
transform_val = transforms.Compose([
|
28 |
+
transforms.ToTensor(),
|
29 |
+
])
|
30 |
+
train_loader = torch.utils.data.DataLoader(
|
31 |
+
datasets.MNIST('deeprobust/image/defense/data', train=True, download=True,
|
32 |
+
transform=transforms.Compose([transforms.ToTensor()])),
|
33 |
+
batch_size=128,
|
34 |
+
shuffle=True)
|
35 |
+
test_loader = torch.utils.data.DataLoader(
|
36 |
+
datasets.CIFAR10('deeprobust/image/data', train = False, download=True,
|
37 |
+
transform = transform_val),
|
38 |
+
batch_size = 128, shuffle=True) #, **kwargs)
|
39 |
+
|
40 |
+
|
41 |
+
normal_data, adv_data = None, None
|
42 |
+
adversary = PGD(model)
|
43 |
+
|
44 |
+
for x, y in train_loader:
|
45 |
+
x, y = x.cuda(), t.cuda()
|
46 |
+
y_pred = model(x)
|
47 |
+
train_acc += accuracy(y_pred, y)
|
48 |
+
x_adv = adversary.generate(x, y, **attack_params['PGD_CIFAR10']).float()
|
49 |
+
y_adv = model(x_adv)
|
50 |
+
adv_acc += accuracy(y_adv, y)
|
51 |
+
train_n += y.size(0)
|
52 |
+
|
53 |
+
x, x_adv = x.data, x_adv.data
|
54 |
+
if normal_data is None:
|
55 |
+
normal_data, adv_data = x, x_adv
|
56 |
+
else:
|
57 |
+
normal_data = torch.cat((normal_data, x))
|
58 |
+
adv_data = torch.cat((adv_data, x_adv))
|
59 |
+
|
60 |
+
print("Accuracy(normal) {:.6f}, Accuracy(FGSM) {:.6f}".format(train_acc / train_n * 100, adv_acc / train_n * 100))
|
61 |
+
torch.save({"normal": normal_data, "adv": adv_data}, "data.tar")
|
62 |
+
torch.save({"state_dict": model.state_dict()}, "cnn.tar")
|
deeprobust/image/utils.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import numpy as np
|
5 |
+
import urllib.request
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
def create_train_dataset(batch_size = 128, root = '../data'):
|
10 |
+
"""
|
11 |
+
Create different training dataset
|
12 |
+
"""
|
13 |
+
|
14 |
+
transform_train = transforms.Compose([
|
15 |
+
transforms.ToTensor(),
|
16 |
+
])
|
17 |
+
trainset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_train)
|
18 |
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
|
19 |
+
|
20 |
+
return trainloader
|
21 |
+
|
22 |
+
def create_test_dataset(batch_size = 128, root = '../data'):
|
23 |
+
transform_test = transforms.Compose([
|
24 |
+
transforms.ToTensor(),
|
25 |
+
])
|
26 |
+
testset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform_test)
|
27 |
+
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
|
28 |
+
return testloader
|
29 |
+
|
30 |
+
def download_model(url, file):
|
31 |
+
print('Dowloading from {} to {}'.format(url, file))
|
32 |
+
try:
|
33 |
+
urllib.request.urlretrieve(url, file)
|
34 |
+
except:
|
35 |
+
raise Exception("Download failed! Make sure you have stable Internet connection and enter the right name")
|
36 |
+
|
37 |
+
def save_checkpoint(now_epoch, net, optimizer, lr_scheduler, file_name):
|
38 |
+
checkpoint = {'epoch': now_epoch,
|
39 |
+
'state_dict': net.state_dict(),
|
40 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
41 |
+
'lr_scheduler_state_dict':lr_scheduler.state_dict()}
|
42 |
+
if os.path.exists(file_name):
|
43 |
+
print('Overwriting {}'.format(file_name))
|
44 |
+
torch.save(checkpoint, file_name)
|
45 |
+
# link_name = os.path.join(*file_name.split(os.path.sep)[:-1], 'last.checkpoint')
|
46 |
+
# #print(link_name)
|
47 |
+
# make_symlink(source = file_name, link_name=link_name)
|
48 |
+
|
49 |
+
def load_checkpoint(file_name, net = None, optimizer = None, lr_scheduler = None):
|
50 |
+
if os.path.isfile(file_name):
|
51 |
+
print("=> loading checkpoint '{}'".format(file_name))
|
52 |
+
check_point = torch.load(file_name)
|
53 |
+
if net is not None:
|
54 |
+
print('Loading network state dict')
|
55 |
+
net.load_state_dict(check_point['state_dict'])
|
56 |
+
if optimizer is not None:
|
57 |
+
print('Loading optimizer state dict')
|
58 |
+
optimizer.load_state_dict(check_point['optimizer_state_dict'])
|
59 |
+
if lr_scheduler is not None:
|
60 |
+
print('Loading lr_scheduler state dict')
|
61 |
+
lr_scheduler.load_state_dict(check_point['lr_scheduler_state_dict'])
|
62 |
+
|
63 |
+
return check_point['epoch']
|
64 |
+
else:
|
65 |
+
print("=> no checkpoint found at '{}'".format(file_name))
|
66 |
+
|
67 |
+
def make_symlink(source, link_name):
|
68 |
+
"""
|
69 |
+
Note: overwriting enabled!
|
70 |
+
"""
|
71 |
+
|
72 |
+
if os.path.exists(link_name):
|
73 |
+
print("Link name already exist! Removing '{}' and overwriting".format(link_name))
|
74 |
+
os.remove(link_name)
|
75 |
+
if os.path.exists(source):
|
76 |
+
os.symlink(source, link_name)
|
77 |
+
return
|
78 |
+
else:
|
79 |
+
print('Source path not exists')
|
80 |
+
|
81 |
+
from texttable import Texttable
|
82 |
+
def tab_printer(args):
|
83 |
+
"""
|
84 |
+
Function to print the logs in a nice tabular format.
|
85 |
+
input:
|
86 |
+
param args: Parameters used for the model.
|
87 |
+
"""
|
88 |
+
args = vars(args)
|
89 |
+
keys = sorted(args.keys())
|
90 |
+
t = Texttable()
|
91 |
+
t.add_rows([["Parameter", "Value"]] + [[k.replace("_"," ").capitalize(), args[k]] for k in keys])
|
92 |
+
print(t.draw())
|
93 |
+
|
94 |
+
def onehot_like(a, index, value=1):
|
95 |
+
"""Creates an array like a, with all values
|
96 |
+
set to 0 except one.
|
97 |
+
Parameters
|
98 |
+
----------
|
99 |
+
a : array_like
|
100 |
+
The returned one-hot array will have the same shape
|
101 |
+
and dtype as this array
|
102 |
+
index : int
|
103 |
+
The index that should be set to `value`
|
104 |
+
value : single value compatible with a.dtype
|
105 |
+
The value to set at the given index
|
106 |
+
Returns
|
107 |
+
-------
|
108 |
+
`numpy.ndarray`
|
109 |
+
One-hot array with the given value at the given
|
110 |
+
location and zeros everywhere else.
|
111 |
+
"""
|
112 |
+
#TODO: change the note here.
|
113 |
+
x = np.zeros_like(a)
|
114 |
+
x[index] = value
|
115 |
+
return x
|
116 |
+
|
117 |
+
def reduce_sum(x, keepdim=True):
|
118 |
+
# silly PyTorch, when will you get proper reducing sums/means?
|
119 |
+
for a in reversed(range(1, x.dim())):
|
120 |
+
x = x.sum(a, keepdim=keepdim)
|
121 |
+
return x
|
122 |
+
|
123 |
+
def arctanh(x, eps=1e-6):
|
124 |
+
"""
|
125 |
+
Calculate arctanh(x)
|
126 |
+
"""
|
127 |
+
x *= (1. - eps)
|
128 |
+
return (np.log((1 + x) / (1 - x))) * 0.5
|
129 |
+
|
130 |
+
def l2r_dist(x, y, keepdim=True, eps=1e-8):
|
131 |
+
d = (x - y)**2
|
132 |
+
d = reduce_sum(d, keepdim=keepdim)
|
133 |
+
d += eps # to prevent infinite gradient at 0
|
134 |
+
return d.sqrt()
|
135 |
+
|
136 |
+
|
137 |
+
def l2_dist(x, y, keepdim=True):
|
138 |
+
d = (x - y)**2
|
139 |
+
return reduce_sum(d, keepdim=keepdim)
|
140 |
+
|
141 |
+
|
142 |
+
def l1_dist(x, y, keepdim=True):
|
143 |
+
d = torch.abs(x - y)
|
144 |
+
return reduce_sum(d, keepdim=keepdim)
|
145 |
+
|
146 |
+
|
147 |
+
def l2_norm(x, keepdim=True):
|
148 |
+
norm = reduce_sum(x*x, keepdim=keepdim)
|
149 |
+
return norm.sqrt()
|
150 |
+
|
151 |
+
|
152 |
+
def l1_norm(x, keepdim=True):
|
153 |
+
return reduce_sum(x.abs(), keepdim=keepdim)
|
154 |
+
|
155 |
+
def adjust_learning_rate(optimizer, epoch, learning_rate):
|
156 |
+
"""decrease the learning rate"""
|
157 |
+
lr = learning_rate
|
158 |
+
if epoch >= 55:
|
159 |
+
lr = learning_rate * 0.1
|
160 |
+
if epoch >= 75:
|
161 |
+
lr = learning_rate * 0.01
|
162 |
+
if epoch >= 90:
|
163 |
+
lr = learning_rate * 0.001
|
164 |
+
for param_group in optimizer.param_groups:
|
165 |
+
param_group['lr'] = lr
|
166 |
+
|
167 |
+
return optimizer
|
168 |
+
|
169 |
+
def progress_bar(current, total, msg=None):
|
170 |
+
global last_time, begin_time
|
171 |
+
if current == 0:
|
172 |
+
begin_time = time.time() # Reset for new bar.
|
173 |
+
|
174 |
+
cur_len = int(TOTAL_BAR_LENGTH*current/total)
|
175 |
+
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
176 |
+
|
177 |
+
sys.stdout.write(' [')
|
178 |
+
for i in range(cur_len):
|
179 |
+
sys.stdout.write('=')
|
180 |
+
sys.stdout.write('>')
|
181 |
+
for i in range(rest_len):
|
182 |
+
sys.stdout.write('.')
|
183 |
+
sys.stdout.write(']')
|
184 |
+
|
185 |
+
cur_time = time.time()
|
186 |
+
step_time = cur_time - last_time
|
187 |
+
last_time = cur_time
|
188 |
+
tot_time = cur_time - begin_time
|
189 |
+
|
190 |
+
L = []
|
191 |
+
L.append(' Step: %s' % format_time(step_time))
|
192 |
+
L.append(' | Tot: %s' % format_time(tot_time))
|
193 |
+
if msg:
|
194 |
+
L.append(' | ' + msg)
|
195 |
+
|
196 |
+
msg = ''.join(L)
|
197 |
+
sys.stdout.write(msg)
|
198 |
+
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
|
199 |
+
sys.stdout.write(' ')
|
200 |
+
|
201 |
+
# Go back to the center of the bar.
|
202 |
+
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
|
203 |
+
sys.stdout.write('\b')
|
204 |
+
sys.stdout.write(' %d/%d ' % (current+1, total))
|
205 |
+
|
206 |
+
if current < total-1:
|
207 |
+
sys.stdout.write('\r')
|
208 |
+
else:
|
209 |
+
sys.stdout.write('\n')
|
210 |
+
sys.stdout.flush()
|
211 |
+
|
docs/graph/defense.rst
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Introduction to Graph Defense with Examples
|
2 |
+
=======================
|
3 |
+
In this section, we introduce the graph attack algorithms provided
|
4 |
+
in DeepRobust.
|
5 |
+
|
6 |
+
.. contents::
|
7 |
+
:local:
|
8 |
+
|
9 |
+
Test your model's robustness on poisoned graph
|
10 |
+
-------
|
11 |
+
DeepRobust provides a series of defense methods that aim to enhance the robustness
|
12 |
+
of GNNs.
|
13 |
+
|
14 |
+
Victim Models:
|
15 |
+
|
16 |
+
- :class:`deeprobust.graph.defense.GCN`
|
17 |
+
- :class:`deeprobust.graph.defense.GAT`
|
18 |
+
- :class:`deeprobust.graph.defense.ChebNet`
|
19 |
+
- :class:`deeprobust.graph.defense.SGC`
|
20 |
+
|
21 |
+
Node Embedding Victim Models: (see more details `here <https://deeprobust.readthedocs.io/en/latest/graph/node_embedding.html>`_)
|
22 |
+
|
23 |
+
- :class:`deeprobust.graph.defense.DeepWalk`
|
24 |
+
- :class:`deeprobust.graph.defense.Node2Vec`
|
25 |
+
|
26 |
+
Defense Methods:
|
27 |
+
|
28 |
+
- :class:`deeprobust.graph.defense.GCNJaccard`
|
29 |
+
- :class:`deeprobust.graph.defense.GCNSVD`
|
30 |
+
- :class:`deeprobust.graph.defense.ProGNN`
|
31 |
+
- :class:`deeprobust.graph.defense.RGCN`
|
32 |
+
- :class:`deeprobust.graph.defense.SimPGCN`
|
33 |
+
- :class:`deeprobust.graph.defense.AdvTraining`
|
34 |
+
|
35 |
+
#. Load pre-attacked graph data
|
36 |
+
|
37 |
+
.. code-block:: python
|
38 |
+
|
39 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
40 |
+
# load the prognn splits by using setting='prognn'
|
41 |
+
# because the attacked graphs are generated under prognn splits
|
42 |
+
data = Dataset(root='/tmp/', name='cora', setting='prognn')
|
43 |
+
|
44 |
+
adj, features, labels = data.adj, data.features, data.labels
|
45 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
46 |
+
# Load meta attacked data
|
47 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
48 |
+
name='cora',
|
49 |
+
attack_method='meta',
|
50 |
+
ptb_rate=0.05)
|
51 |
+
perturbed_adj = perturbed_data.adj
|
52 |
+
|
53 |
+
#. You can also choose to load graphs attacked by nettack. See details `here <https://deeprobust.readthedocs.io/en/latest/graph/data.html#attacked-graphs-for-node-classification>`_
|
54 |
+
|
55 |
+
.. code-block:: python
|
56 |
+
|
57 |
+
# Load nettack attacked data
|
58 |
+
perturbed_data = PrePtbDataset(root='/tmp/', name='cora',
|
59 |
+
attack_method='nettack',
|
60 |
+
ptb_rate=3.0) # here ptb_rate means number of perturbation per nodes
|
61 |
+
perturbed_adj = perturbed_data.adj
|
62 |
+
idx_test = perturbed_data.target_nodes
|
63 |
+
|
64 |
+
#. Train a victim model (GCN) on clearn/poinsed graph
|
65 |
+
|
66 |
+
.. code-block:: python
|
67 |
+
|
68 |
+
from deeprobust.graph.defense import GCN
|
69 |
+
gcn = GCN(nfeat=features.shape[1],
|
70 |
+
nhid=16,
|
71 |
+
nclass=labels.max().item() + 1,
|
72 |
+
dropout=0.5, device='cpu')
|
73 |
+
gcn = gcn.to('cpu')
|
74 |
+
gcn.fit(features, adj, labels, idx_train, idx_val) # train on clean graph with earlystopping
|
75 |
+
gcn.test(idx_test)
|
76 |
+
|
77 |
+
gcn.fit(features, perturbed_adj, labels, idx_train, idx_val) # train on poisoned graph
|
78 |
+
gcn.test(idx_test)
|
79 |
+
|
80 |
+
#. Train defense models (GCN-Jaccard, RGCN, ProGNN) poinsed graph
|
81 |
+
|
82 |
+
.. code-block:: python
|
83 |
+
|
84 |
+
from deeprobust.graph.defense import GCNJaccard
|
85 |
+
model = GCNJaccard(nfeat=features.shape[1],
|
86 |
+
nhid=16,
|
87 |
+
nclass=labels.max().item() + 1,
|
88 |
+
dropout=0.5, device='cpu').to('cpu')
|
89 |
+
model.fit(features, perturbed_adj, labels, idx_train, idx_val, threshold=0.03)
|
90 |
+
model.test(idx_test)
|
91 |
+
|
92 |
+
.. code-block:: python
|
93 |
+
|
94 |
+
from deeprobust.graph.defense import GCNJaccard
|
95 |
+
model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1],
|
96 |
+
nclass=labels.max()+1, nhid=32, device='cpu')
|
97 |
+
model.fit(features, perturbed_adj, labels, idx_train, idx_val,
|
98 |
+
train_iters=200, verbose=True)
|
99 |
+
model.test(idx_test)
|
100 |
+
|
101 |
+
|
102 |
+
For details in training ProGNN, please refer to `this page <https://github.com/ChandlerBang/Pro-GNN/blob/master/train.py>`_.
|
103 |
+
|
104 |
+
|
105 |
+
More Examples
|
106 |
+
-----------------------
|
107 |
+
More examples can be found in :class:`deeprobust.graph.defense`. You can also find examples in
|
108 |
+
`github code examples <https://github.com/DSE-MSU/DeepRobust/tree/master/examples/graph>`_
|
109 |
+
and more details in `defense table <https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph#defense-methods>`_.
|
docs/graph/node_embedding.rst
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Node Embedding Attack and Defense
|
2 |
+
=======================
|
3 |
+
In this section, we introduce the node embedding attack algorithms and
|
4 |
+
corresponding victim models provided in DeepRobust.
|
5 |
+
|
6 |
+
.. contents::
|
7 |
+
:local:
|
8 |
+
|
9 |
+
|
10 |
+
Node Embedding Attack
|
11 |
+
-----------------------
|
12 |
+
Node embedding attack aims to fool node embedding models produce bad-quality embeddings.
|
13 |
+
Specifically, DeepRobust provides the following node attack algorithms:
|
14 |
+
|
15 |
+
- :class:`deeprobust.graph.global_attack.NodeEmbeddingAttack`
|
16 |
+
- :class:`deeprobust.graph.global_attack.OtherNodeEmbeddingAttack`
|
17 |
+
|
18 |
+
They only take the adjacency matrix as input and the adjacency
|
19 |
+
matrix is in the format of :obj:`scipy.sparse.csr_matrix`. You can specify the attack_type
|
20 |
+
to either add edges or remove edges. Let's take a look at an example:
|
21 |
+
|
22 |
+
.. code-block:: python
|
23 |
+
|
24 |
+
from deeprobust.graph.data import Dataset
|
25 |
+
from deeprobust.graph.global_attack import NodeEmbeddingAttack
|
26 |
+
data = Dataset(root='/tmp/', name='cora_ml', seed=15)
|
27 |
+
adj, features, labels = data.adj, data.features, data.labels
|
28 |
+
model = NodeEmbeddingAttack()
|
29 |
+
model.attack(adj, attack_type="remove")
|
30 |
+
modified_adj = model.modified_adj
|
31 |
+
model.attack(adj, attack_type="remove", min_span_tree=True)
|
32 |
+
modified_adj = model.modified_adj
|
33 |
+
model.attack(adj, attack_type="add", n_candidates=10000)
|
34 |
+
modified_adj = model.modified_adj
|
35 |
+
model.attack(adj, attack_type="add_by_remove", n_candidates=10000)
|
36 |
+
modified_adj = model.modified_adj
|
37 |
+
|
38 |
+
The :obj:`OtherNodeEmbeddingAttack` contains the baseline methods reported in the paper
|
39 |
+
Adversarial Attacks on Node Embeddings via Graph Poisoning. Aleksandar Bojchevski and
|
40 |
+
Stephan Günnemann, ICML 2019. We can specify the type (chosen from
|
41 |
+
`["degree", "eigencentrality", "random"]`) to generate corresponding attacks.
|
42 |
+
|
43 |
+
.. code-block:: python
|
44 |
+
|
45 |
+
from deeprobust.graph.data import Dataset
|
46 |
+
from deeprobust.graph.global_attack import OtherNodeEmbeddingAttack
|
47 |
+
data = Dataset(root='/tmp/', name='cora_ml', seed=15)
|
48 |
+
adj, features, labels = data.adj, data.features, data.labels
|
49 |
+
model = OtherNodeEmbeddingAttack(type='degree')
|
50 |
+
model.attack(adj, attack_type="remove")
|
51 |
+
modified_adj = model.modified_adj
|
52 |
+
#
|
53 |
+
model = OtherNodeEmbeddingAttack(type='eigencentrality')
|
54 |
+
model.attack(adj, attack_type="remove")
|
55 |
+
modified_adj = model.modified_adj
|
56 |
+
#
|
57 |
+
model = OtherNodeEmbeddingAttack(type='random')
|
58 |
+
model.attack(adj, attack_type="add", n_candidates=10000)
|
59 |
+
modified_adj = model.modified_adj
|
60 |
+
|
61 |
+
Node Embedding Victim Models
|
62 |
+
-----------------------
|
63 |
+
DeepRobust provides two node embedding victim models, DeepWalk and Node2Vec:
|
64 |
+
|
65 |
+
- :class:`deeprobust.graph.defense.DeepWalk`
|
66 |
+
- :class:`deeprobust.graph.defense.Node2Vec`
|
67 |
+
|
68 |
+
There are three major functions in the two classes: :obj:`fit()`, :obj:`evaluate_node_classification()`
|
69 |
+
and :obj:`evaluate_link_prediction`. The function :obj:`fit()` will train the node embdding models
|
70 |
+
and store the embedding in :obj:`self.embedding`. For example,
|
71 |
+
|
72 |
+
.. code-block:: python
|
73 |
+
|
74 |
+
from deeprobust.graph.data import Dataset
|
75 |
+
from deeprobust.graph.defense import DeepWalk
|
76 |
+
from deeprobust.graph.global_attack import NodeEmbeddingAttack
|
77 |
+
import numpy as np
|
78 |
+
|
79 |
+
dataset_str = 'cora_ml'
|
80 |
+
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
|
81 |
+
adj, features, labels = data.adj, data.features, data.labels
|
82 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
83 |
+
|
84 |
+
print("Test DeepWalk on clean graph")
|
85 |
+
model = DeepWalk(type="skipgram")
|
86 |
+
model.fit(adj)
|
87 |
+
print(model.embedding)
|
88 |
+
|
89 |
+
After we trained the model, we can then test its performance on node classification and link prediction:
|
90 |
+
|
91 |
+
.. code-block:: python
|
92 |
+
|
93 |
+
print("Test DeepWalk on node classification...")
|
94 |
+
# model.evaluate_node_classification(labels, idx_train, idx_test, lr_params={"max_iter": 1000})
|
95 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
96 |
+
print("Test DeepWalk on link prediciton...")
|
97 |
+
model.evaluate_link_prediction(adj, np.array(adj.nonzero()).T)
|
98 |
+
|
99 |
+
We can then test its performance on the attacked graph:
|
100 |
+
|
101 |
+
.. code-block:: python
|
102 |
+
|
103 |
+
# set up the attack model
|
104 |
+
attacker = NodeEmbeddingAttack()
|
105 |
+
attacker.attack(adj, attack_type="remove", n_perturbations=1000)
|
106 |
+
modified_adj = attacker.modified_adj
|
107 |
+
print("Test DeepWalk on attacked graph")
|
108 |
+
model.fit(modified_adj)
|
109 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
110 |
+
|