diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 28f317bcf633bfbf9fe7615fe89212e36444edd8..ab4a425ac5eea1bb7f84cc8adb38202e5e565772 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -112,9 +112,22 @@ class CachedTranscription(Model): table_name = "transcriptions" +class CachedClassification(Model): + id = UUIDField(primary_key=True) + element = ForeignKeyField(CachedElement, backref="classifications") + class_name = TextField() + confidence = FloatField() + state = CharField(max_length=10) + worker_version_id = UUIDField() + + class Meta: + database = db + table_name = "classifications" + + # Add all the managed models in that list # It's used here, but also in unit tests -MODELS = [CachedImage, CachedElement, CachedTranscription] +MODELS = [CachedImage, CachedElement, CachedTranscription, CachedClassification] def init_cache_db(path): diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index 8bc3b7ba4ff9d811e6a707e084cfa42be7a35338..5e57326efdd5f061f0266e98e9cb18f4b655f8d1 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -2,9 +2,10 @@ import os from apistar.exceptions import ErrorResponse +from peewee import IntegrityError from arkindex_worker import logger -from arkindex_worker.cache import CachedElement +from arkindex_worker.cache import CachedClassification, CachedElement from arkindex_worker.models import Element @@ -84,18 +85,36 @@ class ClassificationMixin(object): return try: - self.request( + created = self.request( "CreateClassification", body={ - "element": element.id, + "element": str(element.id), "ml_class": self.get_ml_class_id(None, ml_class), "worker_version": self.worker_version_id, "confidence": confidence, "high_confidence": high_confidence, }, ) - except ErrorResponse as e: + if self.use_cache: + # Store classification in local cache + try: + to_insert = [ + { + "id": created["id"], + "element_id": element.id, + "class_name": ml_class, + "confidence": created["confidence"], + "state": created["state"], + "worker_version_id": self.worker_version_id, + } + ] + CachedClassification.insert_many(to_insert).execute() + except IntegrityError as e: + logger.warning( + f"Couldn't save created classification in local cache: {e}" + ) + except ErrorResponse as e: # Detect already existing classification if ( e.status_code == 400 diff --git a/tests/test_cache.py b/tests/test_cache.py index 798ae8aad78b50780c050e4b5fe6790292047b4c..6d2b147f5d5e6e534dc0deae82123518da459533 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -54,7 +54,8 @@ def test_create_tables(tmp_path): init_cache_db(db_path) create_tables() - expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) + expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id")) +CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 7ea525ead5eef5e24c0355bd8a6d88b905671286..939d48791191689c5c54d825826c9444da3bdd93 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- import json +from uuid import UUID import pytest from apistar.exceptions import ErrorResponse -from arkindex_worker.cache import CachedElement +from arkindex_worker.cache import CachedClassification, CachedElement from arkindex_worker.models import Element @@ -401,19 +402,28 @@ def test_create_classification(responses, mock_elements_worker): ] == {"a_class": 1} -def test_create_classification_with_cached_element(responses, mock_elements_worker): - mock_elements_worker.classes = { +def test_create_classification_with_cache(responses, mock_elements_worker_with_cache): + mock_elements_worker_with_cache.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} } - elt = CachedElement(id="12341234-1234-1234-1234-123412341234") + elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( responses.POST, "http://testserver/api/v1/classifications/", status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "element": "12341234-1234-1234-1234-123412341234", + "ml_class": "0000", + "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": 0.42, + "high_confidence": True, + "state": "pending", + }, ) - mock_elements_worker.create_classification( + mock_elements_worker_with_cache.create_classification( element=elt, ml_class="a_class", confidence=0.42, @@ -436,10 +446,22 @@ def test_create_classification_with_cached_element(responses, mock_elements_work } # Classification has been created and reported - assert mock_elements_worker.report.report_data["elements"][elt.id][ + assert mock_elements_worker_with_cache.report.report_data["elements"][elt.id][ "classifications" ] == {"a_class": 1} + # Check that created classification was properly stored in SQLite cache + assert list(CachedClassification.select()) == [ + CachedClassification( + id=UUID("56785678-5678-5678-5678-567856785678"), + element_id=UUID(elt.id), + class_name="a_class", + confidence=0.42, + state="pending", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + ] + def test_create_classification_duplicate(responses, mock_elements_worker): mock_elements_worker.classes = {