Skip to content
Snippets Groups Projects
Commit e276294e authored by Solene Tarride's avatar Solene Tarride
Browse files

Simplify and document data extraction

parent da4652c8
No related branches found
No related tags found
1 merge request!287Support subword and word language models
......@@ -393,13 +393,13 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_corpus.txt").write_text(
(self.output / "language_model" / "corpus.txt").write_text(
"\n".join(self.language_corpus)
)
(self.output / "language_tokens.txt").write_text(
(self.output / "language_model" / "tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_lexicon.txt").write_text(
(self.output / "language_model" / "lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
......
......@@ -92,12 +92,12 @@ class DAN:
self.decoder = decoder
self.lm_decoder = None
if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0:
if use_language_model and parameters["language_model"]["weight"] > 0:
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
language_model_path=parameters["language_model"]["model"],
lexicon_path=parameters["language_model"]["lexicon"],
tokens_path=parameters["language_model"]["tokens"],
language_model_weight=parameters["language_model"]["weight"],
)
self.mean, self.std = (
......
......@@ -20,10 +20,10 @@ output/
│ ├── train
│ ├── val
│ └── test
── language_model
├── corpus.txt
├── lexicon.txt
└── tokens.txt
── language_model
├── language_corpus.txt
├── language_lexicon.txt
└── language_tokens.txt
```
## 2. Train
......
......@@ -22,8 +22,8 @@ parameters:
- type: "max_resize"
max_height: 1500
max_width: 1500
lm_decoder:
language_model_path: tests/data/prediction/language_model.arpa
lexicon_path: tests/data/prediction/language_lexicon.txt
tokens_path: tests/data/prediction/language_tokens.txt
language_model_weight: 1.0
language_model:
model: tests/data/prediction/language_model.arpa
lexicon: tests/data/prediction/language_lexicon.txt
tokens: tests/data/prediction/language_tokens.txt
weight: 1.0
......@@ -396,9 +396,10 @@ def test_extract(
VAL_DIR / "val-page_1-line_2.jpg",
VAL_DIR / "val-page_1-line_3.jpg",
output / "labels.json",
output / "language_corpus.txt",
output / "language_lexicon.txt",
output / "language_tokens.txt",
# Language resources
output / "language_model" / "corpus.txt",
output / "language_model" / "lexicon.txt",
output / "language_model" / "tokens.txt",
]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
......@@ -487,20 +488,22 @@ def test_extract(
"", expected_language_corpus
)
assert (output / "language_corpus.txt").read_text() == expected_language_corpus
assert (
output / "language_model" / "corpus.txt"
).read_text() == expected_language_corpus
# Check "language_tokens.txt"
expected_language_tokens = [
t if t != " " else "" for t in sorted(list(expected_charset))
]
expected_language_tokens.append("")
assert (output / "language_tokens.txt").read_text() == "\n".join(
assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
expected_language_tokens
)
# Check "language_lexicon.txt"
expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (output / "language_lexicon.txt").read_text() == "\n".join(
assert (output / "language_model" / "lexicon.txt").read_text() == "\n".join(
expected_language_lexicon
)
......
......@@ -739,11 +739,15 @@ def test_run_prediction_language_model(
# Update language_model_weight in parameters.yml
params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
<<<<<<< HEAD
<<<<<<< HEAD
params["parameters"]["language_model"]["weight"] = language_model_weight
=======
params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight
>>>>>>> c80c413 (Write tests for LM decoding)
=======
params["parameters"]["language_model"]["weight"] = language_model_weight
>>>>>>> 57684ef (Simplify and document data extraction)
yaml.dump(params, (tmp_path / "parameters.yml").open("w"))
run_prediction(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment