Skip to content
Snippets Groups Projects

Implement extraction command

Merged Yoann Schneider requested to merge implement-extraction-command into main
5 files
+ 242
68
Compare changes
  • Side-by-side
  • Inline
Files
5
@@ -10,6 +10,8 @@ import pathlib
import random
import uuid
from collections import defaultdict
from pathlib import Path
from typing import List, NamedTuple
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
@@ -28,6 +30,8 @@ IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
MANUAL_SOURCE = "manual"
Entity = NamedTuple("Entity", offset=int, length=int, label=str)
def parse_worker_version(worker_version_id):
if worker_version_id == MANUAL_SOURCE:
@@ -110,6 +114,13 @@ def add_extract_parser(subcommands) -> None:
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parent-element-type",
type=str,
help="Type of the parent element containing the data.",
required=False,
default="page",
)
parser.add_argument(
"--output",
type=pathlib.Path,
@@ -186,21 +197,23 @@ class ArkindexExtractor:
def __init__(
self,
client,
folders,
element_type,
split_names,
output,
load_entities,
tokens,
use_existing_split,
transcription_worker_version,
entity_worker_version,
train_prob,
val_prob,
client: ArkindexClient,
folders: list = [],
element_type: list = [],
parent_element_type: list = ["page"],
split_names: list = [],
output: Path = None,
load_entities: bool = None,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
) -> None:
self.client = client
self.element_type = element_type
self.parent_element_type = parent_element_type
self.split_names = split_names
self.output = output
self.load_entities = load_entities
@@ -213,7 +226,11 @@ class ArkindexExtractor:
self.get_subsets(folders)
def get_subsets(self, folders):
def get_subsets(self, folders: list):
"""
Assign each folder to its split if it's already known.
Assign None if it's unknown.
"""
if self.use_existing_split:
self.subsets = [
(folder, split) for folder, split in zip(folders, self.split_names)
@@ -221,35 +238,49 @@ class ArkindexExtractor:
else:
self.subsets = [(folder, None) for folder in folders]
def assign_random_split(self):
def _assign_random_split(self):
"""
assuming train_prob + valid_prob + test_prob = 1
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
"""
prob = random.random()
if prob <= self.train_prob:
return self.split_names[0]
yield self.split_names[0]
elif prob <= self.train_prob + self.val_prob:
return self.split_names[1]
yield self.split_names[1]
else:
return self.split_names[2]
yield self.split_names[2]
def extract_entities(self, transcription):
entities = self.client.request(
def get_random_split(self):
return next(self._assign_random_split())
def extract_entities(self, transcription: dict):
entities = self.client.paginate(
"ListTranscriptionEntities",
id=transcription["id"],
worker_version=self.entity_worker_version,
)
if entities["count"] == 0:
if entities is None:
logger.warning(
f"No entities found on transcription ({transcription['id']})."
)
return
else:
text = transcription["text"]
return [
Entity(
offset=entity["offset"],
length=entity["length"],
label=entity["entity"]["metas"]["subtype"],
)
for entity in entities
]
def reconstruct_text(self, text: str, entities: List[Entity]):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
count = 0
for entity in entities["results"]:
matching_tokens = self.tokens[entity["entity"]["metas"]["subtype"]]
for entity in entities:
matching_tokens = self.tokens[entity.label]
start_token, end_token = (
matching_tokens["start"],
matching_tokens["end"],
@@ -259,21 +290,24 @@ class ArkindexExtractor:
count,
start_token,
end_token,
offset=entity["offset"],
length=entity["length"],
offset=entity.offset,
length=entity.length,
)
return text
def extract_transcription(
self,
element,
element: dict,
):
"""
Extract the element's transcription.
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = self.client.request(
"ListTranscriptions",
id=element["id"],
worker_version=self.transcription_worker_version,
)
if transcriptions["count"] != 1:
logger.warning(
f"More than one transcription found on element ({element['id']}) with this config."
@@ -282,23 +316,26 @@ class ArkindexExtractor:
transcription = transcriptions["results"].pop()
if self.load_entities:
return self.extract_entities(transcription)
entities = self.extract_entities(transcription)
return self.reconstruct_text(transcription["text"], entities)
else:
return transcription["text"].strip()
def process_element(
self,
element,
split,
element: dict,
split: str,
):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
"""
text = self.extract_transcription(
element,
)
if not text:
logging.warning(
f"Skipping {element['id']} (zero or multiple transcriptions with worker_version=None)"
)
logging.warning(f"Skipping {element['id']}")
else:
logging.info(f"Processed {element['type']} {element['id']}")
@@ -315,29 +352,32 @@ class ArkindexExtractor:
save_image(im_path, image)
except Exception:
logger.error(f"Couldn't retrieve image of element ({element['id']}")
pass
raise
return element["id"]
def process_page(
def process_parent(
self,
page,
split,
parent: dict,
split: str,
):
# Extract only pages
"""
Extract data from a parent element.
Depending on the given types,
"""
data = defaultdict(list)
if self.element_type == ["page"]:
data["page"] = [
if self.element_type == [parent["type"]]:
data[self.element_type[0]] = [
self.process_element(
page,
parent,
split,
)
]
# Extract page's children elements (text_zone, text_line)
# Extract children elements
else:
for element_type in self.element_type:
for element in self.client.paginate(
"ListElementChildren",
id=page["id"],
id=parent["id"],
type=element_type,
recursive=True,
):
@@ -355,16 +395,19 @@ class ArkindexExtractor:
for subset_id, subset_split in self.subsets:
page_idx = 0
# Iterate over the pages to create splits at page level.
for page in tqdm(
for parent in tqdm(
self.client.paginate(
"ListElementChildren", id=subset_id, type="page", recursive=True
"ListElementChildren",
id=subset_id,
type=self.parent_element_type,
recursive=True,
)
):
page_idx += 1
split = subset_split or self.assign_random_split()
split = subset_split or self.get_random_split()
split_dict[split][page["id"]] = self.process_page(
page=page,
split_dict[split][parent["id"]] = self.process_parent(
parent=parent,
split=split,
)
@@ -374,6 +417,7 @@ class ArkindexExtractor:
def run(
parent,
element_type,
parent_element_type,
output,
load_entities,
tokens,
@@ -426,6 +470,7 @@ def run(
client=client,
folders=folders,
element_type=element_type,
parent_element_type=parent_element_type,
split_names=split_names,
output=output,
load_entities=load_entities,
Loading