Skip to content
Snippets Groups Projects
Commit c844adec authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

Merge branch 'classes-filters' into 'master'

Class filter on ListElement{s,Parents,Children}

See merge request !558
parents 0a8c778f d4e5a88c
No related branches found
No related tags found
1 merge request!558Class filter on ListElement{s,Parents,Children}
......@@ -77,6 +77,13 @@ class ElementsList(CorpusACLMixin, ListAPIView):
'required': False,
'schema': {'type': 'string'},
},
{
'name': 'best_class',
'in': 'query',
'description': 'Filter elements by best class id',
'required': False,
'schema': {'type': 'string', 'format': 'uuid'},
},
{
'name': 'hidden',
'in': 'query',
......@@ -140,6 +147,17 @@ class ElementsList(CorpusACLMixin, ListAPIView):
filtered_queryset = super().filter_queryset(queryset).filter(**filters)
class_filter = self.request.query_params.get('best_class')
if class_filter is not None:
filtered_queryset = filtered_queryset.filter(
classifications__in=Classification.objects.all()
.filter(
Q(state=ClassificationState.Validated) | Q(best=True),
ml_class_id=class_filter
)
.exclude(state=ClassificationState.Rejected)
)
best_classes = self.request.query_params.get('best_classes')
if best_classes and best_classes.lower() not in ('false', '0'):
filtered_queryset = filtered_queryset.prefetch_related(best_classifications_prefetch)
......@@ -316,6 +334,13 @@ class ElementParents(ListAPIView):
'default': False
}
},
{
'name': 'best_class',
'in': 'query',
'description': 'Filter elements by best class id',
'required': False,
'schema': {'type': 'string', 'format': 'uuid'},
},
{
'name': 'hidden',
'in': 'query',
......@@ -377,6 +402,13 @@ class ElementParents(ListAPIView):
if best_classes and best_classes.lower() not in ('false', '0'):
prefetch_related_lookups += (best_classifications_prefetch, )
class_filter = self.request.query_params.get('best_class')
if class_filter is not None:
filters['classifications__in'] = Classification.objects.all().filter(
Q(state=ClassificationState.Validated) | Q(best=True),
ml_class_id=class_filter
).exclude(state=ClassificationState.Rejected)
if recursive_param is None or recursive_param.lower() in ('false', '0'):
# List direct parents only: elements whose IDs are the last in the element's paths
return Element.objects.filter(
......@@ -387,6 +419,7 @@ class ElementParents(ListAPIView):
).prefetch_related(*prefetch_related_lookups)
parents = Element.objects.get_ascending(self.kwargs['pk'], **filters)
prefetch_related_objects(parents, *prefetch_related_lookups)
return parents
......@@ -431,6 +464,13 @@ class ElementChildren(ListAPIView):
'default': False
}
},
{
'name': 'best_class',
'in': 'query',
'description': 'Filter elements by best class id',
'required': False,
'schema': {'type': 'string', 'format': 'uuid'},
},
{
'name': 'hidden',
'in': 'query',
......@@ -500,6 +540,13 @@ class ElementChildren(ListAPIView):
# so we append the __last filter to it.
filters['paths__path__last'] = self.kwargs['pk']
class_filter = self.request.query_params.get('best_class')
if class_filter is not None:
filters['classifications__in'] = Classification.objects.all().filter(
Q(state=ClassificationState.Validated) | Q(best=True),
ml_class_id=class_filter
).exclude(state=ClassificationState.Rejected)
return super().filter_queryset(
Element.objects.get_descending(self.kwargs['pk'])
).filter(**filters).prefetch_related(*prefetch_related_lookups)
......
......@@ -14,6 +14,26 @@ class TestClasses(FixtureAPITestCase):
cls.cover = MLClass.objects.create(name='cover', corpus=cls.corpus)
cls.classified = cls.corpus.types.create(slug='classified', folder=True)
def populate_classified_elements(self):
self.folder_type = self.corpus.types.create(slug='folder', folder=True)
self.parent = self.corpus.elements.create(type=self.folder_type)
self.common_children = self.corpus.elements.create(type=self.folder_type)
for elt_num in range(1, 13):
elt = Element.objects.create(
name='elt_{}'.format(elt_num),
type=self.classified,
corpus_id=self.corpus.id
)
elt.add_parent(self.parent)
self.common_children.add_parent(elt)
for ml_class, score in zip((self.text, self.cover), (.7, .99)):
elt.classifications.create(
source_id=DataSource.objects.get(slug='test').id,
ml_class_id=ml_class.id,
confidence=score,
best=bool(score == .99)
)
def test_list(self):
"""
Test listing results alpha-ordered
......@@ -168,26 +188,6 @@ class TestClasses(FixtureAPITestCase):
response = self.client.delete(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}))
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def populate_classified_elements(self):
self.folder_type = self.corpus.types.create(slug='folder', folder=True)
self.parent = self.corpus.elements.create(type=self.folder_type)
self.common_children = self.corpus.elements.create(type=self.folder_type)
for elt_num in range(1, 13):
elt = Element.objects.create(
name='elt_{}'.format(elt_num),
type=self.classified,
corpus_id=self.corpus.id
)
elt.add_parent(self.parent)
self.common_children.add_parent(elt)
for ml_class, score in zip((self.text, self.cover), (.7, .99)):
elt.classifications.create(
source_id=DataSource.objects.get(slug='test').id,
ml_class_id=ml_class.id,
confidence=score,
best=bool(score == .99)
)
def test_list_elements_db_queries(self):
self.populate_classified_elements()
with self.assertNumQueries(5):
......@@ -344,3 +344,50 @@ class TestClasses(FixtureAPITestCase):
list(map(lambda c: (c['source']['slug'], c['confidence']), elt['best_classes'])),
[('test', .99)]
)
def test_class_filter_list_elements(self):
self.populate_classified_elements()
element = Element.objects.filter(type=self.classified.id).first()
element.classifications.create(
source_id=DataSource.objects.create(type=MLToolType.NER, slug='ner', internal=False).id,
ml_class_id=self.text.id,
confidence=.1337,
best=True,
)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements'),
data={'type': self.classified.slug, 'best_class': str(self.text.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(element.id))
def test_class_filter_list_parents(self):
self.populate_classified_elements()
parent = Element.objects.get_ascending(self.common_children.id)[-1]
parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
data={'type': self.classified.slug, 'best_class': str(self.text.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(parent.id))
def test_class_filter_list_children(self):
self.populate_classified_elements()
child = Element.objects.filter(type=self.classified.id).first()
child.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}),
data={'type': self.classified.slug, 'best_class': str(self.text.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 1)
self.assertEqual(data['results'][0]['id'], str(child.id))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment