-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvae_basic.py
More file actions
executable file
·164 lines (134 loc) · 5.76 KB
/
vae_basic.py
File metadata and controls
executable file
·164 lines (134 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from matplotlib import pyplot as plt
import numpy as np
from stacked_mnist import DataMode, StackedMNISTData
from variational_autoencoder import VariationalAutoEncoder
from verification_net import VerificationNet
class VAEBasic:
"""
VAE basic task: Reconstructing images
"""
def __init__(self,
latent_dim=4,
three_colors=False,
save_weigths=False,
save_image=False) -> None:
self.var_autoencoder = VariationalAutoEncoder(latent_dim)
self.three_colors = three_colors
self.save_weigths = save_weigths
self.save_image = save_image
self.gen = self.get_generator(self.three_colors)
self.ver_net = VerificationNet()
def get_generator(self, three_colors):
"""
Returning the appropriate generator
"""
# Returning a generator that uses standard MNIST
if three_colors:
return StackedMNISTData(mode=DataMode.COLOR_BINARY_COMPLETE,
default_batch_size=2048)
# Returning a generator that uses stacked MNIST
else:
return StackedMNISTData(mode=DataMode.MONO_BINARY_COMPLETE,
default_batch_size=2048)
def get_train_test(self, gen):
"""
Getting the train and test data
"""
# Getting training and test data
x_train, y_train = gen.get_full_data_set(training=True)
x_test, y_test = gen.get_full_data_set(training=False)
return x_train, y_train, x_test, y_test
def train_var_autoencoder(self):
"""
Training the autoencoder on single-channel images
"""
x_train, y_train, x_test, y_test = self.get_train_test(self.gen)
# Reshaping
x_train = x_train[:, :, :, [0]]
x_test = x_test[:, :, :, [0]]
# Training the VAE
self.var_autoencoder.train(x_train,
x_train,
batch_size=256,
epochs=20,
shuffle=True,
validation_data=(x_test, x_test),
verbose=True,
save_weights=self.save_weigths)
def run(self):
"""
Reconstructing images and displaying the results
"""
# Training the VAE
self.train_var_autoencoder()
x_train, y_train, x_test, y_test = self.get_train_test(self.gen)
if self.three_colors:
reconstructed = []
for i in range(3):
# Getting the specific color channel
x_test_channel = x_test[:, :, :, [i]]
# Sending the images through the VAE to get reconstructed images (mode)
encoded_imgs = self.var_autoencoder.encoder(x_test_channel)
decoded_imgs = self.var_autoencoder.decoder(
encoded_imgs).mode().numpy()
reconstructed.append(np.squeeze(decoded_imgs))
# Combining the different color channel images to one stacked image
reconstructed = np.stack(reconstructed, axis=-1)
# Using VerNet to get predictability and accuracy
pred, acc = self.ver_net.check_predictability(reconstructed,
y_test,
tolerance=0.5)
print("Predictability: " + str(pred) + ", accuracy:" + str(acc))
self.show_figure(10, x_test, reconstructed, y_test, pred, acc)
else:
# Reshaping
x_test = x_test[:, :, :, [0]]
# Sending the images through the VAE to get reconstructed images (mode)
encoded_imgs = self.var_autoencoder.encoder(x_test)
decoded_imgs = self.var_autoencoder.decoder(
encoded_imgs).mode().numpy()
# Using VerNet to get predictability and accuracy
pred, acc = self.ver_net.check_predictability(decoded_imgs, y_test)
print("Predictability: " + str(pred) + ", accuracy:" + str(acc))
self.show_figure(10, x_test, decoded_imgs, y_test, pred, acc)
def show_figure(self, n, original, reconstructed, y_test, predictability,
accuracy):
"""
Plotting original images and their reconstructions
"""
# Showing the original images and reconstructed images
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(original[i].astype(np.float64))
plt.title("Class " + str(y_test[i]))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(reconstructed[i])
plt.title("Reconstruct")
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.suptitle("" + str(n) + " images reconstructed" +
" (Predictability: " + str(predictability) +
", Accuracy: " + str(accuracy) + ")",
fontsize="x-large")
# Choosing filepath
if self.three_colors:
path = "./results/vae-basic-color"
else:
path = "./results/vae-basic-mono"
if self.save_image:
# Save figure
plt.savefig(path)
# Show image
plt.show()
if __name__ == "__main__":
vae_basic = VAEBasic(three_colors=True,
save_image=False,
save_weigths=False)
vae_basic.run()