benchmark.trainer

class SingleGraphLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch_geometric.data.Data object for one graph.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    • args.seed (int): Random seed.

    • args.data (str): Dataset name.

    • args.data_split (str): Index of dataset split.

  • res_logger (ResLogger | None, default: None) – Logger for results.

get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

class SingleGraphLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: SingleGraphLoader

Reuse necessary data for multiple runs.

get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

update(args: Namespace, data: Data) Data[source]

Update data split for the next trial.

class ModelLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch.nn.Module object.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    args.model (str): Model architecture name. args.conv (str): Convolution layer name.

  • res_logger (ResLogger | None, default: None) – Logger for results.

_resolve_import(args: Namespace) Tuple[str, str, dict, TrnBase][source]
get(args: Namespace) Tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.hidden (int) – Number of hidden units.

  • args.dp_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Criterion for loss calculation

class ModelLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: ModelLoader

Reuse necessary data for multiple runs.

get(args: Namespace) Tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.hidden (int) – Number of hidden units.

  • args.dp_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Criterion for loss calculation

update(args: Namespace, model: Module) Module[source]
class TrnBase_Trial(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

Bases: TrnBase

Trainer supporting optuna.pruners in training.

split_hyperval(data: Data) Data[source]
clear()[source]
train_val(split_train: List[str] = ['train'], split_val: List[str] = ['val']) ResLogger[source]
class TrnFullbatch(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnBase

Fullbatch trainer class for node classification.
  • Model forward input: separate edge index and node features.

  • Run pipeline: train_val -> test.

Parameters:
  • model (Module) – args for TrnBase.

  • data (Dataset) – args for TrnBase.

  • res_logger – args for TrnBase.

  • args (Namespace) –

    args for TrnBase.

    • device (str): torch device.

    • metric (str): Metric for evaluation.

    • epoch (int): Number of training epochs.

    • lr_[lin/conv] (float): Learning rate for linear/conv.

    • wd_[lin/conv] (float): Weight decay for linear/conv.

    • patience (int): Patience for early stopping.

    • period (int): Period for checkpoint saving.

    • suffix (str): Suffix for checkpoint saving.

    • storage (str): Storage scheme for checkpoint saving.

    • logpath (Path): Path for logging.

    • multi (bool): True for multi-label classification.

    • num_features (int): Number of data input features.

    • num_classes (int): Number of data output classes.

name: str = 'fb'
clear()[source]
_fetch_data() Tuple[Data, dict][source]

Process the single graph data.

_fetch_input() tuple[source]

Process each sample of model input and label.

_learn_split(split: list = ['train']) ResLogger[source]

Actual train iteration on the given splits.

_eval_split(split: list = ['test']) ResLogger[source]

Actual test on the given splits.

test_deg() ResLogger[source]

Separate high/low degree subsets and evaluate.

run() ResLogger[source]
class TrnFullbatch_Trial(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnFullbatch, TrnBase_Trial

Trainer supporting optuna.pruners in training.

run() ResLogger[source]
update(*args, **kwargs)[source]
class TrnMinibatch(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnBase

Minibatch trainer class for node classification.
  • Model forward input: node embeddings.

  • Run pipeline: propagate -> train_val -> test.

Parameters:
  • args.batch (int) – Batch size.

  • args.normf (int) – Embedding normalization.

  • model (Module) – args for TrnBase.

  • data (Dataset) – args for TrnBase.

  • res_logger – args for TrnBase.

  • args (Namespace) –

    args for TrnBase.

    • device (str): torch device.

    • metric (str): Metric for evaluation.

    • epoch (int): Number of training epochs.

    • lr_[lin/conv] (float): Learning rate for linear/conv.

    • wd_[lin/conv] (float): Weight decay for linear/conv.

    • patience (int): Patience for early stopping.

    • period (int): Period for checkpoint saving.

    • suffix (str): Suffix for checkpoint saving.

    • storage (str): Storage scheme for checkpoint saving.

    • logpath (Path): Path for logging.

    • multi (bool): True for multi-label classification.

    • num_features (int): Number of data input features.

    • num_classes (int): Number of data output classes.

name: str = 'mb'
clear()[source]
_fetch_data() tuple[source]

Process the single graph data.

_fetch_preprocess(embed: Tensor, label: Tensor, mask: dict) dict[source]

Call model preprocess for precomputation.

_fetch_input(split: str) Generator[source]

Process each sample of model input and label for training.

_learn_split(split: list = ['train']) ResLogger[source]

Actual train iteration on the given splits.

_eval_split(split: list = ['test']) ResLogger[source]

Actual test on the given splits.

preprocess(*args, **kwargs)[source]
run() ResLogger[source]
class TrnMinibatch_Trial(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnMinibatch, TrnBase_Trial

Trainer supporting optuna.pruners in training. Lazy calling precomputation.

preprocess(*args, **kwargs)[source]
run() ResLogger[source]
update(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

benchmark.trainer.base

Author: nyLiao File Created: 2024-03-03

class TrnBase(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

Bases: object

Base trainer class for general pipelines and tasks.

Parameters:
  • model (Module) – Pytorch model to be trained.

  • data (Data) – PyG style data.

  • res_logger (ResLogger | None, default: None) – Logger for results.

  • args (Namespace) –

    Configuration arguments.

    • device (str): torch device.

    • metric (str): Metric for evaluation.

    • criterion (set): Loss function in torch.nn.

    • epoch (int): Number of training epochs.

    • lr_[lin/conv] (float): Learning rate for linear/conv.

    • wd_[lin/conv] (float): Weight decay for linear/conv.

    • patience (int): Patience for early stopping.

    • period (int): Period for checkpoint saving.

    • suffix (str): Suffix for checkpoint saving.

    • storage (str): Storage scheme for checkpoint saving.

    • logpath (Path): Path for logging.

    • multi (bool): True for multi-label classification.

    • num_features (int): Number of data input features.

    • num_classes (int): Number of data output classes.

setup_optimizer()[source]

Set up the optimizer and scheduler.

clear()[source]

Clear self cache.

run()[source]

Run the training process.

name: str[source]
setup_optimizer()[source]
clear()[source]
static _log_memory(split: str | None = None, row: int = 0)[source]
train_val(*args, **kwargs)[source]
test(*args, **kwargs)[source]
_fetch_data()[source]

Process the single graph data.

_fetch_input() tuple[source]

Process each sample of model input and label.

_learn_split(split: List[str] = ['train']) ResLogger[source]

Actual train iteration on the given splits.

_eval_split(split: List[str]) ResLogger[source]

Actual test on the given splits.

run() ResLogger[source]
class TrnBase_Trial(model: Module, data: Data, args: Namespace, res_logger: ResLogger | None = None, **kwargs)[source]

Bases: TrnBase

Trainer supporting optuna.pruners in training.

split_hyperval(data: Data) Data[source]
clear()[source]
train_val(split_train: List[str] = ['train'], split_val: List[str] = ['val']) ResLogger[source]

benchmark.trainer.regression

class TrnRegression(model: Module, data: Dataset, args: Namespace, **kwargs)[source]

Bases: TrnFullbatch

name: str = 'regression'
_fetch_input() tuple[source]

Process each sample of model input and label.

class RegressionLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for regression learning datas.

_resolve_import(args: Namespace) Tuple[str, str, dict][source]
get(args: Namespace) Data[source]

Load data based on parameters.

benchmark.trainer.load_data

Author: nyLiao File Created: 2024-02-26

class SingleGraphLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch_geometric.data.Data object for one graph.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    • args.seed (int): Random seed.

    • args.data (str): Dataset name.

    • args.data_split (str): Index of dataset split.

  • res_logger (ResLogger | None, default: None) – Logger for results.

get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

class SingleGraphLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: SingleGraphLoader

Reuse necessary data for multiple runs.

get(args: Namespace) Data[source]

Load data based on parameters.

Parameters:

args.normg (float) – Generalized graph norm.

Updates:
  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.multi (bool) – True for multi-label classification.

  • args.metric (str) – Main metric name for evaluation.

update(args: Namespace, data: Data) Data[source]

Update data split for the next trial.

benchmark.trainer.load_metric

Author: nyLiao File Created: 2024-03-03

class ResCollection(metrics: Metric | Sequence[Metric] | Dict[str, Metric], *additional_metrics: Metric, prefix: str | None = None, postfix: str | None = None, compute_groups: bool | List[List[str]] = True) None[source]

Bases: MetricCollection

compute() List[Tuple[str, Any, Callable]][source]

Wrap compute output to ResLogger style.

metric_loader(args: Namespace) MetricCollection[source]

Loader for torchmetrics.Metric object.

Parameters:

args (Namespace) –

Configuration arguments.

args.multi (bool): True for multi-label classification. args.num_classes (int): Number of output classes/labels.

benchmark.trainer.load_model

Author: nyLiao File Created: 2024-02-26

class ModelLoader(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: object

Loader for torch.nn.Module object.

Parameters:
  • args (Namespace) –

    Configuration arguments.

    args.model (str): Model architecture name. args.conv (str): Convolution layer name.

  • res_logger (ResLogger | None, default: None) – Logger for results.

_resolve_import(args: Namespace) Tuple[str, str, dict, TrnBase][source]
get(args: Namespace) Tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.hidden (int) – Number of hidden units.

  • args.dp_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Criterion for loss calculation

class ModelLoader_Trial(args: Namespace, res_logger: ResLogger | None = None) None[source]

Bases: ModelLoader

Reuse necessary data for multiple runs.

get(args: Namespace) Tuple[Module, TrnBase][source]

Load model with specified arguments.

Parameters:
  • args.num_hops (int) – Number of conv hops.

  • args.in_layers (int) – Number of MLP layers before conv.

  • args.out_layers (int) – Number of MLP layers after conv.

  • args.num_features (int) – Number of input features.

  • args.num_classes (int) – Number of output classes.

  • args.hidden (int) – Number of hidden units.

  • args.dp_[lin/conv] (float) – Dropout rate for linear/conv.

Updates:

args.criterion (str) – Criterion for loss calculation

update(args: Namespace, model: Module) Module[source]