Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
......@@ -3,20 +3,43 @@
Extract dataset from Arkindex using API.
"""
import argparse
import pathlib
import uuid
from uuid import UUID
from dan.datasets.extract.extract_from_arkindex import run
from dan.datasets.extract.extract import run
MANUAL_SOURCE = "manual"
def validate_uuid(arg_uuid):
try:
return UUID(arg_uuid)
except ValueError:
raise argparse.ArgumentTypeError(f"`{arg_uuid}` is not a valid UUID.")
def parse_worker_version(worker_version_id):
if worker_version_id == MANUAL_SOURCE:
return False
validate_uuid(worker_version_id)
return worker_version_id
def validate_probability(proba):
try:
proba = float(proba)
except ValueError:
raise argparse.ArgumentTypeError(f"`{proba}` is not a valid float.")
if proba > 1 or proba < 0:
raise argparse.ArgumentTypeError(
f"`{proba}` is not a valid probability. Must be between 0 and 1 (both exclusive)."
)
return proba
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
......@@ -25,9 +48,14 @@ def add_extract_parser(subcommands) -> None:
)
# Required arguments.
parser.add_argument(
"database",
type=pathlib.Path,
help="Path where the data were exported from Arkindex.",
)
parser.add_argument(
"--parent",
type=uuid.UUID,
type=validate_uuid,
nargs="+",
help="ID of the parent folder to import from Arkindex.",
required=False,
......@@ -36,7 +64,7 @@ def add_extract_parser(subcommands) -> None:
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
help="Type of elements to retrieve.",
required=True,
)
parser.add_argument(
......@@ -55,7 +83,7 @@ def add_extract_parser(subcommands) -> None:
# Optional arguments.
parser.add_argument(
"--load-entities", action="store_true", help="Extract text with their entities"
"--load-entities", action="store_true", help="Extract text with their entities."
)
parser.add_argument(
"--tokens",
......@@ -72,19 +100,19 @@ def add_extract_parser(subcommands) -> None:
parser.add_argument(
"--train-folder",
type=uuid.UUID,
type=validate_uuid,
help="ID of the training folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--val-folder",
type=uuid.UUID,
type=validate_uuid,
help="ID of the validation folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--test-folder",
type=uuid.UUID,
type=validate_uuid,
help="ID of the testing folder to import from Arkindex.",
required=False,
)
......@@ -94,22 +122,28 @@ def add_extract_parser(subcommands) -> None:
type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
default=False,
)
parser.add_argument(
"--entity-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
default=False,
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set split size."
"--train-prob",
type=validate_probability,
default=0.7,
help="Training set split size.",
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set split size"
"--val-prob",
type=validate_probability,
default=0.15,
help="Validation set split size.",
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
import ast
from itertools import starmap
from typing import List, NamedTuple, Union
from urllib.parse import urljoin
from arkindex_export import Image
from arkindex_export.models import Element as ArkindexElement
from arkindex_export.models import Entity as ArkindexEntity
from arkindex_export.models import EntityType as ArkindexEntityType
from arkindex_export.models import Transcription as ArkindexTranscription
from arkindex_export.models import TranscriptionEntity as ArkindexTranscriptionEntity
from arkindex_export.queries import list_children
def bounding_box(polygon: list):
"""
Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return int(x), int(y), int(width), int(height)
# DB models
Transcription = NamedTuple(
"Transcription",
id=str,
text=str,
)
Entity = NamedTuple(
"Entity",
type=str,
value=str,
offset=float,
length=float,
)
class Element(NamedTuple):
id: str
type: str
polygon: str
url: str
width: str
height: str
@property
def bounding_box(self):
return bounding_box(ast.literal_eval(self.polygon))
@property
def image_url(self):
x, y, width, height = self.bounding_box
return urljoin(self.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg")
def get_elements(parent_id: str, element_type: str) -> List[Element]:
"""
Retrieve elements from an SQLite export of an Arkindex corpus
"""
# And load all the elements found in the CTE
query = (
list_children(parent_id=parent_id)
.join(Image)
.where(ArkindexElement.type == element_type)
.select(
ArkindexElement.id,
ArkindexElement.type,
ArkindexElement.polygon,
Image.url,
Image.width,
Image.height,
)
)
return list(
starmap(
Element,
query.tuples(),
)
)
def build_worker_version_filter(ArkindexModel, worker_version):
"""
`False` worker version means `manual` worker_version -> null field.
"""
if worker_version:
return ArkindexModel.worker_version == worker_version
else:
return ArkindexModel.worker_version.is_null()
def get_transcriptions(
element_id: str, transcription_worker_version: Union[str, bool]
) -> List[Transcription]:
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
"""
query = ArkindexTranscription.select(
ArkindexTranscription.id, ArkindexTranscription.text
).where(
(ArkindexTranscription.element == element_id)
& build_worker_version_filter(
ArkindexTranscription, worker_version=transcription_worker_version
)
)
return list(
starmap(
Transcription,
query.tuples(),
)
)
def get_transcription_entities(
transcription_id: str, entity_worker_version: Union[str, bool]
) -> List[Entity]:
"""
Retrieve transcription entities from an SQLite export of an Arkindex corpus
"""
query = (
ArkindexTranscriptionEntity.select(
ArkindexEntityType.name,
ArkindexEntity.name,
ArkindexTranscriptionEntity.offset,
ArkindexTranscriptionEntity.length,
)
.join(ArkindexEntity, on=ArkindexTranscriptionEntity.entity)
.join(ArkindexEntityType, on=ArkindexEntity.type)
.where(
(ArkindexTranscriptionEntity.transcription == transcription_id)
& build_worker_version_filter(
ArkindexTranscriptionEntity, worker_version=entity_worker_version
)
)
)
return list(
starmap(
Entity,
query.order_by(ArkindexTranscriptionEntity.offset).tuples(),
)
)
# -*- coding: utf-8 -*-
class ProcessingError(Exception):
"""
Raised there is a problem somewhere in the processing
"""
class ElementProcessingError(ProcessingError):
element_id: str
def __init__(self, element_id: str, *args: object) -> None:
super().__init__(*args)
self.element_id = element_id
class ImageDownloadError(ElementProcessingError):
"""
Raised when an element's image could not be downloaded
"""
error: Exception
def __init__(self, element_id: str, error: Exception, *args: object) -> None:
super().__init__(element_id, *args)
self.error = error
def __str__(self) -> str:
return (
f"Couldn't retrieve image of element ({self.element_id}: {str(self.error)})"
)
class NoTranscriptionError(ElementProcessingError):
"""
Raised when there are no transcriptions on an element
"""
def __str__(self) -> str:
return f"No transcriptions found on element ({self.element_id}) with this config. Skipping."
class MultipleTranscriptionsError(ElementProcessingError):
"""
Raised when there are more than one transcription on an element
"""
def __str__(self) -> str:
return f"More than one transcription found on element ({self.element_id}) with this config."
class UnknownLabelError(ProcessingError):
"""
Raised when the specified label is not known
"""
label: str
def __init__(self, label: str, *args: object) -> None:
super().__init__(*args)
self.label = label
def __str__(self) -> str:
return f"Label `{self.label}` is missing in the NER configuration."
# -*- coding: utf-8 -*-
import random
from collections import defaultdict
from pathlib import Path
from typing import List, Union
from uuid import UUID
from arkindex_export import open_database
from tqdm import tqdm
from dan import logger
from dan.datasets.extract.db import (
Element,
Entity,
get_elements,
get_transcription_entities,
get_transcriptions,
)
from dan.datasets.extract.exceptions import (
MultipleTranscriptionsError,
NoTranscriptionError,
ProcessingError,
UnknownLabelError,
)
from dan.datasets.extract.utils import (
EntityType,
Subset,
download_image,
insert_token,
parse_tokens,
save_json,
save_text,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"]
class ArkindexExtractor:
"""
Extract data from Arkindex
"""
def __init__(
self,
folders: list = [],
element_type: list = [],
parent_element_type: str = None,
output: Path = None,
load_entities: bool = None,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
) -> None:
self.element_type = element_type
self.parent_element_type = parent_element_type
self.output = output
self.load_entities = load_entities
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.subsets = self.get_subsets(folders)
def get_subsets(self, folders: list) -> List[Subset]:
"""
Assign each folder to its split if it's already known.
"""
if self.use_existing_split:
return [
Subset(folder, split) for folder, split in zip(folders, SPLIT_NAMES)
]
else:
return [Subset(folder) for folder in folders]
def _assign_random_split(self):
"""
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
"""
prob = random.random()
if prob <= self.train_prob:
yield SPLIT_NAMES[0]
elif prob <= self.train_prob + self.val_prob:
yield SPLIT_NAMES[1]
else:
yield SPLIT_NAMES[2]
def get_random_split(self):
return next(self._assign_random_split())
def reconstruct_text(self, text: str, entities: List[Entity]):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
count = 0
for entity in entities:
if entity.type not in self.tokens:
raise UnknownLabelError(entity.type)
entity_type: EntityType = self.tokens[entity.type]
text = insert_token(
text,
count,
entity_type,
offset=entity.offset,
length=entity.length,
)
count += entity_type.offset
return text
def extract_transcription(self, element: Element):
"""
Extract the element's transcription.
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = get_transcriptions(
element.id, self.transcription_worker_version
)
if len(transcriptions) > 1:
raise MultipleTranscriptionsError(element.id)
elif len(transcriptions) == 0:
raise NoTranscriptionError(element.id)
transcription = transcriptions.pop()
if self.load_entities:
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
else:
return transcription.text.strip()
def process_element(
self,
element: Element,
split: str,
):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
"""
text = self.extract_transcription(element)
txt_path = Path(
self.output, LABELS_DIR, split, f"{element.type}_{element.id}.txt"
)
save_text(txt_path, text)
im_path = Path(
self.output, IMAGES_DIR, split, f"{element.type}_{element.id}.jpg"
)
download_image(element, im_path)
return element.id
def process_parent(
self,
parent: Element,
split: str,
):
"""
Extract data from a parent element.
"""
data = defaultdict(list)
if self.element_type == [parent.type]:
try:
data[parent.type].append(self.process_element(parent, split))
except ProcessingError as e:
logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements
else:
for element_type in self.element_type:
for element in get_elements(parent.id, element_type):
try:
data[element_type].append(self.process_element(element, split))
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
return data
def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels.
for idx, subset in enumerate(self.subsets, start=1):
# Iterate over the pages to create splits at page level.
for parent in tqdm(
get_elements(subset.id, self.parent_element_type),
desc=f"Processing {subset} {idx}/{len(self.subsets)}",
):
split = subset.split or self.get_random_split()
split_dict[split][parent.id] = self.process_parent(
parent=parent,
split=split,
)
save_json(self.output / "split.json", split_dict)
def run(
database: Path,
parent: list,
element_type: str,
parent_element_type: str,
output: Path,
load_entities: bool,
tokens: Path,
use_existing_split: bool,
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Union[str, bool],
entity_worker_version: Union[str, bool],
train_prob,
val_prob,
):
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
assert use_existing_split ^ bool(
parent
), "Only one of `--use-existing-split` and `--parent` must be set"
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
if use_existing_split:
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [str(train_folder), str(val_folder), str(test_folder)]
else:
folders = [str(parent_id) for parent_id in parent]
if load_entities:
assert tokens, "Please provide the entities to match."
# Create directories
for split in SPLIT_NAMES:
Path(output, LABELS_DIR, split).mkdir(parents=True, exist_ok=True)
Path(output, IMAGES_DIR, split).mkdir(parents=True, exist_ok=True)
ArkindexExtractor(
folders=folders,
element_type=element_type,
parent_element_type=parent_element_type,
output=output,
load_entities=load_entities,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
).run()
# -*- coding: utf-8 -*-
import json
import logging
import time
from pathlib import Path
from typing import NamedTuple
import cv2
import imageio.v2 as iio
import yaml
from numpy import ndarray
from dan.datasets.extract.db import Element
from dan.datasets.extract.exceptions import ImageDownloadError
def save_text(path, text):
with open(path, "w") as f:
logger = logging.getLogger(__name__)
MAX_RETRIES = 5
class Subset(NamedTuple):
id: str
split: str = None
def __str__(self) -> str:
return (
f"Subset(id='{self.id}', split='{self.split.capitalize()}')"
if self.split
else f"Subset(id='{self.id}')"
)
class EntityType(NamedTuple):
start: str
end: str = ""
@property
def offset(self):
return len(self.start) + len(self.end)
def download_image(element: Element, im_path: Path):
tries = 1
# retry loop
while True:
if tries > MAX_RETRIES:
raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
try:
image = iio.imread(element.image_url)
save_image(im_path, image)
return
except TimeoutError:
logger.warning("Timeout, retry in 1 second.")
time.sleep(1)
tries += 1
except Exception as e:
raise ImageDownloadError(element.id, e)
def save_text(path: Path, text: str):
with path.open("w") as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_image(path: Path, image: ndarray):
cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
def save_json(path: Path, data: dict):
with path.open("w") as outfile:
json.dump(data, outfile, indent=4)
def insert_token(text, count, start_token, end_token, offset, length):
def insert_token(
text: str, count: int, entity_type: EntityType, offset: int, length: int
) -> str:
"""
Insert the given tokens at the right position in the text
start_token or end_token can be empty strings
"""
text = (
return (
# Text before entity
text[: count + offset]
# Starting token
+ start_token
+ entity_type.start
# Entity
+ text[count + offset : count + 1 + offset + length]
+ text[count + offset : count + offset + length]
# End token
+ end_token
+ entity_type.end
# Text after entity
+ text[count + 1 + offset + length :]
+ text[count + offset + length :]
)
token_offset = len(start_token) + len(end_token)
return text, count + token_offset
def parse_tokens(filename):
with open(filename) as f:
return yaml.safe_load(f)
def parse_tokens(filename: Path) -> dict:
with filename.open() as f:
return {
name: EntityType(**tokens) for name, tokens in yaml.safe_load(f).items()
}
......@@ -2,10 +2,11 @@
## Description
Use the `teklia-dan dataset extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model.
Use the `teklia-dan dataset extract` command to extract a dataset from an Arkindex export database (SQLite format). This will generate the images and the labels needed to train a DAN model.
| Parameter | Description | Type | Default |
| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `Path` | |
| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str|uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
......@@ -24,9 +25,9 @@ Use the `teklia-dan dataset extract` command to extract a dataset from Arkindex.
The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively.
```yaml
---
INTITULE:
start: ⓘ
end: Ⓘ
INTITULE: # Type of the entity on Arkindex
start: ⓘ # Starting token for this entity
end: Ⓘ # Optional ending token for this entity
DATE:
start: ⓓ
end: Ⓓ
......@@ -54,6 +55,7 @@ CLASSEMENT:
To extract HTR+NER data from **pages** from a folder, use the following command:
```shell
teklia-dan dataset extract \
database.sqlite \
--parent folder_uuid \
--element-type page \
--output data \
......@@ -66,6 +68,7 @@ with `tokens.yml` compliant with the format described before.
To do the same but only use the data from three folders, the commands becomes:
```shell
teklia-dan dataset extract \
database.sqlite \
--parent folder1_uuid folder2_uuid folder3_uuid \
--element-type page \
--output data \
......@@ -77,6 +80,7 @@ teklia-dan dataset extract \
To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes:
```shell
teklia-dan dataset extract \
database.sqlite \
--use-existing-split \
--train-folder train_folder_uuid \
--val-folder val_folder_uuid \
......@@ -91,6 +95,7 @@ teklia-dan dataset extract \
To extract HTR data from **annotations** and **text_zones** from a folder, but only keep those that are children of **single_pages**, use the following command:
```shell
teklia-dan dataset extract \
database.sqlite \
--parent folder_uuid \
--element-type text_zone annotation \
--parent-element-type single_page \
......
arkindex-client==1.0.11
boto3==1.26.96
arkindex-export==0.1.1
boto3==1.26.97
editdistance==0.6.2
fontTools==4.39.2
imageio==2.26.1
......
# -*- coding: utf-8 -*-
import os
from pathlib import Path
import pytest
from arkindex_export import open_database
FIXTURES = Path(__file__).resolve().parent / "data"
@pytest.fixture(autouse=True)
......@@ -18,3 +22,16 @@ def setup_environment(responses):
# Set schema url in environment
os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
@pytest.fixture
def database_path():
return FIXTURES / "export.sqlite"
@pytest.fixture(autouse=True)
def demo_db(database_path):
"""
Open connection towards a known demo database
"""
open_database(database_path)
File added
# -*- coding: utf-8 -*-
import pytest
from dan.datasets.extract.db import (
Element,
Entity,
Transcription,
get_elements,
get_transcription_entities,
get_transcriptions,
)
def test_get_elements():
"""
Assert elements retrieval output against verified results
"""
elements = get_elements(
parent_id="d2b9fe93-3198-42de-8c07-f4ab67990e21",
element_type="page",
)
# Check number of results
assert len(elements) == 4
assert all(isinstance(element, Element) for element in elements)
# ID verification
assert [element.id for element in elements] == [
"0c8c62ef-3a2b-4b7b-a1bb-b0048864ee08",
"e03c645b-017d-4946-b65c-491eeeb6888b",
"e1676371-0d3e-44cc-a6e4-c4f33af878a9",
"bc60dfc9-f180-48b0-873c-ba7629d4f6d8",
]
@pytest.mark.parametrize(
"worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60")
)
def test_get_transcriptions(worker_version):
"""
Assert transcriptions retrieval output against verified results
"""
element_id = "a3bf4b60-a149-49b4-80dd-5fbe27137efa"
transcriptions = get_transcriptions(
element_id=element_id,
transcription_worker_version=worker_version,
)
# Check number of results
assert len(transcriptions) == 1
transcription = transcriptions.pop()
assert isinstance(transcription, Transcription)
# Common keys
assert transcription.text == "[ T 8º SUP 26200"
# Differences
if worker_version:
assert transcription.id == "3bd248d6-998a-4579-a00c-d4639f3825aa"
else:
assert transcription.id == "c551960a-0f82-4779-b975-77a457bcf273"
@pytest.mark.parametrize(
"worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965")
)
def test_get_transcription_entities(worker_version):
transcription_id = "3bd248d6-998a-4579-a00c-d4639f3825aa"
entities = get_transcription_entities(
transcription_id=transcription_id,
entity_worker_version=worker_version,
)
# Check number of results
assert len(entities) == 1
transcription_entity = entities.pop()
assert isinstance(transcription_entity, Entity)
# Differences
if worker_version:
assert transcription_entity.type == "cote"
assert transcription_entity.value == "T 8 º SUP 26200"
assert transcription_entity.offset == 2
assert transcription_entity.length == 14
else:
assert transcription_entity.type == "Cote"
assert transcription_entity.value == "[ T 8º SUP 26200"
assert transcription_entity.offset == 0
assert transcription_entity.length == 16
# -*- coding: utf-8 -*-
import pytest
from arkindex.mock import MockApiClient
from dan.datasets.extract.extract_from_arkindex import ArkindexExtractor, Entity
from dan.datasets.extract.utils import insert_token
@pytest.fixture
def arkindex_extractor():
return ArkindexExtractor(
client=MockApiClient(), split_names=["train", "val", "test"]
)
from dan.datasets.extract.extract import ArkindexExtractor, Entity
from dan.datasets.extract.utils import EntityType, insert_token
@pytest.mark.parametrize(
"text,count,offset,length,expected",
(
("n°1 16 janvier 1611", 0, 0, 3, "ⓘn°1 Ⓘ16 janvier 1611"),
("ⓘn°1 Ⓘ16 janvier 1611", 2, 4, 15, "ⓘn°1 Ⓘⓘ16 janvier 1611Ⓘ"),
("n°1 16 janvier 1611", 0, 0, 3, "ⓘn°1Ⓘ 16 janvier 1611"),
("ⓘn°1Ⓘ 16 janvier 1611", 2, 4, 15, "ⓘn°1Ⓘ ⓘ16 janvier 1611Ⓘ"),
),
)
def test_insert_token(text, count, offset, length, expected):
start_token, end_token = "", ""
assert (
insert_token(text, count, start_token, end_token, offset, length)[0] == expected
insert_token(text, count, EntityType(start="", end=""), offset, length)
== expected
)
......@@ -34,87 +26,27 @@ def test_insert_token(text, count, offset, length, expected):
(
"n°1 16 janvier 1611",
[
Entity(offset=0, length=3, label="P"),
Entity(offset=4, length=15, label="D"),
Entity(
offset=0,
length=3,
type="P",
value="n°1",
),
Entity(
offset=4,
length=15,
type="D",
value="16 janvier 1611",
),
],
"ⓟn°1 Ⓟⓓ16 janvier 1611Ⓓ",
"ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ",
),
),
)
def test_reconstruct_text(arkindex_extractor, text, entities, expected):
def test_reconstruct_text(text, entities, expected):
arkindex_extractor = ArkindexExtractor()
arkindex_extractor.tokens = {
"P": {"start": "", "end": ""},
"D": {"start": "", "end": ""},
"P": EntityType(start="", end=""),
"D": EntityType(start="", end=""),
}
assert arkindex_extractor.reconstruct_text(text, entities) == expected
@pytest.mark.parametrize(
"text,offset,length,label,expected",
(
(" n°1 16 janvier 1611 ", None, None, None, "n°1 16 janvier 1611"),
("n°1 16 janvier 1611", 0, 3, "P", "ⓟn°1 Ⓟ16 janvier 1611"),
),
)
def test_extract_transcription(
arkindex_extractor, text, offset, length, label, expected
):
element = {"id": "element_id"}
transcription = {"id": "transcription_id", "text": text}
arkindex_extractor.client.add_response(
"ListTranscriptions",
id="element_id",
worker_version=None,
response={"count": 1, "results": [transcription]},
)
if label:
arkindex_extractor.load_entities = True
arkindex_extractor.tokens = {
"P": {"start": "", "end": ""},
}
arkindex_extractor.client.add_response(
"ListTranscriptionEntities",
id="transcription_id",
worker_version=None,
response=[
{
"entity": {"id": "entity_id", "metas": {"subtype": label}},
"offset": offset,
"length": length,
"worker_version": None,
"worker_run_id": None,
}
],
)
assert arkindex_extractor.extract_transcription(element) == expected
@pytest.mark.parametrize(
"offset,length,label",
((0, 3, "P"),),
)
def test_extract_entities(arkindex_extractor, offset, length, label):
transcription = {"id": "transcription_id"}
arkindex_extractor.tokens = {
"P": {"start": "", "end": ""},
}
arkindex_extractor.client.add_response(
"ListTranscriptionEntities",
id="transcription_id",
worker_version=None,
response=[
{
"entity": {"id": "entity_id", "metas": {"subtype": label}},
"offset": offset,
"length": length,
"worker_version": None,
"worker_run_id": None,
}
],
)
assert arkindex_extractor.extract_entities(transcription) == [
Entity(offset=offset, length=length, label=label)
]