Skip to content
Snippets Groups Projects
Commit d1df04f8 authored by Solene Tarride's avatar Solene Tarride
Browse files

Support two loggers

parent e1025f58
No related branches found
No related tags found
1 merge request!103Log to wandb
Pipeline #197303 passed
......@@ -61,7 +61,7 @@ The full list of parameters is detailed in this section.
| `train.early_stopping_patience` | Number of validation epochs with no improvement after which training will be stopped. | `int` | `20` |
| `train.gpu_stats` | Whether to include GPU stats in the training progress bar. | `bool` | `False` |
| `train.augment_training` | Whether to use data augmentation. | `bool` | `False` |
| `train.log_to` | Logger to use during training. Should be either `"csv"` to log metrics locally, or `"wandb"` to report to Weights & Biases | `Logger` | `"csv"` |
| `train.log_to_wandb` | Whether to log training metrics and parameters to Weights & Biases. | `bool` | `False` |
### Logging arguments
......@@ -211,7 +211,7 @@ To set up Weights & Biases:
* Run `pip install pylaia[wandb]` to install the required dependencies
* Sign in to Weights & Biases using `wandb login`
Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to wandb`.
Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to_wandb true`.
This will create a project called `PyLaia` in W&B with one run for each training. The following are monitored for each run:
* Training and validation metrics (losses, CER, WER)
......
......@@ -18,11 +18,6 @@ from jsonargparse.typing import (
GeNeg1Int = restricted_number_type(None, int, (">=", -1))
class Logger(str, Enum):
csv = "csv"
wandb = "wandb"
class Monitor(str, Enum):
va_loss = "va_loss"
va_cer = "va_cer"
......@@ -233,7 +228,7 @@ class TrainArgs:
early_stopping_patience: NonNegativeInt = 20
gpu_stats: bool = False
augment_training: bool = False
log_to: Logger = Logger.csv
log_to_wandb: bool = False
@dataclass
......
......@@ -12,7 +12,6 @@ from laia.common.arguments import (
CommonArgs,
DataArgs,
DecodeArgs,
Logger,
OptimizerArgs,
SchedulerArgs,
TrainArgs,
......@@ -160,18 +159,18 @@ def run(
callbacks.append(LearningRate(logging_interval="epoch"))
# prepare the logger
if train.log_to == Logger.wandb:
logger = pl.loggers.WandbLogger(project="PyLaia")
logger.watch(model)
else:
logger = EpochCSVLogger(common.experiment_dirpath)
loggers = [EpochCSVLogger(common.experiment_dirpath)]
if train.log_to_wandb:
wandb_logger = pl.loggers.WandbLogger(project="PyLaia")
wandb_logger.watch(model)
loggers.append(wandb_logger)
# prepare the trainer
trainer = pl.Trainer(
default_root_dir=common.train_path,
resume_from_checkpoint=checkpoint_path,
callbacks=callbacks,
logger=logger,
logger=loggers,
checkpoint_callback=True,
**vars(trainer),
)
......
......@@ -48,7 +48,7 @@ train:
early_stopping_patience: 20
gpu_stats: false
augment_training: false
log_to: csv
log_to_wandb: false
logging:
fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s'
level: INFO
......
......@@ -101,7 +101,7 @@ train:
early_stopping_patience: 20
gpu_stats: false
augment_training: false
log_to: csv
log_to_wandb: false
logging:
fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s'
level: INFO
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment