Skip to content

Commit f8d35a5

Browse files
committed
add kalman filter processor
1 parent 65c0cfe commit f8d35a5

3 files changed

Lines changed: 140 additions & 1 deletion

File tree

dlclive/processor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from dlclive.processor.processor import Processor
2+
from dlclive.processor.kalmanfilter import KalmanFilterPredictor

dlclive/processor/kalmanfilter.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
DeepLabCut2.0 Toolbox (deeplabcut.org)
3+
© A. & M. Mathis Labs
4+
https://github.com/AlexEMG/DeepLabCut
5+
6+
Please see AUTHORS for contributors.
7+
https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS
8+
Licensed under GNU Lesser General Public License v3.0
9+
"""
10+
11+
12+
import time
13+
import numpy as np
14+
from dlclive.processor import Processor
15+
16+
17+
class KalmanFilterPredictor(Processor):
18+
19+
20+
def __init__(self,
21+
adapt=True,
22+
forward=0.002,
23+
fps=30,
24+
nderiv=2,
25+
priors=[10, 1],
26+
initial_var=1,
27+
process_var=1,
28+
dlc_var=20,
29+
**kwargs):
30+
31+
super().__init__(**kwargs)
32+
33+
self.adapt=adapt
34+
self.forward = forward
35+
self.dt = 1.0 / fps
36+
self.nderiv = nderiv
37+
self.priors = np.hstack(([1e5], priors))
38+
self.initial_var = initial_var
39+
self.process_var = process_var
40+
self.dlc_var = dlc_var
41+
self.is_initialized = False
42+
43+
44+
def _get_forward_model(self, dt):
45+
46+
F = np.zeros((self.n_states, self.n_states))
47+
for d in range(self.nderiv+1):
48+
for i in range(self.n_states - (d * self.bp * 2)):
49+
F[i, i + (2 * self.bp * d)] = (dt ** d) / max(1, d)
50+
51+
return F
52+
53+
54+
def _init_kf(self, pose):
55+
56+
# get number of body parts
57+
self.bp = pose.shape[0]
58+
self.n_states = self.bp * 2 * (self.nderiv+1)
59+
60+
# initialize state matrix, set position to first pose
61+
self.X = np.zeros((self.n_states, 1))
62+
self.X[:(self.bp * 2)] = pose[:, :2].reshape(self.bp * 2, 1)
63+
64+
# initialize covariance matrix, measurement noise and process noise
65+
self.P = np.eye(self.n_states) * self.initial_var
66+
self.R = np.eye(self.n_states) * self.dlc_var
67+
self.Q = np.eye(self.n_states) * self.process_var
68+
69+
# initialize forward model:
70+
self.F = self._get_forward_model(self.dt)
71+
72+
self.H = np.eye(self.n_states)
73+
self.K = np.zeros((self.n_states, self.n_states))
74+
self.I = np.eye(self.n_states)
75+
76+
# initialize priors for forward prediction step only
77+
B = np.repeat(self.priors, self.bp * 2)
78+
self.B = B.reshape(B.size, 1)
79+
80+
self.is_initialized = True
81+
82+
83+
def _predict(self):
84+
85+
self.Xp = np.dot(self.F, self.X)
86+
self.Pp = np.dot(np.dot(self.F, self.P), self.F.T) + self.Q
87+
88+
89+
def _get_residuals(self, pose):
90+
91+
z = np.zeros((self.n_states, 1))
92+
z[:(self.bp * 2)] = pose[:self.bp, :2].reshape(self.bp * 2, 1)
93+
for i in range(self.bp * 2, self.n_states):
94+
z[i] = (z[i - (self.bp * 2)] - self.X[i - (self.bp * 2)]) / self.dt
95+
self.y = z - np.dot(self.H, self.Xp)
96+
97+
98+
def _update(self):
99+
100+
S = np.dot(self.H, np.dot(self.Pp, self.H.T)) + self.R
101+
K = np.dot(np.dot(self.Pp, self.H.T), np.linalg.inv(S))
102+
self.X = self.Xp + np.dot(K, self.y)
103+
self.P = np.dot(self.I - np.dot(K, self.H), self.Pp)
104+
105+
106+
def _get_future_pose(self, dt):
107+
108+
print(dt)
109+
110+
Ff = self._get_forward_model(dt)
111+
112+
Pf = np.diag(self.P).reshape(self.P.shape[0], 1)
113+
Xf = (1 / ((1 / Pf) + (1 / self.B))) * (self.X / Pf)
114+
Xfp = np.dot(Ff, Xf)
115+
116+
future_pose = Xfp[:(self.bp * 2)].reshape(self.bp, 2)
117+
118+
return future_pose
119+
120+
121+
def process(self, pose, **kwargs):
122+
123+
if not self.is_initialized:
124+
125+
self._init_kf(pose)
126+
return pose
127+
128+
else:
129+
130+
self._predict()
131+
self._get_residuals(pose)
132+
self._update()
133+
134+
forward_time = (time.time() - kwargs['frame_time'] + self.forward) if self.adapt else self.forward
135+
future_pose = self._get_future_pose(forward_time)
136+
future_pose = np.hstack((future_pose, pose[:,2].reshape(self.bp,1)))
137+
138+
return future_pose

dlclive/processor/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class Processor(object):
1818

19-
def __init__(self):
19+
def __init__(self, **kwargs):
2020
pass
2121

2222
def process(self, pose, **kwargs):

0 commit comments

Comments
 (0)