From 1a54fd1e3e40742b11efc38badae8361206cb61d Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Fri, 9 Apr 2021 11:31:01 +0000 Subject: [PATCH] Store Classification in local cache during create_classification --- arkindex_worker/cache.py | 15 +++++++- arkindex_worker/worker/classification.py | 27 ++++++++++++--- tests/test_cache.py | 3 +- .../test_classifications.py | 34 +++++++++++++++---- 4 files changed, 67 insertions(+), 12 deletions(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 28f317bc..ab4a425a 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -112,9 +112,22 @@ class CachedTranscription(Model): table_name = "transcriptions" +class CachedClassification(Model): + id = UUIDField(primary_key=True) + element = ForeignKeyField(CachedElement, backref="classifications") + class_name = TextField() + confidence = FloatField() + state = CharField(max_length=10) + worker_version_id = UUIDField() + + class Meta: + database = db + table_name = "classifications" + + # Add all the managed models in that list # It's used here, but also in unit tests -MODELS = [CachedImage, CachedElement, CachedTranscription] +MODELS = [CachedImage, CachedElement, CachedTranscription, CachedClassification] def init_cache_db(path): diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index 8bc3b7ba..5e57326e 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -2,9 +2,10 @@ import os from apistar.exceptions import ErrorResponse +from peewee import IntegrityError from arkindex_worker import logger -from arkindex_worker.cache import CachedElement +from arkindex_worker.cache import CachedClassification, CachedElement from arkindex_worker.models import Element @@ -84,18 +85,36 @@ class ClassificationMixin(object): return try: - self.request( + created = self.request( "CreateClassification", body={ - "element": element.id, + "element": str(element.id), "ml_class": self.get_ml_class_id(None, ml_class), "worker_version": self.worker_version_id, "confidence": confidence, "high_confidence": high_confidence, }, ) - except ErrorResponse as e: + if self.use_cache: + # Store classification in local cache + try: + to_insert = [ + { + "id": created["id"], + "element_id": element.id, + "class_name": ml_class, + "confidence": created["confidence"], + "state": created["state"], + "worker_version_id": self.worker_version_id, + } + ] + CachedClassification.insert_many(to_insert).execute() + except IntegrityError as e: + logger.warning( + f"Couldn't save created classification in local cache: {e}" + ) + except ErrorResponse as e: # Detect already existing classification if ( e.status_code == 400 diff --git a/tests/test_cache.py b/tests/test_cache.py index 798ae8aa..6d2b147f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -54,7 +54,8 @@ def test_create_tables(tmp_path): init_cache_db(db_path) create_tables() - expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) + expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id")) +CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 7ea525ea..939d4879 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- import json +from uuid import UUID import pytest from apistar.exceptions import ErrorResponse -from arkindex_worker.cache import CachedElement +from arkindex_worker.cache import CachedClassification, CachedElement from arkindex_worker.models import Element @@ -401,19 +402,28 @@ def test_create_classification(responses, mock_elements_worker): ] == {"a_class": 1} -def test_create_classification_with_cached_element(responses, mock_elements_worker): - mock_elements_worker.classes = { +def test_create_classification_with_cache(responses, mock_elements_worker_with_cache): + mock_elements_worker_with_cache.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} } - elt = CachedElement(id="12341234-1234-1234-1234-123412341234") + elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( responses.POST, "http://testserver/api/v1/classifications/", status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "element": "12341234-1234-1234-1234-123412341234", + "ml_class": "0000", + "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": 0.42, + "high_confidence": True, + "state": "pending", + }, ) - mock_elements_worker.create_classification( + mock_elements_worker_with_cache.create_classification( element=elt, ml_class="a_class", confidence=0.42, @@ -436,10 +446,22 @@ def test_create_classification_with_cached_element(responses, mock_elements_work } # Classification has been created and reported - assert mock_elements_worker.report.report_data["elements"][elt.id][ + assert mock_elements_worker_with_cache.report.report_data["elements"][elt.id][ "classifications" ] == {"a_class": 1} + # Check that created classification was properly stored in SQLite cache + assert list(CachedClassification.select()) == [ + CachedClassification( + id=UUID("56785678-5678-5678-5678-567856785678"), + element_id=UUID(elt.id), + class_name="a_class", + confidence=0.42, + state="pending", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + ] + def test_create_classification_duplicate(responses, mock_elements_worker): mock_elements_worker.classes = { -- GitLab