Skip to content
Snippets Groups Projects
Commit 0ccc3e85 authored by Manon Blanco's avatar Manon Blanco
Browse files

Allow to keep text around entities

parent 28eb8e3f
No related branches found
No related tags found
No related merge requests found
......@@ -83,7 +83,14 @@ def add_extract_parser(subcommands) -> None:
# Optional arguments.
parser.add_argument(
"--load-entities", action="store_true", help="Extract text with their entities."
"--load-entities",
action="store_true",
help="Extract text with their entities.",
)
parser.add_argument(
"--only-entities",
action="store_true",
help="Remove all text that does not belong to the tokens.",
)
parser.add_argument(
"--allow-unknown-entities",
......
......@@ -62,3 +62,18 @@ class UnknownLabelError(ProcessingError):
def __str__(self) -> str:
return f"Label `{self.label}` is missing in the NER configuration."
class NoEndTokenError(ProcessingError):
"""
Raised when the specified label has no end token and there is potentially additional text around the labels
"""
label: str
def __init__(self, label: str, *args: object) -> None:
super().__init__(*args)
self.label = label
def __str__(self) -> str:
return f"Label `{self.label}` has no end token."
......@@ -17,11 +17,13 @@ from dan.datasets.extract.db import (
get_transcriptions,
)
from dan.datasets.extract.exceptions import (
NoEndTokenError,
NoTranscriptionError,
ProcessingError,
UnknownLabelError,
)
from dan.datasets.extract.utils import (
EntityType,
Subset,
download_image,
insert_token,
......@@ -34,6 +36,8 @@ IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"]
EMPTY_CHARS = [" ", "\n", "\t", "\r"]
class ArkindexExtractor:
"""
......@@ -48,6 +52,7 @@ class ArkindexExtractor:
output: Path = None,
load_entities: bool = False,
allow_unknown_entities: bool = False,
only_entities: bool = False,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
......@@ -62,6 +67,7 @@ class ArkindexExtractor:
self.output = output
self.load_entities = load_entities
self.allow_unknown_entities = allow_unknown_entities
self.only_entities = only_entities
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
......@@ -100,31 +106,46 @@ class ArkindexExtractor:
def get_random_split(self):
return next(self._assign_random_split())
def reconstruct_text(self, text: str, entities):
def reconstruct_text(self, full_text: str, entities):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
text, text_offset = "", 0
for entity in entities:
# Text before entity
if not self.only_entities:
text += full_text[text_offset : entity.offset]
# Filter entities
for entity in entities.copy():
# Tokens known for this entity
if entity.type in self.tokens:
continue
# Tokens unknown for this entity
if not self.allow_unknown_entities:
entity_type: EntityType = self.tokens.get(entity.type)
if not entity_type and not self.allow_unknown_entities:
raise UnknownLabelError(entity.type)
entities.remove(entity)
# Only keep text of the filtered entities
return " ".join(
insert_token(
text,
entity_type=self.tokens[entity.type],
offset=entity.offset,
length=entity.length,
)
for entity in entities
)
if entity_type and not entity_type.end and not self.only_entities:
raise NoEndTokenError(entity.type)
# Entity text (optionally with tokens)
if entity_type or not self.only_entities:
text += insert_token(
full_text,
entity_type,
offset=entity.offset,
length=entity.length,
)
text_offset = entity.offset + entity.length
# Keep the exact separator after entity (it can be a space, a line break etc.)
if self.only_entities:
separator = next(
(char for char in full_text[text_offset:] if char in EMPTY_CHARS),
"",
)
text += separator
# Remaining text
if not self.only_entities:
text += full_text[text_offset:]
# Remove extra spaces
return text.strip("".join(EMPTY_CHARS))
def extract_transcription(self, element: Element):
"""
......@@ -139,7 +160,7 @@ class ArkindexExtractor:
transcription = random.choice(transcriptions)
if self.load_entities:
if self.load_entities or self.only_entities:
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
......@@ -228,6 +249,7 @@ def run(
output: Path,
load_entities: bool,
allow_unknown_entities: bool,
only_entities: bool,
tokens: Path,
use_existing_split: bool,
train_folder: UUID,
......@@ -280,6 +302,7 @@ def run(
output=output,
load_entities=load_entities,
allow_unknown_entities=allow_unknown_entities,
only_entities=only_entities,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
......
......@@ -80,11 +80,11 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
"""
return (
# Starting token
entity_type.start
(entity_type.start if entity_type else "")
# Entity
+ text[offset : offset + length]
# End token
+ entity_type.end
+ (entity_type.end if entity_type else "")
)
......
......@@ -12,6 +12,7 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `Path` | |
| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` |
| `--only-entities` | Remove all text that does not belong to the tokens. | `bool` | `False` |
| `--allow-unknown-entities` | Ignore entities that do not appear in the list of tokens. | `bool` | `False` |
| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | |
| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | |
......
......@@ -25,14 +25,23 @@ def test_insert_token(text, offset, length, expected):
@pytest.mark.parametrize(
"tokens,text,entities,expected",
"only_entities,expected",
(
(
{
"P": EntityType(start="", end=""),
"D": EntityType(start="", end=""),
},
"n°1 x 16 janvier 1611",
(False, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ x Michou"),
(True, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
),
)
def test_reconstruct_text(only_entities, expected):
arkindex_extractor = ArkindexExtractor(
allow_unknown_entities=True, only_entities=only_entities
)
arkindex_extractor.tokens = {
"P": EntityType(start="", end=""),
"D": EntityType(start="", end=""),
}
assert (
arkindex_extractor.reconstruct_text(
"n°1 x 16 janvier 1611 x Michou",
[
Entity(
offset=0,
......@@ -46,33 +55,13 @@ def test_insert_token(text, offset, length, expected):
type="D",
value="16 janvier 1611",
),
],
"ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ",
),
(
{
"P": EntityType(start="", end=""),
},
"n°1 x 16 janvier 1611",
[
Entity(
offset=0,
length=3,
type="P",
value="n°1",
),
Entity(
offset=6,
length=15,
type="D",
value="16 janvier 1611",
offset=24,
length=6,
type="N",
value="Michou",
),
],
"ⓟn°1Ⓟ",
),
),
)
def test_reconstruct_text(tokens, text, entities, expected):
arkindex_extractor = ArkindexExtractor(allow_unknown_entities=True)
arkindex_extractor.tokens = tokens
assert arkindex_extractor.reconstruct_text(text, entities) == expected
)
== expected
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment