Split Learning for Graph Neural Network#

The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.

Setup#

Create two participant alice and bob.

[1]:
import secretflow as sf

# In case you got a running secetflow runtime already.
sf.shutdown()

sf.init(parties=['alice', 'bob'])

alice, bob = sf.PYU('alice'), sf.PYU('bob')

Prepare the Dataset#

The cora dataset#

The cora dataset has two tap-separated files: cora.cites and cora.content.

  • The cora.cites includes the citation records with two columns: cited_paper_id (target) and citing_paper_id (source).

  • The cora.content includes the paper content records with 1,435 columns: paper_id, subject, and 1,433 binary features.

Let us use the partitioned cora dataset, which is already a built-in dataset of SecretFlow.

  • The train set includes 140 cited_paper_ids.

  • The test set includes 1000 cited_paper_ids.

  • The valid set includes 500 cited_paper_ids.

Split the dataset#

Let us split the dataset for split learning.

  • Alice holds the 1~716 features, and bob holds the left.

  • Alice holds all label.

  • Alice and bob hold all edges both.

[2]:
import networkx as nx
import numpy as np
import os
import pickle
import scipy
import zipfile
import tempfile
from pathlib import Path
from secretflow.utils.simulation.datasets import dataset
from secretflow.data.ndarray import load


def load_cora():
    dataset_zip = dataset('cora')
    extract_path = str(Path(dataset_zip).parent)
    with zipfile.ZipFile(dataset_zip, 'r') as zip_f:
        zip_f.extractall(extract_path)

    file_names = [
        os.path.join(extract_path, f'ind.cora.{name}')
        for name in ['y', 'tx', 'ty', 'allx', 'ally', 'graph']
    ]

    objects = []
    for name in file_names:
        with open(name, 'rb') as f:
            objects.append(pickle.load(f, encoding='latin1'))

    y, tx, ty, allx, ally, graph = tuple(objects)

    with open(os.path.join(extract_path, f"ind.cora.test.index"), 'r') as f:
        test_idx_reorder = f.readlines()
    test_idx_reorder = list(map(lambda s: int(s.strip()), test_idx_reorder))
    test_idx_range = np.sort(test_idx_reorder)

    nodes = scipy.sparse.vstack((allx, tx)).tolil()
    nodes[test_idx_reorder, :] = nodes[test_idx_range, :]
    edge = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
    edge = edge.toarray() + np.eye(edge.shape[1])

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y) + 500)

    def sample_mask(idx, length):
        mask = np.zeros(length)
        mask[idx] = 1
        return np.array(mask, dtype=bool)

    idx_train = sample_mask(idx_train, labels.shape[0])
    idx_val = sample_mask(idx_val, labels.shape[0])
    idx_test = sample_mask(idx_test, labels.shape[0])

    y_train = np.zeros(labels.shape)
    y_val = np.zeros(labels.shape)
    y_test = np.zeros(labels.shape)
    y_train[idx_train, :] = labels[idx_train, :]
    y_val[idx_val, :] = labels[idx_val, :]
    y_test[idx_test, :] = labels[idx_test, :]

    nodes = nodes.toarray()
    features_split_pos = round(nodes.shape[1] / 2)
    nodes_alice, nodes_bob = nodes[:, :features_split_pos], nodes[:, features_split_pos:]
    temp_dir = tempfile.mkdtemp()
    saved_files = [
        os.path.join(temp_dir, name)
        for name in [
            'edge.npy',
            'x_alice.npy',
            'x_bob.npy',
            'y_train.npy',
            'y_val.npy',
            'y_test.npy',
            'idx_train.npy',
            'idx_val.npy',
            'idx_test.npy',
        ]
    ]
    np.save(saved_files[0], edge)
    np.save(saved_files[1], nodes_alice)
    np.save(saved_files[2], nodes_bob)
    np.save(saved_files[3], y_train)
    np.save(saved_files[4], y_val)
    np.save(saved_files[5], y_test)
    np.save(saved_files[6], idx_train)
    np.save(saved_files[7], idx_val)
    np.save(saved_files[8], idx_test)
    return saved_files

saved_files = load_cora()

