diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 6771e77bc522d435ea8e2b163f273fa4d3a2e1e1..a7c83525c437f06306b5e8bbf77e8e9e576fd1ff 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -106,8 +106,6 @@ def add_extract_parser(subcommands) -> None: If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text. """, - required=False, - default=list(map(validate_char, ("\n", " "))), ) parser.add_argument( "--unknown-token", diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index 1e953cf83a2a186e15d22b8cfea493634716897b..31af3e0a25565689786da8c3f236f0a172d6bf98 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -24,7 +24,6 @@ from dan.datasets.extract.db import ( ) from dan.datasets.extract.exceptions import ( ImageDownloadError, - NoEndTokenError, NoTranscriptionError, ProcessingError, UnknownTokenInText, @@ -32,13 +31,14 @@ from dan.datasets.extract.exceptions import ( from dan.datasets.extract.utils import ( Tokenizer, download_image, + entities_to_xml, get_bbox, + get_translation_map, get_vocabulary, - insert_token, normalize_linebreaks, normalize_spaces, ) -from dan.utils import EntityType, LMTokenMapping, parse_tokens +from dan.utils import LMTokenMapping, parse_tokens from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, @@ -87,7 +87,7 @@ class ArkindexExtractor: self.output = output self.entity_separators = entity_separators self.unknown_token = unknown_token - self.tokens = parse_tokens(tokens) if tokens else None + self.tokens = parse_tokens(tokens) if tokens else {} self.transcription_worker_version = transcription_worker_version self.entity_worker_version = entity_worker_version self.max_width = max_width @@ -107,6 +107,9 @@ class ArkindexExtractor: # Image download tasks to process self.tasks: List[Dict[str, str]] = [] + # NER extraction + self.translation_map: Dict[str, str] | None = get_translation_map(self.tokens) + def get_iiif_size_arg(self, width: int, height: int) -> str: if (self.max_width is None or width <= self.max_width) and ( self.max_height is None or height <= self.max_height @@ -136,67 +139,13 @@ class ArkindexExtractor: image_url=image_url, bbox=get_bbox(polygon), size=size ) - def _keep_char(self, char: str) -> bool: - # Keep all text by default if no separator was given - return not self.entity_separators or char in self.entity_separators - - def reconstruct_text(self, full_text: str, entities) -> str: + def translate(self, text: str): """ - Insert tokens delimiting the start/end of each entity on the transcription. + Use translation map to replace XML tags to actual tokens """ - text, text_offset = "", 0 - for entity in entities: - # Text before entity - text += "".join( - filter(self._keep_char, full_text[text_offset : entity.offset]) - ) - - entity_type: EntityType = self.tokens.get(entity.type) - if not entity_type: - logger.warning( - f"Label `{entity.type}` is missing in the NER configuration." - ) - # We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends - elif not entity_type.end and not self.entity_separators: - raise NoEndTokenError(entity.type) - - # Entity text: - # - with tokens if there is an entity_type - # - without tokens if there is no entity_type but we want to keep the whole text - if entity_type or not self.entity_separators: - text += insert_token( - full_text, - entity_type, - offset=entity.offset, - length=entity.length, - ) - text_offset = entity.offset + entity.length - - # Remaining text after the last entity - text += "".join(filter(self._keep_char, full_text[text_offset:])) - - if not self.entity_separators or self.keep_spaces: - return text - - # Add some clean up to avoid several separators between entities - text, full_text = "", text - for char in full_text: - last_char = text[-1] if len(text) else "" - - # Keep the current character if there are no two consecutive separators - if ( - char not in self.entity_separators - or last_char not in self.entity_separators - ): - text += char - # If several separators follow each other, keep only one according to the given order - elif self.entity_separators.index(char) < self.entity_separators.index( - last_char - ): - text = text[:-1] + char - - # Remove separators at the beginning and end of text - return text.strip("".join(self.entity_separators)) + for pattern, repl in self.translation_map.items(): + text = text.replace(pattern, repl) + return text def extract_transcription(self, element: Element): """ @@ -217,9 +166,16 @@ class ArkindexExtractor: return transcription.text.strip() entities = get_transcription_entities( - transcription.id, self.entity_worker_version + transcription.id, + self.entity_worker_version, + supported_types=list(self.tokens), + ) + + return self.translate( + entities_to_xml( + transcription.text, entities, entity_separators=self.entity_separators + ) ) - return self.reconstruct_text(transcription.text, entities) def get_image( self, diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index a8933d063b09c6e7f7174f15f40af7698fdf7055..7f9b7aea8d04ef4ac19ab685c8e4fb82c988ff80 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import List, Union +from typing import List, Optional, Union from arkindex_export import Image from arkindex_export.models import ( @@ -60,7 +60,9 @@ def get_transcriptions( def get_transcription_entities( - transcription_id: str, entity_worker_version: Union[str, bool] + transcription_id: str, + entity_worker_version: Optional[Union[str, bool]], + supported_types: List[str], ) -> List[TranscriptionEntity]: """ Retrieve transcription entities from an SQLite export of an Arkindex corpus @@ -75,7 +77,10 @@ def get_transcription_entities( ) .join(Entity, on=TranscriptionEntity.entity) .join(EntityType, on=Entity.type) - .where((TranscriptionEntity.transcription == transcription_id)) + .where( + TranscriptionEntity.transcription == transcription_id, + EntityType.name.in_(supported_types), + ) ) if entity_worker_version is not None: @@ -85,4 +90,6 @@ def get_transcription_entities( ) ) - return query.order_by(TranscriptionEntity.offset).namedtuples() + return query.order_by( + TranscriptionEntity.offset, TranscriptionEntity.length.desc() + ).dicts() diff --git a/dan/datasets/extract/exceptions.py b/dan/datasets/extract/exceptions.py index 2a703b1a8b26d270e73e1869ed960fe95e03fc85..ef7ba5b9b36809ac013ff2b6d6195d96681fcf08 100644 --- a/dan/datasets/extract/exceptions.py +++ b/dan/datasets/extract/exceptions.py @@ -52,18 +52,3 @@ class UnknownTokenInText(ElementProcessingError): def __str__(self) -> str: return f"Unknown token found in the transcription text of element ({self.element_id})" - - -class NoEndTokenError(ProcessingError): - """ - Raised when the specified label has no end token and there is potentially additional text around the labels - """ - - label: str - - def __init__(self, label: str, *args: object) -> None: - super().__init__(*args) - self.label = label - - def __str__(self) -> str: - return f"Label `{self.label}` has no end token." diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 70903330686018c93d1cb871453703bd662e6ec6..ef43c7bd81ca8673299d1ef92c31083525038ba5 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -7,10 +7,11 @@ from dataclasses import dataclass, field from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Iterator, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union import requests import sentencepiece as spm +from lxml.etree import Element, SubElement, tostring from nltk import wordpunct_tokenize from PIL import Image, ImageOps from tenacity import ( @@ -20,6 +21,7 @@ from tenacity import ( wait_exponential, ) +from arkindex_export import TranscriptionEntity from dan.utils import EntityType, LMTokenMapping logger = logging.getLogger(__name__) @@ -31,6 +33,14 @@ DOWNLOAD_TIMEOUT = (30, 60) TRIM_SPACE_REGEX = re.compile(r"[\t ]+") TRIM_RETURN_REGEX = re.compile(r"[\r\n]+") +# Some characters are encoded in XML but we don't want them encoded in the end +ENCODING_MAP = { + " ": "\r", + "<": "<", + ">": ">", + "&": "&", +} + def _retry_log(retry_state, *args, **kwargs): logger.warning( @@ -83,20 +93,6 @@ def download_image(url): return image -def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str: - """ - Insert the given tokens at the right position in the text - """ - return ( - # Starting token - (entity_type.start if entity_type else "") - # Entity - + text[offset : offset + length] - # End token - + (entity_type.end if entity_type else "") - ) - - def normalize_linebreaks(text: str) -> str: """ Remove begin/ending linebreaks. @@ -248,3 +244,161 @@ class Tokenizer: :param text: Text to be encoded. """ return map(self.mapping.encode_token, text) + + +def slugify(text: str): + """ + Replace space in text to underscores to use it as XML tag. + """ + return text.replace(" ", "_") + + +def get_translation_map(tokens: Dict[str, EntityType]) -> Optional[Dict[str, str]]: + if not tokens: + return + + translation_map = { + # Roots + "<root>": "", + "</root>": "", + } + # Tokens + for entity_name, token_type in tokens.items(): + translation_map[f"<{slugify(entity_name)}>"] = token_type.start + translation_map[f"</{slugify(entity_name)}>"] = token_type.end + + return translation_map + + +@dataclass +class XMLEntity: + type: str + name: str + offset: int + length: int + worker_version: str + children: List["XMLEntity"] = field(default_factory=list) + + @property + def end(self) -> int: + return self.offset + self.length + + def add_child(self, child: TranscriptionEntity): + self.children.append( + XMLEntity( + type=child["type"], + name=child["name"], + offset=child["offset"] - self.offset, + length=child["length"], + worker_version=child["worker_version"], + ) + ) + + def insert(self, parent: Element): + e = SubElement(parent, slugify(self.type)) + + if not self.children: + # No children + e.text = self.name + return + + offset = 0 + for child in self.children: + # Add text before entity + portion_before = self.name[offset : child.offset] + offset += len(portion_before) + if len(e): + e[-1].tail = portion_before + else: + e.text = portion_before + child.insert(e) + offset += child.length + + # Text after the last entity + e[-1].tail = self.name[self.children[-1].end : self.end] + + +def entities_to_xml( + text: str, + predictions: List[TranscriptionEntity], + entity_separators: Optional[List[str]] = None, +) -> str: + """Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag. + Its type will be used to name the tag. + + :param text: The text of the transcription + :param predictions: The list of entities linked to the transcription + :param entity_separators: When provided, instead of adding the text between entities, add one separator encountered in this text. The order is kept when looking for separators. Defaults to None + :return: The representation of the transcription in XML format + """ + + def _find_separator(transcription: str) -> str: + """ + Find the first entity separator in the provided transcription. + """ + for separator in entity_separators: + if separator in transcription: + return separator + return "" + + def add_portion(entity_offset: Optional[int] = None): + """ + Add the portion of text between entities either: + - after the last node, if there is one before + - on this node + + If we remove the text between entities, we keep one of the separators provided. Order matters. + """ + portion = text[offset:entity_offset] + + if entity_separators: + # Remove the text except the first entity_separator encountered + portion = _find_separator(portion) + + if len(root): + root[-1].tail = portion + else: + root.text = portion + + entities = iter(predictions) + + # This will mark the ending position of the first-level of entities + last_end = None + parsed: List[XMLEntity] = [] + + for entity in entities: + # First entity is not inside any other + # If offset is too high, no nestation + if not last_end or entity["offset"] >= last_end: + parsed.append(XMLEntity(**entity)) + last_end = entity["offset"] + entity["length"] + continue + + # Nested entity + parsed[-1].add_child(entity) + + # XML export + offset = 0 + root = Element("root") + + for entity in parsed: + add_portion(entity.offset) + + entity.insert(root) + + offset = entity.end + + # Add text after last entity + add_portion() + + # Cleanup separators introduced when text was removed + if entity_separators: + characters = "".join(entity_separators) + root.text = root.text.lstrip(characters) + # Strip trailing spaces on last child + root[-1].tail = root[-1].tail.rstrip(characters) + + encoded_transcription = tostring(root, encoding="utf-8").decode() + for pattern, repl in ENCODING_MAP.items(): + encoded_transcription = encoded_transcription.replace(pattern, repl) + return encoded_transcription diff --git a/docs/css/ner.css b/docs/css/ner.css new file mode 100644 index 0000000000000000000000000000000000000000..5628eabae876f2e793f611031924b508c4e08f50 --- /dev/null +++ b/docs/css/ner.css @@ -0,0 +1,42 @@ +.entities-block { + background-color: var(--md-code-bg-color); + padding: .7720588235em 1.1764705882em; +} + +/* Light mode */ +body[data-md-color-scheme="default"] .entities-block > span[type=adj] { +background: #7fffd4; +} + +body[data-md-color-scheme="default"] .entities-block > span[type=name] { + background: #83ff7f; +} + +body[data-md-color-scheme="default"] .entities-block > span[type=person] { + background: #ffc17f; +} + +/* Dark mode */ +body[data-md-color-scheme="slate"] .entities-block > span[type=adj] { + background: #7fffd48c; +} + +body[data-md-color-scheme="slate"] .entities-block > span[type=name] { + background: #83ff7f63; +} + +body[data-md-color-scheme="slate"] .entities-block > span[type=person] { + background: #ffc17f6c; +} + +.entities-block > span { + padding: 0.1em; + border-radius: 0.35em; +} + +.entities-block > span::after { + content: ' ' attr(type); + font-size: 0.8em; + font-weight: bold; + vertical-align: middle; +} diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index e19d8eb6dd70302b9587cd83745f54101c52dd48..b4b7992a644ddc54813793519a598c7384f2263b 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -11,26 +11,26 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind If an image download fails for whatever reason, it won't appear in the transcriptions file. The reason will be printed to stdout at the end of the process. Before trying to download the image, it checks that it wasn't downloaded previously. It is thus safe to run this command twice if a few images failed. -| Parameter | Description | Type | Default | -| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------- | -------------------------------------------------- | -| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | | -| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | -| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | -| `--output` | Folder where the data will be generated. | `pathlib.Path` | | -| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text. | `str` | `["\n", " "]` (see [dedicated section](#examples)) | -| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | -| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | -| `--train-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | -| `--val-folder` | ID of the validation folder to extract from Arkindex. | `uuid` | | -| `--test-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | -| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | -| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | -| `--max-width` | Images larger than this width will be resized to this width. | `int` | | -| `--max-height` | Images larger than this height will be resized to this height. | `int` | | -| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | -| `--image-format` | Images will be saved under this format. | `str` | `.jpg` | -| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` | -| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | +| Parameter | Description | Type | Default | +| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- | +| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | | +| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | +| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | +| `--output` | Folder where the data will be generated. | `pathlib.Path` | | +| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | | +| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | +| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | +| `--train-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | +| `--val-folder` | ID of the validation folder to extract from Arkindex. | `uuid` | | +| `--test-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | +| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | +| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | +| `--max-width` | Images larger than this width will be resized to this width. | `int` | | +| `--max-height` | Images larger than this height will be resized to this height. | `int` | | +| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | +| `--image-format` | Images will be saved under this format. | `str` | `.jpg` | +| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` | +| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md). @@ -75,15 +75,44 @@ teklia-dan dataset extract \ --tokens tokens.yml ``` -If there is no end token, it is possible to define the characters to keep with the `--entity-separators` parameter: +If the model should predict entities only and not the text surrounding them, the `--entity-separators` parameter can be used to list the only characters allowed in the transcription outside of entities. Only one of them will be used between entities, the priority is parsed through the order of the characters. -```shell -teklia-dan dataset extract \ - [...] \ - --entity-separators $'\n' " " -``` +Here is an example of transcription with entities, on two lines: + +<div class="entities-block highlight"> + The + <span type="adj">great</span> + king + <span type="name">Charles</span> + III has eaten <br />with + <span type="person">us</span> + . +</div> + +Here is the extraction with `--entity-separators=" "`: + +<div class="entities-block highlight"> + <span type="adj">great</span> + <span type="name">Charles</span> + <span type="person">us</span> +</div> + +Here is the extraction with `--entity-separators="\n" " "`: + +<div class="entities-block highlight"> + <span type="adj">great</span> + <span type="name">Charles</span> + <br /> + <span type="person">us</span> +</div> + +The order of the argument is important. If the whitespaces are more important than the linebreaks, i.e. `--entity-separators=" " "\n"`, the extraction will result in: -If several separators follow each other, it will keep only one, ideally a line break if there is one, otherwise a space. If you change the order of the `--entity-separators` parameters, then it will keep a space if there is one, otherwise a line break. +<div class="entities-block highlight"> + <span type="adj">great</span> + <span type="name">Charles</span> + <span type="person">us</span> +</div> ### HTR from multiple element types diff --git a/mkdocs.yml b/mkdocs.yml index a4c58007032dff3aa61d4f3505345c6e137e654b..1a9e00fae86e743ae59bb0ccbbe5398adea96882 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -141,3 +141,6 @@ extra: - icon: fontawesome/brands/linkedin name: Teklia @ LinkedIn link: https://www.linkedin.com/company/teklia + +extra_css: + - css/ner.css diff --git a/requirements.txt b/requirements.txt index 1ac038d1abba1c78129d9a37c01b52d32e4177b7..445189ae68ef6febce26b477a5360a308208a7e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ editdistance==0.6.2 flashlight-text==0.0.4 imageio==2.26.1 imagesize==1.4.1 +lxml==4.9.3 mdutils==1.6.0 nltk==3.8.1 numpy==1.24.3 diff --git a/tests/conftest.py b/tests/conftest.py index 1da4a88b3916b99bd6d6b6eb83c1492e4f16da7c..e629efb9c840f52c3273900b51b3665387b6ad8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -178,6 +178,71 @@ def mock_database(tmp_path_factory): # Create folders create_element(id="root") + # Create data for entities extraction tests + # Create transcription + transcription = Transcription.create( + id="tr-with-entities", + text="The great king Charles III has eaten \nwith us.", + element=Element.select().first(), + ) + + WorkerVersion.bulk_create( + [ + WorkerVersion( + id=f"{nestation}-id", + slug=nestation, + name=nestation, + repository_url="http://repository/url", + revision="main", + type="worker", + ) + for nestation in ("nested", "non-nested") + ] + ) + + entities = [ + # Non-nested entities + { + "worker_version": "non-nested-id", + "type": "adj", + "name": "great", + "offset": 4, + }, + { + "worker_version": "non-nested-id", + "type": "name", + "name": "Charles", + "offset": 15, + }, + { + "worker_version": "non-nested-id", + "type": "person", + "name": "us", + "offset": 43, + }, + # Nested entities + { + "worker_version": "nested-id", + "type": "fullname", + "name": "Charles III", + "offset": 15, + }, + { + "worker_version": "nested-id", + "type": "name", + "name": "Charles", + "offset": 15, + }, + { + "worker_version": "nested-id", + "type": "person", + "name": "us", + "offset": 43, + }, + ] + for entity in entities: + create_transcription_entity(transcription=transcription, **entity) + return database_path diff --git a/tests/data/entities.yml b/tests/data/entities.yml index 3b85f005878884c8199020ebbcd28b8f3bf0073a..c690cbc0d122060622d8172ef1c57ff0700765ba 100644 --- a/tests/data/entities.yml +++ b/tests/data/entities.yml @@ -1,5 +1,9 @@ --- entities: +- adj - birthdate - firstname +- fullname +- name +- person - surname diff --git a/tests/data/tokens/end_tokens.yml b/tests/data/tokens/end_tokens.yml index 78a68cab7503bb0583921bd5253d1679662a7dcc..c55ca9946d22e52e0259045727f988240ba8c812 100644 --- a/tests/data/tokens/end_tokens.yml +++ b/tests/data/tokens/end_tokens.yml @@ -1,10 +1,22 @@ --- -birthdate: +adj: start: â’¶ end: â’· -firstname: +birthdate: start: â’¸ end: â’¹ -surname: +firstname: start: â’º end: â’» +fullname: + start: â’¼ + end: â’½ +name: + start: â’¾ + end: â’¿ +person: + start: â“€ + end: â“ +surname: + start: â“‚ + end: Ⓝ diff --git a/tests/data/tokens/no_end_tokens.yml b/tests/data/tokens/no_end_tokens.yml index 08c5790004c038281b1e13fcbddfec032b77d0d7..49ab427c21e83e276130886c15fec8596e2fd7fb 100644 --- a/tests/data/tokens/no_end_tokens.yml +++ b/tests/data/tokens/no_end_tokens.yml @@ -1,10 +1,22 @@ --- -birthdate: +adj: start: â’¶ end: '' -firstname: +birthdate: start: â’· end: '' -surname: +firstname: start: â’¸ end: '' +fullname: + start: â’¹ + end: '' +name: + start: â’º + end: '' +person: + start: â’» + end: '' +surname: + start: â’¼ + end: '' diff --git a/tests/test_db.py b/tests/test_db.py index e22d472b702bf8718c86c4832cc0a99cd8a9908e..60fa969bc7868959b22a4fa2aa95076b992bbf75 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -72,11 +72,15 @@ def test_get_transcriptions(worker_version, mock_database): @pytest.mark.parametrize("worker_version", (False, "worker_version_id", None)) -def test_get_transcription_entities(worker_version, mock_database): +@pytest.mark.parametrize( + "supported_types", (["surname"], ["surname", "firstname", "birthdate"]) +) +def test_get_transcription_entities(worker_version, mock_database, supported_types): transcription_id = "train-page_1-line_1" + (worker_version or "") entities = get_transcription_entities( transcription_id=transcription_id, entity_worker_version=worker_version, + supported_types=supported_types, ) expected_entities = [ @@ -99,23 +103,18 @@ def test_get_transcription_entities(worker_version, mock_database): "length": 7, }, ] + + expected_entities = list( + filter(lambda ent: ent["type"] in supported_types, expected_entities) + ) for entity in expected_entities: if worker_version: entity["name"] = entity["name"].lower() - entity["worker_version_id"] = worker_version or None + entity["worker_version"] = worker_version or None assert ( sorted( - [ - { - "name": transcription_entity.name, - "type": transcription_entity.type, - "offset": transcription_entity.offset, - "length": transcription_entity.length, - "worker_version_id": transcription_entity.worker_version, - } - for transcription_entity in entities - ], + entities, key=itemgetter("offset"), ) == expected_entities diff --git a/tests/test_extract.py b/tests/test_extract.py index d56484c37d55dcdcf90ab737d77eac9e7bd7d7b1..265649d043e758abf4176a500be9db9886d41458 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -12,17 +12,17 @@ from unittest.mock import patch import pytest from PIL import Image, ImageChops -from arkindex_export import Element, Transcription +from arkindex_export import Element, Transcription, TranscriptionEntity from dan.datasets.extract.arkindex import IIIF_FULL_SIZE, ArkindexExtractor +from dan.datasets.extract.db import get_transcription_entities from dan.datasets.extract.exceptions import ( - NoEndTokenError, NoTranscriptionError, UnknownTokenInText, ) from dan.datasets.extract.utils import ( EntityType, download_image, - insert_token, + entities_to_xml, normalize_linebreaks, normalize_spaces, ) @@ -69,95 +69,6 @@ def test_get_iiif_size_arg(max_width, max_height, width, height, resize): ) -@pytest.mark.parametrize( - "text,offset,length,expected", - ( - ("n°1 16 janvier 1611", 0, 3, "ⓘn°1â’¾"), - ("ⓘn°1â’¾ 16 janvier 1611", 6, 15, "ⓘ16 janvier 1611â’¾"), - ), -) -def test_insert_token(text, offset, length, expected): - assert ( - insert_token(text, EntityType(start="ⓘ", end="â’¾"), offset, length) == expected - ) - - -def test_reconstruct_text_no_end_token_error(): - arkindex_extractor = ArkindexExtractor(entity_separators=[]) - arkindex_extractor.tokens = { - "X": EntityType(start="ⓧ"), - } - with pytest.raises(NoEndTokenError, match="Label `X` has no end token."): - arkindex_extractor.reconstruct_text( - "n°1 x 16 janvier 1611", - [ - Entity( - offset=0, - length=3, - type="X", - value="n°1", - ), - ], - ) - - -@pytest.mark.parametrize( - "entity_separators,tokens,expected", - ( - # Whole text... - # ... + All tokens - ([], TOKENS, "â“Ÿn°1â“… x â““16 janvier 1611â’¹\nâ“MichelⓃ"), - # ... + 1rst and 2nd tokens - ([], filter_tokens(["P", "D"]), "â“Ÿn°1â“… x â““16 janvier 1611â’¹\nMichel"), - # ... + 1rst and 3rd tokens - ([], filter_tokens(["P", "N"]), "â“Ÿn°1â“… x 16 janvier 1611\nâ“MichelⓃ"), - # ... + 2nd and 3rd tokens - ([], filter_tokens(["D", "N"]), "n°1 x â““16 janvier 1611â’¹\nâ“MichelⓃ"), - # Only entities... - # ... + All tokens - (["\n", " "], TOKENS, "â“Ÿn°1â“… â““16 janvier 1611â’¹\nâ“MichelⓃ"), - # ... + 1rst and 2nd tokens - (["\n", " "], filter_tokens(["P", "D"]), "â“Ÿn°1â“… â““16 janvier 1611â’¹"), - # ... + 1rst and 3rd tokens - (["\n", " "], filter_tokens(["P", "N"]), "â“Ÿn°1â“…\nâ“MichelⓃ"), - # ... + 2nd and 3rd tokens - (["\n", " "], filter_tokens(["D", "N"]), "â““16 janvier 1611â’¹\nâ“MichelⓃ"), - ), -) -@pytest.mark.parametrize("text_before", ("", "text before ")) -@pytest.mark.parametrize("text_after", ("", " text after")) -def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after): - arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators) - arkindex_extractor.tokens = tokens - assert arkindex_extractor.reconstruct_text( - text_before + "n°1 x 16 janvier 1611\nMichel" + text_after, - [ - Entity( - offset=0 + len(text_before), - length=3, - type="P", - value="n°1", - ), - Entity( - offset=6 + len(text_before), - length=15, - type="D", - value="16 janvier 1611", - ), - Entity( - offset=22 + len(text_before), - length=6, - type="N", - value="Michel", - ), - ], - ) == ( - (text_before if not entity_separators else "") - + expected - + (text_after if not entity_separators else "") - ) - - @pytest.mark.parametrize( "text,trimmed", ( @@ -194,104 +105,6 @@ def test_normalize_linebreaks(text, trimmed): assert normalize_linebreaks(text) == trimmed -@pytest.mark.parametrize( - "entity_separators", - ( - # Whole text - [], - # Only entities - [" ", "\n"], - ), -) -@pytest.mark.parametrize("text_before", ("", "text before ")) -@pytest.mark.parametrize("text_after", ("", " text after")) -def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after): - arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators) - arkindex_extractor.tokens = TOKENS - assert arkindex_extractor.reconstruct_text( - text_before + "LouisXIV" + text_after, - [ - Entity( - offset=0 + len(text_before), - length=5, - type="P", - value="Louis", - ), - Entity( - offset=5 + len(text_before), - length=3, - type="I", - value="XIV", - ), - ], - ) == ( - (text_before if not entity_separators else "") - + "â“ŸLouisⓅⓘXIVâ’¾" - + (text_after if not entity_separators else "") - ) - - -@pytest.mark.parametrize("text_before", ("", "text before ")) -@pytest.mark.parametrize("text_after", ("", " text after")) -def test_reconstruct_text_several_separators(text_before, text_after): - arkindex_extractor = ArkindexExtractor(entity_separators=["\n", " "]) - arkindex_extractor.tokens = TOKENS - # Keep "\n" instead of " " - assert ( - arkindex_extractor.reconstruct_text( - text_before + "King\nLouis XIV" + text_after, - [ - Entity( - offset=0 + len(text_before), - length=4, - type="D", - value="King", - ), - Entity( - offset=11 + len(text_before), - length=3, - type="I", - value="XIV", - ), - ], - ) - == "â““Kingâ’¹\nⓘXIVâ’¾" - ) - - -@pytest.mark.parametrize("joined", (True, False)) -@pytest.mark.parametrize("text_before", ("", "text before ")) -@pytest.mark.parametrize("text_after", ("", " text after")) -def test_reconstruct_text_only_start_token(joined, text_before, text_after): - separator = " " if not joined else "" - - arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"]) - arkindex_extractor.tokens = { - "P": EntityType(start="â“Ÿ"), - "I": EntityType(start="ⓘ"), - } - assert ( - arkindex_extractor.reconstruct_text( - text_before + "Louis" + separator + "XIV" + text_after, - [ - Entity( - offset=0 + len(text_before), - length=5, - type="P", - value="Louis", - ), - Entity( - offset=5 + len(separator) + len(text_before), - length=3, - type="I", - value="XIV", - ), - ], - ) - == "â“ŸLouis" + separator + "ⓘXIV" - ) - - def test_process_element_unknown_token_in_text_error(mock_database, tmp_path): output = tmp_path / "extraction" arkindex_extractor = ArkindexExtractor(output=output) @@ -473,7 +286,8 @@ def test_extract( element_type=["text_line"], parent_element_type="double_page", output=output, - entity_separators=[" "] if load_entities else None, + # Keep the whole text + entity_separators=None, tokens=tokens_path if load_entities else None, transcription_worker_version=transcription_entities_worker_version, entity_worker_version=transcription_entities_worker_version @@ -803,3 +617,108 @@ def test_empty_transcription(allow_empty, mock_database): else: with pytest.raises(NoTranscriptionError): extractor.extract_transcription(element_no_transcription) + + +@pytest.mark.parametrize( + "nestation, xml_output, separators", + ( + # Non-nested + ( + "non-nested-id", + "<root>The <adj>great</adj> king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>", + None, + ), + # Non-nested no text between entities + ( + "non-nested-id", + "<root><adj>great</adj> <name>Charles</name>\n<person>us</person></root>", + ["\n", " "], + ), + # Nested + ( + "nested-id", + "<root>The great king <fullname><name>Charles</name> III</fullname> has eaten \nwith <person>us</person>.</root>", + None, + ), + # Nested no text between entities + ( + "nested-id", + "<root><fullname><name>Charles</name> III</fullname>\n<person>us</person></root>", + ["\n", " "], + ), + ), +) +def test_entities_to_xml(mock_database, nestation, xml_output, separators): + transcription = Transcription.get_by_id("tr-with-entities") + assert ( + entities_to_xml( + text=transcription.text, + predictions=get_transcription_entities( + transcription_id="tr-with-entities", + entity_worker_version=nestation, + supported_types=["name", "fullname", "person", "adj"], + ), + entity_separators=separators, + ) + == xml_output + ) + + +@pytest.mark.parametrize( + "supported_entities, xml_output, separators", + ( + # <adj> missing, no text between entities + ( + ["name", "person"], + "<root><name>Charles</name>\n<person>us</person></root>", + ["\n", " "], + ), + # <adj> missing, text between entities + ( + ["name", "person"], + "<root>The great king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>", + None, + ), + ), +) +def test_entities_to_xml_partial_entities( + mock_database, supported_entities, xml_output, separators +): + transcription = Transcription.get_by_id("tr-with-entities") + assert ( + entities_to_xml( + text=transcription.text, + predictions=get_transcription_entities( + transcription_id="tr-with-entities", + entity_worker_version="non-nested-id", + supported_types=supported_entities, + ), + entity_separators=separators, + ) + == xml_output + ) + + +@pytest.mark.parametrize( + "transcription", + ( + "Something\n", + "Something\r", + "Something\t", + 'Something"', + "Something'", + "Something<", + "Something>", + "Something&", + ), +) +def test_entities_to_xml_no_encode(transcription): + assert ( + entities_to_xml( + text=transcription, + # Empty queryset + predictions=TranscriptionEntity.select().where(TranscriptionEntity.id == 0), + entity_separators=None, + ) + == f"<root>{transcription}</root>" + )