Skip to content
Snippets Groups Projects
Verified Commit 7fb7097c authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

working implem

parent ac10b18e
No related branches found
No related tags found
1 merge request!2Implement worker
Pipeline #81794 failed
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import NamedTuple from typing import NamedTuple
from arkindex_export import Classification from arkindex_export import Classification, ElementPath, Image
from arkindex_export.models import ( from arkindex_export.models import (
Element, Element,
Entity, Entity,
...@@ -23,8 +23,8 @@ def retrieve_element(element_id: str): ...@@ -23,8 +23,8 @@ def retrieve_element(element_id: str):
return Element.get_by_id(element_id) return Element.get_by_id(element_id)
def list_classifications(element: Element): def list_classifications(element_id: str):
query = Classification.select().where(Classification.element == element) query = Classification.select().where(Classification.element_id == element_id)
return query return query
...@@ -33,7 +33,8 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement): ...@@ -33,7 +33,8 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement):
id=transcription.id, id=transcription.id,
element=element, element=element,
text=transcription.text, text=transcription.text,
confidence=transcription.confidence, # Dodge not-null constraint for now
confidence=transcription.confidence or 1.0,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription.worker_version.id worker_version_id=transcription.worker_version.id
if transcription.worker_version if transcription.worker_version
...@@ -89,3 +90,47 @@ def retrieve_entities(transcription: CachedTranscription): ...@@ -89,3 +90,47 @@ def retrieve_entities(transcription: CachedTranscription):
return [], [] return [], []
return zip(*data) return zip(*data)
def list_children(parent_id):
# First, build the base query to get direct children
base = (
ElementPath.select(
ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering
)
.where(ElementPath.parent_id == parent_id)
.cte("children", recursive=True, columns=("child_id", "parent_id", "ordering"))
)
# Then build the second recursive query, using an alias to join both queries on the same table
EP = ElementPath.alias()
recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join(
base, on=(EP.parent_id == base.c.child_id)
)
# Combine both queries, using UNION and not UNION ALL to deduplicate parents
# that might be found multiple times with complex element structures
cte = base.union(recursion)
# And load all the elements found in the CTE
query = (
Element.select(
Element.id,
Element.type,
Image.id.alias("image_id"),
Image.width.alias("image_width"),
Image.height.alias("image_height"),
Image.url.alias("image_url"),
Element.polygon,
Element.rotation_angle,
Element.mirrored,
Element.worker_version,
Element.confidence,
cte.c.parent_id,
)
.with_cte(cte)
.join(cte, on=(Element.id == cte.c.child_id))
.join(Image, on=(Element.image_id == Image.id))
.order_by(cte.c.ordering.asc())
)
return query
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import ast import ast
import logging import logging
import os
import tarfile
import tempfile
import time import time
from pathlib import Path from pathlib import Path
from urllib.parse import urljoin from urllib.parse import urljoin
import cv2 import cv2
import imageio.v2 as iio import imageio.v2 as iio
from arkindex_export.models import Element import zstandard as zstd
from worker_generic_training_dataset.exceptions import ImageDownloadError from worker_generic_training_dataset.exceptions import ImageDownloadError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -24,14 +27,14 @@ def bounding_box(polygon: list): ...@@ -24,14 +27,14 @@ def bounding_box(polygon: list):
return int(x), int(y), int(width), int(height) return int(x), int(y), int(width), int(height)
def build_image_url(element: Element): def build_image_url(element):
x, y, width, height = bounding_box(ast.literal_eval(element.polygon)) x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
return urljoin( return urljoin(
element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
) )
def download_image(element: Element, folder: Path): def download_image(element, folder: Path):
""" """
Download the image to `folder / {element.image.id}.jpg` Download the image to `folder / {element.image.id}.jpg`
""" """
...@@ -43,7 +46,7 @@ def download_image(element: Element, folder: Path): ...@@ -43,7 +46,7 @@ def download_image(element: Element, folder: Path):
try: try:
image = iio.imread(build_image_url(element)) image = iio.imread(build_image_url(element))
cv2.imwrite( cv2.imwrite(
str(folder / f"{element.image.id}.jpg"), str(folder / f"{element.image_id}.jpg"),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB), cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
) )
break break
...@@ -53,3 +56,27 @@ def download_image(element: Element, folder: Path): ...@@ -53,3 +56,27 @@ def download_image(element: Element, folder: Path):
tries += 1 tries += 1
except Exception as e: except Exception as e:
raise ImageDownloadError(element.id, e) raise ImageDownloadError(element.id, e)
def create_tar_zstd_archive(folder_path, destination: Path, chunk_size=1024):
compressor = zstd.ZstdCompressor(level=3)
# Remove extension from the model filename
_, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar")
# Create an uncompressed tar archive with all the needed files
# Files hierarchy ifs kept in the archive.
with tarfile.open(path_to_tar_archive, "w") as tar:
for p in folder_path.glob("**/*"):
x = p.relative_to(folder_path)
tar.add(p, arcname=x, recursive=False)
# Compress the archive
with destination.open("wb") as archive_file:
with open(path_to_tar_archive, "rb") as model_data:
for model_chunk in iter(lambda: model_data.read(chunk_size), b""):
compressed_chunk = compressor.compress(model_chunk)
archive_file.write(compressed_chunk)
# Remove the tar archive
os.remove(path_to_tar_archive)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
import operator import operator
import tempfile
from pathlib import Path from pathlib import Path
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from arkindex_export import open_database from arkindex_export import open_database
from arkindex_export.models import Element
from arkindex_export.queries import list_children
from arkindex_worker.cache import ( from arkindex_worker.cache import (
CachedClassification, CachedClassification,
CachedElement, CachedElement,
...@@ -21,16 +20,19 @@ from arkindex_worker.cache import db as cache_database ...@@ -21,16 +20,19 @@ from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db from arkindex_worker.cache import init_cache_db
from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker import ElementsWorker
from worker_generic_training_dataset.db import ( from worker_generic_training_dataset.db import (
list_children,
list_classifications, list_classifications,
list_transcriptions, list_transcriptions,
retrieve_element, retrieve_element,
retrieve_entities, retrieve_entities,
) )
from worker_generic_training_dataset.utils import download_image from worker_generic_training_dataset.utils import (
create_tar_zstd_archive,
download_image,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IMAGE_FOLDER = Path("images")
BULK_BATCH_SIZE = 50 BULK_BATCH_SIZE = 50
...@@ -57,8 +59,12 @@ class DatasetExtractor(ElementsWorker): ...@@ -57,8 +59,12 @@ class DatasetExtractor(ElementsWorker):
# - self.workdir / "db.sqlite" in Arkindex mode # - self.workdir / "db.sqlite" in Arkindex mode
# - self.args.database in dev mode # - self.args.database in dev mode
database_path = ( database_path = (
self.args.database if self.is_read_only else self.workdir / "db.sqlite" Path(self.args.database)
if self.is_read_only
else self.workdir / "db.sqlite"
) )
if database_path.exists():
database_path.unlink()
init_cache_db(database_path) init_cache_db(database_path)
...@@ -83,6 +89,7 @@ class DatasetExtractor(ElementsWorker): ...@@ -83,6 +89,7 @@ class DatasetExtractor(ElementsWorker):
exports = sorted( exports = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)), list(filter(lambda exp: exp["state"] == "done", exports)),
key=operator.itemgetter("updated"), key=operator.itemgetter("updated"),
reverse=True,
) )
assert len(exports) > 0, "No available exports found." assert len(exports) > 0, "No available exports found."
...@@ -102,33 +109,33 @@ class DatasetExtractor(ElementsWorker): ...@@ -102,33 +109,33 @@ class DatasetExtractor(ElementsWorker):
) )
raise e raise e
def insert_element(self, element: Element, parent_id: str): def insert_element(self, element, image_folder: Path, root: bool = False):
logger.info(f"Processing element ({element.id})") logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images: if element.image_id and element.image_id not in self.cached_images:
# Download image # Download image
logger.info("Downloading image") logger.info("Downloading image")
download_image(element, folder=IMAGE_FOLDER) download_image(element, folder=image_folder)
# Insert image # Insert image
logger.info("Inserting image") logger.info("Inserting image")
# Store images in case some other elements use it as well # Store images in case some other elements use it as well
self.cached_images[element.image.id] = CachedImage.create( self.cached_images[element.image_id] = CachedImage.create(
id=element.image.id, id=element.image_id,
width=element.image.width, width=element.image_width,
height=element.image.height, height=element.image_height,
url=element.image.url, url=element.image_url,
) )
# Insert element # Insert element
logger.info("Inserting element") logger.info("Inserting element")
cached_element = CachedElement.create( cached_element = CachedElement.create(
id=element.id, id=element.id,
parent_id=parent_id, parent_id=None if root else element.parent_id,
type=element.type, type=element.type,
image=element.image.id if element.image else None, image=self.cached_images[element.image_id] if element.image_id else None,
polygon=element.polygon, polygon=element.polygon,
rotation_angle=element.rotation_angle, rotation_angle=element.rotation_angle,
mirrored=element.mirrored, mirrored=element.mirrored,
worker_version_id=element.worker_version.id worker_version_id=element.worker_version
if element.worker_version if element.worker_version
else None, else None,
confidence=element.confidence, confidence=element.confidence,
...@@ -144,10 +151,10 @@ class DatasetExtractor(ElementsWorker): ...@@ -144,10 +151,10 @@ class DatasetExtractor(ElementsWorker):
confidence=classification.confidence, confidence=classification.confidence,
state=classification.state, state=classification.state,
) )
for classification in list_classifications(element) for classification in list_classifications(element.id)
] ]
if classifications: if classifications:
logger.info(f"Inserting {len(classifications)} classifications") logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic(): with cache_database.atomic():
CachedClassification.bulk_create( CachedClassification.bulk_create(
model_list=classifications, model_list=classifications,
...@@ -158,7 +165,7 @@ class DatasetExtractor(ElementsWorker): ...@@ -158,7 +165,7 @@ class DatasetExtractor(ElementsWorker):
logger.info("Listing transcriptions") logger.info("Listing transcriptions")
transcriptions = list_transcriptions(cached_element) transcriptions = list_transcriptions(cached_element)
if transcriptions: if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcriptions") logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic(): with cache_database.atomic():
CachedTranscription.bulk_create( CachedTranscription.bulk_create(
model_list=transcriptions, model_list=transcriptions,
...@@ -190,11 +197,27 @@ class DatasetExtractor(ElementsWorker): ...@@ -190,11 +197,27 @@ class DatasetExtractor(ElementsWorker):
) )
def process_element(self, element): def process_element(self, element):
# Where to save the downloaded images
image_folder = Path(tempfile.mkdtemp())
logger.info(
f"Filling the Base-Worker cache with information from children under element ({element.id})"
)
# Fill cache
# Retrieve parent and create parent # Retrieve parent and create parent
parent = retrieve_element(element.id) parent = retrieve_element(element.id)
self.insert_element(parent, parent_id=None) self.insert_element(parent, image_folder, root=True)
for child in list_children(parent_id=element.id): # Create children
self.insert_element(child, parent_id=element.id) children = list_children(parent_id=element.id)
nb_children = children.count()
for idx, child in enumerate(children.namedtuples(), start=1):
logger.info(f"Processing child ({idx}/{nb_children})")
self.insert_element(child, image_folder)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path)
def main(): def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment