diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index fc7eeb2c0..1e8da15c8 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -44,7 +44,8 @@ int EMD_wrap_sparse( uint64_t *flow_sources_out, // Output: source indices of non-zero flows uint64_t *flow_targets_out, // Output: target indices of non-zero flows double *flow_values_out, // Output: flow values - uint64_t *n_flows_out, + uint64_t *n_flows_out, + uint64_t max_flows_out, double *alpha, // Output: dual variables for sources (n1) double *beta, // Output: dual variables for targets (n2) double *cost, // Output: total transportation cost @@ -62,7 +63,11 @@ int EMD_wrap_lazy( double *coords_b, // Target coordinates (n2 x dim) int dim, // Dimension of coordinates int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock - double *G, // Output: transport plan (n1 x n2) + uint64_t *flow_sources_out, // Output: source indices of non-zero flows + uint64_t *flow_targets_out, // Output: target indices of non-zero flows + double *flow_values_out, // Output: flow values + uint64_t *n_flows_out, + uint64_t max_flows_out, double *alpha, // Output: dual variables for sources (n1) double *beta, // Output: dual variables for targets (n2) double *cost, // Output: total transportation cost diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 10ae94190..148b145cb 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -147,6 +147,86 @@ inline void extract_compressed_support( } } +struct SparseFlowWriter { + uint64_t* sources; + uint64_t* targets; + double* values; + uint64_t* count; + uint64_t capacity; + + SparseFlowWriter( + uint64_t* sources_, + uint64_t* targets_, + double* values_, + uint64_t* count_, + uint64_t capacity_ + ) : sources(sources_), + targets(targets_), + values(values_), + count(count_), + capacity(capacity_) {} + + bool push(uint64_t source, uint64_t target, double flow) { + if (*count >= capacity) return false; + sources[*count] = source; + targets[*count] = target; + values[*count] = flow; + ++(*count); + return true; + } +}; + +template < + typename NetType, + typename DigraphType, + typename InvalidType, + typename SourceIndexVector, + typename TargetIndexVector, + typename CostAccessor +> +inline bool extract_sparse_solution( + const NetType& net, + DigraphType& di, + InvalidType invalid, + const SourceIndexVector& idx_a, + const TargetIndexVector& idx_b, + double* alpha, + double* beta, + double* cost, + SparseFlowWriter& writer, + CostAccessor cost_accessor, + double min_output_flow +) { + const int n = static_cast(idx_a.size()); + + for (int i = 0; i < n; i++) { + alpha[static_cast(idx_a[i])] = -net.potential(i); + } + for (int j = 0; j < static_cast(idx_b.size()); j++) { + beta[static_cast(idx_b[j])] = net.potential(j + n); + } + + typename DigraphType::Arc a; + di.first(a); + for (; a != invalid; di.next(a)) { + const int i = di.source(a); + const int j = di.target(a) - n; + const double flow = net.flow(a); + if (flow != 0) { + *cost += flow * cost_accessor(a, i, j); + } + if (flow > min_output_flow) { + if (!writer.push( + static_cast(idx_a[i]), + static_cast(idx_b[j]), + flow)) { + return false; + } + } + } + return true; +} + } // namespace @@ -186,9 +266,9 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, std::vector weights1(n), weights2(m); Digraph di(n, m); const SetupPolicy policy = make_setup_policy(n, m, n1, n2, true); - NetworkSimplexSimple net( - di, policy.use_arc_mixing, (int) (n + m), n * m, maxIter - ); + typedef NetworkSimplexSimple Simplex; + Simplex::SimplexOptions simplex_options(policy.use_arc_mixing); + Simplex net(di, simplex_options, (int) (n + m), n * m, maxIter); // Set supply and demand, don't account for 0 values (faster) @@ -341,6 +421,7 @@ int EMD_wrap_sparse( uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, + uint64_t max_flows_out, double *alpha, double *beta, double *cost, @@ -432,9 +513,9 @@ int EMD_wrap_sparse( di.buildFromEdges(edges); - NetworkSimplexSimple net( - di, true, (int)(n + m), di.arcNum(), maxIter - ); + typedef NetworkSimplexSimple Simplex; + Simplex::SimplexOptions simplex_options(true); + Simplex net(di, simplex_options, (int)(n + m), di.arcNum(), maxIter); net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m); @@ -463,41 +544,34 @@ int EMD_wrap_sparse( int ret = net.run(); if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { *cost = 0; - *n_flows_out = 0; + *n_flows_out = 0; + + SparseFlowWriter writer( + flow_sources_out, + flow_targets_out, + flow_values_out, + n_flows_out, + max_flows_out + ); - Arc a; - di.first(a); - for (; a != INVALID; di.next(a)) { - uint64_t i = di.source(a); - uint64_t j = di.target(a); - double flow = net.flow(a); - - uint64_t orig_i = indI[i]; - uint64_t orig_j = indJ[j - n]; - - - double arc_cost = arc_costs[a]; - - *cost += flow * arc_cost; - - - *(alpha + orig_i) = -net.potential(i); - *(beta + orig_j) = net.potential(j); - - if (flow > 1e-15) { - flow_sources_out[*n_flows_out] = orig_i; - flow_targets_out[*n_flows_out] = orig_j; - flow_values_out[*n_flows_out] = flow; - (*n_flows_out)++; - } + auto sparse_cost = [&arc_costs](Arc a, int, int) { + return arc_costs[a]; + }; + if (!extract_sparse_solution( + net, di, INVALID, indI, indJ, alpha, beta, cost, writer, + sparse_cost, 1e-15)) { + return (int)net.MAX_ITER_REACHED; } } return ret; } int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, - int dim, int metric, double *G, double *alpha, double *beta, - double *cost, uint64_t maxIter, double *alpha_init, double *beta_init) { + int dim, int metric, uint64_t *flow_sources_out, + uint64_t *flow_targets_out, double *flow_values_out, + uint64_t *n_flows_out, uint64_t max_flows_out, + double *alpha, double *beta, double *cost, uint64_t maxIter, + double *alpha_init, double *beta_init) { using namespace lemon; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(Digraph); @@ -552,8 +626,15 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double // Create full bipartite graph Digraph di(n, m); - NetworkSimplexSimple net( - di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter + typedef NetworkSimplexSimple Simplex; + Simplex::SimplexOptions simplex_options(false); + simplex_options.cost_storage_mode = Simplex::CostStorageMode::ArtificialOnly; + simplex_options.flow_storage_mode = Simplex::FlowStorageMode::SparseRealArcs; + simplex_options.endpoint_storage_mode = Simplex::EndpointStorageMode::ComputedRealArcs; + simplex_options.state_storage_mode = Simplex::StateStorageMode::Packed; + + Simplex net( + di, simplex_options, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter ); // Set supplies @@ -583,32 +664,26 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { *cost = 0; + *n_flows_out = 0; // Initialize output arrays - for (int i = 0; i < n1 * n2; i++) G[i] = 0.0; for (int i = 0; i < n1; i++) alpha[i] = 0.0; for (int i = 0; i < n2; i++) beta[i] = 0.0; - - // Extract solution - Arc a; - di.first(a); - for (; a != INVALID; di.next(a)) { - int i = di.source(a); - int j = di.target(a) - n; - - int orig_i = idx_a[i]; - int orig_j = idx_b[j]; - - double flow = net.flow(a); - G[orig_i * n2 + orig_j] = flow; - - alpha[orig_i] = -net.potential(i); - beta[orig_j] = net.potential(j + n); - - if (flow > 0) { - double c = net.computeLazyCost(i, j); - *cost += flow * c; - } + + SparseFlowWriter writer( + flow_sources_out, + flow_targets_out, + flow_values_out, + n_flows_out, + max_flows_out + ); + auto lazy_cost = [&net](Arc, int i, int j) { + return net.computeLazyCost(i, j); + }; + if (!extract_sparse_solution( + net, di, INVALID, idx_a, idx_b, alpha, beta, cost, writer, + lazy_cost, 0.0)) { + return (int)net.MAX_ITER_REACHED; } } diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index c352f17e6..8b7870d15 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -1037,7 +1037,7 @@ def emd2_lazy( alpha_init_np = np.asarray(alpha_init_np, dtype=np.float64, order="C") beta_init_np = np.asarray(beta_init_np, dtype=np.float64, order="C") - G, cost, u, v, result_code = emd_c_lazy( + flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_lazy( a_np, b_np, X_a_np, X_b_np, metric, numItermax, alpha_init_np, beta_init_np ) @@ -1053,8 +1053,6 @@ def emd2_lazy( stacklevel=2, ) - G_backend = nx.from_numpy(G, type_as=type_as) - cost_backend = nx.set_gradients( nx.from_numpy(cost, type_as=type_as), (a0, b0), @@ -1075,6 +1073,16 @@ def emd2_lazy( "result_code": result_code, } if return_matrix: + flow_values_backend = nx.from_numpy(flow_values, type_as=type_as) + flow_sources_backend = nx.from_numpy(flow_sources.astype(np.int64)) + flow_targets_backend = nx.from_numpy(flow_targets.astype(np.int64)) + G_backend = nx.coo_matrix( + flow_values_backend, + flow_sources_backend, + flow_targets_backend, + shape=(n1, n2), + type_as=type_as, + ) log_dict["G"] = G_backend return cost_backend, log_dict else: diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 71bb7d9d5..a6d42bb50 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -22,8 +22,8 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil - int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil - int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil + int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, uint64_t max_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil + int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, uint64_t max_flows_out, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -306,7 +306,7 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, n_edges, edge_sources.data, edge_targets.data, edge_costs.data, flow_sources.data, flow_targets.data, flow_values.data, - &n_flows_out, + &n_flows_out, n_edges, alpha.data, beta.data, &cost, max_iter, alpha_init_ptr, beta_init_ptr ) @@ -329,6 +329,8 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1 cdef int result_code = 0 cdef double cost = 0 cdef int metric_code + cdef uint64_t n_flows_out = 0 + cdef uint64_t max_flows_out = n1 + n2 # Validate dimension consistency if coords_b.shape[1] != dim: @@ -345,9 +347,11 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1 except KeyError: raise ValueError(f"Unknown metric: '{metric}'. Supported metrics are: {list(metric_map.keys())}") + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_sources = np.zeros(max_flows_out, dtype=np.uint64) + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_targets = np.zeros(max_flows_out, dtype=np.uint64) + cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(max_flows_out, dtype=np.float64) cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) - cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2]) if not len(a): a = np.ones((n1,)) / n1 if not len(b): @@ -360,5 +364,10 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1 beta_init_ptr = beta_init.data with nogil: - result_code = EMD_wrap_lazy(n1, n2, a.data, b.data, coords_a.data, coords_b.data, dim, metric_code, G.data, alpha.data, beta.data, &cost, max_iter, alpha_init_ptr, beta_init_ptr) - return G, cost, alpha, beta, result_code + result_code = EMD_wrap_lazy(n1, n2, a.data, b.data, coords_a.data, coords_b.data, dim, metric_code, flow_sources.data, flow_targets.data, flow_values.data, &n_flows_out, max_flows_out, alpha.data, beta.data, &cost, max_iter, alpha_init_ptr, beta_init_ptr) + + flow_sources = flow_sources[:n_flows_out] + flow_targets = flow_targets[:n_flows_out] + flow_values = flow_values[:n_flows_out] + + return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 6b0904d15..3292645ca 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -48,6 +48,7 @@ #include #include #include +#include #ifdef HASHMAP #include #else @@ -228,14 +229,62 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) : + enum class CostMode { + StoredArray, + DenseMatrix, + LazyGeometry + }; + + enum class CostStorageMode { + Dense, + ArtificialOnly + }; + + enum class FlowStorageMode { + Dense, + SparseRealArcs + }; + + enum class EndpointStorageMode { + Dense, + ComputedRealArcs + }; + + enum class StateStorageMode { + Dense, + Packed + }; + + struct SimplexOptions { + bool arc_mixing; + CostStorageMode cost_storage_mode; + FlowStorageMode flow_storage_mode; + EndpointStorageMode endpoint_storage_mode; + StateStorageMode state_storage_mode; + + explicit SimplexOptions(bool arc_mixing_ = false) + : arc_mixing(arc_mixing_), + cost_storage_mode(CostStorageMode::Dense), + flow_storage_mode(FlowStorageMode::Dense), + endpoint_storage_mode(EndpointStorageMode::Dense), + state_storage_mode(StateStorageMode::Dense) {} + }; + + NetworkSimplexSimple( + const GR& graph, SimplexOptions options, int nbnodes, + ArcsType nb_arcs, uint64_t maxiters) : _graph(graph), //_arc_id(graph), - _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), + _arc_mixing(options.arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), INF(std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : MAX), - _lazy_cost(false), _coords_a(nullptr), _coords_b(nullptr), _dim(0), _metric(0), _n1(0), _n2(0), - _dense_cost(false), _D_ptr(nullptr), _D_n2(0), + _cost_mode(CostMode::StoredArray), + _cost_storage_mode(options.cost_storage_mode), + _flow_storage_mode(options.flow_storage_mode), + _endpoint_storage_mode(options.endpoint_storage_mode), + _state_storage_mode(options.state_storage_mode), + _coords_a(nullptr), _coords_b(nullptr), _dim(0), _metric(0), _n1(0), _n2(0), + _D_ptr(nullptr), _D_n2(0), _warmstart_provided(false), _warmstart_tree_built(false), _max_cost(0), _has_max_cost(false) { @@ -310,6 +359,52 @@ namespace lemon { STATE_LOWER = 1 }; + class PackedStateVector { + public: + void resize(ArcsType n) { + _size = n; + _data.assign((static_cast(n) + 3) / 4, 0); + } + + void clear() { + _size = 0; + _data.clear(); + } + + void fill(ArcsType count, signed char state) { + for (ArcsType i = 0; i < count; ++i) { + set(i, state); + } + } + + signed char get(ArcsType index) const { + const uint8_t bits = (_data[static_cast(index) / 4] >> + (2 * (static_cast(index) % 4))) & 0x03; + if (bits == 0) return STATE_LOWER; + if (bits == 1) return STATE_TREE; + return STATE_UPPER; + } + + void set(ArcsType index, signed char state) { + const size_t byte_index = static_cast(index) / 4; + const size_t shift = 2 * (static_cast(index) % 4); + _data[byte_index] = static_cast( + (_data[byte_index] & ~(uint8_t(0x03) << shift)) | + (encode(state) << shift) + ); + } + + private: + static uint8_t encode(signed char state) { + if (state == STATE_LOWER) return 0; + if (state == STATE_TREE) return 1; + return 2; + } + + ArcsType _size; + std::vector _data; + }; + typedef std::vector StateVector; // Note: vector is used instead of vector for // efficiency reasons @@ -336,28 +431,38 @@ namespace lemon { // IntArcMap _arc_id; IntVector _source; // keep nodes as integers IntVector _target; + IntVector _artificial_source; + IntVector _artificial_target; bool _arc_mixing; public: // Node and arc data CostVector _cost; ValueVector _supply; ValueVector _flow; - //SparseValueVector _flow; + ValueVector _artificial_flow; CostVector _pi; - // Lazy cost computation support - bool _lazy_cost; + // Cost access support + CostMode _cost_mode; + CostStorageMode _cost_storage_mode; + FlowStorageMode _flow_storage_mode; + EndpointStorageMode _endpoint_storage_mode; + StateStorageMode _state_storage_mode; const double* _coords_a; const double* _coords_b; int _dim; int _metric; // 0: sqeuclidean, 1: euclidean, 2: cityblock // Dense cost matrix pointer (lazy access, no copy) - bool _dense_cost; const double* _D_ptr; // pointer to row-major cost matrix int _D_n2; // number of columns in D (original n2) private: + // Sparse real-flow mode is used by lazy full-bipartite solves: + // real arcs use _real_flow, artificial arcs use _artificial_flow. + // Dense mode keeps the original all-arc _flow vector. + std::map _real_flow; + // Warmstart data bool _warmstart_provided; // Flag indicating warmstart is available bool _warmstart_tree_built; // Flag: tree was built by warmstartInit() @@ -372,6 +477,7 @@ namespace lemon { IntVector _dirty_revs; BoolVector _forward; StateVector _state; + PackedStateVector _packed_state; ArcsType _root; // Temporary data used in the current pivot iteration @@ -423,6 +529,199 @@ namespace lemon { return n; } + inline bool usesStoredCost() const { + return _cost_mode == CostMode::StoredArray; + } + + inline bool usesDenseCost() const { + return _cost_mode == CostMode::DenseMatrix; + } + + inline bool usesLazyCost() const { + return _cost_mode == CostMode::LazyGeometry; + } + + inline bool usesArtificialCostOnly() const { + return _cost_storage_mode == CostStorageMode::ArtificialOnly; + } + + inline bool usesSparseRealFlow() const { + return _flow_storage_mode == FlowStorageMode::SparseRealArcs; + } + + inline bool usesComputedRealEndpoints() const { + return _endpoint_storage_mode == EndpointStorageMode::ComputedRealArcs; + } + + inline bool usesPackedState() const { + return _state_storage_mode == StateStorageMode::Packed; + } + + Cost computeLazyCostUpperBound() const { + Cost squared_range_sum = 0; + Cost l1_range_sum = 0; + + for (int d = 0; d < _dim; ++d) { + Cost min_value = _coords_a[d]; + Cost max_value = _coords_a[d]; + + for (int i = 0; i < _n1; ++i) { + const Cost value = _coords_a[i * _dim + d]; + if (value < min_value) min_value = value; + if (value > max_value) max_value = value; + } + for (int j = 0; j < _n2; ++j) { + const Cost value = _coords_b[j * _dim + d]; + if (value < min_value) min_value = value; + if (value > max_value) max_value = value; + } + + const Cost range = max_value - min_value; + squared_range_sum += range * range; + l1_range_sum += range; + } + + if (_metric == 0) return squared_range_sum; + if (_metric == 1) return std::sqrt(squared_range_sum); + return l1_range_sum; + } + + Cost maxRealArcCost() { + if (_has_max_cost) { + return _max_cost; + } + + Cost max_cost = 0; + for (ArcsType i = 0; i != _arc_num; ++i) { + Cost cost = getCostForArc(i); + if (i == 0 || cost > max_cost) { + max_cost = cost; + } + } + _max_cost = max_cost; + _has_max_cost = true; + return max_cost; + } + + inline int arcSource(ArcsType arc_id) const { + if (usesComputedRealEndpoints()) { + if (arc_id < _arc_num) { + const ArcsType graph_arc = _arc_num - arc_id - 1; + return _node_id(static_cast(graph_arc / _n2)); + } + return _artificial_source[arc_id - _arc_num]; + } + return _source[arc_id]; + } + + inline int arcTarget(ArcsType arc_id) const { + if (usesComputedRealEndpoints()) { + if (arc_id < _arc_num) { + const ArcsType graph_arc = _arc_num - arc_id - 1; + return _node_id(static_cast(graph_arc % _n2) + _n1); + } + return _artificial_target[arc_id - _arc_num]; + } + return _target[arc_id]; + } + + inline void setArcEndpoints(ArcsType arc_id, int source, int target) { + if (usesComputedRealEndpoints()) { + if (arc_id >= _arc_num) { + _artificial_source[arc_id - _arc_num] = source; + _artificial_target[arc_id - _arc_num] = target; + } + return; + } + _source[arc_id] = source; + _target[arc_id] = target; + } + + inline void setArcCost(ArcsType arc_id, Cost cost) { + if (usesArtificialCostOnly()) { + if (arc_id >= _arc_num) { + _cost[arc_id - _arc_num] = cost; + } + } else { + _cost[arc_id] = cost; + } + if (!_has_max_cost || cost > _max_cost) { + _max_cost = cost; + _has_max_cost = true; + } + } + + inline signed char arcState(ArcsType arc_id) const { + if (usesPackedState()) { + return _packed_state.get(arc_id); + } + return _state[arc_id]; + } + + inline void setArcState(ArcsType arc_id, signed char state) { + if (usesPackedState()) { + _packed_state.set(arc_id, state); + return; + } + _state[arc_id] = state; + } + + inline void flipArcState(ArcsType arc_id) { + setArcState(arc_id, -arcState(arc_id)); + } + + inline void fillArcStates(ArcsType count, signed char state) { + if (usesPackedState()) { + _packed_state.fill(count, state); + } else { + std::fill_n(_state.begin(), count, state); + } + } + + inline ArcsType flowArcCount() const { + return usesSparseRealFlow() ? _all_arc_num : static_cast(_flow.size()); + } + + inline Value arcFlow(ArcsType arc_id) const { + if (usesSparseRealFlow()) { + if (arc_id < _arc_num) { + typename std::map::const_iterator it = + _real_flow.find(arc_id); + return it == _real_flow.end() ? Value(0) : it->second; + } + return _artificial_flow[arc_id - _arc_num]; + } + return _flow[arc_id]; + } + + inline void setArcFlow(ArcsType arc_id, Value flow) { + if (usesSparseRealFlow()) { + if (arc_id < _arc_num) { + if (flow == 0) { + _real_flow.erase(arc_id); + } else { + _real_flow[arc_id] = flow; + } + } else { + _artificial_flow[arc_id - _arc_num] = flow; + } + return; + } + _flow[arc_id] = flow; + } + + inline void addArcFlow(ArcsType arc_id, Value delta) { + if (usesSparseRealFlow()) { + if (arc_id < _arc_num) { + setArcFlow(arc_id, arcFlow(arc_id) + delta); + } else { + _artificial_flow[arc_id - _arc_num] += delta; + } + return; + } + _flow[arc_id] += delta; + } + // finally unused because too slow inline ArcsType getSource(const ArcsType arc) const { @@ -459,10 +758,6 @@ namespace lemon { private: // References to the NetworkSimplexSimple class - const IntVector &_source; - const IntVector &_target; - const CostVector &_cost; - const StateVector &_state; const CostVector &_pi; ArcsType &_in_arc; ArcsType _search_arc_num; @@ -476,8 +771,7 @@ namespace lemon { // Constructor BlockSearchPivotRule(NetworkSimplexSimple &ns) : - _source(ns._source), _target(ns._target), - _cost(ns._cost), _state(ns._state), _pi(ns._pi), + _pi(ns._pi), _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num), _next_arc(0),_ns(ns) { @@ -490,40 +784,7 @@ namespace lemon { // Get cost for an arc (either from pre-computed array or compute lazily) inline Cost getCost(ArcsType e) const { - if (_ns._dense_cost) { - // Dense matrix mode: read directly from D pointer - return _ns._D_ptr[_ns._arc_num - e - 1]; - } else if (!_ns._lazy_cost) { - return _cost[e]; - } else { - // For lazy mode, compute cost from coordinates inline - // _source and _target use reversed node numbering - int i = _ns._node_num - _source[e] - 1; - int j = _ns._node_num - _target[e] - 1 - _ns._n1; - - const double* xa = _ns._coords_a + i * _ns._dim; - const double* xb = _ns._coords_b + j * _ns._dim; - Cost cost = 0; - - if (_ns._metric == 0) { // sqeuclidean - for (int d = 0; d < _ns._dim; ++d) { - Cost diff = xa[d] - xb[d]; - cost += diff * diff; - } - return cost; - } else if (_ns._metric == 1) { // euclidean - for (int d = 0; d < _ns._dim; ++d) { - Cost diff = xa[d] - xb[d]; - cost += diff * diff; - } - return std::sqrt(cost); - } else { // cityblock - for (int d = 0; d < _ns._dim; ++d) { - cost += std::abs(xa[d] - xb[d]); - } - return cost; - } - } + return _ns.getCostForArc(e); } // Find next entering arc @@ -533,32 +794,32 @@ namespace lemon { ArcsType cnt = _block_size; double a; for (e = _next_arc; e != _search_arc_num; ++e) { - c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); + c = _ns.arcState(e) * (getCost(e) + _pi[_ns.arcSource(e)] - _pi[_ns.arcTarget(e)]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { - a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); + a=fabs(_pi[_ns.arcSource(_in_arc)])>fabs(_pi[_ns.arcTarget(_in_arc)]) ? fabs(_pi[_ns.arcSource(_in_arc)]):fabs(_pi[_ns.arcTarget(_in_arc)]); a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } for (e = 0; e != _next_arc; ++e) { - c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); + c = _ns.arcState(e) * (getCost(e) + _pi[_ns.arcSource(e)] - _pi[_ns.arcTarget(e)]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { - a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); + a=fabs(_pi[_ns.arcSource(_in_arc)])>fabs(_pi[_ns.arcTarget(_in_arc)]) ? fabs(_pi[_ns.arcSource(_in_arc)]):fabs(_pi[_ns.arcTarget(_in_arc)]); a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } - a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); + a=fabs(_pi[_ns.arcSource(_in_arc)])>fabs(_pi[_ns.arcTarget(_in_arc)]) ? fabs(_pi[_ns.arcSource(_in_arc)]):fabs(_pi[_ns.arcTarget(_in_arc)]); a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min >= -EPSILON*a) return false; @@ -605,12 +866,7 @@ namespace lemon { NetworkSimplexSimple& costMap(const CostMap& map) { Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a)) { - Cost c = map[a]; - _cost[getArcID(a)] = c; - if (!_has_max_cost || c > _max_cost) { - _max_cost = c; - _has_max_cost = true; - } + setArcCost(getArcID(a), map[a]); } return *this; } @@ -627,11 +883,7 @@ namespace lemon { /// \return (*this) template NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) { - _cost[getArcID(arc)] = cost; - if (!_has_max_cost || cost > _max_cost) { - _max_cost = cost; - _has_max_cost = true; - } + setArcCost(getArcID(arc), cost); return *this; } @@ -649,13 +901,15 @@ namespace lemon { /// \return (*this) NetworkSimplexSimple& setLazyCost(const double* coords_a, const double* coords_b, int dim, int metric, int n1, int n2) { - _lazy_cost = true; + _cost_mode = CostMode::LazyGeometry; _coords_a = coords_a; _coords_b = coords_b; _dim = dim; _metric = metric; _n1 = n1; _n2 = n2; + _max_cost = computeLazyCostUpperBound(); + _has_max_cost = true; return *this; } @@ -670,7 +924,7 @@ namespace lemon { /// /// \return (*this) NetworkSimplexSimple& setDenseCostMatrix(const double* D, int n2) { - _dense_cost = true; + _cost_mode = CostMode::DenseMatrix; _D_ptr = D; _D_n2 = n2; // Precompute max cost once for reuse in init() @@ -726,28 +980,31 @@ namespace lemon { /// \param arc_id The arc ID /// \return Cost of the arc inline Cost getCostForArc(ArcsType arc_id) const { - if (_dense_cost) { + if (usesDenseCost()) { // Dense matrix mode: read directly from D pointer // For artificial arcs (>= _arc_num), read from _cost array if (arc_id >= _arc_num) { - return _cost[arc_id]; + return usesArtificialCostOnly() ? + _cost[arc_id - _arc_num] : _cost[arc_id]; } // Without arc mixing: internal arc_id maps to graph arc = _arc_num - arc_id - 1 // graph arc g encodes source i = g / m, target j = g % m // cost = D[i * _D_n2 + j] = D[g] (since m == _D_n2) return _D_ptr[_arc_num - arc_id - 1]; - } else if (!_lazy_cost) { - return _cost[arc_id]; + } else if (usesStoredCost()) { + return usesArtificialCostOnly() ? + _cost[arc_id - _arc_num] : _cost[arc_id]; } else { // For artificial arcs (>= _arc_num), return stored cost // (0 for positive supply, ART_COST for negative supply) if (arc_id >= _arc_num) { - return _cost[arc_id]; + return usesArtificialCostOnly() ? + _cost[arc_id - _arc_num] : _cost[arc_id]; } // Compute lazily from coordinates // Convert internal node IDs back to graph node IDs, then to coordinate indices - int i = _node_num - _source[arc_id] - 1; // graph source in [0, _n1-1] - int j = _node_num - _target[arc_id] - 1 - _n1; // graph target in [_n1, _node_num-1] -> [0, _n2-1] + int i = _node_num - arcSource(arc_id) - 1; // graph source in [0, _n1-1] + int j = _node_num - arcTarget(arc_id) - 1 - _n1; // graph target in [_n1, _node_num-1] -> [0, _n2-1] return computeLazyCost(i, j); } } @@ -954,7 +1211,7 @@ namespace lemon { std::fill_n(_supply.begin(), _node_num, Value(0)); // In dense/lazy modes, real-arc costs are not read from _cost. // Keep the default fill for the regular explicit-cost mode only. - if (!_dense_cost && !_lazy_cost) { + if (usesStoredCost() && !usesArtificialCostOnly()) { std::fill_n(_cost.begin(), _arc_num, Cost(1)); } _stype = GEQ; @@ -997,13 +1254,35 @@ namespace lemon { _arc_num = _init_nb_arcs; int all_node_num = _node_num + 1; ArcsType max_arc_num = _arc_num + 2 * _node_num; + _all_arc_num = max_arc_num; - _source.resize(max_arc_num); - _target.resize(max_arc_num); + if (usesComputedRealEndpoints()) { + _source.clear(); + _target.clear(); + _artificial_source.resize(2 * _node_num); + _artificial_target.resize(2 * _node_num); + } else { + _source.resize(max_arc_num); + _target.resize(max_arc_num); + _artificial_source.clear(); + _artificial_target.clear(); + } - _cost.resize(max_arc_num); + if (usesArtificialCostOnly()) { + _cost.resize(2 * _node_num); + } else { + _cost.resize(max_arc_num); + } _supply.resize(all_node_num); - _flow.resize(max_arc_num); + if (usesSparseRealFlow()) { + _flow.clear(); + _artificial_flow.assign(2 * _node_num, Value(0)); + _real_flow.clear(); + } else { + _flow.resize(max_arc_num); + _artificial_flow.clear(); + _real_flow.clear(); + } _pi.resize(all_node_num); _parent.resize(all_node_num); @@ -1013,11 +1292,18 @@ namespace lemon { _rev_thread.resize(all_node_num); _succ_num.resize(all_node_num); _last_succ.resize(all_node_num); - _state.resize(max_arc_num); + if (usesPackedState()) { + _state.clear(); + _packed_state.resize(max_arc_num); + } else { + _state.resize(max_arc_num); + _packed_state.clear(); + } - //_arc_mixing=false; - if (_arc_mixing) { + if (usesComputedRealEndpoints()) { + // Real full-bipartite arc endpoints are computed from arc ids. + } else if (_arc_mixing) { // Store the arcs in a mixed order const ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10)); mixingCoeff = k; @@ -1028,8 +1314,11 @@ namespace lemon { ArcsType i = 0, j = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a)) { - _source[i] = _node_id(_graph.source(a)); - _target[i] = _node_id(_graph.target(a)); + setArcEndpoints( + i, + _node_id(_graph.source(a)), + _node_id(_graph.target(a)) + ); //_arc_id[a] = i; if ((i += k) >= _arc_num) i = ++j; } @@ -1038,8 +1327,11 @@ namespace lemon { ArcsType i = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a), ++i) { - _source[i] = _node_id(_graph.source(a)); - _target[i] = _node_id(_graph.target(a)); + setArcEndpoints( + i, + _node_id(_graph.source(a)), + _node_id(_graph.target(a)) + ); //_arc_id[a] = i; } } @@ -1096,24 +1388,9 @@ namespace lemon { c += Number(it->second) * Number(_cost[it->first]); return c;*/ - if (_dense_cost) { - // Dense matrix mode: compute cost from D pointer - for (ArcsType i=0; i<_flow.size(); i++) { - if (_flow[i] != 0) { - c += _flow[i] * Number(getCostForArc(i)); - } - } - } else if (!_lazy_cost) { - for (ArcsType i=0; i<_flow.size(); i++) - c += _flow[i] * Number(_cost[i]); - } else { - // Compute costs lazily - for (ArcsType i=0; i<_flow.size(); i++) { - if (_flow[i] != 0) { - int src = _node_num - _source[i] - 1; - int tgt = _node_num - _target[i] - 1 - _n1; - c += _flow[i] * Number(computeLazyCost(src, tgt)); - } + for (ArcsType i=0; i maxheap; for (ArcsType e = 0; e < _arc_num; ++e) { - _state[e] = STATE_LOWER; - Cost c; - if (_lazy_cost) { - // Compute cost on-the-fly for lazy mode - c = getCostForArc(e); - } else { - c = _cost[e]; - } + setArcState(e, STATE_LOWER); + Cost c = getCostForArc(e); if (c > ART_COST) ART_COST = c; - Cost rc = fabs(c + _pi[_source[e]] - _pi[_target[e]]); + Cost rc = fabs(c + _pi[arcSource(e)] - _pi[arcTarget(e)]); if ((ArcsType)maxheap.size() < K) { maxheap.push({rc, e}); } else if (rc < maxheap.top().first) { @@ -1247,8 +1518,8 @@ namespace lemon { for (ArcsType idx = 0; idx < (ArcsType)candidates.size() && tree_edges < _node_num - 1; ++idx) { ArcsType e = candidates[idx].second; - int s = _source[e]; - int t = _target[e]; + int s = arcSource(e); + int t = arcTarget(e); int rs = s, rt = t; while (uf_parent[rs] != rs) { uf_parent[rs] = uf_parent[uf_parent[rs]]; rs = uf_parent[rs]; } while (uf_parent[rt] != rt) { uf_parent[rt] = uf_parent[uf_parent[rt]]; rt = uf_parent[rt]; } @@ -1267,8 +1538,8 @@ namespace lemon { for (ArcsType e = 0; e < _arc_num && tree_edges < _node_num - 1; ++e) { if (considered[e]) continue; - int s = _source[e]; - int t = _target[e]; + int s = arcSource(e); + int t = arcTarget(e); int rs = s, rt = t; while (uf_parent[rs] != rs) { uf_parent[rs] = uf_parent[uf_parent[rs]]; rs = uf_parent[rs]; } while (uf_parent[rt] != rt) { uf_parent[rt] = uf_parent[uf_parent[rt]]; rt = uf_parent[rt]; } @@ -1285,8 +1556,8 @@ namespace lemon { std::vector tree_adj_deg(_node_num, 0); for (int k = 0; k < tree_edges; ++k) { ArcsType e = tree_arcs[k]; - tree_adj_deg[_source[e]]++; - tree_adj_deg[_target[e]]++; + tree_adj_deg[arcSource(e)]++; + tree_adj_deg[arcTarget(e)]++; } std::vector tree_adj_start(_node_num + 1, 0); for (int i = 0; i < _node_num; ++i) { @@ -1298,7 +1569,7 @@ namespace lemon { std::vector tree_adj_pos(_node_num, 0); for (int k = 0; k < tree_edges; ++k) { ArcsType e = tree_arcs[k]; - int s = _source[e], t = _target[e]; + int s = arcSource(e), t = arcTarget(e); int ps = tree_adj_start[s] + tree_adj_pos[s]++; tree_adj_node[ps] = t; tree_adj_arc[ps] = e; @@ -1313,17 +1584,15 @@ namespace lemon { _root = _node_num; for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { - _state[e] = STATE_TREE; + setArcState(e, STATE_TREE); if (_supply[u] >= 0) { - _source[e] = u; - _target[e] = _root; - _cost[e] = 0; - _flow[e] = _supply[u]; + setArcEndpoints(e, u, _root); + setArcCost(e, 0); + setArcFlow(e, _supply[u]); } else { - _source[e] = _root; - _target[e] = u; - _cost[e] = ART_COST; - _flow[e] = -_supply[u]; + setArcEndpoints(e, _root, u); + setArcCost(e, ART_COST); + setArcFlow(e, -_supply[u]); } } @@ -1344,7 +1613,7 @@ namespace lemon { _parent[u] = _root; _pred[u] = _arc_num + u; _forward[u] = (_supply[u] >= 0); // same as init() - _state[_arc_num + u] = STATE_TREE; + setArcState(_arc_num + u, STATE_TREE); visited[u] = true; std::queue bfs_queue; @@ -1360,11 +1629,11 @@ namespace lemon { _parent[w] = v; _pred[w] = arc_e; - _state[arc_e] = STATE_TREE; - _forward[w] = (_source[arc_e] == w); + setArcState(arc_e, STATE_TREE); + _forward[w] = (arcSource(arc_e) == w); - _state[_arc_num + w] = STATE_LOWER; - _flow[_arc_num + w] = 0; + setArcState(_arc_num + w, STATE_LOWER); + setArcFlow(_arc_num + w, 0); bfs_queue.push(w); } @@ -1441,25 +1710,25 @@ namespace lemon { Value f = _forward[u] ? net[u] : -net[u]; if (f >= 0) { - _flow[e] = f; + setArcFlow(e, f); net[_parent[u]] += net[u]; } else { if (e < _arc_num) { - _state[e] = STATE_LOWER; - _flow[e] = 0; + setArcState(e, STATE_LOWER); + setArcFlow(e, 0); } // Reconnect u to root via artificial arc ArcsType art_e = _arc_num + u; _parent[u] = _root; _pred[u] = art_e; - _forward[u] = (_source[art_e] == u); - _state[art_e] = STATE_TREE; + _forward[u] = (arcSource(art_e) == u); + setArcState(art_e, STATE_TREE); Value art_f = _forward[u] ? net[u] : -net[u]; - _flow[art_e] = art_f >= 0 ? art_f : -art_f; + setArcFlow(art_e, art_f >= 0 ? art_f : -art_f); if (art_f < 0) { _forward[u] = !_forward[u]; - _flow[art_e] = -art_f; + setArcFlow(art_e, -art_f); } net[_root] += net[u]; @@ -1544,33 +1813,11 @@ namespace lemon { if (std::numeric_limits::is_exact) { ART_COST = std::numeric_limits::max() / 2 + 1; } else { - // Prefer precomputed max to avoid rescans on repeated runs - Cost max_cost = 0; - if (_has_max_cost) { - max_cost = _max_cost; - } else if (_dense_cost) { - max_cost = *_D_ptr; - for (ArcsType i = 1; i != _arc_num; ++i) { - if (_D_ptr[i] > max_cost) max_cost = _D_ptr[i]; - } - } else if (!_lazy_cost) { - max_cost = _cost[0]; - for (ArcsType i = 1; i != _arc_num; ++i) { - if (_cost[i] > max_cost) max_cost = _cost[i]; - } - } else { - // Lazy cost: fall back to on-the-fly computation - for (ArcsType i = 0; i != _arc_num; ++i) { - Cost c = getCostForArc(i); - if (c > max_cost) max_cost = c; - } - } + Cost max_cost = maxRealArcCost(); ART_COST = (max_cost + 1) * _node_num; - _max_cost = max_cost; - _has_max_cost = true; } - memset(&_state[0], STATE_LOWER, _arc_num); + fillArcStates(_arc_num, STATE_LOWER); // Set data for the artificial root node _root = _node_num; @@ -1595,21 +1842,19 @@ namespace lemon { _rev_thread[u + 1] = u; _succ_num[u] = 1; _last_succ[u] = u; - _state[e] = STATE_TREE; + setArcState(e, STATE_TREE); if (_supply[u] >= 0) { _forward[u] = true; _pi[u] = 0; - _source[e] = u; - _target[e] = _root; - _flow[e] = _supply[u]; - _cost[e] = 0; + setArcEndpoints(e, u, _root); + setArcFlow(e, _supply[u]); + setArcCost(e, 0); } else { _forward[u] = false; _pi[u] = ART_COST; - _source[e] = _root; - _target[e] = u; - _flow[e] = -_supply[u]; - _cost[e] = ART_COST; + setArcEndpoints(e, _root, u); + setArcFlow(e, -_supply[u]); + setArcCost(e, ART_COST); } } } @@ -1627,25 +1872,22 @@ namespace lemon { _forward[u] = true; _pi[u] = 0; _pred[u] = e; - _source[e] = u; - _target[e] = _root; - _flow[e] = _supply[u]; - _cost[e] = 0; - _state[e] = STATE_TREE; + setArcEndpoints(e, u, _root); + setArcFlow(e, _supply[u]); + setArcCost(e, 0); + setArcState(e, STATE_TREE); } else { _forward[u] = false; _pi[u] = ART_COST; _pred[u] = f; - _source[f] = _root; - _target[f] = u; - _flow[f] = -_supply[u]; - _cost[f] = ART_COST; - _state[f] = STATE_TREE; - _source[e] = u; - _target[e] = _root; - //_flow[e] = 0; //by default, the sparse matrix is empty - _cost[e] = 0; - _state[e] = STATE_LOWER; + setArcEndpoints(f, _root, u); + setArcFlow(f, -_supply[u]); + setArcCost(f, ART_COST); + setArcState(f, STATE_TREE); + setArcEndpoints(e, u, _root); + // Flow is zero by default. + setArcCost(e, 0); + setArcState(e, STATE_LOWER); ++f; } } @@ -1665,25 +1907,22 @@ namespace lemon { _forward[u] = false; _pi[u] = 0; _pred[u] = e; - _source[e] = _root; - _target[e] = u; - _flow[e] = -_supply[u]; - _cost[e] = 0; - _state[e] = STATE_TREE; + setArcEndpoints(e, _root, u); + setArcFlow(e, -_supply[u]); + setArcCost(e, 0); + setArcState(e, STATE_TREE); } else { _forward[u] = true; _pi[u] = -ART_COST; _pred[u] = f; - _source[f] = u; - _target[f] = _root; - _flow[f] = _supply[u]; - _state[f] = STATE_TREE; - _cost[f] = ART_COST; - _source[e] = _root; - _target[e] = u; - //_flow[e] = 0; - _cost[e] = 0; - _state[e] = STATE_LOWER; + setArcEndpoints(f, u, _root); + setArcFlow(f, _supply[u]); + setArcState(f, STATE_TREE); + setArcCost(f, ART_COST); + setArcEndpoints(e, _root, u); + // Flow is zero by default. + setArcCost(e, 0); + setArcState(e, STATE_LOWER); ++f; } } @@ -1695,8 +1934,8 @@ namespace lemon { // Find the join node void findJoinNode() { - int u = _source[in_arc]; - int v = _target[in_arc]; + int u = arcSource(in_arc); + int v = arcTarget(in_arc); while (u != v) { if (_succ_num[u] < _succ_num[v]) { u = _parent[u]; @@ -1712,12 +1951,12 @@ namespace lemon { bool findLeavingArc() { // Initialize first and second nodes according to the direction // of the cycle - if (_state[in_arc] == STATE_LOWER) { - first = _source[in_arc]; - second = _target[in_arc]; + if (arcState(in_arc) == STATE_LOWER) { + first = arcSource(in_arc); + second = arcTarget(in_arc); } else { - first = _target[in_arc]; - second = _source[in_arc]; + first = arcTarget(in_arc); + second = arcSource(in_arc); } delta = INF; char result = 0; @@ -1727,7 +1966,7 @@ namespace lemon { // Search the cycle along the path form the first node to the root for (int u = first; u != join; u = _parent[u]) { e = _pred[u]; - d = _forward[u] ? _flow[e] : INF ; + d = _forward[u] ? arcFlow(e) : INF ; if (d < delta) { delta = d; u_out = u; @@ -1737,7 +1976,7 @@ namespace lemon { // Search the cycle along the path form the second node to the root for (int u = second; u != join; u = _parent[u]) { e = _pred[u]; - d = _forward[u] ? INF : _flow[e]; + d = _forward[u] ? INF : arcFlow(e); if (d <= delta) { delta = d; u_out = u; @@ -1755,26 +1994,28 @@ namespace lemon { return result != 0; } - // Change _flow and _state vectors + // Change flow and state vectors void changeFlow(bool change) { // Augment along the cycle if (delta > 0) { - Value val = _state[in_arc] * delta; - _flow[in_arc] += val; - for (int u = _source[in_arc]; u != join; u = _parent[u]) { - _flow[_pred[u]] += _forward[u] ? -val : val; + Value val = arcState(in_arc) * delta; + addArcFlow(in_arc, val); + for (int u = arcSource(in_arc); u != join; u = _parent[u]) { + addArcFlow(_pred[u], _forward[u] ? -val : val); } - for (int u = _target[in_arc]; u != join; u = _parent[u]) { - _flow[_pred[u]] += _forward[u] ? val : -val; + for (int u = arcTarget(in_arc); u != join; u = _parent[u]) { + addArcFlow(_pred[u], _forward[u] ? val : -val); } } // Update the state of the entering and leaving arcs if (change) { - _state[in_arc] = STATE_TREE; - _state[_pred[u_out]] = - (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER; + setArcState(in_arc, STATE_TREE); + setArcState( + _pred[u_out], + (arcFlow(_pred[u_out]) == 0) ? STATE_LOWER : STATE_UPPER + ); } else { - _state[in_arc] = -_state[in_arc]; + flipArcState(in_arc); } } @@ -1857,7 +2098,7 @@ namespace lemon { u = w; } _pred[u_in] = in_arc; - _forward[u_in] = (u_in == _source[in_arc]); + _forward[u_in] = (u_in == arcSource(in_arc)); _succ_num[u_in] = old_succ_num; // Set limits for updating _last_succ form v_in and v_out @@ -1995,8 +2236,8 @@ namespace lemon { for (ArcsType i = 0; i != arc_vector.size(); ++i) { in_arc = arc_vector[i]; // l'erreur est probablement ici... - if (_state[in_arc] * (getCostForArc(in_arc) + _pi[_source[in_arc]] - - _pi[_target[in_arc]]) >= 0) continue; + if (arcState(in_arc) * (getCostForArc(in_arc) + _pi[arcSource(in_arc)] - + _pi[arcTarget(in_arc)]) >= 0) continue; findJoinNode(); bool change = findLeavingArc(); if (delta >= MAX) return false; @@ -2048,11 +2289,11 @@ namespace lemon { // Check feasibility if( retVal == OPTIMAL){ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { - if (_flow[e] != 0){ - if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 + if (arcFlow(e) != 0){ + if (fabs(arcFlow(e)) > _EPSILON) // change of the original code following issue #126 return INFEASIBLE; else - _flow[e]=0; + setArcFlow(e, 0); } } diff --git a/ot/solvers.py b/ot/solvers.py index 25f3bd32f..88cf5c7ab 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1763,6 +1763,7 @@ def solve_sample( and X_b is not None ): # Use lazy EMD solver with coordinates (no regularization, balanced) + nx = get_backend(X_a, X_b, a, b) value_linear, log = emd2_lazy( X_a, X_b, @@ -1779,7 +1780,8 @@ def solve_sample( potentials=(log["u"], log["v"]), value=value_linear, value_linear=value_linear, - plan=log["G"], + sparse_plan=log["G"], + backend=nx, status=log["warning"] if log["warning"] is not None else "Converged", ) diff --git a/ot/utils.py b/ot/utils.py index 64bf1ace9..95cded2ee 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1192,8 +1192,12 @@ def plan(self): """Transport plan, encoded as a dense array.""" # N.B.: We may catch out-of-memory errors and suggest # the use of lazy_plan or sparse_plan when appropriate. - - return self._plan + if self._plan is not None: + return self._plan + elif self._sparse_plan is not None: + return self._backend.todense(self._sparse_plan) + else: + return None @property def sparse_plan(self): diff --git a/test/test_ot.py b/test/test_ot.py index dba8466e0..28e19639c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -1231,6 +1231,42 @@ def test_emd_lazy_warmstart(): np.testing.assert_allclose(cost_warm_emd2, log_cold["cost"], rtol=1e-7) +def test_emd2_lazy_returns_sparse_plan(): + from scipy.sparse import issparse + from ot.lp import emd2_lazy + + n_s = 12 + n_t = 15 + rng = np.random.RandomState(42) + + X_s = rng.randn(n_s, 2) + X_t = rng.randn(n_t, 2) + a = ot.utils.unif(n_s) + b = ot.utils.unif(n_t) + M = ot.dist(X_s, X_t, metric="sqeuclidean") + + cost_dense = ot.emd2(a, b, M) + cost_lazy, log_lazy = emd2_lazy( + X_s, + X_t, + a, + b, + metric="sqeuclidean", + log=True, + return_matrix=True, + ) + G_lazy = log_lazy["G"] + + assert issparse(G_lazy) + assert G_lazy.shape == (n_s, n_t) + assert G_lazy.nnz <= n_s + n_t + np.testing.assert_allclose(cost_lazy, cost_dense, rtol=1e-10, atol=1e-10) + + G_lazy_dense = G_lazy.toarray() + np.testing.assert_allclose(G_lazy_dense.sum(axis=1), a, rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(G_lazy_dense.sum(axis=0), b, rtol=1e-6, atol=1e-8) + + def test_emd_sparse_warmstart(): n = 100 rng = np.random.RandomState(42)