diff --git a/basicsr/data/ffhq_blind_dataset.py b/basicsr/data/ffhq_blind_dataset.py index 9f900606..f0564375 100755 --- a/basicsr/data/ffhq_blind_dataset.py +++ b/basicsr/data/ffhq_blind_dataset.py @@ -47,10 +47,11 @@ def __init__(self, opt): self.crop_components = False if self.latent_gt_path is not None: - self.load_latent_gt = True - self.latent_gt_dict = torch.load(self.latent_gt_path) + self.load_latent_gt = True + self.latent_gt_dict = None # Lazy load to avoid memory issues with multi-process DataLoader else: - self.load_latent_gt = False + self.load_latent_gt = False + self.latent_gt_dict = None if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = self.gt_folder @@ -190,6 +191,9 @@ def __getitem__(self, index): img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) if self.load_latent_gt: + # Lazy load latent_gt_dict to avoid memory issues with multi-process DataLoader + if self.latent_gt_dict is None: + self.latent_gt_dict = torch.load(self.latent_gt_path) if status[0]: latent_gt = self.latent_gt_dict['hflip'][name] else: