Skip to content
Snippets Groups Projects
Commit a520d362 authored by Eva Bardou's avatar Eva Bardou :frog:
Browse files

Move code

parent 40b30992
No related branches found
No related tags found
1 merge request!8New DatasetExtractor using a DatasetWorker
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
......@@ -47,11 +47,6 @@ setup(
author="Teklia",
author_email="contact@teklia.com",
install_requires=parse_requirements(),
entry_points={
"console_scripts": [
f"{COMMAND}={MODULE}.worker:main",
"worker-generic-training-dataset-new=worker_generic_training_dataset.dataset_worker:main",
]
},
entry_points={"console_scripts": [f"{COMMAND}={MODULE}.worker:main"]},
packages=find_packages(),
)
# -*- coding: utf-8 -*-
import logging
import sys
import tempfile
from argparse import Namespace
from itertools import groupby
from operator import itemgetter
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import Iterator, List, Optional, Tuple
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
create_tables,
create_version_table,
)
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
list_transcriptions,
)
from worker_generic_training_dataset.utils import build_image_url
from worker_generic_training_dataset.worker import (
BULK_BATCH_SIZE,
DEFAULT_TRANSCRIPTION_ORIENTATION,
)
logger: logging.Logger = logging.getLogger(__name__)
class DatasetWorker(BaseWorker, DatasetMixin):
def __init__(
self,
description: str = "Arkindex Elements Worker",
support_cache: bool = False,
generator: bool = False,
):
super().__init__(description, support_cache)
self.parser.add_argument(
"--dataset",
type=UUID,
nargs="+",
help="One or more Arkindex dataset ID",
)
self.generator = generator
def list_dataset_elements_per_set(
self, dataset: Dataset
) -> Iterator[Tuple[str, Element]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def format_element(element):
return Element.get(Element.id == element[1].id)
def format_set(set):
return (set[0], list(map(format_element, list(set[1]))))
return list(
map(
format_set,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
)
def process_dataset(self, dataset: Dataset):
"""
Override this method to implement your worker and process a single Arkindex dataset at once.
:param dataset: The dataset to process.
"""
def list_datasets(self) -> List[Dataset] | List[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if self.is_read_only:
return list(map(str, self.args.dataset))
return self.list_process_datasets()
def run(self):
self.configure()
datasets: List[Dataset] | List[str] = self.list_datasets()
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(datasets)
failed = 0
for i, item in enumerate(datasets, start=1):
dataset = None
try:
if not self.is_read_only:
# Just use the result of list_datasets as the dataset
dataset = item
else:
# Load dataset using the Arkindex API
dataset = Dataset(**self.request("RetrieveDataset", id=item))
if self.generator:
assert (
dataset.state == DatasetState.Open.value
), "When generating a new dataset, its state should be Open"
else:
assert (
dataset.state == DatasetState.Complete.value
), "When processing an existing dataset, its state should be Complete"
if self.generator:
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
# Process the dataset
self.process_dataset(dataset)
if self.generator:
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while retrieving, processing or patching the state for this dataset.
failed += 1
# Handle the case where we failed retrieving the dataset
dataset_id = dataset.id if dataset else item
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
else:
message = (
f"Failed running worker on dataset {dataset_id}: {repr(e)}"
)
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
if dataset and self.generator:
# Try to update the state to Error regardless of the response
try:
self.update_dataset_state(dataset, DatasetState.Error)
except Exception:
pass
if failed:
logger.error(
"Ran on {} dataset: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
class DatasetExtractor(DatasetWorker):
def configure(self) -> None:
self.args: Namespace = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
if self.user_configuration:
logger.info("Overriding with user_configuration")
self.config.update(self.user_configuration)
# Download corpus
self.download_latest_export()
# Initialize db that will be written
self.configure_cache()
# CachedImage downloaded and created in DB
self.cached_images = dict()
# Where to save the downloaded images
self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
logger.info(f"Images will be saved at `{self.image_folder}`.")
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
# Remove previous execution result if present
self.cache_path.unlink(missing_ok=True)
init_cache_db(self.cache_path)
create_version_table()
create_tables()
def download_latest_export(self) -> None:
"""
Download the latest export of the current corpus.
Export must be in `"done"` state.
"""
try:
exports = list(
self.api_client.paginate(
"ListExports",
id=self.corpus_id,
)
)
except ErrorResponse as e:
logger.error(
f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=itemgetter("updated"),
reverse=True,
)
assert (
len(exports) > 0
), f"No available exports found for the corpus {self.corpus_id}."
# Download latest export
try:
export_id: str = exports[0]["id"]
logger.info(f"Downloading export ({export_id})...")
self.export: _TemporaryFileWrapper = self.api_client.request(
"DownloadExport",
id=export_id,
)
logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
open_database(self.export.name)
except ErrorResponse as e:
logger.error(
f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
worker_run_id=classification.worker_run,
)
for classification in list_classifications(element.id)
]
if classifications:
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
CachedClassification.bulk_create(
model_list=classifications,
batch_size=BULK_BATCH_SIZE,
)
def insert_transcriptions(
self, element: CachedElement
) -> List[CachedTranscription]:
logger.info("Listing transcriptions")
transcriptions: list[CachedTranscription] = [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription.worker_version,
worker_run_id=transcription.worker_run,
)
for transcription in list_transcriptions(element.id)
]
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
CachedTranscription.bulk_create(
model_list=transcriptions,
batch_size=BULK_BATCH_SIZE,
)
return transcriptions
def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
logger.info("Listing entities")
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription in transcriptions:
for transcription_entity in list_transcription_entities(transcription.id):
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
worker_run_id=transcription_entity.entity.worker_run,
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
worker_run_id=transcription_entity.worker_run,
)
)
if entities:
# First insert entities since they are foreign keys on transcription entities
logger.info(f"Inserting {len(entities)} entities")
with cache_database.atomic():
CachedEntity.bulk_create(
model_list=entities,
batch_size=BULK_BATCH_SIZE,
)
if transcription_entities:
# Insert transcription entities
logger.info(
f"Inserting {len(transcription_entities)} transcription entities"
)
with cache_database.atomic():
CachedTranscriptionEntity.bulk_create(
model_list=transcription_entities,
batch_size=BULK_BATCH_SIZE,
)
def insert_element(
self, element: Element, parent_id: Optional[UUID] = None
) -> None:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn't already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions' entities (both Entity and TranscriptionEntity)
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.image_folder / f"{element.image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
# Store images in case some other elements use it as well
with cache_database.atomic():
self.cached_images[element.image.id] = CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
)
# Insert element
logger.info("Inserting element")
with cache_database.atomic():
cached_element: CachedElement = CachedElement.create(
id=element.id,
parent_id=parent_id,
type=element.type,
image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=element.worker_version,
worker_run_id=element.worker_run,
confidence=element.confidence,
)
# Insert classifications
self.insert_classifications(cached_element)
# Insert transcriptions
transcriptions: List[CachedTranscription] = self.insert_transcriptions(
cached_element
)
# Insert entities
self.insert_entities(transcriptions)
def process_set(self, set_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the cache with information from elements in the set {set_name}"
)
# First list all pages
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(element)
# List children
children = list_children(element.id)
nb_children: int = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=element.id)
def process_dataset(self, dataset: Dataset):
# Iterate over given sets
for set_name, elements in self.list_dataset_elements_per_set(dataset):
self.process_set(set_name, elements)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
def main():
DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images",
generator=True,
).run()
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
import logging
import operator
import sys
import tempfile
from argparse import Namespace
from itertools import groupby
from operator import itemgetter
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import List, Optional
from typing import Iterator, List, Optional, Tuple
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, Image, open_database
from arkindex_export import Element, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
......@@ -24,13 +26,14 @@ from arkindex_worker.cache import (
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
list_transcriptions,
retrieve_element,
)
from worker_generic_training_dataset.utils import build_image_url
......@@ -40,7 +43,142 @@ BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
class DatasetExtractor(BaseWorker):
class DatasetWorker(BaseWorker, DatasetMixin):
def __init__(
self,
description: str = "Arkindex Elements Worker",
support_cache: bool = False,
generator: bool = False,
):
super().__init__(description, support_cache)
self.parser.add_argument(
"--dataset",
type=UUID,
nargs="+",
help="One or more Arkindex dataset ID",
)
self.generator = generator
def list_dataset_elements_per_set(
self, dataset: Dataset
) -> Iterator[Tuple[str, Element]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def format_element(element):
return Element.get(Element.id == element[1].id)
def format_set(set):
return (set[0], list(map(format_element, list(set[1]))))
return list(
map(
format_set,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
)
def process_dataset(self, dataset: Dataset):
"""
Override this method to implement your worker and process a single Arkindex dataset at once.
:param dataset: The dataset to process.
"""
def list_datasets(self) -> List[Dataset] | List[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if self.is_read_only:
return list(map(str, self.args.dataset))
return self.list_process_datasets()
def run(self):
self.configure()
datasets: List[Dataset] | List[str] = self.list_datasets()
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(datasets)
failed = 0
for i, item in enumerate(datasets, start=1):
dataset = None
try:
if not self.is_read_only:
# Just use the result of list_datasets as the dataset
dataset = item
else:
# Load dataset using the Arkindex API
dataset = Dataset(**self.request("RetrieveDataset", id=item))
if self.generator:
assert (
dataset.state == DatasetState.Open.value
), "When generating a new dataset, its state should be Open"
else:
assert (
dataset.state == DatasetState.Complete.value
), "When processing an existing dataset, its state should be Complete"
if self.generator:
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
# Process the dataset
self.process_dataset(dataset)
if self.generator:
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while retrieving, processing or patching the state for this dataset.
failed += 1
# Handle the case where we failed retrieving the dataset
dataset_id = dataset.id if dataset else item
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
else:
message = (
f"Failed running worker on dataset {dataset_id}: {repr(e)}"
)
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
if dataset and self.generator:
# Try to update the state to Error regardless of the response
try:
self.update_dataset_state(dataset, DatasetState.Error)
except Exception:
pass
if failed:
logger.error(
"Ran on {} dataset: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
class DatasetExtractor(DatasetWorker):
def configure(self) -> None:
self.args: Namespace = self.parser.parse_args()
if self.is_read_only:
......@@ -52,9 +190,6 @@ class DatasetExtractor(BaseWorker):
logger.info("Overriding with user_configuration")
self.config.update(self.user_configuration)
# Read process information
self.read_training_related_information()
# Download corpus
self.download_latest_export()
......@@ -68,28 +203,6 @@ class DatasetExtractor(BaseWorker):
self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
logger.info(f"Images will be saved at `{self.image_folder}`.")
def read_training_related_information(self) -> None:
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
"""
logger.info("Retrieving information from process_information")
train_folder_id = self.process_information.get("train_folder_id")
assert train_folder_id, "A training folder id is necessary to use this worker"
self.training_folder_id = UUID(train_folder_id)
val_folder_id = self.process_information.get("validation_folder_id")
assert val_folder_id, "A validation folder id is necessary to use this worker"
self.validation_folder_id = UUID(val_folder_id)
test_folder_id = self.process_information.get("test_folder_id")
self.testing_folder_id: UUID | None = (
UUID(test_folder_id) if test_folder_id else None
)
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
......@@ -126,7 +239,7 @@ class DatasetExtractor(BaseWorker):
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=operator.itemgetter("updated"),
key=itemgetter("updated"),
reverse=True,
)
assert (
......@@ -301,52 +414,34 @@ class DatasetExtractor(BaseWorker):
# Insert entities
self.insert_entities(transcriptions)
def process_split(self, split_name: str, split_id: UUID) -> None:
"""
Insert all elements under the given parent folder (all queries are recursive).
- `page` elements are linked to this folder (via parent_id foreign key)
- `page` element children are linked to their `page` parent (via parent_id foreign key)
"""
def process_set(self, set_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the Base-Worker cache with information from children under element ({split_id})"
f"Filling the cache with information from elements in the set {set_name}"
)
# Fill cache
# Retrieve parent and create parent
parent: Element = retrieve_element(split_id)
self.insert_element(parent)
# First list all pages
pages = list_children(split_id).join(Image).where(Element.type == "page")
nb_pages: int = pages.count()
for idx, page in enumerate(pages, start=1):
logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(page, parent_id=split_id)
self.insert_element(element)
# List children
children = list_children(page.id)
children = list_children(element.id)
nb_children: int = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=page.id)
def run(self):
self.configure()
self.insert_element(child, parent_id=element.id)
# Iterate over given split
for split_name, split_id in [
("Train", self.training_folder_id),
("Validation", self.validation_folder_id),
("Test", self.testing_folder_id),
]:
if not split_id:
continue
self.process_split(split_name, split_id)
def process_dataset(self, dataset: Dataset):
# Iterate over given sets
for set_name, elements in self.list_dataset_elements_per_set(dataset):
self.process_set(set_name, elements)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd"
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
......@@ -354,7 +449,7 @@ class DatasetExtractor(BaseWorker):
def main():
DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images",
support_cache=True,
generator=True,
).run()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment