Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (4)
Showing
with 876 additions and 684 deletions
...@@ -2,6 +2,8 @@ SELECT ...@@ -2,6 +2,8 @@ SELECT
dataset.id, dataset.id,
dataset.name, dataset.name,
dataset.state, dataset.state,
ARRAY_TO_STRING(dataset.sets, ',', '') string_agg(datasetset.name, ',')
FROM training_dataset dataset FROM training_dataset dataset
INNER JOIN training_datasetset datasetset ON datasetset.dataset_id = dataset.id
WHERE dataset.corpus_id = '{corpus_id}'::uuid WHERE dataset.corpus_id = '{corpus_id}'::uuid
GROUP BY dataset.id
SELECT SELECT
dataset_element.id, dataset_element.id,
dataset_element.element_id, dataset_element.element_id,
dataset_element.dataset_id, dataset_set.dataset_id,
dataset_element.set dataset_set.name
FROM training_datasetelement dataset_element FROM training_datasetelement dataset_element
INNER JOIN training_dataset dataset ON (dataset_element.dataset_id = dataset.id) INNER JOIN training_datasetset dataset_set ON (dataset_element.set_id = dataset_set.id)
INNER JOIN training_dataset dataset ON (dataset_set.dataset_id = dataset.id)
WHERE dataset.corpus_id = '{corpus_id}'::uuid WHERE dataset.corpus_id = '{corpus_id}'::uuid
This diff is collapsed.
...@@ -20,6 +20,7 @@ from arkindex.process.models import ( ...@@ -20,6 +20,7 @@ from arkindex.process.models import (
WorkerVersionState, WorkerVersionState,
) )
from arkindex.project.tools import fake_now from arkindex.project.tools import fake_now
from arkindex.training.models import DatasetSet
from arkindex.users.models import Group, Right, Role, User from arkindex.users.models import Group, Right, Role, User
...@@ -271,8 +272,15 @@ class Command(BaseCommand): ...@@ -271,8 +272,15 @@ class Command(BaseCommand):
) )
# Create 2 datasets # Create 2 datasets
corpus.datasets.create(name="First Dataset", description="dataset number one", creator=user) dataset_1 = corpus.datasets.create(name="First Dataset", description="dataset number one", creator=user)
corpus.datasets.create(name="Second Dataset", description="dataset number two", creator=user) dataset_2 = corpus.datasets.create(name="Second Dataset", description="dataset number two", creator=user)
# Create their sets
DatasetSet.objects.bulk_create(
DatasetSet(name=name, dataset_id=dataset_1.id) for name in ["training", "validation", "test"]
)
DatasetSet.objects.bulk_create(
DatasetSet(name=name, dataset_id=dataset_2.id) for name in ["training", "validation", "test"]
)
# Create 2 volumes # Create 2 volumes
vol1 = Element.objects.create( vol1 = Element.objects.create(
......
...@@ -37,7 +37,7 @@ from arkindex.process.models import ( ...@@ -37,7 +37,7 @@ from arkindex.process.models import (
WorkerType, WorkerType,
WorkerVersion, WorkerVersion,
) )
from arkindex.training.models import Dataset, DatasetElement, Model from arkindex.training.models import Dataset, DatasetElement, DatasetSet, Model
from arkindex.users.models import Role, User from arkindex.users.models import Role, User
EXPORT_VERSION = 8 EXPORT_VERSION = 8
...@@ -320,17 +320,30 @@ class Command(BaseCommand): ...@@ -320,17 +320,30 @@ class Command(BaseCommand):
id=row["id"], id=row["id"],
corpus=self.corpus, corpus=self.corpus,
name=row["name"], name=row["name"],
sets=[r.strip() for r in row["sets"].split(",")],
creator=self.user, creator=self.user,
description="Imported dataset", description="Imported dataset",
)] )]
def convert_dataset_sets(self, row):
return [
DatasetSet(
name=set_name.strip(),
dataset_id=row["id"]
)
for set_name in row["sets"].split(",")
]
def map_dataset_sets(self):
return {
(str(set.dataset_id), set.name): set.id
for set in DatasetSet.objects.filter(dataset__corpus=self.corpus)
}
def convert_dataset_elements(self, row): def convert_dataset_elements(self, row):
return [DatasetElement( return [DatasetElement(
id=row["id"], id=row["id"],
element_id=row["element_id"], element_id=row["element_id"],
dataset_id=row["dataset_id"], set_id=self.dataset_sets_map[(row["dataset_id"], row["set_name"])]
set=row["set_name"],
)] )]
def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True): def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True):
...@@ -603,6 +616,12 @@ class Command(BaseCommand): ...@@ -603,6 +616,12 @@ class Command(BaseCommand):
# Create datasets # Create datasets
self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY) self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY)
# Create dataset sets
self.bulk_create_objects(DatasetSet, self.convert_dataset_sets, SQL_DATASET_QUERY)
# Create dataset sets mapping
self.dataset_sets_map = self.map_dataset_sets()
# Create dataset elements # Create dataset elements
self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY) self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY)
......
...@@ -24,7 +24,7 @@ from arkindex.documents.models import ( ...@@ -24,7 +24,7 @@ from arkindex.documents.models import (
) )
from arkindex.ponos.models import Task from arkindex.ponos.models import Task
from arkindex.process.models import Process, ProcessDataset, ProcessElement, WorkerActivity, WorkerRun from arkindex.process.models import Process, ProcessDataset, ProcessElement, WorkerActivity, WorkerRun
from arkindex.training.models import DatasetElement from arkindex.training.models import DatasetElement, DatasetSet
from arkindex.users.models import User from arkindex.users.models import User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -73,7 +73,8 @@ def corpus_delete(corpus_id: str) -> None: ...@@ -73,7 +73,8 @@ def corpus_delete(corpus_id: str) -> None:
# ProcessDataset M2M # ProcessDataset M2M
ProcessDataset.objects.filter(dataset__corpus_id=corpus_id), ProcessDataset.objects.filter(dataset__corpus_id=corpus_id),
ProcessDataset.objects.filter(process__corpus_id=corpus_id), ProcessDataset.objects.filter(process__corpus_id=corpus_id),
DatasetElement.objects.filter(dataset__corpus_id=corpus_id), DatasetElement.objects.filter(set__dataset__corpus_id=corpus_id),
DatasetSet.objects.filter(dataset__corpus_id=corpus_id),
corpus.datasets.all(), corpus.datasets.all(),
# Delete the hidden M2M task parents table # Delete the hidden M2M task parents table
Task.parents.through.objects.filter(from_task__process__corpus_id=corpus_id), Task.parents.through.objects.filter(from_task__process__corpus_id=corpus_id),
......
...@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete ...@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete
from arkindex.images.models import Image, ImageServer from arkindex.images.models import Image, ImageServer
from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion
from arkindex.project.tests import FixtureTestCase from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Dataset, DatasetElement
BASE_DIR = Path(__file__).absolute().parent BASE_DIR = Path(__file__).absolute().parent
...@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase): ...@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase):
dla_version = WorkerVersion.objects.get(worker__slug="dla") dla_version = WorkerVersion.objects.get(worker__slug="dla")
dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers) dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
element.classifications.create( element.classifications.create(
ml_class=self.corpus.ml_classes.create(name="Blah"), ml_class=self.corpus.ml_classes.create(name="Blah"),
confidence=.55555555, confidence=.55555555,
...@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase): ...@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase):
confidence=.55555555, confidence=.55555555,
) )
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
person_type = EntityType.objects.get( person_type = EntityType.objects.get(
name="person", name="person",
corpus=self.corpus corpus=self.corpus
......
...@@ -5,7 +5,7 @@ from arkindex.documents.tasks import corpus_delete ...@@ -5,7 +5,7 @@ from arkindex.documents.tasks import corpus_delete
from arkindex.ponos.models import Farm, State, Task from arkindex.ponos.models import Farm, State, Task
from arkindex.process.models import CorpusWorkerVersion, Process, ProcessDataset, ProcessMode, Repository, WorkerVersion from arkindex.process.models import CorpusWorkerVersion, Process, ProcessDataset, ProcessMode, Repository, WorkerVersion
from arkindex.project.tests import FixtureTestCase, force_constraints_immediate from arkindex.project.tests import FixtureTestCase, force_constraints_immediate
from arkindex.training.models import Dataset from arkindex.training.models import Dataset, DatasetSet
class TestDeleteCorpus(FixtureTestCase): class TestDeleteCorpus(FixtureTestCase):
...@@ -114,18 +114,25 @@ class TestDeleteCorpus(FixtureTestCase): ...@@ -114,18 +114,25 @@ class TestDeleteCorpus(FixtureTestCase):
cls.corpus2 = Corpus.objects.create(name="Other corpus") cls.corpus2 = Corpus.objects.create(name="Other corpus")
dataset1 = Dataset.objects.get(name="First Dataset") dataset1 = Dataset.objects.get(name="First Dataset")
dataset1.dataset_elements.create(element=element, set="test") test_set_1 = dataset1.sets.get(name="test")
test_set_1.set_elements.create(element=element)
cls.dataset2 = Dataset.objects.create(name="Dead Sea Scrolls", description="How to trigger a Third Impact", creator=cls.user, corpus=cls.corpus2) cls.dataset2 = Dataset.objects.create(name="Dead Sea Scrolls", description="How to trigger a Third Impact", creator=cls.user, corpus=cls.corpus2)
DatasetSet.objects.bulk_create(
DatasetSet(
dataset=cls.dataset2,
name=set_name
) for set_name in ["test", "training", "validation"]
)
# Process on cls.corpus and with a dataset from cls.corpus # Process on cls.corpus and with a dataset from cls.corpus
dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset) dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=dataset1.sets) ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
# Process on cls.corpus with a dataset from another corpus # Process on cls.corpus with a dataset from another corpus
dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset) dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=dataset1.sets) ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=cls.dataset2.sets) ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
# Process on another corpus with a dataset from another corpus and none from cls.corpus # Process on another corpus with a dataset from another corpus and none from cls.corpus
cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset) cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=cls.dataset2.sets) ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
cls.rev = cls.repo.revisions.create( cls.rev = cls.repo.revisions.create(
hash="42", hash="42",
......
...@@ -24,6 +24,7 @@ from arkindex.documents.models import ( ...@@ -24,6 +24,7 @@ from arkindex.documents.models import (
from arkindex.images.models import Image, ImageServer from arkindex.images.models import Image, ImageServer
from arkindex.process.models import Repository, WorkerType, WorkerVersion, WorkerVersionState from arkindex.process.models import Repository, WorkerType, WorkerVersion, WorkerVersionState
from arkindex.project.tests import FixtureTestCase from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import DatasetElement
TABLE_NAMES = { TABLE_NAMES = {
"export_version", "export_version",
...@@ -131,8 +132,9 @@ class TestExport(FixtureTestCase): ...@@ -131,8 +132,9 @@ class TestExport(FixtureTestCase):
) )
dataset = self.corpus.datasets.get(name="First Dataset") dataset = self.corpus.datasets.get(name="First Dataset")
dataset.dataset_elements.create(element=element, set="train") _, train_set, validation_set = dataset.sets.all().order_by("name")
dataset.dataset_elements.create(element=element, set="validation") train_set.set_elements.create(element=element)
validation_set.set_elements.create(element=element)
export = self.corpus.exports.create(user=self.user) export = self.corpus.exports.create(user=self.user)
...@@ -488,7 +490,7 @@ class TestExport(FixtureTestCase): ...@@ -488,7 +490,7 @@ class TestExport(FixtureTestCase):
( (
str(dataset.id), str(dataset.id),
dataset.name, dataset.name,
",".join(dataset.sets), ",".join(list(dataset.sets.values_list("name", flat=True))),
) for dataset in self.corpus.datasets.all() ) for dataset in self.corpus.datasets.all()
] ]
) )
...@@ -506,9 +508,9 @@ class TestExport(FixtureTestCase): ...@@ -506,9 +508,9 @@ class TestExport(FixtureTestCase):
( (
str(dataset_element.id), str(dataset_element.id),
str(dataset_element.element_id), str(dataset_element.element_id),
str(dataset_element.dataset_id), str(dataset_element.set.dataset_id),
dataset_element.set dataset_element.set.name
) for dataset_element in dataset.dataset_elements.all() ) for dataset_element in DatasetElement.objects.filter(set__dataset_id=dataset.id)
] ]
) )
......
...@@ -183,7 +183,8 @@ class TestDeleteSelectionWorkerResults(FixtureTestCase): ...@@ -183,7 +183,8 @@ class TestDeleteSelectionWorkerResults(FixtureTestCase):
job_mock.return_value.user_id = self.user.id job_mock.return_value.user_id = self.user.id
self.page1.worker_version = self.version self.page1.worker_version = self.version
self.page1.save() self.page1.save()
Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test") dataset = Dataset.objects.get(name="First Dataset")
dataset.sets.get(name="test").set_elements.create(element=self.page1)
self.user.selected_elements.set([self.page1]) self.user.selected_elements.set([self.page1])
selection_worker_results_delete(corpus_id=self.corpus.id, version_id=self.version.id) selection_worker_results_delete(corpus_id=self.corpus.id, version_id=self.version.id)
......
...@@ -6,7 +6,7 @@ from arkindex.documents.models import Entity, EntityType, MLClass, Transcription ...@@ -6,7 +6,7 @@ from arkindex.documents.models import Entity, EntityType, MLClass, Transcription
from arkindex.documents.tasks import worker_results_delete from arkindex.documents.tasks import worker_results_delete
from arkindex.process.models import ProcessMode, WorkerVersion from arkindex.process.models import ProcessMode, WorkerVersion
from arkindex.project.tests import FixtureTestCase from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Dataset, Model, ModelVersionState from arkindex.training.models import DatasetSet, Model, ModelVersionState
class TestDeleteWorkerResults(FixtureTestCase): class TestDeleteWorkerResults(FixtureTestCase):
...@@ -270,7 +270,7 @@ class TestDeleteWorkerResults(FixtureTestCase): ...@@ -270,7 +270,7 @@ class TestDeleteWorkerResults(FixtureTestCase):
self.page1.worker_run = self.worker_run_1 self.page1.worker_run = self.worker_run_1
self.page1.worker_version = self.version_1 self.page1.worker_version = self.version_1
self.page1.save() self.page1.save()
Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test") DatasetSet.objects.get(name="test", dataset__name="First Dataset").set_elements.create(element=self.page1)
worker_results_delete(corpus_id=self.corpus.id) worker_results_delete(corpus_id=self.corpus.id)
# Prevent delaying constraints check at end of the test transaction # Prevent delaying constraints check at end of the test transaction
......
...@@ -148,7 +148,8 @@ class TestDestroyElements(FixtureAPITestCase): ...@@ -148,7 +148,8 @@ class TestDestroyElements(FixtureAPITestCase):
""" """
An element cannot be deleted via the API if linked to a dataset An element cannot be deleted via the API if linked to a dataset
""" """
Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.vol, set="test") dataset = Dataset.objects.get(name="First Dataset")
dataset.sets.get(name="test").set_elements.create(element=self.vol)
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(3): with self.assertNumQueries(3):
response = self.client.delete(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)})) response = self.client.delete(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
...@@ -179,9 +180,9 @@ class TestDestroyElements(FixtureAPITestCase): ...@@ -179,9 +180,9 @@ class TestDestroyElements(FixtureAPITestCase):
""" """
Elements that are part of a dataset cannot be deleted Elements that are part of a dataset cannot be deleted
""" """
Dataset.objects.get(name="First Dataset").dataset_elements.create( dataset = Dataset.objects.get(name="First Dataset")
element=Element.objects.get_descending(self.vol.id).first(), dataset.sets.get(name="test").set_elements.create(
set="test", element=Element.objects.get_descending(self.vol.id).first()
) )
Element.objects.filter(id=self.vol.id).trash() Element.objects.filter(id=self.vol.id).trash()
......
...@@ -111,19 +111,18 @@ class TestAPI(FixtureAPITestCase): ...@@ -111,19 +111,18 @@ class TestAPI(FixtureAPITestCase):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id])) resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
@expectedFailure @patch("arkindex.project.mixins.get_max_level")
def test_task_details_requires_process_guest(self): def test_task_details_requires_process_guest(self, get_max_level_mock):
self.process.creator = self.superuser get_max_level_mock.return_value = None
self.process.save()
self.corpus.memberships.filter(user=self.user).delete()
self.corpus.public = False
self.corpus.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(5): with self.assertNumQueries(3):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id])) resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.corpus))
@patch("arkindex.project.aws.s3") @patch("arkindex.project.aws.s3")
def test_task_details_process_level_corpus(self, s3_mock): def test_task_details_process_level_corpus(self, s3_mock):
s3_mock.Object.return_value.bucket_name = "ponos" s3_mock.Object.return_value.bucket_name = "ponos"
......
...@@ -706,6 +706,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): ...@@ -706,6 +706,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
return ( return (
ProcessDataset.objects.filter(process_id=self.process.id) ProcessDataset.objects.filter(process_id=self.process.id)
.select_related("process__creator", "dataset__creator") .select_related("process__creator", "dataset__creator")
.prefetch_related("dataset__sets")
.order_by("dataset__name") .order_by("dataset__name")
) )
...@@ -715,8 +716,6 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): ...@@ -715,8 +716,6 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
if not self.kwargs: if not self.kwargs:
return context return context
context["process"] = self.process context["process"] = self.process
# Disable set elements counts in serialized dataset
context["sets_count"] = False
return context return context
...@@ -751,6 +750,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): ...@@ -751,6 +750,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView):
process_dataset = get_object_or_404( process_dataset = get_object_or_404(
ProcessDataset.objects ProcessDataset.objects
.select_related("dataset__creator", "process__corpus") .select_related("dataset__creator", "process__corpus")
.prefetch_related("dataset__sets")
# Required to check for a process that have already started # Required to check for a process that have already started
.annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))), .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))),
dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"] dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"]
...@@ -759,12 +759,6 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): ...@@ -759,12 +759,6 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView):
process_dataset.process.has_tasks = process_dataset.process_has_tasks process_dataset.process.has_tasks = process_dataset.process_has_tasks
return process_dataset return process_dataset
def get_serializer_context(self):
context = super().get_serializer_context()
# Disable set elements counts in serialized dataset
context["sets_count"] = False
return context
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
......
import django.core.validators import django.core.validators
from django.db import migrations, models from django.db import migrations, models
import arkindex.process.models
import arkindex.project.fields import arkindex.project.fields
import arkindex.training.models
class Migration(migrations.Migration): class Migration(migrations.Migration):
...@@ -37,7 +37,7 @@ class Migration(migrations.Migration): ...@@ -37,7 +37,7 @@ class Migration(migrations.Migration):
validators=[django.core.validators.MinLengthValidator(1)] validators=[django.core.validators.MinLengthValidator(1)]
), ),
size=None, size=None,
validators=[django.core.validators.MinLengthValidator(1), arkindex.training.models.validate_unique_set_names] validators=[django.core.validators.MinLengthValidator(1), arkindex.process.models.validate_unique_set_names]
), ),
), ),
] ]
...@@ -5,6 +5,7 @@ from typing import Optional ...@@ -5,6 +5,7 @@ from typing import Optional
from django.conf import settings from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.fields import GenericRelation
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator, MinValueValidator from django.core.validators import MinLengthValidator, MinValueValidator
from django.db import models, transaction from django.db import models, transaction
from django.db.models import F, Q from django.db.models import F, Q
...@@ -27,8 +28,9 @@ from arkindex.process.managers import ( ...@@ -27,8 +28,9 @@ from arkindex.process.managers import (
from arkindex.project.aws import S3FileMixin, S3FileStatus from arkindex.project.aws import S3FileMixin, S3FileStatus
from arkindex.project.fields import ArrayField, MD5HashField from arkindex.project.fields import ArrayField, MD5HashField
from arkindex.project.models import IndexableModel from arkindex.project.models import IndexableModel
from arkindex.project.tools import is_prefetched
from arkindex.project.validators import MaxValueValidator from arkindex.project.validators import MaxValueValidator
from arkindex.training.models import ModelVersion, ModelVersionState, validate_unique_set_names from arkindex.training.models import ModelVersion, ModelVersionState
from arkindex.users.models import Role from arkindex.users.models import Role
...@@ -40,6 +42,11 @@ def process_max_chunks(): ...@@ -40,6 +42,11 @@ def process_max_chunks():
return settings.MAX_CHUNKS return settings.MAX_CHUNKS
def validate_unique_set_names(sets):
if len(set(sets)) != len(sets):
raise ValidationError("Set names must be unique.")
class ActivityState(Enum): class ActivityState(Enum):
""" """
Store the state of the workers activity tracking for a process. Store the state of the workers activity tracking for a process.
...@@ -218,11 +225,7 @@ class Process(IndexableModel): ...@@ -218,11 +225,7 @@ class Process(IndexableModel):
See https://stackoverflow.com/a/19651840/5990435 See https://stackoverflow.com/a/19651840/5990435
""" """
return ( return is_prefetched(self.tasks)
hasattr(self, "_prefetched_objects_cache")
and self.tasks.field.remote_field.get_cache_name()
in self._prefetched_objects_cache
)
@property @property
def expiry(self): def expiry(self):
......
...@@ -89,7 +89,7 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): ...@@ -89,7 +89,7 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
else: else:
dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user))
try: try:
dataset = dataset_qs.select_related("creator").get(pk=data["dataset_id"]) dataset = dataset_qs.select_related("creator").prefetch_related("sets").get(pk=data["dataset_id"])
except Dataset.DoesNotExist: except Dataset.DoesNotExist:
raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']}) raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']})
else: else:
...@@ -109,11 +109,11 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): ...@@ -109,11 +109,11 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
sets = data.get("sets") sets = data.get("sets")
if not sets or len(sets) == 0: if not sets or len(sets) == 0:
if not self.instance: if not self.instance:
data["sets"] = dataset.sets data["sets"] = [item.name for item in list(dataset.sets.all())]
else: else:
errors["sets"].append("This field cannot be empty.") errors["sets"].append("This field cannot be empty.")
else: else:
if any(s not in dataset.sets for s in sets): if any(s not in [item.name for item in list(dataset.sets.all())] for s in sets):
errors["sets"].append("The specified sets must all exist in the specified dataset.") errors["sets"].append("The specified sets must all exist in the specified dataset.")
if len(set(sets)) != len(sets): if len(set(sets)) != len(sets):
errors["sets"].append("Sets must be unique.") errors["sets"].append("Sets must be unique.")
......
...@@ -899,7 +899,8 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -899,7 +899,8 @@ class TestCreateProcess(FixtureAPITestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
dataset = self.corpus.datasets.first() dataset = self.corpus.datasets.first()
ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) test_sets = list(dataset.sets.values_list("name", flat=True))
ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets)
process.versions.set([self.version_2, self.version_3]) process.versions.set([self.version_2, self.version_3])
with self.assertNumQueries(9): with self.assertNumQueries(9):
...@@ -929,7 +930,8 @@ class TestCreateProcess(FixtureAPITestCase): ...@@ -929,7 +930,8 @@ class TestCreateProcess(FixtureAPITestCase):
self.worker_1.save() self.worker_1.save()
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
dataset = self.corpus.datasets.first() dataset = self.corpus.datasets.first()
ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) test_sets = list(dataset.sets.values_list("name", flat=True))
ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets)
process.versions.add(self.version_1) process.versions.add(self.version_1)
with self.assertNumQueries(9): with self.assertNumQueries(9):
......
...@@ -2324,7 +2324,7 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2324,7 +2324,7 @@ class TestProcesses(FixtureAPITestCase):
def test_start_process_dataset_requires_dataset_in_same_corpus(self): def test_start_process_dataset_requires_dataset_in_same_corpus(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True)))
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists()) self.assertFalse(process2.tasks.exists())
...@@ -2341,8 +2341,8 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2341,8 +2341,8 @@ class TestProcesses(FixtureAPITestCase):
def test_start_process_dataset_unsupported_parameters(self): def test_start_process_dataset_unsupported_parameters(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.dataset2.sets) ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.dataset2.sets.values_list("name", flat=True)))
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user) self.client.force_login(self.user)
...@@ -2366,8 +2366,8 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2366,8 +2366,8 @@ class TestProcesses(FixtureAPITestCase):
def test_start_process_dataset(self): def test_start_process_dataset(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True)))
run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists()) self.assertFalse(process2.tasks.exists())
...@@ -2562,8 +2562,8 @@ class TestProcesses(FixtureAPITestCase): ...@@ -2562,8 +2562,8 @@ class TestProcesses(FixtureAPITestCase):
It should be possible to pass chunks when starting a dataset process It should be possible to pass chunks when starting a dataset process
""" """
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=self.dataset1.sets) ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=self.dataset2.sets) ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=list(self.dataset2.sets.values_list("name", flat=True)))
# Add a worker run to this process # Add a worker run to this process
run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None) run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
......