diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index 82c9c5dc3f5548bdbe79c3219c64af1e17eb9a04..4c371922c2e43cccf83554c9e37f3f2105f7ef70 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -43,6 +43,7 @@ from arkindex.documents.models import ( ElementType, MetaData, MetaType, + MLClass, Selection, Transcription, ) @@ -131,13 +132,20 @@ def _fetch_has_children(elements): return elements +# Operators available for numeric filters in element lists # Maps valid operator names in the API to Django QuerySet lookups -METADATA_OPERATORS = { +NUMERIC_OPERATORS = { 'eq': 'exact', 'lt': 'lt', 'gt': 'gt', 'lte': 'lte', 'gte': 'gte', +} + +# Operators available for metadata values. +METADATA_OPERATORS = { + # Only for numeric metadata + **NUMERIC_OPERATORS, # The contains operator should be case-insensitive 'contains': 'icontains', } @@ -202,17 +210,6 @@ class ElementsListAutoSchema(AutoSchema): # Add method-specific parameters if self.method.lower() == 'get': parameters.extend([ - OpenApiParameter( - 'best_class', - description='Restrict to or exclude elements with a best class, ' - 'or restrict to elements with specific best class', - type={ - 'oneOf': [ - {'type': 'string', 'format': 'uuid'}, - {'type': 'boolean'} - ] - } - ), OpenApiParameter( 'metadata_name', description='Restrict to elements having a metadata with the given name.', @@ -221,7 +218,7 @@ class ElementsListAutoSchema(AutoSchema): OpenApiParameter( 'metadata_value', description='Restrict to elements having a metadata with the given value. ' - 'Can be set it exclude elements with a value or filter by numerical values using `metadata_operator`. ' + 'The comparison operator can be set using `metadata_operator`. ' 'Requires `metadata_name` to be set.', required=False, ), @@ -243,6 +240,47 @@ class ElementsListAutoSchema(AutoSchema): default='eq', required=False, ), + OpenApiParameter( + 'class_id', + description='Restrict to elements having a classification with the specified ML class ID.' + 'If `classification_confidence` or `classification_high_confidence` are set, ' + 'the elements must have a classification that satisfies all of the parameters at once.', + type=UUID, + required=False, + ), + OpenApiParameter( + 'classification_confidence', + description='Restrict to elements having a classification with the given confidence. ' + 'The comparison operator can be set using `classification_confidence_operator`. ' + 'If `class_id` or `classification_high_confidence` are set, the elements must have a ' + 'classification that satisfies all of the parameters at once.', + required=False, + ), + OpenApiParameter( + 'classification_confidence_operator', + description=dedent(""" + Set the comparison operator to filter on classification confidence scores: + + * `eq` (default): Elements having a classification with this exact confidence. + * `lt`: Elements having a classification with a confidence strictly lower than the filter. + * `lte`: Elements having a classification with a confidence lower than or equal to the filter. + * `lt`: Elements having a classification with a confidence strictly greater than the filter. + * `gte`: Elements having a classification with a confidence greather than or equal to the filter. + + This requires `classification_confidence` to be set. + """), + enum=NUMERIC_OPERATORS.keys(), + default='eq', + required=False, + ), + OpenApiParameter( + 'classification_high_confidence', + description='Restrict to elements having a classification marked as `high_confidence`. ' + 'If `class_id` or `classification_confidence` are set, the elements must have a ' + 'classification that satisfies all of the parameters at once.', + type=bool, + required=False, + ), OpenApiParameter( 'order', description='Sort elements by a specific field', @@ -260,7 +298,7 @@ class ElementsListAutoSchema(AutoSchema): OpenApiParameter( 'with_best_classes', description='Returns best classifications for each element. ' - 'If not set, elements best_classes field will always be null', + 'Otherwise, `best_classes` will always be null.', type=bool, required=False, ), @@ -460,6 +498,66 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView): return queryset.values('element_id') + def get_classification_queryset(self): + """ + Returns a queryset that includes matched element IDs from classification filters, + or None if no classification filters apply. + """ + class_id = self.clean_params.get('class_id') + confidence = self.clean_params.get('classification_confidence') + confidence_operator = self.clean_params.get('classification_confidence_operator', '').lower().strip() + high_confidence = self.clean_params.get('classification_high_confidence', '').lower().strip() + + if len(high_confidence): + high_confidence = high_confidence not in ('false', '0') + else: + # An empty string should be treated as no filter at all + high_confidence = None + + if not class_id and confidence is None and not confidence_operator and high_confidence is None: + # No filters apply + return None + + queryset = Classification.objects.all() + errors = defaultdict(list) + + if class_id: + try: + ml_class = self.selected_corpus.ml_classes.get(id=class_id) + except DjangoValidationError as e: + # An invalid UUID would cause a Django ValidationError + errors['class_id'].extend(e.messages) + except MLClass.DoesNotExist: + errors['class_id'].append(f'ML class "{class_id}" not found') + else: + queryset = queryset.filter(ml_class=ml_class) + + if confidence_operator: + if confidence_operator not in NUMERIC_OPERATORS: + errors['classification_confidence_operator'].append('This operator is not supported.') + if confidence is None: + errors['classification_confidence_operator'].append('This option is not supported without classification_confidence.') + else: + confidence_operator = 'eq' + + if confidence: + try: + confidence = float(confidence) + assert 0 <= confidence <= 1, 'Confidence must be between 0 and 1' + except (TypeError, ValueError, AssertionError) as e: + errors['classification_confidence'].append(str(e)) + else: + lookup = NUMERIC_OPERATORS.get(confidence_operator, 'exact') + queryset = queryset.filter(**{f'confidence__{lookup}': confidence}) + + if high_confidence is not None: + queryset = queryset.filter(high_confidence=high_confidence) + + if errors: + raise ValidationError(errors) + + return queryset.values('element_id') + def get_filters(self): filters = { 'corpus': self.selected_corpus @@ -500,43 +598,19 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView): if metadata_queryset is not None: filters['id__in'] = metadata_queryset + try: + classification_queryset = self.get_classification_queryset() + except ValidationError as e: + errors.update(e.detail) + else: + if classification_queryset is not None: + filters['id__in'] = classification_queryset + if errors: raise ValidationError(errors) return filters - def get_classifications_filters(self): - """ - Build Django ORM filters using Q expressions related to best classes. - Supports 3 modes: - - elements without any best classes - - elements with any best classes - - elements with a specific best class - """ - class_filter = self.clean_params.get('best_class') - if class_filter is None: - return - - # Generic ORM query to find best classes: - # - elements with a validated classification - # - OR where high confidence is True and not rejected - best_classifications = Q(classifications__state=ClassificationState.Validated) | ( - Q(classifications__high_confidence=True) - & ~Q(classifications__state=ClassificationState.Rejected) - ) - - # List elements without any best classes, by inverting the query above - if class_filter.lower() in ('false', '0'): - return ~best_classifications - - try: - # Filter on a specific class - class_filter = uuid.UUID(class_filter) - return best_classifications & Q(classifications__ml_class_id=class_filter) - except (TypeError, ValueError): - # By default, use all best classifications - return best_classifications - def get_serializer_context(self): context = super().get_serializer_context() context['corpus'] = self.selected_corpus @@ -577,11 +651,6 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView): .prefetch_related(*self.get_prefetch()) \ .order_by(*self.get_order_by()) - class_filters = self.get_classifications_filters() - if class_filters is not None: - # Use queryset.distinct() whenever best_class is defined - queryset = queryset.filter(class_filters).distinct() - with_has_children = self.clean_params.get('with_has_children') if with_has_children and with_has_children.lower() not in ('false', '0'): queryset = BulkMap(_fetch_has_children, queryset) diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py index 65942e9e0dcf161d97cb8fb86c68fe2f741a62e6..eeac8361bc93e371abfa9e1df001f81793d0c0f4 100644 --- a/arkindex/documents/tests/test_classes.py +++ b/arkindex/documents/tests/test_classes.py @@ -424,165 +424,22 @@ class TestClasses(FixtureAPITestCase): [(str(self.version1.id), .99), (str(self.version2.id), .99)] ) - def test_class_filter_list_elements(self): - element = Element.objects.filter(type=self.classified.id).first() - element.classifications.create( - ml_class=self.text, - confidence=.1337, - high_confidence=True, - ) - with self.assertNumQueries(5): - response = self.client.get( - reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), - data={'type': self.classified.slug, 'best_class': str(self.text.id)} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 1) - self.assertEqual(data['results'][0]['id'], str(element.id)) - - def test_class_filter_list_parents(self): - parent = Element.objects.get_ascending(self.common_children.id).get(name='elt_1') - self.assertEqual(parent.classifications.filter(ml_class=self.text).count(), 2) - parent.classifications.filter(ml_class=self.text).update(state=ClassificationState.Validated) - with self.assertNumQueries(5): - response = self.client.get( - reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), - data={'type': self.classified.slug, 'best_class': str(self.text.id)} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 1) - self.assertEqual(data['results'][0]['id'], str(parent.id)) - - def test_class_filter_list_children(self): - child = Element.objects.filter(type=self.classified.id).first() - child.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated) - with self.assertNumQueries(6): - response = self.client.get( - reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}), - data={'type': self.classified.slug, 'best_class': str(self.text.id)} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 1) - self.assertEqual(data['results'][0]['id'], str(child.id)) - - def test_class_filter_list_elements_distinct(self): - self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) - self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) - with self.assertNumQueries(5): - response = self.client.get( - reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), - data={'type': self.classified.slug, 'best_class': str(self.cover.id)} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 12) - # Ensure element IDs are unique - ids = [e['id'] for e in data['results']] - self.assertCountEqual(ids, set(ids)) - - def test_class_filter_list_parents_distinct(self): - self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) - self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) - with self.assertNumQueries(5): - response = self.client.get( - reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), - data={'type': self.classified.slug, 'best_class': str(self.cover.id)} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 12) - # Ensure element IDs are unique - ids = [e['id'] for e in data['results']] - self.assertCountEqual(ids, set(ids)) - - def test_class_filter_list_children_distinct(self): - self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) - self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) - with self.assertNumQueries(6): - response = self.client.get( - reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}), - data={'type': self.classified.slug, 'best_class': str(self.cover.id)} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 12) - # Ensure element IDs are unique - ids = [e['id'] for e in data['results']] - self.assertCountEqual(ids, set(ids)) - - def test_class_filter_true(self): - element = Element.objects.filter(type=self.classified.id).first() - element.classifications.all().delete() - element.classifications.create( - worker_version=self.version2, - ml_class_id=self.text.id, - confidence=.1337, - high_confidence=True, - ) - with self.assertNumQueries(6): - response = self.client.get( - reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), - data={'type': self.classified.slug, 'best_class': 'true', 'with_best_classes': 'true'}, - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 12) - best_class_ids = set( - best_class['ml_class']['id'] - for element in data['results'] - for best_class in element['best_classes'] - ) - self.assertSetEqual(best_class_ids, {str(self.text.id), str(self.cover.id)}) - - def test_class_filter_false(self): - element = Element.objects.filter(type=self.classified.id).first() - element.classifications.all().delete() - element.classifications.create( - worker_version=self.version2, - ml_class_id=self.text.id, - confidence=.1337, - high_confidence=False, - ) - with self.assertNumQueries(6): - response = self.client.get( - reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), - data={'type': self.classified.slug, 'best_class': 'false', 'with_best_classes': 'true'} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['count'], 1) - self.assertEqual(data['results'][0]['id'], str(element.id)) - self.assertListEqual(data['results'][0]['best_classes'], []) - - def test_exclude_rejected(self): + def test_with_best_classes_exclude_rejected(self): """ - Ensure that best_classes and with_best_classes ignore rejected high confidence classifications + Ensure that with_best_classes ignores rejected high confidence classifications """ Classification.objects.all().delete() - # One element with only a rejected classification; should be ignored - element1 = Element.objects.get(type=self.classified.id, name='elt_1') - element1.classifications.create( - worker_version=self.version2, - ml_class_id=self.text.id, - confidence=.1337, - high_confidence=True, - state='rejected' - ) - # One element with a rejected classification and a best class, should be included with 1 best class - element2 = Element.objects.get(type=self.classified.id, name='elt_2') - element2.classifications.create( + element = Element.objects.get(type=self.classified.id, name='elt_2') + element.classifications.create( worker_version=self.version1, ml_class_id=self.text.id, confidence=.1337, high_confidence=True, state='rejected' ) - expected_classification = element2.classifications.create( + expected_classification = element.classifications.create( worker_version=self.version2, ml_class_id=self.text.id, confidence=.1337, @@ -592,13 +449,13 @@ class TestClasses(FixtureAPITestCase): with self.assertNumQueries(6): response = self.client.get( reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), - data={'type': self.classified.slug, 'best_class': 'true', 'with_best_classes': 'true'}, + data={'type': self.classified.slug, 'name': 'elt_2', 'with_best_classes': 'true'}, ) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() self.assertEqual(data['count'], 1) - self.assertEqual(data['results'][0]['id'], str(element2.id)) + self.assertEqual(data['results'][0]['id'], str(element.id)) self.assertListEqual(data['results'][0]['best_classes'], [ { "id": str(expected_classification.id), @@ -612,3 +469,119 @@ class TestClasses(FixtureAPITestCase): } } ]) + + def test_element_lists_invalid_class_filters(self): + corpus2 = Corpus.objects.create(name='Corpus 2') + other_class = corpus2.ml_classes.create(name='something') + + endpoints = [ + reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), + reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}), + reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), + ] + cases = [ + ({'class_id': 'lol'}, {'class_id': ['“lol†is not a valid UUID.']}), + ({'class_id': str(other_class.id)}, {'class_id': [f'ML class "{other_class.id}" not found']}), + ({'classification_confidence': 'very'}, {'classification_confidence': ["could not convert string to float: 'very'"]}), + ({'classification_confidence': 'nan'}, {'classification_confidence': ['Confidence must be between 0 and 1']}), + ({'classification_confidence': 'inf'}, {'classification_confidence': ['Confidence must be between 0 and 1']}), + ({'classification_confidence': '42'}, {'classification_confidence': ['Confidence must be between 0 and 1']}), + ({'classification_confidence': '-.5'}, {'classification_confidence': ['Confidence must be between 0 and 1']}), + ( + {'classification_confidence_operator': 'hah'}, + {'classification_confidence_operator': [ + 'This operator is not supported.', + 'This option is not supported without classification_confidence.', + ]} + ), + ( + {'classification_confidence': '0.3', 'classification_confidence_operator': 'lol'}, + {'classification_confidence_operator': ['This operator is not supported.']} + ) + ] + for endpoint in endpoints: + for data, expected_errors in cases: + with self.subTest(endpoint=endpoint, **data): + response = self.client.get(endpoint, data) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), expected_errors) + + def test_element_lists_class_filters(self): + # Create more diverse test data for these filters + Classification.objects.all().delete() + for i in range(1, 7): + self.corpus.elements.get(name=f'elt_{i}').classifications.create( + worker_version=self.version1, + ml_class=self.text, + confidence=i / 6, + # Give every even element a high confidence classification + high_confidence=i % 2 == 0 + ) + # Second half of elements gets self.cover + for i in range(7, 13): + self.corpus.elements.get(name=f'elt_{i}').classifications.create( + worker_version=self.version2, + ml_class=self.cover, + confidence=(i - 6) / 6, + high_confidence=i % 2 == 0 + ) + + endpoints = [ + reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), + reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}), + reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), + ] + cases = [ + ( + {'class_id': str(self.text.id)}, + ['elt_1', 'elt_2', 'elt_3', 'elt_4', 'elt_5', 'elt_6'], + ), + ( + {'classification_confidence': '0.5'}, + ['elt_3', 'elt_9'], + ), + ( + {'classification_confidence': '0.51'}, + [], + ), + ( + {'classification_confidence': '0.5', 'classification_confidence_operator': 'gt'}, + ['elt_4', 'elt_5', 'elt_6', 'elt_10', 'elt_11', 'elt_12'] + ), + ( + {'classification_confidence': '0.7', 'classification_confidence_operator': 'lte'}, + ['elt_1', 'elt_2', 'elt_3', 'elt_4', 'elt_7', 'elt_8', 'elt_9', 'elt_10'], + ), + ( + {'class_id': str(self.cover.id), 'classification_confidence': '0.7'}, + [], + ), + ( + {'classification_high_confidence': True}, + ['elt_2', 'elt_4', 'elt_6', 'elt_8', 'elt_10', 'elt_12'], + ), + ( + {'classification_high_confidence': False}, + ['elt_1', 'elt_3', 'elt_5', 'elt_7', 'elt_9', 'elt_11'], + ), + ( + {'class_id': str(self.text.id), 'classification_high_confidence': True}, + ['elt_2', 'elt_4', 'elt_6'], + ), + ( + {'class_id': str(self.text.id), 'classification_high_confidence': False}, + ['elt_1', 'elt_3', 'elt_5'], + ), + ] + for endpoint in endpoints: + for data, element_names in cases: + with self.subTest(endpoint=endpoint, **data): + response = self.client.get( + endpoint, + data={**data, 'type': self.classified.slug}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) + self.assertCountEqual( + [element['name'] for element in response.json()['results']], + element_names, + )