Skip to content
Snippets Groups Projects
Commit 462bf308 authored by Eva Bardou's avatar Eva Bardou
Browse files

Fix some review related code snippets

parent b38754b5
No related branches found
No related tags found
1 merge request!67Store created elements in a local SQLite database
Pipeline #78276 failed
# -*- coding: utf-8 -*-
import os
import sqlite3
from arkindex_worker import logger
SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
id VARCHAR(32) PRIMARY KEY,
parent_id VARCHAR(32),
name TEXT NOT NULL,
type TEXT NOT NULL,
polygon TEXT,
worker_version_id VARCHAR(32)
)"""
class LocalDB(object):
def __init__(self, path):
if not os.path.exists(path):
open(path, "x").close()
self.db = sqlite3.connect(path)
self.db.row_factory = sqlite3.Row
self.cursor = self.db.cursor()
logger.info(f"Connection to local cache {path} established.")
def create_elements_table(self):
try:
self.cursor.execute(
"""CREATE TABLE elements (
id TEXT PRIMARY KEY,
parent_id TEXT,
name TEXT NOT NULL,
type TEXT NOT NULL,
polygon TEXT,
worker_version_id TEXT
)"""
)
except sqlite3.OperationalError:
print("Table 'elements' already exists")
def create_tables(self):
self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
def insert(self, table, lines):
self.cursor.executemany(f"INSERT INTO {table} VALUES (?,?,?,?,?,?)", lines)
if not lines:
return
columns = ", ".join(lines[0].keys())
placeholders = ", ".join("?" * len(lines[0]))
values = [tuple(line.values()) for line in lines]
self.cursor.executemany(
f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
)
self.db.commit()
......@@ -3,6 +3,7 @@ import argparse
import json
import logging
import os
import sqlite3
import sys
import uuid
import warnings
......@@ -21,6 +22,7 @@ from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
MANUAL_SLUG = "manual"
CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}"
class BaseWorker(object):
......@@ -47,6 +49,14 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory")
if os.path.isdir(CACHE_DIR):
cache_path = CACHE_DIR + "/db.sqlite"
else:
cache_path = os.getcwd() + "/db.sqlite"
self.cache = LocalDB(cache_path)
self.cache.create_tables()
@property
def is_read_only(self):
"""Worker cannot publish anything without a worker version ID"""
......@@ -453,20 +463,21 @@ class ElementsWorker(BaseWorker):
self.report.add_element(parent.id, element["type"])
# Store elements in local cache
cache = LocalDB(f"/data/{os.environ.get('TASK_ID')}/db.sqlite")
cache.create_elements_table()
to_insert = [
(
created_ids[idx],
parent.id,
element["name"],
element["type"],
json.dumps(element["polygon"]),
self.worker_version_id,
)
for idx, element in enumerate(elements)
]
cache.insert("elements", to_insert)
try:
to_insert = [
{
"id": uuid.UUID(created_ids[idx]).hex,
"parent_id": uuid.UUID(parent.id).hex,
"name": element["name"],
"type": element["type"],
"polygon": json.dumps(element["polygon"]),
"worker_version_id": uuid.UUID(self.worker_version_id).hex,
}
for idx, element in enumerate(elements)
]
self.cache.insert("elements", to_insert)
except sqlite3.IntegrityError as e:
logger.warning(f"Couldn't save created elements in local cache: {e}")
return created_ids
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment