# -*- coding: utf-8 -*- import json import uuid from operator import itemgetter from typing import List, Optional, Union import pytest from arkindex_export import ( Element, ElementPath, Entity, EntityType, Image, ImageServer, Transcription, TranscriptionEntity, WorkerRun, WorkerVersion, database, ) from dan.ocr.train import update_config from tests import FIXTURES @pytest.fixture(scope="session") def mock_database(tmp_path_factory): def create_transcription_entity( transcription: Transcription, worker_version: Union[str, None], type: str, name: str, offset: int, ) -> None: entity_type, _ = EntityType.get_or_create( name=type, defaults={"id": f"{type}_id"} ) entity = Entity.create( id=str(uuid.uuid4()), name=name, type=entity_type, worker_version=worker_version, ) TranscriptionEntity.create( entity=entity, length=len(name), offset=offset, transcription=transcription, worker_version=worker_version, ) def create_transcriptions(element: Element, entities: List[dict]) -> None: if not entities: return # Add transcription with entities entities = sorted(entities, key=itemgetter("offset")) # We will add extra spaces to test the "keep_spaces" parameters of the "extract" command for offset, entity in enumerate(entities[1:], start=1): entity["offset"] += offset for worker_version in [None, "worker_version_id"]: # Use different transcriptions to filter by worker version if worker_version == "worker_version_id": for entity in entities: entity["name"] = entity["name"].lower() transcription = Transcription.create( id=element.id + (worker_version or ""), # Add extra spaces to test the "keep_spaces" parameters of the "extract" command text=" ".join(map(itemgetter("name"), entities)), element=element, worker_version=worker_version, ) for entity in entities: create_transcription_entity( transcription=transcription, worker_version=worker_version, **entity, ) def create_element(id: str, parent: Optional[Element] = None) -> None: element_path = (FIXTURES / "extraction" / "elements" / id).with_suffix(".json") element_json = json.loads(element_path.read_text()) element_type = element_json["type"] image_path = ( FIXTURES / "extraction" / "images" / element_type / id ).with_suffix(".jpg") polygon = element_json.get("polygon") # Always use page images because polygons are based on the full image image, _ = ( Image.get_or_create( id=id + "-image", defaults={ "server": image_server, # Use path to image instead of actual URL since we won't be doing any download "url": image_path, "width": 0, "height": 0, }, ) if polygon else (None, False) ) element = Element.create( id=id, name=id, type=element_type, image=image, polygon=json.dumps(polygon) if polygon else None, created=0.0, updated=0.0, ) if parent: ElementPath.create(id=str(uuid.uuid4()), parent=parent, child=element) create_transcriptions( element=element, entities=element_json.get("transcription_entities", []), ) # Recursive function to create children for child in element_json.get("children", []): create_element(id=child, parent=element) MODELS = [ WorkerVersion, WorkerRun, ImageServer, Image, Element, ElementPath, EntityType, Entity, Transcription, TranscriptionEntity, ] # Initialisation tmp_path = tmp_path_factory.mktemp("data") database_path = tmp_path / "db.sqlite" database.init( database_path, pragmas={ # Recommended settings from peewee # http://docs.peewee-orm.com/en/latest/peewee/database.html#recommended-settings # Do not set journal mode to WAL as it writes in the database "cache_size": -1 * 64000, # 64MB "foreign_keys": 1, "ignore_check_constraints": 0, "synchronous": 0, }, ) database.connect() # Create tables database.create_tables(MODELS) image_server = ImageServer.create( url="http://image/server/url", display_name="Image server", ) WorkerVersion.create( id="worker_version_id", slug="worker_version", name="Worker version", repository_url="http://repository/url", revision="main", type="worker", ) # Create folders create_element(id="root") return database_path @pytest.fixture def training_config(): config = { "dataset": { "datasets": { "training": str(FIXTURES / "training" / "training_dataset"), }, "train": { "name": "training-train", "datasets": [ ("training", "train"), ], }, "val": { "training-val": [ ("training", "val"), ], }, "test": { "training-test": [ ("training", "test"), ], }, "max_char_prediction": 30, # max number of token prediction "tokens": None, }, "model": { "transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model "additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset "encoder": { "dropout": 0.5, # dropout rate for encoder "nb_layers": 5, # encoder }, "h_max": 500, # maximum height for encoder output (for 2D positional embedding) "w_max": 1000, # maximum width for encoder output (for 2D positional embedding) "decoder": { "l_max": 15000, # max predicted sequence (for 1D positional embedding) "dec_num_layers": 8, # number of transformer decoder layers "dec_num_heads": 4, # number of heads in transformer decoder layers "dec_res_dropout": 0.1, # dropout in transformer decoder layers "dec_pred_dropout": 0.1, # dropout rate before decision layer "dec_att_dropout": 0.1, # dropout rate in multi head attention "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "attention_win": 100, # length of attention window "enc_dim": 256, # dimension of extracted features }, }, "training": { "data": { "batch_size": 2, # mini-batch size for training "load_in_memory": True, # Load all images in CPU memory "worker_per_gpu": 4, # Num of parallel processes per gpu for data loading "preprocessings": [ { "type": "max_resize", "max_width": 2000, "max_height": 2000, } ], "augmentation": True, }, "device": { "use_ddp": False, # Use DistributedDataParallel "ddp_port": "20027", "use_amp": True, # Enable automatic mix-precision "nb_gpu": 0, "force_cpu": True, # True for debug purposes }, "metrics": { "train": [ "loss_ce", "cer", "wer", "wer_no_punct", ], # Metrics name for training "eval": [ "cer", "wer", "wer_no_punct", ], # Metrics name for evaluation on validation set during training }, "validation": { "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not "eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training "set_name_focus_metric": "training-val", }, "output_folder": "dan_trained_model", # folder name for checkpoint and results "gradient_clipping": {}, "max_nb_epochs": 4, # maximum number of epochs before to stop "load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate "optimizers": { "all": { "args": { "lr": 0.0001, "amsgrad": False, }, }, }, "lr_schedulers": None, # Learning rate schedulers # Keep teacher forcing rate to 20% during whole training "label_noise_scheduler": { "min_error_rate": 0.2, "max_error_rate": 0.2, "total_num_steps": 5e4, }, "transfer_learning": None, }, } update_config(config) return config @pytest.fixture def prediction_data_path(): return FIXTURES / "prediction"