Skip to content
Snippets Groups Projects
Commit d4655a95 authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Rework the worker due to `Dataset` API changes

parent 64f3411b
No related branches found
No related tags found
1 merge request!23Rework the worker due to `Dataset` API changes
...@@ -39,3 +39,7 @@ repos: ...@@ -39,3 +39,7 @@ repos:
- repo: meta - repo: meta
hooks: hooks:
- id: check-useless-excludes - id: check-useless-excludes
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.10.0.1
hooks:
- id: shellcheck
#!/bin/sh -e #!/bin/sh -e
# Build the tasks Docker image. # Build the tasks Docker image.
# Requires CI_PROJECT_DIR and CI_REGISTRY_IMAGE to be set. # Requires CI_PROJECT_DIR and CI_REGISTRY_IMAGE to be set.
# VERSION defaults to latest.
# Will automatically login to a registry if CI_REGISTRY, CI_REGISTRY_USER and CI_REGISTRY_PASSWORD are set. # Will automatically login to a registry if CI_REGISTRY, CI_REGISTRY_USER and CI_REGISTRY_PASSWORD are set.
# Will only push an image if $CI_REGISTRY is set. # Will only push an image if $CI_REGISTRY is set.
if [ -z "$VERSION" -o -z "$CI_PROJECT_DIR" -o -z "$CI_REGISTRY_IMAGE" ]; then if [ -z "$VERSION" ] || [ -z "$CI_PROJECT_DIR" ] || [ -z "$CI_REGISTRY_IMAGE" ]; then
echo Missing environment variables echo Missing environment variables
exit 1 exit 1
fi fi
IMAGE_TAG="$CI_REGISTRY_IMAGE:$VERSION" IMAGE_TAG="$CI_REGISTRY_IMAGE:$VERSION"
cd $CI_PROJECT_DIR cd "$CI_PROJECT_DIR"
docker build -f Dockerfile . -t "$IMAGE_TAG" docker build -f Dockerfile . -t "$IMAGE_TAG"
if [ -n "$CI_REGISTRY" -a -n "$CI_REGISTRY_USER" -a -n "$CI_REGISTRY_PASSWORD" ]; then if [ -n "$CI_REGISTRY" ] && [ -n "$CI_REGISTRY_USER" ] && [ -n "$CI_REGISTRY_PASSWORD" ]; then
echo $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY echo "$CI_REGISTRY_PASSWORD" | docker login -u "$CI_REGISTRY_USER" --password-stdin "$CI_REGISTRY"
docker push $IMAGE_TAG docker push "$IMAGE_TAG"
else else
echo "Missing environment variables to log in to the container registry…" echo "Missing environment variables to log in to the container registry…"
fi fi
...@@ -4,6 +4,7 @@ import json ...@@ -4,6 +4,7 @@ import json
from argparse import Namespace from argparse import Namespace
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from arkindex_export.models import Element
from arkindex_worker.cache import ( from arkindex_worker.cache import (
CachedClassification, CachedClassification,
CachedDataset, CachedDataset,
...@@ -14,7 +15,6 @@ from arkindex_worker.cache import ( ...@@ -14,7 +15,6 @@ from arkindex_worker.cache import (
CachedTranscription, CachedTranscription,
CachedTranscriptionEntity, CachedTranscriptionEntity,
) )
from worker_generic_training_dataset.db import retrieve_element
from worker_generic_training_dataset.worker import DatasetExtractor from worker_generic_training_dataset.worker import DatasetExtractor
...@@ -44,8 +44,8 @@ def test_process_split(tmp_path, downloaded_images): ...@@ -44,8 +44,8 @@ def test_process_split(tmp_path, downloaded_images):
worker.process_split( worker.process_split(
"train", "train",
[ [
retrieve_element(first_page_id), Element.get_by_id(first_page_id),
retrieve_element(second_page_id), Element.get_by_id(second_page_id),
], ],
) )
......
...@@ -4,16 +4,26 @@ from uuid import UUID ...@@ -4,16 +4,26 @@ from uuid import UUID
from arkindex_export import Classification from arkindex_export import Classification
from arkindex_export.models import ( from arkindex_export.models import (
DatasetElement,
Element, Element,
Entity, Entity,
EntityType, EntityType,
Image,
Transcription, Transcription,
TranscriptionEntity, TranscriptionEntity,
) )
def retrieve_element(element_id: UUID) -> Element: def list_dataset_elements(dataset_id: UUID, set_name: str):
return Element.get_by_id(element_id) return (
Element.select()
.join(Image)
.switch(Element)
.join(DatasetElement, on=DatasetElement.element)
.where(
DatasetElement.dataset == dataset_id, DatasetElement.set_name == set_name
)
)
def list_classifications(element_id: UUID): def list_classifications(element_id: UUID):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import contextlib
import json import json
import logging import logging
import sys
import tempfile import tempfile
import uuid import uuid
from argparse import Namespace from itertools import groupby
from operator import itemgetter from operator import attrgetter, itemgetter
from pathlib import Path from pathlib import Path
from tempfile import _TemporaryFileWrapper from tempfile import _TemporaryFileWrapper
from typing import List, Optional from typing import List, Optional
...@@ -28,16 +30,16 @@ from arkindex_worker.cache import ( ...@@ -28,16 +30,16 @@ from arkindex_worker.cache import (
from arkindex_worker.cache import db as cache_database from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset from arkindex_worker.models import Dataset, Set
from arkindex_worker.models import Element as WorkerElement
from arkindex_worker.utils import create_tar_zst_archive from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker import DatasetWorker from arkindex_worker.worker import DatasetWorker
from arkindex_worker.worker.dataset import DatasetState
from peewee import CharField from peewee import CharField
from worker_generic_training_dataset.db import ( from worker_generic_training_dataset.db import (
list_classifications, list_classifications,
list_dataset_elements,
list_transcription_entities, list_transcription_entities,
list_transcriptions, list_transcriptions,
retrieve_element,
) )
from worker_generic_training_dataset.utils import build_image_url from worker_generic_training_dataset.utils import build_image_url
...@@ -47,29 +49,11 @@ BULK_BATCH_SIZE = 50 ...@@ -47,29 +49,11 @@ BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
def _format_element(element: WorkerElement) -> Element:
return retrieve_element(element.id)
def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None: def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None:
return instance.id if instance else None return instance.id if instance else None
class DatasetExtractor(DatasetWorker): class DatasetExtractor(DatasetWorker):
def configure(self) -> None:
self.args: Namespace = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
if self.user_configuration:
logger.info("Overriding with user_configuration")
self.config.update(self.user_configuration)
# Download corpus
self.download_latest_export()
def configure_storage(self) -> None: def configure_storage(self) -> None:
self.data_folder = tempfile.TemporaryDirectory(suffix="-arkindex-data") self.data_folder = tempfile.TemporaryDirectory(suffix="-arkindex-data")
self.data_folder_path = Path(self.data_folder.name) self.data_folder_path = Path(self.data_folder.name)
...@@ -351,7 +335,7 @@ class DatasetExtractor(DatasetWorker): ...@@ -351,7 +335,7 @@ class DatasetExtractor(DatasetWorker):
sets=json.dumps(dataset.sets), sets=json.dumps(dataset.sets),
) )
def process_dataset(self, dataset: Dataset): def process_dataset(self, dataset: Dataset, sets: list[Set]):
# Configure temporary storage for the dataset data (cache + images) # Configure temporary storage for the dataset data (cache + images)
self.configure_storage() self.configure_storage()
...@@ -359,9 +343,9 @@ class DatasetExtractor(DatasetWorker): ...@@ -359,9 +343,9 @@ class DatasetExtractor(DatasetWorker):
self.insert_dataset(dataset) self.insert_dataset(dataset)
# Iterate over given splits # Iterate over given splits
for split_name, elements in self.list_dataset_elements_per_split(dataset): for dataset_set in sets:
casted_elements = list(map(_format_element, elements)) elements = list_dataset_elements(dataset.id, dataset_set.name)
self.process_split(split_name, casted_elements) self.process_split(dataset_set.name, elements)
# TAR + ZST the cache and the images folder, and store as task artifact # TAR + ZST the cache and the images folder, and store as task artifact
zst_archive_path: Path = self.work_dir / dataset.filepath zst_archive_path: Path = self.work_dir / dataset.filepath
...@@ -371,11 +355,70 @@ class DatasetExtractor(DatasetWorker): ...@@ -371,11 +355,70 @@ class DatasetExtractor(DatasetWorker):
) )
self.data_folder.cleanup() self.data_folder.cleanup()
def run(self):
self.configure()
# Download corpus
self.download_latest_export()
dataset_sets: list[Set] = list(self.list_sets())
grouped_sets: list[tuple[Dataset, list[Set]]] = [
(dataset, list(sets))
for dataset, sets in groupby(dataset_sets, attrgetter("dataset"))
]
if not grouped_sets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(grouped_sets)
failed = 0
for i, (dataset, sets) in enumerate(grouped_sets, start=1):
try:
assert dataset.state in [
DatasetState.Open.value,
DatasetState.Error.value,
], "When generating a new dataset, its state should be Open or Error."
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
logger.info(f"Processing {dataset} ({i}/{count})")
self.process_dataset(dataset, sets)
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while processing or patching the state for this dataset
failed += 1
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing {dataset}: {e.title} - {e.content}"
else:
message = f"Failed running worker on {dataset}: {repr(e)}"
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
# Try to update the state to Error regardless of the response
with contextlib.suppress(Exception):
self.update_dataset_state(dataset, DatasetState.Error)
message = f'Ran on {count} dataset{"s"[:count > 1]}: {count - failed} completed, {failed} failed'
if failed:
logger.error(message)
if failed >= count: # Everything failed!
sys.exit(1)
else:
logger.info(message)
def main(): def main():
DatasetExtractor( DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images", description="Fill base-worker cache with information about dataset and extract images",
generator=True,
).run() ).run()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment