Skip to content
Snippets Groups Projects
Commit ded809b1 authored by Manon Blanco's avatar Manon Blanco
Browse files

Rename predict and evaluation functions

parent 8c0dfdeb
No related branches found
No related tags found
1 merge request!253Rename predict and evaluation functions
...@@ -670,7 +670,7 @@ class GenericTrainingManager: ...@@ -670,7 +670,7 @@ class GenericTrainingManager:
): ):
for valid_set_name in self.dataset.valid_loaders: for valid_set_name in self.dataset.valid_loaders:
# evaluate set and compute metrics # evaluate set and compute metrics
eval_values = self.evaluate( eval_values = self.validate(
valid_set_name, mlflow_logging=mlflow_logging valid_set_name, mlflow_logging=mlflow_logging
) )
# log valid metrics in tensorboard file # log valid metrics in tensorboard file
...@@ -702,7 +702,7 @@ class GenericTrainingManager: ...@@ -702,7 +702,7 @@ class GenericTrainingManager:
self.save_model(epoch=num_epoch, name="last") self.save_model(epoch=num_epoch, name="last")
self.writer.flush() self.writer.flush()
def evaluate(self, set_name, mlflow_logging=False, **kwargs): def validate(self, set_name, mlflow_logging=False, **kwargs):
""" """
Main loop for validation Main loop for validation
""" """
...@@ -720,7 +720,7 @@ class GenericTrainingManager: ...@@ -720,7 +720,7 @@ class GenericTrainingManager:
tokens=self.tokens, tokens=self.tokens,
) )
with tqdm(total=len(loader.dataset)) as pbar: with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("Evaluation E{}".format(self.latest_epoch)) pbar.set_description("Validation E{}".format(self.latest_epoch))
with torch.no_grad(): with torch.no_grad():
# iterate over batch data # iterate over batch data
for ind_batch, batch_data in enumerate(loader): for ind_batch, batch_data in enumerate(loader):
...@@ -751,7 +751,7 @@ class GenericTrainingManager: ...@@ -751,7 +751,7 @@ class GenericTrainingManager:
) )
return display_values return display_values
def predict( def evaluate(
self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False
): ):
""" """
...@@ -772,7 +772,7 @@ class GenericTrainingManager: ...@@ -772,7 +772,7 @@ class GenericTrainingManager:
) )
with tqdm(total=len(loader.dataset)) as pbar: with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("Prediction") pbar.set_description("Evaluation")
with torch.no_grad(): with torch.no_grad():
for ind_batch, batch_data in enumerate(loader): for ind_batch, batch_data in enumerate(loader):
# iterates over batch data # iterates over batch data
......
...@@ -50,7 +50,7 @@ def train_and_test(rank, params, mlflow_logging=False): ...@@ -50,7 +50,7 @@ def train_and_test(rank, params, mlflow_logging=False):
metrics = ["cer", "wer", "wer_no_punct", "time"] metrics = ["cer", "wer", "wer_no_punct", "time"]
for dataset_name in params["dataset"]["datasets"]: for dataset_name in params["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]: for set_name in ["test", "val", "train"]:
model.predict( model.evaluate(
"{}-{}".format(dataset_name, set_name), "{}-{}".format(dataset_name, set_name),
[ [
(dataset_name, set_name), (dataset_name, set_name),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment