Skip to content
Snippets Groups Projects
Commit fbb412bd authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'v0.1' into 'master'

Initial version of the base worker, with cookiecutter support

Closes #1

See merge request arkindex/base-worker!1
parents fde4b49f 2b99c7e5
No related branches found
No related tags found
1 merge request!1Initial version of the base worker, with cookiecutter support
Pipeline #77838 failed
Showing
with 842 additions and 1 deletion
*.egg-info
*.pyc
......@@ -20,3 +20,88 @@ lint:
script:
- pre-commit run -a
test-cookiecutter:
image: python:3
stage: test
cache:
paths:
- .cache/pip
- .cache/pre-commit
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
PRE_COMMIT_HOME: "$CI_PROJECT_DIR/.cache/pre-commit"
ARKINDEX_API_SCHEMA_URL: schema.yml
before_script:
- pip install cookiecutter tox pre-commit
# Configure git to be able to commit in the hook
- git config --global user.email "crasher@teklia.com"
- git config --global user.name "Crash Test"
script:
- cookiecutter --no-input .
- cd worker-demo
- find
- tox
- pre-commit run -a
# Store demo build for later docker build
artifacts:
paths:
- worker-demo/
build-cookiecutter:
image: docker:19.03.1
stage: build
services:
- docker:dind
variables:
DOCKER_DRIVER: overlay2
DOCKER_HOST: tcp://docker:2375/
# Ensure artifacts are available
dependencies:
- test-cookiecutter
script:
- cd worker-demo
- docker build .
pypi-publication:
image: python:3
stage: release
only:
- tags
environment:
name: pypi
url: https://pypi.org/project/arkindex-base-worker
before_script:
- pip install twine setuptools wheel
- echo "[distutils]" > ~/.pypirc
- echo "index-servers =" >> ~/.pypirc
- echo " pypi" >> ~/.pypirc
- echo "[pypi]" >> ~/.pypirc
- echo "repository=https://upload.pypi.org/legacy/" >> ~/.pypirc
- echo "username=$PYPI_DEPLOY_USERNAME" >> ~/.pypirc
- echo "password=$PYPI_DEPLOY_PASSWORD" >> ~/.pypirc
script:
- python setup.py sdist bdist_wheel
- twine upload dist/* -r pypi
release-notes:
stage: release
image: registry.gitlab.com/teklia/devops:latest
only:
- tags
script:
- devops release-notes
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =
known_third_party =PIL,apistar,requests,setuptools,tenacity
......@@ -31,6 +31,7 @@ repos:
- id: trailing-whitespace
- id: check-yaml
args: [--allow-multiple-documents]
exclude: "^worker-{{cookiecutter.slug}}/.arkindex.yml$"
- id: mixed-line-ending
- id: name-tests-test
args: ['--django']
......
......@@ -2,3 +2,11 @@ Arkindex base Worker
====================
An easy to use Python 3 high level API client, to build ML tasks.
## Create a new worker using our template
```
pip install --user cookiecutter
cookiecutter git@gitlab.com:arkindex/base-worker.git
```
0.1.0
# -*- coding: utf-8 -*-
import logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
# -*- coding: utf-8 -*-
from collections import namedtuple
from io import BytesIO
from math import ceil
import requests
from PIL import Image
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from arkindex_worker import logger
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
BoundingBox = namedtuple("BoundingBox", ["x", "y", "width", "height"])
def open_image(path, mode="RGB"):
"""
Open an image from a path or a URL
"""
try:
image = Image.open(path)
except IOError:
image = download_image(path)
if image.mode != mode:
image = image.convert(mode)
return image
def download_image(url):
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.info(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def polygon_bounding_box(polygon):
x_coords, y_coords = zip(*polygon)
x, y = min(x_coords), min(y_coords)
width, height = max(x_coords) - x, max(y_coords) - y
return BoundingBox(x, y, width, height)
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url):
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_tiles(url):
"""
Reconstruct a full IIIF image on servers that cannot serve the full-sized image using tiles.
"""
if not url.endswith("/"):
url += "/"
logger.debug("Downloading image information")
info = _retried_request(url + "info.json").json()
image_width, image_height = info.get("width"), info.get("height")
assert image_width and image_height, "Missing image dimensions in info.json"
assert info.get(
"tiles"
), "Image cannot be retrieved at full size and tiles are not supported"
# Take the biggest available tile size
tile = sorted(info["tiles"], key=lambda tile: tile.get("width", 0), reverse=True)[0]
tile_width = tile["width"]
# Tile height is optional and defaults to the width
tile_height = tile.get("height", tile_width)
full_image = Image.new("RGB", (image_width, image_height))
for tile_x in range(ceil(image_width / tile_width)):
for tile_y in range(ceil(image_height / tile_height)):
region_x = tile_x * tile_width
region_y = tile_y * tile_height
# Prevent trying to crop outside the bounds of an image
region_width = min(tile_width, image_width - region_x)
region_height = min(tile_height, image_height - region_y)
logger.debug(f"Downloading tile {tile_x},{tile_y}")
resp = _retried_request(
f"{url}{region_x},{region_y},{region_width},{region_height}/full/0/default.jpg"
)
tile_img = Image.open(BytesIO(resp.content))
# Some bad IIIF image server implementations may sometimes return tiles with a few pixels of difference
# with the expected sizes, causing Pillow to raise ValueError('images do not match').
actual_width, actual_height = tile_img.size
if actual_width < region_width or actual_height < region_height:
# Fail when tiles are too small
raise ValueError(
f"Expected size {region_width}×{region_height} for tile {tile_x},{tile_y}, "
f"but got {actual_width}×{actual_height}"
)
if actual_width > region_width or actual_height > region_height:
# Warn and crop when tiles are too large
logger.warning(
f"Cropping tile {tile_x},{tile_y} from {actual_width}×{actual_height} "
f"to {region_width}×{region_height}"
)
tile_img = tile_img.crop((0, 0, region_width, region_height))
full_image.paste(
tile_img,
box=(
region_x,
region_y,
region_x + region_width,
region_y + region_height,
),
)
return full_image
# -*- coding: utf-8 -*-
import tempfile
from contextlib import contextmanager
from requests import HTTPError
from arkindex_worker import logger
from arkindex_worker.image import download_tiles, open_image, polygon_bounding_box
class MagicDict(dict):
"""
A dict whose items can be accessed like attributes.
"""
def _magify(self, item):
"""
Automagically convert lists and dicts to MagicDicts and lists of MagicDicts
Allows for nested access: foo.bar.baz
"""
if isinstance(item, list):
return list(map(self._magify, item))
if isinstance(item, dict):
return MagicDict(item)
return item
def __getitem__(self, item):
item = super().__getitem__(item)
return self._magify(item)
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(
"{} object has no attribute '{}'".format(self.__class__.__name__, name)
)
def __delattr__(self, name):
try:
return super().__delattr__(name)
except AttributeError:
try:
return super().__delitem__(name)
except KeyError:
pass
raise
def __dir__(self):
return super().__dir__() + list(self.keys())
class Element(MagicDict):
"""
Describes any kind of element.
"""
def image_url(self, size="full"):
"""
When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers
:param size: Subresolution of the image.
"""
if not self.get("zone"):
return
url = self.zone.image.get("s3_url")
if url:
return url
url = self.zone.image.url
if not url.endswith("/"):
url += "/"
return "{}full/{}/0/default.jpg".format(url, size)
@property
def requires_tiles(self):
if not self.get("zone") or self.zone.image.get("s3_url"):
return False
max_width = self.zone.image.server.max_width or float("inf")
max_height = self.zone.image.server.max_height or float("inf")
bounding_box = polygon_bounding_box(self.zone.polygon)
return bounding_box.width > max_width or bounding_box.height > max_height
def open_image(self, *args, max_size=None, **kwargs):
"""
Open this element's image as a Pillow image.
:param max_size: Subresolution of the image.
"""
if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id))
if self.requires_tiles:
if max_size is None:
return download_tiles(self.zone.image.url)
else:
raise NotImplementedError
if max_size is not None:
bounding_box = polygon_bounding_box(self.zone.polygon)
original_size = {"w": self.zone.image.width, "h": self.zone.image.height}
# No resizing if the element is smaller than the image.
if (
bounding_box.width != original_size["w"]
or bounding_box.height != original_size["h"]
):
resize = "full"
logger.warning(
"Only full image size elements covered, "
+ "downloading full size image."
)
# No resizing if the image is smaller than the wanted size.
elif original_size["w"] <= max_size and original_size["h"] <= max_size:
resize = "full"
# Resizing if the image is bigger than the wanted size.
else:
ratio = max_size / max(original_size.values())
new_width, new_height = [int(x * ratio) for x in original_size.values()]
resize = "{},{}".format(new_width, new_height)
else:
resize = "full"
try:
return open_image(self.image_url(resize), *args, **kwargs)
except HTTPError as e:
if (
self.zone.image.get("s3_url") is not None
and e.response.status_code == 403
):
# This element uses an S3 URL: the URL may have expired.
# Call the API to get a fresh URL and try again
# TODO: this should be done by the worker
raise NotImplementedError
return open_image(self.image_url(resize), *args, **kwargs)
raise
@contextmanager
def open_image_tempfile(self, format="jpeg", *args, **kwargs):
"""
Get the element's image as a temporary file stored on the disk.
To be used as a context manager: with element.open_image_tempfile() as f: ...
"""
with tempfile.NamedTemporaryFile() as f:
self.open_image(*args, **kwargs).save(f, format=format)
yield f
def __str__(self):
if isinstance(self.type, dict):
type_name = self.type["display_name"]
else:
type_name = str(self.type)
return "{} {} ({})".format(type_name, self.name, self.id)
# -*- coding: utf-8 -*-
import json
import traceback
from collections import Counter
from datetime import datetime
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
class Reporter(object):
def __init__(self, name):
# TODO: use real data from workers
self.report_data = {
"slug": name,
"version": "0.0",
"started": datetime.utcnow().isoformat(),
"elements": {},
}
logger.info(f"Starting ML report for {name}")
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.report_data["slug"])
def _get_element(self, element_id):
return self.report_data["elements"].setdefault(
str(element_id),
{
"started": datetime.utcnow().isoformat(),
# Created element counts, by type slug
"elements": {},
# Created transcription counts, by type
"transcriptions": {},
# Created classification counts, by class
"classifications": {},
"errors": [],
},
)
def process(self, element_id):
"""
Report that a specific element ID is being processed.
"""
# Just call the element initializer
self._get_element(element_id)
def add_element(self, parent_id, type):
"""
Report creating a single element with a parent.
"""
elements = self._get_element(parent_id)["elements"]
elements.setdefault(type, 0)
elements[type] += 1
def add_classification(self, element_id, class_name):
"""
Report creating a classification on an element.
"""
classifications = self._get_element(element_id)["classifications"]
classifications.setdefault(class_name, 0)
classifications[class_name] += 1
def add_classifications(self, element_id, classifications):
"""
Report one or more classifications at once.
"""
assert isinstance(
classifications, list
), "A list is required for classifications"
element = self._get_element(element_id)
# Retrieve the previous existing classification counts, if any
counter = Counter(**element["classifications"])
# Add the new ones
counter.update(
[classification["class_name"] for classification in classifications]
)
element["classifications"] = dict(counter)
def add_transcription(self, element_id, type):
"""
Report creating a transcription on an element.
"""
transcriptions = self._get_element(element_id)["transcriptions"]
transcriptions.setdefault(type, 0)
transcriptions[type] += 1
def add_transcriptions(self, element_id, transcriptions):
"""
Report one or more transcriptions at once.
"""
assert isinstance(transcriptions, list), "A list is required for transcriptions"
element = self._get_element(element_id)
# Retrieve the previous existing transcription counts, if any
counter = Counter(**element["transcriptions"])
# Add the new ones
counter.update([transcription["type"] for transcription in transcriptions])
element["transcriptions"] = dict(counter)
def add_entity(self, *args, **kwargs):
raise NotImplementedError
def add_entity_link(self, *args, **kwargs):
raise NotImplementedError
def add_entity_role(self, *args, **kwargs):
raise NotImplementedError
def error(self, element_id, exception):
error_data = {
"class": exception.__class__.__name__,
"message": str(exception),
}
if exception.__traceback__ is not None:
error_data["traceback"] = "\n".join(
traceback.format_tb(exception.__traceback__)
)
if isinstance(exception, ErrorResponse):
error_data["message"] = exception.title
error_data["status_code"] = exception.status_code
error_data["content"] = exception.content
self._get_element(element_id)["errors"].append(error_data)
def save(self, path):
logger.info(f"Saving ML report to {path}")
with open(path, "w") as f:
json.dump(self.report_data, f)
# -*- coding: utf-8 -*-
import argparse
import json
import logging
import os
import sys
import uuid
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker"):
self.parser = argparse.ArgumentParser(description=description)
# Setup workdir either in Ponos environment or on host's home
if os.environ.get("PONOS_DATA"):
self.work_dir = os.path.join(os.environ["PONOS_DATA"], "current")
else:
# We use the official XDG convention to store file for developers
# https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html
xdg_data_home = os.environ.get(
"XDG_DATA_HOME", os.path.expanduser("~/.local/share")
)
self.work_dir = os.path.join(xdg_data_home, "arkindex")
os.makedirs(self.work_dir, exist_ok=True)
logger.info(f"Worker will use {self.work_dir} as working directory")
def configure(self):
"""
Configure worker using cli args and environment variables
"""
self.parser.add_argument(
"-v",
"--verbose",
help="Display more information on events and errors",
action="store_true",
default=False,
)
# Call potential extra arguments
self.add_arguments()
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
# Setup logging level
if self.args.verbose:
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
# Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
def add_arguments(self):
"""Override this method to add argparse argument to this worker"""
def run(self):
"""Override this method to implement your own process"""
class ElementsWorker(BaseWorker):
def __init__(self, description="Arkindex Elements Worker"):
super().__init__(description)
# Add report concerning elements
self.report = Reporter("unknown worker")
# Add mandatory argument to process elements
self.parser.add_argument(
"--elements-list",
help="JSON elements list to use",
type=open,
default=os.environ.get("TASK_ELEMENTS"),
)
self.parser.add_argument(
"--element",
type=uuid.UUID,
nargs="+",
help="One or more Arkindex element ID",
)
def list_elements(self):
out = []
# Process elements from JSON file
if self.args.elements_list:
data = json.load(self.args.elements_list)
assert isinstance(data, list), "Elements list must be a list"
assert len(data), "No elements in elements list"
out += list(filter(None, [element.get("id") for element in data]))
# Add any extra element from CLI
if self.args.element:
out += self.args.element
return out
def run(self):
"""
Process every elements from the provided list
"""
self.configure()
# List all elements either from JSON file
# or direct list of elements on CLI
elements = self.list_elements()
if not elements:
logger.warning("No elements to process, stopping.")
sys.exit(1)
# Process every element
count = len(elements)
failed = 0
for i, element_id in enumerate(elements, start=1):
try:
# Load element using Arkindex API
element = Element(
**self.api_client.request("RetrieveElement", id=element_id)
)
logger.info(f"Processing {element} ({i}/{count})")
self.process_element(element)
except Exception as e:
failed += 1
logger.warning(
"Failed running worker on {}: {!r}".format(element_id, e),
exc_info=e if self.args.verbose else None,
)
self.report.error(element_id, e)
# Save report as local artifact
self.report.save(os.path.join(self.work_dir, "ml_report.json"))
if failed:
logger.error(
"Ran on {} elements: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
def process_element(self, element):
"""Override this method to analyze an Arkindex element from the provided list"""
{
"slug": "demo",
"name": "My demo ML Worker",
"version": "0.1",
"worker_type": "classifier"
}
demo.py 0 → 100644
# -*- coding: utf-8 -*-
from arkindex_worker.worker import ElementsWorker
class Demo(ElementsWorker):
def process_element(self, element):
print("Demo processing element", element)
size = 200
w = element.zone.image.width / 2
h = element.zone.image.height / 2
self.api_client.request(
"CreateElement",
body={
"corpus": element.corpus.id,
"type": "crash",
"parent": element.id,
"name": "test bastien",
"image": element.zone.image.id,
"polygon": [
[w - size, h - size],
[w + size, h - size],
[w + size, h + size],
[w - size, h + size],
],
},
)
if __name__ == "__main__":
Demo(description="My demo worker !").run()
#!/bin/bash -e
# Build a new repository for this project
if [[ ! -d .git ]] ; then
git init .
git add .
git commit -m 'Generated files from arkindex/base-worker for {{ cookiecutter.name }}'
fi
arkindex-client==1.0.0
Pillow==7.2.0
tenacity==6.2.0
setup.py 0 → 100644
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os.path
from setuptools import find_packages, setup
def requirements(path):
assert os.path.exists(path), "Missing requirements {}".format(path)
with open(path) as f:
return list(map(str.strip, f.read().splitlines()))
with open("VERSION") as f:
VERSION = f.read()
install_requires = requirements("requirements.txt")
setup(
name="arkindex-base-worker",
version=VERSION,
description="Base Worker to easily build Arkindex ML workflows",
author="Teklia",
author_email="contact@teklia.com",
url="https://teklia.com",
python_requires=">=3.6",
install_requires=install_requires,
packages=find_packages(),
)
---
version: 2
type: worker
workers:
- slug: {{ cookiecutter.slug }}
name: {{ cookiecutter.name }}
type: {{ cookiecutter.worker_type }}
docker:
build: Dockerfile
[flake8]
max-line-length = 150
exclude = .git,__pycache__
ignore = E203,E501,W503
*.pyc
*.egg-info/
.tox/
stages:
- test
- build
- release
test:
image: python:3
stage: test
cache:
paths:
- .cache/pip
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
ARKINDEX_API_SCHEMA_URL: schema.yml
before_script:
- pip install tox
script:
- tox
lint:
image: python:3
cache:
paths:
- .cache/pip
- .cache/pre-commit
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
PRE_COMMIT_HOME: "$CI_PROJECT_DIR/.cache/pre-commit"
before_script:
- pip install pre-commit
script:
- pre-commit run -a
docker-build:
stage: build
image: docker:19.03.1
services:
- docker:dind
variables:
DOCKER_DRIVER: overlay2
DOCKER_HOST: tcp://docker:2375/
script:
- ci/build.sh
release-notes:
stage: release
image: registry.gitlab.com/teklia/devops:latest
only:
- tags
script:
- devops release-notes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment