Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (26)
Showing
with 1319 additions and 635 deletions
1.6.1-rc1
1.6.2-alpha1
......@@ -88,6 +88,7 @@ from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import BulkMap
from arkindex.project.triggers import (
corpus_delete,
element_delete,
element_trash,
move_element,
selection_worker_results_delete,
......@@ -944,6 +945,7 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
def filter_queryset(self, queryset):
queryset = queryset \
.filter(self.get_filters()) \
.select_related("worker_run") \
.prefetch_related(*self.get_prefetch()) \
.order_by(*self.get_order_by())
......@@ -1231,12 +1233,13 @@ class ElementRetrieve(ACLMixin, RetrieveUpdateDestroyAPIView):
corpora = Corpus.objects.readable(self.request.user)
queryset = Element.objects.filter(corpus__in=corpora)
if self.request and self.request.method == "DELETE":
# Only include corpus and creator for ACL check and ID for deletion
# Only include the corpus and creator for ACL check,
# the ID for the deletion task, and the type and name for the task's description.
return (
queryset
.select_related("corpus")
.select_related("corpus", "type")
.annotate(has_dataset=Exists(DatasetElement.objects.filter(element_id=OuterRef("pk"))))
.only("id", "creator_id", "corpus")
.only("id", "creator_id", "corpus", "type__display_name", "name")
)
return (
......@@ -1274,12 +1277,12 @@ class ElementRetrieve(ACLMixin, RetrieveUpdateDestroyAPIView):
return context
def delete(self, request, *args, **kwargs):
self.check_object_permissions(self.request, self.get_object())
element = self.get_object()
self.check_object_permissions(self.request, element)
queryset = Element.objects.filter(id=self.kwargs["pk"])
delete_children = self.request.query_params.get("delete_children", "false").lower() not in ("false", "0")
element_trash(queryset, user_id=self.request.user.id, delete_children=delete_children)
element_delete(element, user_id=self.request.user.id, delete_children=delete_children)
return Response(status=status.HTTP_204_NO_CONTENT)
......
......@@ -53,11 +53,14 @@ class CorpusExportAPIView(ListCreateAPIView):
return corpus
def get_queryset(self):
return CorpusExport \
.objects \
.filter(corpus=self.corpus) \
.select_related("user") \
return (
CorpusExport.objects
# Avoid a stale read when an export has just been created
.using("default")
.filter(corpus=self.corpus)
.select_related("user")
.order_by("-created")
)
def get_serializer_context(self):
context = super().get_serializer_context()
......
......@@ -11,6 +11,7 @@ from pathlib import Path
from django.conf import settings
from django.core.mail import send_mail
from django.db import connections
from django.db.utils import InterfaceError, OperationalError
from django.template.loader import render_to_string
from django.utils.text import slugify
from django_rq import job
......@@ -118,7 +119,35 @@ def send_email(subject, template_name, corpus_export, **context):
logger.error(f"Failed to send email to {corpus_export.user.email}")
def update_state(corpus_export: CorpusExport, state: CorpusExportState):
"""
Make sure the corpus export instance from default DB is still available for updates
Sometimes the DB connection may drop after a really long export, especially with remote DB exports
"""
try:
corpus_export.state = state
corpus_export.save()
except (InterfaceError, OperationalError) as e:
logger.warning(f"Database connection has been lost, retrying: {e}")
connections["default"].connection = None
connections["default"].connect()
corpus_export.refresh_from_db(using="default")
corpus_export.state = state
corpus_export.save()
@job("high", timeout=settings.RQ_TIMEOUTS["export_corpus"])
def local_export(corpus_export: CorpusExport) -> None:
assert corpus_export.source == "default"
export_corpus(corpus_export)
@job("export", timeout=settings.RQ_TIMEOUTS["export_corpus"])
def remote_export(corpus_export: CorpusExport) -> None:
assert corpus_export.source != "default"
export_corpus(corpus_export)
def export_corpus(corpus_export: CorpusExport) -> None:
_, db_path = tempfile.mkstemp(suffix=".db")
try:
......@@ -181,8 +210,8 @@ def export_corpus(corpus_export: CorpusExport) -> None:
},
)
corpus_export.state = CorpusExportState.Done
corpus_export.save()
# Safely update state with auto-reconnect to db
update_state(corpus_export, CorpusExportState.Done)
send_email(
"Arkindex project export completed",
......@@ -190,8 +219,9 @@ def export_corpus(corpus_export: CorpusExport) -> None:
corpus_export,
)
except Exception as e:
corpus_export.state = CorpusExportState.Failed
corpus_export.save()
# Safely update state with auto-reconnect to db
update_state(corpus_export, CorpusExportState.Failed)
send_email(
"Arkindex project export failed",
"export_error.html",
......
This diff is collapsed.
......@@ -14,7 +14,6 @@ from arkindex.process.models import (
ProcessMode,
Repository,
Worker,
WorkerRun,
WorkerType,
WorkerVersion,
WorkerVersionState,
......@@ -100,10 +99,29 @@ class Command(BaseCommand):
gpu_worker_type = WorkerType.objects.create(slug="worker", display_name="Worker requiring a GPU")
import_worker_type = WorkerType.objects.create(slug="import", display_name="Import")
custom_worker_type = WorkerType.objects.create(slug="custom", display_name="Custom")
init_type = WorkerType.objects.create(slug="init", display_name="Elements Initialisation")
farm = Farm.objects.create(name="Wheat farm")
farm.memberships.create(user=user, level=Role.Guest.value)
# Create the elements initialisation worker version
init_worker = WorkerVersion.objects.create(
worker=Worker.objects.create(
name="Elements Initialisation Worker",
slug="initialisation",
type=init_type,
),
revision=None,
version=1,
configuration={
"docker": {
"command": "worker-init-elements"
}
},
state=WorkerVersionState.Available,
docker_image_iid="registry.gitlab.teklia.com/arkindex/workers/init-elements:latest"
)
# Create some workers with available versions
recognizer_worker = WorkerVersion.objects.create(
worker=worker_repo.workers.create(
......@@ -194,14 +212,11 @@ class Command(BaseCommand):
mode=ProcessMode.Local,
creator=superuser,
)
WorkerRun.objects.create(
process=user_local_process,
user_local_process.worker_runs.create(
version=custom_version,
parents=[],
)
WorkerRun.objects.create(
process=superuser_local_process,
superuser_local_process.worker_runs.create(
version=custom_version,
parents=[],
)
......@@ -223,9 +238,13 @@ class Command(BaseCommand):
corpus=corpus,
name="Process fixture",
)
init_worker_run = process.worker_runs.create(
version=init_worker,
parents=[]
)
dla_worker_run = process.worker_runs.create(
version=dla_worker,
parents=[],
parents=[init_worker_run.id],
)
process.worker_runs.create(
version=recognizer_worker,
......@@ -276,10 +295,10 @@ class Command(BaseCommand):
dataset_2 = corpus.datasets.create(name="Second Dataset", description="dataset number two", creator=user)
# Create their sets
DatasetSet.objects.bulk_create(
DatasetSet(name=name, dataset_id=dataset_1.id) for name in ["training", "validation", "test"]
DatasetSet(name=name, dataset_id=dataset_1.id) for name in ["train", "dev", "test"]
)
DatasetSet.objects.bulk_create(
DatasetSet(name=name, dataset_id=dataset_2.id) for name in ["training", "validation", "test"]
DatasetSet(name=name, dataset_id=dataset_2.id) for name in ["train", "dev", "test"]
)
# Create 2 volumes
......
......@@ -11,7 +11,7 @@ from django.contrib.postgres.indexes import GinIndex
from django.core.exceptions import ValidationError
from django.core.validators import MaxValueValidator, MinValueValidator, RegexValidator
from django.db import connections, models, transaction
from django.db.models import Deferrable, Q
from django.db.models import Count, Deferrable, Q, Window
from django.db.models.functions import Cast, Least
from django.urls import reverse
from django.utils.functional import cached_property
......@@ -645,6 +645,82 @@ class Element(IndexableModel):
# Now that the child's descendants are handled, we can clean up the child's own paths.
child.paths.filter(path__last=self.id).delete()
@transaction.atomic
def remove_children(self):
"""
Remove this parent element from all of its children at once
"""
# Fetch two values that we will need to detect which queries to run, and build them:
# - How many paths this parent has, so we can tell if we need to delete paths and not only update them
# - One path on the parent, so that we can perform the update.
# In the rare edge case where the element has zero paths, this returns nothing, so we'll act as if there was a top-level path.
first_parent_path, parent_paths_count = (
self.paths
.using("default")
.annotate(count=Window(Count("*")))
.values_list("path", "count")
# Unnecessary for this algorithm to work, but simplifies unit testing a lot.
.order_by("id")
.first()
) or ([], 0)
with connections["default"].cursor() as cursor:
if parent_paths_count > 1:
# Delete all child paths that are not the first parent path.
# If we tried to also update those, we would end up with duplicates.
cursor.execute(
"""
DELETE FROM documents_elementpath child_paths
USING documents_elementpath parent_paths
WHERE
parent_paths.element_id = %(parent_id)s
AND parent_paths.path <> %(first_parent_path)s
AND child_paths.path @> (parent_paths.path || %(parent_id)s)
""",
{
"parent_id": self.id,
"first_parent_path": first_parent_path,
},
)
# For children that have other parents, delete all the paths for this parent.
# The paths for the other parents will preserve the structure of the child's descendants,
# so we will have no updates to make.
cursor.execute(
"""
DELETE FROM documents_elementpath parent_paths
USING documents_elementpath other_paths
WHERE
parent_paths.path && ARRAY[%(parent_id)s]
AND other_paths.element_id = parent_paths.element_id
AND NOT other_paths.path && ARRAY[%(parent_id)s]
""",
{"parent_id": self.id},
)
# For the child elements that had no other parent, we have one path starting with `first_parent_path`.
# We strip that from the children's paths and their descendants',
# meaning each child will have a top-level path left (empty array)
# and its descendants will remain descendants of this child.
# As an extra precaution, we will check that the path really starts with this prefix before updating,
# since @> is really a set operation ([1,2,3,4] @> [3,1,4] is true).
prefix = first_parent_path + [self.id]
# TODO: In Django 5.1, rewrite this without raw SQL by using F() slicing
# See https://docs.djangoproject.com/en/dev/ref/models/expressions/#slicing-f-expressions
cursor.execute(
"""
UPDATE documents_elementpath
SET path = path[%(prefix_size)s + 1:]
WHERE
path @> %(prefix)s
AND path[:%(prefix_size)s] = %(prefix)s
""",
{
"prefix": prefix,
"prefix_size": len(prefix),
},
)
@cached_property
def thumbnail(self):
from arkindex.images.models import Thumbnail # Prevent circular imports
......
......@@ -100,6 +100,21 @@ def corpus_delete(corpus_id: str) -> None:
logger.info(f"Deleted corpus {corpus_id}")
@job("high", timeout=settings.RQ_TIMEOUTS["element_trash"])
def element_delete(element: Element, delete_children: bool) -> None:
"""
Wrapper around the element_trash task that removes the element from its children's paths
when not deleting recursively.
"""
if not delete_children:
element.remove_children()
element_trash(
Element.objects.filter(id=element.id),
delete_children=delete_children,
)
@job("high", timeout=settings.RQ_TIMEOUTS["element_trash"])
def element_trash(queryset: ElementQuerySet, delete_children: bool) -> None:
queryset.trash(delete_children=delete_children)
......
......@@ -136,6 +136,7 @@ class TestLoadExport(FixtureTestCase):
reco_run = reco_version.worker_runs.get(process__mode=ProcessMode.Workers)
dla_version = WorkerVersion.objects.get(worker__slug="dla")
dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers)
init_version = WorkerVersion.objects.get(worker__slug="initialisation")
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
......@@ -194,6 +195,10 @@ class TestLoadExport(FixtureTestCase):
# Clean up WorkerRuns from local processes so we can check new worker runs are created
WorkerRun.objects.filter(process__mode=ProcessMode.Local).delete()
# Delete the elements initialisation worker run before calling dumpdata as it is not exported
init_version = WorkerVersion.objects.get(worker__slug="initialisation")
init_version.worker_runs.get().delete()
# Call dumpdata command before the deletion
# Ignore django_rq as it uses a fake database table to insert itself into Django's permissions system
_, dump_path_before = tempfile.mkstemp(suffix=".json")
......
......@@ -303,7 +303,7 @@ class TestChildrenElements(FixtureAPITestCase):
worker_run_child = Element.objects.create(name="bob", type=self.page.type, worker_run=self.worker_run, worker_version=self.worker_version, corpus=self.corpus)
worker_run_child.add_parent(self.vol)
with self.assertNumQueries(6):
with self.assertNumQueries(5):
response = self.client.get(
reverse("api:elements-children", kwargs={"pk": str(self.vol.id)}),
data={"worker_run": str(self.worker_run.id)}
......
......@@ -46,7 +46,7 @@ class TestDestroyElements(FixtureAPITestCase):
{"detail": "You do not have admin access to this element."}
)
@patch("arkindex.project.triggers.documents_tasks.element_trash.delay")
@patch("arkindex.project.triggers.documents_tasks.element_delete.delay")
def test_element_destroy(self, delay_mock):
self.client.force_login(self.user)
castle_story = self.corpus.elements.create(
......@@ -61,14 +61,14 @@ class TestDestroyElements(FixtureAPITestCase):
self.assertEqual(delay_mock.call_count, 1)
args, kwargs = delay_mock.call_args
self.assertEqual(len(args), 0)
self.assertCountEqual(list(kwargs.pop("queryset")), list(self.corpus.elements.filter(id=castle_story.id)))
self.assertDictEqual(kwargs, {
"element": castle_story,
"delete_children": False,
"user_id": self.user.id,
"description": "Element deletion",
"description": "Deletion of Volume: Castle story",
})
@patch("arkindex.project.triggers.documents_tasks.element_trash.delay")
@patch("arkindex.project.triggers.documents_tasks.element_delete.delay")
def test_element_destroy_delete_children(self, delay_mock):
self.client.force_login(self.user)
castle_story = self.corpus.elements.create(
......@@ -77,7 +77,12 @@ class TestDestroyElements(FixtureAPITestCase):
)
self.assertTrue(self.corpus.elements.filter(id=castle_story.id).exists())
for delete_children in [True, False]:
cases = [
(True, "Recursive deletion of Volume: Castle story"),
(False, "Deletion of Volume: Castle story"),
]
for delete_children, expected_description in cases:
with self.subTest(delete_children=delete_children):
delay_mock.reset_mock()
......@@ -91,14 +96,14 @@ class TestDestroyElements(FixtureAPITestCase):
self.assertEqual(delay_mock.call_count, 1)
args, kwargs = delay_mock.call_args
self.assertEqual(len(args), 0)
self.assertCountEqual(list(kwargs.pop("queryset")), list(self.corpus.elements.filter(id=castle_story.id)))
self.assertDictEqual(kwargs, {
"element": castle_story,
"delete_children": delete_children,
"user_id": self.user.id,
"description": "Element deletion",
"description": expected_description,
})
@patch("arkindex.project.triggers.documents_tasks.element_trash.delay")
@patch("arkindex.project.triggers.documents_tasks.element_delete.delay")
def test_element_destroy_creator(self, delay_mock):
"""
An element's creator can delete the element if it has write access
......@@ -117,11 +122,11 @@ class TestDestroyElements(FixtureAPITestCase):
self.assertEqual(delay_mock.call_count, 1)
args, kwargs = delay_mock.call_args
self.assertEqual(len(args), 0)
self.assertCountEqual(list(kwargs.pop("queryset")), list(self.private_corpus.elements.filter(id=castle_story.id)))
self.assertDictEqual(kwargs, {
"element": castle_story,
"delete_children": False,
"user_id": self.user.id,
"description": "Element deletion",
"description": "Deletion of Volume: Castle story",
})
@patch("arkindex.project.mixins.has_access", return_value=False)
......@@ -156,7 +161,7 @@ class TestDestroyElements(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You cannot delete an element that is part of a dataset."})
@patch("arkindex.project.triggers.documents_tasks.element_trash.delay")
@patch("arkindex.project.triggers.documents_tasks.element_delete.delay")
def test_non_empty_element(self, delay_mock):
"""
We can now delete a non-empty element
......@@ -169,11 +174,11 @@ class TestDestroyElements(FixtureAPITestCase):
self.assertEqual(delay_mock.call_count, 1)
args, kwargs = delay_mock.call_args
self.assertEqual(len(args), 0)
self.assertCountEqual(list(kwargs.pop("queryset")), list(self.corpus.elements.filter(id=self.vol.id)))
self.assertDictEqual(kwargs, {
"element": self.vol,
"delete_children": False,
"user_id": self.user.id,
"description": "Element deletion",
"description": "Deletion of Volume: Volume 1",
})
def test_element_trash_dataset_failure(self):
......@@ -542,7 +547,6 @@ class TestDestroyElements(FixtureAPITestCase):
"""
test Element.delete method
"""
self.maxDiff = None
self.client.force_login(self.user)
with self.assertExactQueries("element_dot_delete.sql", params={"id": str(self.vol.id)}):
self.vol.delete()
......
......@@ -469,3 +469,196 @@ class TestEditElementPath(FixtureTestCase):
self.check_parents(elements, "Element", *[[parent_name] for parent_name in parent_names])
self.check_parents(elements, "Child 1", *[[parent_name, "Element"] for parent_name in parent_names])
self.check_parents(elements, "Child 2", *[[parent_name, "Element"] for parent_name in parent_names])
def test_remove_children_no_parent(self):
r"""
Test removing all children from parent A, with no parent paths at all.
K A
\ - -
B C
/ \ \
D E F
/
G
"""
elements = build_tree(
{
"C": "A",
"F": "C",
"B": ["A", "K"],
"D": "B",
"E": "B",
"G": "D",
},
corpus=self.corpus,
type=self.element_type,
)
# Trigger the edge case where there are no paths at all, even top-level paths
elements["A"].paths.all().delete()
self.check_parents(elements, "A")
self.check_parents(elements, "B", ["A"],
["K"])
self.check_parents(elements, "C", ["A"])
self.check_parents(elements, "D", ["A", "B"],
["K", "B"])
self.check_parents(elements, "E", ["A", "B"],
["K", "B"])
self.check_parents(elements, "F", ["A", "C"])
self.check_parents(elements, "G", ["A", "B", "D"],
["K", "B", "D"])
with self.assertExactQueries(
"remove_children_no_parents.sql", params={
# remove_children uses transaction.atomic(), and we are running in a unit test, which is already in a transaction.
# This will cause a savepoint to be created, with a name that is hard to mock.
"savepoint": f"s{_thread.get_ident()}_x{connections['default'].savepoint_state + 1}",
"A": elements["A"].id,
}
):
elements["A"].remove_children()
self.check_parents(elements, "K", [])
self.check_parents(elements, "A")
self.check_parents(elements, "B", ["K"])
self.check_parents(elements, "C", [])
self.check_parents(elements, "D", ["K", "B"])
self.check_parents(elements, "E", ["K", "B"])
self.check_parents(elements, "F", ["C"])
self.check_parents(elements, "G", ["K", "B", "D"])
def test_remove_children_single_parent(self):
r"""
Test removing all children from parent A, with one parent path.
X
\
K A
\ - -
B C
/ \ \
D E F
/
G
"""
elements = build_tree(
{
"A": "X",
"C": "A",
"F": "C",
"B": ["A", "K"],
"D": "B",
"E": "B",
"G": "D",
},
corpus=self.corpus,
type=self.element_type,
)
self.check_parents(elements, "X", [])
self.check_parents(elements, "K", [])
self.check_parents(elements, "A", ["X"])
self.check_parents(elements, "B", ["X", "A"],
["K"])
self.check_parents(elements, "C", ["X", "A"])
self.check_parents(elements, "D", ["X", "A", "B"],
["K", "B"])
self.check_parents(elements, "E", ["X", "A", "B"],
["K", "B"])
self.check_parents(elements, "F", ["X", "A", "C"])
self.check_parents(elements, "G", ["X", "A", "B", "D"],
["K", "B", "D"])
with self.assertExactQueries(
"remove_children_single_parent.sql", params={
# remove_children uses transaction.atomic(), and we are running in a unit test, which is already in a transaction.
# This will cause a savepoint to be created, with a name that is hard to mock.
"savepoint": f"s{_thread.get_ident()}_x{connections['default'].savepoint_state + 1}",
"A": elements["A"].id,
"X": elements["X"].id,
}
):
elements["A"].remove_children()
self.check_parents(elements, "X", [])
self.check_parents(elements, "K", [])
self.check_parents(elements, "A", ["X"])
self.check_parents(elements, "B", ["K"])
self.check_parents(elements, "C", [])
self.check_parents(elements, "D", ["K", "B"])
self.check_parents(elements, "E", ["K", "B"])
self.check_parents(elements, "F", ["C"])
self.check_parents(elements, "G", ["K", "B", "D"])
def test_remove_children_multiple_parents(self):
r"""
Test removing all children from parent A, with multiple parent paths.
X Y
\ /
K A
\ - -
B C
/ \ \
D E F
/
G
"""
elements = build_tree(
{
"A": ["X", "Y"],
"C": "A",
"F": "C",
"B": ["A", "K"],
"D": "B",
"E": "B",
"G": "D",
},
corpus=self.corpus,
type=self.element_type,
)
self.check_parents(elements, "X", [])
self.check_parents(elements, "Y", [])
self.check_parents(elements, "K", [])
self.check_parents(elements, "A", ["X"],
["Y"])
self.check_parents(elements, "B", ["X", "A"],
["Y", "A"],
["K"])
self.check_parents(elements, "C", ["X", "A"],
["Y", "A"])
self.check_parents(elements, "D", ["X", "A", "B"],
["Y", "A", "B"],
["K", "B"])
self.check_parents(elements, "E", ["X", "A", "B"],
["Y", "A", "B"],
["K", "B"])
self.check_parents(elements, "F", ["X", "A", "C"],
["Y", "A", "C"])
self.check_parents(elements, "G", ["X", "A", "B", "D"],
["Y", "A", "B", "D"],
["K", "B", "D"])
with self.assertExactQueries(
"remove_children_multiple_parents.sql", params={
# remove_children uses transaction.atomic(), and we are running in a unit test, which is already in a transaction.
# This will cause a savepoint to be created, with a name that is hard to mock.
"savepoint": f"s{_thread.get_ident()}_x{connections['default'].savepoint_state + 1}",
"A": elements["A"].id,
"first_parent": elements["A"].paths.order_by("id").first().path[0],
}
):
elements["A"].remove_children()
self.check_parents(elements, "X", [])
self.check_parents(elements, "Y", [])
self.check_parents(elements, "K", [])
self.check_parents(elements, "A", ["X"],
["Y"])
self.check_parents(elements, "B", ["K"])
self.check_parents(elements, "C", [])
self.check_parents(elements, "D", ["K", "B"])
self.check_parents(elements, "E", ["K", "B"])
self.check_parents(elements, "F", ["C"])
self.check_parents(elements, "G", ["K", "B", "D"])
......@@ -5,6 +5,7 @@ from django.test import override_settings
from django.urls import reverse
from rest_framework import status
from arkindex.documents.export import local_export, remote_export
from arkindex.documents.models import Corpus, CorpusExportState
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role
......@@ -12,7 +13,7 @@ from arkindex.users.models import Role
class TestExport(FixtureAPITestCase):
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.local_export.delay")
@override_settings(EXPORT_TTL_SECONDS=420)
def test_start(self, delay_mock):
self.client.force_login(self.superuser)
......@@ -41,7 +42,7 @@ class TestExport(FixtureAPITestCase):
description="Export of corpus Unit Tests"
))
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.local_export.delay")
def test_start_requires_login(self, delay_mock):
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
......@@ -49,7 +50,7 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(self.corpus.exports.exists())
self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.local_export.delay")
def test_start_requires_verified(self, delay_mock):
self.user.verified_email = False
self.user.save()
......@@ -61,7 +62,7 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(self.corpus.exports.exists())
self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.local_export.delay")
@patch("arkindex.users.utils.get_max_level", return_value=Role.Guest.value)
def test_start_requires_contributor(self, max_level_mock, delay_mock):
self.user.rights.update(level=Role.Guest.value)
......@@ -75,7 +76,7 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(self.corpus.exports.exists())
self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.remote_export.delay")
def test_start_bad_source(self, delay_mock):
self.client.force_login(self.superuser)
......@@ -87,7 +88,7 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(delay_mock.called)
@patch("arkindex.documents.models.CorpusExport.source")
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.remote_export.delay")
@override_settings(EXPORT_TTL_SECONDS=420)
def test_start_with_source(self, delay_mock, source_field_mock):
source_field_mock.field.choices.return_value = [("default", "default"), ("jouvence", "jouvence")]
......@@ -118,7 +119,7 @@ class TestExport(FixtureAPITestCase):
description="Export of corpus Unit Tests from source jouvence"
))
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.local_export.delay")
def test_start_running(self, delay_mock):
self.client.force_login(self.superuser)
self.corpus.exports.create(user=self.user, state=CorpusExportState.Running)
......@@ -133,7 +134,7 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(delay_mock.called)
@override_settings(EXPORT_TTL_SECONDS=420)
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.local_export.delay")
def test_start_recent_export(self, delay_mock):
self.client.force_login(self.superuser)
with patch("django.utils.timezone.now") as mock_now:
......@@ -153,7 +154,7 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(delay_mock.called)
@override_settings(EXPORT_TTL_SECONDS=420)
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.triggers.export.local_export.delay")
def test_start_recent_export_different_source(self, delay_mock):
from arkindex.documents.models import CorpusExport
CorpusExport.source.field.choices = [("default", "default"), ("jouvence", "jouvence")]
......@@ -433,3 +434,23 @@ class TestExport(FixtureAPITestCase):
response = self.client.delete(reverse("api:manage-export", kwargs={"pk": export.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
assert not self.corpus.exports.exists()
@patch("arkindex.documents.models.CorpusExport.source")
def test_local_export_default_db_only(self, source_field_mock):
source_field_mock.field.choices.return_value = [("default", "default"), ("jouvence", "jouvence")]
export = self.corpus.exports.create(user=self.superuser, state=CorpusExportState.Created, source="jouvence")
with self.assertRaises(AssertionError):
local_export(export)
export.refresh_from_db()
self.assertEqual(export.state, CorpusExportState.Created)
@patch("arkindex.documents.models.CorpusExport.source")
def test_remote_export_not_default_db(self, source_field_mock):
source_field_mock.field.choices.return_value = [("default", "default"), ("jouvence", "jouvence")]
export = self.corpus.exports.create(user=self.superuser, state=CorpusExportState.Created, source="default")
with self.assertRaises(AssertionError):
remote_export(export)
export.refresh_from_db()
self.assertEqual(export.state, CorpusExportState.Created)
......@@ -1613,7 +1613,7 @@ class TestMetaData(FixtureAPITestCase):
Ponos tasks are allowed to create metadata with any text/name.
"""
self.vol.metadatas.all().delete()
process_worker_run = self.process.worker_runs.get()
process_worker_run = self.process.worker_runs.get(version_id=self.worker_version.id)
with self.assertNumQueries(7):
response = self.client.post(
reverse("api:element-metadata-bulk", kwargs={"pk": str(self.vol.id)}),
......
......@@ -10,7 +10,7 @@ from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task
from arkindex.ponos.models import FINAL_STATES, Agent, AgentMode, Artifact, State, Task
from arkindex.ponos.signals import task_failure
from arkindex.project.serializer_fields import EnumField
from arkindex.project.triggers import notify_process_completion
......@@ -36,6 +36,7 @@ class TaskLightSerializer(serializers.ModelSerializer):
└⟶ Error
Stopping ⟶ Stopped
└⟶ Error
Slurm agents are also allowed to update state from Pending to Completed or Failed.
""").strip(),
)
......@@ -49,6 +50,8 @@ class TaskLightSerializer(serializers.ModelSerializer):
"state",
"parents",
"shm_size",
"requires_gpu",
"original_task_id",
)
read_only_fields = (
"id",
......@@ -57,15 +60,21 @@ class TaskLightSerializer(serializers.ModelSerializer):
"slug",
"parents",
"shm_size",
"requires_gpu",
"original_task_id",
)
def validate_state(self, state):
# Updates from a state to the same state is blocked to avoid side effects on finished tasks
allowed_transitions = {
State.Unscheduled: [State.Pending],
State.Pending: [State.Running, State.Error],
State.Running: [State.Completed, State.Failed, State.Error],
State.Stopping: [State.Stopped, State.Error],
}
user = self.context["request"].user
if isinstance(user, Agent) and user.mode == AgentMode.Slurm:
allowed_transitions[State.Pending].extend([State.Completed, State.Failed])
if self.instance and state not in allowed_transitions.get(self.instance.state, []):
raise ValidationError(f"Transition from state {self.instance.state} to state {state} is forbidden.")
return state
......@@ -89,7 +98,6 @@ class TaskSerializer(TaskLightSerializer):
"agent",
"gpu",
"extra_files",
"original_task_id"
)
read_only_fields = TaskLightSerializer.Meta.read_only_fields + (
"logs",
......@@ -97,7 +105,6 @@ class TaskSerializer(TaskLightSerializer):
"agent",
"gpu",
"extra_files",
"original_task_id"
)
@extend_schema_field(serializers.CharField())
......
import copy
import uuid
from datetime import datetime, timedelta, timezone
from io import BytesIO
from itertools import combinations
from unittest import expectedFailure
from unittest.mock import call, patch, seal
......@@ -10,7 +12,7 @@ from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.ponos.models import State, Task
from arkindex.ponos.models import Agent, AgentMode, Farm, State, Task
from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Right, Role, User
......@@ -44,6 +46,41 @@ class TestAPI(FixtureAPITestCase):
cls.dla = WorkerVersion.objects.get(worker__slug="dla")
cls.recognizer = WorkerVersion.objects.get(worker__slug="reco")
cls.farm = Farm.objects.first()
cls.docker_agent = Agent.objects.create(
mode=AgentMode.Docker,
farm=cls.farm,
last_ping=datetime.now(),
cpu_cores=42,
cpu_frequency=42e8,
ram_total=42e3
)
cls.slurm_agent = Agent.objects.create(
mode=AgentMode.Slurm,
farm=cls.farm,
last_ping=datetime.now(),
)
@property
def docker_task_transitions(self):
return (
(State.Unscheduled, State.Pending),
(State.Pending, State.Error),
(State.Pending, State.Running),
(State.Running, State.Completed),
(State.Running, State.Failed),
(State.Running, State.Error),
(State.Stopping, State.Stopped),
(State.Stopping, State.Error),
)
@property
def slurm_task_transitions(self):
return self.docker_task_transitions + (
(State.Pending, State.Completed),
(State.Pending, State.Failed),
)
def test_task_details_requires_login(self):
with self.assertNumQueries(0):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
......@@ -81,6 +118,7 @@ class TestAPI(FixtureAPITestCase):
"agent": None,
"gpu": None,
"shm_size": None,
"requires_gpu": False,
},
)
......@@ -165,6 +203,7 @@ class TestAPI(FixtureAPITestCase):
"agent": None,
"gpu": None,
"shm_size": None,
"requires_gpu": False,
},
)
......@@ -207,6 +246,7 @@ class TestAPI(FixtureAPITestCase):
"agent": None,
"gpu": None,
"shm_size": None,
"requires_gpu": False,
},
)
......@@ -374,6 +414,139 @@ class TestAPI(FixtureAPITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.project.aws.s3")
@patch("arkindex.ponos.serializers.notify_process_completion")
@patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple())
def test_partial_update_task_from_docker_agent_allowed_transitions(self, notify_mock, s3_mock, get_user_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
# Authenticate from a possible agent
custom_user = copy.copy(self.docker_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
for (state_from, state_to) in self.docker_task_transitions:
with self.subTest(state_from=state_from, state_to=state_to):
self.task1.state = state_from
self.task1.save()
resp = self.client.patch(
reverse("api:task-details", args=[self.task1.id]),
data={"state": state_to.value},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.task1.refresh_from_db()
self.assertEqual(self.task1.state, state_to)
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.project.aws.s3")
@patch("arkindex.ponos.serializers.notify_process_completion")
@patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple())
def test_partial_update_task_from_slurm_agent_allowed_transitions(self, notify_mock, s3_mock, get_user_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
# Authenticate from a possible agent
custom_user = copy.copy(self.slurm_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
for (state_from, state_to) in self.slurm_task_transitions:
with self.subTest(state_from=state_from, state_to=state_to):
self.task1.state = state_from
self.task1.save()
resp = self.client.patch(
reverse("api:task-details", args=[self.task1.id]),
data={"state": state_to.value},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.task1.refresh_from_db()
self.assertEqual(self.task1.state, state_to)
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.project.aws.s3")
@patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple())
def test_partial_update_task_from_docker_agent_forbidden_transitions(self, s3_mock, get_user_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
# Mock authenticating from an agent
custom_user = copy.copy(self.docker_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
for (state_from, state_to) in combinations(State, 2):
if (state_from, state_to) in self.docker_task_transitions:
continue
with self.subTest(state_from=state_from, state_to=state_to):
self.task1.state = state_from
self.task1.save()
resp = self.client.patch(
reverse("api:task-details", args=[self.task1.id]),
data={"state": state_to.value},
)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(resp.json(), {
"state": [f"Transition from state {state_from} to state {state_to} is forbidden."]
})
self.task1.refresh_from_db()
self.assertEqual(self.task1.state, state_from)
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.project.aws.s3")
@patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple())
def test_partial_update_task_from_slurm_agent_forbidden_transitions(self, s3_mock, get_user_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
# Mock authenticating from an agent
custom_user = copy.copy(self.slurm_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
for (state_from, state_to) in combinations(State, 2):
if (state_from, state_to) in self.slurm_task_transitions:
continue
with self.subTest(state_from=state_from, state_to=state_to):
self.task1.state = state_from
self.task1.save()
resp = self.client.patch(
reverse("api:task-details", args=[self.task1.id]),
data={"state": state_to.value},
)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(resp.json(), {
"state": [f"Transition from state {state_from} to state {state_to} is forbidden."]
})
self.task1.refresh_from_db()
self.assertEqual(self.task1.state, state_from)
def test_partial_update_task_requires_login(self):
with self.assertNumQueries(0):
resp = self.client.patch(
......@@ -641,6 +814,7 @@ class TestAPI(FixtureAPITestCase):
self.task1.state = State.Completed.value
self.task1.save()
self.task2.state = State.Error.value
self.task2.requires_gpu = True
self.task2.save()
self.client.force_login(self.user)
......@@ -669,6 +843,7 @@ class TestAPI(FixtureAPITestCase):
"shm_size": None,
"slug": task_2_slug,
"state": "pending",
"requires_gpu": True,
},
)
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
......@@ -734,6 +909,7 @@ class TestAPI(FixtureAPITestCase):
"shm_size": None,
"slug": task_2_slug,
"state": "pending",
"requires_gpu": False,
},
)
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
......
......@@ -228,7 +228,8 @@ class ProcessList(ProcessACLMixin, ListAPIView):
filters &= Q(name__icontains=self.request.query_params["name"])
qs = (
self.readable_processes
Process.objects
.filter(corpus__in=Corpus.objects.readable(self.user))
.select_related("creator")
.filter(filters)
# Order processes by completion date when available, or start date, or last update
......@@ -1515,6 +1516,11 @@ class WorkerRunDetails(ProcessACLMixin, RetrieveUpdateDestroyAPIView):
"version__revision__repo",
"configuration",
"model_version__model",
"process__element__corpus",
"process__element__type",
"process__element__image__server",
"process__folder_type",
"process__element_type",
)
.prefetch_related(Prefetch(
"version__revision__refs",
......
......@@ -242,22 +242,49 @@ class ProcessBuilder:
@prefetch_worker_runs
def build_workers(self):
# Build the initialisation task listing elements by chunks
args = [
"python", "-m", "arkindex_tasks.init_elements", str(self.process.id),
"--chunks-number", str(self.process.chunks),
from arkindex.process.models import WorkerVersion
# Retrieve worker runs
worker_runs = list(self.process.worker_runs.all())
# Find the WorkerRun to use for the initialisation task, or create it
initialisation_runs = [
run for run in worker_runs if run.version == WorkerVersion.objects.init_elements_version and not len(run.parents)
]
if self.process.use_cache:
args.append("--use-cache")
if len(initialisation_runs):
# In case there is more than one run using the element initialisation worker, use the first one and ignore the others.
initialisation_worker_run = initialisation_runs[0]
# Remove the elements initialisation run from the worker runs list, so that it doesn't go through the regular task
# creation process with WorkerRun.build_task and is not split between chunks
worker_runs.remove(initialisation_worker_run)
# If there is no elements initialisation worker run in the process, create one
else:
initialisation_worker = WorkerVersion.objects.init_elements_version
initialisation_worker_run = self.process.worker_runs.create(
version=initialisation_worker
)
# Link all parentless worker runs to the initialisation worker run
no_parents = [run for run in worker_runs if not len(run.parents)]
for run in no_parents:
run.parents = [initialisation_worker_run.id]
from arkindex.process.models import WorkerRun
WorkerRun.objects.bulk_update(no_parents, ["parents"])
# Create the initialisation task
import_task_slug = "initialisation"
env = {
"ARKINDEX_WORKER_RUN_ID": str(initialisation_worker_run.id),
**self.base_env
}
self._build_task(
command=" ".join(args),
slug=import_task_slug,
env=self.base_env,
command=initialisation_worker_run.version.docker_command,
image=initialisation_worker_run.version.docker_image_iid,
worker_run=initialisation_worker_run,
env=env,
)
# Distribute worker run tasks
worker_runs = list(self.process.worker_runs.all())
chunks = self._get_elements_json_chunks()
for index, chunk in enumerate(chunks, start=1):
elements_path = shlex.quote(path.join("/data", import_task_slug, chunk))
......@@ -268,7 +295,9 @@ class ProcessBuilder:
import_task_slug,
elements_path,
chunk=index if len(chunks) > 1 else None,
workflow_runs=worker_runs,
# Slip the elements initialisation worker run back in, so that WorkerRun.build_task
# can use it when building the tasks from worker runs that have it as a parent
workflow_runs=worker_runs + [initialisation_worker_run],
run=self.run,
active_gpu_agents=self.active_gpu_agents,
)
......
......@@ -158,6 +158,22 @@ class WorkerResultSourceQuerySet(QuerySet):
class WorkerVersionManager(Manager):
@cached_property
def init_elements_version(self):
"""
WorkerVersion for elements initialization.
"""
from arkindex.process.models import WorkerVersionState
init_version = (
self
.select_related("worker", "revision")
.prefetch_related("revision__refs")
.get(docker_image_iid=settings.INIT_ELEMENTS_DOCKER_IMAGE)
)
if init_version.state != WorkerVersionState.Available:
raise ValueError("The elements initialization worker version must be 'available'.")
return init_version
@cached_property
def imports_version(self):
"""
......
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("process", "0033_remove_process_generate_thumbnails"),
]
operations = [
migrations.RunSQL(
[
"""
UPDATE process_process
SET mode = 'files'
WHERE mode = 'iiif'
"""
],
reverse_sql=migrations.RunSQL.noop,
elidable=True,
)
]