Skip to content
Snippets Groups Projects
Commit 5fe511cf authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie
Browse files

Remove deepcopy when I can

parent 87580f6e
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import copy
import os import os
import pickle import pickle
...@@ -62,7 +61,7 @@ class OCRDataset(GenericDataset): ...@@ -62,7 +61,7 @@ class OCRDataset(GenericDataset):
self.collate_function = OCRCollateFunction self.collate_function = OCRCollateFunction
def __getitem__(self, idx): def __getitem__(self, idx):
sample = copy.deepcopy(self.samples[idx]) sample = dict(**self.samples[idx])
if not self.load_in_memory: if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx) sample["img"] = self.get_sample_img(idx)
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import copy
import json import json
import os import os
import random import random
from copy import deepcopy
from time import time from time import time
import numpy as np import numpy as np
...@@ -481,8 +481,7 @@ class GenericTrainingManager: ...@@ -481,8 +481,7 @@ class GenericTrainingManager:
path = os.path.join(self.paths["results"], "params") path = os.path.join(self.paths["results"], "params")
if os.path.isfile(path): if os.path.isfile(path):
return return
params = copy.deepcopy(self.params) params = class_to_str_dict(my_dict=deepcopy(self.params))
params = class_to_str_dict(params)
total_params = 0 total_params = 0
for model_name in self.models.keys(): for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name]) current_params = compute_nb_params(self.models[model_name])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment