Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (7)
Showing
with 419 additions and 337 deletions
......@@ -52,6 +52,21 @@ test:
script:
- tox -- -v
docker-build:
stage: build
image: docker:19.03.1
services:
- docker:dind
variables:
DOCKER_DRIVER: overlay2
DOCKER_HOST: tcp://docker:2375/
except:
- schedules
script:
- ci/build.sh
# Make sure docs still build correctly
.docs:
image: python:3.10
......
FROM nvidia/cuda:12.2.0-base-ubuntu22.04
# Install python and pip
RUN apt-get -y update && \
apt-get -y install python3 python3-pip && \
apt-get clean -y
WORKDIR /src
# Install DAN as a package
COPY dan dan
COPY teklia_line_image_extractor teklia_line_image_extractor
COPY requirements.txt *-requirements.txt setup.py VERSION README.md ./
RUN pip install . --no-cache-dir
......@@ -2,3 +2,4 @@ include requirements.txt
include doc-requirements.txt
include mlflow-requirements.txt
include VERSION
include README.md
#!/bin/sh -e
# Build the tasks Docker image.
# Requires CI_PROJECT_DIR and CI_REGISTRY_IMAGE to be set.
# VERSION defaults to latest.
# Will automatically login to a registry if CI_REGISTRY, CI_REGISTRY_USER and CI_REGISTRY_PASSWORD are set.
# Will only push an image if $CI_REGISTRY is set.
if [ -z "$VERSION" ]; then
VERSION=${CI_COMMIT_TAG:-latest}
fi
if [ -z "$VERSION" -o -z "$CI_PROJECT_DIR" -o -z "$CI_REGISTRY_IMAGE" ]; then
echo Missing environment variables
exit 1
fi
IMAGE_TAG="$CI_REGISTRY_IMAGE:$VERSION"
cd $CI_PROJECT_DIR
docker build -f Dockerfile . -t "$IMAGE_TAG"
# Publish the image on the main branch or on a tag
if [ "$CI_COMMIT_REF_NAME" = "main" -o -n "$CI_COMMIT_TAG" ]; then
if [ -n "$CI_REGISTRY" -a -n "$CI_REGISTRY_USER" -a -n "$CI_REGISTRY_PASSWORD" ]; then
echo $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY
docker push $IMAGE_TAG
else
echo "Missing environment variables to log in to the container registry…"
fi
else
echo "The build was not published to the repository registry (only for main branch or tags)…"
fi
{
"mlflow": {
"run_name": "Test log DAN",
"run_id": null,
"s3_endpoint_url": "",
"tracking_uri": "",
"experiment_id": "0",
"aws_access_key_id": "",
"aws_secret_access_key": ""
},
"dataset": {
"datasets": {
"$dataset_name": "$dataset_path"
},
"train": {
"name": "$dataset_name-train",
"datasets": [
["$dataset_name", "train"]
]
},
"val": {
"$dataset_name-val": [
["$dataset_name", "val"]
]
},
"test": {
"$dataset_name-test": [
["$dataset_name", "test"]
]
},
"max_char_prediction": 1000,
"tokens": null
},
"model": {
"transfered_charset": true,
"additional_tokens": 1,
"encoder": {
"dropout": 0.5,
"nb_layers": 5
},
"h_max": 500,
"w_max": 1000,
"decoder": {
"l_max": 15000,
"dec_num_layers": 8,
"dec_num_heads": 4,
"dec_res_dropout": 0.1,
"dec_pred_dropout": 0.1,
"dec_att_dropout": 0.1,
"dec_dim_feedforward": 256,
"attention_win": 100,
"enc_dim": 256
}
},
"training": {
"data": {
"batch_size": 2,
"load_in_memory": true,
"worker_per_gpu": 4,
"preprocessings": [
{
"type": "max_resize",
"max_width": 2000,
"max_height": 2000
}
],
"augmentation": true
},
"device": {
"use_ddp": false,
"ddp_port": "20027",
"use_amp": true,
"nb_gpu": null,
"force_cpu": false
},
"metrics": {
"train": [
"loss_ce",
"cer",
"wer",
"wer_no_punct"
],
"eval": [
"cer",
"wer",
"wer_no_punct"
]
},
"validation": {
"eval_on_valid": true,
"eval_on_valid_interval": 5,
"set_name_focus_metric": "$dataset_name-val"
},
"output_folder": "$dataset_path/output",
"max_nb_epochs": 800,
"load_epoch": "last",
"optimizers": {
"all": {
"args": {
"lr": 0.0001,
"amsgrad": false
}
}
},
"lr_schedulers": null,
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4
},
"transfer_learning": {
"encoder": [
"encoder",
"pretrained_models/dan_rimes_page.pt",
true,
true
],
"decoder": [
"decoder",
"pretrained_models/dan_rimes_page.pt",
true,
false
]
}
}
}
......@@ -5,4 +5,3 @@ logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
......@@ -3,31 +3,10 @@
Analyze dataset and display statistics in markdown format.
"""
import json
from pathlib import Path
from typing import Dict
import yaml
from dan.datasets.analyze.statistics import run
def read_yaml(yaml_path: str) -> Dict:
"""
Read YAML tokens file
"""
filename = Path(yaml_path)
assert filename.exists()
return yaml.safe_load(filename.read_text())
def read_json(json_path: str) -> Dict:
"""
Read labels JSON file
"""
filename = Path(json_path)
assert filename.exists()
return json.loads(filename.read_text())
from dan.utils import read_json, read_yaml
def add_analyze_parser(subcommands) -> None:
......
# -*- coding: utf-8 -*-
import logging
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Optional
......@@ -8,7 +9,7 @@ import numpy as np
from mdutils.mdutils import MdUtils
from prettytable import MARKDOWN, PrettyTable
from dan import logger
logger = logging.getLogger(__name__)
METRIC_COLUMN = "Metric"
......
# -*- coding: utf-8 -*-
import json
import logging
import pickle
import random
from collections import defaultdict
......@@ -12,7 +13,6 @@ import numpy as np
from tqdm import tqdm
from arkindex_export import open_database
from dan import logger
from dan.datasets.extract.db import (
Element,
get_elements,
......@@ -37,6 +37,7 @@ IMAGES_DIR = "images" # Subpath to the images directory.
SPLIT_NAMES = ["train", "val", "test"]
IIIF_URL_SUFFIX = "/full/full/0/default.jpg"
logger = logging.getLogger(__name__)
class ArkindexExtractor:
......
......@@ -5,6 +5,7 @@ Train a new DAN model.
from dan.ocr.predict import add_predict_parser # noqa
from dan.ocr.train import run
from dan.utils import read_json
def add_train_parser(subcommands) -> None:
......@@ -14,4 +15,11 @@ def add_train_parser(subcommands) -> None:
help=__doc__,
)
parser.add_argument(
"--config",
type=read_json,
required=True,
help="Configuration file.",
)
parser.set_defaults(func=run)
......@@ -95,15 +95,11 @@ class OCRDataset(Dataset):
set_name = path_and_set["set_name"]
gt = gt_per_set[set_name]
for filename in natural_sort(gt):
if isinstance(gt[filename], dict) and "text" in gt[filename]:
label = gt[filename]["text"]
else:
label = gt[filename]
filepath = Path(filename)
samples.append(
{
"name": filepath.name,
"label": label,
"label": gt[filename],
"path": filepath.resolve(),
}
)
......
......@@ -5,6 +5,7 @@ Predict on an image using a trained DAN model.
import pathlib
from dan.ocr.predict.attention import Level
from dan.ocr.predict.prediction import run
from dan.utils import parse_tokens
......@@ -81,7 +82,7 @@ def add_predict_parser(subcommands) -> None:
parser.add_argument(
"--confidence-score-levels",
default=[],
type=str,
type=Level,
nargs="+",
help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].",
required=False,
......@@ -94,10 +95,9 @@ def add_predict_parser(subcommands) -> None:
)
parser.add_argument(
"--attention-map-level",
type=str,
choices=["line", "word", "char"],
default="line",
help="Level of attention maps.",
type=Level,
default=Level.Line,
help="Level to plot the attention maps. Should be in ['line', 'word', 'char'].",
required=False,
)
parser.add_argument(
......
# -*- coding: utf-8 -*-
import logging
import re
from enum import Enum
from operator import attrgetter
from typing import List, Tuple
import cv2
......@@ -8,7 +11,13 @@ import torch
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan import logger
logger = logging.getLogger(__name__)
class Level(str, Enum):
Char = "char"
Word = "word"
Line = "line"
def parse_delimiters(delimiters: List[str]) -> re.Pattern:
......@@ -36,7 +45,7 @@ def compute_prob_by_separator(
def split_text(
text: str, level: str, word_separators: re.Pattern, line_separators: re.Pattern
text: str, level: Level, word_separators: re.Pattern, line_separators: re.Pattern
) -> Tuple[List[str], int]:
"""
Split text into a list of characters, word, or lines.
......@@ -45,27 +54,27 @@ def split_text(
:param word_separators: List of word separators
:param line_separators: List of line separators
"""
if level == "char":
text_split = list(text)
offset = 0
# split into words
elif level == "word":
text_split = re.split(word_separators, text)
offset = 1
# split into lines
elif level == "line":
text_split = re.split(line_separators, text)
offset = 1
else:
logger.error("Level should be either 'char', 'word', or 'line'")
match level:
case Level.Char:
text_split = list(text)
# split into words
case Level.Word:
text_split = re.split(word_separators, text)
# split into lines
case Level.Line:
text_split = re.split(line_separators, text)
case _:
choices = ", ".join(list(map(attrgetter("value"), Level)))
logger.error(f"Level should be either {choices}")
offset = int(level != Level.Char)
return text_split, offset
def split_text_and_confidences(
text: str,
confidences: List[float],
level: str,
level: Level,
word_separators: re.Pattern,
line_separators: re.Pattern,
) -> Tuple[List[str], List[np.float64], int]:
......@@ -77,21 +86,22 @@ def split_text_and_confidences(
:param word_separators: List of word separators
:param line_separators: List of line separators
"""
if level == "char":
texts = list(text)
offset = 0
elif level == "word":
texts, confidences = compute_prob_by_separator(
text, confidences, word_separators
)
offset = 1
elif level == "line":
texts, confidences = compute_prob_by_separator(
text, confidences, line_separators
)
offset = 1
else:
logger.error("Level should be either 'char', 'word', or 'line'")
match level:
case Level.Char:
texts = list(text)
case Level.Word:
texts, confidences = compute_prob_by_separator(
text, confidences, word_separators
)
case Level.Line:
texts, confidences = compute_prob_by_separator(
text, confidences, line_separators
)
case _:
choices = ", ".join(list(map(attrgetter("value"), Level)))
logger.error(f"Level should be either {choices}")
offset = int(level != Level.Char)
return texts, [np.around(num, 2) for num in confidences], offset
......@@ -99,7 +109,7 @@ def get_predicted_polygons_with_confidence(
text: str,
weights: np.ndarray,
confidences: List[float],
level: str,
level: Level,
height: int,
width: int,
threshold_method: str = "otsu",
......@@ -310,7 +320,7 @@ def plot_attention(
image: torch.Tensor,
text: str,
weights: np.ndarray,
level: str,
level: Level,
scale: float,
outname: str,
threshold_method: str = "otsu",
......
# -*- coding: utf-8 -*-
import json
import logging
import pickle
import re
from itertools import pairwise
......@@ -11,10 +12,10 @@ import numpy as np
import torch
import yaml
from dan import logger
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.predict.attention import (
Level,
get_predicted_polygons_with_confidence,
parse_delimiters,
plot_attention,
......@@ -29,6 +30,8 @@ from dan.utils import (
read_image,
)
logger = logging.getLogger(__name__)
class DAN:
"""
......@@ -114,7 +117,7 @@ class DAN:
input_sizes: List[torch.Size],
confidences: bool = False,
attentions: bool = False,
attention_level: str = "line",
attention_level: Level = Level.Line,
extract_objects: bool = False,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
......@@ -278,9 +281,9 @@ def process_batch(
device: str,
output: Path,
confidence_score: bool,
confidence_score_levels: List[str],
confidence_score_levels: List[Level],
attention_map: bool,
attention_map_level: str,
attention_map_level: Level,
attention_map_scale: float,
word_separators: List[str],
line_separators: List[str],
......@@ -362,7 +365,7 @@ def process_batch(
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels:
result["confidences"][level] = []
result["confidences"][level.value] = []
texts, confidences, _ = split_text_and_confidences(
predicted_text,
char_confidences,
......@@ -372,7 +375,7 @@ def process_batch(
)
for text, conf in zip(texts, confidences):
result["confidences"][level].append(
result["confidences"][level.value].append(
{"text": text, "confidence": conf}
)
......@@ -410,9 +413,9 @@ def run(
charset: Path,
output: Path,
confidence_score: bool,
confidence_score_levels: List[str],
confidence_score_levels: List[Level],
attention_map: bool,
attention_map_level: str,
attention_map_level: Level,
attention_map_scale: float,
word_separators: List[str],
line_separators: List[str],
......
......@@ -20,7 +20,7 @@ from dan.utils import MLflowNotInstalled
if MLFLOW_AVAILABLE:
import mlflow
from dan.mlflow import make_mlflow_request, start_mlflow_run
from dan.ocr.mlflow import make_mlflow_request, start_mlflow_run
logger = logging.getLogger(__name__)
......@@ -61,153 +61,48 @@ def train_and_test(rank, params, mlflow_logging=False):
)
def get_config():
def update_config(config: dict):
"""
Retrieve model configuration
Update some fields for easier
"""
dataset_name = "esposalles"
dataset_level = "record"
dataset_variant = "_debug"
dataset_path = "."
params = {
"mlflow": {
"run_name": "Test log DAN",
"run_id": None,
"s3_endpoint_url": "",
"tracking_uri": "",
"experiment_id": "0",
"aws_access_key_id": "",
"aws_secret_access_key": "",
},
"dataset": {
"datasets": {
dataset_name: Path(dataset_path)
/ "{}_{}{}".format(dataset_name, dataset_level, dataset_variant),
},
"train": {
"name": "{}-train".format(dataset_name),
"datasets": [
(dataset_name, "train"),
],
},
"val": {
"{}-val".format(dataset_name): [
(dataset_name, "val"),
],
},
"test": {
"{}-test".format(dataset_name): [
(dataset_name, "test"),
],
},
"max_char_prediction": 1000, # max number of token prediction
"tokens": None,
},
"model": {
"transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model
"additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset
"encoder": {
"class": FCN_Encoder,
"dropout": 0.5, # dropout rate for encoder
"nb_layers": 5, # encoder
},
"h_max": 500, # maximum height for encoder output (for 2D positional embedding)
"w_max": 1000, # maximum width for encoder output (for 2D positional embedding)
"decoder": {
"class": GlobalHTADecoder,
"l_max": 15000, # max predicted sequence (for 1D positional embedding)
"dec_num_layers": 8, # number of transformer decoder layers
"dec_num_heads": 4, # number of heads in transformer decoder layers
"dec_res_dropout": 0.1, # dropout in transformer decoder layers
"dec_pred_dropout": 0.1, # dropout rate before decision layer
"dec_att_dropout": 0.1, # dropout rate in multi head attention
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"attention_win": 100, # length of attention window
"enc_dim": 256, # dimension of extracted features
},
},
"training": {
"data": {
"batch_size": 2, # mini-batch size for training
"load_in_memory": True, # Load all images in CPU memory
"worker_per_gpu": 4, # Num of parallel processes per gpu for data loading
"preprocessings": [
{
"type": Preprocessing.MaxResize,
"max_width": 2000,
"max_height": 2000,
}
],
"augmentation": True,
},
"device": {
"use_ddp": False, # Use DistributedDataParallel
"ddp_port": "20027",
"use_amp": True, # Enable automatic mix-precision
"nb_gpu": torch.cuda.device_count(),
"force_cpu": False, # True for debug purposes
},
"metrics": {
"train": [
"loss_ce",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval": [
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
},
"validation": {
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
},
"output_folder": Path(
"outputs/dan_esposalles_record"
), # folder name for checkpoint and results
"max_nb_epochs": 800, # maximum number of epochs before to stop
"load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate
"optimizers": {
"all": {
"class": Adam,
"args": {
"lr": 0.0001,
"amsgrad": False,
},
},
},
"lr_schedulers": None, # Learning rate schedulers
# Keep teacher forcing rate to 20% during whole training
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4,
},
# "transfer_learning": None,
"transfer_learning": {
# model_name: [state_dict_name, checkpoint_path, learnable, strict]
"encoder": [
"encoder",
Path("pretrained_models/dan_rimes_page.pt"),
True,
True,
],
"decoder": [
"decoder",
Path("pretrained_models/dan_rimes_page.pt"),
True,
False,
],
},
},
# .dataset.datasets cast all values to Path
config["dataset"]["datasets"] = {
name: Path(path) for name, path in config["dataset"]["datasets"].items()
}
return params, dataset_name
# .model.encoder.class = FCN_ENCODER
config["model"]["encoder"]["class"] = FCN_Encoder
# .model.decoder.class = GlobalHTADecoder
config["model"]["decoder"]["class"] = GlobalHTADecoder
# Update preprocessing type
for prepro in config["training"]["data"]["preprocessings"]:
prepro["type"] = Preprocessing(prepro["type"])
# .training.output_folder to Path
config["training"]["output_folder"] = Path(config["training"]["output_folder"])
if config["training"]["transfer_learning"]:
# .training.transfer_learning.encoder[1]
config["training"]["transfer_learning"]["encoder"][1] = Path(
config["training"]["transfer_learning"]["encoder"][1]
)
# .training.transfer_learning.decoder[1]
config["training"]["transfer_learning"]["decoder"][1] = Path(
config["training"]["transfer_learning"]["decoder"][1]
)
# Parse optimizers
for optimizer_setup in config["training"]["optimizers"].values():
# Only supported optimizer is Adam
optimizer_setup["class"] = Adam
# set nb_gpu if not present
if config["training"]["device"]["nb_gpu"] is None:
config["training"]["device"]["nb_gpu"] = torch.cuda.device_count()
def serialize_config(config):
......@@ -262,23 +157,27 @@ def start_training(config, mlflow_logging: bool) -> None:
train_and_test(0, config, mlflow_logging)
def run():
def run(config: dict):
"""
Main program, training a new model, using a valid configuration
"""
names = list(config["dataset"]["datasets"].keys())
# We should only have one dataset
assert len(names) == 1, f"Found {len(names)} datasets but only one is expected"
config, dataset_name = get_config()
dataset_name = names.pop()
update_config(config)
if "mlflow" in config and not MLFLOW_AVAILABLE:
if config.get("mlflow") and not MLFLOW_AVAILABLE:
logger.error(
"Cannot log to MLflow. Please install the `mlflow` extra requirements."
)
raise MLflowNotInstalled()
if "mlflow" not in config:
if not config.get("mlflow"):
start_training(config, mlflow_logging=False)
else:
labels_path = Path(config["dataset"]["datasets"][dataset_name]) / "labels.json"
labels_path = config["dataset"]["datasets"][dataset_name] / "labels.json"
with start_mlflow_run(config["mlflow"]) as (run, created):
if created:
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
......
# -*- coding: utf-8 -*-
import json
from argparse import ArgumentTypeError
from itertools import islice
from pathlib import Path
from typing import Dict, NamedTuple
......@@ -106,8 +108,32 @@ def list_to_batches(iterable, n):
yield batch
def parse_tokens(filename: Path) -> Dict[str, EntityType]:
def parse_tokens(filename: str) -> Dict[str, EntityType]:
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
for name, tokens in yaml.safe_load(Path(filename).read_text()).items()
}
def read_yaml(yaml_path: str) -> Dict:
"""
Read YAML tokens file
"""
filename = Path(yaml_path)
assert filename.exists(), f"{yaml_path} does not resolve."
try:
return yaml.safe_load(filename.read_text())
except yaml.YAMLError as e:
raise ArgumentTypeError(e)
def read_json(json_path: str) -> Dict:
"""
Read labels JSON file
"""
filename = Path(json_path)
assert filename.exists(), f"{json_path} does not resolve."
try:
return json.loads(filename.read_text())
except json.JSONDecodeError as e:
raise ArgumentTypeError(e)
......@@ -24,16 +24,8 @@ output/
## 2. Train
The training command does not take any input parameters for now. To train a DAN model, you will therefore need to:
1. Update the parameters from those listed in the [dedicated page](../usage/train/parameters.md). You will always need to update at least these variables:
- `dataset_name`, `dataset_level`, `dataset_variant` and `dataset_path`,
- `model_params.transfer_learning.*[checkpoint_path]` to finetune an existing model,
- `training_params.output_folder`.
1. Train a DAN model with the [train command](../usage/train/index.md).
To train a DAN model, please refer to the [documentation of the training command](../usage/train/index.md).
## 3. Predict
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict.md) and the `inference_parameters.yml` file, located in `{training_params.output_folder}/results`.
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`.
......@@ -2,22 +2,14 @@
Use the `teklia-dan train` command to train a new DAN model. It is able to train a DAN model at line or document-level and evaluate it.
## Examples
To train DAN on your dataset:
### Document
To train DAN on documents:
1. Set your training configuration in `dan/ocr/train.py`. Refer to the [dedicated section](parameters.md) for a description of parameters.
1. Run `teklia-dan train`.
1. Look into evaluation results in the `output` folder:
1. Create a training JSON configuration file. Refer to the [dedicated section](parameters.md) for a description of parameters.
1. Run `teklia-dan train --config path/to/your/config.json`.
1. Look into evaluation results in the output folder indicated in your configuration:
- `checkpoints` contains model weights for the last trained epoch and for the epoch giving the best valid CER.
- `results` contains the tensorboard log file, the parameters file, and the evaluation results for the best epoch.
### Line
To train DAN on lines, run `teklia-dan train` with a line dataset.
## Additional pages
- [Jean Zay tutorial](jeanzay.md)
......
This diff is collapsed.
......@@ -2,12 +2,9 @@
from pathlib import Path
import pytest
from torch.optim import Adam
from arkindex_export import open_database
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.transforms import Preprocessing
from dan.ocr.train import update_config
FIXTURES = Path(__file__).resolve().parent / "data"
......@@ -27,10 +24,10 @@ def demo_db(database_path):
@pytest.fixture
def training_config():
return {
config = {
"dataset": {
"datasets": {
"training": FIXTURES / "training" / "training_dataset",
"training": str(FIXTURES / "training" / "training_dataset"),
},
"train": {
"name": "training-train",
......@@ -55,14 +52,12 @@ def training_config():
"transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model
"additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset
"encoder": {
"class": FCN_Encoder,
"dropout": 0.5, # dropout rate for encoder
"nb_layers": 5, # encoder
},
"h_max": 500, # maximum height for encoder output (for 2D positional embedding)
"w_max": 1000, # maximum width for encoder output (for 2D positional embedding)
"decoder": {
"class": GlobalHTADecoder,
"l_max": 15000, # max predicted sequence (for 1D positional embedding)
"dec_num_layers": 8, # number of transformer decoder layers
"dec_num_heads": 4, # number of heads in transformer decoder layers
......@@ -81,7 +76,7 @@ def training_config():
"worker_per_gpu": 4, # Num of parallel processes per gpu for data loading
"preprocessings": [
{
"type": Preprocessing.MaxResize,
"type": "max_resize",
"max_width": 2000,
"max_height": 2000,
}
......@@ -113,15 +108,12 @@ def training_config():
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"set_name_focus_metric": "training-val",
},
"output_folder": Path(
"dan_trained_model"
), # folder name for checkpoint and results
"output_folder": "dan_trained_model", # folder name for checkpoint and results
"gradient_clipping": {},
"max_nb_epochs": 4, # maximum number of epochs before to stop
"load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate
"optimizers": {
"all": {
"class": Adam,
"args": {
"lr": 0.0001,
"amsgrad": False,
......@@ -138,6 +130,8 @@ def training_config():
"transfer_learning": None,
},
}
update_config(config)
return config
@pytest.fixture
......