Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ int EMD_wrap_sparse(
uint64_t *flow_sources_out, // Output: source indices of non-zero flows
uint64_t *flow_targets_out, // Output: target indices of non-zero flows
double *flow_values_out, // Output: flow values
uint64_t *n_flows_out,
uint64_t *n_flows_out,
uint64_t max_flows_out,
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
Expand All @@ -62,7 +63,11 @@ int EMD_wrap_lazy(
double *coords_b, // Target coordinates (n2 x dim)
int dim, // Dimension of coordinates
int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock
double *G, // Output: transport plan (n1 x n2)
uint64_t *flow_sources_out, // Output: source indices of non-zero flows
uint64_t *flow_targets_out, // Output: target indices of non-zero flows
double *flow_values_out, // Output: flow values
uint64_t *n_flows_out,
uint64_t max_flows_out,
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
Expand Down
191 changes: 133 additions & 58 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,86 @@ inline void extract_compressed_support(
}
}

struct SparseFlowWriter {
uint64_t* sources;
uint64_t* targets;
double* values;
uint64_t* count;
uint64_t capacity;

SparseFlowWriter(
uint64_t* sources_,
uint64_t* targets_,
double* values_,
uint64_t* count_,
uint64_t capacity_
) : sources(sources_),
targets(targets_),
values(values_),
count(count_),
capacity(capacity_) {}

bool push(uint64_t source, uint64_t target, double flow) {
if (*count >= capacity) return false;
sources[*count] = source;
targets[*count] = target;
values[*count] = flow;
++(*count);
return true;
}
};

template <
typename NetType,
typename DigraphType,
typename InvalidType,
typename SourceIndexVector,
typename TargetIndexVector,
typename CostAccessor
>
inline bool extract_sparse_solution(
const NetType& net,
DigraphType& di,
InvalidType invalid,
const SourceIndexVector& idx_a,
const TargetIndexVector& idx_b,
double* alpha,
double* beta,
double* cost,
SparseFlowWriter& writer,
CostAccessor cost_accessor,
double min_output_flow
) {
const int n = static_cast<int>(idx_a.size());

for (int i = 0; i < n; i++) {
alpha[static_cast<uint64_t>(idx_a[i])] = -net.potential(i);
}
for (int j = 0; j < static_cast<int>(idx_b.size()); j++) {
beta[static_cast<uint64_t>(idx_b[j])] = net.potential(j + n);
}

typename DigraphType::Arc a;
di.first(a);
for (; a != invalid; di.next(a)) {
const int i = di.source(a);
const int j = di.target(a) - n;
const double flow = net.flow(a);
if (flow != 0) {
*cost += flow * cost_accessor(a, i, j);
}
if (flow > min_output_flow) {
if (!writer.push(
static_cast<uint64_t>(idx_a[i]),
static_cast<uint64_t>(idx_b[j]),
flow)) {
return false;
}
}
}
return true;
}

} // namespace


Expand Down Expand Up @@ -186,9 +266,9 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
const SetupPolicy policy = make_setup_policy(n, m, n1, n2, true);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(
di, policy.use_arc_mixing, (int) (n + m), n * m, maxIter
);
typedef NetworkSimplexSimple<Digraph, double, double, node_id_type> Simplex;
Simplex::SimplexOptions simplex_options(policy.use_arc_mixing);
Simplex net(di, simplex_options, (int) (n + m), n * m, maxIter);

// Set supply and demand, don't account for 0 values (faster)

Expand Down Expand Up @@ -341,6 +421,7 @@ int EMD_wrap_sparse(
uint64_t *flow_targets_out,
double *flow_values_out,
uint64_t *n_flows_out,
uint64_t max_flows_out,
double *alpha,
double *beta,
double *cost,
Expand Down Expand Up @@ -432,9 +513,9 @@ int EMD_wrap_sparse(

di.buildFromEdges(edges);

NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
di, true, (int)(n + m), di.arcNum(), maxIter
);
typedef NetworkSimplexSimple<Digraph, double, double, node_id_type> Simplex;
Simplex::SimplexOptions simplex_options(true);
Simplex net(di, simplex_options, (int)(n + m), di.arcNum(), maxIter);

net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m);

Expand Down Expand Up @@ -463,41 +544,34 @@ int EMD_wrap_sparse(
int ret = net.run();
if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
*cost = 0;
*n_flows_out = 0;
*n_flows_out = 0;

SparseFlowWriter writer(
flow_sources_out,
flow_targets_out,
flow_values_out,
n_flows_out,
max_flows_out
);

Arc a;
di.first(a);
for (; a != INVALID; di.next(a)) {
uint64_t i = di.source(a);
uint64_t j = di.target(a);
double flow = net.flow(a);

uint64_t orig_i = indI[i];
uint64_t orig_j = indJ[j - n];


double arc_cost = arc_costs[a];

*cost += flow * arc_cost;


*(alpha + orig_i) = -net.potential(i);
*(beta + orig_j) = net.potential(j);

if (flow > 1e-15) {
flow_sources_out[*n_flows_out] = orig_i;
flow_targets_out[*n_flows_out] = orig_j;
flow_values_out[*n_flows_out] = flow;
(*n_flows_out)++;
}
auto sparse_cost = [&arc_costs](Arc a, int, int) {
return arc_costs[a];
};
if (!extract_sparse_solution(
net, di, INVALID, indI, indJ, alpha, beta, cost, writer,
sparse_cost, 1e-15)) {
return (int)net.MAX_ITER_REACHED;
}
}
return ret;
}

