benchmark.utils

setup_seed(seed: int | None = None, cuda: bool = True) int[source]
setup_argparse()[source]
setup_args(parser: ArgumentParser) Namespace[source]
save_args(logpath: Path, args: dict)[source]
dict_to_json(dictionary) dict[source]
setup_logger(logpath: Path | str = PosixPath('../log'), level_console: int = 10, level_file: int = 15, quiet: bool = True, fmt='{message}')[source]
clear_logger(logger: Logger)[source]
setup_logpath(dir: Path | str = PosixPath('../log'), folder_args: tuple | None = None, quiet: bool = True)[source]

Resolve log path for saving.

Parameters:
  • dir (Path | str, default: PosixPath('../log')) – Base directory for saving logs.

  • folder_args (tuple | None, default: None) – Subfolder names.

  • quiet (bool, default: True) – Quiet run without creating directories.

Returns:

logpath (Path) – Path for log directory.

class ResLogger(logpath: Path | str = PosixPath('../log'), prefix: str = 'summary', suffix: str | None = None, quiet: bool = True)[source]

Bases: object

Logger for formatting result to strings by wrapping pd.DataFrame table.

Parameters:
  • logpath (Path | str, default: PosixPath('../log')) – Path to CSV file saving directory.

  • quiet (bool, default: True) – Quiet run without saving file.

static guess_fmt(key: str, val) Callable[source]

Guesses the string format function based on its name.

property nrows[source]
property ncols[source]
_set(data: DataFrame, fmt: Series)[source]

Sets the data from input DataFrame.

Parameters:
  • data (DataFrame) – Concat on columns, inner join on index.

  • fmt (Series) – Inner join on columns.

concat(vals: list[tuple[str, Any, Callable]] | dict, row: int = 0, suffix: str | None = None)[source]

Concatenate data entries of a single row to data.

Parameters:
  • vals (list[tuple[str, Any, Callable]] | dict) – list of entries (key, value, formatter).

  • row (int, default: 0) – New index in self dataframe for vals to be logged.

  • suffix (str | None, default: None) – Suffix string for input keys. Default is None.

Returns:

self (ResLogger)

merge(logger: ResLogger, rows: list[int] | None = None, suffix: str | None = None)[source]

Merge from another logger.

Parameters:
  • logger (ResLogger) – Logger to merge.

  • row – New index in self dataframe.

  • suffix (str | None, default: None) – Suffix string for input keys. Default is None.

del_col(col: list | str) ResLogger[source]

Delete columns from data.

Parameters:

col (list | str) – Column(s) to delete.

_get(col: list | str | None = None, row: list | str | None = None) DataFrame | Series | str[source]

Retrieve one or sliced data and apply string format.

Parameters:
  • col (list | str | None, default: None) – Column(s) to retrieve. Defaults to all.

  • row (list | str | None, default: None) – Row(s) to retrieve. Defaults to all.

Returns:

val – Formatted data. - type: follows the return type of DataFrame.loc[row, col]. - value: formatted string in each entry.

save()[source]

Saves table data to CSV file.

get_str(col: list | str | None = None, row: list | int | None = None, maxlen: int = -1) str[source]

Get formatted long string for printing of the specified columns and rows.

Parameters:
  • col (list | str | None, default: None) – Column(s) to retrieve. Defaults to all.

  • row (list | int | None, default: None) – Row(s) to retrieve. Defaults to all.

  • maxlen (int, default: -1) – Max line length of the resulting string.

Returns:

s (str) – Formatted string representation.

flt_str(metric) str[source]

Remove all substring start with s but not contain metric.

Parameters:

metric (str) – Metric to keep.

class CkptLogger(logpath: Path | str, patience: int = -1, period: int = 0, prefix: str = 'model', storage: str = 'state_gpu', metric_cmp: Callable[[float, float], bool] | str = 'max')[source]

Bases: object

Checkpoint Logger for saving and loading models and managing early stopping during training.

Parameters:
  • logpath (Path | str) – Path to checkpoints saving directory.

  • patience (int, default: -1) – Patience for early stopping. Defaults no early stopping.

  • period (int, default: 0) – Periodic saving interval. Defaults to no periodic saving.

  • prefix (str, default: 'model') – Prefix for the checkpoint file names.

  • storage (str, default: 'state_gpu') – Storage scheme for saving the checkpoints. * ‘model’ vs ‘state’: Save model object or state_dict. * ‘_file’, ‘_ram’, ‘_gpu’: Save as file, RAM, or GPU memory.

  • metric_cmp (Callable[[float, float], bool] | str, default: 'max') – Comparison function for the metric. Can be ‘max’ or ‘min’.

set_epoch(epoch: int = 0)[source]
_get_model_path(*suffix) Path[source]
get_last_epoch() int[source]

Get last saved model epoch. Useful for deciding load model path.

save(*suffix, model: Module)[source]

Save the model according to storage scheme.

Parameters:
  • suffix – Variable length argument for suffix in the model file name.

  • model (nn.Module) – The model to be saved.

load(*suffix, model: Module, map_location='cpu') Module[source]

Load the model from the storage.

Parameters:
  • suffix – Variable length argument for suffix in the model file name.

  • model (Module) – The model structure to load.

  • map_location (default: 'cpu') – map_location argument for torch.load.

Returns:

model (nn.Module) – The loaded model.

clear()[source]
property is_early_stop: bool[source]

Whether current epoch satisfies early stopping criteria.

property is_period: bool[source]

Whether current epoch should do periodic saving.

_is_improved(metric) bool[source]

Whether the metric is better than previous best.

step(metric: float, model: Module | None = None) bool[source]

Step one epoch with periodic saving and early stopping.

Parameters:
  • metric (float) – Metric value for the current step.

  • model (Module | None, default: None) – Model for the current step. Defaults to None.

Returns:

early_stop (bool) – True if early stopping criteria is met.

set_at_best(**kwargs)[source]

Save given args to model attributes if is the best epoch.

get_at_best() list[source]

Get saved model attributes from the best epoch.

benchmark.utils.config

config.force_list_str(x)

config.force_list_int(x)

config.list_str(x)

config.list_int(x)

config.list_float(x)