benchmark.trainer
- class SingleGraphLoader(args: Namespace, res_logger: ResLogger | None = None) None [source]
Bases:
object
Loader for
torch_geometric.data.Data
object for one graph.- Parameters:
- args_out = ['in_channels', 'out_channels', 'multi', 'metric']
- param = {'normg': ('float', (0.0, 1.0), {'step': 0.05}, <function SingleGraphLoader.<lambda>>)}
- class SingleGraphLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None [source]
Bases:
SingleGraphLoader
Reuse necessary data for multiple runs.
- class ModelLoader(args: Namespace, res_logger: ResLogger | None = None) None [source]
Bases:
object
Loader for
torch.nn.Module
object.- Parameters:
- args_out = ['criterion']
- param = {}
- static get_name(args: Namespace) tuple[str] [source]
Get model+conv name for logging path from argparse input without instantiation. Wrapper for
pyg_spectral.nn.get_nn_name()
.- Parameters:
args (
Namespace
) –Configuration arguments.
args.model (str): Model architecture name.
args.conv (str): Convolution layer name.
other args specified in module
name
function.
- Returns:
nn_name (
tuple[str]
) – Name strings(model_name, conv_name)
.
- static get_trn(args: Namespace) TrnBase [source]
Get trainer class from model name.
- Parameters:
args (
Namespace
) – Configuration arguments.- Returns:
trn (
TrnBase
) – Trainer class.
- get(args: Namespace) tuple[Module, TrnBase] [source]
Load model with specified arguments.
- Parameters:
args.num_hops (
int
) – Number of conv hops.args.in_layers (
int
) – Number of MLP layers before conv.args.out_layers (
int
) – Number of MLP layers after conv.args.in_channels (
int
) – Number of input features.args.out_channels (
int
) – Number of output classes.args.hidden_channels (
int
) – Number of hidden units.args.dropout_[lin/conv] (
float
) – Dropout rate for linear/conv.
- Updates:
args.criterion (
str
) – Loss function name.
- class ModelLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None [source]
Bases:
ModelLoader
Reuse necessary data for multiple runs.
- get(args: Namespace) tuple[Module, TrnBase] [source]
Load model with specified arguments.
- Parameters:
args.num_hops (
int
) – Number of conv hops.args.in_layers (
int
) – Number of MLP layers before conv.args.out_layers (
int
) – Number of MLP layers after conv.args.in_channels (
int
) – Number of input features.args.out_channels (
int
) – Number of output classes.args.hidden_channels (
int
) – Number of hidden units.args.dropout_[lin/conv] (
float
) – Dropout rate for linear/conv.
- Updates:
args.criterion (
str
) – Loss function name.
- class TrnBase_Trial(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]
Bases:
TrnBase
Trainer supporting optuna.pruners in training.
- class TrnFullbatch(model: Module, data: Dataset, args: Namespace, **kwargs)[source]
Bases:
TrnBase
- Fullbatch trainer class for node classification.
Model forward input: separate edge index and node features.
Run pipeline: train_val -> test.
- Parameters:
model (
Module
) – args forTrnBase
.data (
Dataset
) – args forTrnBase
.res_logger – args for
TrnBase
.args (
Namespace
) –args for
TrnBase
.device (str): torch device.
metric (str): Metric for evaluation.
epoch (int): Number of training epochs.
lr_[lin/conv] (float): Learning rate for linear/conv.
wd_[lin/conv] (float): Weight decay for linear/conv.
patience (int): Patience for early stopping.
period (int): Period for checkpoint saving.
suffix (str): Suffix for checkpoint saving.
storage (str): Storage scheme for checkpoint saving.
logpath (Path): Path for logging.
multi (bool): True for multi-label classification.
in_channels (int): Number of data input features.
out_channels (int): Number of data output classes.
- class TrnFullbatch_Trial(model: Module, data: Dataset, args: Namespace, **kwargs)[source]
Bases:
TrnFullbatch
,TrnBase_Trial
Trainer supporting optuna.pruners in training.
- class TrnMinibatch(model: Module, data: Dataset, args: Namespace, **kwargs)[source]
Bases:
TrnBase
- Minibatch trainer class for node classification.
Model forward input: node embeddings.
Run pipeline: propagate -> train_val -> test.
- Parameters:
args.batch (
int
) – Batch size.args.normf (
int
) – Embedding normalization.model (
Module
) – args forTrnBase
.data (
Dataset
) – args forTrnBase
.res_logger – args for
TrnBase
.args (
Namespace
) –args for
TrnBase
.device (str): torch device.
metric (str): Metric for evaluation.
epoch (int): Number of training epochs.
lr_[lin/conv] (float): Learning rate for linear/conv.
wd_[lin/conv] (float): Weight decay for linear/conv.
patience (int): Patience for early stopping.
period (int): Period for checkpoint saving.
suffix (str): Suffix for checkpoint saving.
storage (str): Storage scheme for checkpoint saving.
logpath (Path): Path for logging.
multi (bool): True for multi-label classification.
in_channels (int): Number of data input features.
out_channels (int): Number of data output classes.
- _fetch_preprocess(embed: Tensor, label: Tensor, mask: dict) dict [source]
Call model preprocess for precomputation.
- _fetch_input(split: str) Generator [source]
Process each sample of model input and label for training.
- class TrnMinibatch_Trial(model: Module, data: Dataset, args: Namespace, **kwargs)[source]
Bases:
TrnMinibatch
,TrnBase_Trial
Trainer supporting optuna.pruners in training. Lazy calling precomputation.
benchmark.trainer.base
Author: nyLiao File Created: 2024-03-03
- class TrnBase(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]
Bases:
object
Base trainer class for general pipelines and tasks.
- Parameters:
model (
Module
) – Pytorch model to be trained.data (
Data
) – PyG style data.res_logger (
ResLogger
|None
, default:None
) – Logger for results.args (
Namespace
) –Configuration arguments.
device (str): torch device.
metric (str): Metric for evaluation.
criterion (set): Loss function in
torch.nn
.epoch (int): Number of training epochs.
lr_[lin/conv] (float): Learning rate for linear/conv.
wd_[lin/conv] (float): Weight decay for linear/conv.
patience (int): Patience for early stopping.
period (int): Period for checkpoint saving.
suffix (str): Suffix for checkpoint saving.
storage (str): Storage scheme for checkpoint saving.
logpath (Path): Path for logging.
multi (bool): True for multi-label classification.
in_channels (int): Number of data input features.
out_channels (int): Number of data output classes.
- setup_optimizer()[source]
Set up the optimizer and scheduler.
- clear()[source]
Clear self cache.
- run()[source]
Run the training process.
- param = {'lr_conv': ('float', (1e-05, 0.5), {'log': True}, <function TrnBase.<lambda>>), 'lr_lin': ('float', (1e-05, 0.5), {'log': True}, <function TrnBase.<lambda>>), 'wd_conv': ('float', (1e-07, 0.001), {'log': True}, <function TrnBase.<lambda>>), 'wd_lin': ('float', (1e-07, 0.001), {'log': True}, <function TrnBase.<lambda>>)}
- setup_optimizer()[source]
- clear()[source]
- train_val(*args, **kwargs)[source]
- test(*args, **kwargs)[source]
- _fetch_data()[source]
Process the single graph data.
benchmark.trainer.regression
benchmark.trainer.load_data
Author: nyLiao File Created: 2024-02-26
- class SingleGraphLoader(args: Namespace, res_logger: ResLogger | None = None) None [source]
Bases:
object
Loader for
torch_geometric.data.Data
object for one graph.- Parameters:
- args_out = ['in_channels', 'out_channels', 'multi', 'metric']
- param = {'normg': ('float', (0.0, 1.0), {'step': 0.05}, <function SingleGraphLoader.<lambda>>)}
benchmark.trainer.load_metric
Author: nyLiao File Created: 2024-03-03
- class ResCollection(metrics: Metric | Sequence[Metric] | Dict[str, Metric], *additional_metrics: Metric, prefix: str | None = None, postfix: str | None = None, compute_groups: bool | List[List[str]] = True) None [source]
Bases:
MetricCollection
- metric_loader(args: Namespace) MetricCollection [source]
Loader for
torchmetrics.Metric
object.- Parameters:
args (
Namespace
) –Configuration arguments.
args.multi (bool): True for multi-label classification.
args.out_channels (int): Number of output classes/labels.
benchmark.trainer.load_model
Author: nyLiao File Created: 2024-02-26
- class ModelLoader(args: Namespace, res_logger: ResLogger | None = None) None [source]
Bases:
object
Loader for
torch.nn.Module
object.- Parameters:
- args_out = ['criterion']
- param = {}
- static get_name(args: Namespace) tuple[str] [source]
Get model+conv name for logging path from argparse input without instantiation. Wrapper for
pyg_spectral.nn.get_nn_name()
.- Parameters:
args (
Namespace
) –Configuration arguments.
args.model (str): Model architecture name.
args.conv (str): Convolution layer name.
other args specified in module
name
function.
- Returns:
nn_name (
tuple[str]
) – Name strings(model_name, conv_name)
.
- static get_trn(args: Namespace) TrnBase [source]
Get trainer class from model name.
- Parameters:
args (
Namespace
) – Configuration arguments.- Returns:
trn (
TrnBase
) – Trainer class.
- get(args: Namespace) tuple[Module, TrnBase] [source]
Load model with specified arguments.
- Parameters:
args.num_hops (
int
) – Number of conv hops.args.in_layers (
int
) – Number of MLP layers before conv.args.out_layers (
int
) – Number of MLP layers after conv.args.in_channels (
int
) – Number of input features.args.out_channels (
int
) – Number of output classes.args.hidden_channels (
int
) – Number of hidden units.args.dropout_[lin/conv] (
float
) – Dropout rate for linear/conv.
- Updates:
args.criterion (
str
) – Loss function name.
- class ModelLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None [source]
Bases:
ModelLoader
Reuse necessary data for multiple runs.
- get(args: Namespace) tuple[Module, TrnBase] [source]
Load model with specified arguments.
- Parameters:
args.num_hops (
int
) – Number of conv hops.args.in_layers (
int
) – Number of MLP layers before conv.args.out_layers (
int
) – Number of MLP layers after conv.args.in_channels (
int
) – Number of input features.args.out_channels (
int
) – Number of output classes.args.hidden_channels (
int
) – Number of hidden units.args.dropout_[lin/conv] (
float
) – Dropout rate for linear/conv.
- Updates:
args.criterion (
str
) – Loss function name.