Skip to content
Snippets Groups Projects
Commit 0b941c06 authored by Valentin Rigal's avatar Valentin Rigal Committed by Bastien Abadie
Browse files

Revert "Avoid retrieving related elts in case of invalid corpus"

This reverts commit f02f68d6fe700d6c46fdd37acf2e1e8223b0e5a1.
parent 975dd9f4
No related branches found
No related tags found
1 merge request!1196Move methods to list elements of a process in dataimport.models
......@@ -47,7 +47,7 @@ from arkindex.dataimport.serializers.imports import (
)
from arkindex.dataimport.serializers.workers import RepositorySerializer, WorkerSerializer, WorkerVersionSerializer
from arkindex.documents.api.elements import ElementsListMixin
from arkindex.documents.models import ClassificationState, Corpus, Element, ElementType
from arkindex.documents.models import Corpus, ElementType
from arkindex.project.fields import ArrayRemove
from arkindex.project.mixins import CorpusACLMixin, CustomPaginationViewMixin, DeprecatedMixin, SelectionMixin
from arkindex.project.openapi import AutoSchema
......@@ -907,110 +907,18 @@ class ListProcessElements(CustomPaginationViewMixin, ListAPIView):
'tags': ['imports']
}
def retrieve_elements(self, dataimport):
elements = None
if dataimport.element:
# Handle a single element
if dataimport.element.corpus_id != dataimport.corpus_id:
raise ValidationError({
'element_id': [
'Element {} is not part of corpus {}'.format(
dataimport.element.id, dataimport.corpus.id
)
]
})
elements = Element.objects.filter(id=dataimport.element.id)
elif dataimport.elements.exists():
# Handle a selection of elements
# The list() is necessary to make a unique SQL request to check length & content
# It's more performant, as we usually have only 1 corpus in here
elements = dataimport.elements.all()
corpus = list(elements.values_list('corpus_id', flat=True).distinct())
# Check all elements are in the same corpus as the process
if len(corpus) != 1 or corpus[0] != dataimport.corpus_id:
raise ValidationError({
'elements': [
'Some elements on this process are not part of corpus {}'.format(
dataimport.corpus.id
)]
})
# Apply base filters as early as possible to trim results
base_filters = self.get_filters(dataimport)
if elements is not None and dataimport.load_children:
# Load all the children elements whose path contains the pre-selected elements
# Those children are appended to the pre-selection
elements |= Element.objects.filter(
paths__path__overlap=map(str, elements.values_list('id', flat=True)),
**base_filters
)
# Load the full corpus, only when elements has not been populated before
if elements is None:
# Handle all elements of the process corpus
elements = Element.objects.filter(corpus=dataimport.corpus_id)
# Filter elements depending on process properties
elements = elements.filter(**base_filters)
class_filters = self.get_classifications_filters(dataimport)
if class_filters is not None:
elements = elements.filter(class_filters).distinct()
# Only retrieve necessary values for the serializer
return elements.values('id', 'type__slug', 'name')
def get_filters(self, dataimport):
filters = {
"corpus_id": dataimport.corpus_id,
}
if dataimport.name_contains:
filters['name__contains'] = dataimport.name_contains
if dataimport.element_type:
filters['type_id'] = dataimport.element_type_id
else:
# Limit the scope of types available to merge
# This prevent memory from exploding when no type is selected
filters['type__corpus_id'] = dataimport.corpus_id
return filters
def get_classifications_filters(self, dataimport):
if dataimport.best_class is None:
return
# Generic ORM query to find best classes:
# - elements with a validated classification
# - OR where high confidence is True
best_classifications = Q(classifications__state=ClassificationState.Validated) \
| Q(classifications__high_confidence=True)
# List elements without any best classes, by inverting the query above
if dataimport.best_class in ('false', '0'):
return ~best_classifications
try:
# Filter on a specific class
class_filter = UUID(dataimport.best_class)
return best_classifications & Q(classifications__ml_class_id=class_filter)
except (TypeError, ValueError):
# By default, use all best classifications
return best_classifications
def get_queryset(self):
dataimport = get_object_or_404(
# Avoid stale read on newly created dataimports
DataImport.objects.all().using('default'),
DataImport.objects.using('default'),
Q(pk=self.kwargs['pk'])
& Q(
Q(corpus__isnull=False)
& Q(corpus__in=Corpus.objects.readable(self.request.user))
Q(corpus_id__isnull=False)
& Q(corpus_id__in=Corpus.objects.readable(self.request.user).values('id'))
)
)
if dataimport.mode not in (DataImportMode.Elements, DataImportMode.Workers):
return Element.objects.none()
return self.retrieve_elements(dataimport)
try:
return dataimport.list_elements().values('id', 'type__slug', 'name')
except AssertionError as e:
raise ValidationError({'__all__': [str(e)]})
......@@ -2,16 +2,19 @@ import shlex
import urllib.parse
import uuid
from os import path
from uuid import UUID
import yaml
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from django.db.models import Q
from django.utils.functional import cached_property
from enumfields import Enum, EnumField
from rest_framework.exceptions import ValidationError
from arkindex.dataimport.providers import get_provider, git_providers
from arkindex.documents.models import ClassificationState, Element
from arkindex.project.aws import S3FileMixin, S3FileStatus
from arkindex.project.fields import ArrayField
from arkindex.project.models import IndexableModel
......@@ -93,6 +96,88 @@ class DataImport(IndexableModel):
else:
return self.workflow.state
def _get_filters(self):
filters = {
"corpus_id": self.corpus_id,
}
if self.name_contains:
filters['name__contains'] = self.name_contains
if self.element_type:
filters['type_id'] = self.element_type_id
else:
# Limit the scope of types available to merge
# This prevent memory from exploding when no type is selected
filters['type__corpus_id'] = self.corpus_id
return filters
def _get_classifications_filters(self):
if self.best_class is None:
return
# Generic ORM query to find best classes:
# - elements with a validated classification
# - OR where high confidence is True
best_classifications = Q(classifications__state=ClassificationState.Validated) \
| Q(classifications__high_confidence=True)
# List elements without any best classes, by inverting the query above
if self.best_class in ('false', '0'):
return ~best_classifications
try:
# Filter on a specific class
class_filter = UUID(self.best_class)
return best_classifications & Q(classifications__ml_class_id=class_filter)
except (TypeError, ValueError):
# By default, use all best classifications
return best_classifications
def list_elements(self):
"""
Return a queryset of elements involved in this process
"""
if self.mode not in (DataImportMode.Elements, DataImportMode.Workers):
return Element.objects.none()
elements = None
if self.element:
# Assert the element has the same corpus as the process
assert self.element.corpus_id == self.corpus_id, \
f'Element {self.element.id} is not part of corpus {self.corpus_id}'
# Handle a single element
elements = Element.objects.filter(id=self.element.id)
elif self.elements.exists():
# Handle a selection of elements
elements = self.elements.all()
# Check all elements are in the same corpus as the process
corpus = list(elements.values_list('corpus_id', flat=True).distinct())
assert len(corpus) == 1 and corpus[0] == self.corpus_id, \
f'Some elements on this process are not part of corpus {self.corpus_id}'
if elements is not None and self.load_children:
# Load all the children elements whose path contains the pre-selected elements
# Those children are appended to the pre-selection
elements |= Element.objects.filter(
paths__path__overlap=map(str, elements.values_list('id', flat=True)),
)
# Load the full corpus, only when elements has not been populated before
if elements is None:
# Handle all elements of the process corpus
elements = Element.objects.filter(corpus=self.corpus_id)
# Filter elements depending on process properties
elements = elements.filter(**self._get_filters())
class_filters = self._get_classifications_filters()
if class_filters is not None:
# Distinct is required because multiple classes may match the filter
elements = elements.filter(class_filters).distinct()
return elements
def build_workflow(self, chunks=None, thumbnails=False, corpus_id=None):
'''
Create a ponos workflow with a recipe according to configuration
......@@ -220,7 +305,6 @@ class DataImport(IndexableModel):
def start(self, chunks=None, thumbnails=False, corpus_id=None):
self.workflow = self.build_workflow(chunks, thumbnails, corpus_id)
self.save()
# Start the associated workflow
self.workflow.start()
def retry(self):
......
......@@ -192,19 +192,25 @@ class TestProcessElements(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_filter_elements_wrong_corpus(self):
"""
Selected elements must be part of the same corpus as the process
"""
self.dataimport.corpus = Corpus.objects.create(name='private')
self.dataimport.save()
self.dataimport.elements.add(self.page_1.id, self.folder_2.id)
self.client.force_login(self.superuser)
response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
data = response.json()
self.assertEqual(data, {
'elements': [f'Some elements on this process are not part of corpus {self.dataimport.corpus.id}']
with self.assertNumQueries(5):
response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}))
self.assertDictEqual(response.json(), {
'__all__': [f'Some elements on this process are not part of corpus {self.dataimport.corpus_id}']
})
def test_filter_elements_multiple_corpus(self):
"""
Selected elements are part of multiple corpora
This should not happen in usual situations
"""
corpus = Corpus.objects.create(name='private')
element = Element.objects.create(
corpus=corpus,
......@@ -214,24 +220,26 @@ class TestProcessElements(FixtureAPITestCase):
self.dataimport.elements.add(element.id, self.page_1.id)
self.client.force_login(self.superuser)
response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
data = response.json()
self.assertEqual(data, {
'elements': [f'Some elements on this process are not part of corpus {self.dataimport.corpus.id}']
with self.assertNumQueries(5):
response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}))
self.assertDictEqual(response.json(), {
'__all__': [f'Some elements on this process are not part of corpus {self.dataimport.corpus_id}']
})
def test_filter_element_wrong_corpus(self):
"""
Element and corpus are different
This should not happen in usual situations
"""
self.dataimport.corpus = Corpus.objects.create(name='private')
self.dataimport.element = self.page_1
self.dataimport.save()
self.client.force_login(self.superuser)
response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
data = response.json()
self.assertEqual(data, {
'element_id': [f'Element {self.page_1.id} is not part of corpus {self.dataimport.corpus.id}']
with self.assertNumQueries(4):
response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}))
self.assertDictEqual(response.json(), {
'__all__': [f'Element {self.page_1.id} is not part of corpus {self.dataimport.corpus_id}']
})
def test_filter_name(self):
......
......@@ -442,9 +442,10 @@ class TestWorkflows(FixtureAPITestCase):
self.assertIsNone(dataimport_2.workflow)
self.client.force_login(self.user)
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)})
)
with self.assertNumQueries(35):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['id'], str(dataimport_2.id))
dataimport_2.refresh_from_db()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment