Skip to content

Commit 2d0e7ae

Browse files
authored
[GNN] Reference implementation for GNN node classification (#700)
* Reference implementation for GNN node classification * Added guidelines for building the Docker image using the provided Dockerfile, refine the evaluation frequency and thoroughness * minors * Add code of reference implementation * Set the default validation frac to 0.005 * Renamee the folder and add contributors in readme * minor * minors * Round up epoch_num, add GRADIENT_ACCUMULATION_STEPS and OPT_NAME into log outputs Committed-by: LiSu from Dev container * minor Committed-by: LiSu from Dev container
1 parent 68f8f38 commit 2d0e7ae

14 files changed

+1969
-0
lines changed

graph_neural_network/Dockerfile

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel
2+
3+
WORKDIR /workspace/repository
4+
5+
RUN pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
6+
RUN pip install scikit-learn==0.24.2
7+
RUN pip install torch_geometric==2.4.0
8+
RUN pip install --no-index torch_scatter==2.1.1 torch_sparse==0.6.17 -f https://data.pyg.org/whl/torch-1.13.0+cu117.html
9+
RUN pip install graphlearn-torch==0.2.2
10+
11+
RUN apt update
12+
RUN apt install -y git
13+
RUN pip install git+https://github.com/mlcommons/logging.git
14+
15+
# TF32 instead of FP32 for faster compute
16+
ENV NVIDIA_TF32_OVERRIDE=1
17+
18+
RUN git clone https://github.com/alibaba/graphlearn-for-pytorch.git
19+
WORKDIR /workspace/repository/graphlearn-for-pytorch/examples/igbh

graph_neural_network/README.md

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# 1. Problem
2+
This benchmark represents a multi-class node classification task in a heterogenous graph using the [IGB Heterogeneous Dataset](https://github.com/IllinoisGraphBenchmark/IGB-Datasets) named IGBH-Full. The task is carried out using a [GAT](https://arxiv.org/abs/1710.10903) model based on the [Relational Graph Attention Networks](https://arxiv.org/abs/1904.05811) paper.
3+
4+
The reference implementation is based on [graphlearn-for-pytorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch).
5+
6+
# 2. Directions
7+
### Steps to configure machine
8+
9+
#### 1. Clone the repository:
10+
```bash
11+
git clone https://github.com/alibaba/graphlearn-for-pytorch
12+
```
13+
14+
or
15+
```bash
16+
git clone https://github.com/mlcommons/training.git
17+
```
18+
once `GNN node classification` is merged into `mlcommons/training`.
19+
20+
#### 2. Build the docker image:
21+
22+
If you cloned the `graphlearn-for-pytorch` repository:
23+
```bash
24+
cd graphlearn-for-pytorch/examples/igbh/
25+
docker build -f Dockerfile -t training_gnn:latest .
26+
```
27+
28+
If you cloned the `mlcommons/training` repository:
29+
```bash
30+
cd training/gnn_node_classification/
31+
docker build -f Dockerfile -t training_gnn:latest .
32+
```
33+
34+
35+
### Steps to download and verify data
36+
Download the dataset:
37+
```bash
38+
39+
bash download_igbh_full.sh
40+
```
41+
42+
Before training, generate the seeds for training and validation:
43+
```bash
44+
python split_seeds.py --dataset_size='full'
45+
```
46+
47+
The size of the `IGBH-Full` dataset is 2.2 TB. If you want to test with
48+
the `tiny`, `small` or `medium` datasets, the download procedure is included
49+
in the training script.
50+
51+
### Steps to run and time
52+
53+
#### Single-node Training
54+
55+
The original graph is in the `COO` format and the feature is in the FP32 format. The training script will transform the graph from `COO` to `CSC` and convert the feature to FP16, which could be time consuming due to the graph scale. We provide a script to convert the graph layout from `COO` to `CSC` and persist the feature in FP16 format:
56+
57+
```bash
58+
python compress_graph.py --dataset_size='full' --layout='CSC' --use_fp16
59+
```
60+
61+
To train the model using multiple GPUs:
62+
```bash
63+
CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='full' --layout='CSC' --use_fp16
64+
```
65+
The number of training processes is equal to the number of GPUS. Option `--pin_feature` decides if the feature data will be pinned in host memory, which enables zero-copy feature access from GPU, but will incur extra memory costs.
66+
67+
68+
#### Distributed Training
69+
70+
##### 1. Data Partitioning
71+
To partition the dataset (including both the topology and feature):
72+
```bash
73+
python partition.py --dataset_size='full' --num_partitions=2 --use_fp16 --layout='CSC'
74+
```
75+
The above script will partition the dataset into two parts, convert the feature into
76+
the FP16 format, and transform the graph layout from `COO` to `CSC`.
77+
78+
We suggest using a distributed file system to store the partitioned data, such as HDFS or NFS, suhc that partitioned data can be accessed by all training nodes.
79+
80+
##### 2. Two-stage Data Partitioning
81+
To speed up the partitioning process, GLT also supports two-stage partitioning, which splits the process of topology partitioning and feature partitioning. After the topology partitioning is executed in a single node, the feature partitioning process can be conducted in each training node in parallel to speedup the partitioning process.
82+
83+
The topology partitioning is conducted by executing:
84+
```bash
85+
python partition.py --dataset_size='full' --num_partitions=2 --layout='CSC' --with_feature=0
86+
```
87+
88+
The feature partitioning in conducted in each training node:
89+
```bash
90+
# node 0 which holds partition 0:
91+
python build_partition_feature.py --dataset_size='full' --use_fp16 --in_memory=0 --partition_idx=0
92+
93+
# node 1 which holds partition 1:
94+
python build_partition_feature.py --dataset_size='full' --use_fp16 --in_memory=0 --partition_idx=1
95+
```
96+
97+
##### 2. Model Training
98+
The number of partitions and number of training nodes must be the same. In each training node, the model can be trained using the following command:
99+
100+
```bash
101+
# node 0:
102+
CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=2 --master_addr=master_address_ip --model='rgat' --dataset_size='full' --layout='CSC'
103+
104+
# node 1:
105+
CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=2 --master_addr=master_address_ip --model='rgat' --dataset_size='full' --layout='CSC'
106+
```
107+
The above script assumes that the training nodes are equipped with 2 GPUs and the number of training processes is equal to the number of GPUs. Each training process has a corresponding
108+
sampling process using the same GPU.
109+
110+
The `master_address_ip` should be replaced with the actual IP address of the master node. The `--pin_feature` option decides if the feature data will be pinned in host memory, which enables zero-copy feature access from GPU but will incur extra memory costs.
111+
112+
113+
We recommend separating the sampling and training processes to different GPUs to achieve better performance. To seperate the GPU used by sampling and training processes, please add `--split_training_sampling` and set `--num_training_procs` as half of the number of devices:
114+
115+
```bash
116+
# node 0:
117+
CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='full' --layout='CSC' --split_training_sampling
118+
119+
# node 1:
120+
CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='full' --layout='CSC' --split_training_sampling
121+
```
122+
The above script uses one GPU for training and another for sampling in each node.
123+
124+
125+
126+
# 3. Dataset/Environment
127+
### Publication/Attribution
128+
Arpandeep Khatua, Vikram Sharma Mailthody, Bhagyashree Taleka, Tengfei Ma, Xiang Song, and Wen-mei Hwu. 2023. IGB: Addressing The Gaps In Labeling, Features, Heterogeneity, and Size of Public Graph Datasets for Deep Learning Research. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD '23). Association for Computing Machinery, New York, NY, USA, 4284–4295. https://doi.org/10.1145/3580305.3599843
129+
130+
### Data preprocessing
131+
The original graph is in the `COO` format and the feature is in FP32 format. It is allowed to transform the graph from `COO` to `CSC` and convert the feature to FP16 (supported by the training script).
132+
133+
### Training and test data separation
134+
The training and validation data are selected from the labeled ``paper`` nodes from the dataset and are generated by `split_seeds.py`. Differnet random seeds will result in different training and test data.
135+
136+
### Training data order
137+
Randomly.
138+
139+
### Test data order
140+
Randomly.
141+
142+
# 4. Model
143+
### Publication/Attribution
144+
Dan Busbridge and Dane Sherburn and Pietro Cavallo and Nils Y. Hammerla, Relational Graph Attention Networks, 2019, https://arxiv.org/abs/1904.05811
145+
146+
### List of layers
147+
Three-layer RGAT model
148+
149+
### Loss function
150+
CrossEntropyLoss
151+
152+
### Optimizer
153+
Adam
154+
155+
# 5. Quality
156+
### Quality metric
157+
The validation accuracy is the target quality metric.
158+
### Quality target
159+
0.72
160+
### Evaluation frequency
161+
4,730,280 training seeds (5% of the entire training seeds, evaluated every 0.05 epoch)
162+
### Evaluation thoroughness
163+
788,380 validation seeds
164+
165+
# 6. Contributors
166+
This benchmark is a collaborative effort with contributions from Alibaba, Intel, and Nvidia:
167+
168+
- Alibaba: Li Su, Baole Ai, Wenting Shen, Shuxian Hu, Wenyuan Yu, Yong Li
169+
- Nvidia: Yunzhou (David) Liu, Kyle Kranen, Shriya Palasamudram
170+
- Intel: Kaixuan Liu, Hesham Mostafa, Sasikanth Avancha, Keith Achorn, Radha Giduthuri, Deepak Canchi
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import argparse
2+
import os.path as osp
3+
4+
import graphlearn_torch as glt
5+
import torch
6+
7+
from dataset import IGBHeteroDataset
8+
9+
10+
def partition_feature(src_path: str,
11+
dst_path: str,
12+
partition_idx: int,
13+
chunk_size: int,
14+
dataset_size: str='tiny',
15+
in_memory: bool=True,
16+
use_fp16: bool=False):
17+
print(f'-- Loading igbh_{dataset_size} ...')
18+
data = IGBHeteroDataset(src_path, dataset_size, in_memory, with_edges=False, use_fp16=use_fp16)
19+
20+
print(f'-- Build feature for partition {partition_idx} ...')
21+
dst_path = osp.join(dst_path, f'{dataset_size}-partitions')
22+
node_feat_dtype = torch.float16 if use_fp16 else torch.float32
23+
glt.partition.base.build_partition_feature(root_dir = dst_path,
24+
partition_idx = partition_idx,
25+
chunk_size = chunk_size,
26+
node_feat = data.feat_dict,
27+
node_feat_dtype = node_feat_dtype)
28+
29+
30+
if __name__ == '__main__':
31+
root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
32+
glt.utils.ensure_dir(root)
33+
parser = argparse.ArgumentParser(description="Arguments for partitioning ogbn datasets.")
34+
parser.add_argument('--src_path', type=str, default=root,
35+
help='path containing the datasets')
36+
parser.add_argument('--dst_path', type=str, default=root,
37+
help='path containing the partitioned datasets')
38+
parser.add_argument('--dataset_size', type=str, default='full',
39+
choices=['tiny', 'small', 'medium', 'large', 'full'],
40+
help='size of the datasets')
41+
parser.add_argument('--in_memory', type=int, default=0,
42+
choices=[0, 1], help='0:read only mmap_mode=r, 1:load into memory')
43+
parser.add_argument("--partition_idx", type=int, default=0,
44+
help="Index of a partition")
45+
parser.add_argument("--chunk_size", type=int, default=10000,
46+
help="Chunk size for feature partitioning.")
47+
parser.add_argument("--use_fp16", action="store_true",
48+
help="save node/edge feature using fp16 format")
49+
50+
51+
args = parser.parse_args()
52+
53+
partition_feature(
54+
args.src_path,
55+
args.dst_path,
56+
partition_idx=args.partition_idx,
57+
chunk_size=args.chunk_size,
58+
dataset_size=args.dataset_size,
59+
in_memory=args.in_memory==1,
60+
use_fp16=args.use_fp16
61+
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import argparse, datetime, os
2+
import numpy as np
3+
import torch
4+
import os.path as osp
5+
6+
import graphlearn_torch as glt
7+
8+
from dataset import float2half
9+
from download import download_dataset
10+
from torch_geometric.utils import add_self_loops, remove_self_loops
11+
from typing import Literal
12+
13+
14+
class IGBHeteroDatasetCompress(object):
15+
def __init__(self,
16+
path,
17+
dataset_size,
18+
layout: Literal['CSC', 'CSR'] = 'CSC',):
19+
self.dir = path
20+
self.dataset_size = dataset_size
21+
self.layout = layout
22+
23+
self.ntypes = ['paper', 'author', 'institute', 'fos']
24+
self.etypes = None
25+
self.edge_dict = {}
26+
self.paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174}
27+
self.author_nodes_num = {'tiny':357041, 'small':1926066, 'medium':15544654, 'large':116959896, 'full':277220883}
28+
if not osp.exists(osp.join(path, self.dataset_size, 'processed')):
29+
download_dataset(path, 'heterogeneous', dataset_size)
30+
self.process()
31+
32+
def process(self):
33+
paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
34+
'paper__cites__paper', 'edge_index.npy'))).t()
35+
author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
36+
'paper__written_by__author', 'edge_index.npy'))).t()
37+
affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
38+
'author__affiliated_to__institute', 'edge_index.npy'))).t()
39+
paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
40+
'paper__topic__fos', 'edge_index.npy'))).t()
41+
if self.dataset_size in ['large', 'full']:
42+
paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
43+
'paper__published__journal', 'edge_index.npy'))).t()
44+
paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
45+
'paper__venue__conference', 'edge_index.npy'))).t()
46+
47+
cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0]
48+
self.edge_dict = {
49+
('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])),
50+
('paper', 'written_by', 'author'): author_paper_edges,
51+
('author', 'affiliated_to', 'institute'): affiliation_author_edges,
52+
('paper', 'topic', 'fos'): paper_fos_edges,
53+
('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]),
54+
('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]),
55+
('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :])
56+
}
57+
if self.dataset_size in ['large', 'full']:
58+
self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal
59+
self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference
60+
self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :])
61+
self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :])
62+
self.etypes = list(self.edge_dict.keys())
63+
64+
# init graphlearn_torch Dataset.
65+
edge_dir = 'out' if self.layout == 'CSR' else 'in'
66+
glt_dataset = glt.data.Dataset(edge_dir=edge_dir)
67+
glt_dataset.init_graph(
68+
edge_index=self.edge_dict,
69+
graph_mode='CPU',
70+
)
71+
72+
# save the corresponding csr or csc file
73+
compress_edge_dict = {}
74+
compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper'
75+
compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author'
76+
compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute'
77+
compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos'
78+
compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper'
79+
compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author'
80+
compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper'
81+
compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal'
82+
compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference'
83+
compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper'
84+
compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper'
85+
86+
for etype in self.etypes:
87+
graph = glt_dataset.get_graph(etype)
88+
indptr, indices, _ = graph.export_topology()
89+
path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype])
90+
if not os.path.exists(path):
91+
os.makedirs(path)
92+
torch.save(indptr, os.path.join(path, 'indptr.pt'))
93+
torch.save(indices, os.path.join(path, 'indices.pt'))
94+
path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout)
95+
print(f"The {self.layout} graph has been persisted in path: {path}")
96+
97+
98+
99+
if __name__ == '__main__':
100+
parser = argparse.ArgumentParser()
101+
root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
102+
glt.utils.ensure_dir(root)
103+
parser.add_argument('--path', type=str, default=root,
104+
help='path containing the datasets')
105+
parser.add_argument('--dataset_size', type=str, default='full',
106+
choices=['tiny', 'small', 'medium', 'large', 'full'],
107+
help='size of the datasets')
108+
parser.add_argument("--layout", type=str, default='CSC')
109+
parser.add_argument('--use_fp16', action="store_true",
110+
help="convert the node/edge feature into fp16 format")
111+
args = parser.parse_args()
112+
print(f"Start constructing the {args.layout} graph...")
113+
igbh_dataset = IGBHeteroDatasetCompress(args.path, args.dataset_size, args.layout)
114+
if args.use_fp16:
115+
base_path = osp.join(args.path, args.dataset_size, 'processed')
116+
float2half(base_path, args.dataset_size)
117+
118+
119+
120+

0 commit comments

Comments
 (0)