benchmark.trainer

class SingleGraphLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch_geometric.data.Data object for one graph.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    • args.seed (int): Random seed.

    • args.data (str): Dataset name.

    • args.data_split (str): Index of dataset split.

  • res_logger (ResLogger | None, default: None) – Logger for results.

args_out = ['in_channels', 'out_channels', 'multi', 'metric']
param = {'normg': ('float', (0.0, 1.0), {'step': 0.05}, <function SingleGraphLoader.<lambda>>)}
static available_datasets() dict[source]
get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

class SingleGraphLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: SingleGraphLoader

Reuse necessary data for multiple runs.

get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

update(args: Namespace, data: Data) Data[source]

Update data split for the next trial.

class ModelLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch.nn.Module object.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    • args.model (str): Model architecture name.

    • args.conv (str): Convolution layer name.

  • res_logger (ResLogger | None, default: None) – Logger for results.

args_out = ['criterion']
param = {}
static available_models() list[str][source]
static available_convs() list[str][source]
static get_name(args: Namespace) tuple[str][source]

Get model+conv name for logging path from argparse input without instantiation. Wrapper for pyg_spectral.nn.get_nn_name().

Parameters:

args (Namespace) –

Configuration arguments.

  • args.model (str): Model architecture name.

  • args.conv (str): Convolution layer name.

  • other args specified in module name function.

Returns:

nn_name (tuple[str]) – Name strings (model_name, conv_name).

static get_trn(args: Namespace) TrnBase[source]

Get trainer class from model name.

Parameters:

args (Namespace) – Configuration arguments.

Returns:

trn (TrnBase) – Trainer class.

_resolve_import(args: Namespace) tuple[str, str, dict][source]
get(args: Namespace) tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.hidden_channels (int) – Number of hidden units.

  • args.dropout_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Loss function name.

class ModelLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: ModelLoader

Reuse necessary data for multiple runs.

get(args: Namespace) tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.hidden_channels (int) – Number of hidden units.

  • args.dropout_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Loss function name.

update(args: Namespace, model: Module) Module[source]
class TrnBase_Trial(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

Bases: TrnBase

Trainer supporting optuna.pruners in training.

split_hyperval(data: Data) Data[source]
clear()[source]
train_val(split_train: list[str] = ['train'], split_val: list[str] = ['val']) ResLogger[source]
class TrnFullbatch(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnBase

Fullbatch trainer class for node classification.
  • Model forward input: separate edge index and node features.

  • Run pipeline: train_val -> test.

Parameters:
  • model (Module) – args for TrnBase.

  • data (Dataset) – args for TrnBase.

  • res_logger – args for TrnBase.

  • args (Namespace) –

    args for TrnBase.

    • device (str): torch device.

    • metric (str): Metric for evaluation.

    • epoch (int): Number of training epochs.

    • lr_[lin/conv] (float): Learning rate for linear/conv.

    • wd_[lin/conv] (float): Weight decay for linear/conv.

    • patience (int): Patience for early stopping.

    • period (int): Period for checkpoint saving.

    • suffix (str): Suffix for checkpoint saving.

    • storage (str): Storage scheme for checkpoint saving.

    • logpath (Path): Path for logging.

    • multi (bool): True for multi-label classification.

    • in_channels (int): Number of data input features.

    • out_channels (int): Number of data output classes.

name: str = 'fb'
clear()[source]
_fetch_data() tuple[Data, dict][source]

Process the single graph data.

_fetch_input() tuple[source]

Process each sample of model input and label.

_learn_split(split: list = ['train']) ResLogger[source]

Actual train iteration on the given splits.

_eval_split(split: list = ['test']) ResLogger[source]

Actual test on the given splits.

test_deg() ResLogger[source]

Separate high/low degree subsets and evaluate.

run() ResLogger[source]
class TrnFullbatch_Trial(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnFullbatch, TrnBase_Trial

Trainer supporting optuna.pruners in training.

run() ResLogger[source]
update(*args, **kwargs)[source]
class TrnMinibatch(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnBase

Minibatch trainer class for node classification.
  • Model forward input: node embeddings.

  • Run pipeline: propagate -> train_val -> test.

Parameters:
  • args.batch (int) – Batch size.

  • args.normf (int) – Embedding normalization.

  • model (Module) – args for TrnBase.

  • data (Dataset) – args for TrnBase.

  • res_logger – args for TrnBase.

  • args (Namespace) –

    args for TrnBase.

    • device (str): torch device.

    • metric (str): Metric for evaluation.

    • epoch (int): Number of training epochs.

    • lr_[lin/conv] (float): Learning rate for linear/conv.

    • wd_[lin/conv] (float): Weight decay for linear/conv.

    • patience (int): Patience for early stopping.

    • period (int): Period for checkpoint saving.

    • suffix (str): Suffix for checkpoint saving.

    • storage (str): Storage scheme for checkpoint saving.

    • logpath (Path): Path for logging.

    • multi (bool): True for multi-label classification.

    • in_channels (int): Number of data input features.

    • out_channels (int): Number of data output classes.

name: str = 'mb'
clear()[source]
_fetch_data() tuple[source]

Process the single graph data.

_fetch_preprocess(embed: Tensor, label: Tensor, mask: dict) dict[source]

Call model preprocess for precomputation.

_fetch_input(split: str) Generator[source]

Process each sample of model input and label for training.

_learn_split(split: list = ['train']) ResLogger[source]

Actual train iteration on the given splits.

_eval_split(split: list = ['test']) ResLogger[source]

Actual test on the given splits.

preprocess(*args, **kwargs)[source]
run() ResLogger[source]
class TrnMinibatch_Trial(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnMinibatch, TrnBase_Trial

Trainer supporting optuna.pruners in training. Lazy calling precomputation.

preprocess(*args, **kwargs)[source]
run() ResLogger[source]
update(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

benchmark.trainer.base

Author: nyLiao File Created: 2024-03-03

class TrnBase(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

Bases: object

Base trainer class for general pipelines and tasks.

Parameters:
  • model (Module) – Pytorch model to be trained.

  • data (Data) – PyG style data.

  • res_logger (ResLogger | None, default: None) – Logger for results.

  • args (Namespace) –

    Configuration arguments.

    • device (str): torch device.

    • metric (str): Metric for evaluation.

    • criterion (set): Loss function in torch.nn.

    • epoch (int): Number of training epochs.

    • lr_[lin/conv] (float): Learning rate for linear/conv.

    • wd_[lin/conv] (float): Weight decay for linear/conv.

    • patience (int): Patience for early stopping.

    • period (int): Period for checkpoint saving.

    • suffix (str): Suffix for checkpoint saving.

    • storage (str): Storage scheme for checkpoint saving.

    • logpath (Path): Path for logging.

    • multi (bool): True for multi-label classification.

    • in_channels (int): Number of data input features.

    • out_channels (int): Number of data output classes.

setup_optimizer()[source]

Set up the optimizer and scheduler.

clear()[source]

Clear self cache.

run()[source]

Run the training process.

name: str[source]
param = {'lr_conv': ('float', (1e-05, 0.5), {'log': True}, <function TrnBase.<lambda>>), 'lr_lin': ('float', (1e-05, 0.5), {'log': True}, <function TrnBase.<lambda>>), 'wd_conv': ('float', (1e-07, 0.001), {'log': True}, <function TrnBase.<lambda>>), 'wd_lin': ('float', (1e-07, 0.001), {'log': True}, <function TrnBase.<lambda>>)}
setup_optimizer()[source]
clear()[source]
static _log_memory(split: str | None = None, row: int = 0)[source]
train_val(*args, **kwargs)[source]
test(*args, **kwargs)[source]
_fetch_data()[source]

Process the single graph data.

_fetch_input() tuple[source]

Process each sample of model input and label.

_learn_split(split: list[str] = ['train']) ResLogger[source]

Actual train iteration on the given splits.

_eval_split(split: list[str]) ResLogger[source]

Actual test on the given splits.

run() ResLogger[source]
class TrnBase_Trial(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

Bases: TrnBase

Trainer supporting optuna.pruners in training.

split_hyperval(data: Data) Data[source]
clear()[source]
train_val(split_train: list[str] = ['train'], split_val: list[str] = ['val']) ResLogger[source]

benchmark.trainer.regression

class TrnRegression(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnFullbatch

name: str = 'regression'
_fetch_input() tuple[source]

Process each sample of model input and label.

class RegressionLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for regression learning datas.

_resolve_import(args: Namespace) tuple[str, str, dict][source]
get(args: Namespace) Data[source]

Load data based on parameters.

benchmark.trainer.load_data

Author: nyLiao File Created: 2024-02-26

class SingleGraphLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch_geometric.data.Data object for one graph.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    • args.seed (int): Random seed.

    • args.data (str): Dataset name.

    • args.data_split (str): Index of dataset split.

  • res_logger (ResLogger | None, default: None) – Logger for results.

args_out = ['in_channels', 'out_channels', 'multi', 'metric']
param = {'normg': ('float', (0.0, 1.0), {'step': 0.05}, <function SingleGraphLoader.<lambda>>)}
static available_datasets() dict[source]
get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

class SingleGraphLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: SingleGraphLoader

Reuse necessary data for multiple runs.

get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

update(args: Namespace, data: Data) Data[source]

Update data split for the next trial.

benchmark.trainer.load_metric

Author: nyLiao File Created: 2024-03-03

class ResCollection(metrics: Metric | Sequence[Metric] | Dict[str, Metric], *additional_metrics: Metric, prefix: str | None = None, postfix: str | None = None, compute_groups: bool | List[List[str]] = True) None[source]

Bases: MetricCollection

compute() list[tuple[str, Any, Callable]][source]

Wrap compute output to ResLogger style.

metric_loader(args: Namespace) MetricCollection[source]

Loader for torchmetrics.Metric object.

Parameters:

args (Namespace) –

Configuration arguments.

  • args.multi (bool): True for multi-label classification.

  • args.out_channels (int): Number of output classes/labels.

benchmark.trainer.load_model

Author: nyLiao File Created: 2024-02-26

class ModelLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch.nn.Module object.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    • args.model (str): Model architecture name.

    • args.conv (str): Convolution layer name.

  • res_logger (ResLogger | None, default: None) – Logger for results.

args_out = ['criterion']
param = {}
static available_models() list[str][source]
static available_convs() list[str][source]
static get_name(args: Namespace) tuple[str][source]

Get model+conv name for logging path from argparse input without instantiation. Wrapper for pyg_spectral.nn.get_nn_name().

Parameters:

args (Namespace) –

Configuration arguments.

  • args.model (str): Model architecture name.

  • args.conv (str): Convolution layer name.

  • other args specified in module name function.

Returns:

nn_name (tuple[str]) – Name strings (model_name, conv_name).

static get_trn(args: Namespace) TrnBase[source]

Get trainer class from model name.

Parameters:

args (Namespace) – Configuration arguments.

Returns:

trn (TrnBase) – Trainer class.

_resolve_import(args: Namespace) tuple[str, str, dict][source]
get(args: Namespace) tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.hidden_channels (int) – Number of hidden units.

  • args.dropout_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Loss function name.

class ModelLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: ModelLoader

Reuse necessary data for multiple runs.

get(args: Namespace) tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.in_channels (int) – Number of input features.

  • args.out_channels (int) – Number of output classes.

  • args.hidden_channels (int) – Number of hidden units.

  • args.dropout_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Loss function name.

update(args: Namespace, model: Module) Module[source]