From 94822ae11e4e17657b2581afd9b3c3bc362068a0 Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Tue, 12 Mar 2024 13:29:58 +0100
Subject: [PATCH] fix delete corpus

---
 arkindex/documents/export/dataset.sql         |  4 +++-
 arkindex/documents/export/dataset_element.sql |  7 ++++---
 arkindex/documents/tasks.py                   |  5 +++--
 .../tests/tasks/test_corpus_delete.py         | 19 +++++++++++++------
 arkindex/documents/tests/tasks/test_export.py | 14 ++++++++------
 .../test_selection_worker_results_delete.py   |  3 ++-
 .../tests/tasks/test_worker_results_delete.py |  4 ++--
 .../documents/tests/test_destroy_elements.py  |  9 +++++----
 arkindex/sql_validation/corpus_delete.sql     |  9 +++++++++
 .../corpus_delete_top_level_type.sql          | 13 +++++++++++--
 10 files changed, 60 insertions(+), 27 deletions(-)

diff --git a/arkindex/documents/export/dataset.sql b/arkindex/documents/export/dataset.sql
index fd7990752b..48d54910f0 100644
--- a/arkindex/documents/export/dataset.sql
+++ b/arkindex/documents/export/dataset.sql
@@ -2,6 +2,8 @@ SELECT
     dataset.id,
     dataset.name,
     dataset.state,
-    ARRAY_TO_STRING(dataset.sets, ',', '')
+    string_agg(datasetset.name, ',')
 FROM training_dataset dataset
+INNER JOIN training_datasetset datasetset ON datasetset.dataset_id = dataset.id
 WHERE dataset.corpus_id = '{corpus_id}'::uuid
+GROUP BY dataset.id
diff --git a/arkindex/documents/export/dataset_element.sql b/arkindex/documents/export/dataset_element.sql
index 4084e2e0cd..c75624c81d 100644
--- a/arkindex/documents/export/dataset_element.sql
+++ b/arkindex/documents/export/dataset_element.sql
@@ -1,8 +1,9 @@
 SELECT
     dataset_element.id,
     dataset_element.element_id,
-    dataset_element.dataset_id,
-    dataset_element.set
+    dataset_set.dataset_id,
+    dataset_set.name
 FROM training_datasetelement dataset_element
-INNER JOIN training_dataset dataset ON (dataset_element.dataset_id = dataset.id)
+INNER JOIN training_datasetset dataset_set ON (dataset_element.set_id = dataset_set.id)
+INNER JOIN training_dataset dataset ON (dataset_set.dataset_id = dataset.id)
 WHERE dataset.corpus_id = '{corpus_id}'::uuid
diff --git a/arkindex/documents/tasks.py b/arkindex/documents/tasks.py
index acd269aede..ee4c680a85 100644
--- a/arkindex/documents/tasks.py
+++ b/arkindex/documents/tasks.py
@@ -24,7 +24,7 @@ from arkindex.documents.models import (
 )
 from arkindex.ponos.models import Task
 from arkindex.process.models import Process, ProcessDataset, ProcessElement, WorkerActivity, WorkerRun
-from arkindex.training.models import DatasetElement
+from arkindex.training.models import DatasetElement, DatasetSet
 from arkindex.users.models import User
 
 logger = logging.getLogger(__name__)
@@ -73,7 +73,8 @@ def corpus_delete(corpus_id: str) -> None:
         # ProcessDataset M2M
         ProcessDataset.objects.filter(dataset__corpus_id=corpus_id),
         ProcessDataset.objects.filter(process__corpus_id=corpus_id),
-        DatasetElement.objects.filter(dataset__corpus_id=corpus_id),
+        DatasetElement.objects.filter(set__dataset__corpus_id=corpus_id),
+        DatasetSet.objects.filter(dataset__corpus_id=corpus_id),
         corpus.datasets.all(),
         # Delete the hidden M2M task parents table
         Task.parents.through.objects.filter(from_task__process__corpus_id=corpus_id),
diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py
index 1685f99b41..cea863ab3c 100644
--- a/arkindex/documents/tests/tasks/test_corpus_delete.py
+++ b/arkindex/documents/tests/tasks/test_corpus_delete.py
@@ -5,7 +5,7 @@ from arkindex.documents.tasks import corpus_delete
 from arkindex.ponos.models import Farm, State, Task
 from arkindex.process.models import CorpusWorkerVersion, Process, ProcessDataset, ProcessMode, Repository, WorkerVersion
 from arkindex.project.tests import FixtureTestCase, force_constraints_immediate
-from arkindex.training.models import Dataset
+from arkindex.training.models import Dataset, DatasetSet
 
 
 class TestDeleteCorpus(FixtureTestCase):
@@ -114,18 +114,25 @@ class TestDeleteCorpus(FixtureTestCase):
         cls.corpus2 = Corpus.objects.create(name="Other corpus")
 
         dataset1 = Dataset.objects.get(name="First Dataset")
-        dataset1.dataset_elements.create(element=element, set="test")
+        test_set_1 = dataset1.sets.get(name="test")
+        test_set_1.set_elements.create(element=element)
         cls.dataset2 = Dataset.objects.create(name="Dead Sea Scrolls", description="How to trigger a Third Impact", creator=cls.user, corpus=cls.corpus2)
+        DatasetSet.objects.bulk_create(
+            DatasetSet(
+                dataset=cls.dataset2,
+                name=set_name
+            ) for set_name in ["test", "training", "validation"]
+        )
         # Process on cls.corpus and with a dataset from cls.corpus
         dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=dataset1.sets)
+        ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
         # Process on cls.corpus with a dataset from another corpus
         dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=dataset1.sets)
-        ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=cls.dataset2.sets)
+        ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
+        ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
         # Process on another corpus with a dataset from another corpus and none from cls.corpus
         cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=cls.dataset2.sets)
+        ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
 
         cls.rev = cls.repo.revisions.create(
             hash="42",
diff --git a/arkindex/documents/tests/tasks/test_export.py b/arkindex/documents/tests/tasks/test_export.py
index d04e63a147..4aae19ae02 100644
--- a/arkindex/documents/tests/tasks/test_export.py
+++ b/arkindex/documents/tests/tasks/test_export.py
@@ -24,6 +24,7 @@ from arkindex.documents.models import (
 from arkindex.images.models import Image, ImageServer
 from arkindex.process.models import Repository, WorkerType, WorkerVersion, WorkerVersionState
 from arkindex.project.tests import FixtureTestCase
+from arkindex.training.models import DatasetElement
 
 TABLE_NAMES = {
     "export_version",
@@ -131,8 +132,9 @@ class TestExport(FixtureTestCase):
         )
 
         dataset = self.corpus.datasets.get(name="First Dataset")
-        dataset.dataset_elements.create(element=element, set="train")
-        dataset.dataset_elements.create(element=element, set="validation")
+        _, train_set, validation_set = dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=element)
+        validation_set.set_elements.create(element=element)
 
         export = self.corpus.exports.create(user=self.user)
 
@@ -488,7 +490,7 @@ class TestExport(FixtureTestCase):
                 (
                     str(dataset.id),
                     dataset.name,
-                    ",".join(dataset.sets),
+                    ",".join(list(dataset.sets.values_list("name", flat=True))),
                 ) for dataset in self.corpus.datasets.all()
             ]
         )
@@ -506,9 +508,9 @@ class TestExport(FixtureTestCase):
                 (
                     str(dataset_element.id),
                     str(dataset_element.element_id),
-                    str(dataset_element.dataset_id),
-                    dataset_element.set
-                ) for dataset_element in dataset.dataset_elements.all()
+                    str(dataset_element.set.dataset_id),
+                    dataset_element.set.name
+                ) for dataset_element in DatasetElement.objects.filter(set__dataset_id=dataset.id)
             ]
         )
 
diff --git a/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py b/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py
index 1f25a1aa07..f766bc0b02 100644
--- a/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py
+++ b/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py
@@ -183,7 +183,8 @@ class TestDeleteSelectionWorkerResults(FixtureTestCase):
         job_mock.return_value.user_id = self.user.id
         self.page1.worker_version = self.version
         self.page1.save()
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test")
+        dataset = Dataset.objects.get(name="First Dataset")
+        dataset.sets.get(name="test").set_elements.create(element=self.page1)
         self.user.selected_elements.set([self.page1])
 
         selection_worker_results_delete(corpus_id=self.corpus.id, version_id=self.version.id)
diff --git a/arkindex/documents/tests/tasks/test_worker_results_delete.py b/arkindex/documents/tests/tasks/test_worker_results_delete.py
index 0fb898cf35..52546e1bac 100644
--- a/arkindex/documents/tests/tasks/test_worker_results_delete.py
+++ b/arkindex/documents/tests/tasks/test_worker_results_delete.py
@@ -6,7 +6,7 @@ from arkindex.documents.models import Entity, EntityType, MLClass, Transcription
 from arkindex.documents.tasks import worker_results_delete
 from arkindex.process.models import ProcessMode, WorkerVersion
 from arkindex.project.tests import FixtureTestCase
-from arkindex.training.models import Dataset, Model, ModelVersionState
+from arkindex.training.models import DatasetSet, Model, ModelVersionState
 
 
 class TestDeleteWorkerResults(FixtureTestCase):
@@ -270,7 +270,7 @@ class TestDeleteWorkerResults(FixtureTestCase):
         self.page1.worker_run = self.worker_run_1
         self.page1.worker_version = self.version_1
         self.page1.save()
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test")
+        DatasetSet.objects.get(name="test", dataset__name="First Dataset").set_elements.create(element=self.page1)
 
         worker_results_delete(corpus_id=self.corpus.id)
         # Prevent delaying constraints check at end of the test transaction
diff --git a/arkindex/documents/tests/test_destroy_elements.py b/arkindex/documents/tests/test_destroy_elements.py
index da6312ac19..af9fe5fb44 100644
--- a/arkindex/documents/tests/test_destroy_elements.py
+++ b/arkindex/documents/tests/test_destroy_elements.py
@@ -148,7 +148,8 @@ class TestDestroyElements(FixtureAPITestCase):
         """
         An element cannot be deleted via the API if linked to a dataset
         """
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.vol, set="test")
+        dataset = Dataset.objects.get(name="First Dataset")
+        dataset.sets.get(name="test").set_elements.create(element=self.vol)
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
@@ -179,9 +180,9 @@ class TestDestroyElements(FixtureAPITestCase):
         """
         Elements that are part of a dataset cannot be deleted
         """
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(
-            element=Element.objects.get_descending(self.vol.id).first(),
-            set="test",
+        dataset = Dataset.objects.get(name="First Dataset")
+        dataset.sets.get(name="test").set_elements.create(
+            element=Element.objects.get_descending(self.vol.id).first()
         )
 
         Element.objects.filter(id=self.vol.id).trash()
diff --git a/arkindex/sql_validation/corpus_delete.sql b/arkindex/sql_validation/corpus_delete.sql
index a454d05cdb..766566825d 100644
--- a/arkindex/sql_validation/corpus_delete.sql
+++ b/arkindex/sql_validation/corpus_delete.sql
@@ -185,6 +185,15 @@ FROM "training_datasetelement"
 WHERE "training_datasetelement"."id" IN
         (SELECT U0."id"
          FROM "training_datasetelement" U0
+         INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id")
+         INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id")
+         WHERE U2."corpus_id" = '{corpus_id}'::uuid);
+
+DELETE
+FROM "training_datasetset"
+WHERE "training_datasetset"."id" IN
+        (SELECT U0."id"
+         FROM "training_datasetset" U0
          INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
          WHERE U1."corpus_id" = '{corpus_id}'::uuid);
 
diff --git a/arkindex/sql_validation/corpus_delete_top_level_type.sql b/arkindex/sql_validation/corpus_delete_top_level_type.sql
index 9c7ae8cd60..d64cf0bb8b 100644
--- a/arkindex/sql_validation/corpus_delete_top_level_type.sql
+++ b/arkindex/sql_validation/corpus_delete_top_level_type.sql
@@ -189,8 +189,17 @@ FROM "training_datasetelement"
 WHERE "training_datasetelement"."id" IN
         (SELECT U0."id"
          FROM "training_datasetelement" U0
-         INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
-         WHERE U1."corpus_id" = '{corpus_id}'::uuid);
+         INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id")
+         INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id")
+         WHERE U2."corpus_id" = '{corpus_id}'::uuid);
+
+DELETE
+FROM "training_datasetset"
+WHERE "training_datasetset"."id" IN
+        (SELECT U0."id"
+         FROM "training_datasetset" U0
+           INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
+           WHERE U1."corpus_id" = '{corpus_id}'::uuid);
 
 DELETE
 FROM "training_dataset"
-- 
GitLab