Skip to content
Snippets Groups Projects
Commit f1c303d4 authored by Manon Blanco's avatar Manon Blanco
Browse files

Support joined entities and entity separators

parent 17d90dae
No related branches found
No related tags found
No related merge requests found
...@@ -87,16 +87,18 @@ def add_extract_parser(subcommands) -> None: ...@@ -87,16 +87,18 @@ def add_extract_parser(subcommands) -> None:
action="store_true", action="store_true",
help="Extract text with their entities.", help="Extract text with their entities.",
) )
parser.add_argument(
"--only-entities",
action="store_true",
help="Remove all text that does not belong to the tokens.",
)
parser.add_argument( parser.add_argument(
"--allow-unknown-entities", "--allow-unknown-entities",
action="store_true", action="store_true",
help="Ignore entities that do not appear in the list of tokens.", help="Ignore entities that do not appear in the list of tokens.",
) )
parser.add_argument(
"--entity-separators",
type=str,
nargs="+",
help="Removes all text that does not appear in an entity or in the list of given characters. Do not give any arguments for keeping the whole text.",
required=False,
)
parser.add_argument( parser.add_argument(
"--tokens", "--tokens",
type=pathlib.Path, type=pathlib.Path,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import random import random
from collections import defaultdict from collections import defaultdict
from itertools import pairwise
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
from uuid import UUID from uuid import UUID
...@@ -36,8 +37,6 @@ IMAGES_DIR = "images" # Subpath to the images directory. ...@@ -36,8 +37,6 @@ IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory. LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"] SPLIT_NAMES = ["train", "val", "test"]
EMPTY_CHARS = [" ", "\n", "\t", "\r"]
class ArkindexExtractor: class ArkindexExtractor:
""" """
...@@ -52,7 +51,7 @@ class ArkindexExtractor: ...@@ -52,7 +51,7 @@ class ArkindexExtractor:
output: Path = None, output: Path = None,
load_entities: bool = False, load_entities: bool = False,
allow_unknown_entities: bool = False, allow_unknown_entities: bool = False,
only_entities: bool = False, entity_separators: list = [],
tokens: Path = None, tokens: Path = None,
use_existing_split: bool = None, use_existing_split: bool = None,
transcription_worker_version: Optional[Union[str, bool]] = None, transcription_worker_version: Optional[Union[str, bool]] = None,
...@@ -67,7 +66,7 @@ class ArkindexExtractor: ...@@ -67,7 +66,7 @@ class ArkindexExtractor:
self.output = output self.output = output
self.load_entities = load_entities self.load_entities = load_entities
self.allow_unknown_entities = allow_unknown_entities self.allow_unknown_entities = allow_unknown_entities
self.only_entities = only_entities self.entity_separators = entity_separators
self.tokens = parse_tokens(tokens) if self.load_entities else None self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version self.transcription_worker_version = transcription_worker_version
...@@ -110,20 +109,29 @@ class ArkindexExtractor: ...@@ -110,20 +109,29 @@ class ArkindexExtractor:
""" """
Insert tokens delimiting the start/end of each entity on the transcription. Insert tokens delimiting the start/end of each entity on the transcription.
""" """
def keep_char(char):
return keep_all_text or char in self.entity_separators
text, text_offset = "", 0 text, text_offset = "", 0
# Keep all text by default if no separator was given
keep_all_text = not self.entity_separators
for entity in entities: for entity in entities:
# Text before entity # Text before entity
if not self.only_entities: text += "".join(filter(keep_char, full_text[text_offset : entity.offset]))
text += full_text[text_offset : entity.offset]
entity_type: EntityType = self.tokens.get(entity.type) entity_type: EntityType = self.tokens.get(entity.type)
# Unknown entities are not allowed
if not entity_type and not self.allow_unknown_entities: if not entity_type and not self.allow_unknown_entities:
raise UnknownLabelError(entity.type) raise UnknownLabelError(entity.type)
if entity_type and not entity_type.end and not self.only_entities: # We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends.
if entity_type and not entity_type.end and keep_all_text:
raise NoEndTokenError(entity.type) raise NoEndTokenError(entity.type)
# Entity text (optionally with tokens) # Entity text:
if entity_type or not self.only_entities: # - with tokens if there is an entity_type,
# - without tokens if there is no entity_type but we want to keep the whole text.
if entity_type or keep_all_text:
text += insert_token( text += insert_token(
full_text, full_text,
entity_type, entity_type,
...@@ -132,20 +140,29 @@ class ArkindexExtractor: ...@@ -132,20 +140,29 @@ class ArkindexExtractor:
) )
text_offset = entity.offset + entity.length text_offset = entity.offset + entity.length
# Keep the exact separator after entity (it can be a space, a line break etc.) # Remaining text after the last entity
if self.only_entities: text += "".join(filter(keep_char, full_text[text_offset:]))
separator = next(
(char for char in full_text[text_offset:] if char in EMPTY_CHARS), if keep_all_text:
"", return text
)
text += separator # Add some clean up to avoid several separators between entities.
text, full_text = "", text
for char, next_char in pairwise(full_text):
# If several separators follow each other, keep only the last one.
if (
char not in self.entity_separators
or next_char not in self.entity_separators
):
text += char
# Remaining text # Remaining char
if not self.only_entities: remaining_char = full_text[-1]
text += full_text[text_offset:] if remaining_char not in self.entity_separators:
text += remaining_char
# Remove extra spaces # Remove separators at the beginning and end of text
return text.strip("".join(EMPTY_CHARS)) return text.strip("".join(self.entity_separators))
def extract_transcription(self, element: Element): def extract_transcription(self, element: Element):
""" """
...@@ -160,7 +177,7 @@ class ArkindexExtractor: ...@@ -160,7 +177,7 @@ class ArkindexExtractor:
transcription = random.choice(transcriptions) transcription = random.choice(transcriptions)
if self.load_entities or self.only_entities: if self.load_entities:
entities = get_transcription_entities( entities = get_transcription_entities(
transcription.id, self.entity_worker_version transcription.id, self.entity_worker_version
) )
...@@ -249,7 +266,7 @@ def run( ...@@ -249,7 +266,7 @@ def run(
output: Path, output: Path,
load_entities: bool, load_entities: bool,
allow_unknown_entities: bool, allow_unknown_entities: bool,
only_entities: bool, entity_separators: list,
tokens: Path, tokens: Path,
use_existing_split: bool, use_existing_split: bool,
train_folder: UUID, train_folder: UUID,
...@@ -302,7 +319,7 @@ def run( ...@@ -302,7 +319,7 @@ def run(
output=output, output=output,
load_entities=load_entities, load_entities=load_entities,
allow_unknown_entities=allow_unknown_entities, allow_unknown_entities=allow_unknown_entities,
only_entities=only_entities, entity_separators=entity_separators,
tokens=tokens, tokens=tokens,
use_existing_split=use_existing_split, use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version, transcription_worker_version=transcription_worker_version,
......
...@@ -4,27 +4,27 @@ ...@@ -4,27 +4,27 @@
Use the `teklia-dan dataset extract` command to extract a dataset from an Arkindex export database (SQLite format). This will generate the images and the labels needed to train a DAN model. Use the `teklia-dan dataset extract` command to extract a dataset from an Arkindex export database (SQLite format). This will generate the images and the labels needed to train a DAN model.
| Parameter | Description | Type | Default | | Parameter | Description | Type | Default |
| -------------------------------- | ----------------------------------------------------------------------------------- | --------------- | ------- | | -------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `Path` | | | `database` | Path to an Arkindex export database in SQLite format. | `Path` | |
| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str` or `uuid` | | | `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str` or `uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | | `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | | `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `Path` | | | `--output` | Folder where the data will be generated. | `Path` | |
| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` | | `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` |
| `--only-entities` | Remove all text that does not belong to the tokens. | `bool` | `False` | | `--allow-unknown-entities` | Ignore entities that do not appear in the list of tokens. | `bool` | `False` |
| `--allow-unknown-entities` | Ignore entities that do not appear in the list of tokens. | `bool` | `False` | | `--entity-separators` | Removes all text that does not appear in an entity or in the list of given characters. Do not give any arguments for keeping the whole text. | `str` | |
| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | | | `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | |
| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | | | `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | |
| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | | | `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | | | `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | |
| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | | | `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | | `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | | `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--train-prob` | Training set split size | `float` | `0.7` | | `--train-prob` | Training set split size | `float` | `0.7` |
| `--val-prob` | Validation set split size | `float` | `0.15` | | `--val-prob` | Validation set split size | `float` | `0.15` |
| `--max-width` | Images larger than this width will be resized to this width. | `int` | | | `--max-width` | Images larger than this width will be resized to this width. | `int` | |
| `--max-height` | Images larger than this height will be resized to this height. | `int` | | | `--max-height` | Images larger than this height will be resized to this height. | `int` | |
The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively.
......
...@@ -4,12 +4,24 @@ from typing import NamedTuple ...@@ -4,12 +4,24 @@ from typing import NamedTuple
import pytest import pytest
from dan.datasets.extract.exceptions import NoEndTokenError, UnknownLabelError
from dan.datasets.extract.extract import ArkindexExtractor from dan.datasets.extract.extract import ArkindexExtractor
from dan.datasets.extract.utils import EntityType, insert_token from dan.datasets.extract.utils import EntityType, insert_token
# NamedTuple to mock actual database result # NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str) Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
TOKENS = {
"P": EntityType(start="", end=""),
"D": EntityType(start="", end=""),
"N": EntityType(start="", end=""),
"I": EntityType(start="", end=""),
}
def filter_tokens(keys):
return {key: value for key, value in TOKENS.items() if key in keys}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text,offset,length,expected", "text,offset,length,expected",
...@@ -24,44 +36,165 @@ def test_insert_token(text, offset, length, expected): ...@@ -24,44 +36,165 @@ def test_insert_token(text, offset, length, expected):
) )
def test_reconstruct_text_unknown_label_error():
arkindex_extractor = ArkindexExtractor()
arkindex_extractor.tokens = TOKENS
with pytest.raises(
UnknownLabelError, match="Label `X` is missing in the NER configuration."
):
arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611",
[
Entity(
offset=0,
length=3,
type="X",
value="n°1",
),
],
)
def test_reconstruct_text_no_end_token_error():
arkindex_extractor = ArkindexExtractor()
arkindex_extractor.tokens = {
"X": EntityType(start=""),
}
with pytest.raises(NoEndTokenError, match="Label `X` has no end token."):
arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611",
[
Entity(
offset=0,
length=3,
type="X",
value="n°1",
),
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"only_entities,expected", "entity_separators,tokens,expected",
( (
(False, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ x Michou"), # Whole text...
(True, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"), # ... + All tokens
([], TOKENS, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
# ... + 1rst and 2nd tokens
([], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nMichou"),
# ... + 1rst and 3rd tokens
([], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ x 16 janvier 1611\nⓝMichouⓃ"),
# ... + 2nd and 3rd tokens
([], filter_tokens(["D", "N"]), "n°1 x ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
# Only entities...
# ... + All tokens
([" ", "\n"], TOKENS, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
# ... + 1rst and 2nd tokens
([" ", "\n"], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
# ... + 1rst and 3rd tokens
([" ", "\n"], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ\nⓝMichouⓃ"),
# ... + 2nd and 3rd tokens
([" ", "\n"], filter_tokens(["D", "N"]), "ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
), ),
) )
def test_reconstruct_text(only_entities, expected): @pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after):
arkindex_extractor = ArkindexExtractor( arkindex_extractor = ArkindexExtractor(
allow_unknown_entities=True, only_entities=only_entities allow_unknown_entities=True, entity_separators=entity_separators
)
arkindex_extractor.tokens = tokens
assert arkindex_extractor.reconstruct_text(
text_before + "n°1 x 16 janvier 1611\nMichou" + text_after,
[
Entity(
offset=0 + len(text_before),
length=3,
type="P",
value="n°1",
),
Entity(
offset=6 + len(text_before),
length=15,
type="D",
value="16 janvier 1611",
),
Entity(
offset=22 + len(text_before),
length=6,
type="N",
value="Michou",
),
],
) == (
(text_before if not entity_separators else "")
+ expected
+ (text_after if not entity_separators else "")
) )
@pytest.mark.parametrize(
"entity_separators",
(
# Whole text
[],
# Only entities
[" ", "\n"],
),
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after):
arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators)
arkindex_extractor.tokens = TOKENS
assert arkindex_extractor.reconstruct_text(
text_before + "LouisXIV" + text_after,
[
Entity(
offset=0 + len(text_before),
length=5,
type="P",
value="Louis",
),
Entity(
offset=5 + len(text_before),
length=3,
type="I",
value="XIV",
),
],
) == (
(text_before if not entity_separators else "")
+ "ⓟLouisⓅⓘXIVⒾ"
+ (text_after if not entity_separators else "")
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_only_start_token(text_before, text_after):
arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"])
arkindex_extractor.tokens = { arkindex_extractor.tokens = {
"P": EntityType(start="", end=""), "P": EntityType(start=""),
"D": EntityType(start="", end=""), "I": EntityType(start=""),
} }
assert ( assert (
arkindex_extractor.reconstruct_text( arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611 x Michou", text_before + "LouisXIV" + text_after,
[ [
Entity( Entity(
offset=0, offset=0 + len(text_before),
length=3, length=5,
type="P", type="P",
value="n°1", value="Louis",
), ),
Entity( Entity(
offset=6, offset=5 + len(text_before),
length=15, length=3,
type="D", type="I",
value="16 janvier 1611", value="XIV",
),
Entity(
offset=24,
length=6,
type="N",
value="Michou",
), ),
], ],
) )
== expected == "ⓟLouisⓘXIV"
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment