Skip to content
Snippets Groups Projects
Commit 2dae2711 authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'class-filters' into 'master'

Replace best_class with classification filters

Closes #908

See merge request !1566
parents 96560e0b a162c210
No related branches found
No related tags found
1 merge request!1566Replace best_class with classification filters
......@@ -43,6 +43,7 @@ from arkindex.documents.models import (
ElementType,
MetaData,
MetaType,
MLClass,
Selection,
Transcription,
)
......@@ -131,13 +132,20 @@ def _fetch_has_children(elements):
return elements
# Operators available for numeric filters in element lists
# Maps valid operator names in the API to Django QuerySet lookups
METADATA_OPERATORS = {
NUMERIC_OPERATORS = {
'eq': 'exact',
'lt': 'lt',
'gt': 'gt',
'lte': 'lte',
'gte': 'gte',
}
# Operators available for metadata values.
METADATA_OPERATORS = {
# Only for numeric metadata
**NUMERIC_OPERATORS,
# The contains operator should be case-insensitive
'contains': 'icontains',
}
......@@ -202,17 +210,6 @@ class ElementsListAutoSchema(AutoSchema):
# Add method-specific parameters
if self.method.lower() == 'get':
parameters.extend([
OpenApiParameter(
'best_class',
description='Restrict to or exclude elements with a best class, '
'or restrict to elements with specific best class',
type={
'oneOf': [
{'type': 'string', 'format': 'uuid'},
{'type': 'boolean'}
]
}
),
OpenApiParameter(
'metadata_name',
description='Restrict to elements having a metadata with the given name.',
......@@ -221,7 +218,7 @@ class ElementsListAutoSchema(AutoSchema):
OpenApiParameter(
'metadata_value',
description='Restrict to elements having a metadata with the given value. '
'Can be set it exclude elements with a value or filter by numerical values using `metadata_operator`. '
'The comparison operator can be set using `metadata_operator`. '
'Requires `metadata_name` to be set.',
required=False,
),
......@@ -243,6 +240,47 @@ class ElementsListAutoSchema(AutoSchema):
default='eq',
required=False,
),
OpenApiParameter(
'class_id',
description='Restrict to elements having a classification with the specified ML class ID.'
'If `classification_confidence` or `classification_high_confidence` are set, '
'the elements must have a classification that satisfies all of the parameters at once.',
type=UUID,
required=False,
),
OpenApiParameter(
'classification_confidence',
description='Restrict to elements having a classification with the given confidence. '
'The comparison operator can be set using `classification_confidence_operator`. '
'If `class_id` or `classification_high_confidence` are set, the elements must have a '
'classification that satisfies all of the parameters at once.',
required=False,
),
OpenApiParameter(
'classification_confidence_operator',
description=dedent("""
Set the comparison operator to filter on classification confidence scores:
* `eq` (default): Elements having a classification with this exact confidence.
* `lt`: Elements having a classification with a confidence strictly lower than the filter.
* `lte`: Elements having a classification with a confidence lower than or equal to the filter.
* `lt`: Elements having a classification with a confidence strictly greater than the filter.
* `gte`: Elements having a classification with a confidence greather than or equal to the filter.
This requires `classification_confidence` to be set.
"""),
enum=NUMERIC_OPERATORS.keys(),
default='eq',
required=False,
),
OpenApiParameter(
'classification_high_confidence',
description='Restrict to elements having a classification marked as `high_confidence`. '
'If `class_id` or `classification_confidence` are set, the elements must have a '
'classification that satisfies all of the parameters at once.',
type=bool,
required=False,
),
OpenApiParameter(
'order',
description='Sort elements by a specific field',
......@@ -260,7 +298,7 @@ class ElementsListAutoSchema(AutoSchema):
OpenApiParameter(
'with_best_classes',
description='Returns best classifications for each element. '
'If not set, elements best_classes field will always be null',
'Otherwise, `best_classes` will always be null.',
type=bool,
required=False,
),
......@@ -460,6 +498,66 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
return queryset.values('element_id')
def get_classification_queryset(self):
"""
Returns a queryset that includes matched element IDs from classification filters,
or None if no classification filters apply.
"""
class_id = self.clean_params.get('class_id')
confidence = self.clean_params.get('classification_confidence')
confidence_operator = self.clean_params.get('classification_confidence_operator', '').lower().strip()
high_confidence = self.clean_params.get('classification_high_confidence', '').lower().strip()
if len(high_confidence):
high_confidence = high_confidence not in ('false', '0')
else:
# An empty string should be treated as no filter at all
high_confidence = None
if not class_id and confidence is None and not confidence_operator and high_confidence is None:
# No filters apply
return None
queryset = Classification.objects.all()
errors = defaultdict(list)
if class_id:
try:
ml_class = self.selected_corpus.ml_classes.get(id=class_id)
except DjangoValidationError as e:
# An invalid UUID would cause a Django ValidationError
errors['class_id'].extend(e.messages)
except MLClass.DoesNotExist:
errors['class_id'].append(f'ML class "{class_id}" not found')
else:
queryset = queryset.filter(ml_class=ml_class)
if confidence_operator:
if confidence_operator not in NUMERIC_OPERATORS:
errors['classification_confidence_operator'].append('This operator is not supported.')
if confidence is None:
errors['classification_confidence_operator'].append('This option is not supported without classification_confidence.')
else:
confidence_operator = 'eq'
if confidence:
try:
confidence = float(confidence)
assert 0 <= confidence <= 1, 'Confidence must be between 0 and 1'
except (TypeError, ValueError, AssertionError) as e:
errors['classification_confidence'].append(str(e))
else:
lookup = NUMERIC_OPERATORS.get(confidence_operator, 'exact')
queryset = queryset.filter(**{f'confidence__{lookup}': confidence})
if high_confidence is not None:
queryset = queryset.filter(high_confidence=high_confidence)
if errors:
raise ValidationError(errors)
return queryset.values('element_id')
def get_filters(self):
filters = {
'corpus': self.selected_corpus
......@@ -500,43 +598,19 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
if metadata_queryset is not None:
filters['id__in'] = metadata_queryset
try:
classification_queryset = self.get_classification_queryset()
except ValidationError as e:
errors.update(e.detail)
else:
if classification_queryset is not None:
filters['id__in'] = classification_queryset
if errors:
raise ValidationError(errors)
return filters
def get_classifications_filters(self):
"""
Build Django ORM filters using Q expressions related to best classes.
Supports 3 modes:
- elements without any best classes
- elements with any best classes
- elements with a specific best class
"""
class_filter = self.clean_params.get('best_class')
if class_filter is None:
return
# Generic ORM query to find best classes:
# - elements with a validated classification
# - OR where high confidence is True and not rejected
best_classifications = Q(classifications__state=ClassificationState.Validated) | (
Q(classifications__high_confidence=True)
& ~Q(classifications__state=ClassificationState.Rejected)
)
# List elements without any best classes, by inverting the query above
if class_filter.lower() in ('false', '0'):
return ~best_classifications
try:
# Filter on a specific class
class_filter = uuid.UUID(class_filter)
return best_classifications & Q(classifications__ml_class_id=class_filter)
except (TypeError, ValueError):
# By default, use all best classifications
return best_classifications
def get_serializer_context(self):
context = super().get_serializer_context()
context['corpus'] = self.selected_corpus
......@@ -577,11 +651,6 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
.prefetch_related(*self.get_prefetch()) \
.order_by(*self.get_order_by())
class_filters = self.get_classifications_filters()
if class_filters is not None:
# Use queryset.distinct() whenever best_class is defined
queryset = queryset.filter(class_filters).distinct()
with_has_children = self.clean_params.get('with_has_children')
if with_has_children and with_has_children.lower() not in ('false', '0'):
queryset = BulkMap(_fetch_has_children, queryset)
......
......@@ -424,165 +424,22 @@ class TestClasses(FixtureAPITestCase):
[(str(self.version1.id), .99), (str(self.version2.id), .99)]
)
def test_class_filter_list_elements(self):
element = Element.objects.filter(type=self.classified.id).first()
element.classifications.create(
ml_class=self.text,
confidence=.1337,
high_confidence=True,
)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
data={'type': self.classified.slug, 'best_class': str(self.text.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(element.id))
def test_class_filter_list_parents(self):
parent = Element.objects.get_ascending(self.common_children.id).get(name='elt_1')
self.assertEqual(parent.classifications.filter(ml_class=self.text).count(), 2)
parent.classifications.filter(ml_class=self.text).update(state=ClassificationState.Validated)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
data={'type': self.classified.slug, 'best_class': str(self.text.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(parent.id))
def test_class_filter_list_children(self):
child = Element.objects.filter(type=self.classified.id).first()
child.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}),
data={'type': self.classified.slug, 'best_class': str(self.text.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(child.id))
def test_class_filter_list_elements_distinct(self):
self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24)
self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
data={'type': self.classified.slug, 'best_class': str(self.cover.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 12)
# Ensure element IDs are unique
ids = [e['id'] for e in data['results']]
self.assertCountEqual(ids, set(ids))
def test_class_filter_list_parents_distinct(self):
self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24)
self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
data={'type': self.classified.slug, 'best_class': str(self.cover.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 12)
# Ensure element IDs are unique
ids = [e['id'] for e in data['results']]
self.assertCountEqual(ids, set(ids))
def test_class_filter_list_children_distinct(self):
self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24)
self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}),
data={'type': self.classified.slug, 'best_class': str(self.cover.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 12)
# Ensure element IDs are unique
ids = [e['id'] for e in data['results']]
self.assertCountEqual(ids, set(ids))
def test_class_filter_true(self):
element = Element.objects.filter(type=self.classified.id).first()
element.classifications.all().delete()
element.classifications.create(
worker_version=self.version2,
ml_class_id=self.text.id,
confidence=.1337,
high_confidence=True,
)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
data={'type': self.classified.slug, 'best_class': 'true', 'with_best_classes': 'true'},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 12)
best_class_ids = set(
best_class['ml_class']['id']
for element in data['results']
for best_class in element['best_classes']
)
self.assertSetEqual(best_class_ids, {str(self.text.id), str(self.cover.id)})
def test_class_filter_false(self):
element = Element.objects.filter(type=self.classified.id).first()
element.classifications.all().delete()
element.classifications.create(
worker_version=self.version2,
ml_class_id=self.text.id,
confidence=.1337,
high_confidence=False,
)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
data={'type': self.classified.slug, 'best_class': 'false', 'with_best_classes': 'true'}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(element.id))
self.assertListEqual(data['results'][0]['best_classes'], [])
def test_exclude_rejected(self):
def test_with_best_classes_exclude_rejected(self):
"""
Ensure that best_classes and with_best_classes ignore rejected high confidence classifications
Ensure that with_best_classes ignores rejected high confidence classifications
"""
Classification.objects.all().delete()
# One element with only a rejected classification; should be ignored
element1 = Element.objects.get(type=self.classified.id, name='elt_1')
element1.classifications.create(
worker_version=self.version2,
ml_class_id=self.text.id,
confidence=.1337,
high_confidence=True,
state='rejected'
)
# One element with a rejected classification and a best class, should be included with 1 best class
element2 = Element.objects.get(type=self.classified.id, name='elt_2')
element2.classifications.create(
element = Element.objects.get(type=self.classified.id, name='elt_2')
element.classifications.create(
worker_version=self.version1,
ml_class_id=self.text.id,
confidence=.1337,
high_confidence=True,
state='rejected'
)
expected_classification = element2.classifications.create(
expected_classification = element.classifications.create(
worker_version=self.version2,
ml_class_id=self.text.id,
confidence=.1337,
......@@ -592,13 +449,13 @@ class TestClasses(FixtureAPITestCase):
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
data={'type': self.classified.slug, 'best_class': 'true', 'with_best_classes': 'true'},
data={'type': self.classified.slug, 'name': 'elt_2', 'with_best_classes': 'true'},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(element2.id))
self.assertEqual(data['results'][0]['id'], str(element.id))
self.assertListEqual(data['results'][0]['best_classes'], [
{
"id": str(expected_classification.id),
......@@ -612,3 +469,119 @@ class TestClasses(FixtureAPITestCase):
}
}
])
def test_element_lists_invalid_class_filters(self):
corpus2 = Corpus.objects.create(name='Corpus 2')
other_class = corpus2.ml_classes.create(name='something')
endpoints = [
reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}),
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
]
cases = [
({'class_id': 'lol'}, {'class_id': ['“lol” is not a valid UUID.']}),
({'class_id': str(other_class.id)}, {'class_id': [f'ML class "{other_class.id}" not found']}),
({'classification_confidence': 'very'}, {'classification_confidence': ["could not convert string to float: 'very'"]}),
({'classification_confidence': 'nan'}, {'classification_confidence': ['Confidence must be between 0 and 1']}),
({'classification_confidence': 'inf'}, {'classification_confidence': ['Confidence must be between 0 and 1']}),
({'classification_confidence': '42'}, {'classification_confidence': ['Confidence must be between 0 and 1']}),
({'classification_confidence': '-.5'}, {'classification_confidence': ['Confidence must be between 0 and 1']}),
(
{'classification_confidence_operator': 'hah'},
{'classification_confidence_operator': [
'This operator is not supported.',
'This option is not supported without classification_confidence.',
]}
),
(
{'classification_confidence': '0.3', 'classification_confidence_operator': 'lol'},
{'classification_confidence_operator': ['This operator is not supported.']}
)
]
for endpoint in endpoints:
for data, expected_errors in cases:
with self.subTest(endpoint=endpoint, **data):
response = self.client.get(endpoint, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), expected_errors)
def test_element_lists_class_filters(self):
# Create more diverse test data for these filters
Classification.objects.all().delete()
for i in range(1, 7):
self.corpus.elements.get(name=f'elt_{i}').classifications.create(
worker_version=self.version1,
ml_class=self.text,
confidence=i / 6,
# Give every even element a high confidence classification
high_confidence=i % 2 == 0
)
# Second half of elements gets self.cover
for i in range(7, 13):
self.corpus.elements.get(name=f'elt_{i}').classifications.create(
worker_version=self.version2,
ml_class=self.cover,
confidence=(i - 6) / 6,
high_confidence=i % 2 == 0
)
endpoints = [
reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}),
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
]
cases = [
(
{'class_id': str(self.text.id)},
['elt_1', 'elt_2', 'elt_3', 'elt_4', 'elt_5', 'elt_6'],
),
(
{'classification_confidence': '0.5'},
['elt_3', 'elt_9'],
),
(
{'classification_confidence': '0.51'},
[],
),
(
{'classification_confidence': '0.5', 'classification_confidence_operator': 'gt'},
['elt_4', 'elt_5', 'elt_6', 'elt_10', 'elt_11', 'elt_12']
),
(
{'classification_confidence': '0.7', 'classification_confidence_operator': 'lte'},
['elt_1', 'elt_2', 'elt_3', 'elt_4', 'elt_7', 'elt_8', 'elt_9', 'elt_10'],
),
(
{'class_id': str(self.cover.id), 'classification_confidence': '0.7'},
[],
),
(
{'classification_high_confidence': True},
['elt_2', 'elt_4', 'elt_6', 'elt_8', 'elt_10', 'elt_12'],
),
(
{'classification_high_confidence': False},
['elt_1', 'elt_3', 'elt_5', 'elt_7', 'elt_9', 'elt_11'],
),
(
{'class_id': str(self.text.id), 'classification_high_confidence': True},
['elt_2', 'elt_4', 'elt_6'],
),
(
{'class_id': str(self.text.id), 'classification_high_confidence': False},
['elt_1', 'elt_3', 'elt_5'],
),
]
for endpoint in endpoints:
for data, element_names in cases:
with self.subTest(endpoint=endpoint, **data):
response = self.client.get(
endpoint,
data={**data, 'type': self.classified.slug},
)
self.assertEqual(response.status_code, status.HTTP_200_OK, response.json())
self.assertCountEqual(
[element['name'] for element in response.json()['results']],
element_names,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment