Skip to content
Snippets Groups Projects
Verified Commit 2a841ce2 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

correctly link pages to split folder

parent 5c88bf98
No related branches found
No related tags found
1 merge request!2Implement worker
Pipeline #81797 failed
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
# -*- coding: utf-8 -*-
from typing import NamedTuple
from arkindex_export import Classification, ElementPath, Image
from arkindex_export import Classification, Image
from arkindex_export.models import (
Element,
Entity,
......@@ -9,6 +9,7 @@ from arkindex_export.models import (
Transcription,
TranscriptionEntity,
)
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedElement,
CachedEntity,
......@@ -92,45 +93,8 @@ def retrieve_entities(transcription: CachedTranscription):
return zip(*data)
def list_children(parent_id):
# First, build the base query to get direct children
base = (
ElementPath.select(
ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering
)
.where(ElementPath.parent_id == parent_id)
.cte("children", recursive=True, columns=("child_id", "parent_id", "ordering"))
)
# Then build the second recursive query, using an alias to join both queries on the same table
EP = ElementPath.alias()
recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join(
base, on=(EP.parent_id == base.c.child_id)
)
# Combine both queries, using UNION and not UNION ALL to deduplicate parents
# that might be found multiple times with complex element structures
cte = base.union(recursion)
# And load all the elements found in the CTE
query = (
Element.select(
Element.id,
Element.type,
Image.id.alias("image_id"),
Image.width.alias("image_width"),
Image.height.alias("image_height"),
Image.url.alias("image_url"),
Element.polygon,
Element.rotation_angle,
Element.mirrored,
Element.worker_version,
Element.confidence,
cte.c.parent_id,
)
.with_cte(cte)
.join(cte, on=(Element.id == cte.c.child_id))
.join(Image, on=(Element.image_id == Image.id))
.order_by(cte.c.ordering.asc())
)
def get_children(parent_id, element_type=None):
query = list_children(parent_id).join(Image)
if element_type:
query = query.where(Element.type == element_type)
return query
......@@ -30,7 +30,7 @@ def bounding_box(polygon: list):
def build_image_url(element):
x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
return urljoin(
element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
)
......@@ -46,7 +46,7 @@ def download_image(element, folder: Path):
try:
image = iio.imread(build_image_url(element))
cv2.imwrite(
str(folder / f"{element.image_id}.png"),
str(folder / f"{element.image.id}.jpg"),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)
break
......
# -*- coding: utf-8 -*-
import logging
import operator
import shutil
import tempfile
from pathlib import Path
from typing import Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import open_database
......@@ -18,9 +21,9 @@ from arkindex_worker.cache import (
)
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker.base import BaseWorker
from worker_generic_training_dataset.db import (
list_children,
get_children,
list_classifications,
list_transcriptions,
retrieve_element,
......@@ -36,15 +39,27 @@ logger = logging.getLogger(__name__)
BULK_BATCH_SIZE = 50
class DatasetExtractor(ElementsWorker):
class DatasetExtractor(BaseWorker):
def configure(self):
super().configure()
self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
self.process_information = {
"train_folder_id": "47a0e07b-d07a-4969-aced-44450d132f0d",
"validation_folder_id": "8cbc4b53-9e07-4a72-b4e6-93f7f5b0cbed",
"test_folder_id": "659a37ea-3b26-42f0-8b65-78964f9e433e",
}
else:
super().configure()
# database arg is mandatory in dev mode
assert (
not self.is_read_only or self.args.database is not None
), "`--database` arg is mandatory in developer mode."
# Read process information
self.read_training_related_information()
# Download corpus
self.download_latest_export()
......@@ -54,6 +69,27 @@ class DatasetExtractor(ElementsWorker):
# Cached Images downloaded and created in DB
self.cached_images = dict()
def read_training_related_information(self):
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
- model_id
"""
logger.info("Retrieving information from process_information")
train_folder_id = self.process_information.get("train_folder_id")
assert train_folder_id, "A training folder id is necessary to use this worker"
self.training_folder_id = UUID(train_folder_id)
val_folder_id = self.process_information.get("validation_folder_id")
assert val_folder_id, "A validation folder id is necessary to use this worker"
self.validation_folder_id = UUID(val_folder_id)
test_folder_id = self.process_information.get("test_folder_id")
self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None
def initialize_database(self):
# Create db at
# - self.workdir / "db.sqlite" in Arkindex mode
......@@ -109,9 +145,11 @@ class DatasetExtractor(ElementsWorker):
)
raise e
def insert_element(self, element, image_folder: Path, root: bool = False):
def insert_element(
self, element, image_folder: Path, parent_id: Optional[str] = None
):
logger.info(f"Processing element ({element.id})")
if element.image_id and element.image_id not in self.cached_images:
if element.image and element.image.id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(element, folder=image_folder)
......@@ -119,19 +157,19 @@ class DatasetExtractor(ElementsWorker):
logger.info("Inserting image")
# Store images in case some other elements use it as well
self.cached_images[element.image_id] = CachedImage.create(
id=element.image_id,
width=element.image_width,
height=element.image_height,
url=element.image_url,
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
)
# Insert element
logger.info("Inserting element")
cached_element = CachedElement.create(
id=element.id,
parent_id=None if root else element.parent_id,
parent_id=parent_id,
type=element.type,
image=self.cached_images[element.image_id] if element.image_id else None,
image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
......@@ -196,29 +234,56 @@ class DatasetExtractor(ElementsWorker):
batch_size=BULK_BATCH_SIZE,
)
def process_element(self, element):
# Where to save the downloaded images
image_folder = Path(tempfile.mkdtemp())
def process_split(self, split_name, split_id, image_folder):
logger.info(
f"Filling the Base-Worker cache with information from children under element ({element.id})"
f"Filling the Base-Worker cache with information from children under element ({split_id})"
)
# Fill cache
# Retrieve parent and create parent
parent = retrieve_element(element.id)
self.insert_element(parent, image_folder, root=True)
# Create children
children = list_children(parent_id=element.id)
nb_children = children.count()
for idx, child in enumerate(children.namedtuples(), start=1):
logger.info(f"Processing child ({idx}/{nb_children})")
self.insert_element(child, image_folder)
parent = retrieve_element(split_id)
self.insert_element(parent, image_folder)
# First list all pages
pages = get_children(parent_id=split_id, element_type="page")
nb_pages = pages.count()
for idx, page in enumerate(pages, start=1):
logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
# Insert page
self.insert_element(page, image_folder, parent_id=split_id)
# List children
children = get_children(parent_id=page.id)
nb_children = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, image_folder, parent_id=page.id)
def run(self):
self.configure()
# Where to save the downloaded images
image_folder = Path(tempfile.mkdtemp())
# Iterate over given split
for split_name, split_id in [
("Train", self.training_folder_id),
("Validation", self.validation_folder_id),
("Test", self.testing_folder_id),
]:
if not split_id:
continue
self.process_split(split_name, split_id, image_folder)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path)
# Cleanup image folder
shutil.rmtree(image_folder)
def main():
DatasetExtractor(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment