Skip to content
Snippets Groups Projects
Commit cefae223 authored by Solene Tarride's avatar Solene Tarride Committed by Yoann Schneider
Browse files

Log to wandb

parent 92bf8775
No related branches found
No related tags found
1 merge request!103Log to wandb
......@@ -42,25 +42,26 @@ The full list of parameters is detailed in this section.
### Data arguments
| Name | Description | Type | Default |
| ------------------ | ------------------------------------------------- | ----------- | ------------- |
| `data.batch_size` | Batch size. | `int` | `8` |
| `data.color_mode` | Color mode. Must be either `L`, `RGB` or `RGBA`. | `ColorMode` | `ColorMode.L` |
| `data.num_workers` | Number of worker processes created in dataloaders | `int` | `None` |
| `data.reading_order` | Reading order on the input lines: LFT (Left-to-Right) or RTL (Right-to-Left). | `ReadingOrder` | `LFT` |
| Name | Description | Type | Default |
| -------------------- | ----------------------------------------------------------------------------- | -------------- | ------------- |
| `data.batch_size` | Batch size. | `int` | `8` |
| `data.color_mode` | Color mode. Must be either `L`, `RGB` or `RGBA`. | `ColorMode` | `ColorMode.L` |
| `data.num_workers` | Number of worker processes created in dataloaders | `int` | `None` |
| `data.reading_order` | Reading order on the input lines: LFT (Left-to-Right) or RTL (Right-to-Left). | `ReadingOrder` | `LFT` |
### Train arguments
| Name | Description | Type | Default |
| ------------------------------- | ----------------------------------------------------------------------------------------------------------------------------- | ------------------ | ------------- |
| `train.delimiters` | List of symbols representing the word delimiters. | `List` | `["<space>"]` |
| `train.checkpoint_k` | Model saving mode: `-1` all models will be saved, `0`: no models are saved, `k` the `k` best models are saved. | `int` | `3` |
| `train.resume` | Whether to resume training with a checkpoint. This option can be used to continue training on the same dataset. | `bool` | `False` |
| `train.pretrain` | Whether to load pretrained weights from a checkpoint. This option can be used to load pretrained weights when fine-tuning a model on a new dataset. | `bool` | `False` |
| `train.freeze_layers` | List of layers to freeze during training: `"conv"` to freeze convolutional layers, `"rnn"` to freeze recurrent layers, `"linear"` to freeze the linear layer | `List[str]` | `None` |
| `train.early_stopping_patience` | Number of validation epochs with no improvement after which training will be stopped. | `int` | `20` |
| `train.gpu_stats` | Whether to include GPU stats in the training progress bar. | `bool` | `False` |
| `train.augment_training` | Whether to use data augmentation. | `bool` | `False` |
| Name | Description | Type | Default |
| ------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------- | ------------- |
| `train.delimiters` | List of symbols representing the word delimiters. | `List` | `["<space>"]` |
| `train.checkpoint_k` | Model saving mode: `-1` all models will be saved, `0`: no models are saved, `k` the `k` best models are saved. | `int` | `3` |
| `train.resume` | Whether to resume training with a checkpoint. This option can be used to continue training on the same dataset. | `bool` | `False` |
| `train.pretrain` | Whether to load pretrained weights from a checkpoint. This option can be used to load pretrained weights when fine-tuning a model on a new dataset. | `bool` | `False` |
| `train.freeze_layers` | List of layers to freeze during training: `"conv"` to freeze convolutional layers, `"rnn"` to freeze recurrent layers, `"linear"` to freeze the linear layer | `List[str]` | `None` |
| `train.early_stopping_patience` | Number of validation epochs with no improvement after which training will be stopped. | `int` | `20` |
| `train.gpu_stats` | Whether to include GPU stats in the training progress bar. | `bool` | `False` |
| `train.augment_training` | Whether to use data augmentation. | `bool` | `False` |
| `train.log_to_wandb` | Whether to log training metrics and parameters to Weights & Biases. | `bool` | `False` |
### Logging arguments
......@@ -201,3 +202,21 @@ trainer:
data:
reading_order: RTL
```
### Train and log to Weights & Biases
By default, PyLaia logs metrics and losses to a local CSV file. You can chose to log into [Weights & Biases](https://wandb.ai/home) instead.
To set up Weights & Biases:
* Run `pip install pylaia[wandb]` to install the required dependencies
* Sign in to Weights & Biases using `wandb login`
Then, start training with `pylaia-htr-train-ctc --config config_train_model.yaml --train.log_to_wandb true`.
This will create a project called `PyLaia` in W&B with one run for each training. The following are monitored for each run:
* Training and validation metrics (losses, CER, WER)
* Model gradients
* System metrics (GPU and CPU utilisation, temperature, allocated memory)
* Hyperparameters (training configuration)
A public dashboard is available [here](https://wandb.ai/starride-teklia/PyLaia%20demo) as an example.
......@@ -228,6 +228,7 @@ class TrainArgs:
early_stopping_patience: NonNegativeInt = 20
gpu_stats: bool = False
augment_training: bool = False
log_to_wandb: bool = False
@dataclass
......
......@@ -70,6 +70,8 @@ class DataModule(pl.LightningDataModule):
space_token=space_token,
space_display=space_display,
)
self.save_hyperparameters()
_logger.info(f"Training data transforms:\n{tr_img_transform}")
super().__init__(
train_transforms=(tr_img_transform, txt_transform),
......
......@@ -32,6 +32,7 @@ class HTREngineModule(EngineModule):
)
self.delimiters = delimiters
self.decoder = CTCGreedyDecoder()
self.save_hyperparameters()
def training_step(self, batch: Any, *args, **kwargs):
result = super().training_step(batch, *args, **kwargs)
......
......@@ -158,12 +158,19 @@ def run(
if scheduler.active:
callbacks.append(LearningRate(logging_interval="epoch"))
# prepare the logger
loggers = [EpochCSVLogger(common.experiment_dirpath)]
if train.log_to_wandb:
wandb_logger = pl.loggers.WandbLogger(project="PyLaia")
wandb_logger.watch(model)
loggers.append(wandb_logger)
# prepare the trainer
trainer = pl.Trainer(
default_root_dir=common.train_path,
resume_from_checkpoint=checkpoint_path,
callbacks=callbacks,
logger=EpochCSVLogger(common.experiment_dirpath),
logger=loggers,
checkpoint_callback=True,
**vars(trainer),
)
......
......@@ -65,6 +65,7 @@ docs = [
"mkdocs-section-index==0.3.9",
"mkdocstrings-python==1.10.8",
]
wandb = ["wandb==0.18.5"]
[project.scripts]
pylaia-htr-create-model = "laia.scripts.htr.create_model:main"
......
......@@ -48,6 +48,7 @@ train:
early_stopping_patience: 20
gpu_stats: false
augment_training: false
log_to_wandb: false
logging:
fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s'
level: INFO
......
......@@ -101,6 +101,7 @@ train:
early_stopping_patience: 20
gpu_stats: false
augment_training: false
log_to_wandb: false
logging:
fmt: '[%(asctime)s %(levelname)s %(name)s] %(message)s'
level: INFO
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment