Skip to content
Snippets Groups Projects
Commit 4df2bbd3 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'typing-use-logical-or' into 'main'

Typing with `|` instead of `Union`/`Optional`

Closes #235

See merge request !324
parents dbfffbec 1d68b924
No related branches found
No related tags found
1 merge request!324Typing with `|` instead of `Union`/`Optional`
......@@ -3,7 +3,7 @@ import logging
from collections import Counter, defaultdict
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List
import imagesize
import numpy as np
......@@ -157,7 +157,7 @@ class Statistics:
level=3,
)
def run(self, labels: Dict, tokens: Optional[Dict]):
def run(self, labels: Dict, tokens: Dict | None):
# Iterate over each split
for split_name, split_data in labels.items():
self.document.new_header(level=1, title=split_name.capitalize())
......@@ -175,7 +175,7 @@ class Statistics:
self.document.create_md_file()
def run(labels: Dict, tokens: Optional[Dict], output: Path) -> None:
def run(labels: Dict, tokens: Dict | None, output: Path) -> None:
"""
Compute and save a dataset statistics.
"""
......
......@@ -5,7 +5,7 @@ import logging
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import cv2
import numpy as np
......@@ -37,9 +37,9 @@ class ImageDownloader:
def __init__(
self,
output: Path = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
output: Path | None = None,
max_width: int | None = None,
max_height: int | None = None,
image_extension: str = "",
) -> None:
self.output = output
......@@ -61,7 +61,7 @@ class ImageDownloader:
self.data: Dict = defaultdict(dict)
def check_extraction(self, values: dict) -> Optional[str]:
def check_extraction(self, values: dict) -> str | None:
# Check image parameters
if not (image := values.get("image")):
return "Image information not found"
......@@ -245,8 +245,8 @@ class ImageDownloader:
def run(
output: Path,
max_width: Optional[int],
max_height: Optional[int],
max_width: int | None,
max_height: int | None,
image_format: str,
):
"""
......
......@@ -5,7 +5,6 @@ Extract dataset from Arkindex using a corpus export.
import argparse
import pathlib
from typing import Union
from uuid import UUID
from dan.datasets.extract.arkindex import run
......@@ -13,7 +12,7 @@ from dan.datasets.extract.arkindex import run
MANUAL_SOURCE = "manual"
def parse_worker_version(worker_version_id) -> Union[str, bool]:
def parse_worker_version(worker_version_id) -> str | bool:
if worker_version_id == MANUAL_SOURCE:
return False
......
......@@ -6,7 +6,7 @@ import pickle
import random
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List
from uuid import UUID
from tqdm import tqdm
......@@ -50,13 +50,13 @@ class ArkindexExtractor:
self,
folders: list = [],
element_type: List[str] = [],
parent_element_type: str = None,
output: Path = None,
parent_element_type: str | None = None,
output: Path | None = None,
entity_separators: List[str] = ["\n", " "],
unknown_token: str = "",
tokens: Path = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
tokens: Path | None = None,
transcription_worker_version: str | bool | None = None,
entity_worker_version: str | bool | None = None,
keep_spaces: bool = False,
allow_empty: bool = False,
subword_vocab_size: int = 1000,
......@@ -122,7 +122,7 @@ class ArkindexExtractor:
)
)
def format_text(self, text: str, charset: Optional[set] = None):
def format_text(self, text: str, charset: set | None = None):
if not self.keep_spaces:
text = normalize_spaces(text)
text = normalize_linebreaks(text)
......@@ -312,8 +312,8 @@ def run(
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
transcription_worker_version: str | bool | None,
entity_worker_version: str | bool | None,
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
......
# -*- coding: utf-8 -*-
from typing import List, Optional, Union
from typing import List
from arkindex_export import Image
from arkindex_export.models import (
......@@ -41,7 +41,7 @@ def build_worker_version_filter(ArkindexModel, worker_version):
def get_transcriptions(
element_id: str, transcription_worker_version: Union[str, bool]
element_id: str, transcription_worker_version: str | bool
) -> List[Transcription]:
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
......@@ -61,7 +61,7 @@ def get_transcriptions(
def get_transcription_entities(
transcription_id: str,
entity_worker_version: Optional[Union[str, bool]],
entity_worker_version: str | bool | None,
supported_types: List[str],
) -> List[TranscriptionEntity]:
"""
......
......@@ -6,7 +6,7 @@ import re
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, Iterator, List, Optional, Union
from typing import Dict, Iterator, List
import sentencepiece as spm
from lxml.etree import Element, SubElement, tostring
......@@ -72,7 +72,7 @@ class Tokenizer:
unknown_token: str
outdir: Path
mapping: LMTokenMapping
tokens: Optional[EntityType] = None
tokens: EntityType | None = None
subword_vocab_size: int = 1000
sentencepiece_model: spm.SentencePieceProcessor = field(init=False)
......@@ -81,7 +81,7 @@ class Tokenizer:
return self.outdir / "subword_tokenizer"
@property
def ner_tokens(self) -> Union[List[str], Iterator[str]]:
def ner_tokens(self) -> List[str] | Iterator[str]:
if self.tokens is None:
return []
return itertools.chain(
......@@ -179,7 +179,7 @@ def slugify(text: str):
return text.replace(" ", "_")
def get_translation_map(tokens: Dict[str, EntityType]) -> Optional[Dict[str, str]]:
def get_translation_map(tokens: Dict[str, EntityType]) -> Dict[str, str] | None:
if not tokens:
return
......@@ -247,7 +247,7 @@ class XMLEntity:
def entities_to_xml(
text: str,
predictions: List[TranscriptionEntity],
entity_separators: Optional[List[str]] = None,
entity_separators: List[str] | None = None,
) -> str:
"""Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag.
Its type will be used to name the tag.
......@@ -267,7 +267,7 @@ def entities_to_xml(
return separator
return ""
def add_portion(entity_offset: Optional[int] = None):
def add_portion(entity_offset: int | None = None):
"""
Add the portion of text between entities either:
- after the last node, if there is one before
......
# -*- coding: utf-8 -*-
from typing import Dict, List, Union
from typing import Dict, List
import numpy as np
import torch
......@@ -559,7 +559,7 @@ class CTCLanguageDecoder:
def post_process(
self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
) -> Dict[str, List[str | float]]:
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
:param hypotheses: List of hypotheses returned by the decoder.
......@@ -594,7 +594,7 @@ class CTCLanguageDecoder:
def __call__(
self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
) -> Dict[str, List[str | float]]:
"""
Decode a feature vector using n-gram language modelling.
:param batch_features: Feature vector of size (batch_size, n_tokens, n_frames).
......
......@@ -3,7 +3,7 @@ import re
from collections import defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List
import editdistance
import numpy as np
......@@ -21,9 +21,7 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
class MetricManager:
def __init__(
self, metric_names: List[str], dataset_name: str, tokens: Optional[Path]
):
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self.dataset_name: str = dataset_name
self.remove_tokens: str = None
......
......@@ -363,7 +363,7 @@ def get_polygon(
max_value: np.float32,
offset: int,
weights: np.ndarray,
size: Tuple[int, int] = None,
size: Tuple[int, int] | None = None,
max_object_height: int = 50,
) -> Tuple[dict, np.ndarray]:
"""
......
......@@ -5,7 +5,7 @@ import logging
import pickle
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -159,7 +159,7 @@ class DAN:
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
start_token: str = None,
start_token: str | None = None,
max_object_height: int = 50,
) -> dict:
"""
......@@ -426,7 +426,7 @@ def process_batch(
def run(
image_dir: Optional[Path],
image_dir: Path,
model: Path,
output: Path,
confidence_score: bool,
......
......@@ -12,7 +12,9 @@ select = [
# Isort
"I",
# Pathlib usage
"PTH"
"PTH",
# Implicit Optional
"RUF013"
]
[tool.ruff.isort]
......
......@@ -2,7 +2,7 @@
import json
import uuid
from operator import itemgetter
from typing import List, Optional, Union
from typing import List
import pytest
......@@ -26,7 +26,7 @@ from tests import FIXTURES
def mock_database(tmp_path_factory):
def create_transcription_entity(
transcription: Transcription,
worker_version: Union[str, None],
worker_version: str | None,
type: str,
name: str,
offset: int,
......@@ -80,7 +80,7 @@ def mock_database(tmp_path_factory):
**entity,
)
def create_element(id: str, parent: Optional[Element] = None) -> None:
def create_element(id: str, parent: Element | None = None) -> None:
element_path = (FIXTURES / "extraction" / "elements" / id).with_suffix(".json")
element_json = json.loads(element_path.read_text())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment