Skip to content
Snippets Groups Projects
Commit 1a54fd1e authored by Eva Bardou's avatar Eva Bardou Committed by Bastien Abadie
Browse files

Store Classification in local cache during create_classification

parent d2c8950f
No related branches found
No related tags found
1 merge request!89Store Classification in local cache during create_classification
Pipeline #78436 passed
......@@ -112,9 +112,22 @@ class CachedTranscription(Model):
table_name = "transcriptions"
class CachedClassification(Model):
id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="classifications")
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "classifications"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [CachedImage, CachedElement, CachedTranscription]
MODELS = [CachedImage, CachedElement, CachedTranscription, CachedClassification]
def init_cache_db(path):
......
......@@ -2,9 +2,10 @@
import os
from apistar.exceptions import ErrorResponse
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
......@@ -84,18 +85,36 @@ class ClassificationMixin(object):
return
try:
self.request(
created = self.request(
"CreateClassification",
body={
"element": element.id,
"element": str(element.id),
"ml_class": self.get_ml_class_id(None, ml_class),
"worker_version": self.worker_version_id,
"confidence": confidence,
"high_confidence": high_confidence,
},
)
except ErrorResponse as e:
if self.use_cache:
# Store classification in local cache
try:
to_insert = [
{
"id": created["id"],
"element_id": element.id,
"class_name": ml_class,
"confidence": created["confidence"],
"state": created["state"],
"worker_version_id": self.worker_version_id,
}
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classification in local cache: {e}"
)
except ErrorResponse as e:
# Detect already existing classification
if (
e.status_code == 400
......
......@@ -54,7 +54,8 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
......
# -*- coding: utf-8 -*-
import json
from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
......@@ -401,19 +402,28 @@ def test_create_classification(responses, mock_elements_worker):
] == {"a_class": 1}
def test_create_classification_with_cached_element(responses, mock_elements_worker):
mock_elements_worker.classes = {
def test_create_classification_with_cache(responses, mock_elements_worker_with_cache):
mock_elements_worker_with_cache.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = CachedElement(id="12341234-1234-1234-1234-123412341234")
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
"high_confidence": True,
"state": "pending",
},
)
mock_elements_worker.create_classification(
mock_elements_worker_with_cache.create_classification(
element=elt,
ml_class="a_class",
confidence=0.42,
......@@ -436,10 +446,22 @@ def test_create_classification_with_cached_element(responses, mock_elements_work
}
# Classification has been created and reported
assert mock_elements_worker.report.report_data["elements"][elt.id][
assert mock_elements_worker_with_cache.report.report_data["elements"][elt.id][
"classifications"
] == {"a_class": 1}
# Check that created classification was properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("56785678-5678-5678-5678-567856785678"),
element_id=UUID(elt.id),
class_name="a_class",
confidence=0.42,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_classification_duplicate(responses, mock_elements_worker):
mock_elements_worker.classes = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment