Skip to content
Snippets Groups Projects
Commit 704b0b7b authored by ml bonhomme's avatar ml bonhomme :bee:
Browse files

process dataset sets

parent c367fd09
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !2264. Comments created here will be created in the context of that merge request.
...@@ -23,7 +23,7 @@ from arkindex.documents.models import ( ...@@ -23,7 +23,7 @@ from arkindex.documents.models import (
TranscriptionEntity, TranscriptionEntity,
) )
from arkindex.ponos.models import Task from arkindex.ponos.models import Task
from arkindex.process.models import Process, ProcessDataset, ProcessElement, WorkerActivity, WorkerRun from arkindex.process.models import Process, ProcessDatasetSet, ProcessElement, WorkerActivity, WorkerRun
from arkindex.training.models import DatasetElement, DatasetSet from arkindex.training.models import DatasetElement, DatasetSet
from arkindex.users.models import User from arkindex.users.models import User
...@@ -70,9 +70,9 @@ def corpus_delete(corpus_id: str) -> None: ...@@ -70,9 +70,9 @@ def corpus_delete(corpus_id: str) -> None:
Selection.objects.filter(element__corpus_id=corpus_id), Selection.objects.filter(element__corpus_id=corpus_id),
corpus.memberships.all(), corpus.memberships.all(),
corpus.exports.all(), corpus.exports.all(),
# ProcessDataset M2M # ProcessDatasetSet M2M
ProcessDataset.objects.filter(dataset__corpus_id=corpus_id), ProcessDatasetSet.objects.filter(set__dataset__corpus_id=corpus_id),
ProcessDataset.objects.filter(process__corpus_id=corpus_id), ProcessDatasetSet.objects.filter(process__corpus_id=corpus_id),
DatasetElement.objects.filter(set__dataset__corpus_id=corpus_id), DatasetElement.objects.filter(set__dataset__corpus_id=corpus_id),
DatasetSet.objects.filter(dataset__corpus_id=corpus_id), DatasetSet.objects.filter(dataset__corpus_id=corpus_id),
corpus.datasets.all(), corpus.datasets.all(),
......
...@@ -46,7 +46,6 @@ from rest_framework.generics import ( ...@@ -46,7 +46,6 @@ from rest_framework.generics import (
RetrieveDestroyAPIView, RetrieveDestroyAPIView,
RetrieveUpdateAPIView, RetrieveUpdateAPIView,
RetrieveUpdateDestroyAPIView, RetrieveUpdateDestroyAPIView,
UpdateAPIView,
) )
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
...@@ -61,7 +60,7 @@ from arkindex.process.models import ( ...@@ -61,7 +60,7 @@ from arkindex.process.models import (
GitRef, GitRef,
GitRefType, GitRefType,
Process, Process,
ProcessDataset, ProcessDatasetSet,
ProcessMode, ProcessMode,
Revision, Revision,
Worker, Worker,
...@@ -87,7 +86,7 @@ from arkindex.process.serializers.imports import ( ...@@ -87,7 +86,7 @@ from arkindex.process.serializers.imports import (
StartProcessSerializer, StartProcessSerializer,
) )
from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer
from arkindex.process.serializers.training import ProcessDatasetSerializer from arkindex.process.serializers.training import ProcessDatasetSetSerializer
from arkindex.process.serializers.worker_runs import ( from arkindex.process.serializers.worker_runs import (
CorpusWorkerRunSerializer, CorpusWorkerRunSerializer,
UserWorkerRunSerializer, UserWorkerRunSerializer,
...@@ -565,7 +564,7 @@ class StartProcess(CorpusACLMixin, CreateAPIView): ...@@ -565,7 +564,7 @@ class StartProcess(CorpusACLMixin, CreateAPIView):
"model_version__model", "model_version__model",
"configuration", "configuration",
))) )))
.prefetch_related("datasets") .prefetch_related("sets")
# Uses Exists() for has_tasks and not a __isnull because we are not joining on tasks and do not need to fetch them # Uses Exists() for has_tasks and not a __isnull because we are not joining on tasks and do not need to fetch them
.annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef("pk")))) .annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef("pk"))))
) )
...@@ -677,20 +676,20 @@ class DataFileCreate(CreateAPIView): ...@@ -677,20 +676,20 @@ class DataFileCreate(CreateAPIView):
@extend_schema(tags=["process"]) @extend_schema(tags=["process"])
@extend_schema_view( @extend_schema_view(
get=extend_schema( get=extend_schema(
operation_id="ListProcessDatasets", operation_id="ListProcessSets",
description=dedent( description=dedent(
""" """
List all datasets on a process. List all dataset sets on a process.
Requires a **guest** access to the process. Requires a **guest** access to the process.
""" """
), ),
), ),
) )
class ProcessDatasets(ProcessACLMixin, ListAPIView): class ProcessDatasetSets(ProcessACLMixin, ListAPIView):
permission_classes = (IsVerified, ) permission_classes = (IsVerified, )
serializer_class = ProcessDatasetSerializer serializer_class = ProcessDatasetSetSerializer
queryset = ProcessDataset.objects.none() queryset = ProcessDatasetSet.objects.none()
@cached_property @cached_property
def process(self): def process(self):
...@@ -704,10 +703,10 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): ...@@ -704,10 +703,10 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
def get_queryset(self): def get_queryset(self):
return ( return (
ProcessDataset.objects.filter(process_id=self.process.id) ProcessDatasetSet.objects.filter(process_id=self.process.id)
.select_related("process__creator", "dataset__creator") .select_related("process__creator", "set__dataset__creator")
.prefetch_related("dataset__sets") .prefetch_related("set__dataset__sets")
.order_by("dataset__name") .order_by("set__dataset__name", "set__name")
) )
def get_serializer_context(self): def get_serializer_context(self):
...@@ -724,38 +723,38 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): ...@@ -724,38 +723,38 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
@extend_schema(tags=["process"]) @extend_schema(tags=["process"])
@extend_schema_view( @extend_schema_view(
post=extend_schema( post=extend_schema(
operation_id="CreateProcessDataset", operation_id="CreateProcessSet",
description=dedent( description=dedent(
""" """
Add a dataset to a process. Add a dataset set to a process.
Requires an **admin** access to the process and a **guest** access to the dataset's corpus. Requires an **admin** access to the process and a **guest** access to the dataset's corpus.
""" """
), ),
), ),
delete=extend_schema( delete=extend_schema(
operation_id="DestroyProcessDataset", operation_id="DestroyProcessSet",
description=dedent( description=dedent(
""" """
Remove a dataset from a process. Remove a dataset set from a process.
Requires an **admin** access to the process. Requires an **admin** access to the process.
""" """
), ),
), ),
) )
class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): class ProcessDatasetSetManage(CreateAPIView, DestroyAPIView):
permission_classes = (IsVerified, ) permission_classes = (IsVerified, )
serializer_class = ProcessDatasetSerializer serializer_class = ProcessDatasetSetSerializer
def get_object(self): def get_object(self):
process_dataset = get_object_or_404( process_dataset = get_object_or_404(
ProcessDataset.objects ProcessDatasetSet.objects
.select_related("dataset__creator", "process__corpus") .select_related("set__dataset__creator", "process__corpus")
.prefetch_related("dataset__sets") .prefetch_related("set__dataset__sets")
# Required to check for a process that have already started # Required to check for a process that have already started
.annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))), .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))),
dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"] set_id=self.kwargs["set"], process_id=self.kwargs["process"]
) )
# Copy the has_tasks annotation onto the process # Copy the has_tasks annotation onto the process
process_dataset.process.has_tasks = process_dataset.process_has_tasks process_dataset.process.has_tasks = process_dataset.process_has_tasks
...@@ -770,11 +769,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): ...@@ -770,11 +769,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView):
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
# Ignore the sets when retrieving the ProcessDataset instance, as there cannot be get_object_or_404(ProcessDatasetSet, **serializer.validated_data).delete()
# two ProcessDatasets with the same dataset and process, whatever the sets
validated_data = serializer.validated_data
del validated_data["sets"]
get_object_or_404(ProcessDataset, **validated_data).delete()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)
......
# Generated by Django 4.1.7 on 2024-03-21 12:00
import uuid
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("training", "0007_datasetset_model"),
("process", "0031_process_corpus_check_and_remove_revision_field"),
]
operations = [
migrations.CreateModel(
name="ProcessDatasetSet",
fields=[
("id", models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
("process", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name="process_sets", to="process.process")),
("set", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name="process_sets", to="training.datasetset")),
],
),
migrations.AddConstraint(
model_name="processdatasetset",
constraint=models.UniqueConstraint(fields=("process", "set"), name="unique_process_set"),
),
migrations.RunSQL(
"""
INSERT INTO process_processdatasetset (id, process_id, set_id)
SELECT uuid_generate_v4(), p.process_id, dss.id
FROM (
SELECT DISTINCT process_id, unnest(sets) AS set
FROM process_processdataset
) p
INNER JOIN training_datasetset AS dss ON (dataset_id = dss.dataset_id AND set = dss.name)
""",
reverse_sql=migrations.RunSQL.noop,
),
migrations.RemoveField(
model_name="process",
name="datasets",
),
migrations.AddField(
model_name="process",
name="sets",
field=models.ManyToManyField(related_name="processes", through="process.ProcessDatasetSet", to="training.datasetset"),
),
migrations.DeleteModel(
name="ProcessDataset",
),
]
...@@ -75,7 +75,7 @@ class Process(IndexableModel): ...@@ -75,7 +75,7 @@ class Process(IndexableModel):
corpus = models.ForeignKey("documents.Corpus", on_delete=models.CASCADE, related_name="processes", null=True) corpus = models.ForeignKey("documents.Corpus", on_delete=models.CASCADE, related_name="processes", null=True)
mode = EnumField(ProcessMode, max_length=30) mode = EnumField(ProcessMode, max_length=30)
files = models.ManyToManyField("process.DataFile", related_name="processes") files = models.ManyToManyField("process.DataFile", related_name="processes")
datasets = models.ManyToManyField("training.Dataset", related_name="processes", through="process.ProcessDataset") sets = models.ManyToManyField("training.DatasetSet", related_name="processes", through="process.ProcessDatasetSet")
versions = models.ManyToManyField("process.WorkerVersion", through="process.WorkerRun", related_name="processes") versions = models.ManyToManyField("process.WorkerVersion", through="process.WorkerRun", related_name="processes")
activity_state = EnumField(ActivityState, max_length=32, default=ActivityState.Disabled) activity_state = EnumField(ActivityState, max_length=32, default=ActivityState.Disabled)
...@@ -468,26 +468,19 @@ class ProcessElement(models.Model): ...@@ -468,26 +468,19 @@ class ProcessElement(models.Model):
) )
class ProcessDataset(models.Model): class ProcessDatasetSet(models.Model):
""" """
Link between Processes and Datasets. Link between Processes and Dataset Sets.
""" """
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
process = models.ForeignKey(Process, on_delete=models.CASCADE, related_name="process_datasets") process = models.ForeignKey(Process, on_delete=models.CASCADE, related_name="process_sets")
dataset = models.ForeignKey("training.Dataset", on_delete=models.CASCADE, related_name="process_datasets") set = models.ForeignKey("training.DatasetSet", on_delete=models.CASCADE, related_name="process_sets")
sets = ArrayField(
models.CharField(max_length=50, validators=[MinLengthValidator(1)]),
validators=[
MinLengthValidator(1),
validate_unique_set_names,
]
)
class Meta: class Meta:
constraints = [ constraints = [
models.UniqueConstraint( models.UniqueConstraint(
fields=["process", "dataset"], fields=["process", "set"],
name="unique_process_dataset", name="unique_process_set",
) )
] ]
......
...@@ -5,47 +5,41 @@ from rest_framework import serializers ...@@ -5,47 +5,41 @@ from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.exceptions import PermissionDenied, ValidationError
from arkindex.documents.models import Corpus from arkindex.documents.models import Corpus
from arkindex.process.models import Process, ProcessDataset, ProcessMode, Task from arkindex.process.models import Process, ProcessDatasetSet, ProcessMode, Task
from arkindex.project.mixins import ProcessACLMixin from arkindex.project.mixins import ProcessACLMixin
from arkindex.training.models import Dataset from arkindex.training.models import DatasetSet
from arkindex.training.serializers import DatasetSerializer from arkindex.training.serializers import DatasetSerializer
from arkindex.users.models import Role from arkindex.users.models import Role
def _dataset_id_from_context(serializer_field): def _set_id_from_context(serializer_field):
return serializer_field.context.get("view").kwargs["dataset"] return serializer_field.context.get("view").kwargs["set"]
def _process_id_from_context(serializer_field): def _process_id_from_context(serializer_field):
return serializer_field.context.get("view").kwargs["process"] return serializer_field.context.get("view").kwargs["process"]
_dataset_id_from_context.requires_context = True _set_id_from_context.requires_context = True
_process_id_from_context.requires_context = True _process_id_from_context.requires_context = True
class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): class ProcessDatasetSetSerializer(ProcessACLMixin, serializers.ModelSerializer):
process_id = serializers.HiddenField( process_id = serializers.HiddenField(
write_only=True, write_only=True,
default=_process_id_from_context default=_process_id_from_context
) )
dataset_id = serializers.HiddenField( set_id = serializers.HiddenField(
write_only=True, write_only=True,
default=_dataset_id_from_context default=_set_id_from_context
)
dataset = DatasetSerializer(read_only=True)
sets = serializers.ListField(
child=serializers.CharField(max_length=50),
required=False,
allow_null=False,
allow_empty=False,
min_length=1
) )
dataset = DatasetSerializer(read_only=True, source="set.dataset")
set_name = serializers.CharField(read_only=True, source="set.name")
class Meta: class Meta:
model = ProcessDataset model = ProcessDatasetSet
fields = ("dataset_id", "dataset", "process_id", "id", "sets", ) fields = ("process_id", "set_id", "id", "dataset", "set_name", )
read_only_fields = ("process_id", "id", ) read_only_fields = ("process_id", "id", "dataset", "set_name", )
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -81,44 +75,34 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): ...@@ -81,44 +75,34 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
if not access or not (access >= Role.Admin.value): if not access or not (access >= Role.Admin.value):
raise PermissionDenied(detail="You do not have admin access to this process.") raise PermissionDenied(detail="You do not have admin access to this process.")
# Validate dataset # Validate dataset set
if not self.instance: if not self.instance:
if request_method == "DELETE": if request_method == "DELETE":
# Allow deleting ProcessDatasets even if the user looses access to the corpus # Allow deleting ProcessDatasetSets even if the user looses access to the corpus
dataset_qs = Dataset.objects.all() set_qs = DatasetSet.objects.all()
else: else:
dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) set_qs = (
DatasetSet.objects.filter(dataset__corpus__in=Corpus.objects.readable(self._user))
.select_related("dataset__creator").prefetch_related("dataset__sets")
)
try: try:
dataset = dataset_qs.select_related("creator").prefetch_related("sets").get(pk=data["dataset_id"]) dataset_set = set_qs.get(pk=data["set_id"])
except Dataset.DoesNotExist: except DatasetSet.DoesNotExist:
raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']}) raise ValidationError({"set": [f'Invalid pk "{str(data["set_id"])}" - object does not exist.']})
else: else:
dataset = self.instance.dataset dataset_set = self.instance.set
data["dataset"] = dataset
if process.mode != ProcessMode.Dataset: if process.mode != ProcessMode.Dataset:
errors["process"].append('Datasets can only be added to or removed from processes of mode "dataset".') errors["process"].append('Dataset sets can only be added to or removed from processes of mode "dataset".')
if process.has_tasks: if process.has_tasks:
errors["process"].append("Datasets cannot be updated on processes that have already started.") errors["process"].append("Dataset sets cannot be updated on processes that have already started.")
if self.context["request"].method == "POST" and process.datasets.filter(id=dataset.id).exists():
errors["dataset"].append("This dataset is already selected in this process.")
# Validate sets if self.context["request"].method == "POST" and process.sets.filter(id=dataset_set.id).exists():
sets = data.get("sets") errors["set"].append("This dataset set is already selected in this process.")
if not sets or len(sets) == 0:
if not self.instance:
data["sets"] = [item.name for item in list(dataset.sets.all())]
else:
errors["sets"].append("This field cannot be empty.")
else:
if any(s not in [item.name for item in list(dataset.sets.all())] for s in sets):
errors["sets"].append("The specified sets must all exist in the specified dataset.")
if len(set(sets)) != len(sets):
errors["sets"].append("Sets must be unique.")
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
data["set"] = dataset_set
return data return data
...@@ -13,7 +13,7 @@ from arkindex.process.models import ( ...@@ -13,7 +13,7 @@ from arkindex.process.models import (
ActivityState, ActivityState,
FeatureUsage, FeatureUsage,
Process, Process,
ProcessDataset, ProcessDatasetSet,
ProcessMode, ProcessMode,
Repository, Repository,
WorkerActivity, WorkerActivity,
...@@ -899,8 +899,8 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -899,8 +899,8 @@ class TestCreateProcess(FixtureAPITestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
dataset = self.corpus.datasets.first() dataset = self.corpus.datasets.first()
test_sets = list(dataset.sets.values_list("name", flat=True)) test_set = dataset.sets.get(name="test")
ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) ProcessDatasetSet.objects.create(process=process, set=dataset, sets=test_set)
process.versions.set([self.version_2, self.version_3]) process.versions.set([self.version_2, self.version_3])
with self.assertNumQueries(9): with self.assertNumQueries(9):
...@@ -930,8 +930,8 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -930,8 +930,8 @@ class TestCreateProcess(FixtureAPITestCase):
self.worker_1.save() self.worker_1.save()
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
dataset = self.corpus.datasets.first() dataset = self.corpus.datasets.first()
test_sets = list(dataset.sets.values_list("name", flat=True)) test_set = dataset.sets.get(name="test")
ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) ProcessDatasetSet.objects.create(process=process, set=test_set)
process.versions.add(self.version_1) process.versions.add(self.version_1)
with self.assertNumQueries(9): with self.assertNumQueries(9):
......
...@@ -75,8 +75,8 @@ from arkindex.process.api import ( ...@@ -75,8 +75,8 @@ from arkindex.process.api import (
DataFileRetrieve, DataFileRetrieve,
FilesProcess, FilesProcess,
ListProcessElements, ListProcessElements,
ProcessDatasetManage, ProcessDatasetSetManage,
ProcessDatasets, ProcessDatasetSets,
ProcessDetails, ProcessDetails,
ProcessList, ProcessList,
ProcessRetry, ProcessRetry,
...@@ -272,8 +272,8 @@ api = [ ...@@ -272,8 +272,8 @@ api = [
path("process/<uuid:pk>/apply/", ApplyProcessTemplate.as_view(), name="apply-process-template"), path("process/<uuid:pk>/apply/", ApplyProcessTemplate.as_view(), name="apply-process-template"),
path("process/<uuid:pk>/clear/", ClearProcess.as_view(), name="clear-process"), path("process/<uuid:pk>/clear/", ClearProcess.as_view(), name="clear-process"),
path("process/<uuid:pk>/select-failures/", SelectProcessFailures.as_view(), name="process-select-failures"), path("process/<uuid:pk>/select-failures/", SelectProcessFailures.as_view(), name="process-select-failures"),
path("process/<uuid:pk>/datasets/", ProcessDatasets.as_view(), name="process-datasets"), path("process/<uuid:pk>/sets/", ProcessDatasetSets.as_view(), name="process-sets"),
path("process/<uuid:process>/dataset/<uuid:dataset>/", ProcessDatasetManage.as_view(), name="process-dataset"), path("process/<uuid:process>/set/<uuid:set>/", ProcessDatasetSetManage.as_view(), name="process-set"),
# ML models training # ML models training
path("modelversion/<uuid:pk>/", ModelVersionsRetrieve.as_view(), name="model-version-retrieve"), path("modelversion/<uuid:pk>/", ModelVersionsRetrieve.as_view(), name="model-version-retrieve"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment