Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (13)
0.3.0-rc1
0.3.0-rc5
......@@ -60,7 +60,7 @@ def open_image(path, mode="RGB", rotation_angle=0, mirrored=False):
image = image.convert(mode)
if mirrored:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
image = image.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
if rotation_angle:
image = image.rotate(-rotation_angle, expand=True)
......@@ -102,7 +102,9 @@ def download_image(url):
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.info(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
"Downloaded image {} - size={}x{} in {}".format(
url, image.size[0], image.size[1], resp.elapsed
)
)
return image
......
......@@ -196,7 +196,7 @@ class BaseWorker(object):
)
# Retrieve initial configuration from API
self.config = worker_version["configuration"]["configuration"]
self.config = worker_version["configuration"].get("configuration")
if "user_configuration" in worker_version["configuration"]:
# Add default values (if set) to user_configuration
for key, value in worker_version["configuration"][
......@@ -211,7 +211,9 @@ class BaseWorker(object):
# Load worker run configuration when available
worker_configuration = worker_run.get("configuration")
self.user_configuration = worker_configuration.get("configuration")
self.user_configuration = (
worker_configuration.get("configuration") if worker_configuration else None
)
if self.user_configuration:
logger.info("Loaded user configuration from WorkerRun")
# if debug mode is set to true activate debug mode in logger
......
......@@ -60,6 +60,7 @@ class ElementMixin(object):
name: str,
polygon: List[List[Union[int, float]]],
confidence: Optional[float] = None,
slim_output: bool = True,
) -> str:
"""
Create a child element on the given element through the API.
......@@ -95,6 +96,7 @@ class ElementMixin(object):
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence should be None or a float in [0..1] range"
assert isinstance(slim_output, bool), "slim_output should be of type bool"
if self.is_read_only:
logger.warning("Cannot create element as this worker is in read-only mode")
......@@ -102,6 +104,7 @@ class ElementMixin(object):
sub_element = self.request(
"CreateElement",
slim_output=slim_output,
body={
"type": type,
"name": name,
......@@ -115,7 +118,7 @@ class ElementMixin(object):
)
self.report.add_element(element.id, type)
return sub_element["id"]
return sub_element["id"] if slim_output else sub_element
def create_elements(
self,
......
......@@ -167,8 +167,8 @@ class TrainingMixin(object):
def update_model_version(
self,
model_version_details: dict,
description: str = None,
configuration: dict = None,
description: str = "",
configuration: dict = {},
tag: str = None,
) -> None:
"""
......
......@@ -2,7 +2,7 @@ arkindex-client==1.0.9
peewee==3.14.10
Pillow>=9.0
python-gitlab==2.7.1
python-gnupg==0.4.8
python-gnupg==0.5.0
sh==1.14.2
shapely==1.8.2
tenacity==8.0.1
......
......@@ -130,7 +130,7 @@ def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api):
assert logger.level == logging.NOTSET
mocker.patch.object(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_DEBUG", True)
monkeypatch.setenv("ARKINDEX_DEBUG", "True")
worker.args = worker.parser.parse_args()
assert worker.is_read_only is False
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
......@@ -372,6 +372,53 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
assert worker.user_configuration is None
def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses):
"""
No configuration is provided but should not crash
"""
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"model_version_id": None,
"dataimport_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"configuration_id": None,
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"revision": {"hash": "deadbeef1234"},
"configuration": {},
},
"configuration": None,
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"},
}
responses.add(
responses.GET,
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
worker.args = worker.parser.parse_args()
worker.configure()
assert worker.user_configuration is None
def test_load_missing_secret():
worker = BaseWorker()
worker.api_client = MockApiClient()
......
......@@ -447,7 +447,7 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
)
responses.add(
responses.POST,
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/?slim_output=True",
status=500,
)
......@@ -464,15 +464,16 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
]
def test_create_sub_element(responses, mock_elements_worker):
@pytest.mark.parametrize("slim_output", [True, False])
def test_create_sub_element(responses, mock_elements_worker, slim_output):
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
......@@ -480,25 +481,34 @@ def test_create_sub_element(responses, mock_elements_worker):
"zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}},
}
)
child_elt = {
"id": "12345678-1234-1234-1234-123456789123",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
"zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}},
}
responses.add(
responses.POST,
"http://testserver/api/v1/elements/create/",
f"http://testserver/api/v1/elements/create/?slim_output={slim_output}",
status=200,
json={"id": "12345678-1234-1234-1234-123456789123"},
json=child_elt,
)
sub_element_id = mock_elements_worker.create_sub_element(
element_creation_response = mock_elements_worker.create_sub_element(
element=elt,
type="something",
name="0",
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
slim_output=slim_output,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/elements/create/"),
(
"POST",
f"http://testserver/api/v1/elements/create/?slim_output={slim_output}",
),
]
assert json.loads(responses.calls[-1].request.body) == {
"type": "something",
......@@ -510,7 +520,10 @@ def test_create_sub_element(responses, mock_elements_worker):
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": None,
}
assert sub_element_id == "12345678-1234-1234-1234-123456789123"
if slim_output:
assert element_creation_response == "12345678-1234-1234-1234-123456789123"
else:
assert Element(element_creation_response) == Element(child_elt)
def test_create_sub_element_confidence(responses, mock_elements_worker):
......@@ -523,7 +536,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker):
)
responses.add(
responses.POST,
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/?slim_output=True",
status=200,
json={"id": "12345678-1234-1234-1234-123456789123"},
)
......@@ -540,7 +553,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker):
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
]
assert json.loads(responses.calls[-1].request.body) == {
"type": "something",
......
repos:
- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config
- repo: https://github.com/pre-commit/mirrors-isort
rev: v4.3.21
rev: v5.10.1
hooks:
- id: isort
- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 22.6.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
rev: 3.9.2
hooks:
- id: flake8
additional_dependencies:
......@@ -20,7 +16,7 @@ repos:
- 'flake8-copyright==0.2.2'
- 'flake8-debugger==3.1.0'
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
rev: v4.3.0
hooks:
- id: check-ast
- id: check-docstring-first
......@@ -37,7 +33,7 @@ repos:
- id: check-json
- id: requirements-txt-fixer
- repo: https://github.com/codespell-project/codespell
rev: v1.17.1
rev: v2.2.1
hooks:
- id: codespell
args: ['--write-changes']
......
......@@ -16,5 +16,9 @@ def setup_environment(responses):
)
responses.add_passthru(schema_url)
# Set schema url in environment
os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
# Setup a fake worker version ID
os.environ["WORKER_VERSION_ID"] = "1234-{{ cookiecutter.slug }}"
# Setup a fake worker run ID
os.environ["ARKINDEX_WORKER_RUN_ID"] = "1234-{{ cookiecutter.slug }}"