From d1df04f8bc06f051547df2551f4317ad54df8667 Mon Sep 17 00:00:00 2001
From: starride-teklia <starride@teklia.com>
Date: Thu, 31 Oct 2024 12:32:30 +0100
Subject: [PATCH] Support two loggers

---
 docs/usage/training/index.md                   |  4 ++--
 laia/common/arguments.py                       |  7 +------
 laia/scripts/htr/train_ctc.py                  | 13 ++++++-------
 tests/scripts/htr/dataset/validate_cli_test.py |  2 +-
 tests/scripts/htr/train_ctc_cli_test.py        |  2 +-
 5 files changed, 11 insertions(+), 17 deletions(-)

diff --git a/docs/usage/training/index.md b/docs/usage/training/index.md
index 47b919d0..e6858ae1 100644
--- a/docs/usage/training/index.md
+++ b/docs/usage/training/index.md
@@ -61,7 +61,7 @@ The full list of parameters is detailed in this section.
 | `train.early_stopping_patience` | Number of validation epochs with no improvement after which training will be stopped.                                                                        | `int`       | `20`          |
 | `train.gpu_stats`               | Whether to include GPU stats in the training progress bar.                                                                                                   | `bool`      | `False`       |
 | `train.augment_training`        | Whether to use data augmentation.                                                                                                                            | `bool`      | `False`       |
-| `train.log_to`                  | Logger to use during training. Should be either `"csv"` to log metrics locally, or `"wandb"` to report to Weights & Biases                                   | `Logger`    | `"csv"`       |
+| `train.log_to_wandb`            | Whether to log training metrics and parameters to Weights & Biases.                                                                                          | `bool`      | `False`       |
 
 
 ### Logging arguments
@@ -211,7 +211,7 @@ To set up Weights & Biases:
 * Run `pip install pylaia[wandb]` to install the required dependencies
 * Sign in to Weights & Biases using `wandb login`
 
-Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to wandb`.
+Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to_wandb true`.
 
 This will create a project called `PyLaia` in W&B with one run for each training. The following are monitored for each run:
 * Training and validation metrics (losses, CER, WER)
diff --git a/laia/common/arguments.py b/laia/common/arguments.py
index 96ea090a..55294e70 100644
--- a/laia/common/arguments.py
+++ b/laia/common/arguments.py
@@ -18,11 +18,6 @@ from jsonargparse.typing import (
 GeNeg1Int = restricted_number_type(None, int, (">=", -1))
 
 
-class Logger(str, Enum):
-    csv = "csv"
-    wandb = "wandb"
-
-
 class Monitor(str, Enum):
     va_loss = "va_loss"
     va_cer = "va_cer"
@@ -233,7 +228,7 @@ class TrainArgs:
     early_stopping_patience: NonNegativeInt = 20
     gpu_stats: bool = False
     augment_training: bool = False
-    log_to: Logger = Logger.csv
+    log_to_wandb: bool = False
 
 
 @dataclass
diff --git a/laia/scripts/htr/train_ctc.py b/laia/scripts/htr/train_ctc.py
index 01c738ed..3d5efd9a 100755
--- a/laia/scripts/htr/train_ctc.py
+++ b/laia/scripts/htr/train_ctc.py
@@ -12,7 +12,6 @@ from laia.common.arguments import (
     CommonArgs,
     DataArgs,
     DecodeArgs,
-    Logger,
     OptimizerArgs,
     SchedulerArgs,
     TrainArgs,
@@ -160,18 +159,18 @@ def run(
         callbacks.append(LearningRate(logging_interval="epoch"))
 
     # prepare the logger
-    if train.log_to == Logger.wandb:
-        logger = pl.loggers.WandbLogger(project="PyLaia")
-        logger.watch(model)
-    else:
-        logger = EpochCSVLogger(common.experiment_dirpath)
+    loggers = [EpochCSVLogger(common.experiment_dirpath)]
+    if train.log_to_wandb:
+        wandb_logger = pl.loggers.WandbLogger(project="PyLaia")
+        wandb_logger.watch(model)
+        loggers.append(wandb_logger)
 
     # prepare the trainer
     trainer = pl.Trainer(
         default_root_dir=common.train_path,
         resume_from_checkpoint=checkpoint_path,
         callbacks=callbacks,
-        logger=logger,
+        logger=loggers,
         checkpoint_callback=True,
         **vars(trainer),
     )
diff --git a/tests/scripts/htr/dataset/validate_cli_test.py b/tests/scripts/htr/dataset/validate_cli_test.py
index 12a1ce3f..a4b5ec87 100755
--- a/tests/scripts/htr/dataset/validate_cli_test.py
+++ b/tests/scripts/htr/dataset/validate_cli_test.py
@@ -48,7 +48,7 @@ train:
   early_stopping_patience: 20
   gpu_stats: false
   augment_training: false
-  log_to: csv
+  log_to_wandb: false
 logging:
   fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s'
   level: INFO
diff --git a/tests/scripts/htr/train_ctc_cli_test.py b/tests/scripts/htr/train_ctc_cli_test.py
index 97b5f6da..ca1d81cd 100644
--- a/tests/scripts/htr/train_ctc_cli_test.py
+++ b/tests/scripts/htr/train_ctc_cli_test.py
@@ -101,7 +101,7 @@ train:
   early_stopping_patience: 20
   gpu_stats: false
   augment_training: false
-  log_to: csv
+  log_to_wandb: false
 logging:
   fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s'
   level: INFO
-- 
GitLab