Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (6)
Showing
with 331 additions and 318 deletions
...@@ -12,8 +12,10 @@ repos: ...@@ -12,8 +12,10 @@ repos:
- 'flake8-debugger==3.1.0' - 'flake8-debugger==3.1.0'
- 'flake8-quotes==3.3.2' - 'flake8-quotes==3.3.2'
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11 # Ruff version.
rev: v0.3.7
hooks: hooks:
# Run the linter.
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
......
...@@ -31,7 +31,7 @@ class EntityTypeSerializer(serializers.ModelSerializer): ...@@ -31,7 +31,7 @@ class EntityTypeSerializer(serializers.ModelSerializer):
# Get an error if a request tries to change an entity type's corpus # Get an error if a request tries to change an entity type's corpus
corpus = data.get("corpus") corpus = data.get("corpus")
if self.instance and corpus: if self.instance and corpus:
raise ValidationError({"corpus": ["It is not possible to update an Entity Type\'s corpus."]}) raise ValidationError({"corpus": ["It is not possible to update an Entity Type's corpus."]})
data = super().to_internal_value(data) data = super().to_internal_value(data)
return data return data
......
...@@ -74,7 +74,6 @@ class TestCorpus(FixtureAPITestCase): ...@@ -74,7 +74,6 @@ class TestCorpus(FixtureAPITestCase):
mock_now.return_value = FAKE_NOW mock_now.return_value = FAKE_NOW
cls.corpus_hidden = Corpus.objects.create(name="C Hidden") cls.corpus_hidden = Corpus.objects.create(name="C Hidden")
@expectedFailure
def test_anon(self): def test_anon(self):
# An anonymous user has only access to public # An anonymous user has only access to public
with self.assertNumQueries(4): with self.assertNumQueries(4):
...@@ -225,7 +224,6 @@ class TestCorpus(FixtureAPITestCase): ...@@ -225,7 +224,6 @@ class TestCorpus(FixtureAPITestCase):
self.assertEqual(len(data), 13) self.assertEqual(len(data), 13)
self.assertSetEqual({corpus["top_level_type"] for corpus in data}, {None, "top_level"}) self.assertSetEqual({corpus["top_level_type"] for corpus in data}, {None, "top_level"})
@expectedFailure
def test_mixin(self): def test_mixin(self):
vol1 = Element.objects.get(name="Volume 1") vol1 = Element.objects.get(name="Volume 1")
vol2 = Element.objects.get(name="Volume 2") vol2 = Element.objects.get(name="Volume 2")
...@@ -345,7 +343,7 @@ class TestCorpus(FixtureAPITestCase): ...@@ -345,7 +343,7 @@ class TestCorpus(FixtureAPITestCase):
"description": self.corpus_public.description, "description": self.corpus_public.description,
"public": True, "public": True,
"indexable": False, "indexable": False,
"rights": ["read", "write", "admin"], "rights": ["read"],
"created": DB_CREATED, "created": DB_CREATED,
"authorized_users": 1, "authorized_users": 1,
"top_level_type": None, "top_level_type": None,
......
...@@ -43,7 +43,7 @@ class TestRetrieveElements(FixtureAPITestCase): ...@@ -43,7 +43,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"public": True, "public": True,
}, },
"thumbnail_url": self.vol.thumbnail.s3_url, "thumbnail_url": self.vol.thumbnail.s3_url,
"thumbnail_put_url": self.vol.thumbnail.s3_put_url, "thumbnail_put_url": None,
"worker_version": None, "worker_version": None,
"confidence": None, "confidence": None,
"zone": None, "zone": None,
...@@ -51,7 +51,7 @@ class TestRetrieveElements(FixtureAPITestCase): ...@@ -51,7 +51,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"mirrored": False, "mirrored": False,
"created": "2020-02-02T01:23:45.678000Z", "created": "2020-02-02T01:23:45.678000Z",
"creator": None, "creator": None,
"rights": ["read", "write", "admin"], "rights": ["read"],
"metadata_count": 0, "metadata_count": 0,
"classifications": [ "classifications": [
{ {
...@@ -102,6 +102,8 @@ class TestRetrieveElements(FixtureAPITestCase): ...@@ -102,6 +102,8 @@ class TestRetrieveElements(FixtureAPITestCase):
""" """
Check getting an element only gives a thumbnail URL with folders Check getting an element only gives a thumbnail URL with folders
""" """
self.client.force_login(self.user)
self.assertTrue(self.vol.type.folder) self.assertTrue(self.vol.type.folder)
response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) response = self.client.get(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -230,7 +232,7 @@ class TestRetrieveElements(FixtureAPITestCase): ...@@ -230,7 +232,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"public": True, "public": True,
}, },
"thumbnail_url": self.vol.thumbnail.s3_url, "thumbnail_url": self.vol.thumbnail.s3_url,
"thumbnail_put_url": self.vol.thumbnail.s3_put_url, "thumbnail_put_url": None,
"worker_version": str(self.worker_version.id), "worker_version": str(self.worker_version.id),
"confidence": None, "confidence": None,
"zone": None, "zone": None,
...@@ -238,7 +240,7 @@ class TestRetrieveElements(FixtureAPITestCase): ...@@ -238,7 +240,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"mirrored": False, "mirrored": False,
"created": "2020-02-02T01:23:45.678000Z", "created": "2020-02-02T01:23:45.678000Z",
"creator": None, "creator": None,
"rights": ["read", "write", "admin"], "rights": ["read"],
"metadata_count": 0, "metadata_count": 0,
"classifications": [], "classifications": [],
"worker_run": { "worker_run": {
...@@ -265,7 +267,7 @@ class TestRetrieveElements(FixtureAPITestCase): ...@@ -265,7 +267,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"public": True, "public": True,
}, },
"thumbnail_url": self.vol.thumbnail.s3_url, "thumbnail_url": self.vol.thumbnail.s3_url,
"thumbnail_put_url": self.vol.thumbnail.s3_put_url, "thumbnail_put_url": None,
"worker_version": None, "worker_version": None,
"confidence": None, "confidence": None,
"zone": None, "zone": None,
...@@ -273,7 +275,7 @@ class TestRetrieveElements(FixtureAPITestCase): ...@@ -273,7 +275,7 @@ class TestRetrieveElements(FixtureAPITestCase):
"mirrored": False, "mirrored": False,
"created": "2020-02-02T01:23:45.678000Z", "created": "2020-02-02T01:23:45.678000Z",
"creator": None, "creator": None,
"rights": ["read", "write", "admin"], "rights": ["read"],
"metadata_count": 0, "metadata_count": 0,
"classifications": [ "classifications": [
{ {
......
...@@ -2014,8 +2014,15 @@ class CreateProcessTemplate(ProcessACLMixin, WorkerACLMixin, CreateAPIView): ...@@ -2014,8 +2014,15 @@ class CreateProcessTemplate(ProcessACLMixin, WorkerACLMixin, CreateAPIView):
serializer_class = CreateProcessTemplateSerializer serializer_class = CreateProcessTemplateSerializer
def get_queryset(self): def get_queryset(self):
return Process.objects \ return (
.prefetch_related(Prefetch("worker_runs", queryset=WorkerRun.objects.select_related("version__worker__type"))) Process
.objects
.prefetch_related(Prefetch("worker_runs", queryset=WorkerRun.objects.select_related(
"version__worker__type",
"version__worker__repository",
)))
.select_related("corpus")
)
def check_object_permissions(self, request, process): def check_object_permissions(self, request, process):
access_level = self.process_access_level(process) access_level = self.process_access_level(process)
...@@ -2077,9 +2084,16 @@ class ApplyProcessTemplate(ProcessACLMixin, WorkerACLMixin, CreateAPIView): ...@@ -2077,9 +2084,16 @@ class ApplyProcessTemplate(ProcessACLMixin, WorkerACLMixin, CreateAPIView):
serializer_class = ApplyProcessTemplateSerializer serializer_class = ApplyProcessTemplateSerializer
def get_queryset(self): def get_queryset(self):
return Process.objects \ return (
.filter(mode=ProcessMode.Template) \ Process.objects
.prefetch_related(Prefetch("worker_runs", queryset=WorkerRun.objects.select_related("version__worker__type", "model_version__model"))) .filter(mode=ProcessMode.Template)
.prefetch_related(Prefetch("worker_runs", queryset=WorkerRun.objects.select_related(
"version__worker__type",
"version__worker__repository",
"model_version__model",
)))
.select_related("corpus")
)
def check_object_permissions(self, request, template): def check_object_permissions(self, request, template):
access_level = self.process_access_level(template) access_level = self.process_access_level(template)
......
...@@ -18,6 +18,7 @@ from arkindex.process.models import ( ...@@ -18,6 +18,7 @@ from arkindex.process.models import (
WorkerRun, WorkerRun,
WorkerVersionState, WorkerVersionState,
) )
from arkindex.process.utils import get_default_farm
from arkindex.project.mixins import ProcessACLMixin from arkindex.project.mixins import ProcessACLMixin
from arkindex.project.serializer_fields import EnumField, LinearRingField from arkindex.project.serializer_fields import EnumField, LinearRingField
from arkindex.project.validators import MaxValueValidator from arkindex.project.validators import MaxValueValidator
...@@ -26,7 +27,6 @@ from arkindex.users.models import Role ...@@ -26,7 +27,6 @@ from arkindex.users.models import Role
from arkindex.users.utils import get_max_level from arkindex.users.utils import get_max_level
ProcessFarmField = import_string(getattr(settings, "PROCESS_FARM_FIELD", None) or "arkindex.project.serializer_fields.NullField") ProcessFarmField = import_string(getattr(settings, "PROCESS_FARM_FIELD", None) or "arkindex.project.serializer_fields.NullField")
get_default_farm = import_string(getattr(settings, "GET_DEFAULT_FARM", None) or "arkindex.process.utils.get_default_farm")
class ProcessLightSerializer(serializers.ModelSerializer): class ProcessLightSerializer(serializers.ModelSerializer):
...@@ -522,7 +522,7 @@ class CreateProcessTemplateSerializer(serializers.ModelSerializer): ...@@ -522,7 +522,7 @@ class CreateProcessTemplateSerializer(serializers.ModelSerializer):
class ApplyProcessTemplateSerializer(ProcessACLMixin, serializers.Serializer): class ApplyProcessTemplateSerializer(ProcessACLMixin, serializers.Serializer):
process_id = serializers.PrimaryKeyRelatedField( process_id = serializers.PrimaryKeyRelatedField(
queryset=Process.objects.all(), queryset=Process.objects.select_related("corpus"),
source="process", source="process",
required=True, required=True,
help_text="ID of the process to apply the template to", help_text="ID of the process to apply the template to",
......
from unittest import expectedFailure
from unittest.mock import call, patch from unittest.mock import call, patch
from django.test import override_settings from django.test import override_settings
...@@ -19,8 +18,6 @@ class TestCreateS3Import(FixtureTestCase): ...@@ -19,8 +18,6 @@ class TestCreateS3Import(FixtureTestCase):
def setUpTestData(cls): def setUpTestData(cls):
super().setUpTestData() super().setUpTestData()
cls.import_worker_version = WorkerVersion.objects.get(worker__slug="file_import") cls.import_worker_version = WorkerVersion.objects.get(worker__slug="file_import")
cls.default_farm = Farm.objects.create(name="Crypto farm")
cls.default_farm.memberships.create(user=cls.user, level=Role.Guest.value)
def test_requires_login(self): def test_requires_login(self):
with self.assertNumQueries(0): with self.assertNumQueries(0):
...@@ -138,8 +135,9 @@ class TestCreateS3Import(FixtureTestCase): ...@@ -138,8 +135,9 @@ class TestCreateS3Import(FixtureTestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
ImageServer.objects.create(id=999, display_name="Ingest image server", url="https://dev.null.teklia.com") ImageServer.objects.create(id=999, display_name="Ingest image server", url="https://dev.null.teklia.com")
element = self.corpus.elements.get(name="Volume 1") element = self.corpus.elements.get(name="Volume 1")
farm = Farm.objects.create(name="Crypto farm")
with self.assertNumQueries(22), self.settings(IMPORTS_WORKER_VERSION=str(self.import_worker_version.id)): with self.assertNumQueries(23), self.settings(IMPORTS_WORKER_VERSION=str(self.import_worker_version.id)):
response = self.client.post(reverse("api:s3-import-create"), { response = self.client.post(reverse("api:s3-import-create"), {
"corpus_id": str(self.corpus.id), "corpus_id": str(self.corpus.id),
"element_id": str(element.id), "element_id": str(element.id),
...@@ -147,6 +145,7 @@ class TestCreateS3Import(FixtureTestCase): ...@@ -147,6 +145,7 @@ class TestCreateS3Import(FixtureTestCase):
"element_type": "page", "element_type": "page",
"bucket_name": "blah", "bucket_name": "blah",
"prefix": "a/b/c", "prefix": "a/b/c",
"farm_id": str(farm.id),
}) })
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.json()) self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.json())
data = response.json() data = response.json()
...@@ -161,6 +160,7 @@ class TestCreateS3Import(FixtureTestCase): ...@@ -161,6 +160,7 @@ class TestCreateS3Import(FixtureTestCase):
self.assertEqual(process.element_type, self.corpus.types.get(slug="page")) self.assertEqual(process.element_type, self.corpus.types.get(slug="page"))
self.assertEqual(process.bucket_name, "blah") self.assertEqual(process.bucket_name, "blah")
self.assertEqual(process.prefix, "a/b/c") self.assertEqual(process.prefix, "a/b/c")
self.assertEqual(process.farm, farm)
worker_run = process.worker_runs.get() worker_run = process.worker_runs.get()
self.assertEqual(worker_run.version, self.import_worker_version) self.assertEqual(worker_run.version, self.import_worker_version)
...@@ -245,19 +245,21 @@ class TestCreateS3Import(FixtureTestCase): ...@@ -245,19 +245,21 @@ class TestCreateS3Import(FixtureTestCase):
"INGEST_S3_SECRET_KEY": "its-secret-i-wont-tell-you", "INGEST_S3_SECRET_KEY": "its-secret-i-wont-tell-you",
}) })
@expectedFailure
@override_settings(INGEST_IMAGESERVER_ID=999) @override_settings(INGEST_IMAGESERVER_ID=999)
@patch("arkindex.users.utils.get_max_level", return_value=None) @patch("arkindex.ponos.models.Farm.is_available", return_value=False)
def test_farm_guest(self, get_max_level_mock): def test_farm_guest(self, is_available_mock):
self.user.user_scopes.create(scope=Scope.S3Ingest) self.user.user_scopes.create(scope=Scope.S3Ingest)
self.client.force_login(self.user) self.client.force_login(self.user)
self.corpus.types.create(slug="folder", display_name="Folder", folder=True) self.corpus.types.create(slug="folder", display_name="Folder", folder=True)
ImageServer.objects.create(id=999, display_name="Ingest image server", url="https://dev.null.teklia.com") ImageServer.objects.create(id=999, display_name="Ingest image server", url="https://dev.null.teklia.com")
farm = Farm.objects.create(name="Crypto farm")
with self.assertNumQueries(5), self.settings(IMPORTS_WORKER_VERSION=str(self.import_worker_version.id)): with self.assertNumQueries(5), self.settings(IMPORTS_WORKER_VERSION=str(self.import_worker_version.id)):
response = self.client.post(reverse("api:s3-import-create"), { response = self.client.post(reverse("api:s3-import-create"), {
"corpus_id": str(self.corpus.id), "corpus_id": str(self.corpus.id),
"bucket_name": "blah", "bucket_name": "blah",
"farm_id": str(farm.id),
}) })
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
...@@ -266,20 +268,43 @@ class TestCreateS3Import(FixtureTestCase): ...@@ -266,20 +268,43 @@ class TestCreateS3Import(FixtureTestCase):
}) })
self.assertFalse(Process.objects.filter(mode=ProcessMode.S3).exists()) self.assertFalse(Process.objects.filter(mode=ProcessMode.S3).exists())
self.assertEqual(get_max_level_mock.call_count, 1) self.assertEqual(is_available_mock.call_count, 1)
self.assertEqual(is_available_mock.call_args, call(self.user))
@expectedFailure
@override_settings(INGEST_IMAGESERVER_ID=999) @override_settings(INGEST_IMAGESERVER_ID=999)
@patch("arkindex.users.utils.get_max_level", return_value=None) @patch("arkindex.process.serializers.ingest.get_default_farm")
def test_default_farm_guest(self, get_max_level_mock): def test_default_farm(self, get_default_farm_mock):
self.user.user_scopes.create(scope=Scope.S3Ingest) self.user.user_scopes.create(scope=Scope.S3Ingest)
self.client.force_login(self.user) self.client.force_login(self.user)
self.corpus.types.create(slug="folder", display_name="Folder", folder=True) self.corpus.types.create(slug="folder", display_name="Folder", folder=True)
ImageServer.objects.create(id=999, display_name="Ingest image server", url="https://dev.null.teklia.com") ImageServer.objects.create(id=999, display_name="Ingest image server", url="https://dev.null.teklia.com")
self.default_farm.memberships.filter(user=self.user).delete() default_farm = Farm.objects.create(name="Crypto farm")
get_default_farm_mock.return_value = default_farm
with self.assertNumQueries(5), self.settings(IMPORTS_WORKER_VERSION=str(self.import_worker_version.id)): with self.assertNumQueries(21), self.settings(IMPORTS_WORKER_VERSION=str(self.import_worker_version.id)):
response = self.client.post(reverse("api:s3-import-create"), {
"corpus_id": str(self.corpus.id),
"bucket_name": "blah",
})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
process = Process.objects.get(id=response.json()["id"])
self.assertEqual(process.farm, default_farm)
@override_settings(INGEST_IMAGESERVER_ID=999)
@patch("arkindex.ponos.models.Farm.is_available", return_value=False)
@patch("arkindex.process.serializers.ingest.get_default_farm")
def test_default_farm_guest(self, get_default_farm_mock, is_available_mock):
self.user.user_scopes.create(scope=Scope.S3Ingest)
self.client.force_login(self.user)
self.corpus.types.create(slug="folder", display_name="Folder", folder=True)
ImageServer.objects.create(id=999, display_name="Ingest image server", url="https://dev.null.teklia.com")
default_farm = Farm.objects.create(name="Crypto farm")
get_default_farm_mock.return_value = default_farm
with self.assertNumQueries(4), self.settings(IMPORTS_WORKER_VERSION=str(self.import_worker_version.id)):
response = self.client.post(reverse("api:s3-import-create"), { response = self.client.post(reverse("api:s3-import-create"), {
"corpus_id": str(self.corpus.id), "corpus_id": str(self.corpus.id),
"bucket_name": "blah", "bucket_name": "blah",
...@@ -291,5 +316,5 @@ class TestCreateS3Import(FixtureTestCase): ...@@ -291,5 +316,5 @@ class TestCreateS3Import(FixtureTestCase):
}) })
self.assertFalse(Process.objects.filter(mode=ProcessMode.S3).exists()) self.assertFalse(Process.objects.filter(mode=ProcessMode.S3).exists())
self.assertEqual(get_max_level_mock.call_count, 1) self.assertEqual(is_available_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.default_farm)) self.assertEqual(is_available_mock.call_args, call(self.user))
import uuid import uuid
from unittest import expectedFailure
from unittest.mock import call, patch from unittest.mock import call, patch
from django.conf import settings from django.conf import settings
...@@ -753,16 +752,35 @@ class TestProcesses(FixtureAPITestCase): ...@@ -753,16 +752,35 @@ class TestProcesses(FixtureAPITestCase):
{"__all__": ["Please wait activities to be initialized before deleting this process"]} {"__all__": ["Please wait activities to be initialized before deleting this process"]}
) )
@expectedFailure @patch("arkindex.project.mixins.get_max_level", return_value=None)
def test_delete_process_no_permission(self): def test_delete_process_no_permission(self, get_max_level_mock):
""" """
A user cannot delete a process linked to a corpus they have no admin access to A user cannot delete a process linked to a corpus they have no access to
""" """
self.client.force_login(self.user) self.client.force_login(self.user)
self.assertFalse(self.user_img_process.corpus.memberships.filter(user=self.user).exists())
response = self.client.delete(reverse("api:process-details", kwargs={"pk": self.user_img_process.id})) with self.assertNumQueries(4):
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) response = self.client.delete(reverse("api:process-details", kwargs={"pk": self.user_img_process.id}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertDictEqual(response.json(), {"detail": "Not found."}) self.assertDictEqual(response.json(), {"detail": "Not found."})
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.private_corpus))
@patch("arkindex.project.mixins.get_max_level", return_value=Role.Contributor.value)
def test_delete_process_not_admin(self, get_max_level_mock):
"""
A user cannot delete a process linked to a corpus they have no admin access to
"""
self.client.force_login(self.user)
with self.assertNumQueries(4):
response = self.client.delete(reverse("api:process-details", kwargs={"pk": self.user_img_process.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have a sufficient access level to this process."})
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.private_corpus))
@patch("arkindex.project.triggers.process_tasks.process_delete.delay") @patch("arkindex.project.triggers.process_tasks.process_delete.delay")
def test_delete_worker_run_in_use(self, process_delay_mock): def test_delete_worker_run_in_use(self, process_delay_mock):
...@@ -944,16 +962,20 @@ class TestProcesses(FixtureAPITestCase): ...@@ -944,16 +962,20 @@ class TestProcesses(FixtureAPITestCase):
process.refresh_from_db() process.refresh_from_db()
self.assertEqual(process.element, element) self.assertEqual(process.element, element)
@expectedFailure @patch("arkindex.project.mixins.get_max_level", return_value=None)
def test_partial_update_no_permission(self): def test_partial_update_no_permission(self, get_max_level_mock):
""" """
A user cannot update a process linked to a corpus he has no admin access to A user cannot update a process linked to a corpus they do not have admin access to
""" """
self.client.force_login(self.user) self.client.force_login(self.user)
self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value)
response = self.client.patch(reverse("api:process-details", kwargs={"pk": self.elts_process.id})) with self.assertNumQueries(5):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) response = self.client.patch(reverse("api:process-details", kwargs={"pk": self.elts_process.id}))
self.assertDictEqual(response.json(), {"detail": "You do not have a sufficient access level to this process."}) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertDictEqual(response.json(), {"detail": "Not found."})
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.corpus))
def test_partial_update_stop(self): def test_partial_update_stop(self):
""" """
...@@ -1374,13 +1396,17 @@ class TestProcesses(FixtureAPITestCase): ...@@ -1374,13 +1396,17 @@ class TestProcesses(FixtureAPITestCase):
process.refresh_from_db() process.refresh_from_db()
self.assertEqual(process.name, "newName") self.assertEqual(process.name, "newName")
@expectedFailure @patch("arkindex.project.mixins.get_max_level", return_value=Role.Guest.value)
def test_partial_update_corpus_no_write_right(self): def test_partial_update_corpus_no_write_right(self, get_max_level_mock):
self.client.force_login(self.user) self.client.force_login(self.user)
self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value)
response = self.client.patch(reverse("api:process-details", kwargs={"pk": self.elts_process.id})) with self.assertNumQueries(5):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) response = self.client.patch(reverse("api:process-details", kwargs={"pk": self.elts_process.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have a sufficient access level to this process."}) self.assertDictEqual(response.json(), {"detail": "You do not have a sufficient access level to this process."})
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.corpus))
def test_partial_update(self): def test_partial_update(self):
""" """
...@@ -1584,24 +1610,27 @@ class TestProcesses(FixtureAPITestCase): ...@@ -1584,24 +1610,27 @@ class TestProcesses(FixtureAPITestCase):
# Activity initialization runs again # Activity initialization runs again
self.assertFalse(delay_mock.called) self.assertFalse(delay_mock.called)
@expectedFailure @patch("arkindex.ponos.models.Farm.is_available", return_value=False)
def test_retry_farm_guest(self): def test_retry_farm_unavailable(self, is_available_mock):
self.elts_process.run() self.elts_process.run()
self.elts_process.tasks.all().update(state=State.Error) self.elts_process.tasks.all().update(state=State.Error)
self.elts_process.finished = timezone.now() self.elts_process.finished = timezone.now()
self.elts_process.farm = Farm.objects.get(name="Wheat farm")
self.elts_process.save() self.elts_process.save()
self.assertEqual(self.elts_process.state, State.Error) self.assertEqual(self.elts_process.state, State.Error)
self.elts_process.farm.memberships.filter(user=self.user).delete() self.elts_process.farm.memberships.filter(user=self.user).delete()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(10): with self.assertNumQueries(6):
response = self.client.post(reverse("api:process-retry", kwargs={"pk": self.elts_process.id})) response = self.client.post(reverse("api:process-retry", kwargs={"pk": self.elts_process.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), { self.assertEqual(response.json(), {
"farm": ["You do not have access to this farm."], "farm": ["You do not have access to this farm."],
}) })
self.assertEqual(is_available_mock.call_count, 1)
self.assertEqual(is_available_mock.call_args, call(self.user))
@override_settings(PUBLIC_HOSTNAME="https://darkindex.lol") @override_settings(PUBLIC_HOSTNAME="https://darkindex.lol")
def test_retry_archived_worker(self): def test_retry_archived_worker(self):
...@@ -2047,27 +2076,50 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2047,27 +2076,50 @@ class TestProcesses(FixtureAPITestCase):
process = Process.objects.get(id=data["id"]) process = Process.objects.get(id=data["id"])
self.assertEqual(process.farm, farm) self.assertEqual(process.farm, farm)
@expectedFailure @patch("arkindex.process.serializers.imports.get_default_farm")
def test_from_files_farm_guest(self): def test_from_files_default_farm(self, get_default_farm_mock):
farm = Farm.objects.get(name="Wheat farm")
get_default_farm_mock.return_value = farm
self.client.force_login(self.user)
self.assertEqual(self.version_with_model.worker_runs.count(), 0)
with (
self.settings(IMPORTS_WORKER_VERSION=str(self.version_with_model.id)),
self.assertNumQueries(25),
):
response = self.client.post(reverse("api:files-process"), {
"files": [str(self.img_df.id)],
"folder_type": "volume",
"element_type": "page",
}, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
process = Process.objects.get(id=data["id"])
self.assertEqual(process.farm, farm)
@patch("arkindex.ponos.models.Farm.is_available", return_value=False)
def test_from_files_farm_guest(self, is_available_mock):
self.client.force_login(self.user) self.client.force_login(self.user)
self.assertEqual(self.version_with_model.worker_runs.count(), 0) self.assertEqual(self.version_with_model.worker_runs.count(), 0)
self.other_farm.memberships.filter(user=self.user).delete()
with ( with (
self.settings(IMPORTS_WORKER_VERSION=str(self.version_with_model.id)), self.settings(IMPORTS_WORKER_VERSION=str(self.version_with_model.id)),
self.assertNumQueries(9) self.assertNumQueries(6)
): ):
response = self.client.post(reverse("api:files-process"), { response = self.client.post(reverse("api:files-process"), {
"files": [str(self.img_df.id)], "files": [str(self.img_df.id)],
"folder_type": "volume", "folder_type": "volume",
"element_type": "page", "element_type": "page",
"farm_id": str(self.other_farm.id), "farm_id": str(Farm.objects.first().id),
}, format="json") }, format="json")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), { self.assertEqual(response.json(), {
"farm_id": ["You do not have access to this farm."], "farm_id": ["You do not have access to this farm."],
}) })
self.assertEqual(is_available_mock.call_count, 1)
self.assertEqual(is_available_mock.call_args, call(self.user))
def test_start_process_requires_login(self): def test_start_process_requires_login(self):
response = self.client.post( response = self.client.post(
...@@ -2442,14 +2494,37 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2442,14 +2494,37 @@ class TestProcesses(FixtureAPITestCase):
self.assertEqual(workers_process.state, State.Unscheduled) self.assertEqual(workers_process.state, State.Unscheduled)
self.assertEqual(workers_process.farm_id, farm.id) self.assertEqual(workers_process.farm_id, farm.id)
@expectedFailure @patch("arkindex.process.serializers.imports.get_default_farm")
def test_start_process_default_farm_guest(self): def test_start_process_default_farm(self, get_default_farm_mock):
farm = Farm.objects.get(name="Wheat farm")
get_default_farm_mock.return_value = farm
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process2.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.json()["id"], str(process2.id))
process2.refresh_from_db()
self.assertEqual(process2.state, State.Unscheduled)
self.assertEqual(process2.farm, farm)
@patch("arkindex.ponos.models.Farm.is_available", return_value=False)
@patch("arkindex.process.serializers.imports.get_default_farm")
def test_start_process_default_farm_guest(self, get_default_farm_mock, is_available_mock):
get_default_farm_mock.return_value = Farm.objects.first()
workers_process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) workers_process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
workers_process.worker_runs.create(version=self.recognizer, parents=[], configuration=None) workers_process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user) self.client.force_login(self.user)
self.default_farm.memberships.filter(user=self.user).delete()
with self.assertNumQueries(10): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(workers_process.id)}) reverse("api:process-start", kwargs={"pk": str(workers_process.id)})
) )
...@@ -2461,18 +2536,20 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2461,18 +2536,20 @@ class TestProcesses(FixtureAPITestCase):
workers_process.refresh_from_db() workers_process.refresh_from_db()
self.assertEqual(workers_process.state, State.Unscheduled) self.assertEqual(workers_process.state, State.Unscheduled)
self.assertIsNone(workers_process.farm) self.assertIsNone(workers_process.farm)
self.assertEqual(get_default_farm_mock.call_count, 1)
self.assertEqual(is_available_mock.call_count, 1)
self.assertEqual(is_available_mock.call_args, call(self.user))
@expectedFailure @patch("arkindex.ponos.models.Farm.is_available", return_value=False)
def test_start_process_farm_guest(self): def test_start_process_farm_guest(self, is_available_mock):
workers_process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) workers_process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
workers_process.worker_runs.create(version=self.recognizer, parents=[], configuration=None) workers_process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user) self.client.force_login(self.user)
self.other_farm.memberships.filter(user=self.user).delete()
with self.assertNumQueries(10): with self.assertNumQueries(7):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(workers_process.id)}), reverse("api:process-start", kwargs={"pk": str(workers_process.id)}),
{"farm": str(self.other_farm.id)} {"farm": str(Farm.objects.first().id)}
) )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
...@@ -2482,6 +2559,8 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2482,6 +2559,8 @@ class TestProcesses(FixtureAPITestCase):
workers_process.refresh_from_db() workers_process.refresh_from_db()
self.assertEqual(workers_process.state, State.Unscheduled) self.assertEqual(workers_process.state, State.Unscheduled)
self.assertIsNone(workers_process.farm) self.assertIsNone(workers_process.farm)
self.assertEqual(is_available_mock.call_count, 1)
self.assertEqual(is_available_mock.call_args, call(self.user))
def test_start_process_wrong_farm_id(self): def test_start_process_wrong_farm_id(self):
""" """
...@@ -2857,8 +2936,8 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2857,8 +2936,8 @@ class TestProcesses(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"__all__": ["A process can only be cleared before getting started."]}) self.assertDictEqual(response.json(), {"__all__": ["A process can only be cleared before getting started."]})
@expectedFailure @patch("arkindex.project.mixins.get_max_level", return_value=Role.Contributor.value)
def test_clear_process_requires_permissions(self): def test_clear_process_requires_permissions(self, get_max_level_mock):
process = self.corpus.processes.create( process = self.corpus.processes.create(
creator=self.user, creator=self.user,
mode=ProcessMode.Workers, mode=ProcessMode.Workers,
...@@ -2872,12 +2951,14 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2872,12 +2951,14 @@ class TestProcesses(FixtureAPITestCase):
user2.verified_email = True user2.verified_email = True
user2.save() user2.save()
self.corpus.memberships.create(user=user2, level=Role.Contributor.value)
self.client.force_login(user2) self.client.force_login(user2)
with self.assertNumQueries(8): with self.assertNumQueries(4):
response = self.client.delete(reverse("api:clear-process", kwargs={"pk": str(process.id)})) response = self.client.delete(reverse("api:clear-process", kwargs={"pk": str(process.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have a sufficient access level to this process."}) self.assertDictEqual(response.json(), {"detail": "You do not have a sufficient access level to this process."})
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(user2, self.corpus))
def test_select_failed_elts_forbidden_methods(self): def test_select_failed_elts_forbidden_methods(self):
self.client.force_login(self.user) self.client.force_login(self.user)
...@@ -2905,16 +2986,17 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2905,16 +2986,17 @@ class TestProcesses(FixtureAPITestCase):
response = self.client.post(reverse("api:process-select-failures", kwargs={"pk": str(uuid.uuid4())})) response = self.client.post(reverse("api:process-select-failures", kwargs={"pk": str(uuid.uuid4())}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@expectedFailure @patch("arkindex.project.mixins.get_max_level", return_value=None)
def test_select_failed_elts_requires_corpus_read_access(self): def test_select_failed_elts_requires_corpus_read_access(self, get_max_level_mock):
self.client.force_login(self.user) self.client.force_login(self.user)
self.elts_process.corpus.memberships.filter(user=self.user).delete()
self.elts_process.corpus.public = False with self.assertNumQueries(3):
self.elts_process.corpus.save()
with self.assertNumQueries(5):
response = self.client.post(reverse("api:process-select-failures", kwargs={"pk": str(self.elts_process.id)})) response = self.client.post(reverse("api:process-select-failures", kwargs={"pk": str(self.elts_process.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have a read access to the corpus of this process."}) self.assertDictEqual(response.json(), {"detail": "You do not have a read access to the corpus of this process."})
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.corpus))
def test_select_failed_elts_requires_workers_mode(self): def test_select_failed_elts_requires_workers_mode(self):
self.client.force_login(self.user) self.client.force_login(self.user)
......
import json import json
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest import expectedFailure from unittest.mock import call, patch
from rest_framework import status from rest_framework import status
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
...@@ -82,7 +82,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -82,7 +82,7 @@ class TestTemplates(FixtureAPITestCase):
def test_create(self): def test_create(self):
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(9): with self.assertNumQueries(8):
response = self.client.post( response = self.client.post(
reverse( reverse(
"api:create-process-template", kwargs={"pk": str(self.process_template.id)} "api:create-process-template", kwargs={"pk": str(self.process_template.id)}
...@@ -132,38 +132,47 @@ class TestTemplates(FixtureAPITestCase): ...@@ -132,38 +132,47 @@ class TestTemplates(FixtureAPITestCase):
) )
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@expectedFailure @patch("arkindex.project.mixins.get_max_level", return_value=Role.Guest.value)
def test_create_requires_contributor_access_rights_process(self): def test_create_requires_contributor_access_rights_process(self, get_max_level_mock):
new_user = User.objects.create(email="new@test.fr", verified_email=True) self.client.force_login(self.user)
self.worker_1.memberships.create(user=new_user, level=Role.Contributor.value)
self.private_corpus.memberships.create(user=new_user, level=Role.Guest.value) with self.assertNumQueries(4):
self.client.force_login(new_user) response = self.client.post(
response = self.client.post( reverse(
reverse( "api:create-process-template",
"api:create-process-template", kwargs={"pk": str(self.private_process_template.id)},
kwargs={"pk": str(self.private_process_template.id)}, )
) )
) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual( self.assertDictEqual(
response.json(), response.json(),
{"detail": "You do not have a contributor access to this process."}, {"detail": "You do not have a contributor access to this process."},
) )
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.private_corpus))
@patch("arkindex.project.mixins.has_access", return_value=False)
@patch("arkindex.project.mixins.get_max_level", return_value=Role.Contributor.value)
def test_create_requires_access_rights_all_workers(self, get_max_level_mock, has_access_mock):
self.client.force_login(self.user)
with self.assertNumQueries(4):
response = self.client.post(
reverse("api:create-process-template", kwargs={"pk": str(self.private_process_template.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@expectedFailure
def test_create_requires_access_rights_all_workers(self):
new_user = User.objects.create(email="new@test.fr", verified_email=True)
self.private_corpus.memberships.create(user=new_user, level=Role.Contributor.value)
self.worker_1.memberships.create(user=new_user, level=Role.Guest.value)
self.client.force_login(new_user)
response = self.client.post(
reverse("api:create-process-template", kwargs={"pk": str(self.private_process_template.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual( self.assertEqual(
response.json(), response.json(),
["You do not have an execution access to every worker of this process."], ["You do not have an execution access to every worker of this process."],
) )
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.private_corpus))
self.assertEqual(has_access_mock.call_args_list, [
call(self.user, self.worker_1, Role.Contributor.value, skip_public=False),
call(self.user, self.worker_1.repository, Role.Contributor.value, skip_public=False),
])
def test_create_unsupported_mode(self): def test_create_unsupported_mode(self):
self.client.force_login(self.user) self.client.force_login(self.user)
...@@ -172,7 +181,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -172,7 +181,7 @@ class TestTemplates(FixtureAPITestCase):
self.process.mode = mode self.process.mode = mode
self.process.save() self.process.save()
with self.assertNumQueries(5): with self.assertNumQueries(4):
response = self.client.post( response = self.client.post(
reverse("api:create-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:create-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -188,7 +197,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -188,7 +197,7 @@ class TestTemplates(FixtureAPITestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
local_process = self.user.processes.get(mode=ProcessMode.Local) local_process = self.user.processes.get(mode=ProcessMode.Local)
with self.assertNumQueries(5): with self.assertNumQueries(4):
response = self.client.post( response = self.client.post(
reverse("api:create-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:create-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(local_process.id)}), data=json.dumps({"process_id": str(local_process.id)}),
...@@ -200,74 +209,72 @@ class TestTemplates(FixtureAPITestCase): ...@@ -200,74 +209,72 @@ class TestTemplates(FixtureAPITestCase):
local_process.refresh_from_db() local_process.refresh_from_db()
self.assertEqual(local_process.template, None) self.assertEqual(local_process.template, None)
@expectedFailure @patch("arkindex.project.mixins.get_max_level", return_value=Role.Guest.value)
def test_apply_requires_contributor_rights_on_template(self): def test_apply_requires_contributor_rights_on_template(self, get_max_level_mock):
"""Raise 403 if the user does not have rights on template """Raise 403 if the user does not have rights on template
""" """
new_user = User.objects.create(email="new@test.fr", verified_email=True) self.client.force_login(self.user)
# rights on worker of template
self.worker_1.memberships.create(user=new_user, level=Role.Contributor.value) with self.assertNumQueries(4):
# no rights on template response = self.client.post(
self.private_corpus.memberships.create(user=new_user, level=Role.Guest.value) reverse("api:apply-process-template", kwargs={"pk": str(self.private_template.id)}),
# rights on target process data=json.dumps({"process_id": str(self.process.id)}),
self.corpus.memberships.create(user=new_user, level=Role.Contributor.value) content_type="application/json",
self.client.force_login(new_user) )
response = self.client.post( self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
reverse("api:apply-process-template", kwargs={"pk": str(self.private_template.id)}),
data=json.dumps({"process_id": str(self.process.id)}),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual( self.assertDictEqual(
response.json(), response.json(),
{"detail": "You do not have a contributor access to this process."}, {"detail": "You do not have a contributor access to this process."},
) )
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.private_corpus))
@expectedFailure @patch("arkindex.project.mixins.get_max_level", side_effect=[Role.Contributor.value, Role.Guest.value])
def test_apply_requires_contributor_rights_on_process(self): def test_apply_requires_contributor_rights_on_process(self, get_max_level_mock):
"""Raise 403 if the user does not have rights on the target process """Raise 403 if the user does not have rights on the target process
""" """
new_user = User.objects.create(email="new@test.fr", verified_email=True) self.client.force_login(self.user)
# rights on worker of template
self.worker_1.memberships.create(user=new_user, level=Role.Contributor.value) with self.assertNumQueries(5):
# rights on template response = self.client.post(
self.private_corpus.memberships.create(user=new_user, level=Role.Contributor.value) reverse("api:apply-process-template", kwargs={"pk": str(self.private_template.id)}),
# no rights on target process data=json.dumps({"process_id": str(self.process.id)}),
self.corpus.memberships.create(user=new_user, level=Role.Guest.value) content_type="application/json",
self.client.force_login(new_user) )
response = self.client.post( self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
reverse("api:apply-process-template", kwargs={"pk": str(self.private_template.id)}),
data=json.dumps({"process_id": str(self.process.id)}),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual( self.assertDictEqual(
response.json(), response.json(),
{"detail": "You do not have a contributor access to this process."}, {"detail": "You do not have a contributor access to this process."},
) )
self.assertEqual(get_max_level_mock.call_args_list, [
call(self.user, self.private_corpus),
call(self.user, self.corpus),
])
@expectedFailure @patch("arkindex.project.mixins.has_access", return_value=False)
def test_apply_requires_access_rights_all_workers(self): def test_apply_requires_access_rights_all_workers(self, has_access_mock):
"""Raise 403 if the user does not have rights on all workers concerned """Raise 403 if the user does not have rights on all workers concerned
""" """
new_user = User.objects.create(email="new@test.fr", verified_email=True) self.client.force_login(self.user)
# no rights on worker of template
self.worker_1.memberships.create(user=new_user, level=Role.Guest.value) with self.assertNumQueries(4):
# rights on template response = self.client.post(
self.private_corpus.memberships.create(user=new_user, level=Role.Contributor.value) reverse("api:apply-process-template", kwargs={"pk": str(self.private_template.id)}),
# rights on target process data=json.dumps({"process_id": str(self.process.id)}),
self.corpus.memberships.create(user=new_user, level=Role.Contributor.value) content_type="application/json",
self.client.force_login(new_user) )
response = self.client.post( self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
reverse("api:apply-process-template", kwargs={"pk": str(self.private_template.id)}),
data=json.dumps({"process_id": str(self.process.id)}),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual( self.assertDictEqual(
response.json(), response.json(),
{"detail": "You do not have an execution access to this worker."}, {"detail": "You do not have an execution access to this worker."},
) )
self.assertEqual(has_access_mock.call_args_list, [
call(self.user, self.worker_1, Role.Contributor.value, skip_public=False),
call(self.user, self.worker_1.repository, Role.Contributor.value, skip_public=False),
])
def test_apply_already_applied(self): def test_apply_already_applied(self):
"""Raise 400 if the process already has a template attached """Raise 400 if the process already has a template attached
...@@ -301,7 +308,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -301,7 +308,7 @@ class TestTemplates(FixtureAPITestCase):
def test_apply(self): def test_apply(self):
self.assertIsNotNone(self.version_2.docker_image_iid) self.assertIsNotNone(self.version_2.docker_image_iid)
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(11): with self.assertNumQueries(9):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -334,7 +341,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -334,7 +341,7 @@ class TestTemplates(FixtureAPITestCase):
self.version_2.save() self.version_2.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(11): with self.assertNumQueries(9):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -354,7 +361,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -354,7 +361,7 @@ class TestTemplates(FixtureAPITestCase):
parents=[], parents=[],
) )
# Apply a template that has two other worker runs # Apply a template that has two other worker runs
with self.assertNumQueries(13): with self.assertNumQueries(11):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(process.id)}), data=json.dumps({"process_id": str(process.id)}),
...@@ -382,7 +389,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -382,7 +389,7 @@ class TestTemplates(FixtureAPITestCase):
self.version_2.state = WorkerVersionState.Error self.version_2.state = WorkerVersionState.Error
self.version_2.save() self.version_2.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(7): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -397,7 +404,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -397,7 +404,7 @@ class TestTemplates(FixtureAPITestCase):
self.worker_2.archived = datetime.now(timezone.utc) self.worker_2.archived = datetime.now(timezone.utc)
self.worker_2.save() self.worker_2.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(7): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -413,7 +420,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -413,7 +420,7 @@ class TestTemplates(FixtureAPITestCase):
self.model_version.save() self.model_version.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(7): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -431,7 +438,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -431,7 +438,7 @@ class TestTemplates(FixtureAPITestCase):
self.model.save() self.model.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(7): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -450,7 +457,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -450,7 +457,7 @@ class TestTemplates(FixtureAPITestCase):
self.process.mode = mode self.process.mode = mode
self.process.save() self.process.save()
with self.assertNumQueries(7): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}), data=json.dumps({"process_id": str(self.process.id)}),
...@@ -468,7 +475,7 @@ class TestTemplates(FixtureAPITestCase): ...@@ -468,7 +475,7 @@ class TestTemplates(FixtureAPITestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
local_process = self.user.processes.get(mode=ProcessMode.Local) local_process = self.user.processes.get(mode=ProcessMode.Local)
with self.assertNumQueries(6): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(local_process.id)}), data=json.dumps({"process_id": str(local_process.id)}),
......
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest import expectedFailure
from unittest.mock import call, patch from unittest.mock import call, patch
from django.test import override_settings from django.test import override_settings
...@@ -97,6 +96,9 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -97,6 +96,9 @@ class TestWorkerRuns(FixtureAPITestCase):
ram_total=99e9, ram_total=99e9,
last_ping=datetime.now(timezone.utc), last_ping=datetime.now(timezone.utc),
) )
# Add custom attributes to make the agent usable as an authenticated user
cls.agent.is_agent = True
cls.agent.is_anonymous = False
def test_list_requires_login(self): def test_list_requires_login(self):
with self.assertNumQueries(0): with self.assertNumQueries(0):
...@@ -964,8 +966,9 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -964,8 +966,9 @@ class TestWorkerRuns(FixtureAPITestCase):
""" """
self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent) self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent)
self.client.force_login(self.user) # Agent auth is not implemented in CE
with self.assertNumQueries(7): self.client.force_authenticate(user=self.agent)
with self.assertNumQueries(5):
response = self.client.get( response = self.client.get(
reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}), reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}),
) )
...@@ -1015,7 +1018,6 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -1015,7 +1018,6 @@ class TestWorkerRuns(FixtureAPITestCase):
"summary": f"Worker Recognizer @ {str(self.version_1.id)[:6]}", "summary": f"Worker Recognizer @ {str(self.version_1.id)[:6]}",
}) })
@expectedFailure
def test_retrieve_agent_unassigned(self): def test_retrieve_agent_unassigned(self):
""" """
A Ponos agent cannot retrieve a WorkerRun on a process where it does not have any assigned tasks A Ponos agent cannot retrieve a WorkerRun on a process where it does not have any assigned tasks
...@@ -1023,8 +1025,8 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -1023,8 +1025,8 @@ class TestWorkerRuns(FixtureAPITestCase):
self.process_1.tasks.create(run=0, depth=0, slug="something", agent=None) self.process_1.tasks.create(run=0, depth=0, slug="something", agent=None)
# Agent auth is not implemented in CE # Agent auth is not implemented in CE
self.client.force_login(self.user) self.client.force_authenticate(user=self.agent)
with self.assertNumQueries(3): with self.assertNumQueries(1):
response = self.client.get( response = self.client.get(
reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}), reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}),
) )
...@@ -2007,7 +2009,6 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -2007,7 +2009,6 @@ class TestWorkerRuns(FixtureAPITestCase):
"__all__": ["A WorkerRun already exists on this process with the selected worker version, model version and configuration."], "__all__": ["A WorkerRun already exists on this process with the selected worker version, model version and configuration."],
}) })
@expectedFailure
def test_update_agent(self): def test_update_agent(self):
""" """
Ponos agents cannot update WorkerRuns, even when they can access them Ponos agents cannot update WorkerRuns, even when they can access them
...@@ -2015,8 +2016,8 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -2015,8 +2016,8 @@ class TestWorkerRuns(FixtureAPITestCase):
self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent) self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent)
# Agent auth is not implemented in CE # Agent auth is not implemented in CE
self.client.force_login(self.user) self.client.force_authenticate(user=self.agent)
with self.assertNumQueries(4): with self.assertNumQueries(2):
response = self.client.put( response = self.client.put(
reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}), reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}),
) )
...@@ -2932,7 +2933,6 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -2932,7 +2933,6 @@ class TestWorkerRuns(FixtureAPITestCase):
"__all__": ["A WorkerRun already exists on this process with the selected worker version, model version and configuration."], "__all__": ["A WorkerRun already exists on this process with the selected worker version, model version and configuration."],
}) })
@expectedFailure
def test_partial_update_agent(self): def test_partial_update_agent(self):
""" """
Ponos agents cannot update WorkerRuns, even when they can access them Ponos agents cannot update WorkerRuns, even when they can access them
...@@ -2940,8 +2940,8 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -2940,8 +2940,8 @@ class TestWorkerRuns(FixtureAPITestCase):
self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent) self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent)
# Agent auth is not implemented in CE # Agent auth is not implemented in CE
self.client.force_login(self.user) self.client.force_authenticate(user=self.agent)
with self.assertNumQueries(4): with self.assertNumQueries(2):
response = self.client.patch( response = self.client.patch(
reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}), reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}),
) )
...@@ -3071,7 +3071,6 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -3071,7 +3071,6 @@ class TestWorkerRuns(FixtureAPITestCase):
self.assertEqual(response.json(), ["WorkerRuns cannot be deleted from a process that has already started."]) self.assertEqual(response.json(), ["WorkerRuns cannot be deleted from a process that has already started."])
@expectedFailure
def test_delete_agent(self): def test_delete_agent(self):
""" """
Ponos agents cannot delete WorkerRuns, even when they can access them Ponos agents cannot delete WorkerRuns, even when they can access them
...@@ -3079,8 +3078,8 @@ class TestWorkerRuns(FixtureAPITestCase): ...@@ -3079,8 +3078,8 @@ class TestWorkerRuns(FixtureAPITestCase):
self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent) self.process_1.tasks.create(run=0, depth=0, slug="something", agent=self.agent)
# Agent auth is not implemented in CE # Agent auth is not implemented in CE
self.client.force_login(self.user) self.client.force_authenticate(user=self.agent)
with self.assertNumQueries(3): with self.assertNumQueries(1):
response = self.client.delete( response = self.client.delete(
reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}), reverse("api:worker-run-details", kwargs={"pk": str(self.run_1.id)}),
) )
......
import json import json
from hashlib import md5 from hashlib import md5
from django.conf import settings
from django.db.models import CharField, Value from django.db.models import CharField, Value
from django.db.models.functions import Cast, Concat, NullIf from django.db.models.functions import Cast, Concat, NullIf
from django.utils.module_loading import import_string
from arkindex.project.tools import RTrimChr from arkindex.project.tools import RTrimChr
__default_farm = None get_default_farm = (
import_string(settings.GET_DEFAULT_FARM)
if getattr(settings, "GET_DEFAULT_FARM", None)
def get_default_farm(): else lambda: None
return None )
def hash_object(object): def hash_object(object):
......
...@@ -131,10 +131,6 @@ def get_settings_parser(base_dir): ...@@ -131,10 +131,6 @@ def get_settings_parser(base_dir):
doorbell_parser.add_option("id", type=str, default=None) doorbell_parser.add_option("id", type=str, default=None)
doorbell_parser.add_option("appkey", type=str, default=None) doorbell_parser.add_option("appkey", type=str, default=None)
gitlab_parser = parser.add_subparser("gitlab", default={})
gitlab_parser.add_option("app_id", type=str, default=None)
gitlab_parser.add_option("app_secret", type=str, default=None)
redis_parser = parser.add_subparser("redis", default={}) redis_parser = parser.add_subparser("redis", default={})
redis_parser.add_option("host", type=str, default="localhost") redis_parser.add_option("host", type=str, default="localhost")
redis_parser.add_option("port", type=int, default=6379) redis_parser.add_option("port", type=int, default=6379)
......
...@@ -8,8 +8,7 @@ from rest_framework.exceptions import APIException, PermissionDenied ...@@ -8,8 +8,7 @@ from rest_framework.exceptions import APIException, PermissionDenied
from rest_framework.serializers import CharField, Serializer from rest_framework.serializers import CharField, Serializer
from arkindex.documents.models import Corpus from arkindex.documents.models import Corpus
from arkindex.process.models import Process, Repository, Worker from arkindex.process.models import Process, Repository
from arkindex.training.models import Model
from arkindex.users.models import Role from arkindex.users.models import Role
from arkindex.users.utils import filter_rights, get_max_level, has_access from arkindex.users.utils import filter_rights, get_max_level, has_access
...@@ -62,25 +61,6 @@ class WorkerACLMixin(ACLMixin): ...@@ -62,25 +61,6 @@ class WorkerACLMixin(ACLMixin):
A public worker is considered as executable (i.e. any user has a contributor access) A public worker is considered as executable (i.e. any user has a contributor access)
""" """
@property
def executable_workers(self):
return Worker.objects.filter(
Q(public=True)
| Q(id__in=filter_rights(self.user, Worker, Role.Contributor.value).values("id"))
| Q(repository_id__in=filter_rights(self.user, Repository, Role.Contributor.value).values("id"))
).distinct()
def get_max_level(self, worker):
# Access right on a worker can be defined by a right on its repository
worker_level = get_max_level(self.user, worker)
if not worker.repository:
return worker_level
repo_level = get_max_level(self.user, worker.repository)
return max(
filter(None, (worker_level, repo_level)),
default=None
)
def has_worker_access(self, worker, level): def has_worker_access(self, worker, level):
if worker.public and level <= Role.Contributor.value: if worker.public and level <= Role.Contributor.value:
return True return True
...@@ -113,18 +93,6 @@ class CorpusACLMixin(ACLMixin): ...@@ -113,18 +93,6 @@ class CorpusACLMixin(ACLMixin):
raise PermissionDenied(detail=f"You do not have {str(role).lower()} access to this corpus.") raise PermissionDenied(detail=f"You do not have {str(role).lower()} access to this corpus.")
return corpus return corpus
@property
def readable_corpora(self):
return Corpus.objects.filter(
id__in=filter_rights(self.user, Corpus, Role.Guest.value).values("id")
)
@property
def writable_corpora(self):
return Corpus.objects.filter(
id__in=filter_rights(self.user, Corpus, Role.Contributor.value).values("id")
)
def has_read_access(self, corpus): def has_read_access(self, corpus):
return self.has_access(corpus, Role.Guest.value) return self.has_access(corpus, Role.Guest.value)
...@@ -140,18 +108,6 @@ class TrainingModelMixin(ACLMixin): ...@@ -140,18 +108,6 @@ class TrainingModelMixin(ACLMixin):
Access control mixin for machine learning models Access control mixin for machine learning models
""" """
@property
def readable_models(self):
return Model.objects.filter(
id__in=filter_rights(self.user, Model, Role.Guest.value).values("id")
)
@property
def editable_models(self):
return Model.objects.filter(
id__in=filter_rights(self.user, Model, Role.Contributor.value).values("id")
)
def has_read_access(self, model): def has_read_access(self, model):
return self.has_access(model, Role.Guest.value) return self.has_access(model, Role.Guest.value)
......
...@@ -116,11 +116,6 @@ class FixtureMixin(object): ...@@ -116,11 +116,6 @@ class FixtureMixin(object):
# Clean content type cache for SQL requests checks consistency # Clean content type cache for SQL requests checks consistency
ContentType.objects.clear_cache() ContentType.objects.clear_cache()
# Force clean the default farm global variable in `arkindex.process.utils` module
# This is required not to alter query counts and avoid caching a farm that does not exist in the fixture
from arkindex.process import utils
setattr(utils, "__default_farm", None)
# Clear the local cached properties so that it is re-fetched on each test # Clear the local cached properties so that it is re-fetched on each test
# to avoid intermittently changing query counts. # to avoid intermittently changing query counts.
# Using `del` on a cached property that has not been accessed yet can cause an AttributeError. # Using `del` on a cached property that has not been accessed yet can cause an AttributeError.
......
...@@ -39,9 +39,6 @@ features: ...@@ -39,9 +39,6 @@ features:
search: false search: false
selection: true selection: true
signup: true signup: true
gitlab:
app_id: null
app_secret: null
imports_worker_version: null imports_worker_version: null
ingest: ingest:
access_key_id: null access_key_id: null
......
...@@ -26,9 +26,6 @@ export: ...@@ -26,9 +26,6 @@ export:
ttl: forever ttl: forever
features: features:
sv_cheats: 1 sv_cheats: 1
gitlab:
app_id: yes
app_secret: []
ingest: ingest:
endpoint: https://ohno endpoint: https://ohno
access_key_id: a access_key_id: a
......
...@@ -51,9 +51,6 @@ features: ...@@ -51,9 +51,6 @@ features:
search: true search: true
selection: false selection: false
signup: false signup: false
gitlab:
app_id: a
app_secret: b
imports_worker_version: aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa imports_worker_version: aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa
ingest: ingest:
access_key_id: abcd access_key_id: abcd
......
...@@ -166,58 +166,6 @@ class TestACLMixin(FixtureTestCase): ...@@ -166,58 +166,6 @@ class TestACLMixin(FixtureTestCase):
self.assertEqual(admin_access, access_check) self.assertEqual(admin_access, access_check)
ContentType.objects.clear_cache() ContentType.objects.clear_cache()
@expectedFailure
def test_corpus_acl_mixin_writable(self):
corpus_acl_mixin = CorpusACLMixin(user=self.user1)
with self.assertNumQueries(1):
corpora = list(corpus_acl_mixin.writable_corpora)
self.assertCountEqual(
list(corpora),
[self.corpus1]
)
def test_corpus_readable_orderable(self):
# Assert corpora retrieved via the mixin are still orderable
corpus_acl_mixin = CorpusACLMixin(user=self.user3)
with self.assertNumQueries(1):
corpora = list(corpus_acl_mixin.readable_corpora.order_by("name"))
self.assertListEqual(
[c.name for c in corpora],
["Corpus1", "Corpus2", "Unit Tests"]
)
def test_super_admin_readable_corpora(self):
# A super admin should retrieve all existing corpora with Admin rights
corpus_acl_mixin = CorpusACLMixin(user=self.superuser)
with self.assertNumQueries(1):
corpora = list(corpus_acl_mixin.readable_corpora)
self.assertCountEqual(
list(corpora),
list(Corpus.objects.all())
)
@expectedFailure
def test_anonymous_user_readable_corpora(self):
# An anonymous user should have guest access to any public corpora
corpus_acl_mixin = CorpusACLMixin(user=AnonymousUser())
with self.assertNumQueries(1):
corpora = list(corpus_acl_mixin.readable_corpora)
self.assertCountEqual(
list(corpora),
list(Corpus.objects.filter(public=True))
)
def test_corpus_right_and_public(self):
# User specific rights should be returned instead of the the defaults access for public rights
Right.objects.create(user=self.user3, content_object=self.corpus, level=42)
corpus_acl_mixin = CorpusACLMixin(user=self.user3)
with self.assertNumQueries(1):
corpora = list(corpus_acl_mixin.readable_corpora)
self.assertCountEqual(
list(corpora),
[self.corpus1, self.corpus2, self.corpus]
)
@expectedFailure @expectedFailure
def test_max_level_does_not_exists(self): def test_max_level_does_not_exists(self):
with self.assertNumQueries(3): with self.assertNumQueries(3):
...@@ -306,22 +254,3 @@ class TestACLMixin(FixtureTestCase): ...@@ -306,22 +254,3 @@ class TestACLMixin(FixtureTestCase):
with self.assertNumQueries(0): with self.assertNumQueries(0):
admin_access = TrainingModelMixin(user=self.user5).has_admin_access(self.model2) admin_access = TrainingModelMixin(user=self.user5).has_admin_access(self.model2)
self.assertTrue(admin_access) self.assertTrue(admin_access)
@expectedFailure
def test_models_readable(self):
"""
To view a model, a user needs guest access.
"""
with self.assertNumQueries(1):
readable_models = list(TrainingModelMixin(user=self.user4).readable_models)
self.assertListEqual(readable_models, [self.model1])
@expectedFailure
def test_models_editable(self):
"""
To edit a model, a user needs contributor access.
User5 only has that access on model2.
"""
with self.assertNumQueries(1):
editable_models = list(TrainingModelMixin(user=self.user5).editable_models)
self.assertListEqual(editable_models, [self.model2])
...@@ -10,6 +10,8 @@ def has_access(user: User, instance, level: int, skip_public: bool = False) -> b ...@@ -10,6 +10,8 @@ def has_access(user: User, instance, level: int, skip_public: bool = False) -> b
Check if the user has access to a generic instance with a minimum level Check if the user has access to a generic instance with a minimum level
If skip_public parameter is set to true, exclude rights on public instances If skip_public parameter is set to true, exclude rights on public instances
""" """
if user.is_anonymous:
return level <= Role.Guest.value and not skip_public and getattr(instance, "public", False)
return True return True
...@@ -18,6 +20,11 @@ def filter_rights(user: User, model, level: int): ...@@ -18,6 +20,11 @@ def filter_rights(user: User, model, level: int):
Return a generic queryset of objects with access rights for this user. Return a generic queryset of objects with access rights for this user.
Level filtering parameter should be an integer between 1 and 100. Level filtering parameter should be an integer between 1 and 100.
""" """
if user.is_anonymous:
if hasattr(model, "public"):
return model.objects.filter(public=True).annotate(max_level=Value(Role.Guest.value, IntegerField()))
return model.objects.none()
return model.objects.annotate(max_level=Value(Role.Admin.value, IntegerField())) return model.objects.annotate(max_level=Value(Role.Admin.value, IntegerField()))
...@@ -25,4 +32,9 @@ def get_max_level(user: User, instance) -> Optional[int]: ...@@ -25,4 +32,9 @@ def get_max_level(user: User, instance) -> Optional[int]:
""" """
Returns the maximum access level on a given model instance Returns the maximum access level on a given model instance
""" """
if user.is_anonymous:
if getattr(instance, "public", False):
return Role.Guest.value
return None
return Role.Admin.value return Role.Admin.value
...@@ -5,5 +5,8 @@ line-length = 120 ...@@ -5,5 +5,8 @@ line-length = 120
quote-style = "double" quote-style = "double"
[lint] [lint]
select = ["Q0", "F", "W", "E"] select = ["Q0", "F", "W", "E",
# request-without-timeout
"S113",
]
ignore = ["E501"] ignore = ["E501"]