-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
53 lines (38 loc) · 1.41 KB
/
dataset.py
File metadata and controls
53 lines (38 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import numpy as np
from torchvision import transforms
from torchvision.datasets import DatasetFolder
def load_img(path):
img = np.load(path)
img = img.astype(np.float32)
return img
class NormalizeRange(object):
# Normalize a tensor image to a given range.
def __init__(self, min_val=0.0, max_val=1.0):
self.min_val = min_val
self.max_val = max_val
def __call__(self, tensor):
tensor = tensor.float()
# remove negative values
tensor = torch.clamp(tensor, min=0.0)
tensor = tensor - torch.min(tensor)
tensor = tensor / torch.max(tensor)
tensor = tensor * (self.max_val - self.min_val) + self.min_val
return tensor
def load_dataset(root, max_val=4.0):
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomCrop(240),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
NormalizeRange(0.0, max_val),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
NormalizeRange(0.0, max_val),
])
train_path = root
train_dataset = DatasetFolder(train_path, loader=load_img, extensions='.npy', transform=train_transform)
test_path = root + '_test'
test_dataset = DatasetFolder(test_path, loader=load_img, extensions='.npy', transform=test_transform)
return train_dataset, test_dataset