From d1df04f8bc06f051547df2551f4317ad54df8667 Mon Sep 17 00:00:00 2001 From: starride-teklia <starride@teklia.com> Date: Thu, 31 Oct 2024 12:32:30 +0100 Subject: [PATCH] Support two loggers --- docs/usage/training/index.md | 4 ++-- laia/common/arguments.py | 7 +------ laia/scripts/htr/train_ctc.py | 13 ++++++------- tests/scripts/htr/dataset/validate_cli_test.py | 2 +- tests/scripts/htr/train_ctc_cli_test.py | 2 +- 5 files changed, 11 insertions(+), 17 deletions(-) diff --git a/docs/usage/training/index.md b/docs/usage/training/index.md index 47b919d0..e6858ae1 100644 --- a/docs/usage/training/index.md +++ b/docs/usage/training/index.md @@ -61,7 +61,7 @@ The full list of parameters is detailed in this section. | `train.early_stopping_patience` | Number of validation epochs with no improvement after which training will be stopped. | `int` | `20` | | `train.gpu_stats` | Whether to include GPU stats in the training progress bar. | `bool` | `False` | | `train.augment_training` | Whether to use data augmentation. | `bool` | `False` | -| `train.log_to` | Logger to use during training. Should be either `"csv"` to log metrics locally, or `"wandb"` to report to Weights & Biases | `Logger` | `"csv"` | +| `train.log_to_wandb` | Whether to log training metrics and parameters to Weights & Biases. | `bool` | `False` | ### Logging arguments @@ -211,7 +211,7 @@ To set up Weights & Biases: * Run `pip install pylaia[wandb]` to install the required dependencies * Sign in to Weights & Biases using `wandb login` -Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to wandb`. +Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to_wandb true`. This will create a project called `PyLaia` in W&B with one run for each training. The following are monitored for each run: * Training and validation metrics (losses, CER, WER) diff --git a/laia/common/arguments.py b/laia/common/arguments.py index 96ea090a..55294e70 100644 --- a/laia/common/arguments.py +++ b/laia/common/arguments.py @@ -18,11 +18,6 @@ from jsonargparse.typing import ( GeNeg1Int = restricted_number_type(None, int, (">=", -1)) -class Logger(str, Enum): - csv = "csv" - wandb = "wandb" - - class Monitor(str, Enum): va_loss = "va_loss" va_cer = "va_cer" @@ -233,7 +228,7 @@ class TrainArgs: early_stopping_patience: NonNegativeInt = 20 gpu_stats: bool = False augment_training: bool = False - log_to: Logger = Logger.csv + log_to_wandb: bool = False @dataclass diff --git a/laia/scripts/htr/train_ctc.py b/laia/scripts/htr/train_ctc.py index 01c738ed..3d5efd9a 100755 --- a/laia/scripts/htr/train_ctc.py +++ b/laia/scripts/htr/train_ctc.py @@ -12,7 +12,6 @@ from laia.common.arguments import ( CommonArgs, DataArgs, DecodeArgs, - Logger, OptimizerArgs, SchedulerArgs, TrainArgs, @@ -160,18 +159,18 @@ def run( callbacks.append(LearningRate(logging_interval="epoch")) # prepare the logger - if train.log_to == Logger.wandb: - logger = pl.loggers.WandbLogger(project="PyLaia") - logger.watch(model) - else: - logger = EpochCSVLogger(common.experiment_dirpath) + loggers = [EpochCSVLogger(common.experiment_dirpath)] + if train.log_to_wandb: + wandb_logger = pl.loggers.WandbLogger(project="PyLaia") + wandb_logger.watch(model) + loggers.append(wandb_logger) # prepare the trainer trainer = pl.Trainer( default_root_dir=common.train_path, resume_from_checkpoint=checkpoint_path, callbacks=callbacks, - logger=logger, + logger=loggers, checkpoint_callback=True, **vars(trainer), ) diff --git a/tests/scripts/htr/dataset/validate_cli_test.py b/tests/scripts/htr/dataset/validate_cli_test.py index 12a1ce3f..a4b5ec87 100755 --- a/tests/scripts/htr/dataset/validate_cli_test.py +++ b/tests/scripts/htr/dataset/validate_cli_test.py @@ -48,7 +48,7 @@ train: early_stopping_patience: 20 gpu_stats: false augment_training: false - log_to: csv + log_to_wandb: false logging: fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s' level: INFO diff --git a/tests/scripts/htr/train_ctc_cli_test.py b/tests/scripts/htr/train_ctc_cli_test.py index 97b5f6da..ca1d81cd 100644 --- a/tests/scripts/htr/train_ctc_cli_test.py +++ b/tests/scripts/htr/train_ctc_cli_test.py @@ -101,7 +101,7 @@ train: early_stopping_patience: 20 gpu_stats: false augment_training: false - log_to: csv + log_to_wandb: false logging: fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s' level: INFO -- GitLab