Skip to content
Snippets Groups Projects
Commit 88720b74 authored by Manon Blanco's avatar Manon Blanco
Browse files

Support joined entities and entity separators

parent 0ccc3e85
No related branches found
No related tags found
No related merge requests found
......@@ -87,16 +87,18 @@ def add_extract_parser(subcommands) -> None:
action="store_true",
help="Extract text with their entities.",
)
parser.add_argument(
"--only-entities",
action="store_true",
help="Remove all text that does not belong to the tokens.",
)
parser.add_argument(
"--allow-unknown-entities",
action="store_true",
help="Ignore entities that do not appear in the list of tokens.",
)
parser.add_argument(
"--entity-separators",
type=str,
nargs="+",
help="Removes all text that does not appear in an entity or in the list of given characters. Do not give any arguments for keeping the whole text.",
required=False,
)
parser.add_argument(
"--tokens",
type=pathlib.Path,
......
......@@ -2,6 +2,7 @@
import random
from collections import defaultdict
from itertools import pairwise
from pathlib import Path
from typing import List, Optional, Union
from uuid import UUID
......@@ -36,8 +37,6 @@ IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"]
EMPTY_CHARS = [" ", "\n", "\t", "\r"]
class ArkindexExtractor:
"""
......@@ -52,7 +51,7 @@ class ArkindexExtractor:
output: Path = None,
load_entities: bool = False,
allow_unknown_entities: bool = False,
only_entities: bool = False,
entity_separators: list = [],
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
......@@ -67,7 +66,7 @@ class ArkindexExtractor:
self.output = output
self.load_entities = load_entities
self.allow_unknown_entities = allow_unknown_entities
self.only_entities = only_entities
self.entity_separators = entity_separators
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
......@@ -110,20 +109,29 @@ class ArkindexExtractor:
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
def keep_char(char):
return keep_all_text or char in self.entity_separators
text, text_offset = "", 0
# Keep all text by default if no separator was given
keep_all_text = not self.entity_separators
for entity in entities:
# Text before entity
if not self.only_entities:
text += full_text[text_offset : entity.offset]
text += "".join(filter(keep_char, full_text[text_offset : entity.offset]))
entity_type: EntityType = self.tokens.get(entity.type)
# Unknown entities are not allowed
if not entity_type and not self.allow_unknown_entities:
raise UnknownLabelError(entity.type)
if entity_type and not entity_type.end and not self.only_entities:
# We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends.
if entity_type and not entity_type.end and keep_all_text:
raise NoEndTokenError(entity.type)
# Entity text (optionally with tokens)
if entity_type or not self.only_entities:
# Entity text:
# - with tokens if there is an entity_type,
# - without tokens if there is no entity_type but we want to keep the whole text.
if entity_type or keep_all_text:
text += insert_token(
full_text,
entity_type,
......@@ -132,20 +140,29 @@ class ArkindexExtractor:
)
text_offset = entity.offset + entity.length
# Keep the exact separator after entity (it can be a space, a line break etc.)
if self.only_entities:
separator = next(
(char for char in full_text[text_offset:] if char in EMPTY_CHARS),
"",
)
text += separator
# Remaining text after the last entity
text += "".join(filter(keep_char, full_text[text_offset:]))
if keep_all_text:
return text
# Add some clean up to avoid several separators between entities.
text, full_text = "", text
for char, next_char in pairwise(full_text):
# If several separators follow each other, keep only the last one.
if (
char not in self.entity_separators
or next_char not in self.entity_separators
):
text += char
# Remaining text
if not self.only_entities:
text += full_text[text_offset:]
# Remaining char
remaining_char = full_text[-1]
if remaining_char not in self.entity_separators:
text += remaining_char
# Remove extra spaces
return text.strip("".join(EMPTY_CHARS))
# Remove separators at the beginning and end of text
return text.strip("".join(self.entity_separators))
def extract_transcription(self, element: Element):
"""
......@@ -160,7 +177,7 @@ class ArkindexExtractor:
transcription = random.choice(transcriptions)
if self.load_entities or self.only_entities:
if self.load_entities:
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
......@@ -249,7 +266,7 @@ def run(
output: Path,
load_entities: bool,
allow_unknown_entities: bool,
only_entities: bool,
entity_separators: list,
tokens: Path,
use_existing_split: bool,
train_folder: UUID,
......@@ -302,7 +319,7 @@ def run(
output=output,
load_entities=load_entities,
allow_unknown_entities=allow_unknown_entities,
only_entities=only_entities,
entity_separators=entity_separators,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
......
......@@ -4,12 +4,24 @@ from typing import NamedTuple
import pytest
from dan.datasets.extract.exceptions import NoEndTokenError, UnknownLabelError
from dan.datasets.extract.extract import ArkindexExtractor
from dan.datasets.extract.utils import EntityType, insert_token
# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
TOKENS = {
"P": EntityType(start="", end=""),
"D": EntityType(start="", end=""),
"N": EntityType(start="", end=""),
"I": EntityType(start="", end=""),
}
def filter_tokens(keys):
return {key: value for key, value in TOKENS.items() if key in keys}
@pytest.mark.parametrize(
"text,offset,length,expected",
......@@ -24,44 +36,165 @@ def test_insert_token(text, offset, length, expected):
)
def test_reconstruct_text_unknown_label_error():
arkindex_extractor = ArkindexExtractor()
arkindex_extractor.tokens = TOKENS
with pytest.raises(
UnknownLabelError, match="Label `X` is missing in the NER configuration."
):
arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611",
[
Entity(
offset=0,
length=3,
type="X",
value="n°1",
),
],
)
def test_reconstruct_text_no_end_token_error():
arkindex_extractor = ArkindexExtractor()
arkindex_extractor.tokens = {
"X": EntityType(start=""),
}
with pytest.raises(NoEndTokenError, match="Label `X` has no end token."):
arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611",
[
Entity(
offset=0,
length=3,
type="X",
value="n°1",
),
],
)
@pytest.mark.parametrize(
"only_entities,expected",
"entity_separators,tokens,expected",
(
(False, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ x Michou"),
(True, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
# Whole text...
# ... + All tokens
([], TOKENS, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
# ... + 1rst and 2nd tokens
([], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nMichou"),
# ... + 1rst and 3rd tokens
([], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ x 16 janvier 1611\nⓝMichouⓃ"),
# ... + 2nd and 3rd tokens
([], filter_tokens(["D", "N"]), "n°1 x ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
# Only entities...
# ... + All tokens
([" ", "\n"], TOKENS, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
# ... + 1rst and 2nd tokens
([" ", "\n"], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
# ... + 1rst and 3rd tokens
([" ", "\n"], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ\nⓝMichouⓃ"),
# ... + 2nd and 3rd tokens
([" ", "\n"], filter_tokens(["D", "N"]), "ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
),
)
def test_reconstruct_text(only_entities, expected):
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after):
arkindex_extractor = ArkindexExtractor(
allow_unknown_entities=True, only_entities=only_entities
allow_unknown_entities=True, entity_separators=entity_separators
)
arkindex_extractor.tokens = tokens
assert arkindex_extractor.reconstruct_text(
text_before + "n°1 x 16 janvier 1611\nMichou" + text_after,
[
Entity(
offset=0 + len(text_before),
length=3,
type="P",
value="n°1",
),
Entity(
offset=6 + len(text_before),
length=15,
type="D",
value="16 janvier 1611",
),
Entity(
offset=22 + len(text_before),
length=6,
type="N",
value="Michou",
),
],
) == (
(text_before if not entity_separators else "")
+ expected
+ (text_after if not entity_separators else "")
)
@pytest.mark.parametrize(
"entity_separators",
(
# Whole text
[],
# Only entities
[" ", "\n"],
),
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after):
arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators)
arkindex_extractor.tokens = TOKENS
assert arkindex_extractor.reconstruct_text(
text_before + "LouisXIV" + text_after,
[
Entity(
offset=0 + len(text_before),
length=5,
type="P",
value="Louis",
),
Entity(
offset=5 + len(text_before),
length=3,
type="I",
value="XIV",
),
],
) == (
(text_before if not entity_separators else "")
+ "ⓟLouisⓅⓘXIVⒾ"
+ (text_after if not entity_separators else "")
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_only_start_token(text_before, text_after):
arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"])
arkindex_extractor.tokens = {
"P": EntityType(start="", end=""),
"D": EntityType(start="", end=""),
"P": EntityType(start=""),
"I": EntityType(start=""),
}
assert (
arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611 x Michou",
text_before + "LouisXIV" + text_after,
[
Entity(
offset=0,
length=3,
offset=0 + len(text_before),
length=5,
type="P",
value="n°1",
value="Louis",
),
Entity(
offset=6,
length=15,
type="D",
value="16 janvier 1611",
),
Entity(
offset=24,
length=6,
type="N",
value="Michou",
offset=5 + len(text_before),
length=3,
type="I",
value="XIV",
),
],
)
== expected
== "ⓟLouisⓘXIV"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment