Skip to content
Snippets Groups Projects
Verified Commit 231047c5 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Refactor and implement API version of the worker

parent e4452dc1
No related branches found
No related tags found
1 merge request!25Draft: Refactor and implement API version of the worker
Pipeline #170921 passed
This commit is part of merge request !25. Comments created here will be created in the context of that merge request.
......@@ -7,3 +7,9 @@ workers:
type: data-extract
docker:
build: Dockerfile
- slug: generic-training-dataset-api
name: Generic Training Dataset Extractor (API)
type: data-extract
docker:
build: Dockerfile
command: worker-generic-training-dataset-api
\ No newline at end of file
......@@ -47,6 +47,11 @@ setup(
author="Teklia",
author_email="contact@teklia.com",
install_requires=parse_requirements(),
entry_points={"console_scripts": [f"{COMMAND}={MODULE}.worker:main"]},
entry_points={
"console_scripts": [
f"{COMMAND}={MODULE}.worker:main",
f"{COMMAND}-api={MODULE}.worker:main",
]
},
packages=find_packages(),
)
......@@ -15,11 +15,11 @@ from arkindex_worker.cache import (
CachedTranscription,
CachedTranscriptionEntity,
)
from worker_generic_training_dataset.worker import DatasetExtractor
from worker_generic_training_dataset.from_sql import DatasetExtractorFromSQL
def test_process_split(tmp_path, downloaded_images):
worker = DatasetExtractor()
worker = DatasetExtractorFromSQL()
# Parse some arguments
worker.args = Namespace(database=None)
worker.data_folder_path = tmp_path
......
# -*- coding: utf-8 -*-
import contextlib
import json
import logging
import sys
import tempfile
import uuid
from itertools import groupby
from operator import attrgetter
from pathlib import Path
from typing import List, Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, WorkerRun, WorkerVersion
from arkindex_worker.cache import (
CachedClassification,
CachedDataset,
CachedDatasetElement,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
create_tables,
create_version_table,
)
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Set
from arkindex_worker.models import Transcription as ArkindexTranscription
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker import DatasetWorker
from arkindex_worker.worker.dataset import DatasetState
from peewee import CharField
from worker_generic_training_dataset.utils import build_image_url
logger: logging.Logger = logging.getLogger(__name__)
BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None:
return instance.id if instance else None
class Extractor(DatasetWorker):
def configure_storage(self) -> None:
self.data_folder = tempfile.TemporaryDirectory(suffix="-arkindex-data")
self.data_folder_path = Path(self.data_folder.name)
# Initialize db that will be written
self.configure_cache()
# CachedImage downloaded and created in DB
self.cached_images = dict()
# Where to save the downloaded images
self.images_folder = self.data_folder_path / "images"
self.images_folder.mkdir(parents=True)
logger.info(f"Images will be saved at `{self.images_folder}`.")
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.data_folder_path / "db.sqlite"
logger.info(f"Cached database will be saved at `{self.cache_path}`.")
init_cache_db(self.cache_path)
create_version_table()
create_tables()
def list_classifications(self, element_id: UUID):
raise NotImplementedError
def list_transcriptions(self, element: ArkindexElement, **kwargs):
raise NotImplementedError
def list_transcription_entities(
self, transcription: ArkindexTranscription, **kwargs
):
raise NotImplementedError
def list_element_children(self, element: ArkindexElement, **kwargs):
raise NotImplementedError
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
worker_run_id=get_object_id(classification.worker_run),
)
for classification in self.list_classifications(element.id)
]
if classifications:
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
CachedClassification.bulk_create(
model_list=classifications,
batch_size=BULK_BATCH_SIZE,
)
def insert_transcriptions(
self, element: CachedElement
) -> List[CachedTranscription]:
logger.info("Listing transcriptions")
transcriptions: list[CachedTranscription] = [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=get_object_id(transcription.worker_version),
worker_run_id=get_object_id(transcription.worker_run),
)
for transcription in self.list_transcriptions(
ArkindexElement(id=element.id)
)
]
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
CachedTranscription.bulk_create(
model_list=transcriptions,
batch_size=BULK_BATCH_SIZE,
)
return transcriptions
def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
logger.info("Listing entities")
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription in transcriptions:
for transcription_entity in self.list_transcription_entities(
ArkindexTranscription(id=transcription.id)
):
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
worker_run_id=get_object_id(transcription_entity.entity.worker_run),
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
worker_run_id=get_object_id(transcription_entity.worker_run),
)
)
if entities:
# First insert entities since they are foreign keys on transcription entities
logger.info(f"Inserting {len(entities)} entities")
with cache_database.atomic():
CachedEntity.bulk_create(
model_list=entities,
batch_size=BULK_BATCH_SIZE,
)
if transcription_entities:
# Insert transcription entities
logger.info(
f"Inserting {len(transcription_entities)} transcription entities"
)
with cache_database.atomic():
CachedTranscriptionEntity.bulk_create(
model_list=transcription_entities,
batch_size=BULK_BATCH_SIZE,
)
def insert_element(
self,
element: Element,
split_name: Optional[str] = None,
parent_id: Optional[str] = None,
) -> None:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn't already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions' entities (both Entity and TranscriptionEntity)
The element will also be linked to the appropriate split in the current dataset.
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.images_folder / f"{element.image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
# Store images in case some other elements use it as well
with cache_database.atomic():
self.cached_images[element.image.id] = CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
)
# Insert element
logger.info("Inserting element")
with cache_database.atomic():
cached_element: CachedElement = CachedElement.create(
id=element.id,
parent_id=parent_id,
type=element.type,
image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=get_object_id(element.worker_version),
worker_run_id=get_object_id(element.worker_run),
confidence=element.confidence,
)
# Insert classifications
self.insert_classifications(cached_element)
# Insert transcriptions
transcriptions: List[CachedTranscription] = self.insert_transcriptions(
cached_element
)
# Insert entities
self.insert_entities(transcriptions)
# Link the element to the dataset split
if split_name:
logger.info(
f"Linking element {cached_element.id} to dataset ({self.cached_dataset.id})"
)
with cache_database.atomic():
cached_element: CachedDatasetElement = CachedDatasetElement.create(
id=uuid.uuid4(),
element=cached_element,
dataset=self.cached_dataset,
set_name=split_name,
)
def process_split(self, split_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the cache with information from elements in the split {split_name}"
)
# First list all pages
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(element, split_name=split_name)
# List children
children = self.list_element_children(ArkindexElement(id=element.id))
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx})")
# Insert child
self.insert_element(child, parent_id=element.id)
def insert_dataset(self, dataset: Dataset) -> None:
"""
Insert the given dataset in the cache database.
:param dataset: Dataset to insert.
"""
logger.info(f"Inserting dataset ({dataset.id})")
with cache_database.atomic():
self.cached_dataset = CachedDataset.create(
id=dataset.id,
name=dataset.name,
state=dataset.state,
sets=json.dumps(dataset.sets),
)
def process_dataset(self, dataset: Dataset, sets: list[Set]):
# Configure temporary storage for the dataset data (cache + images)
self.configure_storage()
# Insert dataset in cache database
self.insert_dataset(dataset)
# Iterate over given splits
for dataset_set in sets:
elements = self.list_set_elements(dataset.id, dataset_set.name)
self.process_split(dataset_set.name, elements)
# TAR + ZST the cache and the images folder, and store as task artifact
zst_archive_path: Path = self.work_dir / dataset.filepath
logger.info(f"Compressing the images to {zst_archive_path}")
create_tar_zst_archive(
source=self.data_folder_path, destination=zst_archive_path
)
self.data_folder.cleanup()
def run(self):
self.configure()
dataset_sets: list[Set] = list(self.list_sets())
grouped_sets: list[tuple[Dataset, list[Set]]] = [
(dataset, list(sets))
for dataset, sets in groupby(dataset_sets, attrgetter("dataset"))
]
if not grouped_sets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(grouped_sets)
failed = 0
for i, (dataset, sets) in enumerate(grouped_sets, start=1):
try:
assert dataset.state in [
DatasetState.Open.value,
DatasetState.Error.value,
], "When generating a new dataset, its state should be Open or Error."
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
logger.info(f"Processing {dataset} ({i}/{count})")
self.process_dataset(dataset, sets)
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while processing or patching the state for this dataset
failed += 1
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing {dataset}: {e.title} - {e.content}"
else:
message = f"Failed running worker on {dataset}: {repr(e)}"
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
# Try to update the state to Error regardless of the response
with contextlib.suppress(Exception):
self.update_dataset_state(dataset, DatasetState.Error)
message = f'Ran on {count} dataset{"s"[:count > 1]}: {count - failed} completed, {failed} failed'
if failed:
logger.error(message)
if failed >= count: # Everything failed!
sys.exit(1)
else:
logger.info(message)
# -*- coding: utf-8 -*-
from uuid import UUID
from arkindex_export import Classification
from arkindex_export.models import (
DatasetElement,
Element,
Entity,
EntityType,
Image,
Transcription,
TranscriptionEntity,
)
def list_dataset_elements(dataset_id: UUID, set_name: str):
return (
Element.select()
.join(Image)
.switch(Element)
.join(DatasetElement, on=DatasetElement.element)
.where(
DatasetElement.dataset == dataset_id, DatasetElement.set_name == set_name
)
)
def list_classifications(element_id: UUID):
return Classification.select().where(Classification.element == element_id)
def list_transcriptions(element_id: UUID):
return Transcription.select().where(Transcription.element == element_id)
def list_transcription_entities(transcription_id: UUID):
return (
TranscriptionEntity.select()
.where(TranscriptionEntity.transcription == transcription_id)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
)
# -*- coding: utf-8 -*-
from uuid import UUID
from arkindex_worker.worker.classification import ClassificationMixin
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin
from arkindex_worker.worker.metadata import MetaDataMixin
from arkindex_worker.worker.transcription import TranscriptionMixin
from worker_generic_training_dataset import Extractor
class DatasetExtractorFromAPI(
Extractor,
ElementMixin,
ClassificationMixin,
EntityMixin,
TranscriptionMixin,
MetaDataMixin,
):
def list_classifications(self, element_id: UUID):
return iter(
self.api_client.request("RetrieveElement", id=str(element_id))[
"classifications"
]
)
def main():
DatasetExtractorFromAPI(
description="Fill base-worker cache with information about dataset and extract images",
).run()
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
import logging
from operator import itemgetter
from tempfile import _TemporaryFileWrapper
from typing import List
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import (
Classification,
DatasetElement,
Element,
Entity,
EntityType,
Image,
Transcription,
TranscriptionEntity,
WorkerRun,
WorkerVersion,
open_database,
)
from arkindex_export.queries import list_children
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Transcription as ArkindexTranscription
from peewee import CharField
from worker_generic_training_dataset import Extractor
logger: logging.Logger = logging.getLogger(__name__)
BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None:
return instance.id if instance else None
class DatasetExtractorFromSQL(Extractor):
def configure(self):
super().configure()
self.download_latest_export()
def download_latest_export(self) -> None:
"""
Download the latest export of the current corpus.
Export must be in `"done"` state.
"""
try:
exports = list(
self.api_client.paginate(
"ListExports",
id=self.corpus_id,
)
)
except ErrorResponse as e:
logger.error(
f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=itemgetter("updated"),
reverse=True,
)
assert (
len(exports) > 0
), f"No available exports found for the corpus {self.corpus_id}."
# Download latest export
try:
export_id: str = exports[0]["id"]
logger.info(f"Downloading export ({export_id})...")
self.export: _TemporaryFileWrapper = self.api_client.request(
"DownloadExport",
id=export_id,
)
logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
open_database(self.export.name)
except ErrorResponse as e:
logger.error(
f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
def list_set_elements(self, dataset_id: UUID, set_name: str):
return (
Element.select()
.join(Image)
.switch(Element)
.join(DatasetElement, on=DatasetElement.element)
.where(
DatasetElement.dataset == dataset_id,
DatasetElement.set_name == set_name,
)
)
def list_classifications(self, element_id: UUID):
return (
Classification.select()
.where(Classification.element == element_id)
.iterator()
)
def list_transcriptions(self, element: ArkindexElement, **kwargs):
return (
Transcription.select().where(Transcription.element == element.id).iterator()
)
def list_transcription_entities(
self, transcription: ArkindexTranscription, **kwargs
): # -> Any:
return (
TranscriptionEntity.select()
.where(TranscriptionEntity.transcription == transcription.id)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
).iterator()
def list_element_children(self, element: ArkindexElement, **kwargs):
return list_children(element.id).iterator()
def main():
DatasetExtractorFromSQL(
description="Fill base-worker cache with information about dataset and extract images",
).run()
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
import contextlib
import json
import logging
import sys
import tempfile
import uuid
from itertools import groupby
from operator import attrgetter, itemgetter
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import List, Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, WorkerRun, WorkerVersion, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedDataset,
CachedDatasetElement,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
create_tables,
create_version_table,
)
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset, Set
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker import DatasetWorker
from arkindex_worker.worker.dataset import DatasetState
from peewee import CharField
from worker_generic_training_dataset.db import (
list_classifications,
list_dataset_elements,
list_transcription_entities,
list_transcriptions,
)
from worker_generic_training_dataset.utils import build_image_url
logger: logging.Logger = logging.getLogger(__name__)
BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None:
return instance.id if instance else None
class DatasetExtractor(DatasetWorker):
def configure_storage(self) -> None:
self.data_folder = tempfile.TemporaryDirectory(suffix="-arkindex-data")
self.data_folder_path = Path(self.data_folder.name)
# Initialize db that will be written
self.configure_cache()
# CachedImage downloaded and created in DB
self.cached_images = dict()
# Where to save the downloaded images
self.images_folder = self.data_folder_path / "images"
self.images_folder.mkdir(parents=True)
logger.info(f"Images will be saved at `{self.images_folder}`.")
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.data_folder_path / "db.sqlite"
logger.info(f"Cached database will be saved at `{self.cache_path}`.")
init_cache_db(self.cache_path)
create_version_table()
create_tables()
def download_latest_export(self) -> None:
"""
Download the latest export of the current corpus.
Export must be in `"done"` state.
"""
try:
exports = list(
self.api_client.paginate(
"ListExports",
id=self.corpus_id,
)
)
except ErrorResponse as e:
logger.error(
f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=itemgetter("updated"),
reverse=True,
)
assert (
len(exports) > 0
), f"No available exports found for the corpus {self.corpus_id}."
# Download latest export
try:
export_id: str = exports[0]["id"]
logger.info(f"Downloading export ({export_id})...")
self.export: _TemporaryFileWrapper = self.api_client.request(
"DownloadExport",
id=export_id,
)
logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
open_database(self.export.name)
except ErrorResponse as e:
logger.error(
f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
worker_run_id=get_object_id(classification.worker_run),
)
for classification in list_classifications(element.id).iterator()
]
if classifications:
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
CachedClassification.bulk_create(
model_list=classifications,
batch_size=BULK_BATCH_SIZE,
)
def insert_transcriptions(
self, element: CachedElement
) -> List[CachedTranscription]:
logger.info("Listing transcriptions")
transcriptions: list[CachedTranscription] = [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=get_object_id(transcription.worker_version),
worker_run_id=get_object_id(transcription.worker_run),
)
for transcription in list_transcriptions(element.id).iterator()
]
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
CachedTranscription.bulk_create(
model_list=transcriptions,
batch_size=BULK_BATCH_SIZE,
)
return transcriptions
def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
logger.info("Listing entities")
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription in transcriptions:
for transcription_entity in list_transcription_entities(
transcription.id
).iterator():
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
worker_run_id=get_object_id(transcription_entity.entity.worker_run),
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
worker_run_id=get_object_id(transcription_entity.worker_run),
)
)
if entities:
# First insert entities since they are foreign keys on transcription entities
logger.info(f"Inserting {len(entities)} entities")
with cache_database.atomic():
CachedEntity.bulk_create(
model_list=entities,
batch_size=BULK_BATCH_SIZE,
)
if transcription_entities:
# Insert transcription entities
logger.info(
f"Inserting {len(transcription_entities)} transcription entities"
)
with cache_database.atomic():
CachedTranscriptionEntity.bulk_create(
model_list=transcription_entities,
batch_size=BULK_BATCH_SIZE,
)
def insert_element(
self,
element: Element,
split_name: Optional[str] = None,
parent_id: Optional[UUID] = None,
) -> None:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn't already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions' entities (both Entity and TranscriptionEntity)
The element will also be linked to the appropriate split in the current dataset.
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.images_folder / f"{element.image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
# Store images in case some other elements use it as well
with cache_database.atomic():
self.cached_images[element.image.id] = CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
)
# Insert element
logger.info("Inserting element")
with cache_database.atomic():
cached_element: CachedElement = CachedElement.create(
id=element.id,
parent_id=parent_id,
type=element.type,
image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=get_object_id(element.worker_version),
worker_run_id=get_object_id(element.worker_run),
confidence=element.confidence,
)
# Insert classifications
self.insert_classifications(cached_element)
# Insert transcriptions
transcriptions: List[CachedTranscription] = self.insert_transcriptions(
cached_element
)
# Insert entities
self.insert_entities(transcriptions)
# Link the element to the dataset split
if split_name:
logger.info(
f"Linking element {cached_element.id} to dataset ({self.cached_dataset.id})"
)
with cache_database.atomic():
cached_element: CachedDatasetElement = CachedDatasetElement.create(
id=uuid.uuid4(),
element=cached_element,
dataset=self.cached_dataset,
set_name=split_name,
)
def process_split(self, split_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the cache with information from elements in the split {split_name}"
)
# First list all pages
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(element, split_name=split_name)
# List children
children = list_children(element.id)
nb_children: int = children.count()
for child_idx, child in enumerate(children.iterator(), start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=element.id)
def insert_dataset(self, dataset: Dataset) -> None:
"""
Insert the given dataset in the cache database.
:param dataset: Dataset to insert.
"""
logger.info(f"Inserting dataset ({dataset.id})")
with cache_database.atomic():
self.cached_dataset = CachedDataset.create(
id=dataset.id,
name=dataset.name,
state=dataset.state,
sets=json.dumps(dataset.sets),
)
def process_dataset(self, dataset: Dataset, sets: list[Set]):
# Configure temporary storage for the dataset data (cache + images)
self.configure_storage()
# Insert dataset in cache database
self.insert_dataset(dataset)
# Iterate over given splits
for dataset_set in sets:
elements = list_dataset_elements(dataset.id, dataset_set.name)
self.process_split(dataset_set.name, elements)
# TAR + ZST the cache and the images folder, and store as task artifact
zst_archive_path: Path = self.work_dir / dataset.filepath
logger.info(f"Compressing the images to {zst_archive_path}")
create_tar_zst_archive(
source=self.data_folder_path, destination=zst_archive_path
)
self.data_folder.cleanup()
def run(self):
self.configure()
# Download corpus
self.download_latest_export()
dataset_sets: list[Set] = list(self.list_sets())
grouped_sets: list[tuple[Dataset, list[Set]]] = [
(dataset, list(sets))
for dataset, sets in groupby(dataset_sets, attrgetter("dataset"))
]
if not grouped_sets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(grouped_sets)
failed = 0
for i, (dataset, sets) in enumerate(grouped_sets, start=1):
try:
assert dataset.state in [
DatasetState.Open.value,
DatasetState.Error.value,
], "When generating a new dataset, its state should be Open or Error."
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
logger.info(f"Processing {dataset} ({i}/{count})")
self.process_dataset(dataset, sets)
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while processing or patching the state for this dataset
failed += 1
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing {dataset}: {e.title} - {e.content}"
else:
message = f"Failed running worker on {dataset}: {repr(e)}"
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
# Try to update the state to Error regardless of the response
with contextlib.suppress(Exception):
self.update_dataset_state(dataset, DatasetState.Error)
message = f'Ran on {count} dataset{"s"[:count > 1]}: {count - failed} completed, {failed} failed'
if failed:
logger.error(message)
if failed >= count: # Everything failed!
sys.exit(1)
else:
logger.info(message)
def main():
DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images",
).run()
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment