Skip to content
Snippets Groups Projects
Commit 4c5d66cc authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

Merge branch 'load-export' into 'master'

Import the export

Closes #809

See merge request !1434
parents 5a6d8198 fa6ef080
No related branches found
No related tags found
1 merge request!1434Import the export
......@@ -102,7 +102,7 @@ def send_email(subject, template_name, corpus_export, **context):
@job('high', timeout=settings.RQ_TIMEOUTS['export_corpus'])
def export_corpus(corpus_export: CorpusExport) -> None:
_, db_path = tempfile.mkstemp(suffix='db')
_, db_path = tempfile.mkstemp(suffix='.db')
try:
rq_job = get_current_job()
corpus_export.state = CorpusExportState.Running
......
#!/usr/bin/env python3
import json
import logging
import os
import sqlite3
import uuid
from datetime import datetime, timezone
from pathlib import Path
from django.core.management.base import BaseCommand, CommandError
from teklia_toolbox.time import Timer
from arkindex.dataimport.models import Repository, RepositoryType, Revision, Worker, WorkerVersion
from arkindex.documents.export import BATCH_SIZE
from arkindex.documents.models import (
Classification,
Corpus,
Element,
ElementPath,
ElementType,
Entity,
EntityLink,
EntityRole,
MetaData,
MLClass,
Transcription,
TranscriptionEntity,
)
from arkindex.images.models import Image, ImageServer
from arkindex.users.models import Role, User
logger = logging.getLogger(__name__)
TABLE_NAMES = {
'export_version',
'image_server',
'image',
'worker_version',
'element',
'element_path',
'entity',
'entity_role',
'entity_link',
'transcription',
'transcription_entity',
'metadata',
'classification',
}
SQL_TABLES_QUERY = "SELECT name FROM sqlite_master WHERE type ='table' AND name NOT LIKE 'sqlite_%'"
SQL_VERSION_QUERY = "SELECT version FROM export_version"
SQL_ELEMENT_TYPE_QUERY = """
SELECT DISTINCT element.type as slug, max(element.polygon) IS NULL as folder
FROM element
GROUP BY element.type
"""
SQL_ML_CLASS_QUERY = "SELECT DISTINCT classification.class_name as name FROM classification"
SQL_REPOSITORY_QUERY = "SELECT DISTINCT repository_url as url FROM worker_version"
SQL_WORKER_VERSION_QUERY = "SELECT * FROM worker_version"
SQL_IMAGE_SERVER_QUERY = "SELECT * FROM image_server"
SQL_IMAGE_QUERY = """
SELECT image.*, image_server.url as server_url
FROM image
INNER JOIN image_server ON (image_server.id = image.server_id)
"""
SQL_ELEMENT_QUERY = "SELECT * FROM element"
SQL_ELEMENT_PATH_QUERY = "SELECT * FROM element_path"
SQL_PARENT_QUERY = "SELECT * FROM element_path WHERE child_id = '{}'"
SQL_ENTITY_QUERY = "SELECT * FROM entity"
SQL_ENTITY_ROLE_QUERY = "SELECT * FROM entity_role"
SQL_ENTITY_LINK_QUERY = "SELECT * FROM entity_link"
SQL_TRANSCRIPTION_QUERY = "SELECT * FROM transcription"
SQL_TRANSCRIPTION_ENTITY_QUERY = "SELECT * FROM transcription_entity"
SQL_METADATA_QUERY = "SELECT * FROM metadata"
SQL_CLASSIFICATION_QUERY = "SELECT * FROM classification"
class Command(BaseCommand):
help = "Import an SQLite database generated by an Arkindex export"
# Maximum number of objects that will be stored in memory at once before their creation.
BATCH_SIZE = 500
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"db_path", help="The path to the SQL database file", type=Path,
)
parser.add_argument(
"--email", help="The user email to give corpus rights", type=str, required=True
)
parser.add_argument("--corpus-name", help="The name of the corpus to create")
def sql_chunk(self, sql_query):
execute = self.cursor.execute(sql_query)
while True:
chunk = execute.fetchmany(BATCH_SIZE)
if not chunk:
return
yield chunk
def build_element_paths(self, child_id):
"""
The SQL database only stores links to direct parents.
It is necessary to be able to reconstruct the complete paths (with all grandparents).
This function retrieves the complete paths by making a recursive call.
"""
sql_query = SQL_PARENT_QUERY.format(child_id)
paths = []
for db_chunk in self.sql_chunk(sql_query):
for row in db_chunk:
parent_paths = self.build_element_paths(row["parent_id"])
for parent_path in parent_paths:
parent_path.append(row["parent_id"])
paths.append(parent_path)
if not parent_paths:
paths.append([row["parent_id"]])
return paths
def convert_element_types(self, row):
return [ElementType(
display_name=row["slug"].title().replace("_", " "),
slug=row["slug"],
folder=row["folder"],
corpus=self.corpus
)]
def convert_ml_classes(self, row):
return [MLClass(
name=row["name"],
corpus=self.corpus
)]
def convert_repositories(self, row):
return [Repository(
url=row["url"],
type=RepositoryType.Worker,
hook_token=str(uuid.uuid4())
)]
def convert_worker_versions(self, row):
worker, _ = Worker.objects.get_or_create(
slug=row["slug"],
repository__url=row["repository_url"],
defaults={
"type": row["type"],
"name": row["name"],
}
)
revision, _ = Revision.objects.get_or_create(
repo__url=row["repository_url"],
hash=row["revision"],
defaults={"message": "Fake revision", "author": self.user.display_name}
)
return [WorkerVersion(
id=row["id"],
worker=worker,
revision=revision,
configuration={"configuration": {}}
)]
def convert_image_servers(self, row):
return [ImageServer(
id=row["id"],
display_name=row["display_name"],
url=row["url"],
max_width=row["max_width"],
max_height=row["max_height"],
)]
def convert_images(self, row):
assert row["url"].startswith(row["server_url"]), "The url of the image does not start with the url of its server"
path = row["url"][len(row["server_url"].strip('/')) + 1:]
return [Image(
id=row["id"],
path=path,
width=row["width"],
height=row["height"],
server_id=row["server_id"],
)]
def convert_elements(self, row):
return [Element(
id=row["id"],
name=row["name"],
type_id=self.element_types[row["type"]],
polygon=json.loads(row["polygon"]) if row["polygon"] else None,
image_id=row["image_id"],
rotation_angle=row["rotation_angle"],
mirrored=row["mirrored"],
worker_version_id=row["worker_version_id"],
corpus=self.corpus
)]
def convert_element_paths(self, row):
paths = self.build_element_paths(row["child_id"])
return [ElementPath(
element_id=row["child_id"],
path=path,
ordering=row["ordering"]
) for path in paths]
def convert_entities(self, row):
return [Entity(
id=row["id"],
name=row["name"],
type=row["type"],
validated=row["validated"],
moderator=User.objects.get(email=row["moderator"]) if row["moderator"] else None,
metas=json.loads(row["metas"]) if row["metas"] else None,
worker_version_id=row["worker_version_id"],
corpus=self.corpus
)]
def convert_entity_roles(self, row):
return [EntityRole(
id=row["id"],
parent_name=row["parent_name"],
child_name=row["child_name"],
parent_type=row["parent_type"],
child_type=row["child_type"],
corpus=self.corpus
)]
def convert_entity_links(self, row):
return [EntityLink(
id=row["id"],
parent_id=row["parent_id"],
child_id=row["child_id"],
role_id=row["role_id"],
)]
def convert_transcriptions(self, row):
return [Transcription(
id=row["id"],
element_id=row["element_id"],
text=row["text"],
confidence=row["confidence"],
worker_version_id=row["worker_version_id"],
)]
def convert_transcription_entities(self, row):
return [TranscriptionEntity(
id=row["id"],
transcription_id=row["transcription_id"],
entity_id=row["entity_id"],
offset=row["offset"],
length=row["length"],
worker_version_id=row["worker_version_id"],
)]
def convert_metadatas(self, row):
return [MetaData(
id=row["id"],
element_id=row["element_id"],
name=row["name"],
type=row["type"],
value=row["value"],
entity_id=row["entity_id"],
worker_version_id=row["worker_version_id"],
)]
def convert_classifications(self, row):
return [Classification(
id=row["id"],
element_id=row["element_id"],
ml_class_id=self.ml_class[row["class_name"]],
state=row["state"],
moderator=User.objects.get(email=row["moderator"]) if row["moderator"] else None,
confidence=row["confidence"],
high_confidence=row["high_confidence"],
worker_version_id=row["worker_version_id"],
)]
def bulk_create_objects(self, ModelClass, convert_method, sql_query):
# Model name for logs
verbose_name_plural = ModelClass._meta.verbose_name_plural.lower()
verbose_name = ModelClass._meta.verbose_name.capitalize()
logger.info(f"Creating {verbose_name_plural}")
with Timer() as t:
count, failed = 0, 0
for db_chunk in self.sql_chunk(sql_query):
# Create instances
objects = []
for row in db_chunk:
try:
objects.extend(convert_method(row))
except Exception as e:
logger.warning(f"{verbose_name} creation failed: {e}")
failed += 1
# Create objects in bulk
try:
ModelClass.objects.bulk_create(objects, ignore_conflicts=True)
count += len(objects)
except Exception as e:
logger.warning(f"{verbose_name_plural.title()} creation failed: {e}")
failed += len(objects)
logger.info(f"Ran on {count+failed} rows: {count} completed, {failed} failed in {t.delta}")
def handle(self, db_path, email, **options):
# Check the database file
if os.path.splitext(db_path.name)[1].lower() not in [".db", ".sqlite"]:
raise CommandError(f"File {db_path} is not an SQLite database")
if not os.path.isfile(db_path):
raise CommandError(f"File {db_path} does not exist")
# Check if user exists
try:
self.user = User.objects.get(email=email)
except User.DoesNotExist:
raise CommandError(f"User with the email {email} does not exist")
db = sqlite3.connect(db_path)
db.row_factory = sqlite3.Row
self.cursor = db.cursor()
# Check database tables
db_results = self.cursor.execute(SQL_TABLES_QUERY).fetchall()
if not set([table["name"] for table in db_results]) == TABLE_NAMES:
raise CommandError(f"The SQLite database {db_path} is not a correct Arkindex export")
# Check export version
db_results = self.cursor.execute(SQL_VERSION_QUERY).fetchall()
if len(db_results) != 1 or db_results[0]["version"] != 2:
raise CommandError(f"The SQLite database {db_path} does not have the correct export version")
# Retrieve corpus name
date = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
corpus_name = options.get("corpus_name")
if not corpus_name:
# An Arkindex export has a name like "corpus_name-%Y%m%d-%H%M%S.sqlite"
split_name = db_path.name.split("-")
corpus_name = " ".join(split_name[:len(split_name) - 2]).title()
if not corpus_name:
corpus_name = f"Corpus import {date}"
logger.info(f"Creating corpus {corpus_name}")
with Timer() as t:
# Create corpus
self.corpus = Corpus.objects.create(name=corpus_name, description=f"Corpus import {date}")
self.corpus.memberships.create(
user=self.user,
level=Role.Admin.value
)
# Create element types
self.bulk_create_objects(ElementType, self.convert_element_types, SQL_ELEMENT_TYPE_QUERY)
# Keep the link between element type slug and element type ID in memory
self.element_types = dict(ElementType.objects.filter(corpus=self.corpus).values_list("slug", "id"))
# Create ml class
self.bulk_create_objects(MLClass, self.convert_ml_classes, SQL_ML_CLASS_QUERY)
# Keep the link between ml class name and ml class ID in memory
self.ml_class = dict(MLClass.objects.filter(corpus=self.corpus).values_list("name", "id"))
# Create worker versions
self.bulk_create_objects(Repository, self.convert_repositories, SQL_REPOSITORY_QUERY)
self.bulk_create_objects(WorkerVersion, self.convert_worker_versions, SQL_WORKER_VERSION_QUERY)
# Create images and servers
self.bulk_create_objects(ImageServer, self.convert_image_servers, SQL_IMAGE_SERVER_QUERY)
self.bulk_create_objects(Image, self.convert_images, SQL_IMAGE_QUERY)
# Create elements and paths
self.bulk_create_objects(Element, self.convert_elements, SQL_ELEMENT_QUERY)
self.bulk_create_objects(ElementPath, self.convert_element_paths, SQL_ELEMENT_PATH_QUERY)
# Create entities, roles and links
self.bulk_create_objects(Entity, self.convert_entities, SQL_ENTITY_QUERY)
self.bulk_create_objects(EntityRole, self.convert_entity_roles, SQL_ENTITY_ROLE_QUERY)
self.bulk_create_objects(EntityLink, self.convert_entity_links, SQL_ENTITY_LINK_QUERY)
# Create transcriptions and transcription entities
self.bulk_create_objects(Transcription, self.convert_transcriptions, SQL_TRANSCRIPTION_QUERY)
self.bulk_create_objects(TranscriptionEntity, self.convert_transcription_entities, SQL_TRANSCRIPTION_ENTITY_QUERY)
# Create metadatas
self.bulk_create_objects(MetaData, self.convert_metadatas, SQL_METADATA_QUERY)
# Create classifications
self.bulk_create_objects(Classification, self.convert_classifications, SQL_CLASSIFICATION_QUERY)
logger.info(f"Created corpus {corpus_name} in {t.delta}")
db.close()
......@@ -63,6 +63,7 @@ def corpus_delete(corpus_id: str) -> None:
corpus.elements.all(),
corpus.types.all(),
corpus.memberships.all(),
corpus.exports.all(),
Corpus.objects.filter(id=corpus_id),
]
......
import json
import sqlite3
import tempfile
from pathlib import Path
from unittest.mock import patch
from django.contrib.gis.geos.linestring import LinearRing
from django.core.management import CommandError, call_command
from arkindex.dataimport.models import WorkerVersion
from arkindex.documents.export import export_corpus
from arkindex.documents.models import Corpus, Element, EntityType, Transcription
from arkindex.documents.tasks import corpus_delete
from arkindex.images.models import Image, ImageServer
from arkindex.project.tests import FixtureTestCase
BASE_DIR = Path(__file__).absolute().parent
class TestLoadExport(FixtureTestCase):
def clean_dump_data(self, path, corpus=None):
# Some fields like id, created and updated are not comparable.
# We keep only the comparable fields for each imported model.
unexpected_fields_by_model = {
'documents.elementtype': ['display_name', 'indexable'],
'documents.mlclass': [],
'dataimport.repository': ['hook_token', 'credentials', 'provider_name', 'git_ref_revisions'],
'dataimport.worker': [],
'dataimport.revision': ['message', 'author'],
'dataimport.workerversion': ['configuration', 'state', 'docker_image', 'docker_image_iid'],
'images.imageserver': ['s3_bucket', 's3_region', 'created', 'updated', 'validated', 'read_only'],
'images.image': ['created', 'updated', 'hash', 'status'],
'documents.element': ['created', 'updated', 'type'],
'documents.elementpath': [],
'documents.entity': [],
'documents.entityrole': [],
'documents.entitylink': [],
'documents.transcription': [],
'documents.transcriptionentity': [],
'documents.metadata': ['index'],
'documents.classification': ['ml_class'],
}
with open(path, 'r') as file:
data = json.loads(file.read())
results = []
for object in data:
if object['model'] not in unexpected_fields_by_model:
continue
unexpected_fields = unexpected_fields_by_model[object['model']]
current_fields = list(object['fields'].keys())
# Delete the id
del object['pk']
for field in current_fields:
# Delete unexpected field
if field in unexpected_fields:
del object['fields'][field]
# Update corpus field with the new one
if field == 'corpus' and corpus:
object['fields'][field] = corpus
results.append(object)
return results
def test_invalid_file(self):
with self.assertRaises(CommandError) as context:
call_command('load_export', 'wrong.txt', '--email', self.user.email)
self.assertEqual(str(context.exception), 'File wrong.txt is not an SQLite database')
def test_file_not_exists(self):
with self.assertRaises(CommandError) as context:
call_command('load_export', 'wrong.db', '--email', self.user.email)
self.assertEqual(str(context.exception), 'File wrong.db does not exist')
def test_invalid_email(self):
_, temp_file = tempfile.mkstemp(suffix=".db")
with self.assertRaises(CommandError) as context:
call_command('load_export', temp_file, '--email', 'wrong@email')
self.assertEqual(str(context.exception), 'User with the email wrong@email does not exist')
def test_invalid_database(self):
_, temp_file = tempfile.mkstemp(suffix=".db")
with self.assertRaises(CommandError) as context:
call_command('load_export', temp_file, '--email', self.user.email)
self.assertEqual(str(context.exception), f"The SQLite database {temp_file} is not a correct Arkindex export")
def test_invalid_version(self):
_, temp_file = tempfile.mkstemp(suffix=".db")
db = sqlite3.connect(temp_file)
cursor = db.cursor()
cursor.executescript((BASE_DIR / '../../export/structure.sql').read_text())
cursor.execute("UPDATE export_version SET version = 1")
db.commit()
db.close()
with self.assertRaises(CommandError) as context:
call_command('load_export', temp_file, '--email', self.user.email, '--corpus-name', 'My corpus')
self.assertEqual(str(context.exception), f"The SQLite database {temp_file} does not have the correct export version")
@patch('arkindex.documents.export.os.unlink')
@patch('arkindex.project.aws.s3.Object')
def test_run(self, s3_object_mock, unlink_mock):
Element.objects.filter(type__folder=False, image__isnull=True).update(
polygon=LinearRing((0, 0), (0, 1000), (1000, 1000), (1000, 0), (0, 0)),
image=Image.objects.all().first()
)
ImageServer.objects.all().update(validated=True)
element = self.corpus.elements.get(name='Volume 1')
transcription = Transcription.objects.first()
version = WorkerVersion.objects.get(worker__slug='reco')
element.classifications.create(
ml_class=self.corpus.ml_classes.create(name='Blah'),
confidence=.55555555,
)
entity1 = self.corpus.entities.create(
name='Arrokuda',
type=EntityType.Location,
metas={'subtype': 'pokemon'},
)
entity2 = self.corpus.entities.create(
name='Stonjourner',
type=EntityType.Person,
validated=True,
moderator=self.superuser,
)
role = self.corpus.roles.create(
parent_name='parent',
child_name='child',
parent_type=EntityType.Location,
child_type=EntityType.Person,
)
role.links.create(parent=entity1, child=entity2)
transcription.transcription_entities.create(
entity=entity1,
offset=1,
length=1,
worker_version=version,
)
export = self.corpus.exports.create(user=self.user)
export_corpus(export)
# Retrieve the database path from the S3 upload argument
args, _ = s3_object_mock().upload_file.call_args
db_path = args[0]
# Call dumpdata command before the deletion
_, dump_path_before = tempfile.mkstemp(suffix='.json')
call_command('dumpdata', output=dump_path_before)
# Delete the existing corpus
corpus_delete(self.corpus.id)
Image.objects.all().delete()
ImageServer.objects.all().delete()
WorkerVersion.objects.filter(id=version.id).delete()
call_command('load_export', db_path, '--email', self.user.email, '--corpus-name', 'My corpus')
# Call dumpdata command after the import
_, dump_path_after = tempfile.mkstemp(suffix='.json')
call_command('dumpdata', output=dump_path_after)
corpus = Corpus.objects.get(name='My corpus')
data_before = self.clean_dump_data(dump_path_before, str(corpus.id))
data_after = self.clean_dump_data(dump_path_after)
self.assertCountEqual(data_before, data_after)
def test_run_empty_database(self):
_, temp_file = tempfile.mkstemp(suffix=".db")
db = sqlite3.connect(temp_file)
cursor = db.cursor()
cursor.executescript((BASE_DIR / '../../export/structure.sql').read_text())
db.commit()
db.close()
call_command('load_export', temp_file, '--email', self.user.email, '--corpus-name', 'My corpus')
corpus = Corpus.objects.get(name='My corpus')
self.assertEqual(corpus.types.all().count(), 0)
self.assertEqual(corpus.ml_classes.all().count(), 0)
self.assertEqual(corpus.elements.all().count(), 0)
self.assertEqual(corpus.entities.all().count(), 0)
self.assertEqual(corpus.roles.all().count(), 0)
@patch('arkindex.documents.export.os.unlink')
@patch('arkindex.project.aws.s3.Object')
def test_run_objects_already_exist(self, s3_object_mock, unlink_mock):
Element.objects.filter(type__folder=False, image__isnull=True).update(
polygon=LinearRing((0, 0), (0, 1000), (1000, 1000), (1000, 0), (0, 0)),
image=Image.objects.all().first()
)
ImageServer.objects.all().update(validated=True)
element = self.corpus.elements.get(name='Volume 1')
transcription = Transcription.objects.first()
version = WorkerVersion.objects.get(worker__slug='reco')
element.classifications.create(
ml_class=self.corpus.ml_classes.create(name='Blah'),
confidence=.55555555,
)
entity1 = self.corpus.entities.create(
name='Arrokuda',
type=EntityType.Location,
metas={'subtype': 'pokemon'},
)
entity2 = self.corpus.entities.create(
name='Stonjourner',
type=EntityType.Person,
validated=True,
moderator=self.superuser,
)
role = self.corpus.roles.create(
parent_name='parent',
child_name='child',
parent_type=EntityType.Location,
child_type=EntityType.Person,
)
role.links.create(parent=entity1, child=entity2)
transcription.transcription_entities.create(
entity=entity1,
offset=1,
length=1,
worker_version=version,
)
export = self.corpus.exports.create(user=self.user)
export_corpus(export)
# Retrieve the database path from the S3 upload argument
args, _ = s3_object_mock().upload_file.call_args
db_path = args[0]
call_command('load_export', db_path, '--email', self.user.email, '--corpus-name', 'My corpus')
corpus = Corpus.objects.get(name='My corpus')
self.assertEqual(corpus.types.all().count(), 6)
self.assertEqual(corpus.ml_classes.all().count(), 1)
self.assertEqual(corpus.elements.all().count(), 0)
self.assertEqual(corpus.entities.all().count(), 0)
self.assertEqual(corpus.roles.all().count(), 0)
......@@ -170,6 +170,10 @@ FROM "users_right"
WHERE ("users_right"."content_id" = '{corpus_id}'::uuid
AND "users_right"."content_type_id" = 1);
DELETE
FROM "documents_corpusexport"
WHERE "documents_corpusexport"."corpus_id" = '{corpus_id}'::uuid;
DELETE
FROM "documents_corpus"
WHERE "documents_corpus"."id" = '{corpus_id}'::uuid
......@@ -174,6 +174,10 @@ FROM "users_right"
WHERE ("users_right"."content_id" = '{corpus_id}'::uuid
AND "users_right"."content_type_id" = 1);
DELETE
FROM "documents_corpusexport"
WHERE "documents_corpusexport"."corpus_id" = '{corpus_id}'::uuid;
DELETE
FROM "documents_corpus"
WHERE "documents_corpus"."id" = '{corpus_id}'::uuid
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment