Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (3)
Showing
with 445 additions and 62 deletions
...@@ -2,7 +2,9 @@ from datetime import timedelta ...@@ -2,7 +2,9 @@ from datetime import timedelta
from textwrap import dedent from textwrap import dedent
from django.conf import settings from django.conf import settings
from django.shortcuts import get_object_or_404
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property
from drf_spectacular.utils import extend_schema, extend_schema_view from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import permissions, serializers, status from rest_framework import permissions, serializers, status
from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.exceptions import PermissionDenied, ValidationError
...@@ -11,9 +13,7 @@ from rest_framework.response import Response ...@@ -11,9 +13,7 @@ from rest_framework.response import Response
from arkindex.documents.models import Corpus, CorpusExport, CorpusExportState from arkindex.documents.models import Corpus, CorpusExport, CorpusExportState
from arkindex.documents.serializers.export import CorpusExportSerializer from arkindex.documents.serializers.export import CorpusExportSerializer
from arkindex.project.mixins import CorpusACLMixin
from arkindex.project.permissions import IsVerified from arkindex.project.permissions import IsVerified
from arkindex.users.models import Role
@extend_schema(tags=["exports"]) @extend_schema(tags=["exports"])
...@@ -27,47 +27,42 @@ from arkindex.users.models import Role ...@@ -27,47 +27,42 @@ from arkindex.users.models import Role
), ),
post=extend_schema( post=extend_schema(
operation_id="StartExport", operation_id="StartExport",
request=None,
description=dedent( description=dedent(
f""" f"""
Start a corpus export job. Start a corpus export job.
A user must wait for {settings.EXPORT_TTL_SECONDS} seconds after the last successful import A user must wait for {settings.EXPORT_TTL_SECONDS} seconds after the last successful import
before being able to generate a new export of the same corpus. before being able to generate a new export of the same corpus from the same source.
Contributor access is required. Contributor access is required.
""" """
), ),
) )
) )
class CorpusExportAPIView(CorpusACLMixin, ListCreateAPIView): class CorpusExportAPIView(ListCreateAPIView):
permission_classes = (IsVerified, ) permission_classes = (IsVerified, )
serializer_class = CorpusExportSerializer serializer_class = CorpusExportSerializer
queryset = CorpusExport.objects.none() queryset = CorpusExport.objects.none()
@cached_property
def corpus(self):
qs = Corpus.objects.readable(self.request.user)
corpus = get_object_or_404(qs, pk=self.kwargs["pk"])
if self.request.method not in permissions.SAFE_METHODS and not corpus.is_writable(self.request.user):
raise PermissionDenied(detail="You do not have write access to this corpus.")
return corpus
def get_queryset(self): def get_queryset(self):
return CorpusExport \ return CorpusExport \
.objects \ .objects \
.filter(corpus=self.get_corpus(self.kwargs["pk"])) \ .filter(corpus=self.corpus) \
.select_related("user") \ .select_related("user") \
.order_by("-created") .order_by("-created")
def post(self, *args, **kwargs): def get_serializer_context(self):
corpus = self.get_corpus(self.kwargs["pk"], role=Role.Contributor) context = super().get_serializer_context()
context["corpus"] = self.corpus
if corpus.exports.filter(state__in=(CorpusExportState.Created, CorpusExportState.Running)).exists(): return context
raise ValidationError("An export is already running for this corpus.")
available_exports = corpus.exports.filter(
state=CorpusExportState.Done,
created__gte=timezone.now() - timedelta(seconds=settings.EXPORT_TTL_SECONDS)
)
if available_exports.exists():
raise ValidationError(f"An export has already been made for this corpus in the last {settings.EXPORT_TTL_SECONDS} seconds.")
export = corpus.exports.create(user=self.request.user)
export.start()
return Response(CorpusExportSerializer(export).data, status=status.HTTP_201_CREATED)
@extend_schema( @extend_schema(
......
...@@ -46,12 +46,12 @@ EXPORT_QUERIES = [ ...@@ -46,12 +46,12 @@ EXPORT_QUERIES = [
] ]
def run_pg_query(query): def run_pg_query(query, source_db):
""" """
Run a single Postgresql query and split the results into chunks. Run a single Postgresql query and split the results into chunks.
When a name is given to a cursor, psycopg2 uses a server-side cursor; we just use a random string as a name. When a name is given to a cursor, psycopg2 uses a server-side cursor; we just use a random string as a name.
""" """
with connections["default"].create_cursor(name=str(uuid.uuid4())) as pg_cursor: with connections[source_db].create_cursor(name=str(uuid.uuid4())) as pg_cursor:
pg_cursor.itersize = BATCH_SIZE pg_cursor.itersize = BATCH_SIZE
pg_cursor.execute(query) pg_cursor.execute(query)
...@@ -122,7 +122,11 @@ def export_corpus(corpus_export: CorpusExport) -> None: ...@@ -122,7 +122,11 @@ def export_corpus(corpus_export: CorpusExport) -> None:
corpus_export.state = CorpusExportState.Running corpus_export.state = CorpusExportState.Running
corpus_export.save() corpus_export.save()
logger.info(f"Exporting corpus {corpus_export.corpus_id} into {db_path}") export_source = f"{corpus_export.corpus_id}"
if corpus_export.source != "default":
export_source += f" from source {corpus_export.source}"
logger.info(f"Exporting corpus {export_source} into {db_path}")
db = sqlite3.connect(db_path) db = sqlite3.connect(db_path)
cursor = db.cursor() cursor = db.cursor()
...@@ -135,7 +139,7 @@ def export_corpus(corpus_export: CorpusExport) -> None: ...@@ -135,7 +139,7 @@ def export_corpus(corpus_export: CorpusExport) -> None:
if rq_job: if rq_job:
rq_job.set_progress(i / (len(EXPORT_QUERIES) + 1)) rq_job.set_progress(i / (len(EXPORT_QUERIES) + 1))
for chunk in run_pg_query(query.format(corpus_id=corpus_export.corpus_id)): for chunk in run_pg_query(query.format(corpus_id=corpus_export.corpus_id), corpus_export.source):
save_sqlite(chunk, name, cursor) save_sqlite(chunk, name, cursor)
db.commit() db.commit()
......
# Generated by Django 4.1.7 on 2024-02-28 15:56
from django.db import migrations, models
from arkindex.project import settings
class Migration(migrations.Migration):
dependencies = [
("documents", "0008_alter_elementtype_color_alter_entitytype_color"),
]
operations = [
migrations.AddField(
model_name="corpusexport",
name="source",
field=models.CharField(choices=[(source, source) for source in settings.EXPORT_SOURCES], default="default", max_length=50),
),
]
...@@ -73,6 +73,18 @@ class Corpus(IndexableModel): ...@@ -73,6 +73,18 @@ class Corpus(IndexableModel):
for values in DEFAULT_CORPUS_TYPES for values in DEFAULT_CORPUS_TYPES
) )
def is_writable(self, user) -> bool:
"""
Whether a user has write access to this corpus
"""
if user.is_anonymous or getattr(user, "is_agent", False):
return False
if user.is_admin:
return True
from arkindex.users.utils import get_max_level
level = get_max_level(user, self)
return level is not None and level >= Role.Contributor.value
class ElementType(models.Model): class ElementType(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True, editable=False) id = models.UUIDField(default=uuid.uuid4, primary_key=True, editable=False)
...@@ -1185,6 +1197,7 @@ class CorpusExport(S3FileMixin, IndexableModel): ...@@ -1185,6 +1197,7 @@ class CorpusExport(S3FileMixin, IndexableModel):
corpus = models.ForeignKey(Corpus, related_name="exports", on_delete=models.CASCADE) corpus = models.ForeignKey(Corpus, related_name="exports", on_delete=models.CASCADE)
user = models.ForeignKey(settings.AUTH_USER_MODEL, related_name="exports", on_delete=models.CASCADE) user = models.ForeignKey(settings.AUTH_USER_MODEL, related_name="exports", on_delete=models.CASCADE)
state = EnumField(CorpusExportState, max_length=10, default=CorpusExportState.Created) state = EnumField(CorpusExportState, max_length=10, default=CorpusExportState.Created)
source = models.CharField(max_length=50, default="default", choices=[(source, source) for source in settings.EXPORT_SOURCES])
s3_bucket = settings.AWS_EXPORT_BUCKET s3_bucket = settings.AWS_EXPORT_BUCKET
......
from datetime import timedelta
from django.conf import settings
from django.utils import timezone
from rest_framework import serializers from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex.documents.models import CorpusExport, CorpusExportState from arkindex.documents.models import CorpusExport, CorpusExportState
from arkindex.project.serializer_fields import EnumField from arkindex.project.serializer_fields import EnumField
...@@ -6,9 +11,38 @@ from arkindex.users.serializers import SimpleUserSerializer ...@@ -6,9 +11,38 @@ from arkindex.users.serializers import SimpleUserSerializer
class CorpusExportSerializer(serializers.ModelSerializer): class CorpusExportSerializer(serializers.ModelSerializer):
user = SimpleUserSerializer() user = SimpleUserSerializer(read_only=True)
state = EnumField(CorpusExportState) state = EnumField(CorpusExportState, read_only=True)
class Meta: class Meta:
model = CorpusExport model = CorpusExport
fields = ("id", "created", "updated", "corpus_id", "user", "state") fields = ("id", "created", "updated", "corpus_id", "user", "state", "source",)
def validate(self, data):
corpus = self.context["corpus"]
source = data.get("source", "default")
# Check that there is no export already running for this corpus
if corpus.exports.filter(state__in=(CorpusExportState.Created, CorpusExportState.Running)).exists():
raise ValidationError("An export is already running for this corpus.")
# Check that there is no available completed export from the same source created less than {EXPORT_TTL_SECONDS}
# ago for this corpus
available_exports = corpus.exports.filter(
state=CorpusExportState.Done,
source=source,
created__gte=timezone.now() - timedelta(seconds=settings.EXPORT_TTL_SECONDS)
)
if available_exports.exists():
raise ValidationError(f"An export has already been made for this corpus in the last {settings.EXPORT_TTL_SECONDS} seconds.")
data["corpus"] = corpus
data["source"] = source
return data
def create(self, validated_data):
export = CorpusExport.objects.create(
user=self.context["request"].user,
corpus=validated_data["corpus"],
source=validated_data["source"]
)
export.start()
return export
...@@ -31,6 +31,7 @@ class TestExport(FixtureAPITestCase): ...@@ -31,6 +31,7 @@ class TestExport(FixtureAPITestCase):
}, },
"corpus_id": str(self.corpus.id), "corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value, "state": CorpusExportState.Created.value,
"source": "default"
}) })
self.assertEqual(delay_mock.call_count, 1) self.assertEqual(delay_mock.call_count, 1)
...@@ -61,19 +62,62 @@ class TestExport(FixtureAPITestCase): ...@@ -61,19 +62,62 @@ class TestExport(FixtureAPITestCase):
self.assertFalse(delay_mock.called) self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay") @patch("arkindex.project.triggers.export.export_corpus.delay")
@patch("arkindex.project.mixins.has_access", return_value=False) @patch("arkindex.users.utils.get_max_level", return_value=Role.Guest.value)
def test_start_requires_contributor(self, has_access_mock, delay_mock): def test_start_requires_contributor(self, max_level_mock, delay_mock):
self.user.rights.update(level=Role.Guest.value) self.user.rights.update(level=Role.Guest.value)
self.client.force_login(self.user) self.client.force_login(self.user)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id})) response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(has_access_mock.call_count, 1) self.assertEqual(max_level_mock.call_count, 1)
self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Contributor.value, skip_public=False)) self.assertEqual(max_level_mock.call_args, call(self.user, self.corpus))
self.assertFalse(self.corpus.exports.exists()) self.assertFalse(self.corpus.exports.exists())
self.assertFalse(delay_mock.called) self.assertFalse(delay_mock.called)
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_bad_source(self, delay_mock):
self.client.force_login(self.superuser)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}), {"source": "jouvence"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"source": ['"jouvence" is not a valid choice.']})
self.assertEqual(self.corpus.exports.count(), 0)
self.assertFalse(delay_mock.called)
@patch("arkindex.documents.models.CorpusExport.source")
@patch("arkindex.project.triggers.export.export_corpus.delay")
@override_settings(EXPORT_TTL_SECONDS=420)
def test_start_with_source(self, delay_mock, source_field_mock):
source_field_mock.field.choices.return_value = [("default", "default"), ("jouvence", "jouvence")]
self.client.force_login(self.superuser)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}), {"source": "jouvence"})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
export = self.corpus.exports.get()
self.assertDictEqual(response.json(), {
"id": str(export.id),
"created": export.created.isoformat().replace("+00:00", "Z"),
"updated": export.updated.isoformat().replace("+00:00", "Z"),
"user": {
"id": self.superuser.id,
"display_name": self.superuser.display_name,
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "jouvence"
})
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(
corpus_export=export,
user_id=self.superuser.id,
description="Export of corpus Unit Tests from source jouvence"
))
@patch("arkindex.project.triggers.export.export_corpus.delay") @patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_running(self, delay_mock): def test_start_running(self, delay_mock):
self.client.force_login(self.superuser) self.client.force_login(self.superuser)
...@@ -81,7 +125,9 @@ class TestExport(FixtureAPITestCase): ...@@ -81,7 +125,9 @@ class TestExport(FixtureAPITestCase):
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id})) response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ["An export is already running for this corpus."]) self.assertDictEqual(response.json(), {
"non_field_errors": ["An export is already running for this corpus."]
})
self.assertEqual(self.corpus.exports.count(), 1) self.assertEqual(self.corpus.exports.count(), 1)
self.assertFalse(delay_mock.called) self.assertFalse(delay_mock.called)
...@@ -99,11 +145,53 @@ class TestExport(FixtureAPITestCase): ...@@ -99,11 +145,53 @@ class TestExport(FixtureAPITestCase):
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id})) response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ["An export has already been made for this corpus in the last 420 seconds."]) self.assertDictEqual(response.json(), {
"non_field_errors": ["An export has already been made for this corpus in the last 420 seconds."]
})
self.assertEqual(self.corpus.exports.count(), 1) self.assertEqual(self.corpus.exports.count(), 1)
self.assertFalse(delay_mock.called) self.assertFalse(delay_mock.called)
@override_settings(EXPORT_TTL_SECONDS=420)
@patch("arkindex.project.triggers.export.export_corpus.delay")
def test_start_recent_export_different_source(self, delay_mock):
from arkindex.documents.models import CorpusExport
CorpusExport.source.field.choices = [("default", "default"), ("jouvence", "jouvence")]
self.client.force_login(self.superuser)
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) - timedelta(minutes=2)
self.corpus.exports.create(
user=self.user,
state=CorpusExportState.Done,
source="jouvence"
)
response = self.client.post(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
export = self.corpus.exports.get(source="default")
self.assertDictEqual(response.json(), {
"id": str(export.id),
"created": export.created.isoformat().replace("+00:00", "Z"),
"updated": export.updated.isoformat().replace("+00:00", "Z"),
"user": {
"id": self.superuser.id,
"display_name": self.superuser.display_name,
"email": self.superuser.email,
},
"corpus_id": str(self.corpus.id),
"state": CorpusExportState.Created.value,
"source": "default"
})
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(
corpus_export=export,
user_id=self.superuser.id,
description="Export of corpus Unit Tests"
))
def test_list(self): def test_list(self):
export1 = self.corpus.exports.create(user=self.user, state=CorpusExportState.Done) export1 = self.corpus.exports.create(user=self.user, state=CorpusExportState.Done)
export2 = self.corpus.exports.create(user=self.superuser) export2 = self.corpus.exports.create(user=self.superuser)
...@@ -123,6 +211,7 @@ class TestExport(FixtureAPITestCase): ...@@ -123,6 +211,7 @@ class TestExport(FixtureAPITestCase):
"email": self.superuser.email, "email": self.superuser.email,
}, },
"corpus_id": str(self.corpus.id), "corpus_id": str(self.corpus.id),
"source": "default"
}, },
{ {
"id": str(export1.id), "id": str(export1.id),
...@@ -135,6 +224,7 @@ class TestExport(FixtureAPITestCase): ...@@ -135,6 +224,7 @@ class TestExport(FixtureAPITestCase):
"email": self.user.email, "email": self.user.email,
}, },
"corpus_id": str(self.corpus.id), "corpus_id": str(self.corpus.id),
"source": "default"
}, },
]) ])
...@@ -149,18 +239,19 @@ class TestExport(FixtureAPITestCase): ...@@ -149,18 +239,19 @@ class TestExport(FixtureAPITestCase):
response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id})) response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@patch("arkindex.project.mixins.has_access", return_value=False) @patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_list_requires_guest(self, has_access_mock): def test_list_requires_guest(self, filter_rights_mock):
self.user.rights.all().delete() self.user.rights.all().delete()
self.corpus.public = False self.corpus.public = False
self.corpus.save() self.corpus.save()
filter_rights_mock.return_value = Corpus.objects.none()
self.client.force_login(self.user) self.client.force_login(self.user)
response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id})) response = self.client.get(reverse("api:corpus-export", kwargs={"pk": self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(has_access_mock.call_count, 1) self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Guest.value, skip_public=False)) self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value))
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url") @patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_download_export(self, presigned_url_mock): def test_download_export(self, presigned_url_mock):
......
import shlex import shlex
from collections import defaultdict from collections import defaultdict
from datetime import timedelta
from functools import wraps from functools import wraps
from os import path from os import path
from typing import Dict, List, Sequence, Tuple from typing import Dict, List, Sequence, Tuple
...@@ -7,10 +8,12 @@ from uuid import UUID ...@@ -7,10 +8,12 @@ from uuid import UUID
from django.conf import settings from django.conf import settings
from django.db.models import Prefetch, prefetch_related_objects from django.db.models import Prefetch, prefetch_related_objects
from django.utils import timezone
from django.utils.functional import cached_property
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from arkindex.images.models import ImageServer from arkindex.images.models import ImageServer
from arkindex.ponos.models import Task, task_token_default from arkindex.ponos.models import GPU, Task, task_token_default
class ProcessBuilder(object): class ProcessBuilder(object):
...@@ -162,6 +165,10 @@ class ProcessBuilder(object): ...@@ -162,6 +165,10 @@ class ProcessBuilder(object):
env["ARKINDEX_CORPUS_ID"] = str(self.process.corpus_id) env["ARKINDEX_CORPUS_ID"] = str(self.process.corpus_id)
return env return env
@cached_property
def active_gpu_agents(self) -> bool:
return GPU.objects.filter(agent__last_ping__gt=timezone.now() - timedelta(seconds=30)).exists()
@prefetch_worker_runs @prefetch_worker_runs
def validate_gpu_requirement(self): def validate_gpu_requirement(self):
from arkindex.process.models import FeatureUsage from arkindex.process.models import FeatureUsage
...@@ -298,6 +305,7 @@ class ProcessBuilder(object): ...@@ -298,6 +305,7 @@ class ProcessBuilder(object):
chunk=index if len(chunks) > 1 else None, chunk=index if len(chunks) > 1 else None,
workflow_runs=worker_runs, workflow_runs=worker_runs,
run=self.run, run=self.run,
active_gpu_agents=self.active_gpu_agents,
) )
self.tasks.append(task) self.tasks.append(task)
self.tasks_parents[task.slug].extend(parent_slugs) self.tasks_parents[task.slug].extend(parent_slugs)
......
...@@ -911,7 +911,7 @@ class WorkerRun(models.Model): ...@@ -911,7 +911,7 @@ class WorkerRun(models.Model):
# we add the WorkerRun ID at the end of the slug # we add the WorkerRun ID at the end of the slug
return f"{self.version.worker.slug}_{str(self.id)[:6]}" return f"{self.version.worker.slug}_{str(self.id)[:6]}"
def build_task(self, process, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None): def build_task(self, process, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None, active_gpu_agents=False):
""" """
Build the Task that will represent this WorkerRun in ponos using : Build the Task that will represent this WorkerRun in ponos using :
- the docker image name given by the WorkerVersion - the docker image name given by the WorkerVersion
...@@ -967,6 +967,12 @@ class WorkerRun(models.Model): ...@@ -967,6 +967,12 @@ class WorkerRun(models.Model):
assert self.model_version.state == ModelVersionState.Available, f"ModelVersion {self.model_version.id} is not available and cannot be used to build a task." assert self.model_version.state == ModelVersionState.Available, f"ModelVersion {self.model_version.id} is not available and cannot be used to build a task."
extra_files = {"model": settings.PUBLIC_HOSTNAME + reverse("api:model-version-download", kwargs={"pk": self.model_version.id}) + f"?token={self.model_version.build_authentication_token_hash()}"} extra_files = {"model": settings.PUBLIC_HOSTNAME + reverse("api:model-version-download", kwargs={"pk": self.model_version.id}) + f"?token={self.model_version.build_authentication_token_hash()}"}
requires_gpu = process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported)
# Do not require a GPU if there are no active agents with GPU and the GPU feature is only supported by the worker version;
# this does not make sense in the context of RQ tasks execution
if not settings.PONOS_RQ_EXECUTION and not active_gpu_agents and self.version.gpu_usage != FeatureUsage.Required:
requires_gpu = False
task = Task( task = Task(
command=self.version.docker_command, command=self.version.docker_command,
image=self.version.docker_image_iid or self.version.docker_image_name, image=self.version.docker_image_iid or self.version.docker_image_name,
...@@ -981,7 +987,7 @@ class WorkerRun(models.Model): ...@@ -981,7 +987,7 @@ class WorkerRun(models.Model):
process=process, process=process,
worker_run=self, worker_run=self,
extra_files=extra_files, extra_files=extra_files,
requires_gpu=process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported) requires_gpu=requires_gpu
) )
return task, parents return task, parents
......
import uuid
from collections import namedtuple from collections import namedtuple
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import call, patch from unittest.mock import call, patch
...@@ -7,9 +8,10 @@ from rest_framework import status ...@@ -7,9 +8,10 @@ from rest_framework import status
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from arkindex.documents.models import Corpus, Element from arkindex.documents.models import Corpus, Element
from arkindex.ponos.models import Farm, State from arkindex.ponos.models import GPU, Agent, Farm, State
from arkindex.process.models import ( from arkindex.process.models import (
ActivityState, ActivityState,
FeatureUsage,
Process, Process,
ProcessDataset, ProcessDataset,
ProcessMode, ProcessMode,
...@@ -31,6 +33,21 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -31,6 +33,21 @@ class TestCreateProcess(FixtureAPITestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
super().setUpTestData() super().setUpTestData()
cls.agent = Agent.objects.create(
farm=Farm.objects.first(),
hostname="claude",
cpu_cores=42,
cpu_frequency=1e15,
ram_total=99e9,
last_ping=datetime.now(timezone.utc),
)
cls.agent.gpus.create(
id=uuid.uuid4(),
name="claudette",
index=2,
ram_total=12
)
cls.volume = Element.objects.get(name="Volume 1") cls.volume = Element.objects.get(name="Volume 1")
cls.pages = Element.objects.get_descending(cls.volume.id).filter(type__slug="page", polygon__isnull=False) cls.pages = Element.objects.get_descending(cls.volume.id).filter(type__slug="page", polygon__isnull=False)
cls.ml_class = cls.corpus.ml_classes.create(name="bretzel") cls.ml_class = cls.corpus.ml_classes.create(name="bretzel")
...@@ -585,7 +602,7 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -585,7 +602,7 @@ class TestCreateProcess(FixtureAPITestCase):
self.assertFalse(self.corpus.worker_versions.exists()) self.assertFalse(self.corpus.worker_versions.exists())
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process_2.id)}), reverse("api:process-start", kwargs={"pk": str(process_2.id)}),
{"worker_activity": True}, {"worker_activity": True},
...@@ -676,7 +693,7 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -676,7 +693,7 @@ class TestCreateProcess(FixtureAPITestCase):
) )
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process_2.id)}), reverse("api:process-start", kwargs={"pk": str(process_2.id)}),
{"use_cache": True}, {"use_cache": True},
...@@ -714,7 +731,7 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -714,7 +731,7 @@ class TestCreateProcess(FixtureAPITestCase):
@patch("arkindex.ponos.models.base64.encodebytes") @patch("arkindex.ponos.models.base64.encodebytes")
def test_create_process_use_gpu_option(self, token_mock): def test_create_process_use_gpu_option(self, token_mock):
""" """
A process with the `use_gpu` parameter enables the `requires_gpu` attribute of tasks than need one A process with the `use_gpu` parameter enables the `requires_gpu` attribute of tasks that need one
""" """
token_mock.side_effect = [b"12345", b"67891"] token_mock.side_effect = [b"12345", b"67891"]
process_2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) process_2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
...@@ -724,7 +741,7 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -724,7 +741,7 @@ class TestCreateProcess(FixtureAPITestCase):
) )
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process_2.id)}), reverse("api:process-start", kwargs={"pk": str(process_2.id)}),
{"use_gpu": True}, {"use_gpu": True},
...@@ -755,6 +772,60 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -755,6 +772,60 @@ class TestCreateProcess(FixtureAPITestCase):
self.assertEqual(len(worker_task.parents.all()), 1) self.assertEqual(len(worker_task.parents.all()), 1)
self.assertEqual(worker_task.parents.first(), init_task) self.assertEqual(worker_task.parents.first(), init_task)
@override_settings(
ARKINDEX_TASKS_IMAGE="registry.teklia.com/tasks",
PONOS_DEFAULT_ENV={}
)
@patch("arkindex.ponos.models.base64.encodebytes")
def test_create_process_use_gpu_option_no_available_gpus(self, token_mock):
"""
If there are no avilables Agents with GPU, then requires_gpu is not sent if the worker
version only supports and not requires it
"""
self.agent.gpus.all().delete()
token_mock.side_effect = [b"12345", b"67891", b"54321", b"19876"]
for feature_usage, requires_gpu, task_token in [
(FeatureUsage.Supported, False, "67891"),
(FeatureUsage.Required, True, "19876")
]:
with self.subTest(feature_usage=feature_usage, requires_gpu=requires_gpu):
process = self.corpus.processes.create(
creator=self.user,
mode=ProcessMode.Workers,
farm=Farm.objects.first(),
)
run = process.worker_runs.create(
version=self.version_3,
parents=[],
)
process.use_gpu = True
self.assertEqual(GPU.objects.count(), 0)
self.version_3.gpu_usage = feature_usage
self.version_3.save()
process.run()
init_task = process.tasks.get(slug="initialisation")
self.assertEqual(init_task.command, f"python -m arkindex_tasks.init_elements {process.id} --chunks-number 1")
self.assertEqual(init_task.image, "registry.teklia.com/tasks")
worker_task = process.tasks.get(slug=run.task_slug)
self.assertEqual(worker_task.command, None)
self.assertEqual(worker_task.image, f"my_repo.fake/workers/worker/worker-gpu:{self.version_3.id}")
self.assertEqual(worker_task.image_artifact.id, self.version_3.docker_image.id)
self.assertEqual(worker_task.shm_size, None)
self.assertEqual(worker_task.env, {
"TASK_ELEMENTS": "/data/initialisation/elements.json",
"ARKINDEX_CORPUS_ID": str(self.corpus.id),
"ARKINDEX_PROCESS_ID": str(process.id),
"ARKINDEX_WORKER_RUN_ID": str(process.worker_runs.get().id),
"ARKINDEX_TASK_TOKEN": task_token
})
self.assertEqual(worker_task.requires_gpu, requires_gpu)
self.assertEqual(len(worker_task.parents.all()), 1)
self.assertEqual(worker_task.parents.first(), init_task)
def test_retry_keeps_requires_gpu(self): def test_retry_keeps_requires_gpu(self):
""" """
When a process is retried, the newly created tasks keep the same requires_gpu values When a process is retried, the newly created tasks keep the same requires_gpu values
...@@ -819,7 +890,7 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -819,7 +890,7 @@ class TestCreateProcess(FixtureAPITestCase):
process.use_gpu = True process.use_gpu = True
process.save() process.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process.id)}), reverse("api:process-start", kwargs={"pk": str(process.id)}),
{"use_gpu": "true"} {"use_gpu": "true"}
...@@ -907,7 +978,7 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -907,7 +978,7 @@ class TestCreateProcess(FixtureAPITestCase):
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
process.versions.add(custom_version) process.versions.add(custom_version)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)})) response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......
...@@ -1749,7 +1749,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -1749,7 +1749,7 @@ class TestProcesses(FixtureAPITestCase):
self.workers_process.activity_state = ActivityState.Error self.workers_process.activity_state = ActivityState.Error
self.workers_process.save() self.workers_process.save()
with self.assertNumQueries(13): with self.assertNumQueries(14):
response = self.client.post(reverse("api:process-retry", kwargs={"pk": self.workers_process.id})) response = self.client.post(reverse("api:process-retry", kwargs={"pk": self.workers_process.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -2126,7 +2126,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2126,7 +2126,7 @@ class TestProcesses(FixtureAPITestCase):
with ( with (
self.settings(IMPORTS_WORKER_VERSION=str(self.version_with_model.id)), self.settings(IMPORTS_WORKER_VERSION=str(self.version_with_model.id)),
self.assertNumQueries(8) self.assertNumQueries(9)
): ):
response = self.client.post(reverse("api:files-process"), { response = self.client.post(reverse("api:files-process"), {
"files": [str(self.img_df.id)], "files": [str(self.img_df.id)],
...@@ -2217,7 +2217,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2217,7 +2217,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process2.tasks.exists()) self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process2.id)}) reverse("api:process-start", kwargs={"pk": str(process2.id)})
) )
...@@ -2362,7 +2362,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2362,7 +2362,7 @@ class TestProcesses(FixtureAPITestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process2.id)}) reverse("api:process-start", kwargs={"pk": str(process2.id)})
) )
...@@ -2477,7 +2477,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2477,7 +2477,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process.tasks.exists()) self.assertFalse(process.tasks.exists())
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process.id)}) reverse("api:process-start", kwargs={"pk": str(process.id)})
) )
...@@ -2499,7 +2499,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2499,7 +2499,7 @@ class TestProcesses(FixtureAPITestCase):
farm = Farm.objects.get(name="Wheat farm") farm = Farm.objects.get(name="Wheat farm")
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(15): with self.assertNumQueries(16):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(workers_process.id)}), reverse("api:process-start", kwargs={"pk": str(workers_process.id)}),
{"farm": str(farm.id)} {"farm": str(farm.id)}
...@@ -2697,7 +2697,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2697,7 +2697,7 @@ class TestProcesses(FixtureAPITestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(15): with self.assertNumQueries(16):
response = self.client.post( response = self.client.post(
reverse("api:process-start", kwargs={"pk": str(process.id)}), reverse("api:process-start", kwargs={"pk": str(process.id)}),
{"use_cache": "true", "worker_activity": "true", "use_gpu": "true"} {"use_cache": "true", "worker_activity": "true", "use_gpu": "true"}
...@@ -2733,7 +2733,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2733,7 +2733,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertNotEqual(run_1.task_slug, run_2.task_slug) self.assertNotEqual(run_1.task_slug, run_2.task_slug)
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)})) response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......
...@@ -362,6 +362,8 @@ REDIS_ZREM_CHUNK_SIZE = 10000 ...@@ -362,6 +362,8 @@ REDIS_ZREM_CHUNK_SIZE = 10000
# How long before a corpus export can be run again after a successful one # How long before a corpus export can be run again after a successful one
EXPORT_TTL_SECONDS = conf["export"]["ttl"] EXPORT_TTL_SECONDS = conf["export"]["ttl"]
# Available database sources for corpus exports
EXPORT_SOURCES = ["default"]
LOGGING = { LOGGING = {
"version": 1, "version": 1,
......
...@@ -189,10 +189,13 @@ def export_corpus(corpus_export: CorpusExport) -> None: ...@@ -189,10 +189,13 @@ def export_corpus(corpus_export: CorpusExport) -> None:
""" """
Export a corpus to a SQLite database Export a corpus to a SQLite database
""" """
description = f"Export of corpus {corpus_export.corpus.name}"
if corpus_export.source != "default":
description += f" from source {corpus_export.source}"
export.export_corpus.delay( export.export_corpus.delay(
corpus_export=corpus_export, corpus_export=corpus_export,
user_id=corpus_export.user_id, user_id=corpus_export.user_id,
description=f"Export of corpus {corpus_export.corpus.name}" description=description
) )
......
...@@ -307,10 +307,15 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView): ...@@ -307,10 +307,15 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView):
# Set the current model version as erroneous and return the available one # Set the current model version as erroneous and return the available one
instance.state = ModelVersionState.Error instance.state = ModelVersionState.Error
instance.save(update_fields=["state"]) instance.save(update_fields=["state"])
# Set context
context = {
**self.get_serializer_context(),
"is_contributor": True
}
return Response( return Response(
ModelVersionSerializer( ModelVersionSerializer(
existing_model_version, existing_model_version,
context={"is_contributor": True, "model": instance}, context=context,
).data, ).data,
status=status.HTTP_409_CONFLICT, status=status.HTTP_409_CONFLICT,
) )
......
...@@ -213,7 +213,7 @@ class ModelVersionSerializer(serializers.ModelSerializer): ...@@ -213,7 +213,7 @@ class ModelVersionSerializer(serializers.ModelSerializer):
else: else:
model = self.context.get("model") model = self.context.get("model")
if model: if model:
qs = ModelVersion.objects.filter(model_id=model.id) qs = ModelVersion.objects.filter(model__in=Model.objects.readable(self.context["request"].user))
if getattr(self.instance, "id", None): if getattr(self.instance, "id", None):
qs = qs.exclude(id=self.instance.id) qs = qs.exclude(id=self.instance.id)
self.fields["parent"].queryset = qs self.fields["parent"].queryset = qs
......
...@@ -266,6 +266,82 @@ class TestModelAPI(FixtureAPITestCase): ...@@ -266,6 +266,82 @@ class TestModelAPI(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]}) self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_create_model_version_any_parent_model_version(self, s3_presigned_url_mock):
"""
Any readable model version can be set as parent of a model version, not just versions of the same model
"""
self.client.force_login(self.user1)
s3_presigned_url_mock.return_value = "http://s3/upload_put_url"
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(6):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn("id", data)
df = ModelVersion.objects.get(id=data["id"])
self.assertDictEqual(
data,
{
"id": str(df.id),
"model_id": str(self.model1.id),
"parent": str(self.model_version3.id),
"description": "description",
"state": ModelVersionState.Created.value,
"configuration": {"hello": "this is me"},
"tag": "TAG",
"size": None,
"hash": None,
"created": fake_now.isoformat().replace("+00:00", "Z"),
"s3_url": None,
"s3_put_url": s3_presigned_url_mock.return_value
}
)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_create_model_version_readable_parent_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id=self.model1.id)
self.client.force_login(self.user1)
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(4):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user1, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version3.id)}" - object does not exist.']
})
def test_retrieve_model_requires_login(self): def test_retrieve_model_requires_login(self):
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.client.get(reverse("api:model-retrieve", kwargs={"pk": str(self.model2.id)})) response = self.client.get(reverse("api:model-retrieve", kwargs={"pk": str(self.model2.id)}))
...@@ -934,8 +1010,63 @@ class TestModelAPI(FixtureAPITestCase): ...@@ -934,8 +1010,63 @@ class TestModelAPI(FixtureAPITestCase):
"size": 8, "size": 8,
}) })
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_partial_update_any_parent_model_version(self, s3_presigned_url):
"""
A model version can have any model version as a parent, not just a version from the same model
"""
s3_presigned_url.return_value = "http://s3/get_url"
self.client.force_login(self.user2)
with self.assertNumQueries(6):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.model_version3.id),
"model_id": str(self.model2.id),
"created": self.model_version3.created.isoformat().replace("+00:00", "Z"),
"s3_url": "http://s3/get_url",
"s3_put_url": None,
"tag": "tagged",
"description": "",
"configuration": {},
"parent": str(self.model_version1.id),
"state": ModelVersionState.Available.value,
"hash": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba",
"size": 8,
})
self.model_version3.refresh_from_db() self.model_version3.refresh_from_db()
self.assertIsNone(self.model_version3.parent_id) self.assertEqual(self.model_version3.parent_id, self.model_version1.id)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_partial_update_readable_parent_model_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id__in=[self.model2.id, self.model3.id])
self.client.force_login(self.user2)
with self.assertNumQueries(4):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user2, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version1.id)}" - object does not exist.']
})
@patch("arkindex.training.api.get_max_level", return_value=None) @patch("arkindex.training.api.get_max_level", return_value=None)
def test_update_model_version_requires_contributor(self, get_max_level_mock): def test_update_model_version_requires_contributor(self, get_max_level_mock):
......