diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index 1e730b73f8fa799886dd30c770493b9a85a09dc0..b5dc503dd46a61b6ac49add716426645441dae7c 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -98,7 +98,7 @@ from arkindex.training.models import DatasetElement, ModelVersion from arkindex.users.models import Role from arkindex.users.utils import filter_rights -classifications_queryset = Classification.objects.select_related("ml_class", "worker_version").order_by("-confidence") +classifications_queryset = Classification.objects.select_related("ml_class", "worker_run").order_by("-confidence") def _fetch_has_children(elements): diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 97ff0a54330934813029a73fd391ab1d51b190b8..dcfab609b3cdf086bfb77f7a570fe702e9803a37 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -4,7 +4,6 @@ from enum import Enum from textwrap import dedent from django.conf import settings -from drf_spectacular.utils import extend_schema_serializer from rest_framework import serializers from rest_framework.exceptions import ValidationError @@ -21,7 +20,7 @@ from arkindex.documents.models import ( from arkindex.documents.serializers.light import ElementZoneSerializer from arkindex.project.serializer_fields import EnumField, ForbiddenField, LinearRingField, WorkerRunIDField from arkindex.project.tools import polygon_outside_image -from arkindex.project.validators import ConditionalUniqueValidator, ForbiddenValidator +from arkindex.project.validators import ConditionalUniqueValidator # Defined here to avoid circular imports, because used by documents serializer @@ -63,7 +62,6 @@ class MLClassSerializer(MLClassLightSerializer): fields = ("id", "name", "corpus") -@extend_schema_serializer(deprecate_fields=("worker_version", )) class ClassificationSerializer(serializers.ModelSerializer): """ Serialize a classification on an Element @@ -72,11 +70,6 @@ class ClassificationSerializer(serializers.ModelSerializer): ml_class = MLClassSerializer() state = EnumField(ClassificationState) worker_run = WorkerRunSummarySerializer(read_only=True, allow_null=True) - worker_version = serializers.UUIDField( - read_only=True, - allow_null=True, - source="worker_version_id", - ) class Meta: model = Classification @@ -86,19 +79,16 @@ class ClassificationSerializer(serializers.ModelSerializer): "state", "confidence", "high_confidence", - "worker_version", "worker_run", ) read_only_fields = ( "id", "confidence", "high_confidence", - "worker_version", "worker_run", ) -@extend_schema_serializer(deprecate_fields=("worker_version", )) class ClassificationCreateSerializer(serializers.ModelSerializer): """ Serializer to create a single classification, defaulting to manual @@ -111,19 +101,6 @@ class ClassificationCreateSerializer(serializers.ModelSerializer): queryset=MLClass.objects.using("default").none(), style={"base_template": "input.html"}, ) - worker_version = serializers.UUIDField( - allow_null=True, - default=None, - source="worker_version_id", - validators=[ - ForbiddenValidator(), - ], - help_text=dedent(""" - ID of a WorkerVersion that created this classification. - - Creating new classifications with a WorkerVersion is forbidden. Use `worker_run_id` instead. - """), - ) worker_run_id = WorkerRunIDField( required=False, allow_null=True, @@ -159,7 +136,6 @@ class ClassificationCreateSerializer(serializers.ModelSerializer): "id", "element", "ml_class", - "worker_version", "worker_run_id", "confidence", "high_confidence", @@ -201,7 +177,7 @@ class ClassificationCreateSerializer(serializers.ModelSerializer): self.fields["ml_class"].queryset = MLClass.objects.using("default").filter(corpus__in=corpora) def validate(self, data): - # Note that (worker_version, class, element) unicity is already checked by DRF + # Note that (worker_version / worker run, class, element) unicity is already checked by DRF errors = {} if data["element"].corpus_id != data["ml_class"].corpus_id: @@ -532,7 +508,6 @@ class ClassificationsSerializer(serializers.Serializer): queryset=Element.objects.none(), style={"base_template": "input.html"}, ) - worker_version = ForbiddenField() worker_run_id = WorkerRunIDField( help_text=dedent(""" A WorkerRun ID that the classifications will refer to. diff --git a/arkindex/documents/tests/test_bulk_classification.py b/arkindex/documents/tests/test_bulk_classification.py index 60f82ea19b14bc9891496c53e80a88cb54ba1595..4b2b13fc3b75ceb98cfa9fe5b5941cf7f3833fbd 100644 --- a/arkindex/documents/tests/test_bulk_classification.py +++ b/arkindex/documents/tests/test_bulk_classification.py @@ -71,7 +71,6 @@ class TestBulkClassification(FixtureAPITestCase): format="json", data={ "parent": str(self.page.id), - "worker_version": str(self.worker_version.id), "classifications": [ { "ml_class": str(self.dog_class.id), @@ -87,8 +86,7 @@ class TestBulkClassification(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), { - "worker_run_id": ["This field is required."], - "worker_version": ["This field is forbidden."], + "worker_run_id": ["This field is required."] }) def test_worker_run_not_found(self): @@ -378,12 +376,11 @@ class TestBulkClassification(FixtureAPITestCase): "ml_class__name", "confidence", "high_confidence", - "worker_version_id", "worker_run_id", )), [ - ("dog", 0.99, True, self.custom_version.id, self.local_worker_run.id), - ("cat", 0.42, False, self.custom_version.id, self.local_worker_run.id), + ("dog", 0.99, True, self.local_worker_run.id), + ("cat", 0.42, False, self.local_worker_run.id), ], ) # Worker run is set, and worker version is deduced from it diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py index 20409bbcc53b93f679a3e3562b6ad42b178815f3..b0fd1b3f29c284359506fe9a0c03d733fbe2aa37 100644 --- a/arkindex/documents/tests/test_classes.py +++ b/arkindex/documents/tests/test_classes.py @@ -27,6 +27,9 @@ class TestClasses(FixtureAPITestCase): cls.version1 = WorkerVersion.objects.get(worker__slug="reco") cls.version2 = WorkerVersion.objects.get(worker__slug="dla") + cls.worker_run_1 = cls.version1.worker_runs.first() + cls.worker_run_2 = cls.version2.worker_runs.first() + for elt_num in range(1, 13): elt = cls.corpus.elements.create( name=f"elt_{elt_num}", @@ -35,9 +38,10 @@ class TestClasses(FixtureAPITestCase): elt.add_parent(cls.parent) cls.common_children.add_parent(elt) for ml_class, score in ((cls.text, .7), (cls.cover, .99)): - for worker_version in (cls.version1, cls.version2): + for worker_run in (cls.worker_run_1, cls.worker_run_2): elt.classifications.create( - worker_version=worker_version, + worker_version=worker_run.version, + worker_run=worker_run, ml_class=ml_class, confidence=score, high_confidence=bool(score == .99) @@ -327,12 +331,12 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(data["count"], 12) for elt in data["results"]: self.assertCountEqual( - list(map(lambda c: (c["worker_version"], c["ml_class"]["name"], c["confidence"]), elt["classes"])), + list(map(lambda c: (c["worker_run"], c["ml_class"]["name"], c["confidence"]), elt["classes"])), [ - (str(self.version1.id), "cover", .99), - (str(self.version2.id), "cover", .99), - (str(self.version1.id), "text", .7), - (str(self.version2.id), "text", .7), + ({"id": str(self.worker_run_1.id), "summary": self.worker_run_1.summary}, "cover", .99), + ({"id": str(self.worker_run_2.id), "summary": self.worker_run_2.summary}, "cover", .99), + ({"id": str(self.worker_run_1.id), "summary": self.worker_run_1.summary}, "text", .7), + ({"id": str(self.worker_run_2.id), "summary": self.worker_run_2.summary}, "text", .7), ] ) @@ -360,12 +364,12 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(data["count"], 12) for elt in data["results"]: self.assertCountEqual( - list(map(lambda c: (c["worker_version"], c["ml_class"]["name"], c["confidence"]), elt["classes"])), + list(map(lambda c: (c["worker_run"], c["ml_class"]["name"], c["confidence"]), elt["classes"])), [ - (str(self.version1.id), "cover", .99), - (str(self.version2.id), "cover", .99), - (str(self.version1.id), "text", .7), - (str(self.version2.id), "text", .7), + ({"id": str(self.worker_run_1.id), "summary": self.worker_run_1.summary}, "cover", .99), + ({"id": str(self.worker_run_2.id), "summary": self.worker_run_2.summary}, "cover", .99), + ({"id": str(self.worker_run_1.id), "summary": self.worker_run_1.summary}, "text", .7), + ({"id": str(self.worker_run_2.id), "summary": self.worker_run_2.summary}, "text", .7), ] ) @@ -380,12 +384,12 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(data["count"], 12) for elt in data["results"]: self.assertCountEqual( - list(map(lambda c: (c["worker_version"], c["ml_class"]["name"], c["confidence"]), elt["classes"])), + list(map(lambda c: (c["worker_run"], c["ml_class"]["name"], c["confidence"]), elt["classes"])), [ - (str(self.version1.id), "cover", .99), - (str(self.version2.id), "cover", .99), - (str(self.version1.id), "text", .7), - (str(self.version2.id), "text", .7), + ({"id": str(self.worker_run_1.id), "summary": self.worker_run_1.summary}, "cover", .99), + ({"id": str(self.worker_run_2.id), "summary": self.worker_run_2.summary}, "cover", .99), + ({"id": str(self.worker_run_1.id), "summary": self.worker_run_1.summary}, "text", .7), + ({"id": str(self.worker_run_2.id), "summary": self.worker_run_2.summary}, "text", .7), ] ) diff --git a/arkindex/documents/tests/test_classification.py b/arkindex/documents/tests/test_classification.py index 6aade6578dfaba50b88048bc7353e6bf6a340497..a8dd00e79b37ef31700517da94e2cb3d107ccc1a 100644 --- a/arkindex/documents/tests/test_classification.py +++ b/arkindex/documents/tests/test_classification.py @@ -52,15 +52,14 @@ class TestClassifications(FixtureAPITestCase): "element": str(self.element.id), "ml_class": str(self.text.id), "worker_run_id": None, - "worker_version": None, "state": ClassificationState.Validated.value, "confidence": 1.0, "high_confidence": True, }) - def test_create_null_version(self): + def test_create_null_worker_run(self): """ - A manual classification may be created specifying the version as null + A manual classification may be created specifying the worker run ID as null """ self.client.force_login(self.user) with self.assertNumQueries(6): @@ -69,7 +68,7 @@ class TestClassifications(FixtureAPITestCase): data={ "element": str(self.element.id), "ml_class": str(self.text.id), - "worker_version": None, + "worker_run_id": None, }, format="json" ) @@ -188,7 +187,6 @@ class TestClassifications(FixtureAPITestCase): "element": str(self.element.id), "ml_class": str(self.text.id), "worker_run_id": None, - "worker_version": None, "state": ClassificationState.Validated.value, "confidence": 1, "high_confidence": True, @@ -206,21 +204,6 @@ class TestClassifications(FixtureAPITestCase): }) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_create_worker_version(self): - self.client.force_login(self.user) - with self.assertNumQueries(4): - response = self.client.post(reverse("api:classification-create"), { - "element": str(self.element.id), - "ml_class": str(self.text.id), - "worker_version": str(self.worker_version_1.id), - "confidence": 0.42, - "high_confidence": False, - }) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - "worker_version": ["This field is forbidden."], - }) - def test_create_worker_run_local(self): """ A regular user can create a classification with a WorkerRun of their own local process @@ -249,7 +232,6 @@ class TestClassifications(FixtureAPITestCase): "element": str(self.element.id), "ml_class": str(self.text.id), "worker_run_id": str(self.local_worker_run.id), - "worker_version": str(self.custom_version.id), "state": ClassificationState.Pending.value, "confidence": 0.42, "high_confidence": False, @@ -290,7 +272,6 @@ class TestClassifications(FixtureAPITestCase): "element": str(self.element.id), "ml_class": str(self.text.id), "worker_run_id": str(self.local_worker_run.id), - "worker_version": str(self.custom_version.id), "state": ClassificationState.Pending.value, "confidence": 0.42, "high_confidence": False, @@ -331,7 +312,6 @@ class TestClassifications(FixtureAPITestCase): "element": str(self.element.id), "ml_class": str(self.text.id), "worker_run_id": str(self.worker_run.id), - "worker_version": str(self.worker_version_1.id), "state": ClassificationState.Pending.value, "confidence": 0.42, "high_confidence": False, @@ -409,21 +389,6 @@ class TestClassifications(FixtureAPITestCase): "worker_run_id": ["Only the WorkerRuns of the authenticated task's process may be used."], }) - def test_create_worker_version_xor_worker_run(self): - self.client.force_login(self.user) - with self.assertNumQueries(5): - response = self.client.post(reverse("api:classification-create"), { - "element": str(self.element.id), - "ml_class": str(self.text.id), - "worker_version": str(self.worker_version_1.id), - "worker_run_id": str(self.local_worker_run.id), - }) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - self.assertEqual(response.json(), { - "worker_version": ["This field is forbidden."], - }) - def test_create_worker_run_not_found(self): self.client.force_login(self.user) with self.assertNumQueries(5): @@ -461,16 +426,16 @@ class TestClassifications(FixtureAPITestCase): self.assertEqual(classification.confidence, 0) self.assertFalse(classification.high_confidence) - def test_create_manual_and_worker_version(self): + def test_create_manual_and_worker_run(self): """ - CreateClassification should allow creating the same ML class from a worker run, - a worker version and no worker version on the same element. + CreateClassification should allow creating the same ML class from a worker run + and no worker run on the same element. """ self.client.force_login(self.user) - # Create a classification with a worker version and no worker run self.element.classifications.create( - worker_version=self.worker_version_2, + worker_version=self.worker_run.version, + worker_run=self.worker_run, ml_class=self.text, confidence=0.5, high_confidence=False, @@ -488,44 +453,33 @@ class TestClassifications(FixtureAPITestCase): self.assertEqual(self.element.classifications.count(), 2) - # Create a classification with the same class and a worker run - with self.assertNumQueries(7): - response = self.client.post(reverse("api:classification-create"), { - "element": str(self.element.id), - "ml_class": str(self.text.id), - "worker_run_id": str(self.local_worker_run.id), - "confidence": 0.5, - "high_confidence": False, - }) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - self.assertEqual(self.element.classifications.count(), 3) - self.assertCountEqual( list(self.element.classifications.values("worker_version_id", "worker_run_id")), [ {"worker_version_id": None, "worker_run_id": None}, - {"worker_version_id": self.worker_version_2.id, "worker_run_id": None}, - {"worker_version_id": self.custom_version.id, "worker_run_id": self.local_worker_run.id}, + {"worker_version_id": self.worker_run.version.id, "worker_run_id": self.worker_run.id}, ] ) def test_validate(self): classification = self.element.classifications.create( worker_version=self.worker_version_1, + worker_run=self.worker_run, ml_class=self.text, confidence=.1 ) self.client.force_login(self.user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.put(reverse("api:classification-validate", kwargs={"pk": classification.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { "id": str(classification.id), - "worker_version": str(self.worker_version_1.id), - "worker_run": None, + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Document layout analyser @ version 1" + }, "ml_class": { "id": str(classification.ml_class.id), "name": classification.ml_class.name @@ -561,22 +515,25 @@ class TestClassifications(FixtureAPITestCase): response = self.client.put(reverse("api:classification-validate", kwargs={"pk": classification.id})) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_reject_worker_version(self): + def test_reject_worker_run(self): self.client.force_login(self.user) classification = self.element.classifications.create( worker_version=self.worker_version_1, + worker_run=self.worker_run, ml_class=self.text, confidence=.1, ) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.put(reverse("api:classification-reject", kwargs={"pk": classification.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { "id": str(classification.id), - "worker_version": str(self.worker_version_1.id), - "worker_run": None, + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Document layout analyser @ version 1" + }, "ml_class": { "id": str(classification.ml_class.id), "name": classification.ml_class.name @@ -630,11 +587,12 @@ class TestClassifications(FixtureAPITestCase): confidence=.5, moderator=self.user, state=ClassificationState.Validated, - worker_version=self.worker_version_2, + worker_version=self.worker_version_1, + worker_run=self.worker_run ) # First try to reject - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.put(reverse("api:classification-reject", kwargs={"pk": classification.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -647,12 +605,14 @@ class TestClassifications(FixtureAPITestCase): "state": ClassificationState.Rejected.value, "confidence": classification.confidence, "high_confidence": False, - "worker_version": str(self.worker_version_2.id), - "worker_run": None, + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Document layout analyser @ version 1" + }, }) # Then try to validate - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.put(reverse("api:classification-validate", kwargs={"pk": classification.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -665,8 +625,10 @@ class TestClassifications(FixtureAPITestCase): "state": ClassificationState.Validated.value, "confidence": classification.confidence, "high_confidence": False, - "worker_version": str(self.worker_version_2.id), - "worker_run": None, + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Document layout analyser @ version 1" + }, }) def test_create_selection_requires_login(self): diff --git a/arkindex/documents/tests/test_retrieve_elements.py b/arkindex/documents/tests/test_retrieve_elements.py index 493861f881fb81ac6e7a1f884812a13d7af99f16..ebd458883552d9ae31b18ee09120167a4113ca14 100644 --- a/arkindex/documents/tests/test_retrieve_elements.py +++ b/arkindex/documents/tests/test_retrieve_elements.py @@ -27,7 +27,7 @@ class TestRetrieveElements(FixtureAPITestCase): def test_get_element(self): ml_class = MLClass.objects.create(name="text", corpus=self.corpus) - classification = self.vol.classifications.create(worker_version=self.worker_version, ml_class=ml_class, confidence=0.8) + classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class, confidence=0.8) with self.assertNumQueries(2): response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) @@ -59,8 +59,10 @@ class TestRetrieveElements(FixtureAPITestCase): "confidence": 0.8, "high_confidence": False, "state": "pending", - "worker_version": str(self.worker_version.id), - "worker_run": None, + "worker_run": { + "id": str(self.worker_run.id), + "summary": self.worker_run.summary + }, "ml_class": { "id": str(ml_class.id), "name": "text", @@ -253,7 +255,7 @@ class TestRetrieveElements(FixtureAPITestCase): ml_class = MLClass.objects.create(name="text", corpus=self.corpus) classification = self.vol.classifications.create(worker_version=self.worker_version, worker_run=self.worker_run, ml_class=ml_class, confidence=0.89) - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -283,7 +285,6 @@ class TestRetrieveElements(FixtureAPITestCase): "confidence": 0.89, "high_confidence": False, "state": "pending", - "worker_version": str(self.worker_version.id), "worker_run": { "id": str(self.worker_run.id), "summary": self.worker_run.summary