Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Showing
with 219 additions and 268 deletions
......@@ -22,7 +22,6 @@ from arkindex.documents.models import (
TranscriptionEntity,
)
from arkindex.images.models import Image, ImageServer
from arkindex.ponos.models import Artifact
from arkindex.process.models import Repository, WorkerType, WorkerVersion, WorkerVersionState
from arkindex.project.tests import FixtureTestCase
......@@ -65,7 +64,7 @@ class TestExport(FixtureTestCase):
worker=self.repo.workers.create(slug=str(uuid4()), type=self.worker_type),
configuration={},
state=WorkerVersionState.Available,
docker_image=Artifact.objects.first(),
docker_image_iid="registry.somewhere.com/something:latest",
)
@override_settings(PUBLIC_HOSTNAME="https://darkindex.lol")
......
......@@ -92,8 +92,8 @@ class TestCreateElements(FixtureAPITestCase):
"id": str(volume.id),
"type": volume.type.slug,
"name": volume.name,
"thumbnail_put_url": None,
"thumbnail_url": volume.thumbnail.s3_url,
"thumbnail_put_url": volume.thumbnail.s3_put_url,
"worker_version": None,
"confidence": None,
"creator": "Test user",
......
......@@ -161,7 +161,6 @@ class TestEntitiesAPI(FixtureAPITestCase):
"public": self.corpus.public
},
"thumbnail_url": None,
"thumbnail_put_url": None,
"zone": {
"id": str(e.id),
"url": e.iiif_url,
......
......@@ -31,6 +31,7 @@ class TestExport(FixtureAPITestCase):
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "default"
})
self.assertEqual(delay_mock.call_count, 1)
......@@ -61,19 +62,62 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.mixins.has_access", return_value=False)
def test_start_requires_contributor(self, has_access_mock, delay_mock):
@patch("arkindex.users.utils.get_max_level", return_value=Role.Guest.value)
def test_start_requires_contributor(self, max_level_mock, delay_mock):
self.user.rights.update(level=Role.Guest.value)
self.client.force_login(self.user)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(has_access_mock.call_count, 1)
self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Contributor.value, skip_public=False))
self.assertEqual(max_level_mock.call_count, 1)
self.assertEqual(max_level_mock.call_args, call(self.user, self.corpus))
self.assertFalse(self.corpus.exports.exists())
self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_bad_source(self, delay_mock):
self.client.force_login(self.superuser)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}), {"source": "jouvence"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"source": ['"jouvence" is not a valid choice.']})
self.assertEqual(self.corpus.exports.count(), 0)
self.assertFalse(delay_mock.called)
@patch("arkindex.documents.models.CorpusExport.source")
@patch("arkindex.project.triggers.export.export_corpus.delay")
@override_settings(EXPORT_TTL_SECONDS=420)
def test_start_with_source(self, delay_mock, source_field_mock):
source_field_mock.field.choices.return_value = [("default", "default"), ("jouvence", "jouvence")]
self.client.force_login(self.superuser)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}), {"source": "jouvence"})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
export = self.corpus.exports.get()
self.assertDictEqual(response.json(), {
"id": str(export.id),
"created": export.created.isoformat().replace("+00:00", "Z"),
"updated": export.updated.isoformat().replace("+00:00", "Z"),
"user": {
"id": self.superuser.id,
"display_name": self.superuser.display_name,
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "jouvence"
})
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(
corpus_export=export,
user_id=self.superuser.id,
description="Export of corpus Unit Tests from source jouvence"
))
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_running(self, delay_mock):
self.client.force_login(self.superuser)
......@@ -81,7 +125,9 @@ class TestExport(FixtureAPITestCase):
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ["An export is already running for this corpus."])
self.assertDictEqual(response.json(), {
"non_field_errors": ["An export is already running for this corpus."]
})
self.assertEqual(self.corpus.exports.count(), 1)
self.assertFalse(delay_mock.called)
......@@ -99,11 +145,53 @@ class TestExport(FixtureAPITestCase):
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ["An export has already been made for this corpus in the last 420 seconds."])
self.assertDictEqual(response.json(), {
"non_field_errors": ["An export has already been made for this corpus in the last 420 seconds."]
})
self.assertEqual(self.corpus.exports.count(), 1)
self.assertFalse(delay_mock.called)
@override_settings(EXPORT_TTL_SECONDS=420)
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_recent_export_different_source(self, delay_mock):
from arkindex.documents.models import CorpusExport
CorpusExport.source.field.choices = [("default", "default"), ("jouvence", "jouvence")]
self.client.force_login(self.superuser)
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) - timedelta(minutes=2)
self.corpus.exports.create(
user=self.user,
state=CorpusExportState.Done,
source="jouvence"
)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
export = self.corpus.exports.get(source="default")
self.assertDictEqual(response.json(), {
"id": str(export.id),
"created": export.created.isoformat().replace("+00:00", "Z"),
"updated": export.updated.isoformat().replace("+00:00", "Z"),
"user": {
"id": self.superuser.id,
"display_name": self.superuser.display_name,
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "default"
})
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(
corpus_export=export,
user_id=self.superuser.id,
description="Export of corpus Unit Tests"
))
def test_list(self):
export1 = self.corpus.exports.create(user=self.user, state=CorpusExportState.Done)
export2 = self.corpus.exports.create(user=self.superuser)
......@@ -123,6 +211,7 @@ class TestExport(FixtureAPITestCase):
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"source": "default"
},
{
"id": str(export1.id),
......@@ -135,6 +224,7 @@ class TestExport(FixtureAPITestCase):
"email": self.user.email,
},
"corpus_id": str(self.corpus.id),
"source": "default"
},
])
......@@ -149,18 +239,19 @@ class TestExport(FixtureAPITestCase):
response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@patch("arkindex.project.mixins.has_access", return_value=False)
def test_list_requires_guest(self, has_access_mock):
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_list_requires_guest(self, filter_rights_mock):
self.user.rights.all().delete()
self.corpus.public = False
self.corpus.save()
filter_rights_mock.return_value = Corpus.objects.none()
self.client.force_login(self.user)
response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(has_access_mock.call_count, 1)
self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Guest.value, skip_public=False))
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value))
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_download_export(self, presigned_url_mock):
......
......@@ -43,7 +43,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"public": True,
},
"thumbnail_url": self.vol.thumbnail.s3_url,
"thumbnail_put_url": None,
"thumbnail_put_url": self.vol.thumbnail.s3_put_url,
"worker_version": None,
"confidence": None,
"zone": None,
......@@ -106,7 +106,7 @@ class TestRetrieveElements(FixtureAPITestCase):
response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["thumbnail_url"], self.vol.thumbnail.s3_url)
self.assertIsNone(response.json()["thumbnail_put_url"])
self.assertEqual(response.json()["thumbnail_put_url"], self.vol.thumbnail.s3_put_url)
self.assertFalse(self.page.type.folder)
response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.page.id)}))
......@@ -114,6 +114,38 @@ class TestRetrieveElements(FixtureAPITestCase):
self.assertIsNone(response.json()["thumbnail_url"])
self.assertIsNone(response.json()["thumbnail_put_url"])
@patch("arkindex.documents.serializers.elements.get_max_level")
def test_get_element_thumbnail_put_acl(self, get_max_level_mock):
"""
RetrieveElement returns a thumbnail_put_url for corpus admins,
or for corpus contributors that have created the element
"""
self.client.force_login(self.user)
cases = [
(Role.Guest, self.superuser, False),
(Role.Guest, self.user, False),
(Role.Contributor, self.superuser, False),
(Role.Contributor, self.user, True),
(Role.Admin, self.superuser, True),
(Role.Admin, self.user, True),
]
for role, creator, has_put_url in cases:
with self.subTest(role=role):
get_max_level_mock.return_value = role.value
self.corpus.memberships.filter(user=self.user).update(level=role.value)
self.vol.creator = creator
self.vol.save()
with self.assertNumQueries(4):
response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.json()["thumbnail_put_url"],
self.vol.thumbnail.s3_put_url if has_put_url else None
)
@override_settings(ARKINDEX_TASKS_IMAGE="task_image")
def test_get_element_thumbnail_put_ponos_task(self):
"""
......@@ -121,14 +153,14 @@ class TestRetrieveElements(FixtureAPITestCase):
running the thumbnails generation command on a folder element
"""
process = Process.objects.create(
mode=ProcessMode.Repository,
revision=self.worker_version.revision,
mode=ProcessMode.Workers,
creator=self.user,
generate_thumbnails=True,
farm=Farm.objects.first(),
corpus=self.corpus
)
process.run()
task = process.tasks.get()
task = process.tasks.first()
task.image = "task_image"
task.command = "python generate_thumbnails"
......@@ -148,27 +180,6 @@ class TestRetrieveElements(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIsNone(response.json()["thumbnail_put_url"])
def test_get_element_thumbnail_put_requires_thumbnails_task(self):
"""
Only tasks that are intended to generate thumbnails (ARKINDEX_TASKS_IMAGE + thumbnails_generation command)
can retrieve the thumbnails PUT URL.
"""
process = Process.objects.create(
mode=ProcessMode.Repository,
revision=self.worker_version.revision,
creator=self.user,
farm=Farm.objects.first(),
)
process.run()
task = process.tasks.get()
self.assertTrue(self.vol.type.folder)
response = self.client.get(
reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}),
HTTP_AUTHORIZATION=f"Ponos {task.token}",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIsNone(response.json()["thumbnail_put_url"])
def test_get_element_creator(self):
self.vol.creator = self.user
self.vol.save()
......@@ -220,7 +231,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"public": True,
},
"thumbnail_url": self.vol.thumbnail.s3_url,
"thumbnail_put_url": None,
"thumbnail_put_url": self.vol.thumbnail.s3_put_url,
"worker_version": str(self.worker_version.id),
"confidence": None,
"zone": None,
......@@ -255,7 +266,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"public": True,
},
"thumbnail_url": self.vol.thumbnail.s3_url,
"thumbnail_put_url": None,
"thumbnail_put_url": self.vol.thumbnail.s3_put_url,
"worker_version": None,
"confidence": None,
"zone": None,
......
......@@ -7,7 +7,7 @@ from rest_framework.response import Response
from rest_framework.serializers import ValidationError
from arkindex.documents.models import Corpus, Element
from arkindex.documents.serializers.elements import ElementSlimSerializer
from arkindex.documents.serializers.elements import ElementTinySerializer
from arkindex.images.models import Image
from arkindex.images.serializers import (
CreateImageErrorResponseSerializer,
......@@ -151,7 +151,7 @@ class ImageElements(ListAPIView):
# For OpenAPI type discovery: an image's ID is in the path
queryset = Image.objects.none()
permission_classes = (IsVerified, )
serializer_class = ElementSlimSerializer
serializer_class = ElementTinySerializer
def get_queryset(self):
filters = {
......
......@@ -34,7 +34,7 @@ class Migration(migrations.Migration):
constraint=models.UniqueConstraint(models.F("url"), name="unique_imageserver_url"),
),
],
# This can be removed by manage.py squashmigrations
# This can be removed by `arkindex squashmigrations`
elidable=True,
),
]
......@@ -15,7 +15,7 @@ from django.utils.text import slugify
from enumfields import EnumField
from arkindex.images.managers import ImageServerManager
from arkindex.project.aws import S3FileMixin, S3FileStatus
from arkindex.project.aws import S3FileMixin, S3FileStatus, should_verify_cert
from arkindex.project.fields import LStripTextField, MD5HashField, StripSlashURLField
from arkindex.project.models import IndexableModel
......@@ -238,7 +238,7 @@ class Image(S3FileMixin, IndexableModel):
requests_exception = None
try:
# Load info
resp = requests.get(info_url, timeout=15, allow_redirects=True)
resp = requests.get(info_url, timeout=15, allow_redirects=True, verify=should_verify_cert(info_url))
resp.raise_for_status()
try:
payload = resp.json()
......
......@@ -8,7 +8,7 @@ from rest_framework import status
from arkindex.images.models import Image, ImageServer
from arkindex.ponos.models import Farm
from arkindex.process.models import Process, ProcessMode, Revision
from arkindex.process.models import Process, ProcessMode
from arkindex.project.aws import S3FileStatus
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Scope
......@@ -608,13 +608,13 @@ class TestImageApi(FixtureAPITestCase):
IIIF images created by a Ponos task are immediately checked
"""
process = Process.objects.create(
mode=ProcessMode.Repository,
revision=Revision.objects.first(),
mode=ProcessMode.Workers,
creator=self.user,
corpus=self.corpus,
farm=Farm.objects.first(),
)
process.run()
task = process.tasks.get()
task = process.tasks.first()
# The user scope should not be necessary with Ponos task authentication
self.assertFalse(self.user.user_scopes.filter(scope=Scope.CreateIIIFImage).exists())
......@@ -655,13 +655,13 @@ class TestImageApi(FixtureAPITestCase):
height=100,
)
process = Process.objects.create(
mode=ProcessMode.Repository,
revision=Revision.objects.first(),
mode=ProcessMode.Workers,
creator=self.user,
corpus=self.corpus,
farm=Farm.objects.first(),
)
process.run()
task = process.tasks.get()
task = process.tasks.first()
# The user scope should not be necessary with Ponos task authentication
self.assertFalse(self.user.user_scopes.filter(scope=Scope.CreateIIIFImage).exists())
......
......@@ -2,7 +2,8 @@
import os
import sys
if __name__ == "__main__":
def main():
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "arkindex.project.settings")
try:
from django.core.management import execute_from_command_line
......@@ -20,3 +21,8 @@ if __name__ == "__main__":
)
raise
execute_from_command_line(sys.argv)
if __name__ == "__main__":
sys.stderr.write("WARNING: manage.py is deprecated, you should use the `arkindex` script instead\n")
main()
......@@ -52,8 +52,7 @@ class TaskDetailsFromAgent(RetrieveUpdateAPIView):
"agent__farm",
"gpu",
# Used for permission checks
"process__corpus",
"process__revision__repo",
"process__corpus"
)
permission_classes = (
# On all HTTP methods, require either any Ponos agent, an instance admin, the task itself, or guest access to the process' task
......@@ -102,7 +101,7 @@ class TaskArtifacts(ListCreateAPIView):
def task(self):
task = get_object_or_404(
# Select the required tables for permissions checking
Task.objects.select_related("process__corpus", "process__revision__repo"),
Task.objects.select_related("process__corpus"),
pk=self.kwargs["pk"],
)
self.check_object_permissions(self.request, task)
......@@ -126,7 +125,7 @@ class TaskArtifactDownload(APIView):
def get_object(self, pk, path):
artifact = get_object_or_404(
# Select the required tables for permissions checking
Artifact.objects.select_related("task__process__corpus", "task__process__revision__repo"),
Artifact.objects.select_related("task__process__corpus"),
task_id=pk,
path=path,
)
......@@ -167,5 +166,5 @@ class TaskUpdate(UpdateAPIView):
authentication_classes = (TokenAuthentication, SessionAuthentication)
# Only allow regular users that have admin access to the task's process
permission_classes = (IsTaskAdmin, )
queryset = Task.objects.select_related("process__corpus", "process__revision__repo")
queryset = Task.objects.select_related("process__corpus")
serializer_class = TaskTinySerializer
......@@ -42,7 +42,7 @@ class Migration(migrations.Migration):
),
),
],
# This can be removed by manage.py squashmigrations
# This can be removed by `arkindex squashmigrations`
elidable=True,
),
# Remove the implicit LIKE index on Secret.name and make the unique constraint explicit
......@@ -70,7 +70,7 @@ class Migration(migrations.Migration):
),
),
],
# This can be removed by manage.py squashmigrations
# This can be removed by `arkindex squashmigrations`
elidable=True,
),
# Remove the implicit LIKE index on Task.token and make the unique constraint explicit
......@@ -101,7 +101,7 @@ class Migration(migrations.Migration):
),
),
],
# This can be removed by manage.py squashmigrations
# This can be removed by `arkindex squashmigrations`
elidable=True,
),
]
# Generated by Django 4.1.7 on 2024-03-20 08:45
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("ponos", "0006_task_worker_run"),
]
operations = [
migrations.RemoveField(
model_name="task",
name="has_docker_socket",
),
]
......@@ -273,7 +273,6 @@ class Task(models.Model):
shm_size = models.CharField(max_length=80, blank=True, null=True, editable=False)
command = models.TextField(blank=True, null=True)
env = HStoreField(default=dict)
has_docker_socket = models.BooleanField(default=False)
image_artifact = models.ForeignKey(
"ponos.Artifact",
related_name="tasks_using_image",
......
......@@ -174,60 +174,6 @@ class TestAPI(FixtureAPITestCase):
call("get_object", Params={"Bucket": "ponos", "Key": "somelog"}),
)
@patch("arkindex.project.aws.s3")
def test_task_details_process_level_repo(self, s3_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
self.client.force_login(self.user)
self.process.mode = ProcessMode.Repository
self.process.corpus = None
self.process.revision = self.rev
self.process.save()
membership = self.rev.repo.memberships.create(user=self.user, level=Role.Guest.value)
for role in Role:
s3_mock.reset_mock()
s3_mock.reset_mock()
# Recreate the BytesIO each time, because its contents get consumed each time the API is called
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
with self.subTest(role=role):
membership.level = role.value
membership.save()
with self.assertNumQueries(4):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertDictEqual(
resp.json(),
{
"id": str(self.task1.id),
"run": 0,
"depth": 0,
"slug": "initialisation",
"state": "unscheduled",
"parents": [],
"logs": "Failed successfully",
"full_log": "http://somewhere",
"extra_files": {},
"agent": None,
"gpu": None,
"shm_size": None,
},
)
self.assertEqual(s3_mock.Object.call_count, 2)
self.assertEqual(s3_mock.Object().get.call_count, 1)
self.assertEqual(s3_mock.Object().get.call_args, call(Range="bytes=-42"))
self.assertEqual(s3_mock.meta.client.generate_presigned_url.call_count, 1)
self.assertEqual(
s3_mock.meta.client.generate_presigned_url.call_args,
call("get_object", Params={"Bucket": "ponos", "Key": "somelog"}),
)
@patch("arkindex.project.aws.s3")
def test_task_details(self, s3_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
......@@ -319,28 +265,6 @@ class TestAPI(FixtureAPITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
@expectedFailure
def test_update_task_requires_process_admin_repo(self):
self.process.mode = ProcessMode.Repository
self.process.corpus = None
self.process.revision = self.rev
self.process.creator = self.superuser
self.process.save()
self.client.force_login(self.user)
for role in [None, Role.Guest, Role.Contributor]:
with self.subTest(role=role):
self.rev.repo.memberships.filter(user=self.user).delete()
if role:
self.rev.repo.memberships.create(user=self.user, level=role.value)
with self.assertNumQueries(5):
resp = self.client.put(
reverse("api:task-update", args=[self.task1.id]),
data={"state": State.Stopping.value},
)
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
def test_update_running_task_state_stopping(self):
self.task1.state = State.Running
self.task1.save()
......@@ -500,28 +424,6 @@ class TestAPI(FixtureAPITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
@expectedFailure
def test_partial_update_task_requires_process_admin_repo(self):
self.process.mode = ProcessMode.Repository
self.process.corpus = None
self.process.revision = self.rev
self.process.creator = self.superuser
self.process.save()
self.client.force_login(self.user)
for role in [None, Role.Guest, Role.Contributor]:
with self.subTest(role=role):
self.rev.repo.memberships.filter(user=self.user).delete()
if role:
self.rev.repo.memberships.create(user=self.user, level=role.value)
with self.assertNumQueries(5):
resp = self.client.patch(
reverse("api:task-update", args=[self.task1.id]),
data={"state": State.Stopping.value},
)
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
def test_partial_update_running_task_state_stopping(self):
self.task1.state = State.Running
self.task1.save()
......
......@@ -6,7 +6,7 @@ from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.process.models import Process, ProcessMode, Repository
from arkindex.process.models import Process, ProcessMode
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Right, Role, User
......@@ -17,14 +17,6 @@ class TestAPI(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.repo = Repository.objects.first()
cls.repository_process = Process.objects.create(
mode=ProcessMode.Repository,
creator=cls.superuser,
revision=cls.repo.revisions.first(),
)
# Make corpus private
cls.corpus.public = False
cls.corpus.save()
......@@ -119,43 +111,6 @@ class TestAPI(FixtureAPITestCase):
],
)
def test_list_process_level_repo(self):
self.client.force_login(self.user)
membership = self.repo.memberships.create(user=self.user, level=Role.Guest.value)
for role in Role:
with self.subTest(role=role):
membership.level = role.value
membership.save()
with self.assertNumQueries(4):
response = self.client.get(reverse("api:task-artifacts", args=[self.task1.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(
response.json(),
[
{
"content_type": "application/json",
"created": self.artifact1.created.isoformat().replace("+00:00", "Z"),
"id": str(self.artifact1.id),
"path": "path/to/file.json",
"s3_put_url": None,
"size": 42,
"updated": self.artifact1.updated.isoformat().replace("+00:00", "Z"),
},
{
"content_type": "text/plain",
"created": self.artifact2.created.isoformat().replace("+00:00", "Z"),
"id": str(self.artifact2.id),
"path": "some/text.txt",
"s3_put_url": None,
"size": 1337,
"updated": self.artifact2.updated.isoformat().replace("+00:00", "Z"),
},
],
)
def test_list_admin(self):
self.client.force_login(self.superuser)
with self.assertNumQueries(4):
......@@ -423,31 +378,6 @@ class TestAPI(FixtureAPITestCase):
)
)
def test_download_process_level_repo(self):
self.client.force_login(self.user)
membership = self.repo.memberships.create(user=self.user, level=Role.Guest.value)
task = self.repository_process.tasks.create(run=0, depth=0, slug="a")
task.artifacts.create(path="path/to/file.json", content_type="application/json", size=42)
for role in Role:
with self.subTest(role=role):
membership.level = role.value
membership.save()
with self.assertNumQueries(3):
response = self.client.get(
reverse("api:task-artifact-download", args=[task.id, "path/to/file.json"]),
)
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertTrue(response.has_header("Location"))
self.assertTrue(
response["Location"].startswith(
f"http://s3/ponos-artifacts/{task.id}/path/to/file.json"
)
)
def test_download_task(self):
with self.assertNumQueries(2):
response = self.client.get(
......
......@@ -306,8 +306,6 @@ class ProcessQuerysetMixin(object):
# Element and folder types are serialized as their slugs
"element_type",
"folder_type",
# The revision is serialized with its commit URL, which also requires the repository
"revision__repo",
)
# Files and tasks are also listed
.prefetch_related("files", "tasks__parents")
......@@ -441,10 +439,6 @@ class ProcessRetry(ProcessACLMixin, ProcessQuerysetMixin, GenericAPIView):
elif state == State.Stopping:
raise ValidationError({"__all__": ["This process is stopping"]})
if process.mode == ProcessMode.Repository:
if not process.revision:
raise ValidationError({"__all__": ["Git repository imports must have a revision set"]})
@extend_schema(
operation_id="RetryProcess",
tags=["process"],
......@@ -701,7 +695,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
@cached_property
def process(self):
process = get_object_or_404(
Process.objects.using("default").select_related("corpus", "revision__repo"),
Process.objects.using("default").select_related("corpus"),
Q(pk=self.kwargs["pk"])
)
if not self.process_access_level(process):
......@@ -986,7 +980,7 @@ class WorkerTypesList(ListAPIView):
get=extend_schema(
description=(
"List versions for a given worker ID with their revision and associated git references.\n\n"
"Requires an **execution** access to the worker or its repository."
"Requires an **execution** access to the worker."
),
parameters=[
OpenApiParameter(
......@@ -1001,13 +995,7 @@ class WorkerTypesList(ListAPIView):
description=dedent("""
Create a new version for a worker.
Authentication can be done:
* Using a user authentication (via a cookie or token).
The user must have an administrator access to the worker.
* Using a ponos task authentication.
The worker must be linked to a repository.
The `revision_id` parameter must be set for workers linked to a repository only.
""")
)
)
......@@ -1030,8 +1018,7 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView):
raise PermissionDenied(detail="You do not have an execution access to this worker.")
if (
self.request.method not in permissions.SAFE_METHODS
# Either a task authentication or an admin access is required for creation
and not isinstance(self.request.auth, Task)
# An admin access is required for creation
and not self.has_admin_access(worker)
):
raise PermissionDenied(detail="You do not have an admin access to this worker.")
......
import shlex
from collections import defaultdict
from datetime import timedelta
from functools import wraps
from os import path
from typing import Dict, List, Sequence, Tuple
......@@ -7,10 +8,12 @@ from uuid import UUID
from django.conf import settings
from django.db.models import Prefetch, prefetch_related_objects
from django.utils import timezone
from django.utils.functional import cached_property
from rest_framework.exceptions import ValidationError
from arkindex.images.models import ImageServer
from arkindex.ponos.models import Task, task_token_default
from arkindex.ponos.models import GPU, Task, task_token_default
class ProcessBuilder(object):
......@@ -85,7 +88,6 @@ class ProcessBuilder(object):
env={},
image=None,
artifact=None,
has_docker_socket=False,
extra_files={},
requires_gpu=False,
shm_size=None,
......@@ -109,7 +111,6 @@ class ProcessBuilder(object):
image=image,
requires_gpu=requires_gpu,
shm_size=shm_size,
has_docker_socket=has_docker_socket,
extra_files=extra_files,
image_artifact_id=artifact,
worker_run=worker_run,
......@@ -162,6 +163,10 @@ class ProcessBuilder(object):
env["ARKINDEX_CORPUS_ID"] = str(self.process.corpus_id)
return env
@cached_property
def active_gpu_agents(self) -> bool:
return GPU.objects.filter(agent__last_ping__gt=timezone.now() - timedelta(seconds=30)).exists()
@prefetch_worker_runs
def validate_gpu_requirement(self):
from arkindex.process.models import FeatureUsage
......@@ -184,10 +189,6 @@ class ProcessBuilder(object):
):
raise ValidationError("Some model versions are on archived models and cannot be executed.")
def validate_repository(self) -> None:
if self.process.revision is None:
raise ValidationError("A revision is required to create an import workflow from GitLab repository")
def validate_s3(self) -> None:
if not self.process.bucket_name:
raise ValidationError("Missing S3 bucket name")
......@@ -244,13 +245,6 @@ class ProcessBuilder(object):
)
self._create_worker_versions_cache([(settings.IMPORTS_WORKER_VERSION, None, None)])
def build_repository(self):
self._build_task(
command=f"python -m arkindex_tasks.import_git {self.process.revision.id}",
slug="import_git",
env=self.base_env,
)
def build_iiif(self):
from arkindex.process.models import WorkerVersion
......@@ -298,6 +292,7 @@ class ProcessBuilder(object):
chunk=index if len(chunks) > 1 else None,
workflow_runs=worker_runs,
run=self.run,
active_gpu_agents=self.active_gpu_agents,
)
self.tasks.append(task)
self.tasks_parents[task.slug].extend(parent_slugs)
......
......@@ -3,12 +3,18 @@
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
from enumfields import Enum, fields
import arkindex.process.models
import pgtrigger.compiler
import pgtrigger.migrations
class TmpProcessMode(Enum):
Repository = "repository"
Local = "local"
class Migration(migrations.Migration):
initial = True
......@@ -189,15 +195,25 @@ class Migration(migrations.Migration):
),
migrations.AddConstraint(
model_name="process",
constraint=models.CheckConstraint(check=models.Q(models.Q(("mode", arkindex.process.models.ProcessMode["Local"]), _negated=True), ("workflow", None), _connector="OR"), name="local_process_no_workflow", violation_error_message="Local processes cannot be started."),
constraint=models.UniqueConstraint(models.F("creator"), condition=models.Q(("mode", arkindex.process.models.ProcessMode["Local"])), name="unique_local_process", violation_error_message="Only one local process is allowed per user."),
),
migrations.AddConstraint(
model_name="process",
constraint=models.CheckConstraint(check=models.Q(("mode__in", (arkindex.process.models.ProcessMode["Local"], arkindex.process.models.ProcessMode["Repository"])), models.Q(("corpus", None), _negated=True), _connector="XOR"), name="check_process_corpus", violation_error_message="Local and repository processes cannot have a corpus, and other modes must have one set."),
constraint=models.CheckConstraint(check=models.Q(models.Q(("mode", arkindex.process.models.ProcessMode["Local"]), _negated=True), ("workflow", None), _connector="OR"), name="local_process_no_workflow", violation_error_message="Local processes cannot be started."),
),
migrations.AlterField(
model_name="process",
name="mode",
field=fields.EnumField(enum=TmpProcessMode, max_length=30)
),
migrations.AddConstraint(
model_name="process",
constraint=models.UniqueConstraint(models.F("creator"), condition=models.Q(("mode", arkindex.process.models.ProcessMode["Local"])), name="unique_local_process", violation_error_message="Only one local process is allowed per user."),
constraint=models.CheckConstraint(check=models.Q(("mode__in", (TmpProcessMode["Local"], TmpProcessMode["Repository"])), models.Q(("corpus", None), _negated=True), _connector="XOR"), name="check_process_corpus", violation_error_message="Local and repository processes cannot have a corpus, and other modes must have one set."),
),
migrations.AlterField(
model_name="process",
name="mode",
field=fields.EnumField(enum=arkindex.process.models.ProcessMode, max_length=30)
),
migrations.AlterUniqueTogether(
name="gitref",
......
......@@ -124,7 +124,7 @@ class Migration(migrations.Migration):
SET workflow_id = process_id
""",
],
# manage.py squashmigrations is allowed to remove this data migration
# `arkindex squashmigrations` is allowed to remove this data migration
elidable=True,
),
]