From 6ccba7ae5d2a3a294217e0ad7bf472eb51799df4 Mon Sep 17 00:00:00 2001
From: ml bonhomme <bonhomme@teklia.com>
Date: Thu, 4 Jul 2024 08:43:14 +0000
Subject: [PATCH] Make classification confidence non nullable

---
 arkindex/documents/export/structure.sql       |  4 ++--
 .../0011_alter_classification_confidence.py   | 19 +++++++++++++++++++
 arkindex/documents/models.py                  |  2 +-
 .../tests/tasks/test_corpus_delete.py         |  1 +
 .../documents/tests/test_classification.py    |  3 +++
 .../documents/tests/test_retrieve_elements.py |  8 ++++----
 arkindex/process/tests/test_managers.py       |  1 +
 arkindex/process/tests/test_processes.py      |  2 +-
 8 files changed, 32 insertions(+), 8 deletions(-)
 create mode 100644 arkindex/documents/migrations/0011_alter_classification_confidence.py

diff --git a/arkindex/documents/export/structure.sql b/arkindex/documents/export/structure.sql
index 3676a26f39..580f5342a0 100644
--- a/arkindex/documents/export/structure.sql
+++ b/arkindex/documents/export/structure.sql
@@ -110,7 +110,7 @@ CREATE TABLE classification (
     class_name VARCHAR(1024) NOT NULL,
     state VARCHAR(16) NOT NULL DEFAULT 'pending',
     moderator VARCHAR(255),
-    confidence REAL,
+    confidence REAL NOT NULL,
     high_confidence INTEGER NOT NULL DEFAULT 0,
     worker_version_id VARCHAR(37),
     worker_run_id VARCHAR(37),
@@ -118,7 +118,7 @@ CREATE TABLE classification (
     FOREIGN KEY (element_id) REFERENCES element (id) ON DELETE CASCADE,
     FOREIGN KEY (worker_version_id) REFERENCES worker_version (id) ON DELETE CASCADE,
     FOREIGN KEY (worker_run_id) REFERENCES worker_run (id) ON DELETE CASCADE,
-    CHECK (confidence IS NULL OR (confidence >= 0 AND confidence <= 1)),
+    CHECK (confidence >= 0 AND confidence <= 1),
     CHECK (high_confidence = 0 OR high_confidence = 1),
     CHECK (worker_run_id IS NULL OR worker_version_id IS NOT NULL)
 );
diff --git a/arkindex/documents/migrations/0011_alter_classification_confidence.py b/arkindex/documents/migrations/0011_alter_classification_confidence.py
new file mode 100644
index 0000000000..0cd5ecf021
--- /dev/null
+++ b/arkindex/documents/migrations/0011_alter_classification_confidence.py
@@ -0,0 +1,19 @@
+# Generated by Django 4.2.13 on 2024-07-03 16:03
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ("documents", "0010_delete_entityrole_entitylink"),
+    ]
+
+    operations = [
+        migrations.AlterField(
+            model_name="classification",
+            name="confidence",
+            field=models.FloatField(default=1),
+            preserve_default=False,
+        ),
+    ]
diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py
index 656815442a..01c5107c8a 100644
--- a/arkindex/documents/models.py
+++ b/arkindex/documents/models.py
@@ -1079,7 +1079,7 @@ class Classification(models.Model):
     # Predicted class is considered as correct by its creator
     high_confidence = models.BooleanField(default=False)
     state = EnumField(ClassificationState, max_length=16, default=ClassificationState.Pending)
-    confidence = models.FloatField(null=True, blank=True)
+    confidence = models.FloatField()
 
     class Meta:
         constraints = [
diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py
index 5a0e02f3c5..906523cf88 100644
--- a/arkindex/documents/tests/tasks/test_corpus_delete.py
+++ b/arkindex/documents/tests/tasks/test_corpus_delete.py
@@ -75,6 +75,7 @@ class TestDeleteCorpus(FixtureTestCase):
             ml_class=cls.corpus.ml_classes.create(name="something"),
             worker_run=worker_run,
             worker_version=cls.worker_version,
+            confidence=0.89
         )
         element.metadatas.create(
             type=MetaType.Text,
diff --git a/arkindex/documents/tests/test_classification.py b/arkindex/documents/tests/test_classification.py
index 04a2e5c434..6aade6578d 100644
--- a/arkindex/documents/tests/test_classification.py
+++ b/arkindex/documents/tests/test_classification.py
@@ -836,6 +836,7 @@ class TestClassifications(FixtureAPITestCase):
             worker_version=self.worker_version_1,
             state=ClassificationState.Pending,
             high_confidence=True,
+            confidence=1,
             ml_class=line
         )
 
@@ -845,6 +846,7 @@ class TestClassifications(FixtureAPITestCase):
                 worker_version=self.worker_version_1,
                 state=ClassificationState.Pending,
                 high_confidence=True,
+                confidence=1,
                 ml_class=self.text
             )
             Classification.objects.create(
@@ -852,6 +854,7 @@ class TestClassifications(FixtureAPITestCase):
                 worker_version=self.worker_version_2,
                 state=ClassificationState.Pending,
                 high_confidence=False,
+                confidence=0.4,
                 ml_class=self.text
             )
 
diff --git a/arkindex/documents/tests/test_retrieve_elements.py b/arkindex/documents/tests/test_retrieve_elements.py
index af36b57722..493861f881 100644
--- a/arkindex/documents/tests/test_retrieve_elements.py
+++ b/arkindex/documents/tests/test_retrieve_elements.py
@@ -27,7 +27,7 @@ class TestRetrieveElements(FixtureAPITestCase):
 
     def test_get_element(self):
         ml_class = MLClass.objects.create(name="text", corpus=self.corpus)
-        classification = self.vol.classifications.create(worker_version=self.worker_version, ml_class=ml_class)
+        classification = self.vol.classifications.create(worker_version=self.worker_version, ml_class=ml_class, confidence=0.8)
 
         with self.assertNumQueries(2):
             response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
@@ -56,7 +56,7 @@ class TestRetrieveElements(FixtureAPITestCase):
             "classifications": [
                 {
                     "id": str(classification.id),
-                    "confidence": None,
+                    "confidence": 0.8,
                     "high_confidence": False,
                     "state": "pending",
                     "worker_version": str(self.worker_version.id),
@@ -251,7 +251,7 @@ class TestRetrieveElements(FixtureAPITestCase):
 
     def test_get_element_classification_worker_run(self):
         ml_class = MLClass.objects.create(name="text", corpus=self.corpus)
-        classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class)
+        classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class, confidence=0.89)
 
         with self.assertNumQueries(3):
             response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
@@ -280,7 +280,7 @@ class TestRetrieveElements(FixtureAPITestCase):
             "classifications": [
                 {
                     "id": str(classification.id),
-                    "confidence": None,
+                    "confidence": 0.89,
                     "high_confidence": False,
                     "state": "pending",
                     "worker_version": str(self.worker_version.id),
diff --git a/arkindex/process/tests/test_managers.py b/arkindex/process/tests/test_managers.py
index 8438449621..629f96e778 100644
--- a/arkindex/process/tests/test_managers.py
+++ b/arkindex/process/tests/test_managers.py
@@ -79,6 +79,7 @@ class TestManagers(FixtureTestCase):
             element=Element.objects.first(),
             ml_class=MLClass.objects.create(corpus=self.corpus, name="b"),
             worker_version=self.worker_version,
+            confidence=0.89,
         )
 
         self.assertFalse(self.corpus.worker_version_cache.exists())
diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py
index 39307558df..b3370d28e8 100644
--- a/arkindex/process/tests/test_processes.py
+++ b/arkindex/process/tests/test_processes.py
@@ -806,7 +806,7 @@ class TestProcesses(FixtureAPITestCase):
         entity = self.corpus.entities.create(name="test", type=self.corpus.entity_types.create())
         transcription = page.transcriptions.first()
         transcription_entity = transcription.transcription_entities.create(entity=entity, offset=1, length=1)
-        classification = page.classifications.create(ml_class=self.corpus.ml_classes.create(name="test"))
+        classification = page.classifications.create(ml_class=self.corpus.ml_classes.create(name="test"), confidence=1)
 
         for related_obj in (page, metadata, entity, transcription, transcription_entity, classification):
             # The atomic() block ensures a rollback before the next subtest, in case of a failure
-- 
GitLab