diff --git a/arkindex/documents/export/structure.sql b/arkindex/documents/export/structure.sql index 3676a26f39fec4dc0c3bc49d1450821f2fb84e5a..580f5342a047dd3f154ee99e45dad294e11de57d 100644 --- a/arkindex/documents/export/structure.sql +++ b/arkindex/documents/export/structure.sql @@ -110,7 +110,7 @@ CREATE TABLE classification ( class_name VARCHAR(1024) NOT NULL, state VARCHAR(16) NOT NULL DEFAULT 'pending', moderator VARCHAR(255), - confidence REAL, + confidence REAL NOT NULL, high_confidence INTEGER NOT NULL DEFAULT 0, worker_version_id VARCHAR(37), worker_run_id VARCHAR(37), @@ -118,7 +118,7 @@ CREATE TABLE classification ( FOREIGN KEY (element_id) REFERENCES element (id) ON DELETE CASCADE, FOREIGN KEY (worker_version_id) REFERENCES worker_version (id) ON DELETE CASCADE, FOREIGN KEY (worker_run_id) REFERENCES worker_run (id) ON DELETE CASCADE, - CHECK (confidence IS NULL OR (confidence >= 0 AND confidence <= 1)), + CHECK (confidence >= 0 AND confidence <= 1), CHECK (high_confidence = 0 OR high_confidence = 1), CHECK (worker_run_id IS NULL OR worker_version_id IS NOT NULL) ); diff --git a/arkindex/documents/migrations/0011_alter_classification_confidence.py b/arkindex/documents/migrations/0011_alter_classification_confidence.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd5ecf021fdcef4fe8c044d78c78417e2cb4e7e --- /dev/null +++ b/arkindex/documents/migrations/0011_alter_classification_confidence.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.13 on 2024-07-03 16:03 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("documents", "0010_delete_entityrole_entitylink"), + ] + + operations = [ + migrations.AlterField( + model_name="classification", + name="confidence", + field=models.FloatField(default=1), + preserve_default=False, + ), + ] diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 656815442afd5adfcfa6e3d60c41397cfa7ad745..01c5107c8ac4a27f5b92373aaf3071d4ae5ceb10 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -1079,7 +1079,7 @@ class Classification(models.Model): # Predicted class is considered as correct by its creator high_confidence = models.BooleanField(default=False) state = EnumField(ClassificationState, max_length=16, default=ClassificationState.Pending) - confidence = models.FloatField(null=True, blank=True) + confidence = models.FloatField() class Meta: constraints = [ diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py index 5a0e02f3c5f61e31f86bd3e13b71abf81492d96b..906523cf88b498db4ea64e9dae526ff945cb1c86 100644 --- a/arkindex/documents/tests/tasks/test_corpus_delete.py +++ b/arkindex/documents/tests/tasks/test_corpus_delete.py @@ -75,6 +75,7 @@ class TestDeleteCorpus(FixtureTestCase): ml_class=cls.corpus.ml_classes.create(name="something"), worker_run=worker_run, worker_version=cls.worker_version, + confidence=0.89 ) element.metadatas.create( type=MetaType.Text, diff --git a/arkindex/documents/tests/test_classification.py b/arkindex/documents/tests/test_classification.py index 04a2e5c434a2b6ededee676ff1d375ff08348b3f..6aade6578dfaba50b88048bc7353e6bf6a340497 100644 --- a/arkindex/documents/tests/test_classification.py +++ b/arkindex/documents/tests/test_classification.py @@ -836,6 +836,7 @@ class TestClassifications(FixtureAPITestCase): worker_version=self.worker_version_1, state=ClassificationState.Pending, high_confidence=True, + confidence=1, ml_class=line ) @@ -845,6 +846,7 @@ class TestClassifications(FixtureAPITestCase): worker_version=self.worker_version_1, state=ClassificationState.Pending, high_confidence=True, + confidence=1, ml_class=self.text ) Classification.objects.create( @@ -852,6 +854,7 @@ class TestClassifications(FixtureAPITestCase): worker_version=self.worker_version_2, state=ClassificationState.Pending, high_confidence=False, + confidence=0.4, ml_class=self.text ) diff --git a/arkindex/documents/tests/test_retrieve_elements.py b/arkindex/documents/tests/test_retrieve_elements.py index af36b57722593356d0a4478e7df04da982f384db..493861f881fb81ac6e7a1f884812a13d7af99f16 100644 --- a/arkindex/documents/tests/test_retrieve_elements.py +++ b/arkindex/documents/tests/test_retrieve_elements.py @@ -27,7 +27,7 @@ class TestRetrieveElements(FixtureAPITestCase): def test_get_element(self): ml_class = MLClass.objects.create(name="text", corpus=self.corpus) - classification = self.vol.classifications.create(worker_version=self.worker_version, ml_class=ml_class) + classification = self.vol.classifications.create(worker_version=self.worker_version, ml_class=ml_class, confidence=0.8) with self.assertNumQueries(2): response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) @@ -56,7 +56,7 @@ class TestRetrieveElements(FixtureAPITestCase): "classifications": [ { "id": str(classification.id), - "confidence": None, + "confidence": 0.8, "high_confidence": False, "state": "pending", "worker_version": str(self.worker_version.id), @@ -251,7 +251,7 @@ class TestRetrieveElements(FixtureAPITestCase): def test_get_element_classification_worker_run(self): ml_class = MLClass.objects.create(name="text", corpus=self.corpus) - classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class) + classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class, confidence=0.89) with self.assertNumQueries(3): response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) @@ -280,7 +280,7 @@ class TestRetrieveElements(FixtureAPITestCase): "classifications": [ { "id": str(classification.id), - "confidence": None, + "confidence": 0.89, "high_confidence": False, "state": "pending", "worker_version": str(self.worker_version.id), diff --git a/arkindex/process/tests/test_managers.py b/arkindex/process/tests/test_managers.py index 84384496215eee8242755f6e0779ea4d82768199..629f96e778fcfd8e18c60598bc27365dc59eb267 100644 --- a/arkindex/process/tests/test_managers.py +++ b/arkindex/process/tests/test_managers.py @@ -79,6 +79,7 @@ class TestManagers(FixtureTestCase): element=Element.objects.first(), ml_class=MLClass.objects.create(corpus=self.corpus, name="b"), worker_version=self.worker_version, + confidence=0.89, ) self.assertFalse(self.corpus.worker_version_cache.exists()) diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index 39307558dffcc99a7f62085699feb850436bff5f..b3370d28e8491f87195db503cbf1cce7ab241aeb 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -806,7 +806,7 @@ class TestProcesses(FixtureAPITestCase): entity = self.corpus.entities.create(name="test", type=self.corpus.entity_types.create()) transcription = page.transcriptions.first() transcription_entity = transcription.transcription_entities.create(entity=entity, offset=1, length=1) - classification = page.classifications.create(ml_class=self.corpus.ml_classes.create(name="test")) + classification = page.classifications.create(ml_class=self.corpus.ml_classes.create(name="test"), confidence=1) for related_obj in (page, metadata, entity, transcription, transcription_entity, classification): # The atomic() block ensures a rollback before the next subtest, in case of a failure