Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (3)
Showing
with 445 additions and 62 deletions
......@@ -2,7 +2,9 @@ from datetime import timedelta
from textwrap import dedent
from django.conf import settings
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.functional import cached_property
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import permissions, serializers, status
from rest_framework.exceptions import PermissionDenied, ValidationError
......@@ -11,9 +13,7 @@ from rest_framework.response import Response
from arkindex.documents.models import Corpus, CorpusExport, CorpusExportState
from arkindex.documents.serializers.export import CorpusExportSerializer
from arkindex.project.mixins import CorpusACLMixin
from arkindex.project.permissions import IsVerified
from arkindex.users.models import Role
@extend_schema(tags=["exports"])
......@@ -27,47 +27,42 @@ from arkindex.users.models import Role
),
post=extend_schema(
operation_id="StartExport",
request=None,
description=dedent(
f"""
Start a corpus export job.
A user must wait for {settings.EXPORT_TTL_SECONDS} seconds after the last successful import
before being able to generate a new export of the same corpus.
before being able to generate a new export of the same corpus from the same source.
Contributor access is required.
"""
),
)
)
class CorpusExportAPIView(CorpusACLMixin, ListCreateAPIView):
class CorpusExportAPIView(ListCreateAPIView):
permission_classes = (IsVerified, )
serializer_class = CorpusExportSerializer
queryset = CorpusExport.objects.none()
@cached_property
def corpus(self):
qs = Corpus.objects.readable(self.request.user)
corpus = get_object_or_404(qs, pk=self.kwargs["pk"])
if self.request.method not in permissions.SAFE_METHODS and not corpus.is_writable(self.request.user):
raise PermissionDenied(detail="You do not have write access to this corpus.")
return corpus
def get_queryset(self):
return CorpusExport \
.objects \
.filter(corpus=self.get_corpus(self.kwargs["pk"])) \
.filter(corpus=self.corpus) \
.select_related("user") \
.order_by("-created")
def post(self, *args, **kwargs):
corpus = self.get_corpus(self.kwargs["pk"], role=Role.Contributor)
if corpus.exports.filter(state__in=(CorpusExportState.Created, CorpusExportState.Running)).exists():
raise ValidationError("An export is already running for this corpus.")
available_exports = corpus.exports.filter(
state=CorpusExportState.Done,
created__gte=timezone.now() - timedelta(seconds=settings.EXPORT_TTL_SECONDS)
)
if available_exports.exists():
raise ValidationError(f"An export has already been made for this corpus in the last {settings.EXPORT_TTL_SECONDS} seconds.")
export = corpus.exports.create(user=self.request.user)
export.start()
return Response(CorpusExportSerializer(export).data, status=status.HTTP_201_CREATED)
def get_serializer_context(self):
context = super().get_serializer_context()
context["corpus"] = self.corpus
return context
@extend_schema(
......
......@@ -46,12 +46,12 @@ EXPORT_QUERIES = [
]
def run_pg_query(query):
def run_pg_query(query, source_db):
"""
Run a single Postgresql query and split the results into chunks.
When a name is given to a cursor, psycopg2 uses a server-side cursor; we just use a random string as a name.
"""
with connections["default"].create_cursor(name=str(uuid.uuid4())) as pg_cursor:
with connections[source_db].create_cursor(name=str(uuid.uuid4())) as pg_cursor:
pg_cursor.itersize = BATCH_SIZE
pg_cursor.execute(query)
......@@ -122,7 +122,11 @@ def export_corpus(corpus_export: CorpusExport) -> None:
corpus_export.state = CorpusExportState.Running
corpus_export.save()
logger.info(f"Exporting corpus {corpus_export.corpus_id} into {db_path}")
export_source = f"{corpus_export.corpus_id}"
if corpus_export.source != "default":
export_source += f" from source {corpus_export.source}"
logger.info(f"Exporting corpus {export_source} into {db_path}")
db = sqlite3.connect(db_path)
cursor = db.cursor()
......@@ -135,7 +139,7 @@ def export_corpus(corpus_export: CorpusExport) -> None:
if rq_job:
rq_job.set_progress(i / (len(EXPORT_QUERIES) + 1))
for chunk in run_pg_query(query.format(corpus_id=corpus_export.corpus_id)):
for chunk in run_pg_query(query.format(corpus_id=corpus_export.corpus_id), corpus_export.source):
save_sqlite(chunk, name, cursor)
db.commit()
......
# Generated by Django 4.1.7 on 2024-02-28 15:56
from django.db import migrations, models
from arkindex.project import settings
class Migration(migrations.Migration):
dependencies = [
("documents", "0008_alter_elementtype_color_alter_entitytype_color"),
]
operations = [
migrations.AddField(
model_name="corpusexport",
name="source",
field=models.CharField(choices=[(source, source) for source in settings.EXPORT_SOURCES], default="default", max_length=50),
),
]
......@@ -73,6 +73,18 @@ class Corpus(IndexableModel):
for values in DEFAULT_CORPUS_TYPES
)
def is_writable(self, user) -> bool:
"""
Whether a user has write access to this corpus
"""
if user.is_anonymous or getattr(user, "is_agent", False):
return False
if user.is_admin:
return True
from arkindex.users.utils import get_max_level
level = get_max_level(user, self)
return level is not None and level >= Role.Contributor.value
class ElementType(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True, editable=False)
......@@ -1185,6 +1197,7 @@ class CorpusExport(S3FileMixin, IndexableModel):
corpus = models.ForeignKey(Corpus, related_name="exports", on_delete=models.CASCADE)
user = models.ForeignKey(settings.AUTH_USER_MODEL, related_name="exports", on_delete=models.CASCADE)
state = EnumField(CorpusExportState, max_length=10, default=CorpusExportState.Created)
source = models.CharField(max_length=50, default="default", choices=[(source, source) for source in settings.EXPORT_SOURCES])
s3_bucket = settings.AWS_EXPORT_BUCKET
......
from datetime import timedelta
from django.conf import settings
from django.utils import timezone
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex.documents.models import CorpusExport, CorpusExportState
from arkindex.project.serializer_fields import EnumField
......@@ -6,9 +11,38 @@ from arkindex.users.serializers import SimpleUserSerializer
class CorpusExportSerializer(serializers.ModelSerializer):
user = SimpleUserSerializer()
state = EnumField(CorpusExportState)
user = SimpleUserSerializer(read_only=True)
state = EnumField(CorpusExportState, read_only=True)
class Meta:
model = CorpusExport
fields = ("id", "created", "updated", "corpus_id", "user", "state")
fields = ("id", "created", "updated", "corpus_id", "user", "state", "source",)
def validate(self, data):
corpus = self.context["corpus"]
source = data.get("source", "default")
# Check that there is no export already running for this corpus
if corpus.exports.filter(state__in=(CorpusExportState.Created, CorpusExportState.Running)).exists():
raise ValidationError("An export is already running for this corpus.")
# Check that there is no available completed export from the same source created less than {EXPORT_TTL_SECONDS}
# ago for this corpus
available_exports = corpus.exports.filter(
state=CorpusExportState.Done,
source=source,
created__gte=timezone.now() - timedelta(seconds=settings.EXPORT_TTL_SECONDS)
)
if available_exports.exists():
raise ValidationError(f"An export has already been made for this corpus in the last {settings.EXPORT_TTL_SECONDS} seconds.")
data["corpus"] = corpus
data["source"] = source
return data
def create(self, validated_data):
export = CorpusExport.objects.create(
user=self.context["request"].user,
corpus=validated_data["corpus"],
source=validated_data["source"]
)
export.start()
return export
......@@ -31,6 +31,7 @@ class TestExport(FixtureAPITestCase):
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "default"
})
self.assertEqual(delay_mock.call_count, 1)
......@@ -61,19 +62,62 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.mixins.has_access", return_value=False)
def test_start_requires_contributor(self, has_access_mock, delay_mock):
@patch("arkindex.users.utils.get_max_level", return_value=Role.Guest.value)
def test_start_requires_contributor(self, max_level_mock, delay_mock):
self.user.rights.update(level=Role.Guest.value)
self.client.force_login(self.user)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(has_access_mock.call_count, 1)
self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Contributor.value, skip_public=False))
self.assertEqual(max_level_mock.call_count, 1)
self.assertEqual(max_level_mock.call_args, call(self.user, self.corpus))
self.assertFalse(self.corpus.exports.exists())
self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_bad_source(self, delay_mock):
self.client.force_login(self.superuser)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}), {"source": "jouvence"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"source": ['"jouvence" is not a valid choice.']})
self.assertEqual(self.corpus.exports.count(), 0)
self.assertFalse(delay_mock.called)
@patch("arkindex.documents.models.CorpusExport.source")
@patch("arkindex.project.triggers.export.export_corpus.delay")
@override_settings(EXPORT_TTL_SECONDS=420)
def test_start_with_source(self, delay_mock, source_field_mock):
source_field_mock.field.choices.return_value = [("default", "default"), ("jouvence", "jouvence")]
self.client.force_login(self.superuser)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}), {"source": "jouvence"})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
export = self.corpus.exports.get()
self.assertDictEqual(response.json(), {
"id": str(export.id),
"created": export.created.isoformat().replace("+00:00", "Z"),
"updated": export.updated.isoformat().replace("+00:00", "Z"),
"user": {
"id": self.superuser.id,
"display_name": self.superuser.display_name,
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "jouvence"
})
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(
corpus_export=export,
user_id=self.superuser.id,
description="Export of corpus Unit Tests from source jouvence"
))
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_running(self, delay_mock):
self.client.force_login(self.superuser)
......@@ -81,7 +125,9 @@ class TestExport(FixtureAPITestCase):
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ["An export is already running for this corpus."])
self.assertDictEqual(response.json(), {
"non_field_errors": ["An export is already running for this corpus."]
})
self.assertEqual(self.corpus.exports.count(), 1)
self.assertFalse(delay_mock.called)
......@@ -99,11 +145,53 @@ class TestExport(FixtureAPITestCase):
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ["An export has already been made for this corpus in the last 420 seconds."])
self.assertDictEqual(response.json(), {
"non_field_errors": ["An export has already been made for this corpus in the last 420 seconds."]
})
self.assertEqual(self.corpus.exports.count(), 1)
self.assertFalse(delay_mock.called)
@override_settings(EXPORT_TTL_SECONDS=420)
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_recent_export_different_source(self, delay_mock):
from arkindex.documents.models import CorpusExport
CorpusExport.source.field.choices = [("default", "default"), ("jouvence", "jouvence")]
self.client.force_login(self.superuser)
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) - timedelta(minutes=2)
self.corpus.exports.create(
user=self.user,
state=CorpusExportState.Done,
source="jouvence"
)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
export = self.corpus.exports.get(source="default")
self.assertDictEqual(response.json(), {
"id": str(export.id),
"created": export.created.isoformat().replace("+00:00", "Z"),
"updated": export.updated.isoformat().replace("+00:00", "Z"),
"user": {
"id": self.superuser.id,
"display_name": self.superuser.display_name,
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "default"
})
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(
corpus_export=export,
user_id=self.superuser.id,
description="Export of corpus Unit Tests"
))
def test_list(self):
export1 = self.corpus.exports.create(user=self.user, state=CorpusExportState.Done)
export2 = self.corpus.exports.create(user=self.superuser)
......@@ -123,6 +211,7 @@ class TestExport(FixtureAPITestCase):
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"source": "default"
},
{
"id": str(export1.id),
......@@ -135,6 +224,7 @@ class TestExport(FixtureAPITestCase):
"email": self.user.email,
},
"corpus_id": str(self.corpus.id),
"source": "default"
},
])
......@@ -149,18 +239,19 @@ class TestExport(FixtureAPITestCase):
response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@patch("arkindex.project.mixins.has_access", return_value=False)
def test_list_requires_guest(self, has_access_mock):
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_list_requires_guest(self, filter_rights_mock):
self.user.rights.all().delete()
self.corpus.public = False
self.corpus.save()
filter_rights_mock.return_value = Corpus.objects.none()
self.client.force_login(self.user)
response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(has_access_mock.call_count, 1)
self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Guest.value, skip_public=False))
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value))
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_download_export(self, presigned_url_mock):
......
import shlex
from collections import defaultdict
from datetime import timedelta
from functools import wraps
from os import path
from typing import Dict, List, Sequence, Tuple
......@@ -7,10 +8,12 @@ from uuid import UUID
from django.conf import settings
from django.db.models import Prefetch, prefetch_related_objects
from django.utils import timezone
from django.utils.functional import cached_property
from rest_framework.exceptions import ValidationError
from arkindex.images.models import ImageServer
from arkindex.ponos.models import Task, task_token_default
from arkindex.ponos.models import GPU, Task, task_token_default
class ProcessBuilder(object):
......@@ -162,6 +165,10 @@ class ProcessBuilder(object):
env["ARKINDEX_CORPUS_ID"] = str(self.process.corpus_id)
return env
@cached_property
def active_gpu_agents(self) -> bool:
return GPU.objects.filter(agent__last_ping__gt=timezone.now() - timedelta(seconds=30)).exists()
@prefetch_worker_runs
def validate_gpu_requirement(self):
from arkindex.process.models import FeatureUsage
......@@ -298,6 +305,7 @@ class ProcessBuilder(object):
chunk=index if len(chunks) > 1 else None,
workflow_runs=worker_runs,
run=self.run,
active_gpu_agents=self.active_gpu_agents,
)
self.tasks.append(task)
self.tasks_parents[task.slug].extend(parent_slugs)
......
......@@ -911,7 +911,7 @@ class WorkerRun(models.Model):
# we add the WorkerRun ID at the end of the slug
return f"{self.version.worker.slug}_{str(self.id)[:6]}"
def build_task(self, process, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None):
def build_task(self, process, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None, active_gpu_agents=False):
"""
Build the Task that will represent this WorkerRun in ponos using :
- the docker image name given by the WorkerVersion
......@@ -967,6 +967,12 @@ class WorkerRun(models.Model):
assert self.model_version.state == ModelVersionState.Available, f"ModelVersion {self.model_version.id} is not available and cannot be used to build a task."
extra_files = {"model": settings.PUBLIC_HOSTNAME + reverse("api:model-version-download", kwargs={"pk": self.model_version.id}) + f"?token={self.model_version.build_authentication_token_hash()}"}
requires_gpu = process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported)
# Do not require a GPU if there are no active agents with GPU and the GPU feature is only supported by the worker version;
# this does not make sense in the context of RQ tasks execution
if not settings.PONOS_RQ_EXECUTION and not active_gpu_agents and self.version.gpu_usage != FeatureUsage.Required:
requires_gpu = False
task = Task(
command=self.version.docker_command,
image=self.version.docker_image_iid or self.version.docker_image_name,
......@@ -981,7 +987,7 @@ class WorkerRun(models.Model):
process=process,
worker_run=self,
extra_files=extra_files,
requires_gpu=process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported)
requires_gpu=requires_gpu
)
return task, parents
......
import uuid
from collections import namedtuple
from datetime import datetime, timezone
from unittest.mock import call, patch
......@@ -7,9 +8,10 @@ from rest_framework import status
from rest_framework.reverse import reverse
from arkindex.documents.models import Corpus, Element
from arkindex.ponos.models import Farm, State
from arkindex.ponos.models import GPU, Agent, Farm, State
from arkindex.process.models import (
ActivityState,
FeatureUsage,
Process,
ProcessDataset,
ProcessMode,
......@@ -31,6 +33,21 @@ class TestCreateProcess(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.agent = Agent.objects.create(
farm=Farm.objects.first(),
hostname="claude",
cpu_cores=42,
cpu_frequency=1e15,
ram_total=99e9,
last_ping=datetime.now(timezone.utc),
)
cls.agent.gpus.create(
id=uuid.uuid4(),
name="claudette",
index=2,
ram_total=12
)
cls.volume = Element.objects.get(name="Volume 1")
cls.pages = Element.objects.get_descending(cls.volume.id).filter(type__slug="page", polygon__isnull=False)
cls.ml_class = cls.corpus.ml_classes.create(name="bretzel")
......@@ -585,7 +602,7 @@ class TestCreateProcess(FixtureAPITestCase):
self.assertFalse(self.corpus.worker_versions.exists())
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process_2.id)}),
{"worker_activity": True},
......@@ -676,7 +693,7 @@ class TestCreateProcess(FixtureAPITestCase):
)
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process_2.id)}),
{"use_cache": True},
......@@ -714,7 +731,7 @@ class TestCreateProcess(FixtureAPITestCase):
@patch("arkindex.ponos.models.base64.encodebytes")
def test_create_process_use_gpu_option(self, token_mock):
"""
A process with the `use_gpu` parameter enables the `requires_gpu` attribute of tasks than need one
A process with the `use_gpu` parameter enables the `requires_gpu` attribute of tasks that need one
"""
token_mock.side_effect = [b"12345", b"67891"]
process_2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
......@@ -724,7 +741,7 @@ class TestCreateProcess(FixtureAPITestCase):
)
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process_2.id)}),
{"use_gpu": True},
......@@ -755,6 +772,60 @@ class TestCreateProcess(FixtureAPITestCase):
self.assertEqual(len(worker_task.parents.all()), 1)
self.assertEqual(worker_task.parents.first(), init_task)
@override_settings(
ARKINDEX_TASKS_IMAGE="registry.teklia.com/tasks",
PONOS_DEFAULT_ENV={}
)
@patch("arkindex.ponos.models.base64.encodebytes")
def test_create_process_use_gpu_option_no_available_gpus(self, token_mock):
"""
If there are no avilables Agents with GPU, then requires_gpu is not sent if the worker
version only supports and not requires it
"""
self.agent.gpus.all().delete()
token_mock.side_effect = [b"12345", b"67891", b"54321", b"19876"]
for feature_usage, requires_gpu, task_token in [
(FeatureUsage.Supported, False, "67891"),
(FeatureUsage.Required, True, "19876")
]:
with self.subTest(feature_usage=feature_usage, requires_gpu=requires_gpu):
process = self.corpus.processes.create(
creator=self.user,
mode=ProcessMode.Workers,
farm=Farm.objects.first(),
)
run = process.worker_runs.create(
version=self.version_3,
parents=[],
)
process.use_gpu = True
self.assertEqual(GPU.objects.count(), 0)
self.version_3.gpu_usage = feature_usage
self.version_3.save()
process.run()
init_task = process.tasks.get(slug="initialisation")
self.assertEqual(init_task.command, f"python -m arkindex_tasks.init_elements {process.id} --chunks-number 1")
self.assertEqual(init_task.image, "registry.teklia.com/tasks")
worker_task = process.tasks.get(slug=run.task_slug)
self.assertEqual(worker_task.command, None)
self.assertEqual(worker_task.image, f"my_repo.fake/workers/worker/worker-gpu:{self.version_3.id}")
self.assertEqual(worker_task.image_artifact.id, self.version_3.docker_image.id)
self.assertEqual(worker_task.shm_size, None)
self.assertEqual(worker_task.env, {
"TASK_ELEMENTS": "/data/initialisation/elements.json",
"ARKINDEX_CORPUS_ID": str(self.corpus.id),
"ARKINDEX_PROCESS_ID": str(process.id),
"ARKINDEX_WORKER_RUN_ID": str(process.worker_runs.get().id),
"ARKINDEX_TASK_TOKEN": task_token
})
self.assertEqual(worker_task.requires_gpu, requires_gpu)
self.assertEqual(len(worker_task.parents.all()), 1)
self.assertEqual(worker_task.parents.first(), init_task)
def test_retry_keeps_requires_gpu(self):
"""
When a process is retried, the newly created tasks keep the same requires_gpu values
......@@ -819,7 +890,7 @@ class TestCreateProcess(FixtureAPITestCase):
process.use_gpu = True
process.save()
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process.id)}),
{"use_gpu": "true"}
......@@ -907,7 +978,7 @@ class TestCreateProcess(FixtureAPITestCase):
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
process.versions.add(custom_version)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......
......@@ -1749,7 +1749,7 @@ class TestProcesses(FixtureAPITestCase):
self.workers_process.activity_state = ActivityState.Error
self.workers_process.save()
with self.assertNumQueries(13):
with self.assertNumQueries(14):
response = self.client.post(reverse("api:process-retry", kwargs={"pk": self.workers_process.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
......@@ -2126,7 +2126,7 @@ class TestProcesses(FixtureAPITestCase):
with (
self.settings(IMPORTS_WORKER_VERSION=str(self.version_with_model.id)),
self.assertNumQueries(8)
self.assertNumQueries(9)
):
response = self.client.post(reverse("api:files-process"), {
"files": [str(self.img_df.id)],
......@@ -2217,7 +2217,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process2.id)})
)
......@@ -2362,7 +2362,7 @@ class TestProcesses(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process2.id)})
)
......@@ -2477,7 +2477,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process.id)})
)
......@@ -2499,7 +2499,7 @@ class TestProcesses(FixtureAPITestCase):
farm = Farm.objects.get(name="Wheat farm")
self.client.force_login(self.user)
with self.assertNumQueries(15):
with self.assertNumQueries(16):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(workers_process.id)}),
{"farm": str(farm.id)}
......@@ -2697,7 +2697,7 @@ class TestProcesses(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(15):
with self.assertNumQueries(16):
response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process.id)}),
{"use_cache": "true", "worker_activity": "true", "use_gpu": "true"}
......@@ -2733,7 +2733,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertNotEqual(run_1.task_slug, run_2.task_slug)
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......
......@@ -362,6 +362,8 @@ REDIS_ZREM_CHUNK_SIZE = 10000
# How long before a corpus export can be run again after a successful one
EXPORT_TTL_SECONDS = conf["export"]["ttl"]
# Available database sources for corpus exports
EXPORT_SOURCES = ["default"]
LOGGING = {
"version": 1,
......
......@@ -189,10 +189,13 @@ def export_corpus(corpus_export: CorpusExport) -> None:
"""
Export a corpus to a SQLite database
"""
description = f"Export of corpus {corpus_export.corpus.name}"
if corpus_export.source != "default":
description += f" from source {corpus_export.source}"
export.export_corpus.delay(
corpus_export=corpus_export,
user_id=corpus_export.user_id,
description=f"Export of corpus {corpus_export.corpus.name}"
description=description
)
......
......@@ -307,10 +307,15 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView):
# Set the current model version as erroneous and return the available one
instance.state = ModelVersionState.Error
instance.save(update_fields=["state"])
# Set context
context = {
**self.get_serializer_context(),
"is_contributor": True
}
return Response(
ModelVersionSerializer(
existing_model_version,
context={"is_contributor": True, "model": instance},
context=context,
).data,
status=status.HTTP_409_CONFLICT,
)
......
......@@ -213,7 +213,7 @@ class ModelVersionSerializer(serializers.ModelSerializer):
else:
model = self.context.get("model")
if model:
qs = ModelVersion.objects.filter(model_id=model.id)
qs = ModelVersion.objects.filter(model__in=Model.objects.readable(self.context["request"].user))
if getattr(self.instance, "id", None):
qs = qs.exclude(id=self.instance.id)
self.fields["parent"].queryset = qs
......
......@@ -266,6 +266,82 @@ class TestModelAPI(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_create_model_version_any_parent_model_version(self, s3_presigned_url_mock):
"""
Any readable model version can be set as parent of a model version, not just versions of the same model
"""
self.client.force_login(self.user1)
s3_presigned_url_mock.return_value = "http://s3/upload_put_url"
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(6):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn("id", data)
df = ModelVersion.objects.get(id=data["id"])
self.assertDictEqual(
data,
{
"id": str(df.id),
"model_id": str(self.model1.id),
"parent": str(self.model_version3.id),
"description": "description",
"state": ModelVersionState.Created.value,
"configuration": {"hello": "this is me"},
"tag": "TAG",
"size": None,
"hash": None,
"created": fake_now.isoformat().replace("+00:00", "Z"),
"s3_url": None,
"s3_put_url": s3_presigned_url_mock.return_value
}
)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_create_model_version_readable_parent_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id=self.model1.id)
self.client.force_login(self.user1)
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(4):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user1, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version3.id)}" - object does not exist.']
})
def test_retrieve_model_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(reverse("api:model-retrieve", kwargs={"pk": str(self.model2.id)}))
......@@ -934,8 +1010,63 @@ class TestModelAPI(FixtureAPITestCase):
"size": 8,
})
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_partial_update_any_parent_model_version(self, s3_presigned_url):
"""
A model version can have any model version as a parent, not just a version from the same model
"""
s3_presigned_url.return_value = "http://s3/get_url"
self.client.force_login(self.user2)
with self.assertNumQueries(6):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.model_version3.id),
"model_id": str(self.model2.id),
"created": self.model_version3.created.isoformat().replace("+00:00", "Z"),
"s3_url": "http://s3/get_url",
"s3_put_url": None,
"tag": "tagged",
"description": "",
"configuration": {},
"parent": str(self.model_version1.id),
"state": ModelVersionState.Available.value,
"hash": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba",
"size": 8,
})
self.model_version3.refresh_from_db()
self.assertIsNone(self.model_version3.parent_id)
self.assertEqual(self.model_version3.parent_id, self.model_version1.id)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_partial_update_readable_parent_model_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id__in=[self.model2.id, self.model3.id])
self.client.force_login(self.user2)
with self.assertNumQueries(4):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user2, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version1.id)}" - object does not exist.']
})
@patch("arkindex.training.api.get_max_level", return_value=None)
def test_update_model_version_requires_contributor(self, get_max_level_mock):
......