Skip to content

Commit 7df0c63

Browse files
committed
refactor
1 parent 18bb37a commit 7df0c63

1 file changed

Lines changed: 52 additions & 54 deletions

File tree

hapiplot/plot/heatmap.py

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from hapiplot.plot.datetick import datetick
1212

1313

14-
1514
def heatmap(x, y, z, **kwargs):
1615
"""Plot a heatmap using pcolormesh and do typical configuration.
1716
@@ -87,9 +86,6 @@ def heatmap(x, y, z, **kwargs):
8786
* nan.hatch
8887
* nan.hatch.color
8988
* nan.legend - Show legend entry for nans (True by default and if NaNs)
90-
91-
92-
9389
"""
9490

9591
###########################################################################
@@ -198,6 +194,27 @@ def adjustCenters(y, yc):
198194

199195
return yc, ycl
200196

197+
def boundaryInfo(x, coord):
198+
xlabels = None
199+
xcl = None
200+
N = Nx
201+
if coord == "y":
202+
N = Ny
203+
if len(x.shape) == 1 and len(x) == N:
204+
# Centers given. Calculate edges.
205+
if iscategorical(x):
206+
xlabels = x
207+
x = np.linspace(0, x.shape[0]-1, x.shape[0], dtype='int32')
208+
xedges = False
209+
xc = x
210+
x = calcEdges(x, coord)
211+
xc, xcl = adjustCenters(x, xc)
212+
else:
213+
xc = np.array([])
214+
xedges = True
215+
216+
return x, xc, xedges, xcl, xlabels
217+
201218
def iscategorical(x):
202219
return isinstance(x[0], np.character)
203220

@@ -218,8 +235,6 @@ def allint(x):
218235
else:
219236
return False
220237

221-
###########################################################################
222-
223238
opts = {
224239
'logging': False,
225240
'title': '',
@@ -289,11 +304,13 @@ def allint(x):
289304
opts['cmap.name'] = 'viridis'
290305

291306
if not opts['cmap']:
292-
opts['cmap'] = matplotlib.pyplot.get_cmap(opts['cmap.name'], opts['cmap.numcolors'])
307+
opts['cmap'] = matplotlib.pyplot.get_cmap(\
308+
opts['cmap.name'], opts['cmap.numcolors'])
293309

294310
if opts['returnimage']:
295311
fig = Figure()
296-
FigureCanvas(fig) # Not used directly, but calling attaches canvas to fig which is used later.
312+
# Calling FigureCanvas() attaches canvas to fig which is used later.
313+
FigureCanvas(fig)
297314
ax = fig.add_subplot(111)
298315
else:
299316
fig, ax = plt.subplots()
@@ -334,34 +351,6 @@ def allint(x):
334351
if y.ndim == 1 and not (len(y) == Ny or len(y) == Ny+1):
335352
raise ValueError('Required: len(y) == z.shape[0] or len(y) == z.shape[0] + 1.')
336353

337-
categoricalx = iscategorical(x)
338-
if len(x.shape) == 1 and len(x) == Nx:
339-
# Centers given. Calculate edges.
340-
if categoricalx:
341-
xlabels = x
342-
x = np.linspace(0, y.shape[0]-1, y.shape[0], dtype='int32')
343-
xedges = False
344-
xc = x
345-
x = calcEdges(x, 'x')
346-
xc, xcl = adjustCenters(x, xc)
347-
else:
348-
xc = np.array([])
349-
xedges = True
350-
351-
categoricaly = iscategorical(y)
352-
if len(y.shape) == 1 and len(y) == Ny:
353-
# Centers given. Calculate edges.
354-
if categoricaly:
355-
ylabels = y
356-
y = np.linspace(0, y.shape[0]-1, y.shape[0], dtype='int32')
357-
yedges = False
358-
yc = y
359-
y = calcEdges(y, 'y')
360-
yc, ycl = adjustCenters(y, yc)
361-
else:
362-
yc = np.array([])
363-
yedges = True
364-
365354
inan = np.where(np.isnan(z))
366355
havenans = False
367356
allnan = False
@@ -370,10 +359,17 @@ def allint(x):
370359
if np.all(np.isnan(z)):
371360
allnan = True
372361

362+
categoricalx = iscategorical(x)
363+
x, xc, xedges, xcl, xlabels = boundaryInfo(x,'x')
364+
365+
categoricaly = iscategorical(y)
366+
y, yc, yedges, ycl, ylabels = boundaryInfo(y,'y')
367+
373368
xgaps = np.array([], dtype=np.int32)
374-
ygaps = np.array([], dtype=np.int32)
375369
if len(x.shape) == 2: # x is an matrix
376370
x, z, xgaps = calcGaps(x, z, 'x')
371+
372+
ygaps = np.array([], dtype=np.int32)
377373
if len(y.shape) == 2: # y is an matrix
378374
y, z, ygaps = calcGaps(y, z, 'y')
379375

@@ -444,29 +440,31 @@ def allint(x):
444440
# Relabel y-ticks b/c nonuniform center spacing.
445441
ax.set_yticklabels(ycl[0:-1])
446442

447-
# Note: categoricalx and categoricaly are very similar.
448-
if categoricalx:
449-
xcategories, _ = categoryinfo(xlabels)
443+
def setTicks(labels, coord):
444+
categories, _ = categoryinfo(labels)
450445
# TODO: This will create too many ticks if # of categories is large
451-
ax.set_xticklabels(xcategories)
452-
ax.set_xticks(list(ax.get_xticks()) + [-0.5] + list(ax.get_xticks()+0.5))
446+
if coord == 'x':
447+
ax.set_xticklabels(categories)
448+
ticks = ax.get_xticks()
449+
ticklines = ax.get_xticklines()
450+
ax.set_xticks(list(ticks) + [-0.5] + list(ticks+0.5))
451+
else:
452+
ax.set_yticklabels(categories)
453+
ticks = ax.get_yticks()
454+
ticklines = ax.get_yticklines()
455+
ax.set_yticks(list(ticks) + [-0.5] + list(ticks+0.5))
456+
453457
k = 0
454-
l = ax.get_xticklines()
455-
for l in ax.get_xticklines():
456-
if k < 2*len(xcategories):
458+
for l in ticklines:
459+
if k < 2*len(categories):
457460
l.set_markeredgewidth(0)
458461
k = k+1
459462

463+
if categoricalx:
464+
setTicks(xlabels, 'x')
465+
460466
if categoricaly:
461-
ycategories, _ = categoryinfo(ylabels)
462-
ax.set_yticklabels(ycategories)
463-
ax.set_yticks(list(ax.get_yticks()) + [-0.5] + list(ax.get_yticks()+0.5))
464-
k = 0
465-
l = ax.get_yticklines()
466-
for l in ax.get_yticklines():
467-
if k < 2*len(ycategories):
468-
l.set_markeredgewidth(0)
469-
k = k+1
467+
setTicks(ylabels, 'y')
470468

471469
# TODO: categoricalz not implemented.
472470
categoricalz = iscategorical(z)

0 commit comments

Comments
 (0)