Customize Spectral Modules

Add New Filter

New spectral filters to pyg_spectral.nn.conv can be easily implemented by only three steps, then enjoys a range of model architectures, analysis utilities, and training schemes.

Step 1: Define propagation matrix

The base class nn.conv.BaseMP provides essential methods for building spectral filters. We can define a new filter class nn.conv.SkipConv by inheriting from it:

from torch import Tensor
from pyg_spectral.nn.conv.base_mp import BaseMP

class SkipConv(BaseMP):
    def __init__(self, num_hops, hop, cached, **kwargs):
        kwargs['propagate_mat'] = 'A-I'
        super(SkipConv, self).__init__(num_hops, hop, cached, **kwargs)

The propagation matrix is specified by the propagate_mat argument as a string. Each matrix can be the normalized adjacency matrix (A) or the normalized Laplacian matrix (L), with optional diagonal scaling, where the scaling factor can either be a number or an attribute name of the class. Multiple propagation matrices can be combined by ,. Valid examples: A, L-2*I, L,A+I,L-alpha*I.

Step 2: Prepare representation matrix

Similar to PyG modules, our spectral filter class takes the graph attribute x and edge index edge_index as input. The _get_convolute_mat() method prepares the representation matrices used in recurrent computation as a dictionary:

def _get_convolute_mat(self, x, edge_index):
    return {'x': x, 'x_1': x}

The above example overwrites the method for SkipConv, returning the input feature x and a placeholder x_1 for the representation in the previous hop.

Step 3: Derive recurrent forward

The _forward() method implements recurrent computation of the filter. Its input/output is a dictionary combining the propagation matrices defined by propagate_mat and the representation matrices prepared by _get_convolute_mat().

def _forward(self, x, x_1, prop):
    if self.hop == 0:
        # No propagation for k=0
        return {'x': x, 'x_1': x, 'prop': prop}

    h = self.propagate(prop, x=x)
    h = h + x_1
    return {'x': h, 'x_1': x, 'prop': prop}

Similar to PyG modules, the propagate() method conducts graph propagation by the given matrices. The above example corresponds to the graph propagation with a skip connection to the previous representation: \(H^{(k)} = (A-I)H^{(k-1)} + H^{(k-2)}\).

Build the model!

Now the SkipConv filter is properly defined. The following snippet use the nn.models.DecoupledVar model composing 10 hops of SkipConv filters, which can be used as a normal PyTorch model:

from pyg_spectral.nn.models import DecoupledVar

model = DecoupledVar(conv='SkipConv', num_hops=10, in_channels=x.size(1), hidden_channels=x.size(1), out_channels=x.size(1))
out = model(x, edge_index)