Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (4)
Showing
with 876 additions and 684 deletions
......@@ -2,6 +2,8 @@ SELECT
dataset.id,
dataset.name,
dataset.state,
ARRAY_TO_STRING(dataset.sets, ',', '')
string_agg(datasetset.name, ',')
FROM training_dataset dataset
INNER JOIN training_datasetset datasetset ON datasetset.dataset_id = dataset.id
WHERE dataset.corpus_id = '{corpus_id}'::uuid
GROUP BY dataset.id
SELECT
dataset_element.id,
dataset_element.element_id,
dataset_element.dataset_id,
dataset_element.set
dataset_set.dataset_id,
dataset_set.name
FROM training_datasetelement dataset_element
INNER JOIN training_dataset dataset ON (dataset_element.dataset_id = dataset.id)
INNER JOIN training_datasetset dataset_set ON (dataset_element.set_id = dataset_set.id)
INNER JOIN training_dataset dataset ON (dataset_set.dataset_id = dataset.id)
WHERE dataset.corpus_id = '{corpus_id}'::uuid
This diff is collapsed.
......@@ -20,6 +20,7 @@ from arkindex.process.models import (
WorkerVersionState,
)
from arkindex.project.tools import fake_now
from arkindex.training.models import DatasetSet
from arkindex.users.models import Group, Right, Role, User
......@@ -271,8 +272,15 @@ class Command(BaseCommand):
)
# Create 2 datasets
corpus.datasets.create(name="First Dataset", description="dataset number one", creator=user)
corpus.datasets.create(name="Second Dataset", description="dataset number two", creator=user)
dataset_1 = corpus.datasets.create(name="First Dataset", description="dataset number one", creator=user)
dataset_2 = corpus.datasets.create(name="Second Dataset", description="dataset number two", creator=user)
# Create their sets
DatasetSet.objects.bulk_create(
DatasetSet(name=name, dataset_id=dataset_1.id) for name in ["training", "validation", "test"]
)
DatasetSet.objects.bulk_create(
DatasetSet(name=name, dataset_id=dataset_2.id) for name in ["training", "validation", "test"]
)
# Create 2 volumes
vol1 = Element.objects.create(
......
......@@ -37,7 +37,7 @@ from arkindex.process.models import (
WorkerType,
WorkerVersion,
)
from arkindex.training.models import Dataset, DatasetElement, Model
from arkindex.training.models import Dataset, DatasetElement, DatasetSet, Model
from arkindex.users.models import Role, User
EXPORT_VERSION = 8
......@@ -320,17 +320,30 @@ class Command(BaseCommand):
id=row["id"],
corpus=self.corpus,
name=row["name"],
sets=[r.strip() for r in row["sets"].split(",")],
creator=self.user,
description="Imported dataset",
)]
def convert_dataset_sets(self, row):
return [
DatasetSet(
name=set_name.strip(),
dataset_id=row["id"]
)
for set_name in row["sets"].split(",")
]
def map_dataset_sets(self):
return {
(str(set.dataset_id), set.name): set.id
for set in DatasetSet.objects.filter(dataset__corpus=self.corpus)
}
def convert_dataset_elements(self, row):
return [DatasetElement(
id=row["id"],
element_id=row["element_id"],
dataset_id=row["dataset_id"],
set=row["set_name"],
set_id=self.dataset_sets_map[(row["dataset_id"], row["set_name"])]
)]
def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True):
......@@ -603,6 +616,12 @@ class Command(BaseCommand):
# Create datasets
self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY)
# Create dataset sets
self.bulk_create_objects(DatasetSet, self.convert_dataset_sets, SQL_DATASET_QUERY)
# Create dataset sets mapping
self.dataset_sets_map = self.map_dataset_sets()
# Create dataset elements
self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY)
......
......@@ -24,7 +24,7 @@ from arkindex.documents.models import (
)
from arkindex.ponos.models import Task
from arkindex.process.models import Process, ProcessDataset, ProcessElement, WorkerActivity, WorkerRun
from arkindex.training.models import DatasetElement
from arkindex.training.models import DatasetElement, DatasetSet
from arkindex.users.models import User
logger = logging.getLogger(__name__)
......@@ -73,7 +73,8 @@ def corpus_delete(corpus_id: str) -> None:
# ProcessDataset M2M
ProcessDataset.objects.filter(dataset__corpus_id=corpus_id),
ProcessDataset.objects.filter(process__corpus_id=corpus_id),
DatasetElement.objects.filter(dataset__corpus_id=corpus_id),
DatasetElement.objects.filter(set__dataset__corpus_id=corpus_id),
DatasetSet.objects.filter(dataset__corpus_id=corpus_id),
corpus.datasets.all(),
# Delete the hidden M2M task parents table
Task.parents.through.objects.filter(from_task__process__corpus_id=corpus_id),
......
......@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete
from arkindex.images.models import Image, ImageServer
from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion
from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Dataset, DatasetElement
BASE_DIR = Path(__file__).absolute().parent
......@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase):
dla_version = WorkerVersion.objects.get(worker__slug="dla")
dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
element.classifications.create(
ml_class=self.corpus.ml_classes.create(name="Blah"),
confidence=.55555555,
......@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase):
confidence=.55555555,
)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
person_type = EntityType.objects.get(
name="person",
corpus=self.corpus
......
......@@ -5,7 +5,7 @@ from arkindex.documents.tasks import corpus_delete
from arkindex.ponos.models import Farm, State, Task
from arkindex.process.models import CorpusWorkerVersion, Process, ProcessDataset, ProcessMode, Repository, WorkerVersion
from arkindex.project.tests import FixtureTestCase, force_constraints_immediate
from arkindex.training.models import Dataset
from arkindex.training.models import Dataset, DatasetSet
class TestDeleteCorpus(FixtureTestCase):
......@@ -114,18 +114,25 @@ class TestDeleteCorpus(FixtureTestCase):
cls.corpus2 = Corpus.objects.create(name="Other corpus")
dataset1 = Dataset.objects.get(name="First Dataset")
dataset1.dataset_elements.create(element=element, set="test")
test_set_1 = dataset1.sets.get(name="test")
test_set_1.set_elements.create(element=element)
cls.dataset2 = Dataset.objects.create(name="Dead Sea Scrolls", description="How to trigger a Third Impact", creator=cls.user, corpus=cls.corpus2)
DatasetSet.objects.bulk_create(
DatasetSet(
dataset=cls.dataset2,
name=set_name
) for set_name in ["test", "training", "validation"]
)
# Process on cls.corpus and with a dataset from cls.corpus
dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=dataset1.sets)
ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
# Process on cls.corpus with a dataset from another corpus
dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=dataset1.sets)
ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=cls.dataset2.sets)
ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
# Process on another corpus with a dataset from another corpus and none from cls.corpus
cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=cls.dataset2.sets)
ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
cls.rev = cls.repo.revisions.create(
hash="42",
......
......@@ -24,6 +24,7 @@ from arkindex.documents.models import (
from arkindex.images.models import Image, ImageServer
from arkindex.process.models import Repository, WorkerType, WorkerVersion, WorkerVersionState
from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import DatasetElement
TABLE_NAMES = {
"export_version",
......@@ -131,8 +132,9 @@ class TestExport(FixtureTestCase):
)
dataset = self.corpus.datasets.get(name="First Dataset")
dataset.dataset_elements.create(element=element, set="train")
dataset.dataset_elements.create(element=element, set="validation")
_, train_set, validation_set = dataset.sets.all().order_by("name")
train_set.set_elements.create(element=element)
validation_set.set_elements.create(element=element)
export = self.corpus.exports.create(user=self.user)
......@@ -488,7 +490,7 @@ class TestExport(FixtureTestCase):
(
str(dataset.id),
dataset.name,
",".join(dataset.sets),
",".join(list(dataset.sets.values_list("name", flat=True))),
) for dataset in self.corpus.datasets.all()
]
)
......@@ -506,9 +508,9 @@ class TestExport(FixtureTestCase):
(
str(dataset_element.id),
str(dataset_element.element_id),
str(dataset_element.dataset_id),
dataset_element.set
) for dataset_element in dataset.dataset_elements.all()
str(dataset_element.set.dataset_id),
dataset_element.set.name
) for dataset_element in DatasetElement.objects.filter(set__dataset_id=dataset.id)
]
)
......
......@@ -183,7 +183,8 @@ class TestDeleteSelectionWorkerResults(FixtureTestCase):
job_mock.return_value.user_id = self.user.id
self.page1.worker_version = self.version
self.page1.save()
Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test")
dataset = Dataset.objects.get(name="First Dataset")
dataset.sets.get(name="test").set_elements.create(element=self.page1)
self.user.selected_elements.set([self.page1])
selection_worker_results_delete(corpus_id=self.corpus.id, version_id=self.version.id)
......
......@@ -6,7 +6,7 @@ from arkindex.documents.models import Entity, EntityType, MLClass, Transcription
from arkindex.documents.tasks import worker_results_delete
from arkindex.process.models import ProcessMode, WorkerVersion
from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Dataset, Model, ModelVersionState
from arkindex.training.models import DatasetSet, Model, ModelVersionState
class TestDeleteWorkerResults(FixtureTestCase):
......@@ -270,7 +270,7 @@ class TestDeleteWorkerResults(FixtureTestCase):
self.page1.worker_run = self.worker_run_1
self.page1.worker_version = self.version_1
self.page1.save()
Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test")
DatasetSet.objects.get(name="test", dataset__name="First Dataset").set_elements.create(element=self.page1)
worker_results_delete(corpus_id=self.corpus.id)
# Prevent delaying constraints check at end of the test transaction
......
......@@ -148,7 +148,8 @@ class TestDestroyElements(FixtureAPITestCase):
"""
An element cannot be deleted via the API if linked to a dataset
"""
Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.vol, set="test")
dataset = Dataset.objects.get(name="First Dataset")
dataset.sets.get(name="test").set_elements.create(element=self.vol)
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.delete(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
......@@ -179,9 +180,9 @@ class TestDestroyElements(FixtureAPITestCase):
"""
Elements that are part of a dataset cannot be deleted
"""
Dataset.objects.get(name="First Dataset").dataset_elements.create(
element=Element.objects.get_descending(self.vol.id).first(),
set="test",
dataset = Dataset.objects.get(name="First Dataset")
dataset.sets.get(name="test").set_elements.create(
element=Element.objects.get_descending(self.vol.id).first()
)
Element.objects.filter(id=self.vol.id).trash()
......
......@@ -111,19 +111,18 @@ class TestAPI(FixtureAPITestCase):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
@expectedFailure
def test_task_details_requires_process_guest(self):
self.process.creator = self.superuser
self.process.save()
self.corpus.memberships.filter(user=self.user).delete()
self.corpus.public = False
self.corpus.save()
@patch("arkindex.project.mixins.get_max_level")
def test_task_details_requires_process_guest(self, get_max_level_mock):
get_max_level_mock.return_value = None
self.client.force_login(self.user)
with self.assertNumQueries(5):
with self.assertNumQueries(3):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(get_max_level_mock.call_count, 1)
self.assertEqual(get_max_level_mock.call_args, call(self.user, self.corpus))
@patch("arkindex.project.aws.s3")
def test_task_details_process_level_corpus(self, s3_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
......
......@@ -706,6 +706,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
return (
ProcessDataset.objects.filter(process_id=self.process.id)
.select_related("process__creator", "dataset__creator")
.prefetch_related("dataset__sets")
.order_by("dataset__name")
)
......@@ -715,8 +716,6 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
if not self.kwargs:
return context
context["process"] = self.process
# Disable set elements counts in serialized dataset
context["sets_count"] = False
return context
......@@ -751,6 +750,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView):
process_dataset = get_object_or_404(
ProcessDataset.objects
.select_related("dataset__creator", "process__corpus")
.prefetch_related("dataset__sets")
# Required to check for a process that have already started
.annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))),
dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"]
......@@ -759,12 +759,6 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView):
process_dataset.process.has_tasks = process_dataset.process_has_tasks
return process_dataset
def get_serializer_context(self):
context = super().get_serializer_context()
# Disable set elements counts in serialized dataset
context["sets_count"] = False
return context
def destroy(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
......
import django.core.validators
from django.db import migrations, models
import arkindex.process.models
import arkindex.project.fields
import arkindex.training.models
class Migration(migrations.Migration):
......@@ -37,7 +37,7 @@ class Migration(migrations.Migration):
validators=[django.core.validators.MinLengthValidator(1)]
),
size=None,
validators=[django.core.validators.MinLengthValidator(1), arkindex.training.models.validate_unique_set_names]
validators=[django.core.validators.MinLengthValidator(1), arkindex.process.models.validate_unique_set_names]
),
),
]
......@@ -5,6 +5,7 @@ from typing import Optional
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator, MinValueValidator
from django.db import models, transaction
from django.db.models import F, Q
......@@ -27,8 +28,9 @@ from arkindex.process.managers import (
from arkindex.project.aws import S3FileMixin, S3FileStatus
from arkindex.project.fields import ArrayField, MD5HashField
from arkindex.project.models import IndexableModel
from arkindex.project.tools import is_prefetched
from arkindex.project.validators import MaxValueValidator
from arkindex.training.models import ModelVersion, ModelVersionState, validate_unique_set_names
from arkindex.training.models import ModelVersion, ModelVersionState
from arkindex.users.models import Role
......@@ -40,6 +42,11 @@ def process_max_chunks():
return settings.MAX_CHUNKS
def validate_unique_set_names(sets):
if len(set(sets)) != len(sets):
raise ValidationError("Set names must be unique.")
class ActivityState(Enum):
"""
Store the state of the workers activity tracking for a process.
......@@ -218,11 +225,7 @@ class Process(IndexableModel):
See https://stackoverflow.com/a/19651840/5990435
"""
return (
hasattr(self, "_prefetched_objects_cache")
and self.tasks.field.remote_field.get_cache_name()
in self._prefetched_objects_cache
)
return is_prefetched(self.tasks)
@property
def expiry(self):
......
......@@ -89,7 +89,7 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
else:
dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user))
try:
dataset = dataset_qs.select_related("creator").get(pk=data["dataset_id"])
dataset = dataset_qs.select_related("creator").prefetch_related("sets").get(pk=data["dataset_id"])
except Dataset.DoesNotExist:
raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']})
else:
......@@ -109,11 +109,11 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
sets = data.get("sets")
if not sets or len(sets) == 0:
if not self.instance:
data["sets"] = dataset.sets
data["sets"] = [item.name for item in list(dataset.sets.all())]
else:
errors["sets"].append("This field cannot be empty.")
else:
if any(s not in dataset.sets for s in sets):
if any(s not in [item.name for item in list(dataset.sets.all())] for s in sets):
errors["sets"].append("The specified sets must all exist in the specified dataset.")
if len(set(sets)) != len(sets):
errors["sets"].append("Sets must be unique.")
......
......@@ -899,7 +899,8 @@ class TestCreateProcess(FixtureAPITestCase):
self.client.force_login(self.user)
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
dataset = self.corpus.datasets.first()
ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets)
test_sets = list(dataset.sets.values_list("name", flat=True))
ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets)
process.versions.set([self.version_2, self.version_3])
with self.assertNumQueries(9):
......@@ -929,7 +930,8 @@ class TestCreateProcess(FixtureAPITestCase):
self.worker_1.save()
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
dataset = self.corpus.datasets.first()
ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets)
test_sets = list(dataset.sets.values_list("name", flat=True))
ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets)
process.versions.add(self.version_1)
with self.assertNumQueries(9):
......
......@@ -2324,7 +2324,7 @@ class TestProcesses(FixtureAPITestCase):
def test_start_process_dataset_requires_dataset_in_same_corpus(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets)
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True)))
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists())
......@@ -2341,8 +2341,8 @@ class TestProcesses(FixtureAPITestCase):
def test_start_process_dataset_unsupported_parameters(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets)
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.dataset2.sets)
ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.dataset2.sets.values_list("name", flat=True)))
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user)
......@@ -2366,8 +2366,8 @@ class TestProcesses(FixtureAPITestCase):
def test_start_process_dataset(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets)
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets)
ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True)))
run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists())
......@@ -2562,8 +2562,8 @@ class TestProcesses(FixtureAPITestCase):
It should be possible to pass chunks when starting a dataset process
"""
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=self.dataset1.sets)
ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=self.dataset2.sets)
ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=list(self.dataset2.sets.values_list("name", flat=True)))
# Add a worker run to this process
run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
......