diff --git a/docs/usage/training/index.md b/docs/usage/training/index.md index 47b919d0ea99837b8712f6d7b35919abab5c74d3..e6858ae188f4e3d5bd2482aa3c62fee905395933 100644 --- a/docs/usage/training/index.md +++ b/docs/usage/training/index.md @@ -61,7 +61,7 @@ The full list of parameters is detailed in this section. | `train.early_stopping_patience` | Number of validation epochs with no improvement after which training will be stopped. | `int` | `20` | | `train.gpu_stats` | Whether to include GPU stats in the training progress bar. | `bool` | `False` | | `train.augment_training` | Whether to use data augmentation. | `bool` | `False` | -| `train.log_to` | Logger to use during training. Should be either `"csv"` to log metrics locally, or `"wandb"` to report to Weights & Biases | `Logger` | `"csv"` | +| `train.log_to_wandb` | Whether to log training metrics and parameters to Weights & Biases. | `bool` | `False` | ### Logging arguments @@ -211,7 +211,7 @@ To set up Weights & Biases: * Run `pip install pylaia[wandb]` to install the required dependencies * Sign in to Weights & Biases using `wandb login` -Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to wandb`. +Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to_wandb true`. This will create a project called `PyLaia` in W&B with one run for each training. The following are monitored for each run: * Training and validation metrics (losses, CER, WER) diff --git a/laia/common/arguments.py b/laia/common/arguments.py index 96ea090ac544b0c382369de9f323cfd902ad14cb..55294e7063a94f685116a573259b417133020f07 100644 --- a/laia/common/arguments.py +++ b/laia/common/arguments.py @@ -18,11 +18,6 @@ from jsonargparse.typing import ( GeNeg1Int = restricted_number_type(None, int, (">=", -1)) -class Logger(str, Enum): - csv = "csv" - wandb = "wandb" - - class Monitor(str, Enum): va_loss = "va_loss" va_cer = "va_cer" @@ -233,7 +228,7 @@ class TrainArgs: early_stopping_patience: NonNegativeInt = 20 gpu_stats: bool = False augment_training: bool = False - log_to: Logger = Logger.csv + log_to_wandb: bool = False @dataclass diff --git a/laia/scripts/htr/train_ctc.py b/laia/scripts/htr/train_ctc.py index 01c738ed7db0f671f4610ee484daeba2e274a305..3d5efd9a6d35d0d835bb472f74f2a17a1a4b8c3e 100755 --- a/laia/scripts/htr/train_ctc.py +++ b/laia/scripts/htr/train_ctc.py @@ -12,7 +12,6 @@ from laia.common.arguments import ( CommonArgs, DataArgs, DecodeArgs, - Logger, OptimizerArgs, SchedulerArgs, TrainArgs, @@ -160,18 +159,18 @@ def run( callbacks.append(LearningRate(logging_interval="epoch")) # prepare the logger - if train.log_to == Logger.wandb: - logger = pl.loggers.WandbLogger(project="PyLaia") - logger.watch(model) - else: - logger = EpochCSVLogger(common.experiment_dirpath) + loggers = [EpochCSVLogger(common.experiment_dirpath)] + if train.log_to_wandb: + wandb_logger = pl.loggers.WandbLogger(project="PyLaia") + wandb_logger.watch(model) + loggers.append(wandb_logger) # prepare the trainer trainer = pl.Trainer( default_root_dir=common.train_path, resume_from_checkpoint=checkpoint_path, callbacks=callbacks, - logger=logger, + logger=loggers, checkpoint_callback=True, **vars(trainer), ) diff --git a/tests/scripts/htr/dataset/validate_cli_test.py b/tests/scripts/htr/dataset/validate_cli_test.py index 12a1ce3fa9278ad71cea268cd590335f2e185bc9..a4b5ec872b9fda55575de58ba18cf038734a4e76 100755 --- a/tests/scripts/htr/dataset/validate_cli_test.py +++ b/tests/scripts/htr/dataset/validate_cli_test.py @@ -48,7 +48,7 @@ train: early_stopping_patience: 20 gpu_stats: false augment_training: false - log_to: csv + log_to_wandb: false logging: fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s' level: INFO diff --git a/tests/scripts/htr/train_ctc_cli_test.py b/tests/scripts/htr/train_ctc_cli_test.py index 97b5f6da0ed421886e8887af611527b1a91b61bc..ca1d81cded68584b4f6f3d1999a69e9730afce7e 100644 --- a/tests/scripts/htr/train_ctc_cli_test.py +++ b/tests/scripts/htr/train_ctc_cli_test.py @@ -101,7 +101,7 @@ train: early_stopping_patience: 20 gpu_stats: false augment_training: false - log_to: csv + log_to_wandb: false logging: fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s' level: INFO