Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (3)
......@@ -122,14 +122,12 @@ def add_extract_parser(subcommands) -> None:
type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=False,
)
parser.add_argument(
"--entity-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=False,
)
parser.add_argument(
......
......@@ -3,15 +3,17 @@
import ast
from dataclasses import dataclass
from itertools import starmap
from typing import List, NamedTuple, Optional, Union
from typing import List, Optional, Union
from urllib.parse import urljoin
from arkindex_export import Image
from arkindex_export.models import Element as ArkindexElement
from arkindex_export.models import Entity as ArkindexEntity
from arkindex_export.models import EntityType as ArkindexEntityType
from arkindex_export.models import Transcription as ArkindexTranscription
from arkindex_export.models import TranscriptionEntity as ArkindexTranscriptionEntity
from arkindex_export.models import (
Entity,
EntityType,
Transcription,
TranscriptionEntity,
)
from arkindex_export.queries import list_children
......@@ -25,23 +27,6 @@ def bounding_box(polygon: list):
return int(x), int(y), int(width), int(height)
# DB models
Transcription = NamedTuple(
"Transcription",
id=str,
text=str,
)
Entity = NamedTuple(
"Entity",
type=str,
value=str,
offset=float,
length=float,
)
@dataclass
class Element:
id: str
......@@ -94,6 +79,7 @@ def get_elements(
Image.height,
)
)
return list(
starmap(
lambda *x: Element(*x, max_width=max_width, max_height=max_height),
......@@ -118,47 +104,43 @@ def get_transcriptions(
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
"""
query = ArkindexTranscription.select(
ArkindexTranscription.id, ArkindexTranscription.text
).where(
(ArkindexTranscription.element == element_id)
& build_worker_version_filter(
ArkindexTranscription, worker_version=transcription_worker_version
)
)
return list(
starmap(
Transcription,
query.tuples(),
query = Transcription.select(
Transcription.id, Transcription.text, Transcription.worker_version
).where((Transcription.element == element_id))
if transcription_worker_version is not None:
query = query.where(
build_worker_version_filter(
Transcription, worker_version=transcription_worker_version
)
)
)
return query
def get_transcription_entities(
transcription_id: str, entity_worker_version: Union[str, bool]
) -> List[Entity]:
) -> List[TranscriptionEntity]:
"""
Retrieve transcription entities from an SQLite export of an Arkindex corpus
"""
query = (
ArkindexTranscriptionEntity.select(
ArkindexEntityType.name,
ArkindexEntity.name,
ArkindexTranscriptionEntity.offset,
ArkindexTranscriptionEntity.length,
)
.join(ArkindexEntity, on=ArkindexTranscriptionEntity.entity)
.join(ArkindexEntityType, on=ArkindexEntity.type)
.where(
(ArkindexTranscriptionEntity.transcription == transcription_id)
& build_worker_version_filter(
ArkindexTranscriptionEntity, worker_version=entity_worker_version
)
TranscriptionEntity.select(
EntityType.name.alias("type"),
Entity.name.alias("name"),
TranscriptionEntity.offset,
TranscriptionEntity.length,
TranscriptionEntity.worker_version,
)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
.where((TranscriptionEntity.transcription == transcription_id))
)
return list(
starmap(
Entity,
query.order_by(ArkindexTranscriptionEntity.offset).tuples(),
if entity_worker_version is not None:
query = query.where(
build_worker_version_filter(
TranscriptionEntity, worker_version=entity_worker_version
)
)
)
return query.order_by(TranscriptionEntity.offset).namedtuples()
......@@ -12,7 +12,6 @@ from tqdm import tqdm
from dan import logger
from dan.datasets.extract.db import (
Element,
Entity,
get_elements,
get_transcription_entities,
get_transcriptions,
......@@ -51,8 +50,8 @@ class ArkindexExtractor:
load_entities: bool = None,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
train_prob: float = None,
val_prob: float = None,
max_width: Optional[int] = None,
......@@ -100,7 +99,7 @@ class ArkindexExtractor:
def get_random_split(self):
return next(self._assign_random_split())
def reconstruct_text(self, text: str, entities: List[Entity]):
def reconstruct_text(self, text: str, entities):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
......@@ -226,8 +225,8 @@ def run(
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Union[str, bool],
entity_worker_version: Union[str, bool],
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
train_prob,
val_prob,
max_width: Optional[int],
......
# -*- coding: utf-8 -*-
import json
import os
import random
from copy import deepcopy
from enum import Enum
from time import time
import numpy as np
......@@ -454,37 +453,51 @@ class GenericTrainingManager:
def save_params(self):
"""
Output text file containing a summary of all hyperparameters chosen for the training
Output yaml file containing a summary of all hyperparameters chosen for the training
"""
def compute_nb_params(module):
return sum([np.prod(p.size()) for p in list(module.parameters())])
def class_to_str_dict(my_dict):
for key in my_dict.keys():
if callable(my_dict[key]):
my_dict[key] = my_dict[key].__name__
elif isinstance(my_dict[key], np.ndarray):
my_dict[key] = my_dict[key].tolist()
elif isinstance(my_dict[key], dict):
my_dict[key] = class_to_str_dict(my_dict[key])
return my_dict
path = os.path.join(self.paths["results"], "params")
path = os.path.join(self.paths["results"], "parameters.yml")
if os.path.isfile(path):
return
params = class_to_str_dict(my_dict=deepcopy(self.params))
total_params = 0
for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name])
params["model_params"]["models"][model_name] = [
params["model_params"]["models"][model_name],
"{:,}".format(current_params),
]
total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params)
params = {
"parameters": {
"max_char_prediction": self.params["training_params"][
"max_char_prediction"
],
"encoder": {
"dropout": self.params["model_params"]["dropout"],
},
"decoder": {
key: self.params["model_params"][key]
for key in [
"enc_dim",
"l_max",
"h_max",
"w_max",
"dec_num_layers",
"dec_num_heads",
"dec_res_dropout",
"dec_pred_dropout",
"dec_att_dropout",
"dec_dim_feedforward",
"vocab_size",
"attention_win",
]
},
"preprocessings": [
{
key: value.value if isinstance(value, Enum) else value
for key, value in preprocessing.items()
}
for preprocessing in self.params["dataset_params"]["config"].get(
"preprocessings", []
)
],
},
}
with open(path, "w") as f:
json.dump(params, f, indent=4)
yaml.dump(params, f)
def backward_loss(self, loss, retain_graph=False):
self.scaler.scale(loss).backward(retain_graph=retain_graph)
......
......@@ -45,31 +45,4 @@ Once the training is complete, you can apply a trained DAN model on an image.
To do this, you will need to:
1. Create a `parameters.yml` file using the parameters saved during training in the `params` file, located in `{training_params.output_folder}/results`. This file should have the following format:
```yml
version: 0.0.1
parameters:
max_char_prediction: int
encoder:
dropout: float
decoder:
enc_dim: int
l_max: int
dec_pred_dropout: float
attention_win: int
vocab_size: int
h_max: int
w_max: int
dec_num_layers: int
dec_dim_feedforward: int
dec_num_heads: int
dec_att_dropout: float
dec_res_dropout: float
preprocessings:
- type: str
max_height: int
max_width: int
fixed_height: int
fixed_width: int
```
2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
1. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
......@@ -52,7 +52,7 @@ Usage:
[
{
"type": Preprocessing.FixedWidthResize,
"fixed_height": 1500,
"fixed_width": 1500,
}
]
```
......
---
version: 0.0.1
parameters:
max_char_prediction: 200
encoder:
......
......@@ -4,8 +4,6 @@ import pytest
from dan.datasets.extract.db import (
Element,
Entity,
Transcription,
get_elements,
get_transcription_entities,
get_transcriptions,
......@@ -35,7 +33,7 @@ def test_get_elements():
@pytest.mark.parametrize(
"worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60")
"worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60", None)
)
def test_get_transcriptions(worker_version):
"""
......@@ -48,22 +46,17 @@ def test_get_transcriptions(worker_version):
)
# Check number of results
assert len(transcriptions) == 1
transcription = transcriptions.pop()
assert isinstance(transcription, Transcription)
# Common keys
assert transcription.text == "[ T 8º SUP 26200"
# Differences
if worker_version:
assert transcription.id == "3bd248d6-998a-4579-a00c-d4639f3825aa"
else:
assert transcription.id == "c551960a-0f82-4779-b975-77a457bcf273"
assert len(transcriptions) == 1 + int(worker_version is None)
for transcription in transcriptions:
assert transcription.text == "[ T 8º SUP 26200"
if worker_version:
assert transcription.worker_version.id == worker_version
elif worker_version is False:
assert transcription.worker_version is None
@pytest.mark.parametrize(
"worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965")
"worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965", None)
)
def test_get_transcription_entities(worker_version):
transcription_id = "3bd248d6-998a-4579-a00c-d4639f3825aa"
......@@ -73,18 +66,15 @@ def test_get_transcription_entities(worker_version):
)
# Check number of results
assert len(entities) == 1
transcription_entity = entities.pop()
assert isinstance(transcription_entity, Entity)
# Differences
if worker_version:
assert transcription_entity.type == "cote"
assert transcription_entity.value == "T 8 º SUP 26200"
assert transcription_entity.offset == 2
assert transcription_entity.length == 14
else:
assert transcription_entity.type == "Cote"
assert transcription_entity.value == "[ T 8º SUP 26200"
assert transcription_entity.offset == 0
assert transcription_entity.length == 16
assert len(entities) == 1 + (worker_version is None)
for transcription_entity in entities:
if worker_version:
assert transcription_entity.type == "cote"
assert transcription_entity.name == "T 8 º SUP 26200"
assert transcription_entity.offset == 2
assert transcription_entity.length == 14
elif worker_version is False:
assert transcription_entity.type == "Cote"
assert transcription_entity.name == "[ T 8º SUP 26200"
assert transcription_entity.offset == 0
assert transcription_entity.length == 16
# -*- coding: utf-8 -*-
from typing import NamedTuple
import pytest
from dan.datasets.extract.extract import ArkindexExtractor, Entity
from dan.datasets.extract.extract import ArkindexExtractor
from dan.datasets.extract.utils import EntityType, insert_token
# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
@pytest.mark.parametrize(
"text,count,offset,length,expected",
......
......@@ -9,7 +9,7 @@ from tests.conftest import FIXTURES
@pytest.mark.parametrize(
"expected_best_model_name, expected_last_model_name, training_res, val_res, test_res",
"expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res",
(
(
"best_0.pt",
......@@ -41,6 +41,33 @@ from tests.conftest import FIXTURES
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"parameters": {
"max_char_prediction": 30,
"encoder": {"dropout": 0.5},
"decoder": {
"enc_dim": 256,
"l_max": 15000,
"h_max": 500,
"w_max": 1000,
"dec_num_layers": 8,
"dec_num_heads": 4,
"dec_res_dropout": 0.1,
"dec_pred_dropout": 0.1,
"dec_att_dropout": 0.1,
"dec_dim_feedforward": 256,
"vocab_size": 96,
"attention_win": 100,
},
"preprocessings": [
{
"max_height": 2000,
"max_width": 2000,
"type": "max_resize",
}
],
},
},
),
),
)
......@@ -50,6 +77,7 @@ def test_train_and_test(
training_res,
val_res,
test_res,
params_res,
training_config,
tmp_path,
):
......@@ -146,3 +174,13 @@ def test_train_and_test(
if "time" not in metric
}
assert res == expected_res
# Check that the parameters file is correct
with (
tmp_path
/ training_config["training_params"]["output_folder"]
/ "results"
/ "parameters.yml"
).open() as f:
res = yaml.safe_load(f)
assert res == params_res