Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pynumdiff/total_variation_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,49 @@ def iterative_velocity(x, dt, params=None, options=None, num_iterations=None, ga

return x_hat, dxdt_hat

#N-d case:
def tvrdiff(x, dt, order, gamma, huberM=float('inf'), solver=None, axis=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep only one public-facing function. If there needs to be a 1d function, you can indent it and put it inside the outer one.

"""
Generalized total variation regularized derivatives (cvxpy). Supports multidimensionality by differentiating along
'axis', independently for each vector obtained by fixing all other indices.

:param np.array[float] x: data to differentiate
:param float dt: step size
:param int order: 1, 2, or 3, the derivative to regularize
:param float gamma: regularization parameter
:param float huberM: Huber loss parameter, in units of scaled median absolute deviation of input data.
:math:`M = \\infty` reduces to :math:`\\ell_2` loss squared on first, fidelity cost term, and
:math:`M = 0` reduces to :math:`\\ell_1` loss, which seeks sparse residuals.
:param str solver: Solver to use. Solver options include: 'MOSEK', 'CVXOPT', 'CLARABEL', 'ECOS'.
If not given, fall back to CVXPY's default.

:return: - **x_hat** (np.array) -- estimated (smoothed) x
- **dxdt_hat** (np.array) -- estimated derivative of x
"""

x0 = np.moveaxis(x, axis, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of np.moveaxis with creating a new data array, I think np.apply_along_axis is more automatic and shorter. Strive to do things in as few calls as possible, because each addition becomes something more to parse when people read and manage the code later.


# end quick if it's just 1d case
if x0.ndim == 1:
x_hat0, dxdt0 = tvrdiff(x0, dt, order, gamma, huberM, solver)
return x_hat0, dxdt0

x_hat0 = np.empty_like(x0, dtype=float)
dxdt0 = np.empty_like(x0, dtype=float)
rest = x0.shape[1:]
print(rest)

# had to loop in python:(
for i in np.ndindex(rest):
slice = (slice(None),) + i
x_hat0[slice], dxdt0[slice] = tvrdiff(x0[slice], dt, order, gamma, huberM, solver)

x_hat = np.moveaxis(x_hat0, 0, axis)
dxdt_hat = np.moveaxis(dxdt0, 0, axis)

return x_hat, dxdt_hat

# 1-d case:
def tvrdiff(x, dt, order, gamma, huberM=float('inf'), solver=None):
"""Generalized total variation regularized derivatives. Use convex optimization (cvxpy) to solve for a
total variation regularized derivative. Other convex-solver-based methods in this module call this function.
Expand All @@ -70,6 +112,7 @@ def tvrdiff(x, dt, order, gamma, huberM=float('inf'), solver=None):
:return: - **x_hat** (np.array) -- estimated (smoothed) x
- **dxdt_hat** (np.array) -- estimated derivative of x
"""

# Normalize for numerical consistency with convex solver
mu = np.mean(x)
sigma = median_abs_deviation(x, scale='normal') # robust alternative to std()
Expand Down