Skip to content
Snippets Groups Projects
Verified Commit 3a76f403 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Fix API worker

parent 231047c5
No related branches found
No related tags found
1 merge request!25Draft: Refactor and implement API version of the worker
Pipeline #170937 passed
......@@ -49,8 +49,8 @@ setup(
install_requires=parse_requirements(),
entry_points={
"console_scripts": [
f"{COMMAND}={MODULE}.worker:main",
f"{COMMAND}-api={MODULE}.worker:main",
f"{COMMAND}={MODULE}.from_sql:main",
f"{COMMAND}-api={MODULE}.from_api:main",
]
},
packages=find_packages(),
......
......@@ -5,11 +5,11 @@ import logging
import sys
import tempfile
import uuid
from collections.abc import Iterable
from itertools import groupby
from operator import attrgetter
from pathlib import Path
from typing import List, Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, WorkerRun, WorkerVersion
......@@ -78,20 +78,6 @@ class Extractor(DatasetWorker):
create_tables()
def list_classifications(self, element_id: UUID):
raise NotImplementedError
def list_transcriptions(self, element: ArkindexElement, **kwargs):
raise NotImplementedError
def list_transcription_entities(
self, transcription: ArkindexTranscription, **kwargs
):
raise NotImplementedError
def list_element_children(self, element: ArkindexElement, **kwargs):
raise NotImplementedError
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
......@@ -190,7 +176,7 @@ class Extractor(DatasetWorker):
def insert_element(
self,
element: Element,
element: Element | ArkindexElement,
split_name: Optional[str] = None,
parent_id: Optional[str] = None,
) -> None:
......@@ -209,21 +195,33 @@ class Extractor(DatasetWorker):
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images:
if isinstance(element, Element):
image = element.image
polygon = element.polygon
wk_version = get_object_id(element.worker_version)
wk_run = get_object_id(element.worker_run)
else:
image = element.zone.image
polygon = element.zone.polygon
wk_version = element.worker_version_id
wk_run = element.worker_run.id if element.worker_run else None
if image and image.id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.images_folder / f"{element.image.id}.jpg"
download_image(url=build_image_url(image, polygon)).save(
self.images_folder / f"{image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
# Store images in case some other elements use it as well
with cache_database.atomic():
self.cached_images[element.image.id] = CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
self.cached_images[image.id] = CachedImage.create(
id=image.id,
width=image.width,
height=image.height,
url=image.url,
)
# Insert element
......@@ -233,12 +231,12 @@ class Extractor(DatasetWorker):
id=element.id,
parent_id=parent_id,
type=element.type,
image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon,
image=self.cached_images[image.id] if image else None,
polygon=polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=get_object_id(element.worker_version),
worker_run_id=get_object_id(element.worker_run),
worker_version_id=get_object_id(wk_version),
worker_run_id=get_object_id(wk_run),
confidence=element.confidence,
)
......@@ -266,21 +264,20 @@ class Extractor(DatasetWorker):
set_name=split_name,
)
def process_split(self, split_name: str, elements: List[Element]) -> None:
def process_split(
self, split_name: str, elements: Iterable[Element | ArkindexElement]
) -> None:
logger.info(
f"Filling the cache with information from elements in the split {split_name}"
)
# First list all pages
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
logger.info(f"Processing `{split_name}` element ({idx})")
# Insert page
self.insert_element(element, split_name=split_name)
# List children
children = self.list_element_children(ArkindexElement(id=element.id))
children = self.list_element_children(element)
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx})")
# Insert child
......@@ -310,7 +307,7 @@ class Extractor(DatasetWorker):
# Iterate over given splits
for dataset_set in sets:
elements = self.list_set_elements(dataset.id, dataset_set.name)
elements = self.list_set_elements(dataset_set)
self.process_split(dataset_set.name, elements)
# TAR + ZST the cache and the images folder, and store as task artifact
......
......@@ -21,6 +21,7 @@ from arkindex_export import (
)
from arkindex_export.queries import list_children
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Set
from arkindex_worker.models import Transcription as ArkindexTranscription
from peewee import CharField
from worker_generic_training_dataset import Extractor
......@@ -84,15 +85,15 @@ class DatasetExtractorFromSQL(Extractor):
)
raise e
def list_set_elements(self, dataset_id: UUID, set_name: str):
def list_set_elements(self, dataset_set: Set): # dataset_id: UUID, set_name: str):
return (
Element.select()
.join(Image)
.switch(Element)
.join(DatasetElement, on=DatasetElement.element)
.where(
DatasetElement.dataset == dataset_id,
DatasetElement.set_name == set_name,
DatasetElement.dataset == dataset_set.dataset.id,
DatasetElement.set_name == dataset_set.name,
)
)
......@@ -110,7 +111,7 @@ class DatasetExtractorFromSQL(Extractor):
def list_transcription_entities(
self, transcription: ArkindexTranscription, **kwargs
): # -> Any:
):
return (
TranscriptionEntity.select()
.where(TranscriptionEntity.transcription == transcription.id)
......
......@@ -9,13 +9,14 @@ from arkindex_worker.image import BoundingBox, polygon_bounding_box
logger: Logger = logging.getLogger(__name__)
def build_image_url(element) -> str:
bbox: BoundingBox = polygon_bounding_box(json.loads(element.polygon))
def build_image_url(image, polygon: str | list[list[int]]) -> str:
if isinstance(polygon, str):
polygon = json.loads(polygon)
bbox: BoundingBox = polygon_bounding_box(polygon)
x: int
y: int
width: int
height: int
x, y, width, height = bbox
return urljoin(
element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
)
return urljoin(image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment