Skip to content
Snippets Groups Projects
Verified Commit ec6c64fa authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

first working version

parent a0f977eb
No related branches found
No related tags found
1 merge request!2Implement worker
Pipeline #81792 failed
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
arkindex-base-worker==0.3.2
arkindex-export==0.1.2
imageio==2.27.0
opencv-python==4.7.0.72
# -*- coding: utf-8 -*-
from typing import NamedTuple
from arkindex_export import Classification
from arkindex_export.models import (
Element,
Entity,
EntityType,
Transcription,
TranscriptionEntity,
)
from arkindex_worker.cache import (
CachedElement,
CachedEntity,
CachedTranscription,
CachedTranscriptionEntity,
)
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
def retrieve_element(element_id: str):
return Element.get_by_id(element_id)
def list_classifications(element: Element):
query = Classification.select().where(Classification.element == element)
return query
def parse_transcription(transcription: NamedTuple, element: CachedElement):
return CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription.worker_version.id
if transcription.worker_version
else None,
)
def list_transcriptions(element: CachedElement):
query = Transcription.select().where(Transcription.element_id == element.id)
return [parse_transcription(x, element) for x in query]
def parse_entities(data: NamedTuple, transcription: CachedTranscription):
entity = CachedEntity(
id=data.entity_id,
type=data.type,
name=data.name,
validated=data.validated,
metas=data.metas,
)
return entity, CachedTranscriptionEntity(
id=data.transcription_entity_id,
transcription=transcription,
entity=entity,
offset=data.offset,
length=data.length,
confidence=data.confidence,
)
def retrieve_entities(transcription: CachedTranscription):
query = (
TranscriptionEntity.select(
TranscriptionEntity.id.alias("transcription_entity_id"),
TranscriptionEntity.length.alias("length"),
TranscriptionEntity.offset.alias("offset"),
TranscriptionEntity.confidence.alias("confidence"),
Entity.id.alias("entity_id"),
EntityType.name.alias("type"),
Entity.name,
Entity.validated,
Entity.metas,
)
.where(TranscriptionEntity.transcription_id == transcription.id)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
)
return zip(
*[
parse_entities(entity_data, transcription)
for entity_data in query.namedtuples()
]
)
# -*- coding: utf-8 -*-
class ElementProcessingError(Exception):
"""
Raised when a problem is encountered while processing an element
"""
element_id: str
"""
ID of the element being processed.
"""
def __init__(self, element_id: str, *args: object) -> None:
super().__init__(*args)
self.element_id = element_id
class ImageDownloadError(ElementProcessingError):
"""
Raised when an element's image could not be downloaded
"""
error: Exception
"""
Error encountered.
"""
def __init__(self, element_id: str, error: Exception, *args: object) -> None:
super().__init__(element_id, *args)
self.error = error
def __str__(self) -> str:
return (
f"Couldn't retrieve image of element ({self.element_id}: {str(self.error)})"
)
# -*- coding: utf-8 -*-
import ast
import logging
import time
from pathlib import Path
from urllib.parse import urljoin
import cv2
import imageio.v2 as iio
from arkindex_export.models import Element
from worker_generic_training_dataset.exceptions import ImageDownloadError
logger = logging.getLogger(__name__)
MAX_RETRIES = 5
def bounding_box(polygon: list):
"""
Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return int(x), int(y), int(width), int(height)
def build_image_url(element: Element):
x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
return urljoin(
element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
)
def download_image(element: Element, folder: Path):
"""
Download the image to `folder / {element.id}.jpg`
"""
tries = 1
# retry loop
while True:
if tries > MAX_RETRIES:
raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
try:
image = iio.imread(build_image_url(element))
cv2.imwrite(
str(folder / f"{element.id}.jpg"),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)
break
except TimeoutError:
logger.warning("Timeout, retry in 1 second.")
time.sleep(1)
tries += 1
except Exception as e:
raise ImageDownloadError(element.id, e)
# -*- coding: utf-8 -*-
import logging
import operator
from pathlib import Path
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import create_tables, create_version_table, init_cache_db
from arkindex_export import open_database
from arkindex_export.models import Element
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
create_tables,
create_version_table,
)
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.worker import ElementsWorker
from worker_generic_training_dataset.db import (
list_classifications,
list_transcriptions,
retrieve_element,
retrieve_entities,
)
from worker_generic_training_dataset.utils import download_image
logger = logging.getLogger(__name__)
IMAGE_FOLDER = Path("images")
BULK_BATCH_SIZE = 50
class DatasetExtractor(ElementsWorker):
def configure(self):
......@@ -26,12 +51,10 @@ class DatasetExtractor(ElementsWorker):
def initialize_database(self):
# Create db at
# - self.workdir.parent / self.task_id in Arkindex mode
# - self.workdir / "db.sqlite" in Arkindex mode
# - self.args.database in dev mode
database_path = (
self.args.database
if self.is_read_only
else self.workdir.parent / self.task_id
self.args.database if self.is_read_only else self.workdir / "db.sqlite"
)
init_cache_db(database_path)
......@@ -49,8 +72,9 @@ class DatasetExtractor(ElementsWorker):
)["results"]
except ErrorResponse as e:
logger.error(
f"Could not list exports of corpus ({self.corpus_id}): {str(e)}"
f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
# Find latest that is in "done" state
exports = sorted(
......@@ -62,21 +86,111 @@ class DatasetExtractor(ElementsWorker):
# Download latest it in a tmpfile
try:
export_id = exports[0]["id"]
download_url = self.api_client.request(
logger.info(f"Downloading export ({export_id})...")
self.export = self.api_client.request(
"DownloadExport",
id=export_id,
)["results"]
)
logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
open_database(self.export.name)
except ErrorResponse as e:
logger.error(
f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e)}"
)
print(download_url)
raise e
def insert_element(self, element: Element, parent_id: str):
logger.info(f"Processing element ({element.id})")
if element.image:
# Download image
logger.info("Downloading image")
download_image(element, folder=IMAGE_FOLDER)
# Insert image
logger.info("Inserting image")
CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
)
def process_element(self, element):
...
# Insert element
logger.info("Inserting element")
cached_element = CachedElement.create(
id=element.id,
parent_id=parent_id,
type=element.type,
image=element.image.id if element.image else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=element.worker_version.id
if element.worker_version
else None,
confidence=element.confidence,
)
# List Transcriptions, Metas
#
# Insert classifications
logger.info("Listing classifications")
classifications = [
CachedClassification(
id=classification.id,
element=cached_element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
)
for classification in list_classifications(element)
]
if classifications:
logger.info(f"Inserting {len(classifications)} classifications")
with cache_database.atomic():
CachedClassification.bulk_create(
model_list=classifications,
batch_size=BULK_BATCH_SIZE,
)
# Insert transcriptions
logger.info("Listing transcriptions")
transcriptions = list_transcriptions(cached_element)
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcriptions")
with cache_database.atomic():
CachedTranscription.bulk_create(
model_list=transcriptions,
batch_size=BULK_BATCH_SIZE,
)
logger.info("Listing entities")
entities, transcription_entities = [], []
for transcription in transcriptions:
ents, transc_ents = retrieve_entities(transcription)
entities.extend(ents)
transcription_entities.extend(transc_ents)
if entities:
logger.info(f"Inserting {len(entities)} entities")
with cache_database.atomic():
CachedEntity.bulk_create(
model_list=entities,
batch_size=BULK_BATCH_SIZE,
)
# Insert transcription entities
logger.info(
f"Inserting {len(transcription_entities)} transcription entities"
)
with cache_database.atomic():
CachedTranscriptionEntity.bulk_create(
model_list=transcription_entities,
batch_size=BULK_BATCH_SIZE,
)
def process_element(self, element):
# Retrieve parent and create parent
parent = retrieve_element(element.id)
self.insert_element(parent, parent_id=None)
for child in list_children(parent_id=element.id):
self.insert_element(child, parent_id=element.id)
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment