Customize Spectral Modules
Add New Filter
New spectral filters to pyg_spectral.nn.conv can be easily implemented by only three steps, then enjoys a range of model architectures, analysis utilities, and training schemes.
Step 1: Define propagation matrix
The base class nn.conv.BaseMP provides essential methods for building spectral filters. We can define a new filter class nn.conv.SkipConv by inheriting from it:
from torch import Tensor
from pyg_spectral.nn.conv.base_mp import BaseMP
class SkipConv(BaseMP):
def __init__(self, num_hops, hop, cached, **kwargs):
kwargs['propagate_mat'] = 'A-I'
super(SkipConv, self).__init__(num_hops, hop, cached, **kwargs)
The propagation matrix is specified by the propagate_mat argument as a string. Each matrix can be the normalized adjacency matrix (A) or the normalized Laplacian matrix (L), with optional diagonal scaling, where the scaling factor can either be a number or an attribute name of the class. Multiple propagation matrices can be combined by ,. Valid examples: A, L-2*I, L,A+I,L-alpha*I.
Step 2: Prepare representation matrix
Similar to PyG modules, our spectral filter class takes the graph attribute x and edge index edge_index as input. The _get_convolute_mat() method prepares the representation matrices used in recurrent computation as a dictionary:
def _get_convolute_mat(self, x, edge_index):
return {'x': x, 'x_1': x}
The above example overwrites the method for SkipConv, returning the input feature x and a placeholder x_1 for the representation in the previous hop.
Step 3: Derive recurrent forward
The _forward() method implements recurrent computation of the filter. Its input/output is a dictionary combining the propagation matrices defined by propagate_mat and the representation matrices prepared by _get_convolute_mat().
def _forward(self, x, x_1, prop):
if self.hop == 0:
# No propagation for k=0
return {'x': x, 'x_1': x, 'prop': prop}
h = self.propagate(prop, x=x)
h = h + x_1
return {'x': h, 'x_1': x, 'prop': prop}
Similar to PyG modules, the propagate() method conducts graph propagation by the given matrices. The above example corresponds to the graph propagation with a skip connection to the previous representation: \(H^{(k)} = (A-I)H^{(k-1)} + H^{(k-2)}\).
Build the model!
Now the SkipConv filter is properly defined. The following snippet use the nn.models.DecoupledVar model composing 10 hops of SkipConv filters, which can be used as a normal PyTorch model:
from pyg_spectral.nn.models import DecoupledVar
model = DecoupledVar(conv='SkipConv', num_hops=10, in_channels=x.size(1), hidden_channels=x.size(1), out_channels=x.size(1))
out = model(x, edge_index)