Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (25)
Showing
with 1581 additions and 874 deletions
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,pytest,requests,setuptools,sh,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,pytest,requests,setuptools,sh,tenacity,yaml
0.1.13
0.2.0-beta2
# -*- coding: utf-8 -*-
import json
import os
import sqlite3
from peewee import (
BooleanField,
CharField,
Field,
FloatField,
ForeignKeyField,
IntegerField,
Model,
SqliteDatabase,
TextField,
UUIDField,
)
from arkindex_worker import logger
from arkindex_worker.image import open_image, polygon_bounding_box
db = SqliteDatabase(None)
class JSONField(Field):
field_type = "text"
def db_value(self, value):
if value is None:
return
return json.dumps(value)
def python_value(self, value):
if value is None:
return
return json.loads(value)
class CachedImage(Model):
id = UUIDField(primary_key=True)
width = IntegerField()
height = IntegerField()
url = TextField()
class Meta:
database = db
table_name = "images"
class CachedElement(Model):
id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True)
type = CharField(max_length=50)
image = ForeignKeyField(CachedImage, backref="elements", null=True)
polygon = JSONField(null=True)
initial = BooleanField(default=False)
worker_version_id = UUIDField(null=True)
class Meta:
database = db
table_name = "elements"
def open_image(self, *args, max_size=None, **kwargs):
"""
Open this element's image as a Pillow image.
This does not crop the image to the element's polygon.
IIIF servers with maxWidth, maxHeight or maxArea restrictions on image size are not supported.
:param max_size: Subresolution of the image.
"""
if not self.image_id or not self.polygon:
raise ValueError(f"Element {self.id} has no image")
if max_size is None:
resize = "full"
else:
bounding_box = polygon_bounding_box(self.polygon)
# Do not resize for polygons that do not exactly match the images
if (
bounding_box.width != self.image.width
or bounding_box.height != self.image.height
):
resize = "full"
logger.warning(
"Only full size elements covered, downloading full size image"
)
# Do not resize when the image is below the maximum size
elif self.image.width <= max_size and self.image.height <= max_size:
resize = "full"
else:
ratio = max_size / max(self.image.width, self.image.height)
new_width, new_height = int(self.image.width * ratio), int(
self.image.height * ratio
)
resize = f"{new_width},{new_height}"
url = self.image.url
if not url.endswith("/"):
url += "/"
return open_image(f"{url}full/{resize}/0/default.jpg", *args, **kwargs)
class CachedTranscription(Model):
id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField()
confidence = FloatField()
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "transcriptions"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [CachedImage, CachedElement, CachedTranscription]
def init_cache_db(path):
db.init(
path,
pragmas={
# SQLite ignores foreign keys and check constraints by default!
"foreign_keys": 1,
"ignore_check_constraints": 0,
},
)
db.connect()
logger.info(f"Connected to cache on {path}")
def create_tables():
"""
Creates the tables in the cache DB only if they do not already exist.
"""
db.create_tables(MODELS)
def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=None):
"""
Merge all the potential parent task's databases into the existing local one
"""
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
assert os.path.exists(current_database)
# Handle possible chunk in parent task name
# This is needed to support the init_elements databases
filenames = [
"db.sqlite",
]
if chunk is not None:
filenames.append(f"db_{chunk}.sqlite")
# Find all the paths for these databases
paths = list(
filter(
lambda p: os.path.isfile(p),
[
os.path.join(data_dir, parent, name)
for parent in parent_ids
for name in filenames
],
)
)
if not paths:
logger.info("No parents cache to use")
return
# Open a connection on current database
connection = sqlite3.connect(current_database)
cursor = connection.cursor()
# Merge each table into the local database
for idx, path in enumerate(paths):
logger.info(f"Merging parent db {path} into {current_database}")
statements = [
"PRAGMA page_size=80000;",
"PRAGMA synchronous=OFF;",
f"ATTACH DATABASE '{path}' AS source_{idx};",
f"REPLACE INTO elements SELECT * FROM source_{idx}.elements;",
f"REPLACE INTO transcriptions SELECT * FROM source_{idx}.transcriptions;",
]
for statement in statements:
cursor.execute(statement)
connection.commit()
......@@ -48,10 +48,27 @@ def download_image(url):
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
try:
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
except requests.exceptions.SSLError:
logger.warning(
"An SSLError occurred during image download, retrying with a weaker and unsafe SSL configuration"
)
# Saving current ciphers
previous_ciphers = requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS
# Downgrading ciphers to download the image
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL:@SECLEVEL=1"
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
# Restoring previous ciphers
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = previous_ciphers
resp.raise_for_status()
# Preprocess the image and prepare it for classification
......
......@@ -82,6 +82,7 @@ class Element(MagicDict):
def open_image(self, *args, max_size=None, **kwargs):
"""
Open this element's image as a Pillow image.
This does not crop the image to the element's polygon.
:param max_size: Subresolution of the image.
"""
if not self.get("zone"):
......
# -*- coding: utf-8 -*-
import json
import traceback
import warnings
from collections import Counter
from datetime import datetime
......@@ -83,35 +82,12 @@ class Reporter(object):
)
element["classifications"] = dict(counter)
def add_transcription(self, element_id, type=None, type_count=None):
def add_transcription(self, element_id, count=1):
"""
Report creating a transcription on an element.
Multiple transcriptions with the same parent can be declared with the type_count parameter.
"""
if type_count is None:
if isinstance(type, int):
type_count, type = type, None
else:
type_count = 1
if type is not None:
warnings.warn(
"Transcription types have been deprecated and will be removed in the next release.",
FutureWarning,
)
self._get_element(element_id)["transcriptions"] += type_count
def add_transcriptions(self, element_id, transcriptions):
"""
Report one or more transcriptions at once.
"""
assert isinstance(transcriptions, list), "A list is required for transcriptions"
warnings.warn(
"Reporter.add_transcriptions is deprecated due to transcription types being removed. Please use Reporter.add_transcription(element_id, count) instead.",
FutureWarning,
)
self.add_transcription(element_id, len(transcriptions))
self._get_element(element_id)["transcriptions"] += count
def add_entity(self, element_id, entity_id, type, name):
"""
......
This diff is collapsed.
# -*- coding: utf-8 -*-
import json
import os
import sys
import uuid
from enum import Enum
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.classification import ClassificationMixin
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
class ActivityState(Enum):
Queued = "queued"
Started = "started"
Processed = "processed"
Error = "error"
class ElementsWorker(
BaseWorker,
ClassificationMixin,
ElementMixin,
TranscriptionMixin,
WorkerVersionMixin,
EntityMixin,
MetaDataMixin,
):
def __init__(self, description="Arkindex Elements Worker", use_cache=False):
super().__init__(description, use_cache)
# Add report concerning elements
self.report = Reporter("unknown worker")
# Add mandatory argument to process elements
self.parser.add_argument(
"--elements-list",
help="JSON elements list to use",
type=open,
default=os.environ.get("TASK_ELEMENTS"),
)
self.parser.add_argument(
"--element",
type=uuid.UUID,
nargs="+",
help="One or more Arkindex element ID",
)
self.classes = {}
self._worker_version_cache = {}
def list_elements(self):
assert not (
self.args.elements_list and self.args.element
), "elements-list and element CLI args shouldn't be both set"
out = []
# Load from the cache when available
# Flake8 wants us to use 'is True', but Peewee only supports '== True'
cache_query = CachedElement.select().where(
CachedElement.initial == True # noqa: E712
)
if self.use_cache and cache_query.exists():
return cache_query
# Process elements from JSON file
elif self.args.elements_list:
data = json.load(self.args.elements_list)
assert isinstance(data, list), "Elements list must be a list"
assert len(data), "No elements in elements list"
out += list(filter(None, [element.get("id") for element in data]))
# Add any extra element from CLI
elif self.args.element:
out += self.args.element
return out
def run(self):
"""
Process every elements from the provided list
"""
self.configure()
# List all elements either from JSON file
# or direct list of elements on CLI
elements = self.list_elements()
if not elements:
logger.warning("No elements to process, stopping.")
sys.exit(1)
# Process every element
count = len(elements)
failed = 0
for i, item in enumerate(elements, start=1):
element = None
try:
if self.use_cache:
# Just use the result of list_elements as the element
element = item
else:
# Load element using the Arkindex API
element = Element(**self.request("RetrieveElement", id=item))
logger.info(f"Processing {element} ({i}/{count})")
# Report start of process, run process, then report end of process
self.update_activity(element.id, ActivityState.Started)
self.process_element(element)
self.update_activity(element.id, ActivityState.Processed)
except Exception as e:
failed += 1
element_id = (
element.id
if isinstance(element, (Element, CachedElement))
else item
)
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing element {element_id}: {e.title} - {e.content}"
else:
message = f"Failed running worker on element {element_id}: {e}"
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
self.update_activity(element_id, ActivityState.Error)
self.report.error(element_id, e)
# Save report as local artifact
self.report.save(os.path.join(self.work_dir, "ml_report.json"))
if failed:
logger.error(
"Ran on {} elements: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
def process_element(self, element):
"""Override this method to analyze an Arkindex element from the provided list"""
def update_activity(self, element_id, state):
"""
Update worker activity for this element
This method should not raise a runtime exception, but simply warn users
"""
assert element_id and isinstance(
element_id, (uuid.UUID, str)
), "element_id shouldn't be null and should be an UUID or str"
assert isinstance(state, ActivityState), "state should be an ActivityState"
if not self.features.get("workers_activity"):
logger.debug("Skipping Worker activity update as it's disabled on backend")
return
if self.is_read_only:
logger.warning("Cannot update activity as this worker is in read-only mode")
return
try:
out = self.request(
"UpdateWorkerActivity",
id=self.worker_version_id,
body={
"element_id": str(element_id),
"state": state.value,
},
)
logger.debug(f"Updated activity of element {element_id} to {state}")
return out
except ErrorResponse as e:
logger.warning(
f"Failed to update activity of element {element_id} to {state.value} due to an API error: {e.content}"
)
except Exception as e:
logger.warning(
f"Failed to update activity of element {element_id} to {state.value}: {e}"
)
# -*- coding: utf-8 -*-
import argparse
import json
import logging
import os
from pathlib import Path
import gnupg
import yaml
from apistar.exceptions import ErrorResponse
from tenacity import (
before_sleep_log,
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import create_tables, init_cache_db, merge_parents_cache
def _is_500_error(exc):
"""
Check if an Arkindex API error is a 50x
This is used to retry most API calls implemented here
"""
if not isinstance(exc, ErrorResponse):
return False
return 500 <= exc.status_code < 600
class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker", use_cache=False):
self.parser = argparse.ArgumentParser(description=description)
# Setup workdir either in Ponos environment or on host's home
if os.environ.get("PONOS_DATA"):
self.work_dir = os.path.join(os.environ["PONOS_DATA"], "current")
else:
# We use the official XDG convention to store file for developers
# https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html
xdg_data_home = os.environ.get(
"XDG_DATA_HOME", os.path.expanduser("~/.local/share")
)
self.work_dir = os.path.join(xdg_data_home, "arkindex")
os.makedirs(self.work_dir, exist_ok=True)
self.worker_version_id = os.environ.get("WORKER_VERSION_ID")
if not self.worker_version_id:
logger.warning(
"Missing WORKER_VERSION_ID environment variable, worker is in read-only mode"
)
logger.info(f"Worker will use {self.work_dir} as working directory")
self.use_cache = use_cache
@property
def is_read_only(self):
"""Worker cannot publish anything without a worker version ID"""
return self.worker_version_id is None
def configure(self):
"""
Configure worker using cli args and environment variables
"""
self.parser.add_argument(
"-c",
"--config",
help="Alternative configuration file when running without a Worker Version ID",
type=open,
)
self.parser.add_argument(
"-d",
"--database",
help="Alternative SQLite database to use for worker caching",
type=str,
default=None,
)
self.parser.add_argument(
"-v",
"--verbose",
help="Display more information on events and errors",
action="store_true",
default=False,
)
# Call potential extra arguments
self.add_arguments()
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
# Setup logging level
if self.args.verbose:
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
# Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
# Load features available on backend, and check authentication
user = self.request("RetrieveUser")
logger.debug(f"Connected as {user['display_name']} - {user['email']}")
self.features = user["features"]
if self.worker_version_id:
# Retrieve initial configuration from API
worker_version = self.request(
"RetrieveWorkerVersion", id=self.worker_version_id
)
logger.info(
f"Loaded worker {worker_version['worker']['name']} revision {worker_version['revision']['hash'][0:7]} from API"
)
self.config = worker_version["configuration"]["configuration"]
required_secrets = worker_version["configuration"].get("secrets", [])
elif self.args.config:
# Load config from YAML file
self.config = yaml.safe_load(self.args.config)
required_secrets = self.config.get("secrets", [])
logger.info(
f"Running with local configuration from {self.args.config.name}"
)
else:
self.config = {}
required_secrets = []
logger.warning("Running without any extra configuration")
# Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets}
if self.args.database is not None:
self.use_cache = True
if self.use_cache is True:
if self.args.database is not None:
assert os.path.isfile(
self.args.database
), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database
elif os.environ.get("TASK_ID"):
cache_dir = os.path.join(
os.environ.get("PONOS_DATA", "/data"), os.environ.get("TASK_ID")
)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
else:
self.cache_path = os.path.join(os.getcwd(), "db.sqlite")
init_cache_db(self.cache_path)
create_tables()
else:
logger.debug("Cache is disabled")
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
task_id = os.environ.get("TASK_ID")
if self.use_cache and self.args.database is None and task_id is not None:
task = self.request("RetrieveTaskFromAgent", id=task_id)
merge_parents_cache(
task["parents"],
self.cache_path,
data_dir=os.environ.get("PONOS_DATA", "/data"),
chunk=os.environ.get("ARKINDEX_TASK_CHUNK"),
)
def load_secret(self, name):
"""Load all secrets described in the worker configuration"""
secret = None
# Load from the backend
try:
resp = self.request("RetrieveSecret", name=name)
secret = resp["content"]
logging.info(f"Loaded API secret {name}")
except ErrorResponse as e:
logger.warning(f"Secret {name} not available: {e.content}")
# Load from local developer storage
base_dir = Path(os.environ.get("XDG_CONFIG_HOME") or "~/.config").expanduser()
path = base_dir / "arkindex" / "secrets" / name
if path.exists():
logging.debug(f"Loading local secret from {path}")
try:
gpg = gnupg.GPG()
decrypted = gpg.decrypt_file(open(path, "rb"))
assert (
decrypted.ok
), f"GPG error: {decrypted.status} - {decrypted.stderr}"
secret = decrypted.data.decode("utf-8")
logging.info(f"Loaded local secret {name}")
except Exception as e:
logger.error(f"Local secret {name} is not available as {path}: {e}")
if secret is None:
raise Exception(f"Secret {name} is not available on the API nor locally")
# Parse secret payload, according to its extension
_, ext = os.path.splitext(os.path.basename(name))
try:
ext = ext.lower()
if ext == ".json":
return json.loads(secret)
elif ext in (".yaml", ".yml"):
return yaml.safe_load(secret)
except Exception as e:
logger.error(f"Failed to parse secret {name}: {e}")
# By default give raw secret payload
return secret
@retry(
retry=retry_if_exception(_is_500_error),
wait=wait_exponential(multiplier=2, min=3),
reraise=True,
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.INFO),
)
def request(self, *args, **kwargs):
"""
Proxy all Arkindex API requests with a retry mechanism
in case of 50X errors
The same API call will be retried 5 times, with an exponential sleep time
going through 3, 4, 8 and 16 seconds of wait between call.
If the 5th call still gives a 50x, the exception is re-raised
and the caller should catch it
Log messages are displayed before sleeping (when at least one exception occurred)
"""
return self.api_client.request(*args, **kwargs)
def add_arguments(self):
"""Override this method to add argparse argument to this worker"""
def run(self):
"""Override this method to implement your own process"""
# -*- coding: utf-8 -*-
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.models import Element
class ClassificationMixin(object):
def load_corpus_classes(self, corpus_id):
"""
Load ML classes for the given corpus ID
"""
corpus_classes = self.api_client.paginate(
"ListCorpusMLClasses",
id=corpus_id,
)
self.classes[corpus_id] = {
ml_class["name"]: ml_class["id"] for ml_class in corpus_classes
}
logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes")
def get_ml_class_id(self, corpus_id, ml_class):
"""
Return the ID corresponding to the given class name on a specific corpus
This method will automatically create missing classes
"""
if not self.classes.get(corpus_id):
self.load_corpus_classes(corpus_id)
ml_class_id = self.classes[corpus_id].get(ml_class)
if ml_class_id is None:
logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}")
try:
response = self.request(
"CreateMLClass", id=corpus_id, body={"name": ml_class}
)
ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
logger.debug(f"Created ML class {response['id']}")
except ErrorResponse as e:
# Only reload for 400 errors
if e.status_code != 400:
raise
# Reload and make sure we have the class
logger.info(
f"Reloading corpus classes to see if {ml_class} already exists"
)
self.load_corpus_classes(corpus_id)
assert (
ml_class in self.classes[corpus_id]
), "Missing class {ml_class} even after reloading"
ml_class_id = self.classes[corpus_id][ml_class]
return ml_class_id
def create_classification(
self, element, ml_class, confidence, high_confidence=False
):
"""
Create a classification on the given element through API
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert ml_class and isinstance(
ml_class, str
), "ml_class shouldn't be null and should be of type str"
assert (
isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence shouldn't be null and should be a float in [0..1] range"
assert isinstance(
high_confidence, bool
), "high_confidence shouldn't be null and should be of type bool"
if self.is_read_only:
logger.warning(
"Cannot create classification as this worker is in read-only mode"
)
return
try:
self.request(
"CreateClassification",
body={
"element": element.id,
"ml_class": self.get_ml_class_id(element.corpus.id, ml_class),
"worker_version": self.worker_version_id,
"confidence": confidence,
"high_confidence": high_confidence,
},
)
except ErrorResponse as e:
# Detect already existing classification
if (
e.status_code == 400
and "non_field_errors" in e.content
and "The fields element, worker_version, ml_class must make a unique set."
in e.content["non_field_errors"]
):
logger.warning(
f"This worker version has already set {ml_class} on element {element.id}"
)
return
# Propagate any other API error
raise
self.report.add_classification(element.id, ml_class)
# -*- coding: utf-8 -*-
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Element
class ElementMixin(object):
def create_sub_element(self, element, type, name, polygon):
"""
Create a child element on the given element through API
Return the ID of the created sub element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert type and isinstance(
type, str
), "type shouldn't be null and should be of type str"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
assert polygon and isinstance(
polygon, list
), "polygon shouldn't be null and should be of type list"
assert len(polygon) >= 3, "polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), "polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
), "polygon points should be lists of two numbers"
if self.is_read_only:
logger.warning("Cannot create element as this worker is in read-only mode")
return
sub_element = self.request(
"CreateElement",
body={
"type": type,
"name": name,
"image": element.zone.image.id,
"corpus": element.corpus.id,
"polygon": polygon,
"parent": element.id,
"worker_version": self.worker_version_id,
},
)
self.report.add_element(element.id, type)
return sub_element["id"]
def create_elements(self, parent, elements):
"""
Create children elements on the given element through API
Return the IDs of created elements
"""
if isinstance(parent, Element):
assert parent.get(
"zone"
), "create_elements cannot be used on parents without zones"
elif isinstance(parent, CachedElement):
assert (
parent.image_id
), "create_elements cannot be used on parents without images"
else:
raise TypeError(
"Parent element should be an Element or CachedElement instance"
)
assert elements and isinstance(
elements, list
), "elements shouldn't be null and should be of type list"
for index, element in enumerate(elements):
assert isinstance(
element, dict
), f"Element at index {index} in elements: Should be of type dict"
name = element.get("name")
assert name and isinstance(
name, str
), f"Element at index {index} in elements: name shouldn't be null and should be of type str"
type = element.get("type")
assert type and isinstance(
type, str
), f"Element at index {index} in elements: type shouldn't be null and should be of type str"
polygon = element.get("polygon")
assert polygon and isinstance(
polygon, list
), f"Element at index {index} in elements: polygon shouldn't be null and should be of type list"
assert (
len(polygon) >= 3
), f"Element at index {index} in elements: polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), f"Element at index {index} in elements: polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
), f"Element at index {index} in elements: polygon points should be lists of two numbers"
if self.is_read_only:
logger.warning("Cannot create elements as this worker is in read-only mode")
return
created_ids = self.request(
"CreateElements",
id=parent.id,
body={
"worker_version": self.worker_version_id,
"elements": elements,
},
)
for element in elements:
self.report.add_element(parent.id, element["type"])
if self.use_cache:
# Create the image as needed and handle both an Element and a CachedElement
if isinstance(parent, CachedElement):
image_id = parent.image_id
else:
image_id = parent.zone.image.id
CachedImage.get_or_create(
id=parent.zone.image.id,
defaults={
"width": parent.zone.image.width,
"height": parent.zone.image.height,
"url": parent.zone.image.url,
},
)
# Store elements in local cache
try:
to_insert = [
{
"id": created_ids[idx]["id"],
"parent_id": parent.id,
"type": element["type"],
"image_id": image_id,
"polygon": element["polygon"],
"worker_version_id": self.worker_version_id,
}
for idx, element in enumerate(elements)
]
CachedElement.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(f"Couldn't save created elements in local cache: {e}")
return created_ids
def list_element_children(
self,
element,
best_class=None,
folder=None,
name=None,
recursive=None,
type=None,
with_best_classes=None,
with_corpus=None,
with_has_children=None,
with_zone=None,
worker_version=None,
):
"""
List children of an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
query_params = {}
if best_class is not None:
assert isinstance(best_class, str) or isinstance(
best_class, bool
), "best_class should be of type str or bool"
query_params["best_class"] = best_class
if folder is not None:
assert isinstance(folder, bool), "folder should be of type bool"
query_params["folder"] = folder
if name:
assert isinstance(name, str), "name should be of type str"
query_params["name"] = name
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if type:
assert isinstance(type, str), "type should be of type str"
query_params["type"] = type
if with_best_classes is not None:
assert isinstance(
with_best_classes, bool
), "with_best_classes should be of type bool"
query_params["with_best_classes"] = with_best_classes
if with_corpus is not None:
assert isinstance(with_corpus, bool), "with_corpus should be of type bool"
query_params["with_corpus"] = with_corpus
if with_has_children is not None:
assert isinstance(
with_has_children, bool
), "with_has_children should be of type bool"
query_params["with_has_children"] = with_has_children
if with_zone is not None:
assert isinstance(with_zone, bool), "with_zone should be of type bool"
query_params["with_zone"] = with_zone
if worker_version:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
query_params["worker_version"] = worker_version
if self.use_cache:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"type",
"worker_version",
}, "When using the local cache, you can only filter by 'type' and/or 'worker_version'"
query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type:
query = query.where(CachedElement.type == type)
if worker_version:
query = query.where(CachedElement.worker_version_id == worker_version)
return query
else:
children = self.api_client.paginate(
"ListElementChildren", id=element.id, **query_params
)
return children
# -*- coding: utf-8 -*-
from enum import Enum
from arkindex_worker import logger
from arkindex_worker.models import Element
class EntityType(Enum):
Person = "person"
Location = "location"
Subject = "subject"
Organization = "organization"
Misc = "misc"
Number = "number"
Date = "date"
class EntityMixin(object):
def create_entity(self, element, name, type, corpus, metas=None, validated=None):
"""
Create an entity on the given corpus through API
Return the ID of the created entity
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
assert type and isinstance(
type, EntityType
), "type shouldn't be null and should be of type EntityType"
assert corpus and isinstance(
corpus, str
), "corpus shouldn't be null and should be of type str"
if metas:
assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None:
assert isinstance(validated, bool), "validated should be of type bool"
if self.is_read_only:
logger.warning("Cannot create entity as this worker is in read-only mode")
return
entity = self.request(
"CreateEntity",
body={
"name": name,
"type": type.value,
"metas": metas,
"validated": validated,
"corpus": corpus,
"worker_version": self.worker_version_id,
},
)
self.report.add_entity(element.id, entity["id"], type.value, name)
return entity["id"]
# -*- coding: utf-8 -*-
from enum import Enum
from arkindex_worker import logger
from arkindex_worker.models import Element
class MetaType(Enum):
Text = "text"
HTML = "html"
Date = "date"
Location = "location"
# Element's original structure reference (intended to be indexed)
Reference = "reference"
class MetaDataMixin(object):
def create_metadata(self, element, type, name, value, entity=None):
"""
Create a metadata on the given element through API
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert type and isinstance(
type, MetaType
), "type shouldn't be null and should be of type MetaType"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
assert value and isinstance(
value, str
), "value shouldn't be null and should be of type str"
if entity:
assert isinstance(entity, str), "entity should be of type str"
if self.is_read_only:
logger.warning("Cannot create metadata as this worker is in read-only mode")
return
metadata = self.request(
"CreateMetaData",
id=element.id,
body={
"type": type.value,
"name": name,
"value": value,
"entity": entity,
"worker_version": self.worker_version_id,
},
)
self.report.add_metadata(element.id, metadata["id"], type.value, name)
return metadata["id"]
# -*- coding: utf-8 -*-
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
class TranscriptionMixin(object):
def create_transcription(self, element, text, score):
"""
Create a transcription on the given element through the API.
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert text and isinstance(
text, str
), "text shouldn't be null and should be of type str"
assert (
isinstance(score, float) and 0 <= score <= 1
), "score shouldn't be null and should be a float in [0..1] range"
if self.is_read_only:
logger.warning(
"Cannot create transcription as this worker is in read-only mode"
)
return
created = self.request(
"CreateTranscription",
id=element.id,
body={
"text": text,
"worker_version": self.worker_version_id,
"score": score,
},
)
self.report.add_transcription(element.id)
if self.use_cache:
# Store transcription in local cache
try:
to_insert = [
{
"id": created["id"],
"element_id": element.id,
"text": created["text"],
"confidence": created["confidence"],
"worker_version_id": self.worker_version_id,
}
]
CachedTranscription.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcription in local cache: {e}"
)
def create_transcriptions(self, transcriptions):
"""
Create multiple transcriptions at once on existing elements through the API.
"""
assert transcriptions and isinstance(
transcriptions, list
), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions):
element_id = transcription.get("element_id")
assert element_id and isinstance(
element_id, str
), f"Transcription at index {index} in transcriptions: element_id shouldn't be null and should be of type str"
text = transcription.get("text")
assert text and isinstance(
text, str
), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str"
score = transcription.get("score")
assert (
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
created_trs = self.request(
"CreateTranscriptions",
body={
"worker_version": self.worker_version_id,
"transcriptions": transcriptions,
},
)["transcriptions"]
for created_tr in created_trs:
self.report.add_transcription(created_tr["element_id"])
if self.use_cache:
# Store transcriptions in local cache
try:
to_insert = [
{
"id": created_tr["id"],
"element_id": created_tr["element_id"],
"text": created_tr["text"],
"confidence": created_tr["confidence"],
"worker_version_id": self.worker_version_id,
}
for created_tr in created_trs
]
CachedTranscription.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcriptions in local cache: {e}"
)
def create_element_transcriptions(self, element, sub_element_type, transcriptions):
"""
Create multiple sub elements with their transcriptions on the given element through API
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert sub_element_type and isinstance(
sub_element_type, str
), "sub_element_type shouldn't be null and should be of type str"
assert transcriptions and isinstance(
transcriptions, list
), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions):
text = transcription.get("text")
assert text and isinstance(
text, str
), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str"
score = transcription.get("score")
assert (
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
polygon = transcription.get("polygon")
assert polygon and isinstance(
polygon, list
), f"Transcription at index {index} in transcriptions: polygon shouldn't be null and should be of type list"
assert (
len(polygon) >= 3
), f"Transcription at index {index} in transcriptions: polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), f"Transcription at index {index} in transcriptions: polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
), f"Transcription at index {index} in transcriptions: polygon points should be lists of two numbers"
if self.is_read_only:
logger.warning(
"Cannot create transcriptions as this worker is in read-only mode"
)
return
annotations = self.request(
"CreateElementTranscriptions",
id=element.id,
body={
"element_type": sub_element_type,
"worker_version": self.worker_version_id,
"transcriptions": transcriptions,
"return_elements": True,
},
)
for annotation in annotations:
if annotation["created"]:
logger.debug(
f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation"
)
self.report.add_element(element.id, sub_element_type)
self.report.add_transcription(annotation["element_id"])
if self.use_cache:
# Store transcriptions and their associated element (if created) in local cache
created_ids = set()
elements_to_insert = []
transcriptions_to_insert = []
for index, annotation in enumerate(annotations):
transcription = transcriptions[index]
if annotation["element_id"] not in created_ids:
# Even if the API says the element already existed in the DB,
# we need to check if it is available in the local cache.
# Peewee does not have support for SQLite's INSERT OR IGNORE,
# so we do the check here, element by element.
try:
CachedElement.get_by_id(annotation["element_id"])
except CachedElement.DoesNotExist:
elements_to_insert.append(
{
"id": annotation["element_id"],
"parent_id": element.id,
"type": sub_element_type,
"image_id": element.image_id,
"polygon": transcription["polygon"],
"worker_version_id": self.worker_version_id,
}
)
created_ids.add(annotation["element_id"])
transcriptions_to_insert.append(
{
"id": annotation["id"],
"element_id": annotation["element_id"],
"text": transcription["text"],
"confidence": transcription["score"],
"worker_version_id": self.worker_version_id,
}
)
try:
CachedElement.insert_many(elements_to_insert).execute()
CachedTranscription.insert_many(transcriptions_to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcriptions in local cache: {e}"
)
return annotations
def list_transcriptions(
self, element, element_type=None, recursive=None, worker_version=None
):
"""
List transcriptions on an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
query_params = {}
if element_type:
assert isinstance(element_type, str), "element_type should be of type str"
query_params["element_type"] = element_type
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if worker_version:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
query_params["worker_version"] = worker_version
if self.use_cache and recursive is None:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"worker_version",
}, "When using the local cache, you can only filter by 'worker_version'"
transcriptions = CachedTranscription.select().where(
CachedTranscription.element_id == element.id
)
if worker_version:
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
)
else:
if self.use_cache:
logger.warning(
"'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter."
)
transcriptions = self.api_client.paginate(
"ListTranscriptions", id=element.id, **query_params
)
return transcriptions
# -*- coding: utf-8 -*-
class WorkerVersionMixin(object):
def get_worker_version(self, worker_version_id: str) -> dict:
"""
Get worker version from cache if possible, otherwise make API request
"""
if worker_version_id is None:
raise ValueError("No worker version ID")
if worker_version_id in self._worker_version_cache:
return self._worker_version_cache[worker_version_id]
worker_version = self.request("RetrieveWorkerVersion", id=worker_version_id)
self._worker_version_cache[worker_version_id] = worker_version
return worker_version
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Helper function to get the worker slug from element, classification or transcription.
Gets the worker version slug from cache if possible, otherwise makes an API request.
Returns None if there is no associated worker version.
:type worker_version_id: A worker version UUID
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
arkindex-client==1.0.6
peewee==3.14.4
Pillow==8.1.0
python-gitlab==2.6.0
python-gnupg==0.4.6
sh==1.14.1
tenacity==6.3.1
tenacity==7.0.0
pytest==6.2.2
pytest==6.2.3
pytest-mock==3.5.1
pytest-responses==0.4.0
......@@ -3,20 +3,34 @@ import hashlib
import json
import os
import sys
import time
from pathlib import Path
from uuid import UUID
import pytest
import yaml
from peewee import SqliteDatabase
from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker import BaseWorker, ElementsWorker
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
__yaml_cache = {}
@pytest.fixture(autouse=True)
def disable_sleep(monkeypatch):
"""
Do not sleep at all in between API executions
when errors occur in unit tests.
This speeds up the test execution a lot
"""
monkeypatch.setattr(time, "sleep", lambda x: None)
@pytest.fixture
def cache_yaml(monkeypatch):
"""
......@@ -30,7 +44,9 @@ def cache_yaml(monkeypatch):
# Create a unique cache key for direct YAML strings
# and file descriptors
if isinstance(yaml_payload, str):
key = hashlib.md5(yaml_payload.encode("utf-8")).hexdigest()
yaml_payload = yaml_payload.encode("utf-8")
if isinstance(yaml_payload, bytes):
key = hashlib.md5(yaml_payload).hexdigest()
else:
key = yaml_payload.name
......@@ -75,6 +91,14 @@ def setup_api(responses, monkeypatch, cache_yaml):
monkeypatch.setenv("ARKINDEX_API_TOKEN", "unittest1234")
@pytest.fixture(autouse=True)
def temp_working_directory(monkeypatch, tmp_path):
def _getcwd():
return str(tmp_path)
monkeypatch.setattr(os, "getcwd", _getcwd)
@pytest.fixture(autouse=True)
def give_worker_version_id_env_variable(monkeypatch):
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
......@@ -147,6 +171,26 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
return worker
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
"""Build a BaseWorker using SQLite cache, also mocking a TASK_ID"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True)
monkeypatch.setenv("TASK_ID", "my_task")
return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker(use_cache=True)
worker.configure()
return worker
@pytest.fixture
def fake_page_element():
with open(FIXTURES_DIR / "page_element.json", "r") as f:
......@@ -197,3 +241,127 @@ def fake_gitlab_helper_factory():
)
return run
@pytest.fixture
def mock_cached_elements():
"""Insert few elements in local cache"""
CachedElement.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
parent_id="12341234-1234-1234-1234-123412341234",
type="something",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
@pytest.fixture
def mock_cached_transcriptions():
"""Insert few transcriptions in local cache, on a shared element"""
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello!",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="How are you?",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
@pytest.fixture(scope="function")
def mock_databases(tmpdir):
"""
Initialize several temporary databases
to help testing the merge algorithm
"""
out = {}
for name in ("target", "first", "second", "conflict", "chunk_42"):
# Build a local database in sub directory
# for each name required
filename = "db_42.sqlite" if name == "chunk_42" else "db.sqlite"
path = tmpdir / name / filename
(tmpdir / name).mkdir()
local_db = SqliteDatabase(path)
with local_db.bind_ctx(MODELS):
# Create tables on the current local database
# by binding temporarily the models on that database
local_db.create_tables(MODELS)
out[name] = {"path": path, "db": local_db}
# Add an element in first parent database
with out["first"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("56785678-5678-5678-5678-567856785678"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add another element with a transcription in second parent database
with out["second"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello!",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add a conflicting element
with out["conflict"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
initial=True,
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello again neighbor !",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add an element in chunk parent database
with out["chunk_42"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
initial=True,
)
return out
File added
......@@ -12,7 +12,7 @@ from arkindex_worker import logger
from arkindex_worker.worker import BaseWorker
def test_init_default_local_share():
def test_init_default_local_share(monkeypatch):
worker = BaseWorker()
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
......@@ -28,6 +28,14 @@ def test_init_default_xdg_data_home(monkeypatch):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
def test_init_with_local_cache(monkeypatch):
worker = BaseWorker(use_cache=True)
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.use_cache is True
def test_init_var_ponos_data_given(monkeypatch):
path = str(Path(__file__).absolute().parent)
monkeypatch.setenv("PONOS_DATA", path)
......