int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b,
int dim, int metric, double *G, double *alpha, double *beta,
double *cost, uint64_t maxIter, double *alpha_init, double *beta_init) {
int dim, int metric, uint64_t *flow_sources_out,
uint64_t *flow_targets_out, double *flow_values_out,
uint64_t *n_flows_out, uint64_t max_flows_out,
double *alpha, double *beta, double *cost, uint64_t maxIter,
double *alpha_init, double *beta_init) {
using namespace lemon;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
Expand Down Expand Up @@ -552,8 +626,15 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double
// Create full bipartite graph
Digraph di(n, m);

NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter
typedef NetworkSimplexSimple<Digraph, double, double, node_id_type> Simplex;
Simplex::SimplexOptions simplex_options(false);
simplex_options.cost_storage_mode = Simplex::CostStorageMode::ArtificialOnly;
simplex_options.flow_storage_mode = Simplex::FlowStorageMode::SparseRealArcs;
simplex_options.endpoint_storage_mode = Simplex::EndpointStorageMode::ComputedRealArcs;
simplex_options.state_storage_mode = Simplex::StateStorageMode::Packed;

Simplex net(
di, simplex_options, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter
);

// Set supplies
Expand Down Expand Up @@ -583,32 +664,26 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double

if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
*cost = 0;
*n_flows_out = 0;

// Initialize output arrays
for (int i = 0; i < n1 * n2; i++) G[i] = 0.0;
for (int i = 0; i < n1; i++) alpha[i] = 0.0;
for (int i = 0; i < n2; i++) beta[i] = 0.0;

// Extract solution
Arc a;
di.first(a);
for (; a != INVALID; di.next(a)) {
int i = di.source(a);
int j = di.target(a) - n;

int orig_i = idx_a[i];
int orig_j = idx_b[j];

double flow = net.flow(a);
G[orig_i * n2 + orig_j] = flow;

alpha[orig_i] = -net.potential(i);
beta[orig_j] = net.potential(j + n);

if (flow > 0) {
double c = net.computeLazyCost(i, j);
*cost += flow * c;
}

SparseFlowWriter writer(
flow_sources_out,
flow_targets_out,
flow_values_out,
n_flows_out,
max_flows_out
);
auto lazy_cost = [&net](Arc, int i, int j) {
return net.computeLazyCost(i, j);
};
if (!extract_sparse_solution(
net, di, INVALID, idx_a, idx_b, alpha, beta, cost, writer,
lazy_cost, 0.0)) {
return (int)net.MAX_ITER_REACHED;
}
}

Expand Down
14 changes: 11 additions & 3 deletions ot/lp/_network_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def emd2_lazy(
alpha_init_np = np.asarray(alpha_init_np, dtype=np.float64, order="C")
beta_init_np = np.asarray(beta_init_np, dtype=np.float64, order="C")

G, cost, u, v, result_code = emd_c_lazy(
flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_lazy(
a_np, b_np, X_a_np, X_b_np, metric, numItermax, alpha_init_np, beta_init_np
)

Expand All @@ -1053,8 +1053,6 @@ def emd2_lazy(
stacklevel=2,
)

G_backend = nx.from_numpy(G, type_as=type_as)

cost_backend = nx.set_gradients(
nx.from_numpy(cost, type_as=type_as),
(a0, b0),
Expand All @@ -1075,6 +1073,16 @@ def emd2_lazy(
"result_code": result_code,
}
if return_matrix:
flow_values_backend = nx.from_numpy(flow_values, type_as=type_as)
flow_sources_backend = nx.from_numpy(flow_sources.astype(np.int64))
flow_targets_backend = nx.from_numpy(flow_targets.astype(np.int64))
G_backend = nx.coo_matrix(
flow_values_backend,
flow_sources_backend,
flow_targets_backend,
shape=(n1, n2),
type_as=type_as,
)
log_dict["G"] = G_backend
return cost_backend, log_dict
else:
Expand Down
21 changes: 15 additions & 6 deletions ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, uint64_t max_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, uint64_t max_flows_out, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED


Expand Down Expand Up @@ -306,7 +306,7 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
n_edges,
<uint64_t*> edge_sources.data, <uint64_t*> edge_targets.data, <double*> edge_costs.data,
<uint64_t*> flow_sources.data, <uint64_t*> flow_targets.data, <double*> flow_values.data,
&n_flows_out,
&n_flows_out, n_edges,
<double*> alpha.data, <double*> beta.data, &cost, max_iter,
alpha_init_ptr, beta_init_ptr
)
Expand All @@ -329,6 +329,8 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1
cdef int result_code = 0
cdef double cost = 0
cdef int metric_code
cdef uint64_t n_flows_out = 0
cdef uint64_t max_flows_out = n1 + n2

# Validate dimension consistency
if coords_b.shape[1] != dim:
Expand All @@ -345,9 +347,11 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1
except KeyError:
raise ValueError(f"Unknown metric: '{metric}'. Supported metrics are: {list(metric_map.keys())}")

cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_sources = np.zeros(max_flows_out, dtype=np.uint64)
cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_targets = np.zeros(max_flows_out, dtype=np.uint64)
cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(max_flows_out, dtype=np.float64)
cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1)
cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2)
cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2])
if not len(a):
a = np.ones((n1,)) / n1
if not len(b):
Expand All @@ -360,5 +364,10 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1
beta_init_ptr = <double*> beta_init.data

with nogil:
result_code = EMD_wrap_lazy(n1, n2, <double*> a.data, <double*> b.data, <double*> coords_a.data, <double*> coords_b.data, dim, metric_code, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
return G, cost, alpha, beta, result_code
result_code = EMD_wrap_lazy(n1, n2, <double*> a.data, <double*> b.data, <double*> coords_a.data, <double*> coords_b.data, dim, metric_code, <uint64_t*> flow_sources.data, <uint64_t*> flow_targets.data, <double*> flow_values.data, &n_flows_out, max_flows_out, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)

flow_sources = flow_sources[:n_flows_out]
flow_targets = flow_targets[:n_flows_out]
flow_values = flow_values[:n_flows_out]

return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code
Loading
Loading