Implement training in the script
Depends #5 (closed)
We now need to update worker_pylaia/train.py
to support training from a local dataset.
New arguments:
-
--img-path
, Path, required, path to the training dataset's images -
--data-path
, Path, required, path to the training dataset
Define the following functions:
-
get_model_training_config()
: this will return a JSON version of the following YAML file (inspired by the original training YAML config file)
img_dirs: <args.img_path>
syms: <args.syms_path>
tr_txt_table: <args.data_path> / "train.txt"
va_txt_table: <args.data_path> / "val.txt"
common:
train_path: <created_tmpdir>
data:
batch_size: 8
color_mode: L
optimizer:
learning_rate: 0.0005
name: RMSProp
scheduler:
active: true
monitor: va_loss
train:
augment_training: false
early_stopping_patience: 50
delimiters:
- <space>
trainer:
auto_select_gpus: true
max_epochs: 1
-
train_model()
: this will make a call tolaia.scripts.htr.train_ctc.run
(renamed astrain_model
) defined here with the following arguments
run(
syms=config["syms"],
img_dirs=config["img_dirs"],
tr_txt_table=config["tr_txt_table"],
va_txt_table=config["va_txt_table"],
common=CommonArgs(**config["common"]),
train=TrainArgs(**config["train"]),
optimizer=OptimizerArgs(**config["optimizer"]),
scheduler=SchedulerArgs(**config["scheduler"]),
data=DataArgs(**config["data"]),
trainer=TrainerArgs(**config["trainer"]),
)
Running the script will call first get_model_creation_config
then create_model
to create the model. Then it will call get_model_training_config()
and
train_model()
to train the model.
To run the script, the following directory architecture is now expected (line images may be in subfolders):
<args.data_path>
│ train.txt
│ val.txt
│
└───images
│ │ line001.jpg
│ │ line002.jpg
| | ...
Edited by Yoann Schneider