diff --git a/atr_data_generator/extract/__init__.py b/atr_data_generator/extract/__init__.py index f45e4484d109c1e676522b9ebafc38d1e16edef9..9caac8416d58138cbc881f5e6717a906059eff80 100644 --- a/atr_data_generator/extract/__init__.py +++ b/atr_data_generator/extract/__init__.py @@ -69,8 +69,8 @@ def get_parser(): filters.add_option("skip_vertical_lines", type=bool, default=False) # Select - select = parser.add_subparser("select", default={}) - select.add_option("dataset", type=str, default=None) + select = parser.add_subparser("select") + select.add_option("dataset", type=str) select.add_option("element_type", type=str, default=None) # Format specific diff --git a/atr_data_generator/extract/arguments.py b/atr_data_generator/extract/arguments.py index 9b91f1eb7b522ecdb9e05c454673c99a6347dc73..c70f0f3d1d596d4d5fc59e899e38864becff9a8b 100644 --- a/atr_data_generator/extract/arguments.py +++ b/atr_data_generator/extract/arguments.py @@ -53,6 +53,9 @@ class SelectArgs(BaseArgs): def __post_init__(self): assert UUID(self.dataset) + # Configuration parser issue: https://gitlab.teklia.com/tools/python-toolbox/-/issues/2 + if self.element_type == "None": + self.element_type = None @dataclass diff --git a/atr_data_generator/extract/base.py b/atr_data_generator/extract/base.py index 4c671056044b5fa7f97440d0c4256facbf3432b0..0e3a402e17f391abfcca5a3a4145f06075c4b7c6 100644 --- a/atr_data_generator/extract/base.py +++ b/atr_data_generator/extract/base.py @@ -181,9 +181,7 @@ class DataGenerator: # Iterate over sets for split in dataset.sets.split(","): # Find the dataset elements - for parent in get_dataset_elements( - dataset, split, self.select.element_type - ): + for parent in get_dataset_elements(dataset, split): self.process_parent(parent.element, split) assert sum( diff --git a/atr_data_generator/extract/db.py b/atr_data_generator/extract/db.py index 20f8d1b6731b7d7728147b7199e44aeeda7bc6fd..5675b7f7def693cec40db2e380c44b1467b09ded 100644 --- a/atr_data_generator/extract/db.py +++ b/atr_data_generator/extract/db.py @@ -8,16 +8,15 @@ from arkindex_export.queries import list_children from atr_data_generator.extract.arguments import MANUAL -def get_dataset_elements(dataset: Dataset, split: str, type: Optional[str]): +def get_dataset_elements(dataset: Dataset, split: str): """ Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus :param dataset: Dataset object from which the elements come. :param split: Set name of the dataset to use. - :param type: Optionally filter by element type. :return: The filtered list of dataset elements. """ - query = ( + return ( DatasetElement.select(DatasetElement.element) .join(Element) .where( @@ -25,10 +24,6 @@ def get_dataset_elements(dataset: Dataset, split: str, type: Optional[str]): DatasetElement.set_name == split, ) ) - if type: - query = query.where(Element.type == type) - - return query def parse_sources(sources: List[str]): diff --git a/docs/extract/configuration.md b/docs/extract/configuration.md index 6358d9f44681ee184ff12b501588514f302830b0..152c1253d201ef116a4265889a86e8ac3f92a4da 100644 --- a/docs/extract/configuration.md +++ b/docs/extract/configuration.md @@ -5,7 +5,7 @@ The YAML configuration for the `extract` subcommand has 5 sections: - `common`, - `image` (optional), - `filter` (optional), -- `select` (optional). +- `select`. An example configuration file, filled with the default values when there is one, is available at `examples/extraction.yml`.