Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (16)
Showing
with 466 additions and 244 deletions
......@@ -61,15 +61,17 @@ from dan.ocr.predict.inference import DAN
image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
```
Then one can initialize and load the trained model with the parameters used during training.
Then one can initialize and load the trained model with the parameters used during training. The directory passed as parameter should have:
- a `model.pt` file,
- a `charset.pkl` file,
- a `parameters.yml` file corresponding to the `inference_parameters.yml` file generated during training.
```python
model_path = "model.pt"
params_path = "parameters.yml"
charset_path = "charset.pkl"
model_path = "models"
model = DAN("cpu")
model.load(model_path, params_path, charset_path, mode="eval")
model.load(model_path, mode="eval")
```
To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In the end, one can run the prediction:
......
0.2.0-dev3
0.2.0-dev4
{
"dataset": {
"datasets": {
"training": "tests/data/training/training_dataset"
},
"train": {
"name": "training-train",
"datasets": [
["training", "train"]
]
},
"val": {
"training-val": [
["training", "val"]
]
},
"test": {
"training-test": [
["training", "test"]
]
},
"max_char_prediction": 30,
"tokens": null
},
"model": {
"transfered_charset": true,
"additional_tokens": 1,
"encoder": {
"dropout": 0.5,
"nb_layers": 5
},
"h_max": 500,
"w_max": 1000,
"decoder": {
"l_max": 15000,
"dec_num_layers": 8,
"dec_num_heads": 4,
"dec_res_dropout": 0.1,
"dec_pred_dropout": 0.1,
"dec_att_dropout": 0.1,
"dec_dim_feedforward": 256,
"attention_win": 100,
"enc_dim": 256
}
},
"training": {
"data": {
"batch_size": 2,
"load_in_memory": true,
"worker_per_gpu": 4,
"preprocessings": [
{
"type": "max_resize",
"max_width": 2000,
"max_height": 2000
}
],
"augmentation": true
},
"device": {
"use_ddp": false,
"ddp_port": "20027",
"use_amp": true,
"nb_gpu": 0,
"force": "cpu"
},
"metrics": {
"train": [
"loss_ce",
"cer",
"wer",
"wer_no_punct"
],
"eval": [
"cer",
"wer",
"wer_no_punct"
]
},
"validation": {
"eval_on_valid": true,
"eval_on_valid_interval": 2,
"set_name_focus_metric": "training-val"
},
"output_folder": "tests/data/evaluate",
"gradient_clipping": {},
"max_nb_epochs": 4,
"load_epoch": "best",
"optimizers": {
"all": {
"args": {
"lr": 0.0001,
"amsgrad": false
}
}
},
"lr_schedulers": null,
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4
},
"transfer_learning": null
}
}
......@@ -71,7 +71,7 @@
"ddp_port": "20027",
"use_amp": true,
"nb_gpu": null,
"force_cpu": false
"force": null
},
"metrics": {
"train": [
......
{
"dataset": {
"datasets": {
"training": "tests/data/training/training_dataset"
},
"train": {
"name": "training-train",
"datasets": [
["training", "train"]
]
},
"val": {
"training-val": [
["training", "val"]
]
},
"test": {
"training-test": [
["training", "test"]
]
},
"max_char_prediction": 30,
"tokens": null
},
"model": {
"transfered_charset": true,
"additional_tokens": 1,
"encoder": {
"dropout": 0.5,
"nb_layers": 5
},
"h_max": 500,
"w_max": 1000,
"decoder": {
"l_max": 15000,
"dec_num_layers": 8,
"dec_num_heads": 4,
"dec_res_dropout": 0.1,
"dec_pred_dropout": 0.1,
"dec_att_dropout": 0.1,
"dec_dim_feedforward": 256,
"attention_win": 100,
"enc_dim": 256
}
},
"training": {
"data": {
"batch_size": 2,
"load_in_memory": true,
"worker_per_gpu": 4,
"preprocessings": [
{
"type": "max_resize",
"max_width": 2000,
"max_height": 2000
}
],
"augmentation": true
},
"device": {
"use_ddp": false,
"ddp_port": "20027",
"use_amp": true,
"nb_gpu": 0,
"force": "cpu"
},
"metrics": {
"train": [
"loss_ce",
"cer",
"wer",
"wer_no_punct"
],
"eval": [
"cer",
"wer",
"wer_no_punct"
]
},
"validation": {
"eval_on_valid": true,
"eval_on_valid_interval": 2,
"set_name_focus_metric": "training-val"
},
"output_folder": "dan_trained_model",
"gradient_clipping": {},
"max_nb_epochs": 4,
"load_epoch": "last",
"optimizers": {
"all": {
"args": {
"lr": 0.0001,
"amsgrad": false
}
}
},
"lr_schedulers": null,
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4
},
"transfer_learning": null
}
}
......@@ -3,7 +3,7 @@ import argparse
import errno
from dan.datasets import add_dataset_parser
from dan.ocr import add_predict_parser, add_train_parser
from dan.ocr import add_evaluate_parser, add_predict_parser, add_train_parser
def get_parser():
......@@ -12,6 +12,7 @@ def get_parser():
add_dataset_parser(subcommands)
add_train_parser(subcommands)
add_evaluate_parser(subcommands)
add_predict_parser(subcommands)
return parser
......
......@@ -384,6 +384,9 @@ class ArkindexExtractor:
subword_vocab_size=self.subword_vocab_size,
)
if not tokenizer.sentencepiece_model:
return
for level, tokenize in (
("characters", tokenizer.char_tokenize),
("words", tokenizer.word_tokenize),
......@@ -478,6 +481,11 @@ class ArkindexExtractor:
pbar.update()
pbar.refresh()
if not self.data:
raise Exception(
"No data was extracted using the provided export database and parameters."
)
self.download_images()
self.format_lm_files()
self.export()
......
......@@ -186,12 +186,22 @@ class Tokenizer:
with NamedTemporaryFile(dir=self.outdir, suffix=".txt", mode="w") as tmp:
tmp.write("\n".join(self.training_corpus))
tmp.flush()
spm.SentencePieceTrainer.train(
input=tmp.name,
vocab_size=self.subword_vocab_size,
model_prefix=self.prefix,
user_defined_symbols=self.special_tokens,
)
try:
spm.SentencePieceTrainer.train(
input=tmp.name,
vocab_size=self.subword_vocab_size,
model_prefix=self.prefix,
user_defined_symbols=self.special_tokens,
minloglevel=1,
)
except Exception as e:
logger.warning(
f"Failed to train a sentencepiece model for subword tokenization: {e} "
"Try again by editing the `--subword-vocab-size` parameter."
)
self.sentencepiece_model = None
return
# Load the model
self.sentencepiece_model = spm.SentencePieceProcessor(
......
......@@ -3,6 +3,8 @@
Train a new DAN model.
"""
from dan.ocr.evaluate import add_evaluate_parser # noqa
from dan.ocr.predict import add_predict_parser # noqa
from dan.ocr.train import run
from dan.utils import read_json
......
# -*- coding: utf-8 -*-
"""
Evaluate a trained DAN model.
"""
import logging
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from dan.ocr.manager.training import Manager
from dan.ocr.utils import update_config
from dan.utils import read_json
logger = logging.getLogger(__name__)
def add_evaluate_parser(subcommands) -> None:
parser = subcommands.add_parser(
"evaluate",
description=__doc__,
help=__doc__,
)
parser.add_argument(
"--config",
type=read_json,
required=True,
help="Configuration file.",
)
parser.set_defaults(func=run)
def eval(rank, config, mlflow_logging):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
config["training"]["device"]["ddp_rank"] = rank
# Load best checkpoint
config["training"]["load_epoch"] = "best"
model = Manager(config)
model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"]
for dataset_name in config["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]:
logger.info(f"Evaluating on set `{set_name}`")
model.evaluate(
"{}-{}".format(dataset_name, set_name),
[
(dataset_name, set_name),
],
metrics,
output=True,
mlflow_logging=mlflow_logging,
)
def run(config: dict):
update_config(config)
mlflow_logging = bool(config.get("mlflow"))
if mlflow_logging:
logger.info("MLflow logging enabled")
if (
config["training"]["device"]["use_ddp"]
and config["training"]["device"]["force"] in [None, "cuda"]
and torch.cuda.is_available()
):
mp.spawn(
eval,
args=(config, mlflow_logging),
nprocs=config["training"]["device"]["nb_gpu"],
)
else:
eval(0, config, mlflow_logging)
......@@ -100,8 +100,10 @@ class GenericTrainingManager:
self.dataset.load_dataloaders()
def init_hardware_config(self):
cuda_is_available = torch.cuda.is_available()
# Debug mode
if self.device_params["force_cpu"]:
if self.device_params["force"] not in [None, "cuda"] or not cuda_is_available:
self.device_params["use_ddp"] = False
self.device_params["use_amp"] = False
......@@ -116,17 +118,14 @@ class GenericTrainingManager:
"rank": self.device_params["ddp_rank"],
}
self.is_master = self.ddp_config["master"] or not self.device_params["use_ddp"]
if self.device_params["force_cpu"]:
self.device = torch.device("cpu")
if self.device_params["use_ddp"]:
self.device = torch.device(self.ddp_config["rank"])
self.device_params["ddp_rank"] = self.ddp_config["rank"]
self.launch_ddp()
else:
if self.device_params["use_ddp"]:
self.device = torch.device(self.ddp_config["rank"])
self.device_params["ddp_rank"] = self.ddp_config["rank"]
self.launch_ddp()
else:
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
self.device = torch.device(
self.device_params["force"] or "cuda" if cuda_is_available else "cpu"
)
if self.device == torch.device("cpu"):
self.params["model"]["device"] = "cpu"
else:
......
......@@ -17,13 +17,7 @@ def add_predict_parser(subcommands) -> None:
help=__doc__,
)
# Required arguments.
image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
image_or_folder_input.add_argument(
"--image",
type=pathlib.Path,
help="Path to the image to predict.",
)
image_or_folder_input.add_argument(
parser.add_argument(
"--image-dir",
type=pathlib.Path,
help="Path to the folder where the images to predict are stored.",
......@@ -31,20 +25,7 @@ def add_predict_parser(subcommands) -> None:
parser.add_argument(
"--model",
type=pathlib.Path,
help="Path to the model to use for prediction.",
required=True,
)
parser.add_argument(
"--parameters",
type=pathlib.Path,
help="Path to the YAML parameters file.",
required=True,
default="page",
)
parser.add_argument(
"--charset",
type=pathlib.Path,
help="Path to the charset file.",
help="Path to the directory containing the model, the YAML parameters file and the charset file to use for prediction.",
required=True,
)
parser.add_argument(
......@@ -135,19 +116,6 @@ def add_predict_parser(subcommands) -> None:
type=int,
default=None,
)
parser.add_argument(
"--threshold-method",
help="Thresholding method.",
choices=["otsu", "simple"],
type=str,
default="otsu",
)
parser.add_argument(
"--threshold-value",
help="Thresholding value.",
type=int,
default=0,
)
parser.add_argument(
"--gpu-device",
help="Use a specific GPU if available.",
......
......@@ -220,8 +220,6 @@ def get_predicted_polygons_with_confidence(
level: Level,
height: int,
width: int,
threshold_method: str = "otsu",
threshold_value: int = 0,
max_object_height: int = 50,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
......@@ -235,8 +233,6 @@ def get_predicted_polygons_with_confidence(
:param level: Level to display (must be in [char, word, line, ner])
:param height: Original image height
:param width: Original image width
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"]
:param threshold_value: Thresholding value for the "simple" method.
:param max_object_height: Maximum height of predicted objects.
:param word_separators: List of word separators
:param line_separators: List of line separators
......@@ -256,8 +252,6 @@ def get_predicted_polygons_with_confidence(
max_value,
start_index,
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
size=(width, height),
)
......@@ -347,35 +341,21 @@ def polygon_to_bbx(polygon: np.ndarray) -> List[Tuple[int, int]]:
return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]]
def threshold(
mask: np.ndarray, threshold_method: str = "otsu", threshold_value: int = 0
) -> np.ndarray:
def threshold(mask: np.ndarray) -> np.ndarray:
"""
Threshold a grayscale mask.
:param mask: a grayscale image (np.array)
:param threshold_method: method to be used for thresholding. Should be in ["otsu", "simple"].
:param threshold_value: the threshold value used for binarization (used for the "simple" method).
"""
min_kernel = 1
max_kernel = mask.shape[1] // 100
if threshold_method == "simple":
bin_mask = np.array(np.where(mask > threshold_value, 255, 0), dtype=np.uint8)
return np.asarray(bin_mask, dtype=np.uint8)
elif threshold_method == "otsu":
# Blur and apply Otsu thresholding
blur = cv2.GaussianBlur(mask, (15, 15), 0)
_, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Apply dilation
kernel_width = cv2.getStructuringElement(
cv2.MORPH_CROSS, (max_kernel, min_kernel)
)
dilated = cv2.dilate(bin_mask, kernel_width, iterations=3)
return np.asarray(dilated, dtype=np.uint8)
else:
raise NotImplementedError(f"Method {threshold_method} is not implemented.")
# Blur and apply Otsu thresholding
blur = cv2.GaussianBlur(mask, (15, 15), 0)
_, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Apply dilation
kernel_width = cv2.getStructuringElement(cv2.MORPH_CROSS, (max_kernel, min_kernel))
dilated = cv2.dilate(bin_mask, kernel_width, iterations=3)
return np.asarray(dilated, dtype=np.uint8)
def get_polygon(
......@@ -383,8 +363,6 @@ def get_polygon(
max_value: np.float32,
offset: int,
weights: np.ndarray,
threshold_method: str,
threshold_value: int,
size: Tuple[int, int] = None,
max_object_height: int = 50,
) -> Tuple[dict, np.ndarray]:
......@@ -394,19 +372,13 @@ def get_polygon(
:param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
:param offset: Offset value to get the relevant part of text piece
:param size: Target size (width, height) to resize the coverage vector
:param threshold_method: Binarization method to use (should be in ["simple", "otsu"])
:param max_object_height: Maximum height of predicted objects.
:param threshold_value: Threshold value used for the "simple" binarization method
"""
# Compute coverage vector
coverage_vector = compute_coverage(text, max_value, offset, weights, size=size)
# Generate a binary image for the current channel.
bin_mask = threshold(
coverage_vector,
threshold_method=threshold_method,
threshold_value=threshold_value,
)
bin_mask = threshold(coverage_vector)
coord, confidence = (
get_grid_search_contour(coverage_vector, bin_mask, height=max_object_height)
......@@ -475,8 +447,6 @@ def plot_attention(
level: Level,
scale: float,
outname: str,
threshold_method: str = "otsu",
threshold_value: int = 0,
max_object_height: int = 50,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
......@@ -527,8 +497,6 @@ def plot_attention(
max_value,
tot_len,
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
size=(image.width, image.height),
)
......
......@@ -49,20 +49,25 @@ class DAN:
def load(
self,
model_path: Path,
params_path: Path,
charset_path: Path,
path: Path,
mode: str = "eval",
use_language_model: bool = False,
) -> None:
"""
Load a trained model.
:param model_path: Path to the model.
:param params_path: Path to the parameters.
:param charset_path: Path to the charset.
:param path: Path to the directory containing the model, the YAML parameters file and the charset file.
:param mode: The mode to load the model (train or eval).
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
"""
model_path = path / "model.pt"
assert model_path.is_file(), f"File {model_path} not found"
params_path = path / "parameters.yml"
assert params_path.is_file(), f"File {params_path} not found"
charset_path = path / "charset.pkl"
assert charset_path.is_file(), f"File {charset_path} not found"
parameters = yaml.safe_load(params_path.read_text())["parameters"]
parameters["decoder"]["device"] = self.device
......@@ -104,8 +109,8 @@ class DAN:
)
self.mean, self.std = (
torch.tensor(parameters["mean"]) / 255,
torch.tensor(parameters["std"]) / 255,
torch.tensor(parameters["mean"]) / 255 if "mean" in parameters else None,
torch.tensor(parameters["std"]) / 255 if "std" in parameters else None,
)
self.preprocessing_transforms = get_preprocessing_transforms(
parameters.get("preprocessings", [])
......@@ -119,11 +124,21 @@ class DAN:
"""
image = read_image(path)
preprocessed_image = self.preprocessing_transforms(image)
normalized_image = torch.zeros(preprocessed_image.shape)
for ch in range(preprocessed_image.shape[0]):
if self.mean is None and self.std is None:
return preprocessed_image, preprocessed_image
size = preprocessed_image.shape
normalized_image = torch.zeros(size)
mean = self.mean if self.mean is not None else torch.zeros(size[0])
std = self.std if self.std is not None else torch.ones(size[0])
for ch in range(size[0]):
normalized_image[ch, :, :] = (
preprocessed_image[ch, :, :] - self.mean[ch]
) / self.std[ch]
preprocessed_image[ch, :, :] - mean[ch]
) / std[ch]
return preprocessed_image, normalized_image
def predict(
......@@ -138,8 +153,6 @@ class DAN:
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
start_token: str = None,
threshold_method: str = "otsu",
threshold_value: int = 0,
max_object_height: int = 50,
use_language_model: bool = False,
) -> dict:
......@@ -151,8 +164,6 @@ class DAN:
:param attentions: Return characters attention weights.
:param attention_level: Level of text pieces (must be in [char, word, line, ner])
:param extract_objects: Whether to extract polygons' coordinates.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param max_object_height: Maximum height of predicted objects.
"""
input_tensor = input_tensor.to(self.device)
......@@ -284,8 +295,6 @@ class DAN:
attention_level,
input_sizes[i][0],
input_sizes[i][1],
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
word_separators=word_separators,
line_separators=line_separators,
......@@ -309,8 +318,6 @@ def process_batch(
word_separators: List[str],
line_separators: List[str],
predict_objects: bool,
threshold_method: str,
threshold_value: int,
max_object_height: int,
tokens: Dict[str, EntityType],
start_token: str,
......@@ -346,8 +353,6 @@ def process_batch(
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
start_token=start_token,
use_language_model=use_language_model,
......@@ -406,24 +411,19 @@ def process_batch(
line_separators=line_separators,
tokens=tokens,
display_polygons=predict_objects,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
outname=gif_filename,
)
result["attention_gif"] = gif_filename
json_filename = Path(output, image_path.stem).with_suffix(".json")
json_filename = Path(output, f"{image_path.stem}.json")
logger.info(f"Saving JSON prediction in {json_filename}")
json_filename.write_text(json.dumps(result, indent=2))
def run(
image: Optional[Path],
image_dir: Optional[Path],
model: Path,
parameters: Path,
charset: Path,
output: Path,
confidence_score: bool,
confidence_score_levels: List[Level],
......@@ -434,8 +434,6 @@ def run(
line_separators: List[str],
temperature: float,
predict_objects: bool,
threshold_method: str,
threshold_value: int,
max_object_height: int,
image_extension: str,
gpu_device: int,
......@@ -446,11 +444,8 @@ def run(
) -> None:
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param image_dir: Path to the folder where the images to predict are stored.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param model: Path to the directory containing the model, the YAML parameters file and the charset file to use for prediction.
:param output: Path to the output folder where the results will be saved.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
......@@ -459,8 +454,6 @@ def run(
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param max_object_height: Maximum height of predicted objects.
:param gpu_device: Use a specific GPU if available.
:param batch_size: Size of the batches for prediction.
......@@ -476,14 +469,12 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
dan_model.load(model, mode="eval", use_language_model=use_language_model)
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
images = image_dir.rglob(f"*{image_extension}")
for image_batch in list_to_batches(images, n=batch_size):
process_batch(
image_batch,
......@@ -498,8 +489,6 @@ def run(
word_separators,
line_separators,
predict_objects,
threshold_method,
threshold_value,
max_object_height,
tokens,
start_token,
......
......@@ -3,18 +3,14 @@ import json
import logging
import random
from copy import deepcopy
from pathlib import Path
import numpy as np
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.training import Manager
from dan.ocr.mlflow import MLFLOW_AVAILABLE
from dan.ocr.transforms import Preprocessing
from dan.ocr.utils import update_config
from dan.utils import MLflowNotInstalled
if MLFLOW_AVAILABLE:
......@@ -26,7 +22,7 @@ if MLFLOW_AVAILABLE:
logger = logging.getLogger(__name__)
def train_and_test(rank, params, mlflow_logging=False):
def train(rank, params, mlflow_logging=False):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
......@@ -43,67 +39,6 @@ def train_and_test(rank, params, mlflow_logging=False):
model.train(mlflow_logging=mlflow_logging)
# load weights giving best CER on valid set
model.params["training"]["load_epoch"] = "best"
model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"]
for dataset_name in params["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]:
model.evaluate(
"{}-{}".format(dataset_name, set_name),
[
(dataset_name, set_name),
],
metrics,
output=True,
mlflow_logging=mlflow_logging,
)
def update_config(config: dict):
"""
Update some fields for easier
"""
# .dataset.datasets cast all values to Path
config["dataset"]["datasets"] = {
name: Path(path) for name, path in config["dataset"]["datasets"].items()
}
# .model.encoder.class = FCN_ENCODER
config["model"]["encoder"]["class"] = FCN_Encoder
# .model.decoder.class = GlobalHTADecoder
config["model"]["decoder"]["class"] = GlobalHTADecoder
# Update preprocessing type
for prepro in config["training"]["data"]["preprocessings"]:
prepro["type"] = Preprocessing(prepro["type"])
# .training.output_folder to Path
config["training"]["output_folder"] = Path(config["training"]["output_folder"])
if config["training"]["transfer_learning"]:
# .training.transfer_learning.encoder[1]
config["training"]["transfer_learning"]["encoder"][1] = Path(
config["training"]["transfer_learning"]["encoder"][1]
)
# .training.transfer_learning.decoder[1]
config["training"]["transfer_learning"]["decoder"][1] = Path(
config["training"]["transfer_learning"]["decoder"][1]
)
# Parse optimizers
for optimizer_setup in config["training"]["optimizers"].values():
# Only supported optimizer is Adam
optimizer_setup["class"] = Adam
# set nb_gpu if not present
if config["training"]["device"]["nb_gpu"] is None:
config["training"]["device"]["nb_gpu"] = torch.cuda.device_count()
def serialize_config(config):
"""
......@@ -146,15 +81,16 @@ def serialize_config(config):
def start_training(config, mlflow_logging: bool) -> None:
if (
config["training"]["device"]["use_ddp"]
and not config["training"]["device"]["force_cpu"]
and config["training"]["device"]["force"] in [None, "cuda"]
and torch.cuda.is_available()
):
mp.spawn(
train_and_test,
train,
args=(config, mlflow_logging),
nprocs=config["training"]["device"]["nb_gpu"],
)
else:
train_and_test(0, config, mlflow_logging)
train(0, config, mlflow_logging)
def run(config: dict):
......
......@@ -6,6 +6,7 @@ from enum import Enum
from random import randint
import albumentations as A
import cv2
import numpy as np
from albumentations.augmentations import (
Affine,
......@@ -15,16 +16,16 @@ from albumentations.augmentations import (
GaussianBlur,
GaussNoise,
Perspective,
RandomScale,
Sharpen,
ToGray,
)
from albumentations.core.transforms_interface import ImageOnlyTransform
from cv2 import dilate, erode, resize
from cv2 import dilate, erode
from numpy import random
from torch import Tensor
from torch.distributions.uniform import Uniform
from torchvision.transforms import Compose, ToPILImage
from torchvision.transforms.functional import resize as resize_tensor
from torchvision.transforms.functional import resize
class Preprocessing(str, Enum):
......@@ -54,7 +55,7 @@ class FixedHeightResize:
def __call__(self, img: Tensor) -> Tensor:
size = (self.height, self._calc_new_width(img))
return resize_tensor(img, size, antialias=False)
return resize(img, size, antialias=False)
def _calc_new_width(self, img: Tensor) -> int:
aspect_ratio = img.shape[2] / img.shape[1]
......@@ -71,7 +72,7 @@ class FixedWidthResize:
def __call__(self, img: Tensor) -> Tensor:
size = (self._calc_new_height(img), self.width)
return resize_tensor(img, size, antialias=False)
return resize(img, size, antialias=False)
def _calc_new_height(self, img: Tensor) -> int:
aspect_ratio = img.shape[1] / img.shape[2]
......@@ -96,7 +97,7 @@ class MaxResize:
ratio = min(height_ratio, width_ratio)
new_width = int(width * ratio)
new_height = int(height * ratio)
return resize_tensor(img, (new_height, new_width), antialias=False)
return resize(img, (new_height, new_width), antialias=False)
class Dilation:
......@@ -156,29 +157,6 @@ class ErosionDilation(ImageOnlyTransform):
)
class DPIAdjusting(ImageOnlyTransform):
"""
Resolution modification
"""
def __init__(
self,
min_factor: float = 0.75,
max_factor: float = 1,
always_apply: bool = False,
p: float = 1.0,
):
super(DPIAdjusting, self).__init__(always_apply, p)
self.min_factor = min_factor
self.max_factor = max_factor
self.p = p
self.always_apply = False
def apply(self, img: np.ndarray, **params):
factor = float(Uniform(self.min_factor, self.max_factor).sample())
return resize(img, None, fx=factor, fy=factor)
def get_preprocessing_transforms(
preprocessings: list, to_pil_image: bool = False
) -> Compose:
......@@ -212,7 +190,10 @@ def get_augmentation_transforms() -> A.Compose:
"""
return A.Compose(
[
DPIAdjusting(min_factor=0.75, max_factor=1),
# Scale between 0.75 and 1.0
RandomScale(
scale_limit=[-0.25, 0], always_apply=True, interpolation=cv2.INTER_AREA
),
A.SomeOf(
[
ErosionDilation(min_kernel=1, max_kernel=4, iterations=1),
......@@ -220,10 +201,18 @@ def get_augmentation_transforms() -> A.Compose:
GaussianBlur(sigma_limit=2.5, p=1),
GaussNoise(var_limit=50**2, p=1),
ColorJitter(
contrast=0.2, brightness=0.2, saturation=0.2, hue=0.2, p=1
contrast=0.2,
brightness=0.2,
saturation=0.2,
hue=0.2,
p=1,
),
ElasticTransform(
alpha=20.0, sigma=5.0, alpha_affine=1.0, border_mode=0, p=1
alpha=20.0,
sigma=5.0,
alpha_affine=1.0,
border_mode=0,
p=1,
),
Sharpen(alpha=(0.0, 1.0), p=1),
Affine(shear={"x": (-20, 20), "y": (0, 0)}, p=1),
......
# -*- coding: utf-8 -*-
from pathlib import Path
import torch
from torch.optim import Adam
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.transforms import Preprocessing
def update_config(config: dict):
"""
Complete the fields that are not JSON serializable.
"""
# .dataset.datasets cast all values to Path
config["dataset"]["datasets"] = {
name: Path(path) for name, path in config["dataset"]["datasets"].items()
}
# .model.encoder.class = FCN_ENCODER
config["model"]["encoder"]["class"] = FCN_Encoder
# .model.decoder.class = GlobalHTADecoder
config["model"]["decoder"]["class"] = GlobalHTADecoder
# Update preprocessing type
for prepro in config["training"]["data"]["preprocessings"]:
prepro["type"] = Preprocessing(prepro["type"])
# .training.output_folder to Path
config["training"]["output_folder"] = Path(config["training"]["output_folder"])
if config["training"]["transfer_learning"]:
# .training.transfer_learning.encoder[1]
config["training"]["transfer_learning"]["encoder"][1] = Path(
config["training"]["transfer_learning"]["encoder"][1]
)
# .training.transfer_learning.decoder[1]
config["training"]["transfer_learning"]["decoder"][1] = Path(
config["training"]["transfer_learning"]["decoder"][1]
)
# Parse optimizers
for optimizer_setup in config["training"]["optimizers"].values():
# Only supported optimizer is Adam
optimizer_setup["class"] = Adam
# set nb_gpu if not present
if config["training"]["device"]["nb_gpu"] is None:
config["training"]["device"]["nb_gpu"] = torch.cuda.device_count()
docs/assets/augmentations/document_original.png

213 KiB

docs/assets/augmentations/document_random_scale.png

748 KiB

docs/assets/augmentations/line_full_pipeline_2.png

204 KiB