diff --git a/README.md b/README.md index c2abfdf82d10e7cc97527cc35bdc7e989acd2836..310c5bc6408aeff81b610276e305646e1a891457 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ We provide various open-source models, stored on [HuggingFace](https://huggingfa from doc_ufcn import models from doc_ufcn.main import DocUFCN -model_path, parameters = models.download_model('generic-page') +model_path, parameters = models.download_model('Teklia/doc-ufcn-generic-page') model = DocUFCN(len(parameters['classes']), parameters['input_size'], 'cpu') model.load(model_path, parameters['mean'], parameters['std']) @@ -105,7 +105,7 @@ model.load(model_path, parameters['mean'], parameters['std']) By default, the most recent version of the model will be downloaded. One can also use a specific version using the following line: ```python -model_path, parameters = models.download_model('generic-page', version="main") +model_path, parameters = models.download_model('Teklia/doc-ufcn-generic-page', version="main") ``` ## Training diff --git a/doc_ufcn/models.py b/doc_ufcn/models.py index 045960516fd79ab95d948bdefe9ab446bf9f6280..8eaa91fc29b6aa3df0e5f30e1831d2d9f5a73be3 100644 --- a/doc_ufcn/models.py +++ b/doc_ufcn/models.py @@ -6,16 +6,20 @@ import yaml from huggingface_hub import hf_hub_download from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError -HUGGING_FACE_REPO_PREFIX = "Teklia/doc-ufcn-" +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +def download_model(name: str, version: str | None = None) -> tuple[Path, dict]: + """Downloads a model from Huggingface. + Args: + name (_type_): Name of the model to load on HuggingFace. + version (_type_, optional): Revision to load. Defaults to None which will load the default branch. -def download_model(name, version=None): - # Strip the model name prefix if provided - name = name.replace("doc-ufcn-", "") - logger.info(f"Will look for model @ {HUGGING_FACE_REPO_PREFIX + name}") + Returns: + _type_: The path to the model and its parsed configuration. + """ + logger.info(f"Will look for model @ {name}") cache_dir = Path(os.environ.get("XDG_CACHE_HOME", Path("~/.cache").expanduser())) dir_path = cache_dir / "doc-ufcn" / "models" / name @@ -23,14 +27,14 @@ def download_model(name, version=None): try: # Retrieve parameters.yml parameters_path = hf_hub_download( - repo_id=HUGGING_FACE_REPO_PREFIX + name, + repo_id=name, filename="parameters.yml", cache_dir=dir_path, revision=version, ) # Retrieve model.pth model_path = hf_hub_download( - repo_id=HUGGING_FACE_REPO_PREFIX + name, + repo_id=name, filename="model.pth", cache_dir=dir_path, revision=version, diff --git a/hugging_face/config.yaml b/hugging_face/config.yaml index a9494997028f81b4a5f7a3280b44b9bc69b45f7c..2d4bf7961504e925db0f55b7d0705f3dd2efd0ba 100644 --- a/hugging_face/config.yaml +++ b/hugging_face/config.yaml @@ -9,20 +9,20 @@ examples: - hugging_face/resource/hugging_face_4.jpg models: - - model_name: doc-ufcn-generic-historical-line + - model_name: Teklia/doc-ufcn-generic-historical-line title: Doc-UFCN Generic historical line detection description: >- The [generic historical line detection model](https://huggingface.co/Teklia/doc-ufcn-generic-historical-line) predicts text lines from document images. Please select an image from the examples below or upload your own image! classes_colors: - green - - model_name: doc-ufcn-huginmunin-line + - model_name: Teklia/doc-ufcn-huginmunin-line title: Doc-UFCN Hugin-Munin line detection description: >- The [Hugin-Munin line detection model](https://huggingface.co/Teklia/doc-ufcn-huginmunin-line) predicts horizontal and vertical text lines from Hugin-Munin document images. Please select an image from the examples below or upload your own image! classes_colors: - green - blue - - model_name: doc-ufcn-generic-page + - model_name: Teklia/doc-ufcn-generic-page title: Doc-UFCN Generic page detection description: >- The [generic page detection model](https://huggingface.co/Teklia/doc-ufcn-generic-page) predicts single pages from document images. Please select an image from the examples below or upload your own image! diff --git a/tests/test_models.py b/tests/test_models.py index 886e006637a602084cda812cc5b7898a7223478a..7bfb92ca687b49f1abb967e95166df11d68b4359 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -12,29 +12,14 @@ from doc_ufcn import models [ # Correct name and version ( - "generic-page", + "Teklia/doc-ufcn-generic-page", "main", ( Path("~/.cache") / "doc-ufcn" / "models" - / "generic-page" - / "models--Teklia--doc-ufcn-generic-page" - / "snapshots" - / "21f246af42990ca668e3c2bad69c6e4ae727a0cd" - ), - lazy_fixture("test_parameters"), - True, - ), - # Correct fullname and version - ( - "doc-ufcn-generic-page", - "main", - ( - Path("~/.cache") - / "doc-ufcn" - / "models" - / "generic-page" + / "Teklia" + / "doc-ufcn-generic-page" / "models--Teklia--doc-ufcn-generic-page" / "snapshots" / "21f246af42990ca668e3c2bad69c6e4ae727a0cd" @@ -43,16 +28,17 @@ from doc_ufcn import models True, ), # Correct name and incorrect version - ("generic-page", "version", None, None, True), + ("Teklia/doc-ufcn-generic-page", "version", None, None, True), # Correct name and no version ( - "generic-page", + "Teklia/doc-ufcn-generic-page", None, ( Path("~/.cache") / "doc-ufcn" / "models" - / "generic-page" + / "Teklia" + / "doc-ufcn-generic-page" / "models--Teklia--doc-ufcn-generic-page" / "snapshots" / "21f246af42990ca668e3c2bad69c6e4ae727a0cd" @@ -67,7 +53,11 @@ from doc_ufcn import models ], ) def test_download_model( - name, version, expected_model_path: Path, expected_parameters, correct_name + name: str, + version: str | None, + expected_model_path: Path | None, + expected_parameters: dict | None, + correct_name: bool, ): """ Test of the download_model function. @@ -83,6 +73,7 @@ def test_download_model( with pytest.raises(RepositoryNotFoundError): model_path, parameters = models.download_model(name, version) else: + assert expected_model_path model_path, parameters = models.download_model(name, version) assert model_path == expected_model_path.expanduser() / "model.pth" assert parameters == expected_parameters