Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (2)
......@@ -37,7 +37,7 @@ from arkindex.process.models import (
WorkerType,
WorkerVersion,
)
from arkindex.training.models import Dataset, DatasetElement, Model
from arkindex.training.models import Dataset, DatasetElement, DatasetSet, Model
from arkindex.users.models import Role, User
EXPORT_VERSION = 8
......@@ -320,17 +320,30 @@ class Command(BaseCommand):
id=row["id"],
corpus=self.corpus,
name=row["name"],
sets=[r.strip() for r in row["sets"].split(",")],
creator=self.user,
description="Imported dataset",
)]
def convert_dataset_sets(self, row):
return [
DatasetSet(
name=set_name.strip(),
dataset_id=row["id"]
)
for set_name in row["sets"].split(",")
]
def map_dataset_sets(self):
return {
(str(set.dataset_id), set.name): set.id
for set in DatasetSet.objects.filter(dataset__corpus=self.corpus)
}
def convert_dataset_elements(self, row):
return [DatasetElement(
id=row["id"],
element_id=row["element_id"],
dataset_id=row["dataset_id"],
set=row["set_name"],
set_id=self.dataset_sets_map[(row["dataset_id"], row["set_name"])]
)]
def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True):
......@@ -603,6 +616,12 @@ class Command(BaseCommand):
# Create datasets
self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY)
# Create dataset sets
self.bulk_create_objects(DatasetSet, self.convert_dataset_sets, SQL_DATASET_QUERY)
# Create dataset sets mapping
self.dataset_sets_map = self.map_dataset_sets()
# Create dataset elements
self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY)
......
......@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete
from arkindex.images.models import Image, ImageServer
from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion
from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Dataset, DatasetElement
BASE_DIR = Path(__file__).absolute().parent
......@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase):
dla_version = WorkerVersion.objects.get(worker__slug="dla")
dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
element.classifications.create(
ml_class=self.corpus.ml_classes.create(name="Blah"),
confidence=.55555555,
......@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase):
confidence=.55555555,
)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
person_type = EntityType.objects.get(
name="person",
corpus=self.corpus
......
......@@ -184,7 +184,7 @@ api = [
# Datasets
path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"),
path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"),
path("element/<uuid:pk>/datasets/", ElementDatasetSets.as_view(), name="element-datasets"),
path("element/<uuid:pk>/sets/", ElementDatasetSets.as_view(), name="element-sets"),
path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"),
path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"),
path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"),
......
......@@ -1633,7 +1633,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
self.client.force_login(self.user)
private_elt = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="t"), name="elt")
with self.assertNumQueries(2):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": private_elt.id}))
response = self.client.get(reverse("api:element-sets", kwargs={"pk": private_elt.id}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(filter_rights_mock.call_count, 1)
......@@ -1645,7 +1645,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
for method in forbidden_methods:
with self.subTest(method=method):
client_method = getattr(self.client, method)
response = client_method(reverse("api:element-datasets", kwargs={"pk": str(self.vol.id)}))
response = client_method(reverse("api:element-sets", kwargs={"pk": str(self.vol.id)}))
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_element_datasets_public(self):
......@@ -1655,7 +1655,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
train_set = self.dataset.sets.get(name="training")
train_set.set_elements.create(element=self.vol)
with self.assertNumQueries(4):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.vol.id)}))
response = self.client.get(reverse("api:element-sets", kwargs={"pk": str(self.vol.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"count": 1,
......@@ -1696,7 +1696,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
train_set_2 = self.dataset2.sets.get(name="training")
train_set_2.set_elements.create(element=self.page1, set="train")
with self.assertNumQueries(6):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}))
response = self.client.get(reverse("api:element-sets", kwargs={"pk": str(self.page1.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"count": 3,
......@@ -1783,7 +1783,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
validation_set.set_elements.create(element=self.page1)
train_set_2.set_elements.create(element=self.page1)
with self.assertNumQueries(6):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False})
response = self.client.get(reverse("api:element-sets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"count": 3,
......@@ -1880,7 +1880,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id))
with self.assertNumQueries(8):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True})
response = self.client.get(reverse("api:element-sets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
......