Skip to content
Snippets Groups Projects
Commit 5f99f175 authored by Manon Blanco's avatar Manon Blanco
Browse files

Merge branch 'poc-entities-xml' into 'main'

Refactor entities extraction with lxml

See merge request !316
parents a5695ae7 ab67a16a
No related branches found
No related tags found
1 merge request!316Refactor entities extraction with lxml
...@@ -106,8 +106,6 @@ def add_extract_parser(subcommands) -> None: ...@@ -106,8 +106,6 @@ def add_extract_parser(subcommands) -> None:
If several separators follow each other, keep only the first to appear in the list. If several separators follow each other, keep only the first to appear in the list.
Do not give any arguments to keep the whole text. Do not give any arguments to keep the whole text.
""", """,
required=False,
default=list(map(validate_char, ("\n", " "))),
) )
parser.add_argument( parser.add_argument(
"--unknown-token", "--unknown-token",
......
...@@ -24,7 +24,6 @@ from dan.datasets.extract.db import ( ...@@ -24,7 +24,6 @@ from dan.datasets.extract.db import (
) )
from dan.datasets.extract.exceptions import ( from dan.datasets.extract.exceptions import (
ImageDownloadError, ImageDownloadError,
NoEndTokenError,
NoTranscriptionError, NoTranscriptionError,
ProcessingError, ProcessingError,
UnknownTokenInText, UnknownTokenInText,
...@@ -32,13 +31,14 @@ from dan.datasets.extract.exceptions import ( ...@@ -32,13 +31,14 @@ from dan.datasets.extract.exceptions import (
from dan.datasets.extract.utils import ( from dan.datasets.extract.utils import (
Tokenizer, Tokenizer,
download_image, download_image,
entities_to_xml,
get_bbox, get_bbox,
get_translation_map,
get_vocabulary, get_vocabulary,
insert_token,
normalize_linebreaks, normalize_linebreaks,
normalize_spaces, normalize_spaces,
) )
from dan.utils import EntityType, LMTokenMapping, parse_tokens from dan.utils import LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import ( from line_image_extractor.image_utils import (
BoundingBox, BoundingBox,
...@@ -87,7 +87,7 @@ class ArkindexExtractor: ...@@ -87,7 +87,7 @@ class ArkindexExtractor:
self.output = output self.output = output
self.entity_separators = entity_separators self.entity_separators = entity_separators
self.unknown_token = unknown_token self.unknown_token = unknown_token
self.tokens = parse_tokens(tokens) if tokens else None self.tokens = parse_tokens(tokens) if tokens else {}
self.transcription_worker_version = transcription_worker_version self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version self.entity_worker_version = entity_worker_version
self.max_width = max_width self.max_width = max_width
...@@ -107,6 +107,9 @@ class ArkindexExtractor: ...@@ -107,6 +107,9 @@ class ArkindexExtractor:
# Image download tasks to process # Image download tasks to process
self.tasks: List[Dict[str, str]] = [] self.tasks: List[Dict[str, str]] = []
# NER extraction
self.translation_map: Dict[str, str] | None = get_translation_map(self.tokens)
def get_iiif_size_arg(self, width: int, height: int) -> str: def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and ( if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height self.max_height is None or height <= self.max_height
...@@ -136,67 +139,13 @@ class ArkindexExtractor: ...@@ -136,67 +139,13 @@ class ArkindexExtractor:
image_url=image_url, bbox=get_bbox(polygon), size=size image_url=image_url, bbox=get_bbox(polygon), size=size
) )
def _keep_char(self, char: str) -> bool: def translate(self, text: str):
# Keep all text by default if no separator was given
return not self.entity_separators or char in self.entity_separators
def reconstruct_text(self, full_text: str, entities) -> str:
""" """
Insert tokens delimiting the start/end of each entity on the transcription. Use translation map to replace XML tags to actual tokens
""" """
text, text_offset = "", 0 for pattern, repl in self.translation_map.items():
for entity in entities: text = text.replace(pattern, repl)
# Text before entity return text
text += "".join(
filter(self._keep_char, full_text[text_offset : entity.offset])
)
entity_type: EntityType = self.tokens.get(entity.type)
if not entity_type:
logger.warning(
f"Label `{entity.type}` is missing in the NER configuration."
)
# We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends
elif not entity_type.end and not self.entity_separators:
raise NoEndTokenError(entity.type)
# Entity text:
# - with tokens if there is an entity_type
# - without tokens if there is no entity_type but we want to keep the whole text
if entity_type or not self.entity_separators:
text += insert_token(
full_text,
entity_type,
offset=entity.offset,
length=entity.length,
)
text_offset = entity.offset + entity.length
# Remaining text after the last entity
text += "".join(filter(self._keep_char, full_text[text_offset:]))
if not self.entity_separators or self.keep_spaces:
return text
# Add some clean up to avoid several separators between entities
text, full_text = "", text
for char in full_text:
last_char = text[-1] if len(text) else ""
# Keep the current character if there are no two consecutive separators
if (
char not in self.entity_separators
or last_char not in self.entity_separators
):
text += char
# If several separators follow each other, keep only one according to the given order
elif self.entity_separators.index(char) < self.entity_separators.index(
last_char
):
text = text[:-1] + char
# Remove separators at the beginning and end of text
return text.strip("".join(self.entity_separators))
def extract_transcription(self, element: Element): def extract_transcription(self, element: Element):
""" """
...@@ -217,9 +166,16 @@ class ArkindexExtractor: ...@@ -217,9 +166,16 @@ class ArkindexExtractor:
return transcription.text.strip() return transcription.text.strip()
entities = get_transcription_entities( entities = get_transcription_entities(
transcription.id, self.entity_worker_version transcription.id,
self.entity_worker_version,
supported_types=list(self.tokens),
)
return self.translate(
entities_to_xml(
transcription.text, entities, entity_separators=self.entity_separators
)
) )
return self.reconstruct_text(transcription.text, entities)
def get_image( def get_image(
self, self,
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import List, Union from typing import List, Optional, Union
from arkindex_export import Image from arkindex_export import Image
from arkindex_export.models import ( from arkindex_export.models import (
...@@ -60,7 +60,9 @@ def get_transcriptions( ...@@ -60,7 +60,9 @@ def get_transcriptions(
def get_transcription_entities( def get_transcription_entities(
transcription_id: str, entity_worker_version: Union[str, bool] transcription_id: str,
entity_worker_version: Optional[Union[str, bool]],
supported_types: List[str],
) -> List[TranscriptionEntity]: ) -> List[TranscriptionEntity]:
""" """
Retrieve transcription entities from an SQLite export of an Arkindex corpus Retrieve transcription entities from an SQLite export of an Arkindex corpus
...@@ -75,7 +77,10 @@ def get_transcription_entities( ...@@ -75,7 +77,10 @@ def get_transcription_entities(
) )
.join(Entity, on=TranscriptionEntity.entity) .join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type) .join(EntityType, on=Entity.type)
.where((TranscriptionEntity.transcription == transcription_id)) .where(
TranscriptionEntity.transcription == transcription_id,
EntityType.name.in_(supported_types),
)
) )
if entity_worker_version is not None: if entity_worker_version is not None:
...@@ -85,4 +90,6 @@ def get_transcription_entities( ...@@ -85,4 +90,6 @@ def get_transcription_entities(
) )
) )
return query.order_by(TranscriptionEntity.offset).namedtuples() return query.order_by(
TranscriptionEntity.offset, TranscriptionEntity.length.desc()
).dicts()
...@@ -52,18 +52,3 @@ class UnknownTokenInText(ElementProcessingError): ...@@ -52,18 +52,3 @@ class UnknownTokenInText(ElementProcessingError):
def __str__(self) -> str: def __str__(self) -> str:
return f"Unknown token found in the transcription text of element ({self.element_id})" return f"Unknown token found in the transcription text of element ({self.element_id})"
class NoEndTokenError(ProcessingError):
"""
Raised when the specified label has no end token and there is potentially additional text around the labels
"""
label: str
def __init__(self, label: str, *args: object) -> None:
super().__init__(*args)
self.label = label
def __str__(self) -> str:
return f"Label `{self.label}` has no end token."
...@@ -7,10 +7,11 @@ from dataclasses import dataclass, field ...@@ -7,10 +7,11 @@ from dataclasses import dataclass, field
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Iterator, List, Optional, Union from typing import Dict, Iterator, List, Optional, Union
import requests import requests
import sentencepiece as spm import sentencepiece as spm
from lxml.etree import Element, SubElement, tostring
from nltk import wordpunct_tokenize from nltk import wordpunct_tokenize
from PIL import Image, ImageOps from PIL import Image, ImageOps
from tenacity import ( from tenacity import (
...@@ -20,6 +21,7 @@ from tenacity import ( ...@@ -20,6 +21,7 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from arkindex_export import TranscriptionEntity
from dan.utils import EntityType, LMTokenMapping from dan.utils import EntityType, LMTokenMapping
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -31,6 +33,14 @@ DOWNLOAD_TIMEOUT = (30, 60) ...@@ -31,6 +33,14 @@ DOWNLOAD_TIMEOUT = (30, 60)
TRIM_SPACE_REGEX = re.compile(r"[\t ]+") TRIM_SPACE_REGEX = re.compile(r"[\t ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r\n]+") TRIM_RETURN_REGEX = re.compile(r"[\r\n]+")
# Some characters are encoded in XML but we don't want them encoded in the end
ENCODING_MAP = {
"&#13;": "\r",
"&lt;": "<",
"&gt;": ">",
"&amp;": "&",
}
def _retry_log(retry_state, *args, **kwargs): def _retry_log(retry_state, *args, **kwargs):
logger.warning( logger.warning(
...@@ -83,20 +93,6 @@ def download_image(url): ...@@ -83,20 +93,6 @@ def download_image(url):
return image return image
def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str:
"""
Insert the given tokens at the right position in the text
"""
return (
# Starting token
(entity_type.start if entity_type else "")
# Entity
+ text[offset : offset + length]
# End token
+ (entity_type.end if entity_type else "")
)
def normalize_linebreaks(text: str) -> str: def normalize_linebreaks(text: str) -> str:
""" """
Remove begin/ending linebreaks. Remove begin/ending linebreaks.
...@@ -248,3 +244,161 @@ class Tokenizer: ...@@ -248,3 +244,161 @@ class Tokenizer:
:param text: Text to be encoded. :param text: Text to be encoded.
""" """
return map(self.mapping.encode_token, text) return map(self.mapping.encode_token, text)
def slugify(text: str):
"""
Replace space in text to underscores to use it as XML tag.
"""
return text.replace(" ", "_")
def get_translation_map(tokens: Dict[str, EntityType]) -> Optional[Dict[str, str]]:
if not tokens:
return
translation_map = {
# Roots
"<root>": "",
"</root>": "",
}
# Tokens
for entity_name, token_type in tokens.items():
translation_map[f"<{slugify(entity_name)}>"] = token_type.start
translation_map[f"</{slugify(entity_name)}>"] = token_type.end
return translation_map
@dataclass
class XMLEntity:
type: str
name: str
offset: int
length: int
worker_version: str
children: List["XMLEntity"] = field(default_factory=list)
@property
def end(self) -> int:
return self.offset + self.length
def add_child(self, child: TranscriptionEntity):
self.children.append(
XMLEntity(
type=child["type"],
name=child["name"],
offset=child["offset"] - self.offset,
length=child["length"],
worker_version=child["worker_version"],
)
)
def insert(self, parent: Element):
e = SubElement(parent, slugify(self.type))
if not self.children:
# No children
e.text = self.name
return
offset = 0
for child in self.children:
# Add text before entity
portion_before = self.name[offset : child.offset]
offset += len(portion_before)
if len(e):
e[-1].tail = portion_before
else:
e.text = portion_before
child.insert(e)
offset += child.length
# Text after the last entity
e[-1].tail = self.name[self.children[-1].end : self.end]
def entities_to_xml(
text: str,
predictions: List[TranscriptionEntity],
entity_separators: Optional[List[str]] = None,
) -> str:
"""Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag.
Its type will be used to name the tag.
:param text: The text of the transcription
:param predictions: The list of entities linked to the transcription
:param entity_separators: When provided, instead of adding the text between entities, add one separator encountered in this text. The order is kept when looking for separators. Defaults to None
:return: The representation of the transcription in XML format
"""
def _find_separator(transcription: str) -> str:
"""
Find the first entity separator in the provided transcription.
"""
for separator in entity_separators:
if separator in transcription:
return separator
return ""
def add_portion(entity_offset: Optional[int] = None):
"""
Add the portion of text between entities either:
- after the last node, if there is one before
- on this node
If we remove the text between entities, we keep one of the separators provided. Order matters.
"""
portion = text[offset:entity_offset]
if entity_separators:
# Remove the text except the first entity_separator encountered
portion = _find_separator(portion)
if len(root):
root[-1].tail = portion
else:
root.text = portion
entities = iter(predictions)
# This will mark the ending position of the first-level of entities
last_end = None
parsed: List[XMLEntity] = []
for entity in entities:
# First entity is not inside any other
# If offset is too high, no nestation
if not last_end or entity["offset"] >= last_end:
parsed.append(XMLEntity(**entity))
last_end = entity["offset"] + entity["length"]
continue
# Nested entity
parsed[-1].add_child(entity)
# XML export
offset = 0
root = Element("root")
for entity in parsed:
add_portion(entity.offset)
entity.insert(root)
offset = entity.end
# Add text after last entity
add_portion()
# Cleanup separators introduced when text was removed
if entity_separators:
characters = "".join(entity_separators)
root.text = root.text.lstrip(characters)
# Strip trailing spaces on last child
root[-1].tail = root[-1].tail.rstrip(characters)
encoded_transcription = tostring(root, encoding="utf-8").decode()
for pattern, repl in ENCODING_MAP.items():
encoded_transcription = encoded_transcription.replace(pattern, repl)
return encoded_transcription
.entities-block {
background-color: var(--md-code-bg-color);
padding: .7720588235em 1.1764705882em;
}
/* Light mode */
body[data-md-color-scheme="default"] .entities-block > span[type=adj] {
background: #7fffd4;
}
body[data-md-color-scheme="default"] .entities-block > span[type=name] {
background: #83ff7f;
}
body[data-md-color-scheme="default"] .entities-block > span[type=person] {
background: #ffc17f;
}
/* Dark mode */
body[data-md-color-scheme="slate"] .entities-block > span[type=adj] {
background: #7fffd48c;
}
body[data-md-color-scheme="slate"] .entities-block > span[type=name] {
background: #83ff7f63;
}
body[data-md-color-scheme="slate"] .entities-block > span[type=person] {
background: #ffc17f6c;
}
.entities-block > span {
padding: 0.1em;
border-radius: 0.35em;
}
.entities-block > span::after {
content: ' ' attr(type);
font-size: 0.8em;
font-weight: bold;
vertical-align: middle;
}
...@@ -11,26 +11,26 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind ...@@ -11,26 +11,26 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
If an image download fails for whatever reason, it won't appear in the transcriptions file. The reason will be printed to stdout at the end of the process. Before trying to download the image, it checks that it wasn't downloaded previously. It is thus safe to run this command twice if a few images failed. If an image download fails for whatever reason, it won't appear in the transcriptions file. The reason will be printed to stdout at the end of the process. Before trying to download the image, it checks that it wasn't downloaded previously. It is thus safe to run this command twice if a few images failed.
| Parameter | Description | Type | Default | | Parameter | Description | Type | Default |
| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------- | -------------------------------------------------- | | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | | | `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | | `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | | `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `pathlib.Path` | | | `--output` | Folder where the data will be generated. | `pathlib.Path` | |
| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text. | `str` | `["\n", " "]` (see [dedicated section](#examples)) | | `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | |
| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `⁇` | | `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `⁇` |
| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | | `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | |
| `--train-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | | `--train-folder` | ID of the training folder to extract from Arkindex. | `uuid` | |
| `--val-folder` | ID of the validation folder to extract from Arkindex. | `uuid` | | | `--val-folder` | ID of the validation folder to extract from Arkindex. | `uuid` | |
| `--test-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | | `--test-folder` | ID of the training folder to extract from Arkindex. | `uuid` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | | `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | | `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--max-width` | Images larger than this width will be resized to this width. | `int` | | | `--max-width` | Images larger than this width will be resized to this width. | `int` | |
| `--max-height` | Images larger than this height will be resized to this height. | `int` | | | `--max-height` | Images larger than this height will be resized to this height. | `int` | |
| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | | `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` |
| `--image-format` | Images will be saved under this format. | `str` | `.jpg` | | `--image-format` | Images will be saved under this format. | `str` | `.jpg` |
| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` | | `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` |
| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | | `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` |
The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md). The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md).
...@@ -75,15 +75,44 @@ teklia-dan dataset extract \ ...@@ -75,15 +75,44 @@ teklia-dan dataset extract \
--tokens tokens.yml --tokens tokens.yml
``` ```
If there is no end token, it is possible to define the characters to keep with the `--entity-separators` parameter: If the model should predict entities only and not the text surrounding them, the `--entity-separators` parameter can be used to list the only characters allowed in the transcription outside of entities. Only one of them will be used between entities, the priority is parsed through the order of the characters.
```shell Here is an example of transcription with entities, on two lines:
teklia-dan dataset extract \
[...] \ <div class="entities-block highlight">
--entity-separators $'\n' " " The
``` <span type="adj">great</span>
king
<span type="name">Charles</span>
III has eaten <br />with
<span type="person">us</span>
.
</div>
Here is the extraction with `--entity-separators=" "`:
<div class="entities-block highlight">
<span type="adj">great</span>
<span type="name">Charles</span>
<span type="person">us</span>
</div>
Here is the extraction with `--entity-separators="\n" " "`:
<div class="entities-block highlight">
<span type="adj">great</span>
<span type="name">Charles</span>
<br />
<span type="person">us</span>
</div>
The order of the argument is important. If the whitespaces are more important than the linebreaks, i.e. `--entity-separators=" " "\n"`, the extraction will result in:
If several separators follow each other, it will keep only one, ideally a line break if there is one, otherwise a space. If you change the order of the `--entity-separators` parameters, then it will keep a space if there is one, otherwise a line break. <div class="entities-block highlight">
<span type="adj">great</span>
<span type="name">Charles</span>
<span type="person">us</span>
</div>
### HTR from multiple element types ### HTR from multiple element types
......
...@@ -141,3 +141,6 @@ extra: ...@@ -141,3 +141,6 @@ extra:
- icon: fontawesome/brands/linkedin - icon: fontawesome/brands/linkedin
name: Teklia @ LinkedIn name: Teklia @ LinkedIn
link: https://www.linkedin.com/company/teklia link: https://www.linkedin.com/company/teklia
extra_css:
- css/ner.css
...@@ -5,6 +5,7 @@ editdistance==0.6.2 ...@@ -5,6 +5,7 @@ editdistance==0.6.2
flashlight-text==0.0.4 flashlight-text==0.0.4
imageio==2.26.1 imageio==2.26.1
imagesize==1.4.1 imagesize==1.4.1
lxml==4.9.3
mdutils==1.6.0 mdutils==1.6.0
nltk==3.8.1 nltk==3.8.1
numpy==1.24.3 numpy==1.24.3
......
...@@ -178,6 +178,71 @@ def mock_database(tmp_path_factory): ...@@ -178,6 +178,71 @@ def mock_database(tmp_path_factory):
# Create folders # Create folders
create_element(id="root") create_element(id="root")
# Create data for entities extraction tests
# Create transcription
transcription = Transcription.create(
id="tr-with-entities",
text="The great king Charles III has eaten \nwith us.",
element=Element.select().first(),
)
WorkerVersion.bulk_create(
[
WorkerVersion(
id=f"{nestation}-id",
slug=nestation,
name=nestation,
repository_url="http://repository/url",
revision="main",
type="worker",
)
for nestation in ("nested", "non-nested")
]
)
entities = [
# Non-nested entities
{
"worker_version": "non-nested-id",
"type": "adj",
"name": "great",
"offset": 4,
},
{
"worker_version": "non-nested-id",
"type": "name",
"name": "Charles",
"offset": 15,
},
{
"worker_version": "non-nested-id",
"type": "person",
"name": "us",
"offset": 43,
},
# Nested entities
{
"worker_version": "nested-id",
"type": "fullname",
"name": "Charles III",
"offset": 15,
},
{
"worker_version": "nested-id",
"type": "name",
"name": "Charles",
"offset": 15,
},
{
"worker_version": "nested-id",
"type": "person",
"name": "us",
"offset": 43,
},
]
for entity in entities:
create_transcription_entity(transcription=transcription, **entity)
return database_path return database_path
......
--- ---
entities: entities:
- adj
- birthdate - birthdate
- firstname - firstname
- fullname
- name
- person
- surname - surname
--- ---
birthdate: adj:
start: start:
end: end:
firstname: birthdate:
start: start:
end: end:
surname: firstname:
start: start:
end: end:
fullname:
start:
end:
name:
start:
end:
person:
start:
end:
surname:
start:
end:
--- ---
birthdate: adj:
start: start:
end: '' end: ''
firstname: birthdate:
start: start:
end: '' end: ''
surname: firstname:
start: start:
end: '' end: ''
fullname:
start:
end: ''
name:
start:
end: ''
person:
start:
end: ''
surname:
start:
end: ''
...@@ -72,11 +72,15 @@ def test_get_transcriptions(worker_version, mock_database): ...@@ -72,11 +72,15 @@ def test_get_transcriptions(worker_version, mock_database):
@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None)) @pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
def test_get_transcription_entities(worker_version, mock_database): @pytest.mark.parametrize(
"supported_types", (["surname"], ["surname", "firstname", "birthdate"])
)
def test_get_transcription_entities(worker_version, mock_database, supported_types):
transcription_id = "train-page_1-line_1" + (worker_version or "") transcription_id = "train-page_1-line_1" + (worker_version or "")
entities = get_transcription_entities( entities = get_transcription_entities(
transcription_id=transcription_id, transcription_id=transcription_id,
entity_worker_version=worker_version, entity_worker_version=worker_version,
supported_types=supported_types,
) )
expected_entities = [ expected_entities = [
...@@ -99,23 +103,18 @@ def test_get_transcription_entities(worker_version, mock_database): ...@@ -99,23 +103,18 @@ def test_get_transcription_entities(worker_version, mock_database):
"length": 7, "length": 7,
}, },
] ]
expected_entities = list(
filter(lambda ent: ent["type"] in supported_types, expected_entities)
)
for entity in expected_entities: for entity in expected_entities:
if worker_version: if worker_version:
entity["name"] = entity["name"].lower() entity["name"] = entity["name"].lower()
entity["worker_version_id"] = worker_version or None entity["worker_version"] = worker_version or None
assert ( assert (
sorted( sorted(
[ entities,
{
"name": transcription_entity.name,
"type": transcription_entity.type,
"offset": transcription_entity.offset,
"length": transcription_entity.length,
"worker_version_id": transcription_entity.worker_version,
}
for transcription_entity in entities
],
key=itemgetter("offset"), key=itemgetter("offset"),
) )
== expected_entities == expected_entities
......
...@@ -12,17 +12,17 @@ from unittest.mock import patch ...@@ -12,17 +12,17 @@ from unittest.mock import patch
import pytest import pytest
from PIL import Image, ImageChops from PIL import Image, ImageChops
from arkindex_export import Element, Transcription from arkindex_export import Element, Transcription, TranscriptionEntity
from dan.datasets.extract.arkindex import IIIF_FULL_SIZE, ArkindexExtractor from dan.datasets.extract.arkindex import IIIF_FULL_SIZE, ArkindexExtractor
from dan.datasets.extract.db import get_transcription_entities
from dan.datasets.extract.exceptions import ( from dan.datasets.extract.exceptions import (
NoEndTokenError,
NoTranscriptionError, NoTranscriptionError,
UnknownTokenInText, UnknownTokenInText,
) )
from dan.datasets.extract.utils import ( from dan.datasets.extract.utils import (
EntityType, EntityType,
download_image, download_image,
insert_token, entities_to_xml,
normalize_linebreaks, normalize_linebreaks,
normalize_spaces, normalize_spaces,
) )
...@@ -69,95 +69,6 @@ def test_get_iiif_size_arg(max_width, max_height, width, height, resize): ...@@ -69,95 +69,6 @@ def test_get_iiif_size_arg(max_width, max_height, width, height, resize):
) )
@pytest.mark.parametrize(
"text,offset,length,expected",
(
("n°1 16 janvier 1611", 0, 3, "ⓘn°1Ⓘ"),
("ⓘn°1Ⓘ 16 janvier 1611", 6, 15, "ⓘ16 janvier 1611Ⓘ"),
),
)
def test_insert_token(text, offset, length, expected):
assert (
insert_token(text, EntityType(start="", end=""), offset, length) == expected
)
def test_reconstruct_text_no_end_token_error():
arkindex_extractor = ArkindexExtractor(entity_separators=[])
arkindex_extractor.tokens = {
"X": EntityType(start=""),
}
with pytest.raises(NoEndTokenError, match="Label `X` has no end token."):
arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611",
[
Entity(
offset=0,
length=3,
type="X",
value="n°1",
),
],
)
@pytest.mark.parametrize(
"entity_separators,tokens,expected",
(
# Whole text...
# ... + All tokens
([], TOKENS, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
# ... + 1rst and 2nd tokens
([], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nMichel"),
# ... + 1rst and 3rd tokens
([], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ x 16 janvier 1611\nⓝMichelⓃ"),
# ... + 2nd and 3rd tokens
([], filter_tokens(["D", "N"]), "n°1 x ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
# Only entities...
# ... + All tokens
(["\n", " "], TOKENS, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
# ... + 1rst and 2nd tokens
(["\n", " "], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
# ... + 1rst and 3rd tokens
(["\n", " "], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ\nⓝMichelⓃ"),
# ... + 2nd and 3rd tokens
(["\n", " "], filter_tokens(["D", "N"]), "ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
),
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after):
arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators)
arkindex_extractor.tokens = tokens
assert arkindex_extractor.reconstruct_text(
text_before + "n°1 x 16 janvier 1611\nMichel" + text_after,
[
Entity(
offset=0 + len(text_before),
length=3,
type="P",
value="n°1",
),
Entity(
offset=6 + len(text_before),
length=15,
type="D",
value="16 janvier 1611",
),
Entity(
offset=22 + len(text_before),
length=6,
type="N",
value="Michel",
),
],
) == (
(text_before if not entity_separators else "")
+ expected
+ (text_after if not entity_separators else "")
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text,trimmed", "text,trimmed",
( (
...@@ -194,104 +105,6 @@ def test_normalize_linebreaks(text, trimmed): ...@@ -194,104 +105,6 @@ def test_normalize_linebreaks(text, trimmed):
assert normalize_linebreaks(text) == trimmed assert normalize_linebreaks(text) == trimmed
@pytest.mark.parametrize(
"entity_separators",
(
# Whole text
[],
# Only entities
[" ", "\n"],
),
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after):
arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators)
arkindex_extractor.tokens = TOKENS
assert arkindex_extractor.reconstruct_text(
text_before + "LouisXIV" + text_after,
[
Entity(
offset=0 + len(text_before),
length=5,
type="P",
value="Louis",
),
Entity(
offset=5 + len(text_before),
length=3,
type="I",
value="XIV",
),
],
) == (
(text_before if not entity_separators else "")
+ "ⓟLouisⓅⓘXIVⒾ"
+ (text_after if not entity_separators else "")
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_several_separators(text_before, text_after):
arkindex_extractor = ArkindexExtractor(entity_separators=["\n", " "])
arkindex_extractor.tokens = TOKENS
# Keep "\n" instead of " "
assert (
arkindex_extractor.reconstruct_text(
text_before + "King\nLouis XIV" + text_after,
[
Entity(
offset=0 + len(text_before),
length=4,
type="D",
value="King",
),
Entity(
offset=11 + len(text_before),
length=3,
type="I",
value="XIV",
),
],
)
== "ⓓKingⒹ\nⓘXIVⒾ"
)
@pytest.mark.parametrize("joined", (True, False))
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_only_start_token(joined, text_before, text_after):
separator = " " if not joined else ""
arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"])
arkindex_extractor.tokens = {
"P": EntityType(start=""),
"I": EntityType(start=""),
}
assert (
arkindex_extractor.reconstruct_text(
text_before + "Louis" + separator + "XIV" + text_after,
[
Entity(
offset=0 + len(text_before),
length=5,
type="P",
value="Louis",
),
Entity(
offset=5 + len(separator) + len(text_before),
length=3,
type="I",
value="XIV",
),
],
)
== "ⓟLouis" + separator + "ⓘXIV"
)
def test_process_element_unknown_token_in_text_error(mock_database, tmp_path): def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
output = tmp_path / "extraction" output = tmp_path / "extraction"
arkindex_extractor = ArkindexExtractor(output=output) arkindex_extractor = ArkindexExtractor(output=output)
...@@ -473,7 +286,8 @@ def test_extract( ...@@ -473,7 +286,8 @@ def test_extract(
element_type=["text_line"], element_type=["text_line"],
parent_element_type="double_page", parent_element_type="double_page",
output=output, output=output,
entity_separators=[" "] if load_entities else None, # Keep the whole text
entity_separators=None,
tokens=tokens_path if load_entities else None, tokens=tokens_path if load_entities else None,
transcription_worker_version=transcription_entities_worker_version, transcription_worker_version=transcription_entities_worker_version,
entity_worker_version=transcription_entities_worker_version entity_worker_version=transcription_entities_worker_version
...@@ -803,3 +617,108 @@ def test_empty_transcription(allow_empty, mock_database): ...@@ -803,3 +617,108 @@ def test_empty_transcription(allow_empty, mock_database):
else: else:
with pytest.raises(NoTranscriptionError): with pytest.raises(NoTranscriptionError):
extractor.extract_transcription(element_no_transcription) extractor.extract_transcription(element_no_transcription)
@pytest.mark.parametrize(
"nestation, xml_output, separators",
(
# Non-nested
(
"non-nested-id",
"<root>The <adj>great</adj> king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>",
None,
),
# Non-nested no text between entities
(
"non-nested-id",
"<root><adj>great</adj> <name>Charles</name>\n<person>us</person></root>",
["\n", " "],
),
# Nested
(
"nested-id",
"<root>The great king <fullname><name>Charles</name> III</fullname> has eaten \nwith <person>us</person>.</root>",
None,
),
# Nested no text between entities
(
"nested-id",
"<root><fullname><name>Charles</name> III</fullname>\n<person>us</person></root>",
["\n", " "],
),
),
)
def test_entities_to_xml(mock_database, nestation, xml_output, separators):
transcription = Transcription.get_by_id("tr-with-entities")
assert (
entities_to_xml(
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_version=nestation,
supported_types=["name", "fullname", "person", "adj"],
),
entity_separators=separators,
)
== xml_output
)
@pytest.mark.parametrize(
"supported_entities, xml_output, separators",
(
# <adj> missing, no text between entities
(
["name", "person"],
"<root><name>Charles</name>\n<person>us</person></root>",
["\n", " "],
),
# <adj> missing, text between entities
(
["name", "person"],
"<root>The great king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>",
None,
),
),
)
def test_entities_to_xml_partial_entities(
mock_database, supported_entities, xml_output, separators
):
transcription = Transcription.get_by_id("tr-with-entities")
assert (
entities_to_xml(
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_version="non-nested-id",
supported_types=supported_entities,
),
entity_separators=separators,
)
== xml_output
)
@pytest.mark.parametrize(
"transcription",
(
"Something\n",
"Something\r",
"Something\t",
'Something"',
"Something'",
"Something<",
"Something>",
"Something&",
),
)
def test_entities_to_xml_no_encode(transcription):
assert (
entities_to_xml(
text=transcription,
# Empty queryset
predictions=TranscriptionEntity.select().where(TranscriptionEntity.id == 0),
entity_separators=None,
)
== f"<root>{transcription}</root>"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment