Skip to content
Snippets Groups Projects
Verified Commit 7fb7097c authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

working implem

parent ac10b18e
No related branches found
No related tags found
1 merge request!2Implement worker
Pipeline #81794 failed
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
# -*- coding: utf-8 -*-
from typing import NamedTuple
from arkindex_export import Classification
from arkindex_export import Classification, ElementPath, Image
from arkindex_export.models import (
Element,
Entity,
......@@ -23,8 +23,8 @@ def retrieve_element(element_id: str):
return Element.get_by_id(element_id)
def list_classifications(element: Element):
query = Classification.select().where(Classification.element == element)
def list_classifications(element_id: str):
query = Classification.select().where(Classification.element_id == element_id)
return query
......@@ -33,7 +33,8 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement):
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
# Dodge not-null constraint for now
confidence=transcription.confidence or 1.0,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription.worker_version.id
if transcription.worker_version
......@@ -89,3 +90,47 @@ def retrieve_entities(transcription: CachedTranscription):
return [], []
return zip(*data)
def list_children(parent_id):
# First, build the base query to get direct children
base = (
ElementPath.select(
ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering
)
.where(ElementPath.parent_id == parent_id)
.cte("children", recursive=True, columns=("child_id", "parent_id", "ordering"))
)
# Then build the second recursive query, using an alias to join both queries on the same table
EP = ElementPath.alias()
recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join(
base, on=(EP.parent_id == base.c.child_id)
)
# Combine both queries, using UNION and not UNION ALL to deduplicate parents
# that might be found multiple times with complex element structures
cte = base.union(recursion)
# And load all the elements found in the CTE
query = (
Element.select(
Element.id,
Element.type,
Image.id.alias("image_id"),
Image.width.alias("image_width"),
Image.height.alias("image_height"),
Image.url.alias("image_url"),
Element.polygon,
Element.rotation_angle,
Element.mirrored,
Element.worker_version,
Element.confidence,
cte.c.parent_id,
)
.with_cte(cte)
.join(cte, on=(Element.id == cte.c.child_id))
.join(Image, on=(Element.image_id == Image.id))
.order_by(cte.c.ordering.asc())
)
return query
# -*- coding: utf-8 -*-
import ast
import logging
import os
import tarfile
import tempfile
import time
from pathlib import Path
from urllib.parse import urljoin
import cv2
import imageio.v2 as iio
from arkindex_export.models import Element
import zstandard as zstd
from worker_generic_training_dataset.exceptions import ImageDownloadError
logger = logging.getLogger(__name__)
......@@ -24,14 +27,14 @@ def bounding_box(polygon: list):
return int(x), int(y), int(width), int(height)
def build_image_url(element: Element):
def build_image_url(element):
x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
return urljoin(
element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
)
def download_image(element: Element, folder: Path):
def download_image(element, folder: Path):
"""
Download the image to `folder / {element.image.id}.jpg`
"""
......@@ -43,7 +46,7 @@ def download_image(element: Element, folder: Path):
try:
image = iio.imread(build_image_url(element))
cv2.imwrite(
str(folder / f"{element.image.id}.jpg"),
str(folder / f"{element.image_id}.jpg"),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)
break
......@@ -53,3 +56,27 @@ def download_image(element: Element, folder: Path):
tries += 1
except Exception as e:
raise ImageDownloadError(element.id, e)
def create_tar_zstd_archive(folder_path, destination: Path, chunk_size=1024):
compressor = zstd.ZstdCompressor(level=3)
# Remove extension from the model filename
_, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar")
# Create an uncompressed tar archive with all the needed files
# Files hierarchy ifs kept in the archive.
with tarfile.open(path_to_tar_archive, "w") as tar:
for p in folder_path.glob("**/*"):
x = p.relative_to(folder_path)
tar.add(p, arcname=x, recursive=False)
# Compress the archive
with destination.open("wb") as archive_file:
with open(path_to_tar_archive, "rb") as model_data:
for model_chunk in iter(lambda: model_data.read(chunk_size), b""):
compressed_chunk = compressor.compress(model_chunk)
archive_file.write(compressed_chunk)
# Remove the tar archive
os.remove(path_to_tar_archive)
# -*- coding: utf-8 -*-
import logging
import operator
import tempfile
from pathlib import Path
from apistar.exceptions import ErrorResponse
from arkindex_export import open_database
from arkindex_export.models import Element
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
......@@ -21,16 +20,19 @@ from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.worker import ElementsWorker
from worker_generic_training_dataset.db import (
list_children,
list_classifications,
list_transcriptions,
retrieve_element,
retrieve_entities,
)
from worker_generic_training_dataset.utils import download_image
from worker_generic_training_dataset.utils import (
create_tar_zstd_archive,
download_image,
)
logger = logging.getLogger(__name__)
IMAGE_FOLDER = Path("images")
BULK_BATCH_SIZE = 50
......@@ -57,8 +59,12 @@ class DatasetExtractor(ElementsWorker):
# - self.workdir / "db.sqlite" in Arkindex mode
# - self.args.database in dev mode
database_path = (
self.args.database if self.is_read_only else self.workdir / "db.sqlite"
Path(self.args.database)
if self.is_read_only
else self.workdir / "db.sqlite"
)
if database_path.exists():
database_path.unlink()
init_cache_db(database_path)
......@@ -83,6 +89,7 @@ class DatasetExtractor(ElementsWorker):
exports = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=operator.itemgetter("updated"),
reverse=True,
)
assert len(exports) > 0, "No available exports found."
......@@ -102,33 +109,33 @@ class DatasetExtractor(ElementsWorker):
)
raise e
def insert_element(self, element: Element, parent_id: str):
def insert_element(self, element, image_folder: Path, root: bool = False):
logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images:
if element.image_id and element.image_id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(element, folder=IMAGE_FOLDER)
download_image(element, folder=image_folder)
# Insert image
logger.info("Inserting image")
# Store images in case some other elements use it as well
self.cached_images[element.image.id] = CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
self.cached_images[element.image_id] = CachedImage.create(
id=element.image_id,
width=element.image_width,
height=element.image_height,
url=element.image_url,
)
# Insert element
logger.info("Inserting element")
cached_element = CachedElement.create(
id=element.id,
parent_id=parent_id,
parent_id=None if root else element.parent_id,
type=element.type,
image=element.image.id if element.image else None,
image=self.cached_images[element.image_id] if element.image_id else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=element.worker_version.id
worker_version_id=element.worker_version
if element.worker_version
else None,
confidence=element.confidence,
......@@ -144,10 +151,10 @@ class DatasetExtractor(ElementsWorker):
confidence=classification.confidence,
state=classification.state,
)
for classification in list_classifications(element)
for classification in list_classifications(element.id)
]
if classifications:
logger.info(f"Inserting {len(classifications)} classifications")
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
CachedClassification.bulk_create(
model_list=classifications,
......@@ -158,7 +165,7 @@ class DatasetExtractor(ElementsWorker):
logger.info("Listing transcriptions")
transcriptions = list_transcriptions(cached_element)
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcriptions")
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
CachedTranscription.bulk_create(
model_list=transcriptions,
......@@ -190,11 +197,27 @@ class DatasetExtractor(ElementsWorker):
)
def process_element(self, element):
# Where to save the downloaded images
image_folder = Path(tempfile.mkdtemp())
logger.info(
f"Filling the Base-Worker cache with information from children under element ({element.id})"
)
# Fill cache
# Retrieve parent and create parent
parent = retrieve_element(element.id)
self.insert_element(parent, parent_id=None)
for child in list_children(parent_id=element.id):
self.insert_element(child, parent_id=element.id)
self.insert_element(parent, image_folder, root=True)
# Create children
children = list_children(parent_id=element.id)
nb_children = children.count()
for idx, child in enumerate(children.namedtuples(), start=1):
logger.info(f"Processing child ({idx}/{nb_children})")
self.insert_element(child, image_folder)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path)
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment