From 6ccba7ae5d2a3a294217e0ad7bf472eb51799df4 Mon Sep 17 00:00:00 2001 From: ml bonhomme <bonhomme@teklia.com> Date: Thu, 4 Jul 2024 08:43:14 +0000 Subject: [PATCH] Make classification confidence non nullable --- arkindex/documents/export/structure.sql | 4 ++-- .../0011_alter_classification_confidence.py | 19 +++++++++++++++++++ arkindex/documents/models.py | 2 +- .../tests/tasks/test_corpus_delete.py | 1 + .../documents/tests/test_classification.py | 3 +++ .../documents/tests/test_retrieve_elements.py | 8 ++++---- arkindex/process/tests/test_managers.py | 1 + arkindex/process/tests/test_processes.py | 2 +- 8 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 arkindex/documents/migrations/0011_alter_classification_confidence.py diff --git a/arkindex/documents/export/structure.sql b/arkindex/documents/export/structure.sql index 3676a26f39..580f5342a0 100644 --- a/arkindex/documents/export/structure.sql +++ b/arkindex/documents/export/structure.sql @@ -110,7 +110,7 @@ CREATE TABLE classification ( class_name VARCHAR(1024) NOT NULL, state VARCHAR(16) NOT NULL DEFAULT 'pending', moderator VARCHAR(255), - confidence REAL, + confidence REAL NOT NULL, high_confidence INTEGER NOT NULL DEFAULT 0, worker_version_id VARCHAR(37), worker_run_id VARCHAR(37), @@ -118,7 +118,7 @@ CREATE TABLE classification ( FOREIGN KEY (element_id) REFERENCES element (id) ON DELETE CASCADE, FOREIGN KEY (worker_version_id) REFERENCES worker_version (id) ON DELETE CASCADE, FOREIGN KEY (worker_run_id) REFERENCES worker_run (id) ON DELETE CASCADE, - CHECK (confidence IS NULL OR (confidence >= 0 AND confidence <= 1)), + CHECK (confidence >= 0 AND confidence <= 1), CHECK (high_confidence = 0 OR high_confidence = 1), CHECK (worker_run_id IS NULL OR worker_version_id IS NOT NULL) ); diff --git a/arkindex/documents/migrations/0011_alter_classification_confidence.py b/arkindex/documents/migrations/0011_alter_classification_confidence.py new file mode 100644 index 0000000000..0cd5ecf021 --- /dev/null +++ b/arkindex/documents/migrations/0011_alter_classification_confidence.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.13 on 2024-07-03 16:03 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("documents", "0010_delete_entityrole_entitylink"), + ] + + operations = [ + migrations.AlterField( + model_name="classification", + name="confidence", + field=models.FloatField(default=1), + preserve_default=False, + ), + ] diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 656815442a..01c5107c8a 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -1079,7 +1079,7 @@ class Classification(models.Model): # Predicted class is considered as correct by its creator high_confidence = models.BooleanField(default=False) state = EnumField(ClassificationState, max_length=16, default=ClassificationState.Pending) - confidence = models.FloatField(null=True, blank=True) + confidence = models.FloatField() class Meta: constraints = [ diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py index 5a0e02f3c5..906523cf88 100644 --- a/arkindex/documents/tests/tasks/test_corpus_delete.py +++ b/arkindex/documents/tests/tasks/test_corpus_delete.py @@ -75,6 +75,7 @@ class TestDeleteCorpus(FixtureTestCase): ml_class=cls.corpus.ml_classes.create(name="something"), worker_run=worker_run, worker_version=cls.worker_version, + confidence=0.89 ) element.metadatas.create( type=MetaType.Text, diff --git a/arkindex/documents/tests/test_classification.py b/arkindex/documents/tests/test_classification.py index 04a2e5c434..6aade6578d 100644 --- a/arkindex/documents/tests/test_classification.py +++ b/arkindex/documents/tests/test_classification.py @@ -836,6 +836,7 @@ class TestClassifications(FixtureAPITestCase): worker_version=self.worker_version_1, state=ClassificationState.Pending, high_confidence=True, + confidence=1, ml_class=line ) @@ -845,6 +846,7 @@ class TestClassifications(FixtureAPITestCase): worker_version=self.worker_version_1, state=ClassificationState.Pending, high_confidence=True, + confidence=1, ml_class=self.text ) Classification.objects.create( @@ -852,6 +854,7 @@ class TestClassifications(FixtureAPITestCase): worker_version=self.worker_version_2, state=ClassificationState.Pending, high_confidence=False, + confidence=0.4, ml_class=self.text ) diff --git a/arkindex/documents/tests/test_retrieve_elements.py b/arkindex/documents/tests/test_retrieve_elements.py index af36b57722..493861f881 100644 --- a/arkindex/documents/tests/test_retrieve_elements.py +++ b/arkindex/documents/tests/test_retrieve_elements.py @@ -27,7 +27,7 @@ class TestRetrieveElements(FixtureAPITestCase): def test_get_element(self): ml_class = MLClass.objects.create(name="text", corpus=self.corpus) - classification = self.vol.classifications.create(worker_version=self.worker_version, ml_class=ml_class) + classification = self.vol.classifications.create(worker_version=self.worker_version, ml_class=ml_class, confidence=0.8) with self.assertNumQueries(2): response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) @@ -56,7 +56,7 @@ class TestRetrieveElements(FixtureAPITestCase): "classifications": [ { "id": str(classification.id), - "confidence": None, + "confidence": 0.8, "high_confidence": False, "state": "pending", "worker_version": str(self.worker_version.id), @@ -251,7 +251,7 @@ class TestRetrieveElements(FixtureAPITestCase): def test_get_element_classification_worker_run(self): ml_class = MLClass.objects.create(name="text", corpus=self.corpus) - classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class) + classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class, confidence=0.89) with self.assertNumQueries(3): response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) @@ -280,7 +280,7 @@ class TestRetrieveElements(FixtureAPITestCase): "classifications": [ { "id": str(classification.id), - "confidence": None, + "confidence": 0.89, "high_confidence": False, "state": "pending", "worker_version": str(self.worker_version.id), diff --git a/arkindex/process/tests/test_managers.py b/arkindex/process/tests/test_managers.py index 8438449621..629f96e778 100644 --- a/arkindex/process/tests/test_managers.py +++ b/arkindex/process/tests/test_managers.py @@ -79,6 +79,7 @@ class TestManagers(FixtureTestCase): element=Element.objects.first(), ml_class=MLClass.objects.create(corpus=self.corpus, name="b"), worker_version=self.worker_version, + confidence=0.89, ) self.assertFalse(self.corpus.worker_version_cache.exists()) diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index 39307558df..b3370d28e8 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -806,7 +806,7 @@ class TestProcesses(FixtureAPITestCase): entity = self.corpus.entities.create(name="test", type=self.corpus.entity_types.create()) transcription = page.transcriptions.first() transcription_entity = transcription.transcription_entities.create(entity=entity, offset=1, length=1) - classification = page.classifications.create(ml_class=self.corpus.ml_classes.create(name="test")) + classification = page.classifications.create(ml_class=self.corpus.ml_classes.create(name="test"), confidence=1) for related_obj in (page, metadata, entity, transcription, transcription_entity, classification): # The atomic() block ensures a rollback before the next subtest, in case of a failure -- GitLab