Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (14)
Showing
with 1506 additions and 1115 deletions
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,pytest,requests,setuptools,sh,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,pytest,requests,setuptools,sh,tenacity,yaml
0.1.13
0.2.0-beta1
# -*- coding: utf-8 -*-
import json
import os
import sqlite3
from collections import namedtuple
from peewee import (
BooleanField,
CharField,
Field,
FloatField,
ForeignKeyField,
IntegerField,
Model,
SqliteDatabase,
TextField,
UUIDField,
)
from arkindex_worker import logger
SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
id VARCHAR(32) PRIMARY KEY,
parent_id VARCHAR(32),
name TEXT NOT NULL,
type TEXT NOT NULL,
polygon TEXT,
initial BOOLEAN DEFAULT 0 NOT NULL,
worker_version_id VARCHAR(32)
)"""
CachedElement = namedtuple(
"CachedElement",
["id", "name", "type", "polygon", "worker_version_id", "parent_id", "initial"],
defaults=[None, 0],
)
db = SqliteDatabase(None)
class LocalDB(object):
def __init__(self, path):
self.db = sqlite3.connect(path)
self.db.row_factory = sqlite3.Row
self.cursor = self.db.cursor()
logger.info(f"Connection to local cache {path} established.")
class JSONField(Field):
field_type = "text"
def create_tables(self):
self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
def db_value(self, value):
if value is None:
return
return json.dumps(value)
def insert(self, table, lines):
if not lines:
def python_value(self, value):
if value is None:
return
columns = ", ".join(lines[0]._fields)
placeholders = ", ".join("?" * len(lines[0]))
values = [tuple(line) for line in lines]
self.cursor.executemany(
f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
return json.loads(value)
class CachedImage(Model):
id = UUIDField(primary_key=True)
width = IntegerField()
height = IntegerField()
url = TextField()
class Meta:
database = db
table_name = "images"
class CachedElement(Model):
id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True)
type = CharField(max_length=50)
image_id = ForeignKeyField(CachedImage, backref="elements", null=True)
polygon = JSONField(null=True)
initial = BooleanField(default=False)
worker_version_id = UUIDField(null=True)
class Meta:
database = db
table_name = "elements"
class CachedTranscription(Model):
id = UUIDField(primary_key=True)
element_id = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField()
confidence = FloatField()
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "transcriptions"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [CachedImage, CachedElement, CachedTranscription]
def init_cache_db(path):
db.init(
path,
pragmas={
# SQLite ignores foreign keys and check constraints by default!
"foreign_keys": 1,
"ignore_check_constraints": 0,
},
)
db.connect()
logger.info(f"Connected to cache on {path}")
def create_tables():
"""
Creates the tables in the cache DB only if they do not already exist.
"""
db.create_tables(MODELS)
def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=None):
"""
Merge all the potential parent task's databases into the existing local one
"""
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
assert os.path.exists(current_database)
# Handle possible chunk in parent task name
# This is needed to support the init_elements databases
filenames = [
"db.sqlite",
]
if chunk is not None:
filenames.append(f"db_{chunk}.sqlite")
# Find all the paths for these databases
paths = list(
filter(
lambda p: os.path.isfile(p),
[
os.path.join(data_dir, parent, name)
for parent in parent_ids
for name in filenames
],
)
self.db.commit()
)
if not paths:
logger.info("No parents cache to use")
return
# Open a connection on current database
connection = sqlite3.connect(current_database)
cursor = connection.cursor()
# Merge each table into the local database
for idx, path in enumerate(paths):
logger.info(f"Merging parent db {path} into {current_database}")
statements = [
"PRAGMA page_size=80000;",
"PRAGMA synchronous=OFF;",
f"ATTACH DATABASE '{path}' AS source_{idx};",
f"REPLACE INTO elements SELECT * FROM source_{idx}.elements;",
f"REPLACE INTO transcriptions SELECT * FROM source_{idx}.transcriptions;",
]
for statement in statements:
cursor.execute(statement)
connection.commit()
# -*- coding: utf-8 -*-
import json
import traceback
import warnings
from collections import Counter
from datetime import datetime
......@@ -83,35 +82,12 @@ class Reporter(object):
)
element["classifications"] = dict(counter)
def add_transcription(self, element_id, type=None, type_count=None):
def add_transcription(self, element_id, count=1):
"""
Report creating a transcription on an element.
Multiple transcriptions with the same parent can be declared with the type_count parameter.
"""
if type_count is None:
if isinstance(type, int):
type_count, type = type, None
else:
type_count = 1
if type is not None:
warnings.warn(
"Transcription types have been deprecated and will be removed in the next release.",
FutureWarning,
)
self._get_element(element_id)["transcriptions"] += type_count
def add_transcriptions(self, element_id, transcriptions):
"""
Report one or more transcriptions at once.
"""
assert isinstance(transcriptions, list), "A list is required for transcriptions"
warnings.warn(
"Reporter.add_transcriptions is deprecated due to transcription types being removed. Please use Reporter.add_transcription(element_id, count) instead.",
FutureWarning,
)
self.add_transcription(element_id, len(transcriptions))
self._get_element(element_id)["transcriptions"] += count
def add_entity(self, element_id, entity_id, type, name):
"""
......
# -*- coding: utf-8 -*-
import datetime
import uuid
from timeit import default_timer
......@@ -20,7 +19,3 @@ class Timer(object):
end = self.timer()
self.elapsed = end - self.start
self.delta = datetime.timedelta(seconds=self.elapsed)
def convert_str_uuid_to_hex(id):
return uuid.UUID(id).hex
This diff is collapsed.
# -*- coding: utf-8 -*-
import json
import os
import sys
import uuid
from enum import Enum
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.classification import ClassificationMixin
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import MANUAL_SLUG, WorkerVersionMixin # noqa: F401
class ActivityState(Enum):
Queued = "queued"
Started = "started"
Processed = "processed"
Error = "error"
class ElementsWorker(
BaseWorker,
ClassificationMixin,
ElementMixin,
TranscriptionMixin,
WorkerVersionMixin,
EntityMixin,
MetaDataMixin,
):
def __init__(self, description="Arkindex Elements Worker", use_cache=False):
super().__init__(description, use_cache)
# Add report concerning elements
self.report = Reporter("unknown worker")
# Add mandatory argument to process elements
self.parser.add_argument(
"--elements-list",
help="JSON elements list to use",
type=open,
default=os.environ.get("TASK_ELEMENTS"),
)
self.parser.add_argument(
"--element",
type=uuid.UUID,
nargs="+",
help="One or more Arkindex element ID",
)
self.classes = {}
self._worker_version_cache = {}
def list_elements(self):
assert not (
self.args.elements_list and self.args.element
), "elements-list and element CLI args shouldn't be both set"
out = []
# Process elements from JSON file
if self.args.elements_list:
data = json.load(self.args.elements_list)
assert isinstance(data, list), "Elements list must be a list"
assert len(data), "No elements in elements list"
out += list(filter(None, [element.get("id") for element in data]))
# Add any extra element from CLI
elif self.args.element:
out += self.args.element
return out
def run(self):
"""
Process every elements from the provided list
"""
self.configure()
# List all elements either from JSON file
# or direct list of elements on CLI
elements = self.list_elements()
if not elements:
logger.warning("No elements to process, stopping.")
sys.exit(1)
# Process every element
count = len(elements)
failed = 0
for i, element_id in enumerate(elements, start=1):
try:
# Load element using Arkindex API
element = Element(**self.request("RetrieveElement", id=element_id))
logger.info(f"Processing {element} ({i}/{count})")
# Report start of process, run process, then report end of process
self.update_activity(element, ActivityState.Started)
self.process_element(element)
self.update_activity(element, ActivityState.Processed)
except ErrorResponse as e:
failed += 1
logger.warning(
f"An API error occurred while processing element {element_id}: {e.title} - {e.content}",
exc_info=e if self.args.verbose else None,
)
self.update_activity(element, ActivityState.Error)
self.report.error(element_id, e)
except Exception as e:
failed += 1
logger.warning(
f"Failed running worker on element {element_id}: {e}",
exc_info=e if self.args.verbose else None,
)
self.update_activity(element, ActivityState.Error)
self.report.error(element_id, e)
# Save report as local artifact
self.report.save(os.path.join(self.work_dir, "ml_report.json"))
if failed:
logger.error(
"Ran on {} elements: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
def process_element(self, element):
"""Override this method to analyze an Arkindex element from the provided list"""
def update_activity(self, element, state):
"""
Update worker activity for this element
This method should not raise a runtime exception, but simply warn users
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert isinstance(state, ActivityState), "state should be an ActivityState"
if not self.features.get("workers_activity"):
logger.debug("Skipping Worker activity update as it's disabled on backend")
return
if self.is_read_only:
logger.warning("Cannot update activity as this worker is in read-only mode")
return
try:
out = self.request(
"UpdateWorkerActivity",
id=self.worker_version_id,
body={
"element_id": element.id,
"state": state.value,
},
)
logger.debug(f"Updated activity of element {element.id} to {state}")
return out
except ErrorResponse as e:
logger.warning(
f"Failed to update activity of element {element.id} to {state.value} due to an API error: {e.content}"
)
except Exception as e:
logger.warning(
f"Failed to update activity of element {element.id} to {state.value}: {e}"
)
# -*- coding: utf-8 -*-
import argparse
import json
import logging
import os
from pathlib import Path
import gnupg
import yaml
from apistar.exceptions import ErrorResponse
from tenacity import (
before_sleep_log,
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import create_tables, init_cache_db, merge_parents_cache
def _is_500_error(exc):
"""
Check if an Arkindex API error is a 50x
This is used to retry most API calls implemented here
"""
if not isinstance(exc, ErrorResponse):
return False
return 500 <= exc.status_code < 600
class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker", use_cache=False):
self.parser = argparse.ArgumentParser(description=description)
# Setup workdir either in Ponos environment or on host's home
if os.environ.get("PONOS_DATA"):
self.work_dir = os.path.join(os.environ["PONOS_DATA"], "current")
else:
# We use the official XDG convention to store file for developers
# https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html
xdg_data_home = os.environ.get(
"XDG_DATA_HOME", os.path.expanduser("~/.local/share")
)
self.work_dir = os.path.join(xdg_data_home, "arkindex")
os.makedirs(self.work_dir, exist_ok=True)
self.worker_version_id = os.environ.get("WORKER_VERSION_ID")
if not self.worker_version_id:
logger.warning(
"Missing WORKER_VERSION_ID environment variable, worker is in read-only mode"
)
logger.info(f"Worker will use {self.work_dir} as working directory")
self.use_cache = use_cache
if self.use_cache is True:
if os.environ.get("TASK_ID"):
cache_dir = f"/data/{os.environ.get('TASK_ID')}"
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
else:
self.cache_path = os.path.join(os.getcwd(), "db.sqlite")
init_cache_db(self.cache_path)
create_tables()
else:
logger.debug("Cache is disabled")
@property
def is_read_only(self):
"""Worker cannot publish anything without a worker version ID"""
return self.worker_version_id is None
def configure(self):
"""
Configure worker using cli args and environment variables
"""
self.parser.add_argument(
"-c",
"--config",
help="Alternative configuration file when running without a Worker Version ID",
type=open,
)
self.parser.add_argument(
"-v",
"--verbose",
help="Display more information on events and errors",
action="store_true",
default=False,
)
# Call potential extra arguments
self.add_arguments()
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
# Setup logging level
if self.args.verbose:
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
# Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
# Load features available on backend, and check authentication
user = self.request("RetrieveUser")
logger.debug(f"Connected as {user['display_name']} - {user['email']}")
self.features = user["features"]
if self.worker_version_id:
# Retrieve initial configuration from API
worker_version = self.request(
"RetrieveWorkerVersion", id=self.worker_version_id
)
logger.info(
f"Loaded worker {worker_version['worker']['name']} revision {worker_version['revision']['hash'][0:7]} from API"
)
self.config = worker_version["configuration"]["configuration"]
required_secrets = worker_version["configuration"].get("secrets", [])
elif self.args.config:
# Load config from YAML file
self.config = yaml.safe_load(self.args.config)
required_secrets = self.config.get("secrets", [])
logger.info(
f"Running with local configuration from {self.args.config.name}"
)
else:
self.config = {}
required_secrets = []
logger.warning("Running without any extra configuration")
# Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets}
# Merging parents caches (if there are any) in the current task local cache
task_id = os.environ.get("TASK_ID")
if self.use_cache and task_id is not None:
task = self.request("RetrieveTaskFromAgent", id=task_id)
merge_parents_cache(
task["parents"],
self.cache_path,
data_dir=os.environ.get("PONOS_DATA", "/data"),
chunk=os.environ.get("ARKINDEX_TASK_CHUNK"),
)
def load_secret(self, name):
"""Load all secrets described in the worker configuration"""
secret = None
# Load from the backend
try:
resp = self.request("RetrieveSecret", name=name)
secret = resp["content"]
logging.info(f"Loaded API secret {name}")
except ErrorResponse as e:
logger.warning(f"Secret {name} not available: {e.content}")
# Load from local developer storage
base_dir = Path(os.environ.get("XDG_CONFIG_HOME") or "~/.config").expanduser()
path = base_dir / "arkindex" / "secrets" / name
if path.exists():
logging.debug(f"Loading local secret from {path}")
try:
gpg = gnupg.GPG()
decrypted = gpg.decrypt_file(open(path, "rb"))
assert (
decrypted.ok
), f"GPG error: {decrypted.status} - {decrypted.stderr}"
secret = decrypted.data.decode("utf-8")
logging.info(f"Loaded local secret {name}")
except Exception as e:
logger.error(f"Local secret {name} is not available as {path}: {e}")
if secret is None:
raise Exception(f"Secret {name} is not available on the API nor locally")
# Parse secret payload, according to its extension
_, ext = os.path.splitext(os.path.basename(name))
try:
ext = ext.lower()
if ext == ".json":
return json.loads(secret)
elif ext in (".yaml", ".yml"):
return yaml.safe_load(secret)
except Exception as e:
logger.error(f"Failed to parse secret {name}: {e}")
# By default give raw secret payload
return secret
@retry(
retry=retry_if_exception(_is_500_error),
wait=wait_exponential(multiplier=2, min=3),
reraise=True,
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.INFO),
)
def request(self, *args, **kwargs):
"""
Proxy all Arkindex API requests with a retry mechanism
in case of 50X errors
The same API call will be retried 5 times, with an exponential sleep time
going through 3, 4, 8 and 16 seconds of wait between call.
If the 5th call still gives a 50x, the exception is re-raised
and the caller should catch it
Log messages are displayed before sleeping (when at least one exception occurred)
"""
return self.api_client.request(*args, **kwargs)
def add_arguments(self):
"""Override this method to add argparse argument to this worker"""
def run(self):
"""Override this method to implement your own process"""
# -*- coding: utf-8 -*-
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.models import Element
class ClassificationMixin(object):
def load_corpus_classes(self, corpus_id):
"""
Load ML classes for the given corpus ID
"""
corpus_classes = self.api_client.paginate(
"ListCorpusMLClasses",
id=corpus_id,
)
self.classes[corpus_id] = {
ml_class["name"]: ml_class["id"] for ml_class in corpus_classes
}
logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes")
def get_ml_class_id(self, corpus_id, ml_class):
"""
Return the ID corresponding to the given class name on a specific corpus
This method will automatically create missing classes
"""
if not self.classes.get(corpus_id):
self.load_corpus_classes(corpus_id)
ml_class_id = self.classes[corpus_id].get(ml_class)
if ml_class_id is None:
logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}")
try:
response = self.request(
"CreateMLClass", id=corpus_id, body={"name": ml_class}
)
ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
logger.debug(f"Created ML class {response['id']}")
except ErrorResponse as e:
# Only reload for 400 errors
if e.status_code != 400:
raise
# Reload and make sure we have the class
logger.info(
f"Reloading corpus classes to see if {ml_class} already exists"
)
self.load_corpus_classes(corpus_id)
assert (
ml_class in self.classes[corpus_id]
), "Missing class {ml_class} even after reloading"
ml_class_id = self.classes[corpus_id][ml_class]
return ml_class_id
def create_classification(
self, element, ml_class, confidence, high_confidence=False
):
"""
Create a classification on the given element through API
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert ml_class and isinstance(
ml_class, str
), "ml_class shouldn't be null and should be of type str"
assert (
isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence shouldn't be null and should be a float in [0..1] range"
assert isinstance(
high_confidence, bool
), "high_confidence shouldn't be null and should be of type bool"
if self.is_read_only:
logger.warning(
"Cannot create classification as this worker is in read-only mode"
)
return
try:
self.request(
"CreateClassification",
body={
"element": element.id,
"ml_class": self.get_ml_class_id(element.corpus.id, ml_class),
"worker_version": self.worker_version_id,
"confidence": confidence,
"high_confidence": high_confidence,
},
)
except ErrorResponse as e:
# Detect already existing classification
if (
e.status_code == 400
and "non_field_errors" in e.content
and "The fields element, worker_version, ml_class must make a unique set."
in e.content["non_field_errors"]
):
logger.warning(
f"This worker version has already set {ml_class} on element {element.id}"
)
return
# Propagate any other API error
raise
self.report.add_classification(element.id, ml_class)
# -*- coding: utf-8 -*-
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Element
class ElementMixin(object):
def create_sub_element(self, element, type, name, polygon):
"""
Create a child element on the given element through API
Return the ID of the created sub element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert type and isinstance(
type, str
), "type shouldn't be null and should be of type str"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
assert polygon and isinstance(
polygon, list
), "polygon shouldn't be null and should be of type list"
assert len(polygon) >= 3, "polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), "polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
), "polygon points should be lists of two numbers"
if self.is_read_only:
logger.warning("Cannot create element as this worker is in read-only mode")
return
sub_element = self.request(
"CreateElement",
body={
"type": type,
"name": name,
"image": element.zone.image.id,
"corpus": element.corpus.id,
"polygon": polygon,
"parent": element.id,
"worker_version": self.worker_version_id,
},
)
self.report.add_element(element.id, type)
return sub_element["id"]
def create_elements(self, parent, elements):
"""
Create children elements on the given element through API
Return the IDs of created elements
"""
if isinstance(parent, Element):
assert parent.get(
"zone"
), "create_elements cannot be used on parents without zones"
elif isinstance(parent, CachedElement):
assert (
parent.image_id
), "create_elements cannot be used on parents without images"
else:
raise TypeError(
"Parent element should be an Element or CachedElement instance"
)
assert elements and isinstance(
elements, list
), "elements shouldn't be null and should be of type list"
for index, element in enumerate(elements):
assert isinstance(
element, dict
), f"Element at index {index} in elements: Should be of type dict"
name = element.get("name")
assert name and isinstance(
name, str
), f"Element at index {index} in elements: name shouldn't be null and should be of type str"
type = element.get("type")
assert type and isinstance(
type, str
), f"Element at index {index} in elements: type shouldn't be null and should be of type str"
polygon = element.get("polygon")
assert polygon and isinstance(
polygon, list
), f"Element at index {index} in elements: polygon shouldn't be null and should be of type list"
assert (
len(polygon) >= 3
), f"Element at index {index} in elements: polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), f"Element at index {index} in elements: polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
), f"Element at index {index} in elements: polygon points should be lists of two numbers"
if self.is_read_only:
logger.warning("Cannot create elements as this worker is in read-only mode")
return
created_ids = self.request(
"CreateElements",
id=parent.id,
body={
"worker_version": self.worker_version_id,
"elements": elements,
},
)
for element in elements:
self.report.add_element(parent.id, element["type"])
if self.use_cache:
# Create the image as needed and handle both an Element and a CachedElement
if isinstance(parent, CachedElement):
image_id = parent.image_id
else:
image_id = parent.zone.image.id
CachedImage.get_or_create(
id=parent.zone.image.id,
defaults={
"width": parent.zone.image.width,
"height": parent.zone.image.height,
"url": parent.zone.image.url,
},
)
# Store elements in local cache
try:
to_insert = [
{
"id": created_ids[idx]["id"],
"parent_id": parent.id,
"type": element["type"],
"image_id": image_id,
"polygon": element["polygon"],
"worker_version_id": self.worker_version_id,
}
for idx, element in enumerate(elements)
]
CachedElement.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(f"Couldn't save created elements in local cache: {e}")
return created_ids
def list_element_children(
self,
element,
best_class=None,
folder=None,
name=None,
recursive=None,
type=None,
with_best_classes=None,
with_corpus=None,
with_has_children=None,
with_zone=None,
worker_version=None,
):
"""
List children of an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
query_params = {}
if best_class is not None:
assert isinstance(best_class, str) or isinstance(
best_class, bool
), "best_class should be of type str or bool"
query_params["best_class"] = best_class
if folder is not None:
assert isinstance(folder, bool), "folder should be of type bool"
query_params["folder"] = folder
if name:
assert isinstance(name, str), "name should be of type str"
query_params["name"] = name
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if type:
assert isinstance(type, str), "type should be of type str"
query_params["type"] = type
if with_best_classes is not None:
assert isinstance(
with_best_classes, bool
), "with_best_classes should be of type bool"
query_params["with_best_classes"] = with_best_classes
if with_corpus is not None:
assert isinstance(with_corpus, bool), "with_corpus should be of type bool"
query_params["with_corpus"] = with_corpus
if with_has_children is not None:
assert isinstance(
with_has_children, bool
), "with_has_children should be of type bool"
query_params["with_has_children"] = with_has_children
if with_zone is not None:
assert isinstance(with_zone, bool), "with_zone should be of type bool"
query_params["with_zone"] = with_zone
if worker_version:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
query_params["worker_version"] = worker_version
if self.use_cache:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"type",
"worker_version",
}, "When using the local cache, you can only filter by 'type' and/or 'worker_version'"
query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type:
query = query.where(CachedElement.type == type)
if worker_version:
query = query.where(CachedElement.worker_version_id == worker_version)
return query
else:
children = self.api_client.paginate(
"ListElementChildren", id=element.id, **query_params
)
return children
# -*- coding: utf-8 -*-
from enum import Enum
from arkindex_worker import logger
from arkindex_worker.models import Element
class EntityType(Enum):
Person = "person"
Location = "location"
Subject = "subject"
Organization = "organization"
Misc = "misc"
Number = "number"
Date = "date"
class EntityMixin(object):
def create_entity(self, element, name, type, corpus, metas=None, validated=None):
"""
Create an entity on the given corpus through API
Return the ID of the created entity
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
assert type and isinstance(
type, EntityType
), "type shouldn't be null and should be of type EntityType"
assert corpus and isinstance(
corpus, str
), "corpus shouldn't be null and should be of type str"
if metas:
assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None:
assert isinstance(validated, bool), "validated should be of type bool"
if self.is_read_only:
logger.warning("Cannot create entity as this worker is in read-only mode")
return
entity = self.request(
"CreateEntity",
body={
"name": name,
"type": type.value,
"metas": metas,
"validated": validated,
"corpus": corpus,
"worker_version": self.worker_version_id,
},
)
self.report.add_entity(element.id, entity["id"], type.value, name)
return entity["id"]
# -*- coding: utf-8 -*-
from enum import Enum
from arkindex_worker import logger
from arkindex_worker.models import Element
class MetaType(Enum):
Text = "text"
HTML = "html"
Date = "date"
Location = "location"
# Element's original structure reference (intended to be indexed)
Reference = "reference"
class MetaDataMixin(object):
def create_metadata(self, element, type, name, value, entity=None):
"""
Create a metadata on the given element through API
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert type and isinstance(
type, MetaType
), "type shouldn't be null and should be of type MetaType"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
assert value and isinstance(
value, str
), "value shouldn't be null and should be of type str"
if entity:
assert isinstance(entity, str), "entity should be of type str"
if self.is_read_only:
logger.warning("Cannot create metadata as this worker is in read-only mode")
return
metadata = self.request(
"CreateMetaData",
id=element.id,
body={
"type": type.value,
"name": name,
"value": value,
"entity": entity,
"worker_version": self.worker_version_id,
},
)
self.report.add_metadata(element.id, metadata["id"], type.value, name)
return metadata["id"]
# -*- coding: utf-8 -*-
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
class TranscriptionMixin(object):
def create_transcription(self, element, text, score):
"""
Create a transcription on the given element through the API.
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert text and isinstance(
text, str
), "text shouldn't be null and should be of type str"
assert (
isinstance(score, float) and 0 <= score <= 1
), "score shouldn't be null and should be a float in [0..1] range"
if self.is_read_only:
logger.warning(
"Cannot create transcription as this worker is in read-only mode"
)
return
created = self.request(
"CreateTranscription",
id=element.id,
body={
"text": text,
"worker_version": self.worker_version_id,
"score": score,
},
)
self.report.add_transcription(element.id)
if self.use_cache:
# Store transcription in local cache
try:
to_insert = [
{
"id": created["id"],
"element_id": element.id,
"text": created["text"],
"confidence": created["confidence"],
"worker_version_id": self.worker_version_id,
}
]
CachedTranscription.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcription in local cache: {e}"
)
def create_transcriptions(self, transcriptions):
"""
Create multiple transcriptions at once on existing elements through the API.
"""
assert transcriptions and isinstance(
transcriptions, list
), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions):
element_id = transcription.get("element_id")
assert element_id and isinstance(
element_id, str
), f"Transcription at index {index} in transcriptions: element_id shouldn't be null and should be of type str"
text = transcription.get("text")
assert text and isinstance(
text, str
), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str"
score = transcription.get("score")
assert (
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
created_trs = self.request(
"CreateTranscriptions",
body={
"worker_version": self.worker_version_id,
"transcriptions": transcriptions,
},
)["transcriptions"]
for created_tr in created_trs:
self.report.add_transcription(created_tr["element_id"])
if self.use_cache:
# Store transcriptions in local cache
try:
to_insert = [
{
"id": created_tr["id"],
"element_id": created_tr["element_id"],
"text": created_tr["text"],
"confidence": created_tr["confidence"],
"worker_version_id": self.worker_version_id,
}
for created_tr in created_trs
]
CachedTranscription.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcriptions in local cache: {e}"
)
def create_element_transcriptions(self, element, sub_element_type, transcriptions):
"""
Create multiple sub elements with their transcriptions on the given element through API
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert sub_element_type and isinstance(
sub_element_type, str
), "sub_element_type shouldn't be null and should be of type str"
assert transcriptions and isinstance(
transcriptions, list
), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions):
text = transcription.get("text")
assert text and isinstance(
text, str
), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str"
score = transcription.get("score")
assert (
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
polygon = transcription.get("polygon")
assert polygon and isinstance(
polygon, list
), f"Transcription at index {index} in transcriptions: polygon shouldn't be null and should be of type list"
assert (
len(polygon) >= 3
), f"Transcription at index {index} in transcriptions: polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), f"Transcription at index {index} in transcriptions: polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
), f"Transcription at index {index} in transcriptions: polygon points should be lists of two numbers"
if self.is_read_only:
logger.warning(
"Cannot create transcriptions as this worker is in read-only mode"
)
return
annotations = self.request(
"CreateElementTranscriptions",
id=element.id,
body={
"element_type": sub_element_type,
"worker_version": self.worker_version_id,
"transcriptions": transcriptions,
"return_elements": True,
},
)
for annotation in annotations:
if annotation["created"]:
logger.debug(
f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation"
)
self.report.add_element(element.id, sub_element_type)
self.report.add_transcription(annotation["element_id"])
if self.use_cache:
# Store transcriptions and their associated element (if created) in local cache
created_ids = set()
elements_to_insert = []
transcriptions_to_insert = []
for index, annotation in enumerate(annotations):
transcription = transcriptions[index]
if annotation["element_id"] not in created_ids:
# Even if the API says the element already existed in the DB,
# we need to check if it is available in the local cache.
# Peewee does not have support for SQLite's INSERT OR IGNORE,
# so we do the check here, element by element.
try:
CachedElement.get_by_id(annotation["element_id"])
except CachedElement.DoesNotExist:
elements_to_insert.append(
{
"id": annotation["element_id"],
"parent_id": element.id,
"type": sub_element_type,
"image_id": element.image_id,
"polygon": transcription["polygon"],
"worker_version_id": self.worker_version_id,
}
)
created_ids.add(annotation["element_id"])
transcriptions_to_insert.append(
{
"id": annotation["id"],
"element_id": annotation["element_id"],
"text": transcription["text"],
"confidence": transcription["score"],
"worker_version_id": self.worker_version_id,
}
)
try:
CachedElement.insert_many(elements_to_insert).execute()
CachedTranscription.insert_many(transcriptions_to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcriptions in local cache: {e}"
)
return annotations
def list_transcriptions(
self, element, element_type=None, recursive=None, worker_version=None
):
"""
List transcriptions on an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
query_params = {}
if element_type:
assert isinstance(element_type, str), "element_type should be of type str"
query_params["element_type"] = element_type
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if worker_version:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
query_params["worker_version"] = worker_version
if self.use_cache and recursive is None:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"worker_version",
}, "When using the local cache, you can only filter by 'worker_version'"
transcriptions = CachedTranscription.select().where(
CachedTranscription.element_id == element.id
)
if worker_version:
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
)
else:
if self.use_cache:
logger.warning(
"'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter."
)
transcriptions = self.api_client.paginate(
"ListTranscriptions", id=element.id, **query_params
)
return transcriptions
# -*- coding: utf-8 -*-
MANUAL_SLUG = "manual"
class WorkerVersionMixin(object):
def get_worker_version(self, worker_version_id: str) -> dict:
"""
Get worker version from cache if possible, otherwise make API request
"""
if worker_version_id in self._worker_version_cache:
return self._worker_version_cache[worker_version_id]
worker_version = self.request("RetrieveWorkerVersion", id=worker_version_id)
self._worker_version_cache[worker_version_id] = worker_version
return worker_version
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Get worker version slug from cache if possible, otherwise make API request
Should use `get_ml_result_slug` instead of using this method directly
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
def get_ml_result_slug(self, ml_result) -> str:
"""
Helper function to get the slug from element, classification or transcription
Can handle old and new (source vs worker_version)
:type ml_result: Element or classification or transcription
"""
if (
"source" in ml_result
and ml_result["source"]
and ml_result["source"]["slug"]
):
return ml_result["source"]["slug"]
elif "worker_version" in ml_result and ml_result["worker_version"]:
return self.get_worker_version_slug(ml_result["worker_version"])
# transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result and ml_result["worker_version_id"]:
return self.get_worker_version_slug(ml_result["worker_version_id"])
elif "worker_version" in ml_result and ml_result["worker_version"] is None:
return MANUAL_SLUG
elif (
"worker_version_id" in ml_result and ml_result["worker_version_id"] is None
):
return MANUAL_SLUG
else:
raise ValueError(f"Unable to get slug from: {ml_result}")
arkindex-client==1.0.6
peewee==3.14.4
Pillow==8.1.0
python-gitlab==2.6.0
python-gnupg==0.4.6
sh==1.14.1
tenacity==6.3.1
tenacity==7.0.0
......@@ -3,22 +3,34 @@ import hashlib
import json
import os
import sys
import time
from pathlib import Path
from uuid import UUID
import pytest
import yaml
from peewee import SqliteDatabase
from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker import BaseWorker, ElementsWorker
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
CACHE_DIR = str(Path(__file__).resolve().parent / "data/cache")
CACHE_FILE = os.path.join(CACHE_DIR, "db.sqlite")
__yaml_cache = {}
@pytest.fixture(autouse=True)
def disable_sleep(monkeypatch):
"""
Do not sleep at all in between API executions
when errors occur in unit tests.
This speeds up the test execution a lot
"""
monkeypatch.setattr(time, "sleep", lambda x: None)
@pytest.fixture
def cache_yaml(monkeypatch):
"""
......@@ -80,17 +92,12 @@ def setup_api(responses, monkeypatch, cache_yaml):
@pytest.fixture(autouse=True)
def handle_cache_file(monkeypatch):
def temp_working_directory(monkeypatch, tmp_path):
def _getcwd():
return CACHE_DIR
return str(tmp_path)
monkeypatch.setattr(os, "getcwd", _getcwd)
yield
if os.path.isfile(CACHE_FILE):
os.remove(CACHE_FILE)
@pytest.fixture(autouse=True)
def give_worker_version_id_env_variable(monkeypatch):
......@@ -164,6 +171,16 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
return worker
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
"""Build a BaseWorker using SQLite cache, also mocking a TASK_ID"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True)
monkeypatch.setenv("TASK_ID", "my_task")
return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
......@@ -224,3 +241,127 @@ def fake_gitlab_helper_factory():
)
return run
@pytest.fixture
def mock_cached_elements():
"""Insert few elements in local cache"""
CachedElement.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
parent_id="12341234-1234-1234-1234-123412341234",
type="something",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
@pytest.fixture
def mock_cached_transcriptions():
"""Insert few transcriptions in local cache, on a shared element"""
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello!",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="How are you?",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
@pytest.fixture(scope="function")
def mock_databases(tmpdir):
"""
Initialize several temporary databases
to help testing the merge algorithm
"""
out = {}
for name in ("target", "first", "second", "conflict", "chunk_42"):
# Build a local database in sub directory
# for each name required
filename = "db_42.sqlite" if name == "chunk_42" else "db.sqlite"
path = tmpdir / name / filename
(tmpdir / name).mkdir()
local_db = SqliteDatabase(path)
with local_db.bind_ctx(MODELS):
# Create tables on the current local database
# by binding temporarily the models on that database
local_db.create_tables(MODELS)
out[name] = {"path": path, "db": local_db}
# Add an element in first parent database
with out["first"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("56785678-5678-5678-5678-567856785678"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add another element with a transcription in second parent database
with out["second"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello!",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add a conflicting element
with out["conflict"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
initial=True,
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello again neighbor !",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add an element in chunk parent database
with out["chunk_42"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
initial=True,
)
return out
File deleted
No preview for this file type
......@@ -33,7 +33,7 @@ def test_init_with_local_cache(monkeypatch):
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.cache is not None
assert worker.use_cache is True
def test_init_var_ponos_data_given(monkeypatch):
......
# -*- coding: utf-8 -*-
import json
import os
import sqlite3
from pathlib import Path
import pytest
from peewee import OperationalError
from arkindex_worker.cache import CachedElement, LocalDB
from arkindex_worker.utils import convert_str_uuid_to_hex
FIXTURES = Path(__file__).absolute().parent / "data/cache"
ELEMENTS_TO_INSERT = [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
name="0",
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
name="1",
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
]
from arkindex_worker.cache import create_tables, db, init_cache_db
def test_init_non_existent_path():
with pytest.raises(sqlite3.OperationalError) as e:
LocalDB("path/not/found.sqlite")
with pytest.raises(OperationalError) as e:
init_cache_db("path/not/found.sqlite")
assert str(e.value) == "unable to open database file"
def test_init():
db_path = f"{FIXTURES}/db.sqlite"
LocalDB(db_path)
def test_init(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
assert os.path.isfile(db_path)
def test_create_tables_existing_table():
db_path = f"{FIXTURES}/tables.sqlite"
cache = LocalDB(db_path)
def test_create_tables_existing_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
with open(db_path, "rb") as before_file:
before = before_file.read()
cache.create_tables()
with open(db_path, "rb") as after_file:
after = after_file.read()
assert before == after, "Cache was modified"
def test_create_tables():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
expected_cache = LocalDB(f"{FIXTURES}/tables.sqlite")
# For each table in our new generated cache, we are checking that its structure
# is the same as the one saved in data/tables.sqlite
for table in cache.cursor.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
):
name = table["name"]
expected_table = expected_cache.cursor.execute(
f"SELECT sql FROM sqlite_master WHERE name = '{name}'"
).fetchone()
generated_table = cache.cursor.execute(
f"SELECT sql FROM sqlite_master WHERE name = '{name}'"
).fetchone()
assert expected_table == generated_table
def test_insert_empty_lines():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
cache.insert("elements", [])
expected_cache = LocalDB(f"{FIXTURES}/tables.sqlite")
assert (
cache.cursor.execute("SELECT * FROM elements").fetchall()
== expected_cache.cursor.execute("SELECT * FROM elements").fetchall()
)
def test_insert_existing_lines():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
# Create the tables once…
init_cache_db(db_path)
create_tables()
db.close()
with open(db_path, "rb") as before_file:
before = before_file.read()
with pytest.raises(sqlite3.IntegrityError) as e:
cache.insert("elements", ELEMENTS_TO_INSERT)
assert str(e.value) == "UNIQUE constraint failed: elements.id"
# Create them again
init_cache_db(db_path)
create_tables()
with open(db_path, "rb") as after_file:
after = after_file.read()
assert before == after, "Cache was modified"
assert before == after, "Existing table structure was modified"
def test_create_tables(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_tables()
def test_insert():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
cache.insert("elements", ELEMENTS_TO_INSERT)
generated_rows = cache.cursor.execute("SELECT * FROM elements").fetchall()
expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
expected_cache = LocalDB(f"{FIXTURES}/lines.sqlite")
assert (
generated_rows
== expected_cache.cursor.execute("SELECT * FROM elements").fetchall()
actual_schema = "\n".join(
[
row[0]
for row in db.connection()
.execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name")
.fetchall()
]
)
assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT
assert expected_schema == actual_schema