From 462bf30810d7a8bb51c1aedd5ed9cb5ccc0deb77 Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Wed, 17 Mar 2021 17:43:44 +0100
Subject: [PATCH] Fix some review related code snippets

---
 arkindex_worker/cache.py  | 42 +++++++++++++++++++++------------------
 arkindex_worker/worker.py | 39 +++++++++++++++++++++++-------------
 2 files changed, 48 insertions(+), 33 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index d6c18c9d..3ae76148 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -1,31 +1,35 @@
 # -*- coding: utf-8 -*-
-import os
 import sqlite3
 
+from arkindex_worker import logger
+
+SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
+    id VARCHAR(32) PRIMARY KEY,
+    parent_id VARCHAR(32),
+    name TEXT NOT NULL,
+    type TEXT NOT NULL,
+    polygon TEXT,
+    worker_version_id VARCHAR(32)
+)"""
+
 
 class LocalDB(object):
     def __init__(self, path):
-        if not os.path.exists(path):
-            open(path, "x").close()
-
         self.db = sqlite3.connect(path)
+        self.db.row_factory = sqlite3.Row
         self.cursor = self.db.cursor()
+        logger.info(f"Connection to local cache {path} established.")
 
-    def create_elements_table(self):
-        try:
-            self.cursor.execute(
-                """CREATE TABLE elements (
-                id TEXT PRIMARY KEY,
-                parent_id TEXT,
-                name TEXT NOT NULL,
-                type TEXT NOT NULL,
-                polygon TEXT,
-                worker_version_id TEXT
-            )"""
-            )
-        except sqlite3.OperationalError:
-            print("Table 'elements' already exists")
+    def create_tables(self):
+        self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
 
     def insert(self, table, lines):
-        self.cursor.executemany(f"INSERT INTO {table} VALUES (?,?,?,?,?,?)", lines)
+        if not lines:
+            return
+        columns = ", ".join(lines[0].keys())
+        placeholders = ", ".join("?" * len(lines[0]))
+        values = [tuple(line.values()) for line in lines]
+        self.cursor.executemany(
+            f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
+        )
         self.db.commit()
diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index 830f7b51..b4f78a7b 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -3,6 +3,7 @@ import argparse
 import json
 import logging
 import os
+import sqlite3
 import sys
 import uuid
 import warnings
@@ -21,6 +22,7 @@ from arkindex_worker.models import Element
 from arkindex_worker.reporting import Reporter
 
 MANUAL_SLUG = "manual"
+CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}"
 
 
 class BaseWorker(object):
@@ -47,6 +49,14 @@ class BaseWorker(object):
 
         logger.info(f"Worker will use {self.work_dir} as working directory")
 
+        if os.path.isdir(CACHE_DIR):
+            cache_path = CACHE_DIR + "/db.sqlite"
+        else:
+            cache_path = os.getcwd() + "/db.sqlite"
+
+        self.cache = LocalDB(cache_path)
+        self.cache.create_tables()
+
     @property
     def is_read_only(self):
         """Worker cannot publish anything without a worker version ID"""
@@ -453,20 +463,21 @@ class ElementsWorker(BaseWorker):
             self.report.add_element(parent.id, element["type"])
 
         # Store elements in local cache
-        cache = LocalDB(f"/data/{os.environ.get('TASK_ID')}/db.sqlite")
-        cache.create_elements_table()
-        to_insert = [
-            (
-                created_ids[idx],
-                parent.id,
-                element["name"],
-                element["type"],
-                json.dumps(element["polygon"]),
-                self.worker_version_id,
-            )
-            for idx, element in enumerate(elements)
-        ]
-        cache.insert("elements", to_insert)
+        try:
+            to_insert = [
+                {
+                    "id": uuid.UUID(created_ids[idx]).hex,
+                    "parent_id": uuid.UUID(parent.id).hex,
+                    "name": element["name"],
+                    "type": element["type"],
+                    "polygon": json.dumps(element["polygon"]),
+                    "worker_version_id": uuid.UUID(self.worker_version_id).hex,
+                }
+                for idx, element in enumerate(elements)
+            ]
+            self.cache.insert("elements", to_insert)
+        except sqlite3.IntegrityError as e:
+            logger.warning(f"Couldn't save created elements in local cache: {e}")
 
         return created_ids
 
-- 
GitLab