edge = load({alice: saved_files[0], bob: saved_files[0]})
features = load({alice: saved_files[1], bob: saved_files[2]})
Y_train = load({alice: saved_files[3]})
Y_val = load({alice: saved_files[4]})
Y_test = load({alice: saved_files[5]})
idx_train = load({alice: saved_files[6]})
idx_val = load({alice: saved_files[7]})
idx_test = load({alice: saved_files[8]})

/tmp/ipykernel_754102/2692670243.py:27: DeprecationWarning: Please use `csr_matrix` from the `scipy.sparse` namespace, the `scipy.sparse.csr` namespace is deprecated.
  objects.append(pickle.load(f, encoding='latin1'))
/tmp/ipykernel_754102/2692670243.py:38: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.
  edge = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

By the way, since cora is a built-in dataset of SecretFlow, you can just run the following snippet to replace the codes above.

[3]:
from secretflow.utils.simulation.datasets import load_cora

(edge, features, Y_train, Y_val, Y_test, idx_train, idx_val, idx_test) = load_cora(
    [alice, bob]
)

Build a Graph Neural Network Model#

Implement a graph convolution layer#

[4]:
import tensorflow as tf
from tensorflow.keras import activations
from tensorflow.keras import backend as K
from tensorflow.keras import constraints, initializers, regularizers
from tensorflow.keras.layers import Dropout, Layer, LeakyReLU


class GraphAttention(Layer):
    def __init__(
        self,
        F_,
        attn_heads=1,
        attn_heads_reduction='average',  # {'concat', 'average'}
        dropout_rate=0.5,
        activation='relu',
        use_bias=True,
        kernel_initializer='glorot_uniform',
        bias_initializer='zeros',
        attn_kernel_initializer='glorot_uniform',
        kernel_regularizer=None,
        bias_regularizer=None,
        attn_kernel_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        attn_kernel_constraint=None,
        **kwargs,
    ):
        if attn_heads_reduction not in {'concat', 'average'}:
            raise ValueError('Possbile reduction methods: concat, average')

        self.F_ = F_  # Number of output features (F' in the paper)
        self.attn_heads = attn_heads  # Number of attention heads (K in the paper)
        self.attn_heads_reduction = attn_heads_reduction
        self.dropout_rate = dropout_rate  # Internal dropout rate
        self.activation = activations.get(activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.attn_kernel_initializer = initializers.get(attn_kernel_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.attn_kernel_regularizer = regularizers.get(attn_kernel_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.attn_kernel_constraint = constraints.get(attn_kernel_constraint)
        self.supports_masking = False

        # Populated by build()
        self.kernels = []  # Layer kernels for attention heads
        self.biases = []  # Layer biases for attention heads
        self.attn_kernels = []  # Attention kernels for attention heads

        if attn_heads_reduction == 'concat':
            # Output will have shape (..., K * F')
            self.output_dim = self.F_ * self.attn_heads
        else:
            # Output will have shape (..., F')
            self.output_dim = self.F_

        super(GraphAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) >= 2
        F = input_shape[0][-1]

        # Initialize weights for each attention head
        for head in range(self.attn_heads):
            # Layer kernel
            kernel = self.add_weight(
                shape=(F, self.F_),
                initializer=self.kernel_initializer,
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
                name='kernel_{}'.format(head),
            )
            self.kernels.append(kernel)

            # # Layer bias
            if self.use_bias:
                bias = self.add_weight(
                    shape=(self.F_,),
                    initializer=self.bias_initializer,
                    regularizer=self.bias_regularizer,
                    constraint=self.bias_constraint,
                    name='bias_{}'.format(head),
                )
                self.biases.append(bias)

            # Attention kernels
            attn_kernel_self = self.add_weight(
                shape=(self.F_, 1),
                initializer=self.attn_kernel_initializer,
                regularizer=self.attn_kernel_regularizer,
                constraint=self.attn_kernel_constraint,
                name='attn_kernel_self_{}'.format(head),
            )
            attn_kernel_neighs = self.add_weight(
                shape=(self.F_, 1),
                initializer=self.attn_kernel_initializer,
                regularizer=self.attn_kernel_regularizer,
                constraint=self.attn_kernel_constraint,
                name='attn_kernel_neigh_{}'.format(head),
            )
            self.attn_kernels.append([attn_kernel_self, attn_kernel_neighs])
        self.built = True

    def call(self, inputs):
        X = inputs[0]  # Node features (N x F)
        A = inputs[1]  # Adjacency matrix (N x N)

        outputs = []
        for head in range(self.attn_heads):
            kernel = self.kernels[head]  # W in the paper (F x F')
            attention_kernel = self.attn_kernels[
                head
            ]  # Attention kernel a in the paper (2F' x 1)

            # Compute inputs to attention network
            features = K.dot(X, kernel)  # (N x F')

            # Compute feature combinations
            # Note: [[a_1], [a_2]]^T [[Wh_i], [Wh_2]] = [a_1]^T [Wh_i] + [a_2]^T [Wh_j]
            attn_for_self = K.dot(
                features, attention_kernel[0]
            )  # (N x 1), [a_1]^T [Wh_i]
            attn_for_neighs = K.dot(
                features, attention_kernel[1]
            )  # (N x 1), [a_2]^T [Wh_j]

            # Attention head a(Wh_i, Wh_j) = a^T [[Wh_i], [Wh_j]]
            dense = attn_for_self + K.transpose(
                attn_for_neighs
            )  # (N x N) via broadcasting

            # Add nonlinearty
            dense = LeakyReLU(alpha=0.2)(dense)

            # Mask values before activation (Vaswani et al., 2017)
            mask = -10e9 * (1.0 - A)
            dense += mask

            # Apply softmax to get attention coefficients
            dense = K.softmax(dense)  # (N x N)

            # Apply dropout to features and attention coefficients
            dropout_attn = Dropout(self.dropout_rate)(dense)  # (N x N)
            dropout_feat = Dropout(self.dropout_rate)(features)  # (N x F')

            # Linear combination with neighbors' features
            node_features = K.dot(dropout_attn, dropout_feat)  # (N x F')

            if self.use_bias:
                node_features = K.bias_add(node_features, self.biases[head])

            # Add output of attention head to final output
            outputs.append(node_features)

        # Aggregate the heads' output according to the reduction method
        if self.attn_heads_reduction == 'concat':
            output = K.concatenate(outputs)  # (N x KF')
        else:
            output = K.mean(K.stack(outputs), axis=0)  # N x F')

        output = self.activation(output)
        return output

    def compute_output_shape(self, input_shape):
        output_shape = input_shape[0][0], self.output_dim
        return output_shape

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'attn_heads': self.attn_heads,
                'attn_heads_reduction': self.attn_heads_reduction,
                'F_': self.F_,
            }
        )
        return config

Implement a fuse net#

The fuse model is used in the party with the label. It works as follows: 1. Use the concated node embeddings to generat the final node embeddings. 2. Feed the node embeddings in a Softmax layer to predict the node class.

[5]:

class ServerNet(tf.keras.layers.Layer): def __init__( self, in_channel: int, hidden_size: int, num_layer: int, num_class: int, dropout: float, **kwargs, ): super(ServerNet, self).__init__() self.num_class = num_class self.num_layer = num_layer self.hidden_size = hidden_size self.in_channel = in_channel self.dropout = dropout self.layers = [] super(ServerNet, self).__init__(**kwargs) def build(self, input_shape): self.layers.append( tf.keras.layers.Dense(self.hidden_size, input_shape=(self.in_channel,)) ) for i in range(self.num_layer - 2): self.layers.append( tf.keras.layers.Dense(self.hidden_size, input_shape=(self.hidden_size,)) ) self.layers.append( tf.keras.layers.Dense(self.num_class, input_shape=(self.hidden_size,)) ) super(ServerNet, self).build(input_shape) def call(self, inputs): x = inputs x = Dropout(self.dropout)(x) for i in range(self.num_layer): x = Dropout(self.dropout)(x) x = self.layers[i](x) return K.softmax(x) def compute_output_shape(self, input_shape): output_shape = self.hidden_size, self.output_dim return output_shape def get_config(self): config = super().get_config().copy() config.update( { 'in_channel': self.in_channel, 'hidden_size': self.hidden_size, 'num_layer': self.num_layer, 'num_class': self.num_class, 'dropout': self.dropout, } ) return config

Build the base model#

The base model is used in each party to generate node embeddings. It applys one graph convolutional layer to produce node embeddings.

The node embeddings of all parties are then transfered to the party with labels for further processing.

[6]:
from tensorflow.keras.models import Model

def create_base_model(input_shape, n_hidden, l2_reg, num_heads, dropout_rate, learning_rate):
    def base_model():
        feature_input = tf.keras.Input(shape=(input_shape[1],))
        graph_input = tf.keras.Input(shape=(input_shape[0],))
        regular = tf.keras.regularizers.l2(l2_reg)
        outputs = GraphAttention(
            F_=n_hidden,
            attn_heads=num_heads,
            attn_heads_reduction='average',  # {'concat', 'average'}
            dropout_rate=dropout_rate,
            activation='relu',
            use_bias=True,
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros',
            attn_kernel_initializer='glorot_uniform',
            kernel_regularizer=regular,
            bias_regularizer=None,
            attn_kernel_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            attn_kernel_constraint=None,
        )([feature_input, graph_input])
        # outputs = tf.keras.layers.Flatten()(outputs)
        model = Model(inputs=[feature_input, graph_input], outputs=outputs)
        model._name = "embed_model"
        # Compile model
        model.summary()
        metrics = ['acc']
        optimizer = tf.keras.optimizers.get(
            {
                'class_name': 'adam',
                'config': {'learning_rate': learning_rate},
            }
        )
        model.compile(
            loss='categorical_crossentropy',
            weighted_metrics=metrics,
            optimizer=optimizer,
        )
        return model

    return base_model

Build the fuse model#

The fuse model concat the node embeddings from all parties as input. It works only in the party with the label.

[7]:
from tensorflow import keras
from tensorflow.keras import layers


def create_fuse_model(hidden_units, hidden_size, n_classes, layer_num, learning_rate):
    def fuse_model():
        inputs = [keras.Input(shape=size) for size in hidden_units]
        x = layers.concatenate(inputs)
        input_shape = x.shape[-1]
        y_pred = ServerNet(
            in_channel=input_shape,
            hidden_size=hidden_size,
            num_layer=layer_num,
            num_class=n_classes,
            dropout=0.0,
        )(x)
        # Create the model.
        model = keras.Model(inputs=inputs, outputs=y_pred, name="fuse_model")
        model.summary()
        metrics = ['acc']
        optimizer = tf.keras.optimizers.get(
            {'class_name': 'adam', 'config': {'learning_rate': learning_rate},}
        )
        model.compile(
            loss='categorical_crossentropy',
            weighted_metrics=metrics,
            optimizer=optimizer,
        )
        return model

    return fuse_model

Train GNN model with split learning#

Let us build a split learning model for training.

Alice who has the label holds a base model and a fuse model, while bob holds a base model only.

The whole model structure is as follow

split_learning_gnn_model.png

[8]:
from secretflow.ml.nn import SLModel

hidden_size = 256
n_classes = 7
attn_heads = 2
layer_num = 3
learning_rate = 1e-3
dropout_rate = 0.0
l2_reg = 0.1
num_heads = 4
epochs = 10
optimizer = 'adam'

partition_shapes = features.partition_shape()

input_shape_alice = partition_shapes[alice]
input_shape_bob = partition_shapes[bob]

sl_model = SLModel(
    base_model_dict={
        alice: create_base_model(
            input_shape_alice,
            hidden_size,
            l2_reg,
            num_heads,
            dropout_rate,
            learning_rate,
        ),
        bob: create_base_model(
            input_shape_bob,
            hidden_size,
            l2_reg,
            num_heads,
            dropout_rate,
            learning_rate,
        ),
    },
    device_y=alice,
    model_fuse=create_fuse_model(
        [hidden_size, hidden_size], hidden_size, n_classes, layer_num, learning_rate
    ),
)


Fit the model.

[9]:

sl_model.fit( x=[features, edge], y=Y_train, epochs=epochs, batch_size=input_shape_alice[0], sample_weight=idx_train, validation_data=([features, edge], Y_val, idx_val), )
100%|██████████| 1/1 [00:05<00:00,  5.67s/it, epoch: 0/10 -  train_loss:0.10079389065504074  train_acc:0.12857143580913544  val_loss:0.35441863536834717  val_acc:0.3880000114440918 ]
100%|██████████| 1/1 [00:01<00:00,  1.02s/it, epoch: 1/10 -  train_loss:0.2256714105606079  train_acc:0.44843751192092896  val_loss:0.3481321930885315  val_acc:0.5640000104904175 ]
100%|██████████| 1/1 [00:01<00:00,  1.01s/it, epoch: 2/10 -  train_loss:0.22043751180171967  train_acc:0.637499988079071  val_loss:0.34046509861946106  val_acc:0.6320000290870667 ]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, epoch: 3/10 -  train_loss:0.21422825753688812  train_acc:0.703125  val_loss:0.33100318908691406  val_acc:0.6539999842643738 ]
100%|██████████| 1/1 [00:01<00:00,  1.04s/it, epoch: 4/10 -  train_loss:0.20669467747211456  train_acc:0.7250000238418579  val_loss:0.3193798065185547  val_acc:0.6840000152587891 ]
100%|██████████| 1/1 [00:01<00:00,  1.10s/it, epoch: 5/10 -  train_loss:0.19755281507968903  train_acc:0.7484375238418579  val_loss:0.30531173944473267  val_acc:0.7080000042915344 ]
100%|██████████| 1/1 [00:01<00:00,  1.03s/it, epoch: 6/10 -  train_loss:0.1866169422864914  train_acc:0.7671874761581421  val_loss:0.28866109251976013  val_acc:0.7279999852180481 ]
100%|██████████| 1/1 [00:01<00:00,  1.03s/it, epoch: 7/10 -  train_loss:0.17386208474636078  train_acc:0.7828124761581421  val_loss:0.26949769258499146  val_acc:0.7319999933242798 ]
100%|██████████| 1/1 [00:01<00:00,  1.04s/it, epoch: 8/10 -  train_loss:0.1594790667295456  train_acc:0.785937488079071  val_loss:0.24824994802474976  val_acc:0.7319999933242798 ]
100%|██████████| 1/1 [00:01<00:00,  1.06s/it, epoch: 9/10 -  train_loss:0.1439441591501236  train_acc:0.785937488079071  val_loss:0.2257150411605835  val_acc:0.7400000095367432 ]
[9]:
{'train_loss': [0.10079389,
  0.22567141,
  0.22043751,
  0.21422826,
  0.20669468,
  0.19755282,
  0.18661694,
  0.17386208,
  0.15947907,
  0.14394416],
 'train_acc': [0.12857144,
  0.4484375,
  0.6375,
  0.703125,
  0.725,
  0.7484375,
  0.7671875,
  0.7828125,
  0.7859375,
  0.7859375],
 'val_loss': [0.35441864,
  0.3481322,
  0.3404651,
  0.3310032,
  0.3193798,
  0.30531174,
  0.2886611,
  0.2694977,
  0.24824995,
  0.22571504],
 'val_acc': [0.388,
  0.564,
  0.632,
  0.654,
  0.684,
  0.708,
  0.728,
  0.732,
  0.732,
  0.74]}

Examine the GNN model predictions

[10]:
sl_model.evaluate(
    x=[features, edge],
    y=Y_test,
    batch_size=input_shape_alice[0],
    sample_weight=idx_test,
)
Evaluate Processing:: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s, loss:0.4411766827106476 acc:0.7720000147819519]
[10]:
{'loss': 0.44117668, 'acc': 0.772}

Conclusion#

In this tutorial, we demonstrate how to train graph neural network with split learning. This is a very basic implementation and there are some works we will explore in the future:

  • SGD on large graphs: in the example above, in each training step, we have to do prepare, aggregate and update on whole graph, which is extremely computation intensive. We should perform stochastic minibatch training to reduce computation and memory comsumption.

  • Partially aligned graphs: in the example above, parties must have same nodes set which may not be satisfied in real cases. We want to explore the case where all parties have common subset nodes.