Skip to content
Snippets Groups Projects
Verified Commit 2a841ce2 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

correctly link pages to split folder

parent 5c88bf98
No related branches found
No related tags found
1 merge request!2Implement worker
Pipeline #81797 failed
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import NamedTuple from typing import NamedTuple
from arkindex_export import Classification, ElementPath, Image from arkindex_export import Classification, Image
from arkindex_export.models import ( from arkindex_export.models import (
Element, Element,
Entity, Entity,
...@@ -9,6 +9,7 @@ from arkindex_export.models import ( ...@@ -9,6 +9,7 @@ from arkindex_export.models import (
Transcription, Transcription,
TranscriptionEntity, TranscriptionEntity,
) )
from arkindex_export.queries import list_children
from arkindex_worker.cache import ( from arkindex_worker.cache import (
CachedElement, CachedElement,
CachedEntity, CachedEntity,
...@@ -92,45 +93,8 @@ def retrieve_entities(transcription: CachedTranscription): ...@@ -92,45 +93,8 @@ def retrieve_entities(transcription: CachedTranscription):
return zip(*data) return zip(*data)
def list_children(parent_id): def get_children(parent_id, element_type=None):
# First, build the base query to get direct children query = list_children(parent_id).join(Image)
base = ( if element_type:
ElementPath.select( query = query.where(Element.type == element_type)
ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering
)
.where(ElementPath.parent_id == parent_id)
.cte("children", recursive=True, columns=("child_id", "parent_id", "ordering"))
)
# Then build the second recursive query, using an alias to join both queries on the same table
EP = ElementPath.alias()
recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join(
base, on=(EP.parent_id == base.c.child_id)
)
# Combine both queries, using UNION and not UNION ALL to deduplicate parents
# that might be found multiple times with complex element structures
cte = base.union(recursion)
# And load all the elements found in the CTE
query = (
Element.select(
Element.id,
Element.type,
Image.id.alias("image_id"),
Image.width.alias("image_width"),
Image.height.alias("image_height"),
Image.url.alias("image_url"),
Element.polygon,
Element.rotation_angle,
Element.mirrored,
Element.worker_version,
Element.confidence,
cte.c.parent_id,
)
.with_cte(cte)
.join(cte, on=(Element.id == cte.c.child_id))
.join(Image, on=(Element.image_id == Image.id))
.order_by(cte.c.ordering.asc())
)
return query return query
...@@ -30,7 +30,7 @@ def bounding_box(polygon: list): ...@@ -30,7 +30,7 @@ def bounding_box(polygon: list):
def build_image_url(element): def build_image_url(element):
x, y, width, height = bounding_box(ast.literal_eval(element.polygon)) x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
return urljoin( return urljoin(
element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
) )
...@@ -46,7 +46,7 @@ def download_image(element, folder: Path): ...@@ -46,7 +46,7 @@ def download_image(element, folder: Path):
try: try:
image = iio.imread(build_image_url(element)) image = iio.imread(build_image_url(element))
cv2.imwrite( cv2.imwrite(
str(folder / f"{element.image_id}.png"), str(folder / f"{element.image.id}.jpg"),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB), cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
) )
break break
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
import operator import operator
import shutil
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from arkindex_export import open_database from arkindex_export import open_database
...@@ -18,9 +21,9 @@ from arkindex_worker.cache import ( ...@@ -18,9 +21,9 @@ from arkindex_worker.cache import (
) )
from arkindex_worker.cache import db as cache_database from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db from arkindex_worker.cache import init_cache_db
from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker.base import BaseWorker
from worker_generic_training_dataset.db import ( from worker_generic_training_dataset.db import (
list_children, get_children,
list_classifications, list_classifications,
list_transcriptions, list_transcriptions,
retrieve_element, retrieve_element,
...@@ -36,15 +39,27 @@ logger = logging.getLogger(__name__) ...@@ -36,15 +39,27 @@ logger = logging.getLogger(__name__)
BULK_BATCH_SIZE = 50 BULK_BATCH_SIZE = 50
class DatasetExtractor(ElementsWorker): class DatasetExtractor(BaseWorker):
def configure(self): def configure(self):
super().configure() self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
self.process_information = {
"train_folder_id": "47a0e07b-d07a-4969-aced-44450d132f0d",
"validation_folder_id": "8cbc4b53-9e07-4a72-b4e6-93f7f5b0cbed",
"test_folder_id": "659a37ea-3b26-42f0-8b65-78964f9e433e",
}
else:
super().configure()
# database arg is mandatory in dev mode # database arg is mandatory in dev mode
assert ( assert (
not self.is_read_only or self.args.database is not None not self.is_read_only or self.args.database is not None
), "`--database` arg is mandatory in developer mode." ), "`--database` arg is mandatory in developer mode."
# Read process information
self.read_training_related_information()
# Download corpus # Download corpus
self.download_latest_export() self.download_latest_export()
...@@ -54,6 +69,27 @@ class DatasetExtractor(ElementsWorker): ...@@ -54,6 +69,27 @@ class DatasetExtractor(ElementsWorker):
# Cached Images downloaded and created in DB # Cached Images downloaded and created in DB
self.cached_images = dict() self.cached_images = dict()
def read_training_related_information(self):
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
- model_id
"""
logger.info("Retrieving information from process_information")
train_folder_id = self.process_information.get("train_folder_id")
assert train_folder_id, "A training folder id is necessary to use this worker"
self.training_folder_id = UUID(train_folder_id)
val_folder_id = self.process_information.get("validation_folder_id")
assert val_folder_id, "A validation folder id is necessary to use this worker"
self.validation_folder_id = UUID(val_folder_id)
test_folder_id = self.process_information.get("test_folder_id")
self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None
def initialize_database(self): def initialize_database(self):
# Create db at # Create db at
# - self.workdir / "db.sqlite" in Arkindex mode # - self.workdir / "db.sqlite" in Arkindex mode
...@@ -109,9 +145,11 @@ class DatasetExtractor(ElementsWorker): ...@@ -109,9 +145,11 @@ class DatasetExtractor(ElementsWorker):
) )
raise e raise e
def insert_element(self, element, image_folder: Path, root: bool = False): def insert_element(
self, element, image_folder: Path, parent_id: Optional[str] = None
):
logger.info(f"Processing element ({element.id})") logger.info(f"Processing element ({element.id})")
if element.image_id and element.image_id not in self.cached_images: if element.image and element.image.id not in self.cached_images:
# Download image # Download image
logger.info("Downloading image") logger.info("Downloading image")
download_image(element, folder=image_folder) download_image(element, folder=image_folder)
...@@ -119,19 +157,19 @@ class DatasetExtractor(ElementsWorker): ...@@ -119,19 +157,19 @@ class DatasetExtractor(ElementsWorker):
logger.info("Inserting image") logger.info("Inserting image")
# Store images in case some other elements use it as well # Store images in case some other elements use it as well
self.cached_images[element.image_id] = CachedImage.create( self.cached_images[element.image_id] = CachedImage.create(
id=element.image_id, id=element.image.id,
width=element.image_width, width=element.image.width,
height=element.image_height, height=element.image.height,
url=element.image_url, url=element.image.url,
) )
# Insert element # Insert element
logger.info("Inserting element") logger.info("Inserting element")
cached_element = CachedElement.create( cached_element = CachedElement.create(
id=element.id, id=element.id,
parent_id=None if root else element.parent_id, parent_id=parent_id,
type=element.type, type=element.type,
image=self.cached_images[element.image_id] if element.image_id else None, image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon, polygon=element.polygon,
rotation_angle=element.rotation_angle, rotation_angle=element.rotation_angle,
mirrored=element.mirrored, mirrored=element.mirrored,
...@@ -196,29 +234,56 @@ class DatasetExtractor(ElementsWorker): ...@@ -196,29 +234,56 @@ class DatasetExtractor(ElementsWorker):
batch_size=BULK_BATCH_SIZE, batch_size=BULK_BATCH_SIZE,
) )
def process_element(self, element): def process_split(self, split_name, split_id, image_folder):
# Where to save the downloaded images
image_folder = Path(tempfile.mkdtemp())
logger.info( logger.info(
f"Filling the Base-Worker cache with information from children under element ({element.id})" f"Filling the Base-Worker cache with information from children under element ({split_id})"
) )
# Fill cache # Fill cache
# Retrieve parent and create parent # Retrieve parent and create parent
parent = retrieve_element(element.id) parent = retrieve_element(split_id)
self.insert_element(parent, image_folder, root=True) self.insert_element(parent, image_folder)
# Create children
children = list_children(parent_id=element.id) # First list all pages
nb_children = children.count() pages = get_children(parent_id=split_id, element_type="page")
for idx, child in enumerate(children.namedtuples(), start=1): nb_pages = pages.count()
logger.info(f"Processing child ({idx}/{nb_children})") for idx, page in enumerate(pages, start=1):
self.insert_element(child, image_folder) logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
# Insert page
self.insert_element(page, image_folder, parent_id=split_id)
# List children
children = get_children(parent_id=page.id)
nb_children = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, image_folder, parent_id=page.id)
def run(self):
self.configure()
# Where to save the downloaded images
image_folder = Path(tempfile.mkdtemp())
# Iterate over given split
for split_name, split_id in [
("Train", self.training_folder_id),
("Validation", self.validation_folder_id),
("Test", self.testing_folder_id),
]:
if not split_id:
continue
self.process_split(split_name, split_id, image_folder)
# TAR + ZSTD Image folder and store as task artifact # TAR + ZSTD Image folder and store as task artifact
zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd" zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}") logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path) create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path)
# Cleanup image folder
shutil.rmtree(image_folder)
def main(): def main():
DatasetExtractor( DatasetExtractor(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment