Consider the simulation loop in the Network.run() function:
# Simulate network activity for `time` timesteps.
for t in range(timesteps):
-> for l in self.layers:
# Update each layer of nodes.
if isinstance(self.layers[l], AbstractInput):
self.layers[l].step(inpts[l][t], self.dt)
else:
self.layers[l].step(inpts[l], self.dt)
# Clamp neurons to spike.
clamp = clamps.get(l, None)
if clamp is not None:
self.layers[l].s[clamp] = 1
# Run synapse updates.
-> for c in self.connections:
self.connections[c].update(
reward=reward, mask=masks.get(c, None), learning=self.learning
)
# Get input to all layers.
inpts.update(self.get_inputs())
# Record state variables of interest.
for m in self.monitors:
self.monitors[m].record()
# Re-normalize connections.
-> for c in self.connections:
self.connections[c].normalize()
Where I've marked a ->, there might be an opportunity to use torch.multiprocessing. Since we do updates at time t based on network state at time t-1, all Nodes / Connections updates can be performed with a separate process (thread?) at once. Letting k = no. of layers, m = no. of connections, given enough CPU / GPU resources, the loops marked with -> would have time complexity O(1) instead of O(k), O(m) in the number of layers and connections, respectively.
I think it'd be good to keep around two (?) multiprocessing.Pool objects around, one for Nodes objects and another for Connection objects. Instead of statements of the form:
for l in self.layers:
self.layers[l].step(...)
We might rewrite this as something like:
self.nodes_pool.map(Nodes.step, self.layers)
Here, nodes_pool is defined as an attribute in the Network constructor. This last bit probably won't work straightaway; we'd need to figure out the right syntax (if it exists).
This same idea can also be applied in the Network's reset() and get_inputs() functions.
Consider the simulation loop in the
Network.run()function:Where I've marked a
->, there might be an opportunity to usetorch.multiprocessing. Since we do updates at timetbased on network state at timet-1, allNodes/Connections updates can be performed with a separate process (thread?) at once. Lettingk= no. of layers,m= no. of connections, given enough CPU / GPU resources, the loops marked with->would have time complexityO(1)instead ofO(k),O(m)in the number of layers and connections, respectively.I think it'd be good to keep around two (?)
multiprocessing.Poolobjects around, one forNodesobjects and another forConnectionobjects. Instead of statements of the form:We might rewrite this as something like:
Here,
nodes_poolis defined as an attribute in theNetworkconstructor. This last bit probably won't work straightaway; we'd need to figure out the right syntax (if it exists).This same idea can also be applied in the
Network'sreset()andget_inputs()functions.