pyg_spectral.nn.norm

class TensorStandardScaler(dim: int = 0)[source]

Bases: Module

Applies standard Gaussian normalization to \(\mathcal{N}(0, 1)\).

Parameters:

dim (int, default: 0) – Dimension to calculate mean and std.

fit(x: Tensor) Tuple[Tensor, Tensor][source]

Compute the mean and std to be used for later scaling.

Parameters:

x (Tensor) – Data used to compute the mean and standard deviation

Returns:

var_mean (Tuple[torch.Tensor, torch.Tensor]) – Tuple of mean and std.

forward(x: Tensor, with_mean: bool = False) Tensor[source]

Forward pass.

Parameters:
  • x (Tensor) – The source tensor.

  • with_mean (bool, default: False) – Whether to center the data before scaling.