diff --git a/dan/datasets/analyze/statistics.py b/dan/datasets/analyze/statistics.py index d6a716ed1238717fdf33d4854f30f9c68a4e3402..e937ab659ab5d7aedb2d79ab44c55f5648f5d32f 100644 --- a/dan/datasets/analyze/statistics.py +++ b/dan/datasets/analyze/statistics.py @@ -3,7 +3,7 @@ import logging from collections import Counter, defaultdict from functools import partial from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List import imagesize import numpy as np @@ -157,7 +157,7 @@ class Statistics: level=3, ) - def run(self, labels: Dict, tokens: Optional[Dict]): + def run(self, labels: Dict, tokens: Dict | None): # Iterate over each split for split_name, split_data in labels.items(): self.document.new_header(level=1, title=split_name.capitalize()) @@ -175,7 +175,7 @@ class Statistics: self.document.create_md_file() -def run(labels: Dict, tokens: Optional[Dict], output: Path) -> None: +def run(labels: Dict, tokens: Dict | None, output: Path) -> None: """ Compute and save a dataset statistics. """ diff --git a/dan/datasets/download/images.py b/dan/datasets/download/images.py index feaf10fd69aa4233dc25d0184838583dc6360700..b492377fb7a2d7091638ef55d06867e6f7b03496 100644 --- a/dan/datasets/download/images.py +++ b/dan/datasets/download/images.py @@ -5,7 +5,7 @@ import logging from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import cv2 import numpy as np @@ -37,9 +37,9 @@ class ImageDownloader: def __init__( self, - output: Path = None, - max_width: Optional[int] = None, - max_height: Optional[int] = None, + output: Path | None = None, + max_width: int | None = None, + max_height: int | None = None, image_extension: str = "", ) -> None: self.output = output @@ -61,7 +61,7 @@ class ImageDownloader: self.data: Dict = defaultdict(dict) - def check_extraction(self, values: dict) -> Optional[str]: + def check_extraction(self, values: dict) -> str | None: # Check image parameters if not (image := values.get("image")): return "Image information not found" @@ -245,8 +245,8 @@ class ImageDownloader: def run( output: Path, - max_width: Optional[int], - max_height: Optional[int], + max_width: int | None, + max_height: int | None, image_format: str, ): """ diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 6dbf84bd7c987413206282291e4fe6f1a7212dc3..ad8155c3bbd3a59855a687543171ffffa46e669a 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -5,7 +5,6 @@ Extract dataset from Arkindex using a corpus export. import argparse import pathlib -from typing import Union from uuid import UUID from dan.datasets.extract.arkindex import run @@ -13,7 +12,7 @@ from dan.datasets.extract.arkindex import run MANUAL_SOURCE = "manual" -def parse_worker_version(worker_version_id) -> Union[str, bool]: +def parse_worker_version(worker_version_id) -> str | bool: if worker_version_id == MANUAL_SOURCE: return False diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index 28fd03c0c3e44045b46991a947378ba8663db3a2..0c8b5abc46dbe8ec069a33460e710ac46324cd85 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -6,7 +6,7 @@ import pickle import random from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List from uuid import UUID from tqdm import tqdm @@ -50,13 +50,13 @@ class ArkindexExtractor: self, folders: list = [], element_type: List[str] = [], - parent_element_type: str = None, - output: Path = None, + parent_element_type: str | None = None, + output: Path | None = None, entity_separators: List[str] = ["\n", " "], unknown_token: str = "â‡", - tokens: Path = None, - transcription_worker_version: Optional[Union[str, bool]] = None, - entity_worker_version: Optional[Union[str, bool]] = None, + tokens: Path | None = None, + transcription_worker_version: str | bool | None = None, + entity_worker_version: str | bool | None = None, keep_spaces: bool = False, allow_empty: bool = False, subword_vocab_size: int = 1000, @@ -122,7 +122,7 @@ class ArkindexExtractor: ) ) - def format_text(self, text: str, charset: Optional[set] = None): + def format_text(self, text: str, charset: set | None = None): if not self.keep_spaces: text = normalize_spaces(text) text = normalize_linebreaks(text) @@ -312,8 +312,8 @@ def run( train_folder: UUID, val_folder: UUID, test_folder: UUID, - transcription_worker_version: Optional[Union[str, bool]], - entity_worker_version: Optional[Union[str, bool]], + transcription_worker_version: str | bool | None, + entity_worker_version: str | bool | None, keep_spaces: bool, allow_empty: bool, subword_vocab_size: int, diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index 7f9b7aea8d04ef4ac19ab685c8e4fb82c988ff80..5aeccbf0d78ae744d8d531d20adf18cc20dbcfa8 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import List, Optional, Union +from typing import List from arkindex_export import Image from arkindex_export.models import ( @@ -41,7 +41,7 @@ def build_worker_version_filter(ArkindexModel, worker_version): def get_transcriptions( - element_id: str, transcription_worker_version: Union[str, bool] + element_id: str, transcription_worker_version: str | bool ) -> List[Transcription]: """ Retrieve transcriptions from an SQLite export of an Arkindex corpus @@ -61,7 +61,7 @@ def get_transcriptions( def get_transcription_entities( transcription_id: str, - entity_worker_version: Optional[Union[str, bool]], + entity_worker_version: str | bool | None, supported_types: List[str], ) -> List[TranscriptionEntity]: """ diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 7a39107f1678bd5c5c8306490f4a8e2750d89e8d..538e6ef20925e11a5f57df0b0e9f3ea86ab3ff03 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -6,7 +6,7 @@ import re from dataclasses import dataclass, field from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Dict, Iterator, List, Optional, Union +from typing import Dict, Iterator, List import sentencepiece as spm from lxml.etree import Element, SubElement, tostring @@ -72,7 +72,7 @@ class Tokenizer: unknown_token: str outdir: Path mapping: LMTokenMapping - tokens: Optional[EntityType] = None + tokens: EntityType | None = None subword_vocab_size: int = 1000 sentencepiece_model: spm.SentencePieceProcessor = field(init=False) @@ -81,7 +81,7 @@ class Tokenizer: return self.outdir / "subword_tokenizer" @property - def ner_tokens(self) -> Union[List[str], Iterator[str]]: + def ner_tokens(self) -> List[str] | Iterator[str]: if self.tokens is None: return [] return itertools.chain( @@ -179,7 +179,7 @@ def slugify(text: str): return text.replace(" ", "_") -def get_translation_map(tokens: Dict[str, EntityType]) -> Optional[Dict[str, str]]: +def get_translation_map(tokens: Dict[str, EntityType]) -> Dict[str, str] | None: if not tokens: return @@ -247,7 +247,7 @@ class XMLEntity: def entities_to_xml( text: str, predictions: List[TranscriptionEntity], - entity_separators: Optional[List[str]] = None, + entity_separators: List[str] | None = None, ) -> str: """Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag. Its type will be used to name the tag. @@ -267,7 +267,7 @@ def entities_to_xml( return separator return "" - def add_portion(entity_offset: Optional[int] = None): + def add_portion(entity_offset: int | None = None): """ Add the portion of text between entities either: - after the last node, if there is one before diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index b62c0f6c77b48518dd0939e984c186b1377e887d..736caa2f8f589b4c7b9e607740aea1085af6ff4e 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Dict, List, Union +from typing import Dict, List import numpy as np import torch @@ -559,7 +559,7 @@ class CTCLanguageDecoder: def post_process( self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor - ) -> Dict[str, List[Union[str, float]]]: + ) -> Dict[str, List[str | float]]: """ Post-process hypotheses to output JSON. Exports only the best hypothesis for each image. :param hypotheses: List of hypotheses returned by the decoder. @@ -594,7 +594,7 @@ class CTCLanguageDecoder: def __call__( self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor - ) -> Dict[str, List[Union[str, float]]]: + ) -> Dict[str, List[str | float]]: """ Decode a feature vector using n-gram language modelling. :param batch_features: Feature vector of size (batch_size, n_tokens, n_frames). diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index 602e75ac8b354bf1be321c200af6e6c597136d79..f48da65ca20d8976ab2067ef7828bd20b85f996e 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -3,7 +3,7 @@ import re from collections import defaultdict from operator import attrgetter from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List import editdistance import numpy as np @@ -21,9 +21,7 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+") class MetricManager: - def __init__( - self, metric_names: List[str], dataset_name: str, tokens: Optional[Path] - ): + def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None): self.dataset_name: str = dataset_name self.remove_tokens: str = None diff --git a/dan/ocr/predict/attention.py b/dan/ocr/predict/attention.py index 8e1b07e70ac67d172062ac17f8a4292209b9b696..7df72e6e00b0c3cb3275285be17528426a3991d4 100644 --- a/dan/ocr/predict/attention.py +++ b/dan/ocr/predict/attention.py @@ -363,7 +363,7 @@ def get_polygon( max_value: np.float32, offset: int, weights: np.ndarray, - size: Tuple[int, int] = None, + size: Tuple[int, int] | None = None, max_object_height: int = 50, ) -> Tuple[dict, np.ndarray]: """ diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index 85f8e0818c57303ccaaea6306f08529c466323f8..adf9efafe39464ee86dc80242f0a7874091091b2 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -5,7 +5,7 @@ import logging import pickle import re from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import numpy as np import torch @@ -159,7 +159,7 @@ class DAN: word_separators: re.Pattern = parse_delimiters(["\n", " "]), line_separators: re.Pattern = parse_delimiters(["\n"]), tokens: Dict[str, EntityType] = {}, - start_token: str = None, + start_token: str | None = None, max_object_height: int = 50, ) -> dict: """ @@ -426,7 +426,7 @@ def process_batch( def run( - image_dir: Optional[Path], + image_dir: Path, model: Path, output: Path, confidence_score: bool, diff --git a/pyproject.toml b/pyproject.toml index 6c386a06e0f78c712f4af769c5e9109f45954f03..d7cd4179736d35b94bb4e03ba536af3bf3ea8877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,9 @@ select = [ # Isort "I", # Pathlib usage - "PTH" + "PTH", + # Implicit Optional + "RUF013" ] [tool.ruff.isort] diff --git a/tests/conftest.py b/tests/conftest.py index 4bc38b0733cc8ea76ed012eefa9a953f2461058c..2fe9bc4b872f76a7daffe34f36e5f57d7650ec37 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import json import uuid from operator import itemgetter -from typing import List, Optional, Union +from typing import List import pytest @@ -26,7 +26,7 @@ from tests import FIXTURES def mock_database(tmp_path_factory): def create_transcription_entity( transcription: Transcription, - worker_version: Union[str, None], + worker_version: str | None, type: str, name: str, offset: int, @@ -80,7 +80,7 @@ def mock_database(tmp_path_factory): **entity, ) - def create_element(id: str, parent: Optional[Element] = None) -> None: + def create_element(id: str, parent: Element | None = None) -> None: element_path = (FIXTURES / "extraction" / "elements" / id).with_suffix(".json") element_json = json.loads(element_path.read_text())