diff --git a/GRPC_INTERFACE.md b/GRPC_INTERFACE.md new file mode 100644 index 0000000000..cdcffead97 --- /dev/null +++ b/GRPC_INTERFACE.md @@ -0,0 +1,392 @@ +# gRPC Interface Architecture + +## Overview + +The cuOpt remote execution system uses gRPC for client-server communication. The interface +supports arbitrarily large optimization problems (multi-GB) through a chunked array transfer +protocol that uses only unary (request-response) RPCs — no bidirectional streaming. + +All client-server serialization uses protocol buffers generated by `protoc` and +`grpc_cpp_plugin`. The internal server-to-worker pipe uses protobuf for metadata +headers and raw byte transfer for bulk array data (see Security Notes). + +## Directory Layout + +All gRPC-related C++ source lives under a single tree: + +``` +cpp/src/grpc/ +├── cuopt_remote.proto # Base protobuf messages (job status, settings, etc.) +├── cuopt_remote_service.proto # Service definition + messages (SubmitJob, ChunkedUpload, Incumbent, etc.) +├── grpc_problem_mapper.{hpp,cpp} # CPU problem ↔ proto (incl. chunked header) +├── grpc_solution_mapper.{hpp,cpp} # LP/MIP solution ↔ proto (unary + chunked) +├── grpc_settings_mapper.{hpp,cpp} # PDLP/MIP settings ↔ proto +├── grpc_service_mapper.{hpp,cpp} # Request/response builders (status, cancel, stream logs, etc.) +├── client/ +│ ├── grpc_client.{hpp,cpp} # High-level client: connect, submit, poll, get result +│ └── solve_remote.cpp # solve_lp_remote / solve_mip_remote (uses grpc_client) +└── server/ + ├── grpc_server_main.cpp # main(), argument parsing, gRPC server setup + ├── grpc_service_impl.cpp # CuOptRemoteServiceImpl — all RPC handlers + ├── grpc_server_types.hpp # Shared types, globals, forward declarations + ├── grpc_field_element_size.hpp # ArrayFieldId → element byte size (codegen target) + ├── grpc_pipe_serialization.hpp # Pipe I/O: protobuf headers + raw byte arrays (request/result) + ├── grpc_incumbent_proto.hpp # Incumbent proto build/parse (codegen target) + ├── grpc_worker.cpp # worker_process(), incumbent callback, store_simple_result + ├── grpc_worker_infra.cpp # Pipes, spawn, wait_for_workers, mark_worker_jobs_failed + ├── grpc_server_threads.cpp # result_retrieval, incumbent_retrieval, session_reaper + └── grpc_job_management.cpp # Pipe I/O, submit_job_async, check_status, cancel, etc. +``` + +- **Protos**: Live in `cpp/src/grpc/`. CMake generates C++ in the build dir (`cuopt_remote.pb.h`, `cuopt_remote_service.pb.h`, `cuopt_remote_service.grpc.pb.h`). +- **Mappers**: Shared by client and server; convert between host C++ types and protobuf. Used for unary and chunked paths. +- **Client**: Solver-level utility (not public API). Used by `solve_lp_remote`/`solve_mip_remote` and tests. +- **Server**: Standalone executable `cuopt_grpc_server`. See `GRPC_SERVER_ARCHITECTURE.md` for process model and file roles. + +## Protocol Files + +| File | Purpose | +|------|---------| +| `cpp/src/grpc/cuopt_remote.proto` | Message definitions (problems, settings, solutions, field IDs) | +| `cpp/src/grpc/cuopt_remote_service.proto` | gRPC service definition (RPCs) | + +Generated code is placed in the CMake build directory (not checked into source). + +## Service Interface + +```protobuf +service CuOptRemoteService { + // Job submission (small problems, single message) + rpc SubmitJob(SubmitJobRequest) returns (SubmitJobResponse); + + // Chunked upload (large problems, multiple unary RPCs) + rpc StartChunkedUpload(StartChunkedUploadRequest) returns (StartChunkedUploadResponse); + rpc SendArrayChunk(SendArrayChunkRequest) returns (SendArrayChunkResponse); + rpc FinishChunkedUpload(FinishChunkedUploadRequest) returns (SubmitJobResponse); + + // Job management + rpc CheckStatus(StatusRequest) returns (StatusResponse); + rpc CancelJob(CancelRequest) returns (CancelResponse); + rpc DeleteResult(DeleteRequest) returns (DeleteResponse); + + // Result retrieval (small results, single message) + rpc GetResult(GetResultRequest) returns (ResultResponse); + + // Chunked download (large results, multiple unary RPCs) + rpc StartChunkedDownload(StartChunkedDownloadRequest) returns (StartChunkedDownloadResponse); + rpc GetResultChunk(GetResultChunkRequest) returns (GetResultChunkResponse); + rpc FinishChunkedDownload(FinishChunkedDownloadRequest) returns (FinishChunkedDownloadResponse); + + // Blocking wait (returns status only, use GetResult afterward) + rpc WaitForCompletion(WaitRequest) returns (WaitResponse); + + // Real-time streaming + rpc StreamLogs(StreamLogsRequest) returns (stream LogMessage); + rpc GetIncumbents(IncumbentRequest) returns (IncumbentResponse); +} +``` + +## Chunked Array Transfer Protocol + +### Why Chunking? + +gRPC has per-message size limits (configurable, default set to 256 MiB in cuOpt), and +protobuf has a hard 2 GB serialization limit. Optimization problems and their solutions +can exceed several gigabytes, so a chunked transfer mechanism is needed. + +The protocol uses only **unary RPCs** (no bidirectional streaming), which simplifies +error handling, load balancing, and proxy compatibility. + +### Upload Protocol (Large Problems) + +When the estimated serialized problem size exceeds 75% of `max_message_bytes`, the client +splits large arrays into chunks and sends them via multiple unary RPCs: + +``` +Client Server + | | + |-- StartChunkedUpload(header, settings) -----> | + |<-- upload_id, max_message_bytes -------------- | + | | + |-- SendArrayChunk(upload_id, field, data) ----> | + |<-- ok ---------------------------------------- | + | | + |-- SendArrayChunk(upload_id, field, data) ----> | + |<-- ok ---------------------------------------- | + | ... | + | | + |-- FinishChunkedUpload(upload_id) ------------> | + |<-- job_id ------------------------------------ | +``` + +**Key features:** +- `StartChunkedUpload` sends a `ChunkedProblemHeader` with all scalar fields and + array metadata (`ArrayDescriptor` for each large array: field ID, total elements, + element size) +- Each `SendArrayChunk` carries one chunk of one array, identified by `ArrayFieldId` + and `element_offset` +- The server reports `max_message_bytes` so the client can adapt chunk sizing +- `FinishChunkedUpload` triggers server-side reassembly and job submission + +### Download Protocol (Large Results) + +When the result exceeds the gRPC max message size, the client fetches it via +chunked unary RPCs (mirrors the upload pattern): + +``` +Client Server + | | + |-- StartChunkedDownload(job_id) --------------> | + |<-- download_id, ChunkedResultHeader ---------- | + | | + |-- GetResultChunk(download_id, field, off) ----> | + |<-- data bytes --------------------------------- | + | | + |-- GetResultChunk(download_id, field, off) ----> | + |<-- data bytes --------------------------------- | + | ... | + | | + |-- FinishChunkedDownload(download_id) ---------> | + |<-- ok ----------------------------------------- | +``` + +**Key features:** +- `ChunkedResultHeader` carries all scalar fields (termination status, objectives, + residuals, solve time, warm start scalars) plus `ResultArrayDescriptor` entries + for each array (solution vectors, warm start arrays) +- Each `GetResultChunk` fetches a slice of one array, identified by `ResultFieldId` + and `element_offset` +- `FinishChunkedDownload` releases the server-side download session state +- LP results include PDLP warm start data (9 arrays + 8 scalars) for subsequent + warm-started solves + +### Automatic Routing + +The client handles size-based routing transparently: + +1. **Upload**: Estimate serialized problem size + - Below 75% of `max_message_bytes` → unary `SubmitJob` + - Above threshold → `StartChunkedUpload` + `SendArrayChunk` + `FinishChunkedUpload` +2. **Download**: Check `result_size_bytes` from `CheckStatus` + - Below `max_message_bytes` → unary `GetResult` + - Above limit (or `RESOURCE_EXHAUSTED`) → chunked download RPCs + +## Error Handling + +### gRPC Status Codes + +| Code | Meaning | Client Action | +|------|---------|---------------| +| `OK` | Success | Process result | +| `NOT_FOUND` | Job ID not found | Check job ID | +| `RESOURCE_EXHAUSTED` | Message too large | Use chunked transfer | +| `CANCELLED` | Job was cancelled | Handle gracefully | +| `DEADLINE_EXCEEDED` | Timeout | Retry or increase timeout | +| `UNAVAILABLE` | Server not reachable | Retry with backoff | +| `INTERNAL` | Server error | Report to user | +| `INVALID_ARGUMENT` | Bad request | Fix request | + +### Connection Handling + +- Client detects `context->IsCancelled()` for graceful disconnect +- Server cleans up job state on client disconnect during upload +- Automatic reconnection is NOT built-in (caller should retry) + +## Completion Strategy + +The `solve_lp` and `solve_mip` methods poll `CheckStatus` every `poll_interval_ms` +until the job reaches a terminal state (COMPLETED/FAILED/CANCELLED) or `timeout_seconds` +is exceeded. During polling, MIP incumbent callbacks are invoked on the main thread. + +The `WaitForCompletion` RPC is available as a public async API primitive for callers +managing jobs directly, but it is not used by the convenience `solve_*` methods because +polling provides timeout protection and enables incumbent callbacks. + +## Client API (`grpc_client_t`) + +### Configuration + +```cpp +struct grpc_client_config_t { + std::string server_address = "localhost:8765"; + int poll_interval_ms = 1000; + int timeout_seconds = 3600; // Max wait for job completion (1 hour) + bool stream_logs = false; // Stream solver logs from server + + // Callbacks + std::function log_callback; + std::function debug_log_callback; // Internal client debug messages + std::function&)> incumbent_callback; + int incumbent_poll_interval_ms = 1000; + + // TLS configuration + bool enable_tls = false; + std::string tls_root_certs; // CA certificate (PEM) + std::string tls_client_cert; // Client certificate (mTLS) + std::string tls_client_key; // Client private key (mTLS) + + // Transfer configuration + int64_t max_message_bytes = 256 * 1024 * 1024; // 256 MiB + int64_t chunk_size_bytes = 16 * 1024 * 1024; // 16 MiB per chunk + // Chunked upload threshold is computed as 75% of max_message_bytes. + bool enable_transfer_hash = false; // FNV-1a hash logging +}; +``` + +### Synchronous Operations + +```cpp +// Blocking solve — handles chunked transfer automatically +auto result = client.solve_lp(problem, settings); +auto result = client.solve_mip(problem, settings, enable_incumbents); +``` + +### Asynchronous Operations + +```cpp +// Submit and get job ID +auto submit = client.submit_lp(problem, settings); +std::string job_id = submit.job_id; + +// Poll for status +auto status = client.check_status(job_id); + +// Get result when ready +auto result = client.get_lp_result(job_id); + +// Cancel or delete +client.cancel_job(job_id); +client.delete_job(job_id); +``` + +### Real-Time Streaming + +```cpp +// Log streaming (callback-based) +client.stream_logs(job_id, 0, [](const std::string& line, bool done) { + std::cout << line; + return true; // continue streaming +}); + +// Incumbent polling (during MIP solve) +config.incumbent_callback = [](int64_t idx, double obj, const auto& sol) { + std::cout << "Incumbent " << idx << ": " << obj << "\n"; + return true; // return false to cancel solve +}; +``` + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `CUOPT_REMOTE_HOST` | `localhost` | Server hostname for remote solves | +| `CUOPT_REMOTE_PORT` | `8765` | Server port for remote solves | +| `CUOPT_CHUNK_SIZE` | 16 MiB | Override `chunk_size_bytes` | +| `CUOPT_MAX_MESSAGE_BYTES` | 256 MiB | Override `max_message_bytes` | +| `CUOPT_GRPC_DEBUG` | `0` | Enable client debug/throughput logging (`0` or `1`) | +| `CUOPT_TLS_ENABLED` | `0` | Enable TLS for client connections (`0` or `1`) | +| `CUOPT_TLS_ROOT_CERT` | *(none)* | Path to PEM root CA file (server verification) | +| `CUOPT_TLS_CLIENT_CERT` | *(none)* | Path to PEM client certificate file (for mTLS) | +| `CUOPT_TLS_CLIENT_KEY` | *(none)* | Path to PEM client private key file (for mTLS) | + +## TLS Configuration + +### Server-Side TLS + +```bash +./cuopt_grpc_server --port 8765 \ + --tls \ + --tls-cert server.crt \ + --tls-key server.key +``` + +### Mutual TLS (mTLS) + +Server requires client certificate: + +```bash +./cuopt_grpc_server --port 8765 \ + --tls \ + --tls-cert server.crt \ + --tls-key server.key \ + --tls-root ca.crt \ + --require-client-cert +``` + +Client provides certificate via environment variables (applies to Python, `cuopt_cli`, and C API): + +```bash +export CUOPT_TLS_ENABLED=1 +export CUOPT_TLS_ROOT_CERT=ca.crt +export CUOPT_TLS_CLIENT_CERT=client.crt +export CUOPT_TLS_CLIENT_KEY=client.key +``` + +Or programmatically via `grpc_client_config_t`: + +```cpp +config.enable_tls = true; +config.tls_root_certs = read_file("ca.crt"); +config.tls_client_cert = read_file("client.crt"); +config.tls_client_key = read_file("client.key"); +``` + +## Message Size Limits + +| Configuration | Default | Notes | +|---------------|---------|-------| +| Server `--max-message-mb` | 256 MiB | Per-message limit (also `--max-message-bytes` for exact byte values) | +| Server clamping | [4 KiB, ~2 GiB] | Enforced at startup to stay within protobuf's serialization limit | +| Client `max_message_bytes` | 256 MiB | Clamped to [4 MiB, ~2 GiB] at construction | +| Chunk size | 16 MiB | Payload per `SendArrayChunk`/`GetResultChunk` | +| Chunked threshold | 75% of max_message_bytes | Problems above this use chunked upload (e.g. 192 MiB when max is 256 MiB) | + +Chunked transfer allows unlimited total payload size; only individual +chunks must fit within the per-message limit. Neither client nor server +allows "unlimited" message size — both clamp to the protobuf 2 GiB ceiling. + +## Security Notes + +1. **gRPC Layer**: All client-server message parsing uses protobuf-generated code +2. **Internal Pipe**: The server-to-worker pipe uses protobuf for metadata headers + and length-prefixed raw `read()`/`write()` for bulk array data. This pipe is + internal to the server process (main → forked worker) and not exposed to clients. +3. **Standard gRPC Security**: HTTP/2 framing, flow control, standard status codes +4. **TLS Support**: Optional encryption with mutual authentication +5. **Input Validation**: Server validates all incoming gRPC messages before processing + +## Data Flow Summary + +``` +┌─────────┐ ┌─────────────┐ +│ Client │ │ Server │ +│ │ SubmitJob (small) │ │ +│ problem ├───────────────────────────────────►│ deserialize │ +│ │ -or- Chunked Upload (large) │ ↓ │ +│ │ │ worker │ +│ │ │ process │ +│ │ GetResult (small) │ ↓ │ +│ solution│◄───────────────────────────────────┤ serialize │ +│ │ -or- Chunked Download (large) │ │ +└─────────┘ └─────────────┘ +``` + +See `GRPC_SERVER_ARCHITECTURE.md` for details on internal server architecture. + +## Code Generation + +The `cpp/codegen` directory (optional) generates conversion snippets from `field_registry.yaml`. Targets include: + +- **Settings**: PDLP/MIP settings ↔ proto (replacing hand-written blocks in the settings mapper). +- **Result header/scalars/arrays**: ChunkedResultHeader and array field handling. +- **Field element size**: `grpc_field_element_size.hpp` (ArrayFieldId → byte size). +- **Incumbent**: `grpc_incumbent_proto.hpp` (build/parse `Incumbent` messages). + +Adding or changing a proto field can be done via YAML and regenerate instead of editing mapper code by hand. + +## Build + +- **libcuopt**: Includes the mapper `.cpp` files, `grpc_client.cpp`, and `solve_remote.cpp`. Requires `CUOPT_ENABLE_GRPC`, gRPC, and protobuf. Proto generation is done by CMake custom commands that depend on the `.proto` files in `cpp/src/grpc/`. +- **cuopt_grpc_server**: Executable built from `cpp/src/grpc/server/*.cpp`; links libcuopt, gRPC, protobuf. + +Tests that use the client (e.g. `grpc_client_test.cpp`, `grpc_integration_test.cpp`) get `cpp/src/grpc` and `cpp/src/grpc/client` in their include path. diff --git a/GRPC_QUICK_START.md b/GRPC_QUICK_START.md new file mode 100644 index 0000000000..a3864c101e --- /dev/null +++ b/GRPC_QUICK_START.md @@ -0,0 +1,248 @@ +# cuOpt gRPC Remote Execution Quick Start + +This guide shows how to start the cuOpt gRPC server and solve +optimization problems remotely from Python, `cuopt_cli`, or the C API. + +All three interfaces use the same environment variables for remote +configuration. Once the env vars are set, your code works exactly the +same as a local solve — no API changes required. + +## Prerequisites + +- A host with an NVIDIA GPU and cuOpt installed (server side). +- cuOpt client libraries installed on the client host (can be CPU-only). +- `cuopt_grpc_server` binary available (ships with the cuOpt package). + +## 1. Start the Server + +### Basic (no TLS) + +```bash +cuopt_grpc_server --port 8765 --workers 1 +``` + +### TLS (server authentication) + +```bash +cuopt_grpc_server --port 8765 \ + --tls \ + --tls-cert server.crt \ + --tls-key server.key +``` + +### mTLS (mutual authentication) + +```bash +cuopt_grpc_server --port 8765 \ + --tls \ + --tls-cert server.crt \ + --tls-key server.key \ + --tls-root ca.crt \ + --require-client-cert +``` + +See `GRPC_SERVER_ARCHITECTURE.md` for the full set of server flags. + +### How mTLS Works + +With mTLS the server verifies every client, and the client verifies the +server. The trust model is based on Certificate Authorities (CAs), not +individual certificates: + +- **`--tls-root ca.crt`** tells the server which CA to trust. Any client + presenting a certificate signed by this CA is accepted. The server + never sees or stores individual client certificates. +- **`--require-client-cert`** makes client verification mandatory. Without + it the server requests a client cert but still allows unauthenticated + connections. +- On the client side, `CUOPT_TLS_ROOT_CERT` is the CA that signed the + *server* certificate, so the client can verify the server's identity. + +### Restricting Access with a Custom CA + +To limit which clients can reach your server, create a private CA and +only issue client certificates to authorized users. Anyone without a +certificate signed by your CA is rejected at the TLS handshake before +any solver traffic is exchanged. + +**1. Create a private CA (one-time setup):** + +```bash +# Generate CA private key +openssl genrsa -out ca.key 4096 + +# Generate self-signed CA certificate (valid 10 years) +openssl req -new -x509 -key ca.key -sha256 -days 3650 \ + -subj "/CN=cuopt-internal-ca" -out ca.crt +``` + +**2. Issue a client certificate:** + +```bash +# Generate client key +openssl genrsa -out client.key 2048 + +# Create a certificate signing request +openssl req -new -key client.key \ + -subj "/CN=team-member-alice" -out client.csr + +# Sign with your CA +openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key \ + -CAcreateserial -days 365 -sha256 -out client.crt +``` + +Repeat step 2 for each authorized client. Keep `ca.key` private; +distribute only `ca.crt` (to the server) and the per-client +`client.crt` + `client.key` pairs. + +**3. Issue a server certificate (signed by the same CA):** + +```bash +# Generate server key +openssl genrsa -out server.key 2048 + +# Create CSR with subjectAltName matching the hostname clients will use +openssl req -new -key server.key \ + -subj "/CN=server.example.com" -out server.csr + +# Write a SAN extension file (DNS and/or IP must match client's target) +cat > server.ext < **Note:** `server.crt` must be signed by the same CA distributed to +> clients, and its `subjectAltName` must match the hostname or IP that +> clients connect to. gRPC (BoringSSL) requires SAN — `CN` alone is +> not sufficient for hostname verification. + +**4. Start the server with your CA:** + +```bash +cuopt_grpc_server --port 8765 \ + --tls \ + --tls-cert server.crt \ + --tls-key server.key \ + --tls-root ca.crt \ + --require-client-cert +``` + +**5. Configure an authorized client:** + +```bash +export CUOPT_REMOTE_HOST=server.example.com +export CUOPT_REMOTE_PORT=8765 +export CUOPT_TLS_ENABLED=1 +export CUOPT_TLS_ROOT_CERT=ca.crt # verifies the server +export CUOPT_TLS_CLIENT_CERT=client.crt # proves client identity +export CUOPT_TLS_CLIENT_KEY=client.key +``` + +**Revoking access:** gRPC's built-in TLS does not support Certificate +Revocation Lists (CRL) or OCSP. To revoke a client, either stop issuing +new certs from the compromised CA and rotate to a new one, or deploy a +reverse proxy (e.g., Envoy) in front of the server that supports CRL +checking. + +## 2. Configure the Client (All Interfaces) + +Set these environment variables before running any cuOpt client. +They apply identically to the Python API, `cuopt_cli`, and the C API. + +### Required + +```bash +export CUOPT_REMOTE_HOST= +export CUOPT_REMOTE_PORT=8765 +``` + +When both `CUOPT_REMOTE_HOST` and `CUOPT_REMOTE_PORT` are set, every +call to `solve_lp` / `solve_mip` is transparently forwarded to the +remote server. No code changes are needed. + +### TLS (optional) + +```bash +export CUOPT_TLS_ENABLED=1 +export CUOPT_TLS_ROOT_CERT=ca.crt # verify server certificate +``` + +For mTLS, also provide the client identity: + +```bash +export CUOPT_TLS_CLIENT_CERT=client.crt +export CUOPT_TLS_CLIENT_KEY=client.key +``` + +### Tuning (optional) + +| Variable | Default | Description | +|----------|---------|-------------| +| `CUOPT_CHUNK_SIZE` | 16 MiB | Bytes per chunk for large problem transfer | +| `CUOPT_MAX_MESSAGE_BYTES` | 256 MiB | Client-side gRPC max message size | +| `CUOPT_GRPC_DEBUG` | `0` | Enable debug / throughput logging (`1` to enable) | + +## 3. Usage Examples + +Once the env vars are set, write your solver code exactly as you would +for a local solve. The remote transport is handled automatically. + +### Python + +```python +import cuopt_mps_parser +from cuopt import linear_programming + +# Parse an MPS file +dm = cuopt_mps_parser.ParseMps("model.mps") + +# Solve (routed to remote server via env vars) +solution = linear_programming.Solve(dm, linear_programming.SolverSettings()) + +print("Objective:", solution.get_primal_objective()) +print("Primal: ", solution.get_primal_solution()[:5], "...") +``` + +### cuopt_cli + +```bash +cuopt_cli model.mps +``` + +With solver options: + +```bash +cuopt_cli model.mps --time-limit 30 --relaxation +``` + +### C++ API + +```cpp +#include +#include + +// Build problem using cpu_optimization_problem_t ... +auto solution = cuopt::linear_programming::solve_lp(cpu_problem, settings); +``` + +The same `solve_lp` / `solve_mip` functions automatically detect the +`CUOPT_REMOTE_HOST` / `CUOPT_REMOTE_PORT` env vars and forward to the +gRPC server when they are set. + +## Troubleshooting + +| Symptom | Check | +|---------|-------| +| Connection refused | Verify the server is running and the host/port are correct. | +| TLS handshake failure | Ensure `CUOPT_TLS_ENABLED=1` is set and certificate paths are correct. | +| `Cannot open TLS file: ...` | The path in the TLS env var does not exist or is not readable. | +| Timeout on large problems | Increase the solver `time_limit` or the client `timeout_seconds`. | + +## Further Reading + +- `GRPC_INTERFACE.md` — Protocol details, chunked transfer, client config, message sizes. +- `GRPC_SERVER_ARCHITECTURE.md` — Server process model, IPC, threads, job lifecycle. diff --git a/GRPC_SERVER_ARCHITECTURE.md b/GRPC_SERVER_ARCHITECTURE.md new file mode 100644 index 0000000000..2d6c2c324b --- /dev/null +++ b/GRPC_SERVER_ARCHITECTURE.md @@ -0,0 +1,316 @@ +# Server Architecture + +## Overview + +The cuOpt gRPC server (`cuopt_grpc_server`) is a multi-process architecture designed for: +- **Isolation**: Each solve runs in a separate worker process for fault tolerance +- **Parallelism**: Multiple workers can process jobs concurrently +- **Large Payloads**: Handles multi-GB problems and solutions +- **Real-Time Feedback**: Log streaming and incumbent callbacks during solve + +For gRPC protocol and client API, see `GRPC_INTERFACE.md`. Server source files live under `cpp/src/grpc/server/`. + +## Process Model + +```text +┌────────────────────────────────────────────────────────────────────┐ +│ Main Server Process │ +│ │ +│ ┌─────────────┐ ┌──────────────┐ ┌─────────────────────────────┐ │ +│ │ gRPC │ │ Job │ │ Background Threads │ │ +│ │ Service │ │ Tracker │ │ - Result retrieval │ │ +│ │ Handler │ │ (job status,│ │ - Incumbent retrieval │ │ +│ │ │ │ results) │ │ - Worker monitor │ │ +│ └─────────────┘ └──────────────┘ └─────────────────────────────┘ │ +│ │ ▲ │ +│ │ shared memory │ pipes │ +│ ▼ │ │ +│ ┌─────────────────────────────────────────────────────────────────┐│ +│ │ Shared Memory Queues ││ +│ │ ┌─────────────────┐ ┌─────────────────────┐ ││ +│ │ │ Job Queue │ │ Result Queue │ ││ +│ │ │ (MAX_JOBS=100) │ │ (MAX_RESULTS=100) │ ││ +│ │ └─────────────────┘ └─────────────────────┘ ││ +│ └─────────────────────────────────────────────────────────────────┘│ +└────────────────────────────────────────────────────────────────────┘ + │ ▲ + │ fork() │ + ▼ │ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ Worker 0 │ │ Worker 1 │ │ Worker N │ +│ ┌───────────┐ │ │ ┌───────────┐ │ │ ┌───────────┐ │ +│ │ GPU Solve │ │ │ │ GPU Solve │ │ │ │ GPU Solve │ │ +│ └───────────┘ │ │ └───────────┘ │ │ └───────────┘ │ +│ (separate proc)│ │ (separate proc)│ │ (separate proc)│ +└─────────────────┘ └─────────────────┘ └─────────────────┘ +``` + +## Inter-Process Communication + +### Shared Memory Segments + +| Segment | Purpose | +|---------|---------| +| `/cuopt_job_queue` | Job metadata (ID, type, size, status) | +| `/cuopt_result_queue` | Result metadata (ID, status, size, error) | +| `/cuopt_control` | Server control flags (shutdown, worker count) | + +### Pipe Communication + +Each worker has dedicated pipes for data transfer: + +```cpp +struct WorkerPipes { + int to_worker_fd; // Main → Worker: job data (server writes) + int from_worker_fd; // Worker → Main: result data (server reads) + int worker_read_fd; // Worker end of input pipe (worker reads) + int worker_write_fd; // Worker end of output pipe (worker writes) + int incumbent_from_worker_fd; // Worker → Main: incumbent solutions (server reads) + int worker_incumbent_write_fd; // Worker end of incumbent pipe (worker writes) +}; +``` + +**Why pipes instead of shared memory for data?** +- Pipes handle backpressure naturally (blocking writes) +- No need to manage large shared memory segments +- Works well with streaming uploads (data flows through) + +### Source File Roles + +All paths below are under `cpp/src/grpc/server/`. + +| File | Role | +|------|------| +| `grpc_server_main.cpp` | `main()`, `print_usage()`, argument parsing, shared-memory init, gRPC server run/stop. | +| `grpc_service_impl.cpp` | `CuOptRemoteServiceImpl`: all 14 RPC handlers (SubmitJob, CheckStatus, GetResult, chunked upload/download, StreamLogs, GetIncumbents, CancelJob, DeleteResult, WaitForCompletion, Status probe). Uses mappers and job_management to enqueue jobs and trigger pipe I/O. | +| `grpc_server_types.hpp` | Shared structs (e.g. `JobQueueEntry`, `ResultQueueEntry`, `ServerConfig`, `JobInfo`), enums, globals (atomics, mutexes, condition variables), and forward declarations used across server .cpp files. | +| `grpc_field_element_size.hpp` | Maps `cuopt::remote::ArrayFieldId` to element byte size; used by pipe deserialization and chunked logic. | +| `grpc_pipe_serialization.hpp` | Streaming pipe I/O: write/read individual length-prefixed protobuf messages (ChunkedProblemHeader, ChunkedResultHeader, ArrayChunk) directly to/from pipe fds. Avoids large intermediate buffers. Also serializes SubmitJobRequest for unary pipe transfer. | +| `grpc_incumbent_proto.hpp` | Build `Incumbent` proto from (job_id, objective, assignment) and parse it back; used by worker when pushing incumbents and by main when reading from the incumbent pipe. | +| `grpc_worker.cpp` | `worker_process(worker_index)`: loop over job queue, receive job data via pipe (unary or chunked), call solver, send result (and optionally incumbents) back. Contains `IncumbentPipeCallback` and `store_simple_result`. | +| `grpc_worker_infra.cpp` | Pipe creation/teardown, `spawn_worker` / `spawn_workers`, `wait_for_workers`, `mark_worker_jobs_failed`, `cleanup_shared_memory`. | +| `grpc_server_threads.cpp` | `worker_monitor_thread`, `result_retrieval_thread`, `incumbent_retrieval_thread`, `session_reaper_thread`. | +| `grpc_job_management.cpp` | Low-level pipe read/write, `send_job_data_pipe` / `recv_job_data_pipe`, `submit_job_async`, `check_job_status`, `cancel_job`, `generate_job_id`, log-dir helpers. | + +### Large Payload Handling + +For large problems uploaded via chunked gRPC RPCs: + +1. Server holds chunked upload state in memory (`ChunkedUploadState`: header + array chunks per `upload_id`). +2. When `FinishChunkedUpload` is called, the header and chunks are stored in `pending_chunked_data`. The data dispatch thread streams them directly to the worker pipe as individual length-prefixed protobuf messages — no intermediate blob is created. +3. Worker reads the streamed messages from the pipe, reassembles arrays, runs the solver, and writes the result (and optionally incumbents) back via pipes using the same streaming format. +4. Main process result-retrieval thread reads the streamed result messages from the pipe and stores the result for `GetResult` or chunked download. + +This streaming approach avoids creating a single large buffer, eliminating the 2 GiB protobuf serialization limit for pipe transfers and reducing peak memory usage. Each individual protobuf message (max 64 MiB) is serialized with standard `SerializeToArray`/`ParseFromArray`. + +No disk spooling: chunked data is kept in memory in the main process until forwarded to the worker. + +## Job Lifecycle + +### 1. Submission + +```text +Client Server Worker + │ │ │ + │─── SubmitJob ──────────►│ │ + │ │ Create job entry │ + │ │ Store problem data │ + │ │ job_queue[slot].ready=true│ + │◄── job_id ──────────────│ │ +``` + +### 2. Processing + +```text +Client Server Worker + │ │ │ + │ │ │ Poll job_queue + │ │ │ Claim job (CAS) + │ │◄─────────────────────────│ Read problem via pipe + │ │ │ + │ │ │ Convert CPU→GPU + │ │ │ solve_lp/solve_mip + │ │ │ Convert GPU→CPU + │ │ │ + │ │ result_queue[slot].ready │◄────────────────── + │ │◄── result data via pipe ─│ +``` + +### 3. Result Retrieval + +```text +Client Server Worker + │ │ │ + │─── CheckStatus ────────►│ │ + │◄── COMPLETED ───────────│ │ + │ │ │ + │─── GetResult ──────────►│ │ + │ │ Look up job_tracker │ + │◄── solution ────────────│ │ +``` + +## Data Type Conversions + +Workers perform CPU↔GPU conversions to minimize client complexity: + +```text +Client Worker + │ │ + │ cpu_optimization_ │ + │ problem_t ──────►│ map_proto_to_problem() + │ │ ↓ + │ │ to_optimization_problem() + │ │ ↓ (GPU) + │ │ solve_lp() / solve_mip() + │ │ ↓ (GPU) + │ │ cudaMemcpy() to host + │ │ ↓ + │ cpu_lp_solution_t/ │ map_lp_solution_to_proto() / + │ cpu_mip_solution_t ◄────│ map_mip_solution_to_proto() +``` + +## Background Threads + +### Result Retrieval Thread + +- Monitors `result_queue` for completed jobs +- Reads result data from worker pipes +- Updates `job_tracker` with results +- Notifies waiting clients (via condition variable) + +### Incumbent Retrieval Thread + +- Monitors incumbent pipes from all workers +- Parses `Incumbent` protobuf messages +- Stores in `job_tracker[job_id].incumbents` +- Enables `GetIncumbents` RPC to return data + +### Worker Monitor Thread + +- Detects crashed workers (via `waitpid`) +- Marks affected jobs as FAILED +- Can respawn workers (optional) + +### Session Reaper Thread + +- Runs every 60 seconds +- Removes stale chunked upload and download sessions after 300 seconds of inactivity +- Prevents memory leaks from abandoned upload/download sessions + +## Log Streaming + +Workers write logs to per-job files: + +```text +/tmp/cuopt_logs/job_.log +``` + +The `StreamLogs` RPC: +1. Opens the log file +2. Reads and sends new content periodically +3. Closes when job completes + +## Job States + +```text +┌─────────┐ submit ┌───────────┐ claim ┌────────────┐ +│ QUEUED │──────────►│ PROCESSING│─────────►│ COMPLETED │ +└─────────┘ └───────────┘ └────────────┘ + │ │ + │ cancel │ error + ▼ ▼ +┌───────────┐ ┌─────────┐ +│ CANCELLED │ │ FAILED │ +└───────────┘ └─────────┘ +``` + +## Configuration Options + +```bash +cuopt_grpc_server [options] + + -p, --port PORT gRPC listen port (default: 8765) + -w, --workers NUM Number of worker processes (default: 1) + --max-message-mb N Max gRPC message size in MiB (default: 256; clamped to [4 KiB, ~2 GiB]) + --max-message-bytes N Max gRPC message size in bytes (exact; min 4096) + --enable-transfer-hash Log data hashes for streaming transfers (for testing) + --log-to-console Echo solver logs to server console + -q, --quiet Reduce verbosity (verbose is the default) + +TLS Options: + --tls Enable TLS encryption + --tls-cert PATH Server certificate (PEM) + --tls-key PATH Server private key (PEM) + --tls-root PATH Root CA certificate (for client verification) + --require-client-cert Require client certificate (mTLS) +``` + +## Fault Tolerance + +### Worker Crashes + +If a worker process crashes: +1. Monitor thread detects via `waitpid(WNOHANG)` +2. Any jobs the worker was processing are marked as FAILED +3. A replacement worker is automatically spawned (unless shutting down) +4. Other workers continue operating unaffected + +### Graceful Shutdown + +On SIGINT/SIGTERM: +1. Set `shm_ctrl->shutdown_requested = true` +2. Workers finish current job and exit +3. Main process waits for workers +4. Cleanup shared memory segments + +### Job Cancellation + +When `CancelJob` is called: +1. Set `job_queue[slot].cancelled = true` +2. Worker checks the flag before starting the solve +3. If cancelled, worker stores CANCELLED result and skips to the next job +4. If the solve has already started, it runs to completion (no mid-solve cancellation) + +## Memory Management + +| Resource | Location | Cleanup | +|----------|----------|---------| +| Job queue entries | Shared memory | Reused after completion | +| Result queue entries | Shared memory | Reused after retrieval | +| Problem data | Pipe (transient) | Consumed by worker | +| Chunked upload state | Main process memory | After `FinishChunkedUpload` (forwarded to worker) | +| Result data | `job_tracker` map | `DeleteResult` RPC | +| Log files | `/tmp/cuopt_logs/` | `DeleteResult` RPC | + +## Performance Considerations + +### Worker Count + +- Each worker needs a GPU (or shares with others) +- Too many workers: GPU memory contention +- Too few workers: Underutilized when jobs queue +- Recommendation: 1-2 workers per GPU + +### Pipe Buffering + +- Pipe buffer size is set to 1 MiB via `fcntl(F_SETPIPE_SZ)` (Linux default is 64 KiB) +- Large results block worker until main process reads +- Result retrieval thread should read promptly +- Deadlock prevention: Set `result.ready = true` BEFORE writing pipe + +### Shared Memory Limits + +- `MAX_JOBS = 100`: Maximum concurrent queued jobs +- `MAX_RESULTS = 100`: Maximum stored results +- Increase if needed for high-throughput scenarios + +## File Locations + +| Path | Purpose | +|------|---------| +| `/tmp/cuopt_logs/` | Per-job solver log files | +| `/cuopt_job_queue` | Shared memory (job metadata) | +| `/cuopt_result_queue` | Shared memory (result metadata) | +| `/cuopt_control` | Shared memory (server control) | + +Chunked upload state is held in memory in the main process (no upload directory). diff --git a/build.sh b/build.sh index b5c35f510b..5f9ac4071a 100755 --- a/build.sh +++ b/build.sh @@ -15,11 +15,12 @@ REPODIR=$(cd "$(dirname "$0")"; pwd) LIBCUOPT_BUILD_DIR=${LIBCUOPT_BUILD_DIR:=${REPODIR}/cpp/build} LIBMPS_PARSER_BUILD_DIR=${LIBMPS_PARSER_BUILD_DIR:=${REPODIR}/cpp/libmps_parser/build} -VALIDARGS="clean libcuopt libmps_parser cuopt_mps_parser cuopt cuopt_server cuopt_sh_client docs deb -a -b -g -fsanitize -tsan -msan -v -l= --verbose-pdlp --build-lp-only --no-fetch-rapids --skip-c-python-adapters --skip-tests-build --skip-routing-build --skip-fatbin-write --host-lineinfo [--cmake-args=\\\"\\\"] [--cache-tool=] -n --allgpuarch --ci-only-arch --show_depr_warn -h --help" +VALIDARGS="clean libcuopt cuopt_grpc_server libmps_parser cuopt_mps_parser cuopt cuopt_server cuopt_sh_client docs deb -a -b -g -fsanitize -tsan -msan -v -l= --verbose-pdlp --build-lp-only --no-fetch-rapids --skip-c-python-adapters --skip-tests-build --skip-routing-build --skip-fatbin-write --host-lineinfo [--cmake-args=\\\"\\\"] [--cache-tool=] -n --allgpuarch --ci-only-arch --show_depr_warn -h --help" HELP="$0 [ ...] [ ...] where is: clean - remove all existing build artifacts and configuration (start over) libcuopt - build the cuopt C++ code + cuopt_grpc_server - build only the gRPC server binary (configures + builds libcuopt as needed) libmps_parser - build the libmps_parser C++ code cuopt_mps_parser - build the cuopt_mps_parser python package cuopt - build the cuopt Python package @@ -358,8 +359,8 @@ if buildAll || hasArg libmps_parser; then fi ################################################################################ -# Configure, build, and install libcuopt -if buildAll || hasArg libcuopt; then +# Configure and build libcuopt (and optionally just the gRPC server) +if buildAll || hasArg libcuopt || hasArg cuopt_grpc_server; then mkdir -p "${LIBCUOPT_BUILD_DIR}" cd "${LIBCUOPT_BUILD_DIR}" cmake -DDEFINE_ASSERT=${DEFINE_ASSERT} \ @@ -386,7 +387,10 @@ if buildAll || hasArg libcuopt; then "${EXTRA_CMAKE_ARGS[@]}" \ "${REPODIR}"/cpp JFLAG="${PARALLEL_LEVEL:+-j${PARALLEL_LEVEL}}" - if hasArg -n; then + if hasArg cuopt_grpc_server && ! hasArg libcuopt && ! buildAll; then + # Build only the gRPC server (ninja resolves libcuopt as a dependency) + cmake --build "${LIBCUOPT_BUILD_DIR}" --target cuopt_grpc_server ${VERBOSE_FLAG} ${JFLAG} + elif hasArg -n; then # Manual make invocation to start its jobserver make ${JFLAG} -C "${REPODIR}/cpp" LIBCUOPT_BUILD_DIR="${LIBCUOPT_BUILD_DIR}" VERBOSE_FLAG="${VERBOSE_FLAG}" PARALLEL_LEVEL="${PARALLEL_LEVEL}" ninja-build else diff --git a/ci/build_wheel_libcuopt.sh b/ci/build_wheel_libcuopt.sh index a6dd64c7d8..bad18cb798 100755 --- a/ci/build_wheel_libcuopt.sh +++ b/ci/build_wheel_libcuopt.sh @@ -17,6 +17,16 @@ fi # Install Boost and TBB bash ci/utils/install_boost_tbb.sh +# Install libuuid (needed by cuopt_grpc_server) +if command -v dnf &> /dev/null; then + dnf install -y libuuid-devel +elif command -v apt-get &> /dev/null; then + apt-get update && apt-get install -y uuid-dev +fi + +# Install Protobuf + gRPC (protoc + grpc_cpp_plugin) +bash ci/utils/install_protobuf_grpc.sh + export SKBUILD_CMAKE_ARGS="-DCUOPT_BUILD_WHEELS=ON;-DDISABLE_DEPRECATION_WARNING=ON" # For pull requests we are enabling assert mode. diff --git a/ci/utils/install_protobuf_grpc.sh b/ci/utils/install_protobuf_grpc.sh new file mode 100755 index 0000000000..ebb432a205 --- /dev/null +++ b/ci/utils/install_protobuf_grpc.sh @@ -0,0 +1,234 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +# Install Protobuf and gRPC C++ development libraries from source. +# +# This script builds gRPC, Protobuf, and Abseil from source to ensure consistent +# ABI and avoid symbol issues (notably abseil-cpp#1624: Mutex::Dtor not exported +# from shared libabseil on Linux). +# +# Usage: +# ./install_protobuf_grpc.sh [OPTIONS] +# +# Options: +# --prefix=DIR Installation prefix (default: /usr/local) +# --build-dir=DIR Build directory for source builds (default: /tmp) +# --skip-deps Skip installing system dependencies (for conda builds) +# --help Show this help message +# +# Examples: +# # Wheel builds (install to /usr/local, installs system deps) +# ./install_protobuf_grpc.sh +# +# # Conda builds (install to custom prefix, deps already available) +# ./install_protobuf_grpc.sh --prefix=${GRPC_INSTALL_DIR} --build-dir=${SRC_DIR} --skip-deps + +# Configuration - single source of truth for gRPC version +GRPC_VERSION="v1.64.2" + +# Default values +PREFIX="/usr/local" +BUILD_DIR="/tmp" +SKIP_DEPS=false + +# Parse command-line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --prefix=*) + PREFIX="${1#*=}" + shift + ;; + --build-dir=*) + BUILD_DIR="${1#*=}" + shift + ;; + --skip-deps) + SKIP_DEPS=true + shift + ;; + --help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Build and install gRPC ${GRPC_VERSION} and dependencies from source." + echo "" + echo "Options:" + echo " --prefix=DIR Installation prefix (default: /usr/local)" + echo " --build-dir=DIR Build directory for source builds (default: /tmp)" + echo " --skip-deps Skip installing system dependencies (for conda builds)" + echo " --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +PREFIX=$(realpath -m "$PREFIX" 2>/dev/null || readlink -f "$PREFIX" 2>/dev/null || echo "$PREFIX") +BUILD_DIR=$(realpath -m "$BUILD_DIR" 2>/dev/null || readlink -f "$BUILD_DIR" 2>/dev/null || echo "$BUILD_DIR") + +if [[ -z "$PREFIX" || "$PREFIX" == "/" ]]; then + echo "ERROR: Invalid PREFIX: '$PREFIX'" >&2 + exit 1 +fi +if [[ -z "$BUILD_DIR" || "$BUILD_DIR" == "/" ]]; then + echo "ERROR: Invalid BUILD_DIR: '$BUILD_DIR'" >&2 + exit 1 +fi + +mkdir -p "$BUILD_DIR" + +echo "==============================================" +echo "Installing gRPC ${GRPC_VERSION} from source" +echo " Prefix: ${PREFIX}" +echo " Build dir: ${BUILD_DIR}" +echo " Skip deps: ${SKIP_DEPS}" +echo "==============================================" + +# Install system dependencies if not skipped +if [ "${SKIP_DEPS}" = false ]; then + echo "" + echo "Installing system dependencies..." + if [ -f /etc/os-release ]; then + . /etc/os-release + if [[ "$ID" == "rocky" || "$ID" == "centos" || "$ID" == "rhel" || "$ID" == "fedora" ]]; then + # Enable PowerTools (Rocky 8) or CRB (Rocky 9) for some packages + if [[ "${VERSION_ID%%.*}" == "8" ]]; then + dnf config-manager --set-enabled powertools || dnf config-manager --set-enabled PowerTools || true + elif [[ "${VERSION_ID%%.*}" == "9" ]]; then + dnf config-manager --set-enabled crb || true + fi + dnf install -y git cmake ninja-build gcc gcc-c++ openssl-devel zlib-devel c-ares-devel + elif [[ "$ID" == "ubuntu" || "$ID" == "debian" ]]; then + apt-get update + apt-get install -y git cmake ninja-build g++ libssl-dev zlib1g-dev libc-ares-dev + else + echo "Warning: Unknown OS '$ID'. Assuming build tools are already installed." + fi + else + echo "Warning: /etc/os-release not found. Assuming build tools are already installed." + fi +fi + +# Verify required tools are available +echo "" +echo "Checking required tools..." +for tool in git cmake ninja; do + if ! command -v "$tool" &> /dev/null; then + echo "Error: Required tool '$tool' not found. Please install it (e.g., via your package manager) and re-run this script." + exit 1 + fi +done +echo "All required tools found." + +# Clean up any previous installations to avoid ABI mismatches +# (notably Abseil LTS namespaces like absl::lts_20220623 vs absl::lts_20250512) +echo "Cleaning up previous installations..." +rm -rf \ + "${PREFIX}/lib/cmake/grpc" "${PREFIX}/lib64/cmake/grpc" \ + "${PREFIX}/lib/cmake/protobuf" "${PREFIX}/lib64/cmake/protobuf" \ + "${PREFIX}/lib/cmake/absl" "${PREFIX}/lib64/cmake/absl" \ + "${PREFIX}/include/absl" "${PREFIX}/include/google/protobuf" "${PREFIX}/include/grpc" \ + "${PREFIX}/bin/grpc_cpp_plugin" "${PREFIX}/bin/protoc" "${PREFIX}/bin/protoc-"* || true +rm -f \ + "${PREFIX}/lib/"libgrpc*.a "${PREFIX}/lib/"libgpr*.a "${PREFIX}/lib/"libaddress_sorting*.a "${PREFIX}/lib/"libre2*.a "${PREFIX}/lib/"libupb*.a \ + "${PREFIX}/lib64/"libabsl_*.a "${PREFIX}/lib64/"libprotobuf*.so* "${PREFIX}/lib64/"libprotoc*.so* \ + "${PREFIX}/lib/"libprotobuf*.a "${PREFIX}/lib/"libprotoc*.a || true + +# Build and install gRPC dependencies from source in a consistent way. +# +# IMPORTANT: Protobuf and gRPC both depend on Abseil, and the Abseil LTS +# namespace (e.g. absl::lts_20250512) is part of C++ symbol mangling. +# If Protobuf and gRPC are built against different Abseil versions, gRPC +# plugins can fail to link with undefined references (e.g. Printer::PrintImpl). +# +# To avoid that, we install Abseil first (from gRPC's submodule), then +# build Protobuf and gRPC against that same installed Abseil. + +GRPC_SRC="${BUILD_DIR}/grpc-src" +ABSL_BUILD="${BUILD_DIR}/absl-build" +PROTOBUF_BUILD="${BUILD_DIR}/protobuf-build" +GRPC_BUILD="${BUILD_DIR}/grpc-build" + +rm -rf "${GRPC_SRC}" "${ABSL_BUILD}" "${PROTOBUF_BUILD}" "${GRPC_BUILD}" +mkdir -p "${PREFIX}" + +echo "Cloning gRPC ${GRPC_VERSION} with submodules..." +git clone --depth 1 --branch "${GRPC_VERSION}" --recurse-submodules --shallow-submodules \ + https://github.com/grpc/grpc.git "${GRPC_SRC}" + +# Ensure prefix is in PATH and CMAKE_PREFIX_PATH +export PATH="${PREFIX}/bin:${PATH}" +export CMAKE_PREFIX_PATH="${PREFIX}:${CMAKE_PREFIX_PATH:-}" + +# Ensure a consistent C++ standard across Abseil/Protobuf/gRPC. +# Abseil's options.h defaults to "auto" selection for std::string_view +# (ABSL_OPTION_USE_STD_STRING_VIEW=2). If one library is built in +# C++17+ and another in C++14, they will disagree on whether +# `absl::string_view` is a typedef to `std::string_view` or Abseil's +# own type, leading to link-time ABI mismatches. +CMAKE_STD_FLAGS="-DCMAKE_CXX_STANDARD=17 -DCMAKE_CXX_STANDARD_REQUIRED=ON" + +echo "" +echo "Building Abseil (from gRPC submodule)..." +cmake -S "${GRPC_SRC}/third_party/abseil-cpp" -B "${ABSL_BUILD}" -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + ${CMAKE_STD_FLAGS} \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DCMAKE_INSTALL_PREFIX="${PREFIX}" +cmake --build "${ABSL_BUILD}" --parallel +cmake --install "${ABSL_BUILD}" + +echo "" +echo "Building Protobuf (using installed Abseil)..." +cmake -S "${GRPC_SRC}/third_party/protobuf" -B "${PROTOBUF_BUILD}" -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + ${CMAKE_STD_FLAGS} \ + -Dprotobuf_BUILD_TESTS=OFF \ + -Dprotobuf_ABSL_PROVIDER=package \ + -DCMAKE_PREFIX_PATH="${PREFIX}" \ + -DCMAKE_INSTALL_PREFIX="${PREFIX}" +cmake --build "${PROTOBUF_BUILD}" --parallel +cmake --install "${PROTOBUF_BUILD}" + +echo "" +echo "Building gRPC (using installed Abseil and Protobuf)..." +cmake -S "${GRPC_SRC}" -B "${GRPC_BUILD}" -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + ${CMAKE_STD_FLAGS} \ + -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_BUILD_CODEGEN=ON \ + -DgRPC_BUILD_GRPC_NODE_PLUGIN=OFF \ + -DgRPC_ABSL_PROVIDER=package \ + -DgRPC_PROTOBUF_PROVIDER=package \ + -DgRPC_RE2_PROVIDER=module \ + -DgRPC_SSL_PROVIDER=package \ + -DgRPC_ZLIB_PROVIDER=package \ + -DgRPC_CARES_PROVIDER=package \ + -DCMAKE_PREFIX_PATH="${PREFIX}" \ + -DCMAKE_INSTALL_PREFIX="${PREFIX}" +cmake --build "${GRPC_BUILD}" --parallel +cmake --install "${GRPC_BUILD}" + +# For system-wide installs, update ldconfig +if [[ "${PREFIX}" == "/usr/local" ]]; then + echo "" + echo "Updating ldconfig for system-wide install..." + echo "${PREFIX}/lib64" > /etc/ld.so.conf.d/usr-local-lib64.conf 2>/dev/null || true + echo "${PREFIX}/lib" > /etc/ld.so.conf.d/usr-local-lib.conf 2>/dev/null || true + ldconfig || true +fi + +echo "" +echo "==============================================" +echo "gRPC ${GRPC_VERSION} installed successfully to ${PREFIX}" +echo "==============================================" diff --git a/conda/environments/all_cuda-129_arch-aarch64.yaml b/conda/environments/all_cuda-129_arch-aarch64.yaml index 3cee401c5c..0702557fba 100644 --- a/conda/environments/all_cuda-129_arch-aarch64.yaml +++ b/conda/environments/all_cuda-129_arch-aarch64.yaml @@ -7,6 +7,7 @@ channels: dependencies: - breathe - bzip2 +- c-ares - c-compiler - ccache - clang-tools=20.1.4 @@ -45,6 +46,7 @@ dependencies: - numba>=0.60.0 - numpy>=1.23.5,<3.0 - numpydoc +- openssl - pandas>=2.0 - pexpect - pip diff --git a/conda/environments/all_cuda-129_arch-x86_64.yaml b/conda/environments/all_cuda-129_arch-x86_64.yaml index 5632c8c9c7..b2411fafb1 100644 --- a/conda/environments/all_cuda-129_arch-x86_64.yaml +++ b/conda/environments/all_cuda-129_arch-x86_64.yaml @@ -7,6 +7,7 @@ channels: dependencies: - breathe - bzip2 +- c-ares - c-compiler - ccache - clang-tools=20.1.4 @@ -45,6 +46,7 @@ dependencies: - numba>=0.60.0 - numpy>=1.23.5,<3.0 - numpydoc +- openssl - pandas>=2.0 - pexpect - pip diff --git a/conda/environments/all_cuda-131_arch-aarch64.yaml b/conda/environments/all_cuda-131_arch-aarch64.yaml index add21cbb2f..d2390dfdb5 100644 --- a/conda/environments/all_cuda-131_arch-aarch64.yaml +++ b/conda/environments/all_cuda-131_arch-aarch64.yaml @@ -7,6 +7,7 @@ channels: dependencies: - breathe - bzip2 +- c-ares - c-compiler - ccache - clang-tools=20.1.4 @@ -45,6 +46,7 @@ dependencies: - numba>=0.60.0 - numpy>=1.23.5,<3.0 - numpydoc +- openssl - pandas>=2.0 - pexpect - pip diff --git a/conda/environments/all_cuda-131_arch-x86_64.yaml b/conda/environments/all_cuda-131_arch-x86_64.yaml index 0fa31c7961..7372662c26 100644 --- a/conda/environments/all_cuda-131_arch-x86_64.yaml +++ b/conda/environments/all_cuda-131_arch-x86_64.yaml @@ -7,6 +7,7 @@ channels: dependencies: - breathe - bzip2 +- c-ares - c-compiler - ccache - clang-tools=20.1.4 @@ -45,6 +46,7 @@ dependencies: - numba>=0.60.0 - numpy>=1.23.5,<3.0 - numpydoc +- openssl - pandas>=2.0 - pexpect - pip diff --git a/conda/recipes/libcuopt/recipe.yaml b/conda/recipes/libcuopt/recipe.yaml index 8ee40d6f14..bf66617f50 100644 --- a/conda/recipes/libcuopt/recipe.yaml +++ b/conda/recipes/libcuopt/recipe.yaml @@ -29,7 +29,14 @@ cache: export CXXFLAGS=$(echo $CXXFLAGS | sed -E 's@\-fdebug\-prefix\-map[^ ]*@@g') set +x - ./build.sh -n -v ${BUILD_EXTRA_FLAGS} libmps_parser libcuopt deb --allgpuarch --cmake-args=\"-DCMAKE_INSTALL_LIBDIR=lib\" + # Workaround for abseil-cpp#1624: Mutex::Dtor() not exported from shared libabseil on Linux. + # Build gRPC/Protobuf/Abseil from source to ensure consistent ABI and avoid symbol issues. + GRPC_INSTALL_DIR="${SRC_DIR}/grpc_install" + + bash ./ci/utils/install_protobuf_grpc.sh --prefix="${GRPC_INSTALL_DIR}" --build-dir="${SRC_DIR}" --skip-deps + + export CMAKE_PREFIX_PATH="${GRPC_INSTALL_DIR}:${CMAKE_PREFIX_PATH:-}" + ./build.sh -n -v ${BUILD_EXTRA_FLAGS} libmps_parser libcuopt deb --allgpuarch --cmake-args=\"-DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_PREFIX_PATH=${GRPC_INSTALL_DIR}\" secrets: - AWS_ACCESS_KEY_ID - AWS_SECRET_ACCESS_KEY @@ -73,6 +80,7 @@ cache: - cmake ${{ cmake_version }} - make - ninja + - git - tbb-devel - zlib - bzip2 @@ -90,6 +98,8 @@ cache: - tbb-devel - zlib - bzip2 + - openssl + - c-ares outputs: - package: @@ -164,6 +174,9 @@ outputs: - libcublas - libcudss-dev >=0.7 - libcusparse-dev + - openssl + - c-ares + - libuuid run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - ${{ pin_subpackage("libmps-parser", exact=True) }} @@ -171,6 +184,9 @@ outputs: - librmm =${{ minor_version }} - cuda-nvrtc - libcudss + - openssl + - c-ares + - libuuid ignore_run_exports: by_name: - cuda-nvtx @@ -186,6 +202,7 @@ outputs: files: - lib/libcuopt.so - bin/cuopt_cli + - bin/cuopt_grpc_server about: homepage: ${{ load_from_file("python/cuopt/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/cuopt/pyproject.toml").project.license }} @@ -212,6 +229,8 @@ outputs: - libcublas - libcudss-dev >=0.7 - libcusparse-dev + - openssl + - c-ares run: - ${{ pin_subpackage("libcuopt", exact=True) }} - ${{ pin_subpackage("libmps-parser", exact=True) }} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8225d93655..972fd8152f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -255,6 +255,92 @@ create_logger_macros(CUOPT "cuopt::default_logger()" include/cuopt) find_package(CUDSS REQUIRED) +# ################################################################################################## +# - gRPC and Protobuf setup (REQUIRED) ------------------------------------------------------------ + +# gRPC is required for this branch - it provides remote execution features +# gRPC can come from either: +# - an installed CMake package (gRPCConfig.cmake), or +# - an in-tree build (e.g. python/libcuopt uses FetchContent(grpc), which defines gRPC::grpc++). +if(NOT TARGET gRPC::grpc++) + find_package(gRPC CONFIG REQUIRED) +endif() + +# Find Protobuf (should come with gRPC, but verify) +if(NOT TARGET protobuf::libprotobuf) + find_package(protobuf CONFIG REQUIRED) +endif() + +set(CUOPT_ENABLE_GRPC ON) +add_compile_definitions(CUOPT_ENABLE_GRPC) +message(STATUS "gRPC enabled (target gRPC::grpc++ is available)") + +# Find protoc compiler (provided by config package or target) +if(TARGET protobuf::protoc) + get_target_property(_PROTOBUF_PROTOC protobuf::protoc IMPORTED_LOCATION_RELEASE) + if(NOT _PROTOBUF_PROTOC) + get_target_property(_PROTOBUF_PROTOC protobuf::protoc IMPORTED_LOCATION) + endif() +else() + find_package(protobuf CONFIG REQUIRED) + get_target_property(_PROTOBUF_PROTOC protobuf::protoc IMPORTED_LOCATION_RELEASE) + if(NOT _PROTOBUF_PROTOC) + get_target_property(_PROTOBUF_PROTOC protobuf::protoc IMPORTED_LOCATION) + endif() +endif() + +if(NOT _PROTOBUF_PROTOC) + message(FATAL_ERROR "protoc not found (Protobuf_PROTOC_EXECUTABLE is empty)") +endif() + +# Find grpc_cpp_plugin +if(TARGET grpc_cpp_plugin) + set(_GRPC_CPP_PLUGIN_EXECUTABLE "$") +else() + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + if(NOT _GRPC_CPP_PLUGIN_EXECUTABLE) + message(FATAL_ERROR "grpc_cpp_plugin not found") + endif() +endif() + +# Generate C++ code from cuopt_remote.proto (base message definitions) +set(PROTO_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/grpc/cuopt_remote.proto") +set(PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/cuopt_remote.pb.cc") +set(PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/cuopt_remote.pb.h") + +add_custom_command( + OUTPUT "${PROTO_SRCS}" "${PROTO_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} + --proto_path ${CMAKE_CURRENT_SOURCE_DIR}/src/grpc + ${PROTO_FILE} + DEPENDS ${PROTO_FILE} + COMMENT "Generating C++ code from cuopt_remote.proto" + VERBATIM +) + +# Generate gRPC service code from cuopt_remote_service.proto +set(GRPC_PROTO_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/grpc/cuopt_remote_service.proto") +set(GRPC_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/cuopt_remote_service.pb.cc") +set(GRPC_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/cuopt_remote_service.pb.h") +set(GRPC_SERVICE_SRCS "${CMAKE_CURRENT_BINARY_DIR}/cuopt_remote_service.grpc.pb.cc") +set(GRPC_SERVICE_HDRS "${CMAKE_CURRENT_BINARY_DIR}/cuopt_remote_service.grpc.pb.h") + +add_custom_command( + OUTPUT "${GRPC_PROTO_SRCS}" "${GRPC_PROTO_HDRS}" "${GRPC_SERVICE_SRCS}" "${GRPC_SERVICE_HDRS}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} + --grpc_out ${CMAKE_CURRENT_BINARY_DIR} + --plugin=protoc-gen-grpc=${_GRPC_CPP_PLUGIN_EXECUTABLE} + --proto_path ${CMAKE_CURRENT_SOURCE_DIR}/src/grpc + ${GRPC_PROTO_FILE} + DEPENDS ${GRPC_PROTO_FILE} ${PROTO_FILE} + COMMENT "Generating gRPC C++ code from cuopt_remote_service.proto" + VERBATIM +) + +message(STATUS "gRPC protobuf code generation configured") + if(BUILD_TESTS) include(cmake/thirdparty/get_gtest.cmake) endif() @@ -264,6 +350,20 @@ add_subdirectory(src) if (HOST_LINEINFO) set_source_files_properties(${CUOPT_SRC_FILES} DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTIES COMPILE_OPTIONS "-g1") endif() + +# Add gRPC mapper files and generated protobuf sources +list(APPEND CUOPT_SRC_FILES + ${PROTO_SRCS} + ${GRPC_PROTO_SRCS} + ${GRPC_SERVICE_SRCS} + src/grpc/grpc_problem_mapper.cpp + src/grpc/grpc_solution_mapper.cpp + src/grpc/grpc_settings_mapper.cpp + src/grpc/grpc_service_mapper.cpp + src/grpc/client/grpc_client.cpp + src/grpc/client/solve_remote.cpp +) + add_library(cuopt SHARED ${CUOPT_SRC_FILES} ) @@ -315,6 +415,9 @@ target_include_directories(cuopt PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/../thirdparty" "${CMAKE_CURRENT_SOURCE_DIR}/src" + "${CMAKE_CURRENT_SOURCE_DIR}/src/grpc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/grpc/client" + "${CMAKE_CURRENT_BINARY_DIR}" "${CUDSS_INCLUDE}" PUBLIC "$" @@ -384,6 +487,8 @@ target_link_libraries(cuopt ${CUDSS_LIB_FILE} PRIVATE ${CUOPT_PRIVATE_CUDA_LIBS} + protobuf::libprotobuf + gRPC::grpc++ ) @@ -606,6 +711,66 @@ if(BUILD_LP_BENCHMARKS) endif() endif() +# ################################################################################################## +# - cuopt_grpc_server - gRPC-based remote server -------------------------------------------------- + +add_executable(cuopt_grpc_server + src/grpc/server/grpc_server_main.cpp + src/grpc/server/grpc_worker.cpp + src/grpc/server/grpc_worker_infra.cpp + src/grpc/server/grpc_server_threads.cpp + src/grpc/server/grpc_pipe_io.cpp + src/grpc/server/grpc_job_management.cpp + src/grpc/server/grpc_service_impl.cpp +) + +set_target_properties(cuopt_grpc_server + PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON + CXX_SCAN_FOR_MODULES OFF +) + +target_compile_options(cuopt_grpc_server + PRIVATE "$<$:${CUOPT_CXX_FLAGS}>" +) + +target_include_directories(cuopt_grpc_server + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/src" + "${CMAKE_CURRENT_SOURCE_DIR}/src/grpc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/grpc/server" + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_SOURCE_DIR}/libmps_parser/include" + "${CMAKE_CURRENT_BINARY_DIR}" + PUBLIC + "$" + "$" +) + +find_library(UUID_LIBRARY uuid REQUIRED) + +target_link_libraries(cuopt_grpc_server + PUBLIC + cuopt + OpenMP::OpenMP_CXX + PRIVATE + protobuf::libprotobuf + gRPC::grpc++ + ${UUID_LIBRARY} +) + +# Use RUNPATH when building locally +target_link_options(cuopt_grpc_server PRIVATE -Wl,--enable-new-dtags) +set_property(TARGET cuopt_grpc_server PROPERTY INSTALL_RPATH "$ORIGIN/../${lib_dir}") + +# Install the grpc server executable +install(TARGETS cuopt_grpc_server + COMPONENT runtime + RUNTIME DESTINATION ${_BIN_DEST} +) + +message(STATUS "Building cuopt_grpc_server (gRPC-based remote solve server)") # ################################################################################################## # - CPack has to be the last item in the cmake file------------------------------------------------- diff --git a/cpp/src/grpc/client/grpc_client.cpp b/cpp/src/grpc/client/grpc_client.cpp new file mode 100644 index 0000000000..a287ba8089 --- /dev/null +++ b/cpp/src/grpc/client/grpc_client.cpp @@ -0,0 +1,1192 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#include "grpc_client.hpp" + +#include +#include +#include +#include "grpc_problem_mapper.hpp" +#include "grpc_service_mapper.hpp" +#include "grpc_settings_mapper.hpp" +#include "grpc_solution_mapper.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuopt::linear_programming { + +// ============================================================================= +// Constants +// ============================================================================= + +constexpr int kDefaultRpcTimeoutSeconds = 60; // per-RPC deadline for short operations + +constexpr int64_t kMinMessageBytes = 4LL * 1024 * 1024; // 4 MiB floor for max_message_bytes + +// Protobuf's hard serialization limit is 2 GiB (int32 sizes internally). +// Reserve 1 MiB headroom for gRPC framing and internal bookkeeping. +constexpr int64_t kMaxMessageBytes = 2LL * 1024 * 1024 * 1024 - 1LL * 1024 * 1024; + +// ============================================================================= +// Debug Logging Helper +// ============================================================================= + +// Helper macro to log to debug callback if configured, otherwise to std::cerr. +// Only emits output when enable_debug_log is true or a debug_log_callback is set. +#define GRPC_CLIENT_DEBUG_LOG(config, msg) \ + do { \ + if (!(config).enable_debug_log && !(config).debug_log_callback) break; \ + std::ostringstream _oss; \ + _oss << msg; \ + std::string _msg_str = _oss.str(); \ + if ((config).debug_log_callback) { (config).debug_log_callback(_msg_str); } \ + if ((config).enable_debug_log) { std::cerr << _msg_str << "\n"; } \ + } while (0) + +// Structured throughput log for benchmarking. Parseable format: +// [THROUGHPUT] phase= bytes= elapsed_ms= throughput_mb_s= +#define GRPC_CLIENT_THROUGHPUT_LOG(config, phase_name, byte_count, start_time) \ + do { \ + auto _end = std::chrono::steady_clock::now(); \ + auto _ms = std::chrono::duration_cast(_end - (start_time)).count(); \ + double _sec = _ms / 1e6; \ + double _mb = static_cast(byte_count) / (1024.0 * 1024.0); \ + double _mbs = (_sec > 0.0) ? (_mb / _sec) : 0.0; \ + GRPC_CLIENT_DEBUG_LOG( \ + config, \ + "[THROUGHPUT] phase=" << (phase_name) << " bytes=" << (byte_count) << " elapsed_ms=" \ + << std::fixed << std::setprecision(1) << (_ms / 1000.0) \ + << " throughput_mb_s=" << std::setprecision(1) << _mbs); \ + } while (0) + +// Private implementation (PIMPL pattern to hide gRPC types) +struct grpc_client_t::impl_t { + std::shared_ptr channel; + // Use StubInterface to support both real stubs and mock stubs for testing + std::shared_ptr stub; + bool mock_mode = false; // Set to true when using injected mock stub +}; + +// All finite-duration RPCs (CheckStatus, SendArrayChunk, GetResult, etc.) +// use kDefaultRpcTimeoutSeconds (60s). Indefinite RPCs (StreamLogs, +// WaitForCompletion) omit the deadline entirely and rely on TryCancel or +// client-side polling for cancellation — a fixed deadline would kill +// legitimate long-running solves. +static void set_rpc_deadline(grpc::ClientContext& ctx, int timeout_seconds) +{ + if (timeout_seconds > 0) { + ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(timeout_seconds)); + } +} + +// ============================================================================= +// Test Helper Functions (for mock stub injection) +// ============================================================================= + +void grpc_test_inject_mock_stub(grpc_client_t& client, std::shared_ptr stub) +{ + // Cast from void* to StubInterface* - caller must ensure correct type + client.impl_->stub = + std::static_pointer_cast(stub); + client.impl_->mock_mode = true; +} + +void grpc_test_mark_as_connected(grpc_client_t& client) { client.impl_->mock_mode = true; } + +grpc_client_t::grpc_client_t(const grpc_client_config_t& config) + : impl_(std::make_unique()), config_(config) +{ + config_.max_message_bytes = + std::clamp(config_.max_message_bytes, kMinMessageBytes, kMaxMessageBytes); + if (config_.chunked_array_threshold_bytes >= 0) { + chunked_array_threshold_bytes_ = config_.chunked_array_threshold_bytes; + } else { + chunked_array_threshold_bytes_ = config_.max_message_bytes * 3 / 4; + } +} + +grpc_client_t::grpc_client_t(const std::string& server_address) : impl_(std::make_unique()) +{ + config_.server_address = server_address; + config_.max_message_bytes = + std::clamp(config_.max_message_bytes, kMinMessageBytes, kMaxMessageBytes); + chunked_array_threshold_bytes_ = config_.max_message_bytes * 3 / 4; +} + +grpc_client_t::~grpc_client_t() { stop_log_streaming(); } + +bool grpc_client_t::connect() +{ + std::shared_ptr creds; + + if (config_.enable_tls) { + grpc::SslCredentialsOptions ssl_opts; + + // Root CA certificates for verifying the server + if (!config_.tls_root_certs.empty()) { ssl_opts.pem_root_certs = config_.tls_root_certs; } + + // Client certificate and key for mTLS + if (!config_.tls_client_cert.empty() && !config_.tls_client_key.empty()) { + ssl_opts.pem_cert_chain = config_.tls_client_cert; + ssl_opts.pem_private_key = config_.tls_client_key; + } + + creds = grpc::SslCredentials(ssl_opts); + } else { + creds = grpc::InsecureChannelCredentials(); + } + + grpc::ChannelArguments channel_args; + const int channel_limit = static_cast(config_.max_message_bytes); + channel_args.SetMaxReceiveMessageSize(channel_limit); + channel_args.SetMaxSendMessageSize(channel_limit); + channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, config_.keepalive_time_ms); + channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, config_.keepalive_timeout_ms); + channel_args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + + impl_->channel = grpc::CreateCustomChannel(config_.server_address, creds, channel_args); + impl_->stub = cuopt::remote::CuOptRemoteService::NewStub(impl_->channel); + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] Connecting to " << config_.server_address + << (config_.enable_tls ? " (TLS)" : "")); + + // Verify connectivity with a lightweight RPC probe. Channel-level checks like + // WaitForConnected are unreliable (gRPC lazy connection on localhost can + // report READY even without a server). A real RPC with a deadline is the + // only reliable way to confirm the server is reachable. + { + grpc::ClientContext probe_ctx; + probe_ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(5)); + cuopt::remote::StatusRequest probe_req; + probe_req.set_job_id("__connection_probe__"); + cuopt::remote::StatusResponse probe_resp; + auto probe_status = impl_->stub->CheckStatus(&probe_ctx, probe_req, &probe_resp); + + auto code = probe_status.error_code(); + if (code != grpc::StatusCode::OK && code != grpc::StatusCode::NOT_FOUND) { + last_error_ = "Failed to connect to server at " + config_.server_address + " (" + + probe_status.error_message() + ")"; + GRPC_CLIENT_DEBUG_LOG(config_, "[grpc_client] Connection failed: " << last_error_); + return false; + } + } + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] Connected successfully to " << config_.server_address); + return true; +} + +bool grpc_client_t::is_connected() const +{ + // In mock mode, we're always "connected" if a stub is present + if (impl_->mock_mode) { return impl_->stub != nullptr; } + + if (!impl_->channel) return false; + auto state = impl_->channel->GetState(false); + return state == GRPC_CHANNEL_READY || state == GRPC_CHANNEL_IDLE; +} + +void grpc_client_t::start_log_streaming(const std::string& job_id) +{ + if (!config_.stream_logs || !config_.log_callback) return; + + if (log_thread_ && log_thread_->joinable()) { + stop_logs_.store(true); + { + std::lock_guard lk(log_context_mutex_); + if (active_log_context_) { + static_cast(active_log_context_)->TryCancel(); + } + } + log_thread_->join(); + log_thread_.reset(); + } + + stop_logs_.store(false); + log_thread_ = std::make_unique([this, job_id]() { + stream_logs(job_id, 0, [this](const std::string& line, bool /*job_complete*/) { + if (stop_logs_.load()) return false; + if (config_.log_callback) { config_.log_callback(line); } + return true; + }); + }); +} + +void grpc_client_t::stop_log_streaming() +{ + stop_logs_.store(true); + // Cancel the in-flight streaming RPC so reader->Read() returns false + // immediately instead of blocking until the server sends a message. + { + std::lock_guard lk(log_context_mutex_); + if (active_log_context_) { + static_cast(active_log_context_)->TryCancel(); + } + } + // Move to local so we can join without racing against other callers. + // TryCancel above guarantees the thread will unblock promptly. + std::unique_ptr t; + std::swap(t, log_thread_); + if (t && t->joinable()) { t->join(); } +} + +// ============================================================================= +// Proto → Client Enum Conversion +// ============================================================================= + +job_status_t map_proto_job_status(cuopt::remote::JobStatus proto_status) +{ + switch (proto_status) { + case cuopt::remote::QUEUED: return job_status_t::QUEUED; + case cuopt::remote::PROCESSING: return job_status_t::PROCESSING; + case cuopt::remote::COMPLETED: return job_status_t::COMPLETED; + case cuopt::remote::FAILED: return job_status_t::FAILED; + case cuopt::remote::CANCELLED: return job_status_t::CANCELLED; + default: return job_status_t::NOT_FOUND; + } +} + +// ============================================================================= +// Async Job Management Operations +// ============================================================================= + +job_status_result_t grpc_client_t::check_status(const std::string& job_id) +{ + job_status_result_t result; + + if (!impl_->stub) { + result.error_message = "Not connected to server"; + return result; + } + + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + auto request = build_status_request(job_id); + cuopt::remote::StatusResponse response; + auto status = impl_->stub->CheckStatus(&context, request, &response); + + if (!status.ok()) { + result.error_message = "CheckStatus failed: " + status.error_message(); + return result; + } + + result.success = true; + result.message = response.message(); + result.result_size_bytes = response.result_size_bytes(); + + // Track server max message size + if (response.max_message_bytes() > 0) { + server_max_message_bytes_.store(response.max_message_bytes(), std::memory_order_relaxed); + } + + result.status = map_proto_job_status(response.job_status()); + + return result; +} + +job_status_result_t grpc_client_t::wait_for_completion(const std::string& job_id) +{ + job_status_result_t result; + + if (!impl_->stub) { + result.error_message = "Not connected to server"; + return result; + } + + grpc::ClientContext context; + // No RPC deadline: WaitForCompletion blocks until the solver finishes, + // which may exceed any fixed timeout. The server detects client + // disconnect (context->IsCancelled), and the production path uses + // poll_for_completion which has its own config_.timeout_seconds loop. + cuopt::remote::WaitRequest request; + request.set_job_id(job_id); + cuopt::remote::WaitResponse response; + + auto status = impl_->stub->WaitForCompletion(&context, request, &response); + + if (!status.ok()) { + result.error_message = "WaitForCompletion failed: " + status.error_message(); + return result; + } + + result.success = true; + result.message = response.message(); + result.result_size_bytes = response.result_size_bytes(); + + result.status = map_proto_job_status(response.job_status()); + + return result; +} + +cancel_result_t grpc_client_t::cancel_job(const std::string& job_id) +{ + cancel_result_t result; + + if (!impl_->stub) { + result.error_message = "Not connected to server"; + return result; + } + + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + auto request = build_cancel_request(job_id); + cuopt::remote::CancelResponse response; + auto status = impl_->stub->CancelJob(&context, request, &response); + + if (!status.ok()) { + result.error_message = "CancelJob failed: " + status.error_message(); + return result; + } + + result.success = (response.status() == cuopt::remote::SUCCESS); + result.message = response.message(); + + result.job_status = map_proto_job_status(response.job_status()); + + return result; +} + +bool grpc_client_t::delete_job(const std::string& job_id) +{ + if (!impl_->stub) { + last_error_ = "Not connected to server"; + return false; + } + + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + cuopt::remote::DeleteRequest request; + request.set_job_id(job_id); + cuopt::remote::DeleteResponse response; + auto status = impl_->stub->DeleteResult(&context, request, &response); + + if (!status.ok()) { + last_error_ = "DeleteResult RPC failed: " + status.error_message(); + return false; + } + + // Check response status - job must exist to be deleted + if (response.status() == cuopt::remote::ERROR_NOT_FOUND) { + last_error_ = "Job not found: " + job_id; + return false; + } + + if (response.status() != cuopt::remote::SUCCESS) { + last_error_ = "DeleteResult failed: " + response.message(); + return false; + } + + return true; +} + +incumbents_result_t grpc_client_t::get_incumbents(const std::string& job_id, + int64_t from_index, + int32_t max_count) +{ + incumbents_result_t result; + + if (!impl_->stub) { + result.error_message = "Not connected to server"; + return result; + } + + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + cuopt::remote::IncumbentRequest request; + request.set_job_id(job_id); + request.set_from_index(from_index); + request.set_max_count(max_count); + + cuopt::remote::IncumbentResponse response; + auto status = impl_->stub->GetIncumbents(&context, request, &response); + + if (!status.ok()) { + result.error_message = "GetIncumbents failed: " + status.error_message(); + return result; + } + + result.success = true; + result.next_index = response.next_index(); + result.job_complete = response.job_complete(); + + for (const auto& inc : response.incumbents()) { + incumbent_t entry; + entry.index = inc.index(); + entry.objective = inc.objective(); + entry.assignment.reserve(inc.assignment_size()); + for (int i = 0; i < inc.assignment_size(); ++i) { + entry.assignment.push_back(inc.assignment(i)); + } + result.incumbents.push_back(std::move(entry)); + } + + return result; +} + +bool grpc_client_t::stream_logs( + const std::string& job_id, + int64_t from_byte, + std::function callback) +{ + if (!impl_->stub) { + last_error_ = "Not connected to server"; + return false; + } + + grpc::ClientContext context; + // No RPC deadline here: this stream stays open for the entire solve, which + // can exceed any fixed timeout. Shutdown is via TryCancel from + // stop_log_streaming(), not a deadline. + cuopt::remote::StreamLogsRequest request; + request.set_job_id(job_id); + request.set_from_byte(from_byte); + + // Publish this context so stop_log_streaming() can TryCancel it from + // another thread. The mutex ensures the pointer is never dangling: + // we clear it under the same lock before `context` goes out of scope. + { + std::lock_guard lk(log_context_mutex_); + active_log_context_ = &context; + } + + auto reader = impl_->stub->StreamLogs(&context, request); + + cuopt::remote::LogMessage log_msg; + while (reader->Read(&log_msg)) { + bool should_continue = callback(log_msg.line(), log_msg.job_complete()); + if (!should_continue) { + context.TryCancel(); + break; + } + if (log_msg.job_complete()) { break; } + } + + auto status = reader->Finish(); + + { + std::lock_guard lk(log_context_mutex_); + active_log_context_ = nullptr; + } + + return status.ok() || status.error_code() == grpc::StatusCode::CANCELLED; +} + +bool grpc_client_t::submit_unary(const cuopt::remote::SubmitJobRequest& request, + std::string& job_id_out) +{ + job_id_out.clear(); + + if (!impl_->stub) { + last_error_ = "Not connected to server"; + return false; + } + + auto t0 = std::chrono::steady_clock::now(); + + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + cuopt::remote::SubmitJobResponse response; + auto status = impl_->stub->SubmitJob(&context, request, &response); + + GRPC_CLIENT_THROUGHPUT_LOG(config_, "upload_unary", request.ByteSizeLong(), t0); + + if (!status.ok()) { + last_error_ = "SubmitJob failed: " + status.error_message(); + return false; + } + + job_id_out = response.job_id(); + if (job_id_out.empty()) { + last_error_ = "SubmitJob succeeded but no job_id returned"; + return false; + } + + GRPC_CLIENT_DEBUG_LOG(config_, "[grpc_client] Unary submit succeeded, job_id=" << job_id_out); + return true; +} + +// ============================================================================= +// Async Submit and Get Result +// ============================================================================= + +template +submit_result_t grpc_client_t::submit_lp(const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings) +{ + submit_result_t result; + + GRPC_CLIENT_DEBUG_LOG(config_, "[grpc_client] submit_lp: starting submission"); + + if (!is_connected()) { + result.error_message = "Not connected to server"; + GRPC_CLIENT_DEBUG_LOG(config_, "[grpc_client] submit_lp: not connected to server"); + return result; + } + + // Check if chunked array upload should be used + bool use_chunked = false; + if (chunked_array_threshold_bytes_ >= 0) { + size_t est = estimate_problem_proto_size(problem); + use_chunked = (static_cast(est) > chunked_array_threshold_bytes_); + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] submit_lp: estimated_size=" + << est << " threshold=" << chunked_array_threshold_bytes_ + << " use_chunked=" << use_chunked); + } + + if (use_chunked) { + cuopt::remote::ChunkedProblemHeader header; + populate_chunked_header_lp(problem, settings, &header); + if (!upload_chunked_arrays(problem, header, result.job_id)) { + result.error_message = last_error_; + return result; + } + } else { + auto submit_request = build_lp_submit_request(problem, settings); + if (!submit_unary(submit_request, result.job_id)) { + result.error_message = last_error_; + return result; + } + } + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] submit_lp: job submitted, job_id=" << result.job_id); + result.success = true; + return result; +} + +template +submit_result_t grpc_client_t::submit_mip(const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents) +{ + submit_result_t result; + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] submit_mip: starting submission" + << (enable_incumbents ? " (incumbents enabled)" : "")); + + if (!is_connected()) { + result.error_message = "Not connected to server"; + return result; + } + + bool use_chunked = false; + if (chunked_array_threshold_bytes_ >= 0) { + size_t est = estimate_problem_proto_size(problem); + use_chunked = (static_cast(est) > chunked_array_threshold_bytes_); + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] submit_mip: estimated_size=" + << est << " threshold=" << chunked_array_threshold_bytes_ + << " use_chunked=" << use_chunked); + } + + if (use_chunked) { + cuopt::remote::ChunkedProblemHeader header; + populate_chunked_header_mip(problem, settings, enable_incumbents, &header); + if (!upload_chunked_arrays(problem, header, result.job_id)) { + result.error_message = last_error_; + return result; + } + } else { + auto submit_request = build_mip_submit_request(problem, settings, enable_incumbents); + if (!submit_unary(submit_request, result.job_id)) { + result.error_message = last_error_; + return result; + } + } + + GRPC_CLIENT_DEBUG_LOG( + config_, "[grpc_client] submit_mip: job submitted successfully, job_id=" << result.job_id); + result.success = true; + return result; +} + +template +remote_lp_result_t grpc_client_t::get_lp_result(const std::string& job_id) +{ + remote_lp_result_t result; + + if (!is_connected()) { + result.error_message = "Not connected to server"; + return result; + } + + downloaded_result_t dl; + if (!get_result_or_download(job_id, dl)) { + result.error_message = last_error_; + return result; + } + + if (dl.was_chunked) { + result.solution = std::make_unique>( + chunked_result_to_lp_solution(*dl.chunked_header, dl.chunked_arrays)); + } else { + result.solution = std::make_unique>( + map_proto_to_lp_solution(dl.response->lp_solution())); + } + result.success = true; + return result; +} + +template +remote_mip_result_t grpc_client_t::get_mip_result(const std::string& job_id) +{ + remote_mip_result_t result; + + if (!is_connected()) { + result.error_message = "Not connected to server"; + return result; + } + + downloaded_result_t dl; + if (!get_result_or_download(job_id, dl)) { + result.error_message = last_error_; + return result; + } + + if (dl.was_chunked) { + result.solution = std::make_unique>( + chunked_result_to_mip_solution(*dl.chunked_header, dl.chunked_arrays)); + } else { + result.solution = std::make_unique>( + map_proto_to_mip_solution(dl.response->mip_solution())); + } + result.success = true; + return result; +} + +// ============================================================================= +// Polling helper +// ============================================================================= + +grpc_client_t::poll_result_t grpc_client_t::poll_for_completion(const std::string& job_id) +{ + poll_result_t poll_result; + + int poll_count = 0; + int poll_ms = std::max(config_.poll_interval_ms, 1); + // timeout_seconds <= 0 means "wait indefinitely" — the solver's own + // time_limit (passed via settings) is the authoritative bound. + int max_polls; + if (config_.timeout_seconds > 0) { + int64_t total_ms = static_cast(config_.timeout_seconds) * 1000; + int64_t computed = total_ms / poll_ms; + max_polls = + static_cast(std::min(computed, static_cast(std::numeric_limits::max()))); + } else { + max_polls = std::numeric_limits::max(); + } + + int64_t incumbent_next_index = 0; + auto last_incumbent_poll = std::chrono::steady_clock::now(); + bool cancel_requested = false; + + while (poll_count < max_polls) { + std::this_thread::sleep_for(std::chrono::milliseconds(poll_ms)); + + if (cancel_requested) { + cancel_job(job_id); + poll_result.cancelled_by_callback = true; + poll_result.error_message = "Cancelled by incumbent callback"; + return poll_result; + } + + if (config_.incumbent_callback) { + auto now = std::chrono::steady_clock::now(); + auto ms_since_last = + std::chrono::duration_cast(now - last_incumbent_poll).count(); + if (ms_since_last >= config_.incumbent_poll_interval_ms) { + auto inc_result = get_incumbents(job_id, incumbent_next_index, 0); + if (inc_result.success) { + for (const auto& inc : inc_result.incumbents) { + bool should_continue = + config_.incumbent_callback(inc.index, inc.objective, inc.assignment); + if (!should_continue) { + cancel_requested = true; + break; + } + } + incumbent_next_index = inc_result.next_index; + } + last_incumbent_poll = now; + } + } + + auto status_result = check_status(job_id); + if (!status_result.success) { + poll_result.error_message = status_result.error_message; + return poll_result; + } + + switch (status_result.status) { + case job_status_t::COMPLETED: poll_result.completed = true; break; + case job_status_t::FAILED: + poll_result.error_message = "Job failed: " + status_result.message; + return poll_result; + case job_status_t::CANCELLED: + poll_result.error_message = "Job was cancelled"; + return poll_result; + default: break; + } + + if (poll_result.completed) break; + poll_count++; + } + + // Drain any incumbents that arrived between the last poll and job completion. + if (config_.incumbent_callback && poll_result.completed) { + auto inc_result = get_incumbents(job_id, incumbent_next_index, 0); + if (inc_result.success) { + for (const auto& inc : inc_result.incumbents) { + config_.incumbent_callback(inc.index, inc.objective, inc.assignment); + } + } + } + + if (!poll_result.completed && poll_result.error_message.empty()) { + poll_result.error_message = "Timeout waiting for job completion"; + } + + return poll_result; +} + +// ============================================================================= +// End-to-end solve helpers +// ============================================================================= + +template +remote_lp_result_t grpc_client_t::solve_lp( + const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings) +{ + auto solve_t0 = std::chrono::steady_clock::now(); + + auto sub = submit_lp(problem, settings); + if (!sub.success) { return {.error_message = sub.error_message}; } + + start_log_streaming(sub.job_id); + auto poll = poll_for_completion(sub.job_id); + stop_log_streaming(); + + if (!poll.completed) { return {.error_message = poll.error_message}; } + + auto result = get_lp_result(sub.job_id); + if (result.success) { delete_job(sub.job_id); } + + GRPC_CLIENT_THROUGHPUT_LOG(config_, "end_to_end_lp", 0, solve_t0); + + return result; +} + +template +remote_mip_result_t grpc_client_t::solve_mip( + const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents) +{ + auto solve_t0 = std::chrono::steady_clock::now(); + + bool track_incumbents = enable_incumbents || (config_.incumbent_callback != nullptr); + + auto sub = submit_mip(problem, settings, track_incumbents); + if (!sub.success) { return {.error_message = sub.error_message}; } + + start_log_streaming(sub.job_id); + auto poll = poll_for_completion(sub.job_id); + stop_log_streaming(); + + if (!poll.completed) { return {.error_message = poll.error_message}; } + + auto result = get_mip_result(sub.job_id); + if (result.success) { delete_job(sub.job_id); } + + GRPC_CLIENT_THROUGHPUT_LOG(config_, "end_to_end_mip", 0, solve_t0); + + return result; +} + +// ============================================================================= +// Chunked Transfer utils (upload and download) +// ============================================================================= + +template +bool grpc_client_t::upload_chunked_arrays(const cpu_optimization_problem_t& problem, + const cuopt::remote::ChunkedProblemHeader& header, + std::string& job_id_out) +{ + job_id_out.clear(); + auto upload_t0 = std::chrono::steady_clock::now(); + + // --- 1. StartChunkedUpload --- + std::string upload_id; + { + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + cuopt::remote::StartChunkedUploadRequest request; + *request.mutable_problem_header() = header; + + cuopt::remote::StartChunkedUploadResponse response; + auto status = impl_->stub->StartChunkedUpload(&context, request, &response); + + if (!status.ok()) { + last_error_ = "StartChunkedUpload failed: " + status.error_message(); + return false; + } + + upload_id = response.upload_id(); + if (response.max_message_bytes() > 0) { + server_max_message_bytes_.store(response.max_message_bytes(), std::memory_order_relaxed); + } + } + + GRPC_CLIENT_DEBUG_LOG(config_, "[grpc_client] ChunkedUpload started, upload_id=" << upload_id); + + // --- 2. Build chunk requests directly from problem arrays --- + int64_t chunk_data_budget = config_.chunk_size_bytes; + if (chunk_data_budget <= 0) { chunk_data_budget = 1LL * 1024 * 1024; } + int64_t srv_max = server_max_message_bytes_.load(std::memory_order_relaxed); + if (srv_max > 0 && chunk_data_budget > srv_max * 9 / 10) { chunk_data_budget = srv_max * 9 / 10; } + + auto chunk_requests = build_array_chunk_requests(problem, upload_id, chunk_data_budget); + + // --- 3. Send each chunk request --- + int total_chunks = 0; + int64_t total_bytes_sent = 0; + + for (auto& chunk_request : chunk_requests) { + grpc::ClientContext chunk_context; + set_rpc_deadline(chunk_context, kDefaultRpcTimeoutSeconds); + cuopt::remote::SendArrayChunkResponse chunk_response; + auto status = impl_->stub->SendArrayChunk(&chunk_context, chunk_request, &chunk_response); + + if (!status.ok()) { + last_error_ = "SendArrayChunk failed: " + status.error_message(); + return false; + } + + total_bytes_sent += chunk_request.chunk().data().size(); + ++total_chunks; + } + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] ChunkedUpload sent " << total_chunks << " chunk requests"); + + // --- 4. FinishChunkedUpload --- + { + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + cuopt::remote::FinishChunkedUploadRequest request; + request.set_upload_id(upload_id); + + cuopt::remote::SubmitJobResponse response; + auto status = impl_->stub->FinishChunkedUpload(&context, request, &response); + + if (!status.ok()) { + last_error_ = "FinishChunkedUpload failed: " + status.error_message(); + return false; + } + + job_id_out = response.job_id(); + } + + GRPC_CLIENT_THROUGHPUT_LOG(config_, "upload_chunked", total_bytes_sent, upload_t0); + GRPC_CLIENT_DEBUG_LOG( + config_, + "[grpc_client] ChunkedUpload complete: " << total_chunks << " chunks, job_id=" << job_id_out); + return true; +} + +bool grpc_client_t::get_result_or_download(const std::string& job_id, + downloaded_result_t& result_out) +{ + result_out = downloaded_result_t{}; + + if (!impl_->stub) { + last_error_ = "Not connected to server"; + return false; + } + + int64_t result_size_hint = 0; + { + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + auto request = build_status_request(job_id); + cuopt::remote::StatusResponse response; + auto status = impl_->stub->CheckStatus(&context, request, &response); + + if (status.ok()) { + result_size_hint = response.result_size_bytes(); + if (response.max_message_bytes() > 0) { + server_max_message_bytes_.store(response.max_message_bytes(), std::memory_order_relaxed); + } + } + } + + int64_t srv_max_msg = server_max_message_bytes_.load(std::memory_order_relaxed); + int64_t effective_max = config_.max_message_bytes; + if (srv_max_msg > 0 && srv_max_msg < effective_max) { effective_max = srv_max_msg; } + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] get_result_or_download: result_size_hint=" + << result_size_hint << " bytes, client_max=" << config_.max_message_bytes + << ", server_max=" << srv_max_msg << ", effective_max=" << effective_max); + + if (result_size_hint > 0 && effective_max > 0 && result_size_hint > effective_max) { + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] Using chunked download directly (result_size_hint=" + << result_size_hint << " > effective_max=" << effective_max << ")"); + return download_chunked_result(job_id, result_out); + } + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] Attempting unary GetResult (result_size_hint=" + << result_size_hint << " <= effective_max=" << effective_max << ")"); + + auto download_t0 = std::chrono::steady_clock::now(); + + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + auto request = build_get_result_request(job_id); + auto response = std::make_unique(); + auto status = impl_->stub->GetResult(&context, request, response.get()); + + if (status.ok() && response->status() == cuopt::remote::SUCCESS) { + if (response->has_lp_solution() || response->has_mip_solution()) { + GRPC_CLIENT_THROUGHPUT_LOG(config_, "download_unary", response->ByteSizeLong(), download_t0); + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] Unary GetResult succeeded, result_size=" + << response->ByteSizeLong() << " bytes"); + result_out.was_chunked = false; + result_out.response = std::move(response); + return true; + } + last_error_ = "GetResult succeeded but no solution in response"; + return false; + } + + if (status.error_code() == grpc::StatusCode::RESOURCE_EXHAUSTED) { + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] GetResult rejected (RESOURCE_EXHAUSTED), " + "falling back to chunked download"); + return download_chunked_result(job_id, result_out); + } + + if (!status.ok()) { + last_error_ = "GetResult failed: " + status.error_message(); + } else if (response->status() != cuopt::remote::SUCCESS) { + last_error_ = "GetResult indicates failure: " + response->error_message(); + } + return false; +} + +bool grpc_client_t::download_chunked_result(const std::string& job_id, + downloaded_result_t& result_out) +{ + result_out.was_chunked = true; + result_out.chunked_arrays.clear(); + auto download_t0 = std::chrono::steady_clock::now(); + + GRPC_CLIENT_DEBUG_LOG(config_, "[grpc_client] Starting chunked download for job " << job_id); + + // --- 1. StartChunkedDownload --- + std::string download_id; + auto header = std::make_unique(); + { + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + cuopt::remote::StartChunkedDownloadRequest request; + request.set_job_id(job_id); + + cuopt::remote::StartChunkedDownloadResponse response; + auto status = impl_->stub->StartChunkedDownload(&context, request, &response); + + if (!status.ok()) { + last_error_ = "StartChunkedDownload failed: " + status.error_message(); + return false; + } + + download_id = response.download_id(); + *header = response.header(); + if (response.max_message_bytes() > 0) { + server_max_message_bytes_.store(response.max_message_bytes(), std::memory_order_relaxed); + } + } + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] ChunkedDownload started, download_id=" + << download_id << " arrays=" << header->arrays_size() + << " is_mip=" << header->is_mip()); + + // --- 2. Fetch each array via GetResultChunk RPCs --- + int64_t chunk_data_budget = config_.chunk_size_bytes; + if (chunk_data_budget <= 0) { chunk_data_budget = 1LL * 1024 * 1024; } + int64_t dl_srv_max = server_max_message_bytes_.load(std::memory_order_relaxed); + if (dl_srv_max > 0 && chunk_data_budget > dl_srv_max * 9 / 10) { + chunk_data_budget = dl_srv_max * 9 / 10; + } + + int total_chunks = 0; + int64_t total_bytes_received = 0; + + for (const auto& arr_desc : header->arrays()) { + auto field_id = arr_desc.field_id(); + int64_t total_elems = arr_desc.total_elements(); + int64_t elem_size = arr_desc.element_size_bytes(); + if (total_elems <= 0) continue; + + if (elem_size <= 0) { + last_error_ = "Invalid chunk metadata: non-positive element_size_bytes for field " + + std::to_string(field_id); + return false; + } + // Guard against total_elems * elem_size overflowing int64_t (both are + // positive at this point, so dividing INT64_MAX is safe and avoids the + // signed/unsigned pitfall of casting SIZE_MAX to int64_t). + if (total_elems > std::numeric_limits::max() / elem_size) { + last_error_ = + "Invalid chunk metadata: total byte size overflow for field " + std::to_string(field_id); + return false; + } + + int64_t elems_per_chunk = chunk_data_budget / elem_size; + if (elems_per_chunk <= 0) elems_per_chunk = 1; + + std::vector array_bytes(static_cast(total_elems * elem_size)); + + for (int64_t elem_offset = 0; elem_offset < total_elems; elem_offset += elems_per_chunk) { + int64_t elems_wanted = std::min(elems_per_chunk, total_elems - elem_offset); + + grpc::ClientContext chunk_ctx; + set_rpc_deadline(chunk_ctx, kDefaultRpcTimeoutSeconds); + cuopt::remote::GetResultChunkRequest chunk_req; + chunk_req.set_download_id(download_id); + chunk_req.set_field_id(field_id); + chunk_req.set_element_offset(elem_offset); + chunk_req.set_max_elements(elems_wanted); + + cuopt::remote::GetResultChunkResponse chunk_resp; + auto status = impl_->stub->GetResultChunk(&chunk_ctx, chunk_req, &chunk_resp); + + if (!status.ok()) { + last_error_ = "GetResultChunk failed: " + status.error_message(); + return false; + } + + int64_t elems_received = chunk_resp.elements_in_chunk(); + const auto& data = chunk_resp.data(); + + if (elems_received < 0 || elems_received > elems_wanted || + elems_received > total_elems - elem_offset) { + last_error_ = "GetResultChunk: invalid element count"; + return false; + } + if (static_cast(data.size()) != elems_received * elem_size) { + last_error_ = "GetResultChunk: data size mismatch"; + return false; + } + + std::memcpy(array_bytes.data() + elem_offset * elem_size, data.data(), data.size()); + total_bytes_received += static_cast(data.size()); + ++total_chunks; + } + + result_out.chunked_arrays[static_cast(field_id)] = std::move(array_bytes); + } + + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] ChunkedDownload fetched " + << total_chunks << " chunks for " << header->arrays_size() << " arrays"); + + // --- 3. FinishChunkedDownload --- + { + grpc::ClientContext context; + set_rpc_deadline(context, kDefaultRpcTimeoutSeconds); + cuopt::remote::FinishChunkedDownloadRequest request; + request.set_download_id(download_id); + + cuopt::remote::FinishChunkedDownloadResponse response; + auto status = impl_->stub->FinishChunkedDownload(&context, request, &response); + + if (!status.ok()) { + GRPC_CLIENT_DEBUG_LOG( + config_, "[grpc_client] FinishChunkedDownload warning: " << status.error_message()); + } + } + + result_out.chunked_header = std::move(header); + + GRPC_CLIENT_THROUGHPUT_LOG(config_, "download_chunked", total_bytes_received, download_t0); + GRPC_CLIENT_DEBUG_LOG(config_, + "[grpc_client] ChunkedDownload complete: " + << total_chunks << " chunks, " << total_bytes_received << " bytes"); + + return true; +} + +// Explicit template instantiations +#if CUOPT_INSTANTIATE_FLOAT +template remote_lp_result_t grpc_client_t::solve_lp( + const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings); +template remote_mip_result_t grpc_client_t::solve_mip( + const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents); +template submit_result_t grpc_client_t::submit_lp( + const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings); +template submit_result_t grpc_client_t::submit_mip( + const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents); +template remote_lp_result_t grpc_client_t::get_lp_result(const std::string& job_id); +template remote_mip_result_t grpc_client_t::get_mip_result( + const std::string& job_id); +template bool grpc_client_t::upload_chunked_arrays( + const cpu_optimization_problem_t& problem, + const cuopt::remote::ChunkedProblemHeader& header, + std::string& job_id_out); +#endif + +#if CUOPT_INSTANTIATE_DOUBLE +template remote_lp_result_t grpc_client_t::solve_lp( + const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings); +template remote_mip_result_t grpc_client_t::solve_mip( + const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents); +template submit_result_t grpc_client_t::submit_lp( + const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings); +template submit_result_t grpc_client_t::submit_mip( + const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents); +template remote_lp_result_t grpc_client_t::get_lp_result( + const std::string& job_id); +template remote_mip_result_t grpc_client_t::get_mip_result( + const std::string& job_id); +template bool grpc_client_t::upload_chunked_arrays( + const cpu_optimization_problem_t& problem, + const cuopt::remote::ChunkedProblemHeader& header, + std::string& job_id_out); +#endif + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/client/grpc_client.hpp b/cpp/src/grpc/client/grpc_client.hpp new file mode 100644 index 0000000000..f8579b3271 --- /dev/null +++ b/cpp/src/grpc/client/grpc_client.hpp @@ -0,0 +1,483 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declarations for gRPC types (to avoid exposing gRPC headers in public API) +namespace grpc { +class Channel; +} + +namespace cuopt::remote { +class CuOptRemoteService; +class ChunkedProblemHeader; +class ChunkedResultHeader; +class ResultResponse; +class SubmitJobRequest; +} // namespace cuopt::remote + +namespace cuopt::linear_programming { + +// Forward declarations for test helper functions (implemented in grpc_client.cpp) +void grpc_test_inject_mock_stub(class grpc_client_t& client, std::shared_ptr stub); +void grpc_test_mark_as_connected(class grpc_client_t& client); + +/** + * @brief Configuration options for the gRPC client + * + * Large Problem Handling: + * - Small problems use unary SubmitJob (single message). If the estimated + * serialized size exceeds 75% of max_message_bytes, the client automatically + * switches to the chunked array protocol: StartChunkedUpload + N × SendArrayChunk + * + FinishChunkedUpload (all unary RPCs). This bypasses the protobuf 2GB limit and + * reduces peak memory usage. + * - chunk_size_bytes controls the payload size of individual SendArrayChunk calls. + * - Result retrieval uses chunked download for results exceeding max_message_bytes. + */ +struct grpc_client_config_t { + std::string server_address = "localhost:8765"; + int poll_interval_ms = 1000; // How often to poll for job status + int timeout_seconds = 0; // Max time to wait for job completion (0 = no limit) + bool stream_logs = false; // Whether to stream logs from server + std::function log_callback = nullptr; // Called for each log line + + // Incumbent callback for MIP solves — invoked each time the server finds a + // new best-feasible (incumbent) solution. Parameters: index, objective value, + // solution vector. Return true to continue solving, or false to request + // early termination (e.g. the objective is good enough for the caller's + // purposes). + std::function& solution)> + incumbent_callback = nullptr; + int incumbent_poll_interval_ms = 1000; // How often to poll for new incumbents + + // TLS configuration + bool enable_tls = false; + std::string tls_root_certs; // PEM-encoded root CA certificates (for verifying server) + std::string tls_client_cert; // PEM-encoded client certificate (for mTLS) + std::string tls_client_key; // PEM-encoded client private key (for mTLS) + + // gRPC max message size (used for unary SubmitJob and result download decisions). + // Clamped at construction to [4 MiB, 2 GiB - 1 MiB] (protobuf serialization limit). + int64_t max_message_bytes = 256LL * 1024 * 1024; // 256 MiB + + // Chunk size for chunked array upload and chunked result download. + int64_t chunk_size_bytes = 16LL * 1024 * 1024; // 16 MiB + + // gRPC keepalive — periodic HTTP/2 PINGs to detect dead connections. + int keepalive_time_ms = 30000; // send PING every 30s of inactivity + int keepalive_timeout_ms = 10000; // wait 10s for PONG before declaring dead + + // --- Test / debug options (not intended for production use) ----------------- + + // Receives internal client debug messages (for test verification). + std::function debug_log_callback = nullptr; + + // Enable debug / throughput logging to stderr. + // Controlled by CUOPT_GRPC_DEBUG env var (0|1). Default: off. + bool enable_debug_log = false; + + // Log FNV-1a hashes of uploaded/downloaded data on both client and server. + // Comparing the two hashes confirms data was not corrupted in transit. + bool enable_transfer_hash = false; + + // Override for the chunked upload threshold (bytes). Normally computed + // automatically as 75% of max_message_bytes. Set to 0 to force chunked + // upload for all problems, or a positive value to override. -1 = auto. + int64_t chunked_array_threshold_bytes = -1; +}; + +/** + * @brief Job status enum (transport-agnostic) + */ +enum class job_status_t { QUEUED, PROCESSING, COMPLETED, FAILED, CANCELLED, NOT_FOUND }; + +/** + * @brief Convert job status to string + */ +inline const char* job_status_to_string(job_status_t status) +{ + switch (status) { + case job_status_t::QUEUED: return "QUEUED"; + case job_status_t::PROCESSING: return "PROCESSING"; + case job_status_t::COMPLETED: return "COMPLETED"; + case job_status_t::FAILED: return "FAILED"; + case job_status_t::CANCELLED: return "CANCELLED"; + case job_status_t::NOT_FOUND: return "NOT_FOUND"; + default: return "UNKNOWN"; + } +} + +/** + * @brief Result of a job status check + */ +struct job_status_result_t { + bool success = false; + std::string error_message; + job_status_t status = job_status_t::NOT_FOUND; + std::string message; + int64_t result_size_bytes = 0; +}; + +/** + * @brief Result of a submit operation (job ID) + */ +struct submit_result_t { + bool success = false; + std::string error_message; + std::string job_id; +}; + +/** + * @brief Result of a cancel operation + */ +struct cancel_result_t { + bool success = false; + std::string error_message; + job_status_t job_status = job_status_t::NOT_FOUND; + std::string message; +}; + +/** + * @brief Incumbent solution entry + */ +struct incumbent_t { + int64_t index = 0; + double objective = 0.0; + std::vector assignment; +}; + +/** + * @brief Result of get incumbents operation + */ +struct incumbents_result_t { + bool success = false; + std::string error_message; + std::vector incumbents; + int64_t next_index = 0; + bool job_complete = false; +}; + +/** + * @brief Result of a remote solve operation + */ +template +struct remote_lp_result_t { + bool success = false; + std::string error_message; + std::unique_ptr> solution; +}; + +template +struct remote_mip_result_t { + bool success = false; + std::string error_message; + std::unique_ptr> solution; +}; + +/** + * @brief gRPC client for remote cuOpt solving + * + * This class provides a high-level interface for submitting optimization problems + * to a remote cuopt_grpc_server and retrieving results. It handles: + * - Connection management + * - Job submission + * - Status polling + * - Optional log streaming + * - Result retrieval and parsing + * + * Usage: + * @code + * grpc_client_t client("localhost:8765"); + * if (!client.connect()) { ... handle error ... } + * + * auto result = client.solve_lp(problem, settings); + * if (result.success) { + * // Use result.solution + * } + * @endcode + * + * This class is designed to be used by: + * - Test clients for validation + * - solve_lp_remote() and solve_mip_remote() for production use + */ +class grpc_client_t { + // Allow test helpers to access internal implementation for mock injection + friend void grpc_test_inject_mock_stub(grpc_client_t&, std::shared_ptr); + friend void grpc_test_mark_as_connected(grpc_client_t&); + + public: + /** + * @brief Construct a gRPC client with configuration + * @param config Client configuration options + */ + explicit grpc_client_t(const grpc_client_config_t& config = grpc_client_config_t{}); + + /** + * @brief Construct a gRPC client with just server address + * @param server_address Server address in "host:port" format + */ + explicit grpc_client_t(const std::string& server_address); + + ~grpc_client_t(); + + // Non-copyable, non-movable (due to atomic member and thread) + grpc_client_t(const grpc_client_t&) = delete; + grpc_client_t& operator=(const grpc_client_t&) = delete; + grpc_client_t(grpc_client_t&&) = delete; + grpc_client_t& operator=(grpc_client_t&&) = delete; + + /** + * @brief Connect to the gRPC server + * @return true if connection successful + */ + bool connect(); + + /** + * @brief Check if connected to server + */ + bool is_connected() const; + + /** + * @brief Solve an LP problem remotely + * + * This is a blocking call that: + * 1. Submits the problem to the server + * 2. Polls for completion (with optional log streaming) + * 3. Retrieves and parses the result + * + * @param problem The CPU optimization problem to solve + * @param settings Solver settings + * @return Result containing success status and solution (if successful) + */ + template + remote_lp_result_t solve_lp(const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings); + + /** + * @brief Solve a MIP problem remotely + * + * This is a blocking call that: + * 1. Submits the problem to the server + * 2. Polls for completion (with optional log streaming) + * 3. Retrieves and parses the result + * + * @param problem The CPU optimization problem to solve + * @param settings Solver settings + * @param enable_incumbents Whether to enable incumbent solution streaming + * @return Result containing success status and solution (if successful) + */ + template + remote_mip_result_t solve_mip(const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents = false); + + // ========================================================================= + // Async Operations (for manual job management) + // ========================================================================= + + /** + * @brief Submit an LP problem without waiting for result + * @return Result containing job_id if successful + */ + template + submit_result_t submit_lp(const cpu_optimization_problem_t& problem, + const pdlp_solver_settings_t& settings); + + /** + * @brief Submit a MIP problem without waiting for result + * @return Result containing job_id if successful + */ + template + submit_result_t submit_mip(const cpu_optimization_problem_t& problem, + const mip_solver_settings_t& settings, + bool enable_incumbents = false); + + /** + * @brief Check status of a submitted job + * @param job_id The job ID to check + * @return Status result including job state and optional result size + */ + job_status_result_t check_status(const std::string& job_id); + + /** + * @brief Wait for a job to complete (blocking) + * + * This is more efficient than polling check_status() but does not + * return the result - call get_lp_result/get_mip_result afterward. + * + * @param job_id The job ID to wait for + * @return Status result when job completes (COMPLETED, FAILED, or CANCELLED) + */ + job_status_result_t wait_for_completion(const std::string& job_id); + + /** + * @brief Get LP result for a completed job + * @param job_id The job ID + * @return Result containing solution if successful + */ + template + remote_lp_result_t get_lp_result(const std::string& job_id); + + /** + * @brief Get MIP result for a completed job + * @param job_id The job ID + * @return Result containing solution if successful + */ + template + remote_mip_result_t get_mip_result(const std::string& job_id); + + /** + * @brief Cancel a running job + * @param job_id The job ID to cancel + * @return Cancel result with status + */ + cancel_result_t cancel_job(const std::string& job_id); + + /** + * @brief Delete a job and its results from server + * @param job_id The job ID to delete + * @return true if deletion successful + */ + bool delete_job(const std::string& job_id); + + /** + * @brief Get incumbent solutions for a MIP job + * @param job_id The job ID + * @param from_index Start from this incumbent index + * @param max_count Maximum number to return (0 = no limit) + * @return Incumbents result + */ + incumbents_result_t get_incumbents(const std::string& job_id, + int64_t from_index = 0, + int32_t max_count = 0); + + /** + * @brief Get the last error message + */ + const std::string& get_last_error() const { return last_error_; } + + // --- Test / debug public API ----------------------------------------------- + + /** + * @brief Stream logs for a job (blocking until job completes or callback returns false). + * + * This is a low-level, synchronous API for test tools and CLI utilities. + * Production callers should use config_.stream_logs + config_.log_callback + * instead, which streams logs automatically on a background thread during + * solve_lp / solve_mip calls. + * + * @param job_id The job ID + * @param from_byte Starting byte offset in log + * @param callback Called for each log line; return false to stop streaming + * @return true if streaming completed normally + */ + bool stream_logs(const std::string& job_id, + int64_t from_byte, + std::function callback); + + private: + struct impl_t; + std::unique_ptr impl_; + + grpc_client_config_t config_; + std::string last_error_; + + // Track server-reported max message size (may differ from our config). + // Accessed from multiple RPC methods; atomic to avoid data races. + std::atomic server_max_message_bytes_{0}; + + // 75% of max_message_bytes — computed at construction time. + int64_t chunked_array_threshold_bytes_ = 0; + + // Background log streaming for solve_lp / solve_mip (production path). + // Activated when config_.stream_logs is true and config_.log_callback is set. + void start_log_streaming(const std::string& job_id); + void stop_log_streaming(); + + // Shared polling loop used by solve_lp and solve_mip. + struct poll_result_t { + bool completed = false; + bool cancelled_by_callback = false; + std::string error_message; + }; + poll_result_t poll_for_completion(const std::string& job_id); + + std::unique_ptr log_thread_; + std::atomic stop_logs_{false}; + mutable std::mutex log_context_mutex_; + // Points to the grpc::ClientContext* of the in-flight StreamLogs RPC (if + // any). Typed as void* to avoid exposing grpc headers in the public API. + // Protected by log_context_mutex_; stop_log_streaming() calls TryCancel() + // through this pointer to unblock a stuck reader->Read(). + void* active_log_context_ = nullptr; + + // ========================================================================= + // Result Retrieval Support + // ========================================================================= + + /** + * @brief Result from get_result_or_download: either a unary ResultResponse or + * a chunked header + raw arrays map. Exactly one variant is populated. + */ + struct downloaded_result_t { + bool was_chunked = false; + + // Populated when was_chunked == false (unary path). + std::unique_ptr response; + + // Populated when was_chunked == true (chunked path). + std::unique_ptr chunked_header; + std::map> chunked_arrays; + }; + + /** + * @brief Get result, choosing unary GetResult or chunked download based on size. + * + * Returns a downloaded_result_t with either the unary ResultResponse or the + * chunked header + arrays map populated. + */ + bool get_result_or_download(const std::string& job_id, downloaded_result_t& result_out); + + /** + * @brief Download result via chunked unary RPCs (StartChunkedDownload + + * N × GetResultChunk + FinishChunkedDownload). + */ + bool download_chunked_result(const std::string& job_id, downloaded_result_t& result_out); + + // ========================================================================= + // Chunked Array Upload (for large problems) + // ========================================================================= + + /** + * @brief Submit a SubmitJobRequest via unary SubmitJob RPC. + */ + bool submit_unary(const cuopt::remote::SubmitJobRequest& request, std::string& job_id_out); + + /** + * @brief Upload a problem using chunked array RPCs (StartChunkedUpload + + * N × SendArrayChunk + FinishChunkedUpload). + */ + template + bool upload_chunked_arrays(const cpu_optimization_problem_t& problem, + const cuopt::remote::ChunkedProblemHeader& header, + std::string& job_id_out); +}; + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/client/solve_remote.cpp b/cpp/src/grpc/client/solve_remote.cpp new file mode 100644 index 0000000000..859def795e --- /dev/null +++ b/cpp/src/grpc/client/solve_remote.cpp @@ -0,0 +1,261 @@ +/* clang-format off */ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +/* clang-format on */ + +#include +#include +#include +#include +#include +#include "grpc_client.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace cuopt::linear_programming { + +// Buffer added to the solver's time_limit to account for worker startup, +// GPU init, and result pipe transfer. +constexpr int kTimeoutBufferSeconds = 120; + +// ============================================================================ +// Helper function to get gRPC server address from environment variables +// ============================================================================ + +static std::string get_grpc_server_address() +{ + const char* host = std::getenv("CUOPT_REMOTE_HOST"); + const char* port = std::getenv("CUOPT_REMOTE_PORT"); + + if (host == nullptr || port == nullptr) { + throw std::runtime_error( + "Remote execution enabled but CUOPT_REMOTE_HOST and/or CUOPT_REMOTE_PORT not set"); + } + + return std::string(host) + ":" + std::string(port); +} + +static int64_t parse_env_int64(const char* name, int64_t default_value) +{ + const char* val = std::getenv(name); + if (val == nullptr) return default_value; + try { + return std::stoll(val); + } catch (...) { + return default_value; + } +} + +// Derive client-side polling timeout from the solver's time_limit. +// Returns 0 (no limit) when the solver has no finite time_limit. +template +static int solver_timeout_seconds(f_t time_limit) +{ + if (!std::isfinite(static_cast(time_limit)) || time_limit <= 0) { return 0; } + double secs = static_cast(time_limit) + kTimeoutBufferSeconds; + if (secs > static_cast(std::numeric_limits::max())) { return 0; } + return static_cast(std::ceil(secs)); +} + +static std::string read_pem_file(const char* path) +{ + std::ifstream in(path, std::ios::binary); + if (!in.is_open()) { throw std::runtime_error(std::string("Cannot open TLS file: ") + path); } + std::ostringstream ss; + ss << in.rdbuf(); + return ss.str(); +} + +static const char* get_env(const char* name) +{ + const char* v = std::getenv(name); + return (v && v[0] != '\0') ? v : nullptr; +} + +// Apply env-var overrides for transfer, debug, and TLS configuration. +static void apply_env_overrides(grpc_client_config_t& config) +{ + constexpr int64_t kMinChunkSize = 4096; + constexpr int64_t kMaxChunkSize = 2LL * 1024 * 1024 * 1024; // 2 GiB + constexpr int64_t kMinMessageSize = 4096; + constexpr int64_t kMaxMessageSize = 2LL * 1024 * 1024 * 1024; + + auto chunk = parse_env_int64("CUOPT_CHUNK_SIZE", config.chunk_size_bytes); + if (chunk >= kMinChunkSize && chunk <= kMaxChunkSize) { config.chunk_size_bytes = chunk; } + + auto msg = parse_env_int64("CUOPT_MAX_MESSAGE_BYTES", config.max_message_bytes); + if (msg >= kMinMessageSize && msg <= kMaxMessageSize) { config.max_message_bytes = msg; } + + config.enable_debug_log = (parse_env_int64("CUOPT_GRPC_DEBUG", 0) != 0); + + // TLS configuration from environment variables + if (parse_env_int64("CUOPT_TLS_ENABLED", 0) != 0) { + config.enable_tls = true; + + const char* root_cert = get_env("CUOPT_TLS_ROOT_CERT"); + if (root_cert) { config.tls_root_certs = read_pem_file(root_cert); } + + const char* client_cert = get_env("CUOPT_TLS_CLIENT_CERT"); + const char* client_key = get_env("CUOPT_TLS_CLIENT_KEY"); + if (client_cert && client_key) { + config.tls_client_cert = read_pem_file(client_cert); + config.tls_client_key = read_pem_file(client_key); + } + } + + CUOPT_LOG_INFO("gRPC client config: chunk_size=%lld max_message=%lld tls=%s", + static_cast(config.chunk_size_bytes), + static_cast(config.max_message_bytes), + config.enable_tls ? "on" : "off"); +} + +// ============================================================================ +// Remote execution via gRPC +// ============================================================================ + +template +std::unique_ptr> solve_lp_remote( + cpu_optimization_problem_t const& cpu_problem, + pdlp_solver_settings_t const& settings, + bool problem_checking, + bool use_pdlp_solver_mode) +{ + init_logger_t log(settings.log_file, settings.log_to_console); + + CUOPT_LOG_INFO("solve_lp_remote (CPU problem) - connecting to gRPC server"); + + // Build gRPC client configuration + grpc_client_config_t config; + config.server_address = get_grpc_server_address(); + config.timeout_seconds = solver_timeout_seconds(settings.time_limit); + apply_env_overrides(config); + + // Configure log streaming based on settings + if (settings.log_to_console) { + config.stream_logs = true; + config.log_callback = [](const std::string& line) { std::cout << line << std::endl; }; + } + + // Create client and connect + grpc_client_t client(config); + if (!client.connect()) { + throw std::runtime_error("Failed to connect to gRPC server: " + client.get_last_error()); + } + + CUOPT_LOG_INFO("solve_lp_remote - connected to %s, submitting problem (timeout=%ds)", + config.server_address.c_str(), + config.timeout_seconds); + + // Call the remote solver + auto result = client.solve_lp(cpu_problem, settings); + + if (!result.success) { + throw std::runtime_error("Remote LP solve failed: " + result.error_message); + } + + CUOPT_LOG_INFO("solve_lp_remote - solve completed successfully"); + + return std::move(result.solution); +} + +template +std::unique_ptr> solve_mip_remote( + cpu_optimization_problem_t const& cpu_problem, + mip_solver_settings_t const& settings) +{ + init_logger_t log(settings.log_file, settings.log_to_console); + + CUOPT_LOG_INFO("solve_mip_remote (CPU problem) - connecting to gRPC server"); + + // Build gRPC client configuration + grpc_client_config_t config; + config.server_address = get_grpc_server_address(); + config.timeout_seconds = solver_timeout_seconds(settings.time_limit); + apply_env_overrides(config); + + // Configure log streaming based on settings + if (settings.log_to_console) { + config.stream_logs = true; + config.log_callback = [](const std::string& line) { std::cout << line << std::endl; }; + } + + // Check if user has set incumbent callbacks + auto mip_callbacks = settings.get_mip_callbacks(); + bool has_incumbents = !mip_callbacks.empty(); + bool enable_tracking = has_incumbents; + + // Initialize callbacks with problem size (needed for Python callbacks to work correctly) + // The local MIP solver does this in solve.cu, but for remote solves we need to do it here + if (has_incumbents) { + size_t n_vars = cpu_problem.get_n_variables(); + for (auto* callback : mip_callbacks) { + if (callback != nullptr) { callback->template setup(n_vars); } + } + } + + // Set up incumbent callback forwarding + if (has_incumbents) { + CUOPT_LOG_INFO("solve_mip_remote - setting up inline incumbent callback forwarding"); + config.incumbent_callback = [&mip_callbacks](int64_t index, + double objective, + const std::vector& solution) -> bool { + // Forward incumbent to all user callbacks (invoked from main thread with GIL) + for (auto* callback : mip_callbacks) { + if (callback != nullptr && + callback->get_type() == internals::base_solution_callback_type::GET_SOLUTION) { + auto* get_callback = static_cast(callback); + // Copy solution to non-const buffer for callback interface + std::vector solution_copy = solution; + double obj_copy = objective; + double bound_copy = std::numeric_limits::quiet_NaN(); + get_callback->get_solution( + solution_copy.data(), &obj_copy, &bound_copy, callback->get_user_data()); + } + } + return true; // Continue solving + }; + } + + // Create client and connect + grpc_client_t client(config); + if (!client.connect()) { + throw std::runtime_error("Failed to connect to gRPC server: " + client.get_last_error()); + } + + CUOPT_LOG_INFO( + "solve_mip_remote - connected to %s, submitting problem (incumbents=%s, timeout=%ds)", + config.server_address.c_str(), + enable_tracking ? "enabled" : "disabled", + config.timeout_seconds); + + // Call the remote solver + auto result = client.solve_mip(cpu_problem, settings, enable_tracking); + + if (!result.success) { + throw std::runtime_error("Remote MIP solve failed: " + result.error_message); + } + + CUOPT_LOG_INFO("solve_mip_remote - solve completed successfully"); + + return std::move(result.solution); +} + +// Explicit template instantiations for remote execution stubs +template std::unique_ptr> solve_lp_remote( + cpu_optimization_problem_t const&, + pdlp_solver_settings_t const&, + bool, + bool); + +template std::unique_ptr> solve_mip_remote( + cpu_optimization_problem_t const&, mip_solver_settings_t const&); + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/cuopt_remote.proto b/cpp/src/grpc/cuopt_remote.proto new file mode 100644 index 0000000000..1ce26af191 --- /dev/null +++ b/cpp/src/grpc/cuopt_remote.proto @@ -0,0 +1,348 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package cuopt.remote; + +// Protocol version and metadata +message RequestHeader { + uint32 version = 1; // Protocol version (currently 1) + ProblemType problem_type = 2; // LP or MIP +} + +enum ProblemType { + LP = 0; + MIP = 1; +} + +// Optimization problem representation (field names match cpu_optimization_problem_t) +message OptimizationProblem { + // Problem metadata + string problem_name = 1; + string objective_name = 2; + bool maximize = 3; + double objective_scaling_factor = 4; + double objective_offset = 5; + + // Variable and row names (optional) + repeated string variable_names = 6; + repeated string row_names = 7; + + // Constraint matrix A in CSR format + repeated double A = 8; + repeated int32 A_indices = 9; + repeated int32 A_offsets = 10; + + // Problem vectors + repeated double c = 11; // objective coefficients + repeated double b = 12; // constraint bounds (RHS) + repeated double variable_lower_bounds = 13; + repeated double variable_upper_bounds = 14; + + // Constraint bounds (alternative to b + row_types) + repeated double constraint_lower_bounds = 15; + repeated double constraint_upper_bounds = 16; + bytes row_types = 17; // char array: 'E' (=), 'L' (<=), 'G' (>=), 'N' (objective) + + // Variable types + bytes variable_types = 18; // char array: 'C' (continuous), 'I' (integer), 'B' (binary) + + // Initial solutions + repeated double initial_primal_solution = 19; + repeated double initial_dual_solution = 20; + + // Quadratic objective matrix Q in CSR format + repeated double Q_values = 21; + repeated int32 Q_indices = 22; + repeated int32 Q_offsets = 23; +} + +// PDLP solver mode enum (matches cuOpt pdlp_solver_mode_t) +// Matches cuOpt pdlp_solver_mode_t enum values +enum PDLPSolverMode { + Stable1 = 0; + Stable2 = 1; + Methodical1 = 2; + Fast1 = 3; + Stable3 = 4; +} + +// Matches cuOpt method_t enum values +enum LPMethod { + Concurrent = 0; + PDLP = 1; + DualSimplex = 2; + Barrier = 3; +} + +// PDLP solver settings (field names match cuOpt Python/C++ API) +message PDLPSolverSettings { + // Termination tolerances + double absolute_gap_tolerance = 1; + double relative_gap_tolerance = 2; + double primal_infeasible_tolerance = 3; + double dual_infeasible_tolerance = 4; + double absolute_dual_tolerance = 5; + double relative_dual_tolerance = 6; + double absolute_primal_tolerance = 7; + double relative_primal_tolerance = 8; + + // Limits + double time_limit = 10; + // Iteration limit. Sentinel: set to -1 to mean "unset/use server defaults". + // Note: proto3 numeric fields default to 0 when omitted, so clients should + // explicitly use -1 (or a positive value) to avoid accidentally requesting 0 iterations. + int64 iteration_limit = 11; + + // Solver configuration + bool log_to_console = 20; + bool detect_infeasibility = 21; + bool strict_infeasibility = 22; + PDLPSolverMode pdlp_solver_mode = 23; + LPMethod method = 24; + int32 presolver = 25; + bool dual_postsolve = 26; + bool crossover = 27; + int32 num_gpus = 28; + + bool per_constraint_residual = 30; + bool cudss_deterministic = 31; + int32 folding = 32; + int32 augmented = 33; + int32 dualize = 34; + int32 ordering = 35; + int32 barrier_dual_initial_point = 36; + bool eliminate_dense_columns = 37; + bool save_best_primal_so_far = 38; + bool first_primal_feasible = 39; + int32 pdlp_precision = 40; + + // Warm start data (if provided) + PDLPWarmStartData warm_start_data = 50; +} + +message PDLPWarmStartData { + repeated double current_primal_solution = 1; + repeated double current_dual_solution = 2; + repeated double initial_primal_average = 3; + repeated double initial_dual_average = 4; + repeated double current_ATY = 5; + repeated double sum_primal_solutions = 6; + repeated double sum_dual_solutions = 7; + repeated double last_restart_duality_gap_primal_solution = 8; + repeated double last_restart_duality_gap_dual_solution = 9; + + double initial_primal_weight = 10; + double initial_step_size = 11; + int32 total_pdlp_iterations = 12; + int32 total_pdhg_iterations = 13; + double last_candidate_kkt_score = 14; + double last_restart_kkt_score = 15; + double sum_solution_weight = 16; + int32 iterations_since_last_restart = 17; +} + +// MIP solver settings (field names match cuOpt Python/C++ API) +message MIPSolverSettings { + // Limits + double time_limit = 1; + + // Tolerances + double relative_mip_gap = 2; + double absolute_mip_gap = 3; + double integrality_tolerance = 4; + double absolute_tolerance = 5; + double relative_tolerance = 6; + double presolve_absolute_tolerance = 7; + + // Solver configuration + bool log_to_console = 10; + bool heuristics_only = 11; + int32 num_cpu_threads = 12; + int32 num_gpus = 13; + int32 presolver = 14; + bool mip_scaling = 15; +} + +// LP solve request +message SolveLPRequest { + RequestHeader header = 1; + OptimizationProblem problem = 2; + PDLPSolverSettings settings = 3; +} + +// MIP solve request +message SolveMIPRequest { + RequestHeader header = 1; + OptimizationProblem problem = 2; + MIPSolverSettings settings = 3; + optional bool enable_incumbents = 4; +} + +// LP solution +message LPSolution { + // Solution vectors + repeated double primal_solution = 1; + repeated double dual_solution = 2; + repeated double reduced_cost = 3; + + // Warm start data for next solve + PDLPWarmStartData warm_start_data = 4; + + // Termination information + PDLPTerminationStatus termination_status = 10; + string error_message = 11; + + // Solution statistics + double l2_primal_residual = 20; + double l2_dual_residual = 21; + double primal_objective = 22; + double dual_objective = 23; + double gap = 24; + int32 nb_iterations = 25; + double solve_time = 26; + bool solved_by_pdlp = 27; +} + +enum PDLPTerminationStatus { + PDLP_NO_TERMINATION = 0; + PDLP_NUMERICAL_ERROR = 1; + PDLP_OPTIMAL = 2; + PDLP_PRIMAL_INFEASIBLE = 3; + PDLP_DUAL_INFEASIBLE = 4; + PDLP_ITERATION_LIMIT = 5; + PDLP_TIME_LIMIT = 6; + PDLP_CONCURRENT_LIMIT = 7; + PDLP_PRIMAL_FEASIBLE = 8; +} + +// MIP solution +message MIPSolution { + repeated double solution = 1; + + MIPTerminationStatus termination_status = 10; + string error_message = 11; + + double objective = 20; + double mip_gap = 21; + double solution_bound = 22; + double total_solve_time = 23; + double presolve_time = 24; + double max_constraint_violation = 25; + double max_int_violation = 26; + double max_variable_bound_violation = 27; + int32 nodes = 28; + int32 simplex_iterations = 29; +} + +enum MIPTerminationStatus { + MIP_NO_TERMINATION = 0; + MIP_OPTIMAL = 1; + MIP_FEASIBLE_FOUND = 2; + MIP_INFEASIBLE = 3; + MIP_UNBOUNDED = 4; + MIP_TIME_LIMIT = 5; + MIP_WORK_LIMIT = 6; +} + +// Array field identifiers for chunked array transfers +// Used to identify which problem array a chunk belongs to +enum ArrayFieldId { + FIELD_A_VALUES = 0; + FIELD_A_INDICES = 1; + FIELD_A_OFFSETS = 2; + FIELD_C = 3; + FIELD_B = 4; + FIELD_VARIABLE_LOWER_BOUNDS = 5; + FIELD_VARIABLE_UPPER_BOUNDS = 6; + FIELD_CONSTRAINT_LOWER_BOUNDS = 7; + FIELD_CONSTRAINT_UPPER_BOUNDS = 8; + FIELD_ROW_TYPES = 9; + FIELD_VARIABLE_TYPES = 10; + FIELD_Q_VALUES = 11; + FIELD_Q_INDICES = 12; + FIELD_Q_OFFSETS = 13; + FIELD_INITIAL_PRIMAL = 14; + FIELD_INITIAL_DUAL = 15; + // String arrays (null-separated bytes, sent as chunks alongside numeric data) + FIELD_VARIABLE_NAMES = 20; + FIELD_ROW_NAMES = 21; +} + +// Result array field identifiers for chunked result downloads +// Used to identify which result array a chunk belongs to +enum ResultFieldId { + RESULT_PRIMAL_SOLUTION = 0; + RESULT_DUAL_SOLUTION = 1; + RESULT_REDUCED_COST = 2; + RESULT_MIP_SOLUTION = 3; + // Warm start arrays (LP only) + RESULT_WS_CURRENT_PRIMAL = 10; + RESULT_WS_CURRENT_DUAL = 11; + RESULT_WS_INITIAL_PRIMAL_AVG = 12; + RESULT_WS_INITIAL_DUAL_AVG = 13; + RESULT_WS_CURRENT_ATY = 14; + RESULT_WS_SUM_PRIMAL = 15; + RESULT_WS_SUM_DUAL = 16; + RESULT_WS_LAST_RESTART_GAP_PRIMAL = 17; + RESULT_WS_LAST_RESTART_GAP_DUAL = 18; +} + +// Job status for async operations +enum JobStatus { + QUEUED = 0; // Job submitted, waiting in queue + PROCESSING = 1; // Job currently being solved + COMPLETED = 2; // Job completed successfully + FAILED = 3; // Job failed with error + NOT_FOUND = 4; // Job ID not found + CANCELLED = 5; // Job was cancelled by user +} + +// Response for job submission +message SubmitResponse { + ResponseStatus status = 1; + bytes job_id = 2; // Unique job identifier (bytes to avoid UTF-8 validation warnings) + string message = 3; // Success/error message +} + +// Response for status check +message StatusResponse { + JobStatus job_status = 1; + string message = 2; + double progress = 3; // 0.0-1.0 (future enhancement) + int64 result_size_bytes = 4; // Size of result payload when COMPLETED (0 if unknown) + int64 max_message_bytes = 5; // Server gRPC max message size (-1 = unlimited) +} + +// Response for get result +message ResultResponse { + ResponseStatus status = 1; + string error_message = 2; + + oneof solution { + LPSolution lp_solution = 10; + MIPSolution mip_solution = 11; + } +} + +// Response for delete +message DeleteResponse { + ResponseStatus status = 1; + string message = 2; +} + +// Response for cancel job +message CancelResponse { + ResponseStatus status = 1; + string message = 2; + JobStatus job_status = 3; // Status of job after cancel attempt +} + +enum ResponseStatus { + SUCCESS = 0; + ERROR_INVALID_REQUEST = 1; + ERROR_SOLVE_FAILED = 2; + ERROR_INTERNAL = 3; + ERROR_NOT_FOUND = 4; +} diff --git a/cpp/src/grpc/cuopt_remote_service.proto b/cpp/src/grpc/cuopt_remote_service.proto new file mode 100644 index 0000000000..f469199348 --- /dev/null +++ b/cpp/src/grpc/cuopt_remote_service.proto @@ -0,0 +1,349 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package cuopt.remote; + +// Import the existing message definitions +import "cuopt_remote.proto"; + +// ============================================================================= +// gRPC Service Definition +// ============================================================================= + +service CuOptRemoteService { + // ------------------------- + // Async Job Management + // ------------------------- + + // Submit a new LP or MIP solve job (returns immediately with job_id). + // For problems that fit within gRPC max message size (~256 MiB default). + rpc SubmitJob(SubmitJobRequest) returns (SubmitJobResponse); + + // ------------------------- + // Chunked Array Upload (for large problems exceeding unary message limit) + // ------------------------- + + // Start a chunked upload session. Sends problem metadata (scalars, settings, string arrays) + // in the header. Returns an upload_id for subsequent chunk calls. + rpc StartChunkedUpload(StartChunkedUploadRequest) returns (StartChunkedUploadResponse); + + // Send one chunk of array data. Each chunk carries a slice of one numeric array, + // identified by ArrayFieldId. The server accumulates chunks by upload_id. + rpc SendArrayChunk(SendArrayChunkRequest) returns (SendArrayChunkResponse); + + // Finalize a chunked upload and submit the assembled problem as a job. + // Returns the job_id, same as SubmitJob. + rpc FinishChunkedUpload(FinishChunkedUploadRequest) returns (SubmitJobResponse); + + // ------------------------- + // Job Status and Results + // ------------------------- + + // Check the status of a submitted job + rpc CheckStatus(StatusRequest) returns (StatusResponse); + + // Get the result of a completed job (unary, for results that fit in one message) + rpc GetResult(GetResultRequest) returns (ResultResponse); + + // ------------------------- + // Chunked Result Download (for large results exceeding unary message limit) + // ------------------------- + + // Start a chunked download session. Returns result metadata (scalars, enums, strings) + // and array descriptors (field_id, total_elements, element_size) for the caller to + // fetch array data in subsequent GetResultChunk calls. + rpc StartChunkedDownload(StartChunkedDownloadRequest) returns (StartChunkedDownloadResponse); + + // Get one chunk of result array data. Each call returns a slice of one array, + // identified by ResultFieldId. The client controls pacing. + rpc GetResultChunk(GetResultChunkRequest) returns (GetResultChunkResponse); + + // Finalize a chunked download session and release server-side state. + rpc FinishChunkedDownload(FinishChunkedDownloadRequest) returns (FinishChunkedDownloadResponse); + + // Delete a result from server memory + rpc DeleteResult(DeleteRequest) returns (DeleteResponse); + + // Cancel a queued or running job + rpc CancelJob(CancelRequest) returns (CancelResponse); + + // Wait for a job to complete (blocking, returns status only - use GetResult for solution) + rpc WaitForCompletion(WaitRequest) returns (WaitResponse); + + // ------------------------- + // Log Streaming + // ------------------------- + + // Stream log messages as they are produced (server-side streaming) + rpc StreamLogs(StreamLogsRequest) returns (stream LogMessage); + + // ------------------------- + // Incumbent Solutions + // ------------------------- + + // Get any available incumbent solutions since a given index. + rpc GetIncumbents(IncumbentRequest) returns (IncumbentResponse); +} + +// ============================================================================= +// Request Messages +// ============================================================================= + +// Request to submit a new job +message SubmitJobRequest { + oneof job_data { + SolveLPRequest lp_request = 1; + SolveMIPRequest mip_request = 2; + } +} + +// Response when job is submitted +message SubmitJobResponse { + string job_id = 1; // Unique job identifier + string message = 2; // Optional message +} + +// ============================================================================= +// Chunked Array Upload (for problems exceeding unary message limit) +// ============================================================================= + +// Header for chunked array uploads: carries scalars, settings, and small string arrays. +// Numeric arrays are sent separately via SendArrayChunk RPCs. +message ChunkedProblemHeader { + RequestHeader header = 1; + + // Problem scalars + bool maximize = 2; + double objective_scaling_factor = 3; + double objective_offset = 4; + string problem_name = 5; + string objective_name = 6; + + // String arrays (included here since they are rarely the size bottleneck) + repeated string variable_names = 7; + repeated string row_names = 8; + + // Settings (one of LP or MIP) + oneof settings { + PDLPSolverSettings lp_settings = 20; + MIPSolverSettings mip_settings = 21; + } + + // MIP-specific + bool enable_incumbents = 30; +} + +// A chunk of typed array data for large problem transfers. +// Each chunk carries a slice of one array field, identified by field_id. +// The server pre-allocates the target array on the first chunk (using total_elements) +// and copies each chunk's data at the given element_offset. +message ArrayChunk { + ArrayFieldId field_id = 1; // Which array this chunk belongs to + int64 element_offset = 2; // Element index offset within the array + int64 total_elements = 3; // Total elements in the complete array (for pre-allocation) + bytes data = 4; // Raw bytes: sizeof(element_type) * elements_in_this_chunk +} + +// Start a chunked upload: sends problem metadata and settings. +// Problem type (LP/MIP) is derived from problem_header.header.problem_type. +message StartChunkedUploadRequest { + ChunkedProblemHeader problem_header = 1; +} + +message StartChunkedUploadResponse { + string upload_id = 1; + int64 max_message_bytes = 2; // Server gRPC max message size hint +} + +// Send one array chunk as part of an in-progress chunked upload. +message SendArrayChunkRequest { + string upload_id = 1; + ArrayChunk chunk = 2; +} + +message SendArrayChunkResponse { + string upload_id = 1; + int64 chunks_received = 2; // Running count of chunks received for this upload +} + +// Finalize a chunked upload and submit the assembled problem as a job. +message FinishChunkedUploadRequest { + string upload_id = 1; +} + +// Request to check job status +message StatusRequest { + string job_id = 1; +} + +// Request to get result +message GetResultRequest { + string job_id = 1; +} + +// Metadata about a single result array available for chunked download +message ResultArrayDescriptor { + ResultFieldId field_id = 1; + int64 total_elements = 2; + int64 element_size_bytes = 3; // 8 for double, 4 for int32, etc. +} + +// Header for chunked result download - carries all scalar/enum/string fields +// from LPSolution or MIPSolution. Array data is sent via GetResultChunk. +message ChunkedResultHeader { + bool is_mip = 1; + + // LP result scalars + PDLPTerminationStatus lp_termination_status = 10; + string error_message = 11; + double l2_primal_residual = 12; + double l2_dual_residual = 13; + double primal_objective = 14; + double dual_objective = 15; + double gap = 16; + int32 nb_iterations = 17; + double solve_time = 18; + bool solved_by_pdlp = 19; + + // MIP result scalars + MIPTerminationStatus mip_termination_status = 30; + string mip_error_message = 31; + double mip_objective = 32; + double mip_gap = 33; + double solution_bound = 34; + double total_solve_time = 35; + double presolve_time = 36; + double max_constraint_violation = 37; + double max_int_violation = 38; + double max_variable_bound_violation = 39; + int32 nodes = 40; + int32 simplex_iterations = 41; + + // LP warm start scalars (included in header since they are small) + double ws_initial_primal_weight = 60; + double ws_initial_step_size = 61; + int32 ws_total_pdlp_iterations = 62; + int32 ws_total_pdhg_iterations = 63; + double ws_last_candidate_kkt_score = 64; + double ws_last_restart_kkt_score = 65; + double ws_sum_solution_weight = 66; + int32 ws_iterations_since_last_restart = 67; + + // Array metadata so client knows what to fetch + repeated ResultArrayDescriptor arrays = 50; +} + +message StartChunkedDownloadRequest { + string job_id = 1; +} + +message StartChunkedDownloadResponse { + string download_id = 1; + ChunkedResultHeader header = 2; + int64 max_message_bytes = 3; +} + +message GetResultChunkRequest { + string download_id = 1; + ResultFieldId field_id = 2; + int64 element_offset = 3; + int64 max_elements = 4; +} + +message GetResultChunkResponse { + string download_id = 1; + ResultFieldId field_id = 2; + int64 element_offset = 3; + int64 elements_in_chunk = 4; + bytes data = 5; +} + +message FinishChunkedDownloadRequest { + string download_id = 1; +} + +message FinishChunkedDownloadResponse { + string download_id = 1; +} + +// Request to delete result +message DeleteRequest { + string job_id = 1; +} + +// DeleteResponse is defined in cuopt_remote.proto (imported above) + +// Request to cancel job +message CancelRequest { + string job_id = 1; +} + +// CancelResponse is defined in cuopt_remote.proto (imported above) + +// Request to wait for completion (blocking) +message WaitRequest { + string job_id = 1; +} + +// Response for wait (status only, no solution - use GetResult afterward) +message WaitResponse { + JobStatus job_status = 1; + string message = 2; + int64 result_size_bytes = 3; // Size of result payload when COMPLETED +} + +// Request to stream logs +message StreamLogsRequest { + string job_id = 1; + int64 from_byte = 2; // Optional: start from this byte offset +} + +// Individual log message (streamed) +message LogMessage { + string line = 1; // Single log line + int64 byte_offset = 2; // Byte offset of this line in log file + bool job_complete = 3; // True if this is the last message (job done) +} + +// ============================================================================= +// Incumbent Solutions +// ============================================================================= + +message IncumbentRequest { + string job_id = 1; + int64 from_index = 2; // Return incumbents starting from this index + int32 max_count = 3; // Optional limit (0 or negative => no limit) +} + +message Incumbent { + int64 index = 1; + double objective = 2; + repeated double assignment = 3; + string job_id = 4; +} + +message IncumbentResponse { + repeated Incumbent incumbents = 1; + int64 next_index = 2; // Next index the client should request + bool job_complete = 3; // True if job is complete (no more incumbents) +} + +// ============================================================================= +// Notes on gRPC Status Codes +// ============================================================================= +// +// gRPC uses standard status codes instead of custom ResponseStatus enum: +// OK (0) - Success +// CANCELLED (1) - Operation was cancelled +// UNKNOWN (2) - Unknown error +// INVALID_ARGUMENT (3) - Invalid request +// DEADLINE_EXCEEDED (4) - Timeout +// NOT_FOUND (5) - Job ID not found +// ALREADY_EXISTS (6) - Job already exists +// RESOURCE_EXHAUSTED (8) - Queue full, out of memory, etc. +// INTERNAL (13) - Internal server error +// UNAVAILABLE (14) - Server unavailable +// +// Errors are returned via gRPC Status with a message, not in response message. diff --git a/cpp/src/grpc/grpc_problem_mapper.cpp b/cpp/src/grpc/grpc_problem_mapper.cpp new file mode 100644 index 0000000000..0bf1d13336 --- /dev/null +++ b/cpp/src/grpc/grpc_problem_mapper.cpp @@ -0,0 +1,758 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#include "grpc_problem_mapper.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include "grpc_settings_mapper.hpp" + +#include +#include +#include +#include +#include + +namespace cuopt::linear_programming { + +template +void map_problem_to_proto(const cpu_optimization_problem_t& cpu_problem, + cuopt::remote::OptimizationProblem* pb_problem) +{ + // Basic problem metadata + pb_problem->set_problem_name(cpu_problem.get_problem_name()); + pb_problem->set_objective_name(cpu_problem.get_objective_name()); + pb_problem->set_maximize(cpu_problem.get_sense()); + pb_problem->set_objective_scaling_factor(cpu_problem.get_objective_scaling_factor()); + pb_problem->set_objective_offset(cpu_problem.get_objective_offset()); + + // Get constraint matrix data from host memory + auto values = cpu_problem.get_constraint_matrix_values_host(); + auto indices = cpu_problem.get_constraint_matrix_indices_host(); + auto offsets = cpu_problem.get_constraint_matrix_offsets_host(); + + // Constraint matrix A in CSR format + for (const auto& val : values) { + pb_problem->add_a(static_cast(val)); + } + for (const auto& idx : indices) { + pb_problem->add_a_indices(static_cast(idx)); + } + for (const auto& off : offsets) { + pb_problem->add_a_offsets(static_cast(off)); + } + + // Objective coefficients + auto obj_coeffs = cpu_problem.get_objective_coefficients_host(); + for (const auto& c : obj_coeffs) { + pb_problem->add_c(static_cast(c)); + } + + // Variable bounds + auto var_lb = cpu_problem.get_variable_lower_bounds_host(); + auto var_ub = cpu_problem.get_variable_upper_bounds_host(); + for (const auto& lb : var_lb) { + pb_problem->add_variable_lower_bounds(static_cast(lb)); + } + for (const auto& ub : var_ub) { + pb_problem->add_variable_upper_bounds(static_cast(ub)); + } + + // Constraint bounds + auto con_lb = cpu_problem.get_constraint_lower_bounds_host(); + auto con_ub = cpu_problem.get_constraint_upper_bounds_host(); + + if (!con_lb.empty() && !con_ub.empty()) { + for (const auto& lb : con_lb) { + pb_problem->add_constraint_lower_bounds(static_cast(lb)); + } + for (const auto& ub : con_ub) { + pb_problem->add_constraint_upper_bounds(static_cast(ub)); + } + } + + // Row types (if available) + auto row_types = cpu_problem.get_row_types_host(); + if (!row_types.empty()) { + pb_problem->set_row_types(std::string(row_types.begin(), row_types.end())); + } + + // Constraint bounds (RHS) - if available + auto b = cpu_problem.get_constraint_bounds_host(); + if (!b.empty()) { + for (const auto& rhs : b) { + pb_problem->add_b(static_cast(rhs)); + } + } + + // Variable names + const auto& var_names = cpu_problem.get_variable_names(); + for (const auto& name : var_names) { + pb_problem->add_variable_names(name); + } + + // Row names + const auto& row_names = cpu_problem.get_row_names(); + for (const auto& name : row_names) { + pb_problem->add_row_names(name); + } + + // Variable types (for MIP problems) + auto var_types = cpu_problem.get_variable_types_host(); + if (!var_types.empty()) { + // Convert var_t enum to char representation + std::string var_types_str; + var_types_str.reserve(var_types.size()); + for (const auto& vt : var_types) { + switch (vt) { + case var_t::CONTINUOUS: var_types_str.push_back('C'); break; + case var_t::INTEGER: var_types_str.push_back('I'); break; + default: + throw std::runtime_error("map_problem_to_proto: unknown var_t value " + + std::to_string(static_cast(vt))); + } + } + pb_problem->set_variable_types(var_types_str); + } + + // Quadratic objective matrix Q (for QPS problems) + if (cpu_problem.has_quadratic_objective()) { + const auto& q_values = cpu_problem.get_quadratic_objective_values(); + const auto& q_indices = cpu_problem.get_quadratic_objective_indices(); + const auto& q_offsets = cpu_problem.get_quadratic_objective_offsets(); + + for (const auto& val : q_values) { + pb_problem->add_q_values(static_cast(val)); + } + for (const auto& idx : q_indices) { + pb_problem->add_q_indices(static_cast(idx)); + } + for (const auto& off : q_offsets) { + pb_problem->add_q_offsets(static_cast(off)); + } + } +} + +template +void map_proto_to_problem(const cuopt::remote::OptimizationProblem& pb_problem, + cpu_optimization_problem_t& cpu_problem) +{ + // Basic problem metadata + cpu_problem.set_problem_name(pb_problem.problem_name()); + cpu_problem.set_objective_name(pb_problem.objective_name()); + cpu_problem.set_maximize(pb_problem.maximize()); + cpu_problem.set_objective_scaling_factor(pb_problem.objective_scaling_factor()); + cpu_problem.set_objective_offset(pb_problem.objective_offset()); + + // Constraint matrix A in CSR format + std::vector values(pb_problem.a().begin(), pb_problem.a().end()); + std::vector indices(pb_problem.a_indices().begin(), pb_problem.a_indices().end()); + std::vector offsets(pb_problem.a_offsets().begin(), pb_problem.a_offsets().end()); + + cpu_problem.set_csr_constraint_matrix(values.data(), + static_cast(values.size()), + indices.data(), + static_cast(indices.size()), + offsets.data(), + static_cast(offsets.size())); + + // Objective coefficients + std::vector obj(pb_problem.c().begin(), pb_problem.c().end()); + cpu_problem.set_objective_coefficients(obj.data(), static_cast(obj.size())); + + // Variable bounds + std::vector var_lb(pb_problem.variable_lower_bounds().begin(), + pb_problem.variable_lower_bounds().end()); + std::vector var_ub(pb_problem.variable_upper_bounds().begin(), + pb_problem.variable_upper_bounds().end()); + cpu_problem.set_variable_lower_bounds(var_lb.data(), static_cast(var_lb.size())); + cpu_problem.set_variable_upper_bounds(var_ub.data(), static_cast(var_ub.size())); + + // Constraint bounds (prefer lower/upper bounds if available) + if (pb_problem.constraint_lower_bounds_size() > 0 && + pb_problem.constraint_upper_bounds_size() > 0 && + pb_problem.constraint_lower_bounds_size() == pb_problem.constraint_upper_bounds_size()) { + std::vector con_lb(pb_problem.constraint_lower_bounds().begin(), + pb_problem.constraint_lower_bounds().end()); + std::vector con_ub(pb_problem.constraint_upper_bounds().begin(), + pb_problem.constraint_upper_bounds().end()); + cpu_problem.set_constraint_lower_bounds(con_lb.data(), static_cast(con_lb.size())); + cpu_problem.set_constraint_upper_bounds(con_ub.data(), static_cast(con_ub.size())); + } else if (pb_problem.b_size() > 0) { + // Use b (RHS) + row_types format + std::vector b(pb_problem.b().begin(), pb_problem.b().end()); + cpu_problem.set_constraint_bounds(b.data(), static_cast(b.size())); + + if (!pb_problem.row_types().empty()) { + const std::string& row_types_str = pb_problem.row_types(); + cpu_problem.set_row_types(row_types_str.data(), static_cast(row_types_str.size())); + } + } + + // Variable names + if (pb_problem.variable_names_size() > 0) { + std::vector var_names(pb_problem.variable_names().begin(), + pb_problem.variable_names().end()); + cpu_problem.set_variable_names(var_names); + } + + // Row names + if (pb_problem.row_names_size() > 0) { + std::vector row_names(pb_problem.row_names().begin(), + pb_problem.row_names().end()); + cpu_problem.set_row_names(row_names); + } + + // Variable types + if (!pb_problem.variable_types().empty()) { + const std::string& var_types_str = pb_problem.variable_types(); + // Convert char representation to var_t enum + std::vector var_types; + var_types.reserve(var_types_str.size()); + for (char c : var_types_str) { + switch (c) { + case 'C': var_types.push_back(var_t::CONTINUOUS); break; + case 'I': + case 'B': var_types.push_back(var_t::INTEGER); break; + default: + throw std::runtime_error(std::string("Unknown variable type character '") + c + + "' in variable_types string (expected 'C', 'I', or 'B')"); + } + } + cpu_problem.set_variable_types(var_types.data(), static_cast(var_types.size())); + } + + // Quadratic objective matrix Q (for QPS problems) + if (pb_problem.q_values_size() > 0) { + std::vector q_values(pb_problem.q_values().begin(), pb_problem.q_values().end()); + std::vector q_indices(pb_problem.q_indices().begin(), pb_problem.q_indices().end()); + std::vector q_offsets(pb_problem.q_offsets().begin(), pb_problem.q_offsets().end()); + + cpu_problem.set_quadratic_objective_matrix(q_values.data(), + static_cast(q_values.size()), + q_indices.data(), + static_cast(q_indices.size()), + q_offsets.data(), + static_cast(q_offsets.size())); + } + + // Infer problem category from variable types + if (!pb_problem.variable_types().empty()) { + const std::string& var_types_str = pb_problem.variable_types(); + bool has_integers = false; + for (char c : var_types_str) { + if (c == 'I' || c == 'B') { + has_integers = true; + break; + } + } + cpu_problem.set_problem_category(has_integers ? problem_category_t::MIP + : problem_category_t::LP); + } else { + cpu_problem.set_problem_category(problem_category_t::LP); + } +} + +// ============================================================================ +// Size estimation +// ============================================================================ + +template +size_t estimate_problem_proto_size(const cpu_optimization_problem_t& cpu_problem) +{ + size_t est = 0; + + // Constraint matrix CSR arrays + auto values = cpu_problem.get_constraint_matrix_values_host(); + auto indices = cpu_problem.get_constraint_matrix_indices_host(); + auto offsets = cpu_problem.get_constraint_matrix_offsets_host(); + est += values.size() * sizeof(double); // packed repeated double + est += indices.size() * 5; // varint int32 (worst case 5 bytes each) + est += offsets.size() * 5; + + // Objective coefficients + est += cpu_problem.get_objective_coefficients_host().size() * sizeof(double); + + // Variable bounds + est += cpu_problem.get_variable_lower_bounds_host().size() * sizeof(double); + est += cpu_problem.get_variable_upper_bounds_host().size() * sizeof(double); + + // Constraint bounds + est += cpu_problem.get_constraint_lower_bounds_host().size() * sizeof(double); + est += cpu_problem.get_constraint_upper_bounds_host().size() * sizeof(double); + est += cpu_problem.get_constraint_bounds_host().size() * sizeof(double); + + // Row types and variable types + est += cpu_problem.get_row_types_host().size(); + est += cpu_problem.get_variable_types_host().size(); + + // Quadratic objective + if (cpu_problem.has_quadratic_objective()) { + est += cpu_problem.get_quadratic_objective_values().size() * sizeof(double); + est += cpu_problem.get_quadratic_objective_indices().size() * 5; + est += cpu_problem.get_quadratic_objective_offsets().size() * 5; + } + + // String arrays (rough estimate) + for (const auto& name : cpu_problem.get_variable_names()) { + est += name.size() + 2; // string + tag + length varint + } + for (const auto& name : cpu_problem.get_row_names()) { + est += name.size() + 2; + } + + // Protobuf overhead for tags, submessage lengths, etc. + est += 512; + + return est; +} + +// ============================================================================ +// Chunked header population (client-side, for CHUNKED_ARRAYS upload) +// ============================================================================ + +template +void populate_chunked_header_lp(const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings, + cuopt::remote::ChunkedProblemHeader* header) +{ + // Request header + auto* rh = header->mutable_header(); + rh->set_version(1); + rh->set_problem_type(cuopt::remote::LP); + + header->set_maximize(cpu_problem.get_sense()); + header->set_objective_scaling_factor(cpu_problem.get_objective_scaling_factor()); + header->set_objective_offset(cpu_problem.get_objective_offset()); + header->set_problem_name(cpu_problem.get_problem_name()); + header->set_objective_name(cpu_problem.get_objective_name()); + + // Variable/row names are sent as chunked arrays, not in the header, + // to avoid the header exceeding gRPC max message size for large problems. + + // LP settings + map_pdlp_settings_to_proto(settings, header->mutable_lp_settings()); +} + +template +void populate_chunked_header_mip(const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents, + cuopt::remote::ChunkedProblemHeader* header) +{ + // Request header + auto* rh = header->mutable_header(); + rh->set_version(1); + rh->set_problem_type(cuopt::remote::MIP); + + header->set_maximize(cpu_problem.get_sense()); + header->set_objective_scaling_factor(cpu_problem.get_objective_scaling_factor()); + header->set_objective_offset(cpu_problem.get_objective_offset()); + header->set_problem_name(cpu_problem.get_problem_name()); + header->set_objective_name(cpu_problem.get_objective_name()); + + // Variable/row names are sent as chunked arrays, not in the header. + + // MIP settings + map_mip_settings_to_proto(settings, header->mutable_mip_settings()); + header->set_enable_incumbents(enable_incumbents); +} + +// ============================================================================ +// Chunked header reconstruction (server-side) +// ============================================================================ + +template +void map_chunked_header_to_problem(const cuopt::remote::ChunkedProblemHeader& header, + cpu_optimization_problem_t& cpu_problem) +{ + cpu_problem.set_problem_name(header.problem_name()); + cpu_problem.set_objective_name(header.objective_name()); + cpu_problem.set_maximize(header.maximize()); + cpu_problem.set_objective_scaling_factor(header.objective_scaling_factor()); + cpu_problem.set_objective_offset(header.objective_offset()); + + // String arrays + if (header.variable_names_size() > 0) { + std::vector var_names(header.variable_names().begin(), + header.variable_names().end()); + cpu_problem.set_variable_names(var_names); + } + if (header.row_names_size() > 0) { + std::vector row_names(header.row_names().begin(), header.row_names().end()); + cpu_problem.set_row_names(row_names); + } + + // Problem category inferred later when variable_types array is set +} + +// ============================================================================ +// Chunked array reconstruction (server-side, consolidates all array mapping) +// ============================================================================ + +template +void map_chunked_arrays_to_problem(const cuopt::remote::ChunkedProblemHeader& header, + const std::map>& arrays, + cpu_optimization_problem_t& cpu_problem) +{ + map_chunked_header_to_problem(header, cpu_problem); + + auto get_doubles = [&](int32_t field_id) -> std::vector { + auto it = arrays.find(field_id); + if (it == arrays.end() || it->second.empty()) return {}; + if (it->second.size() % sizeof(double) != 0) return {}; + size_t n = it->second.size() / sizeof(double); + if constexpr (std::is_same_v) { + std::vector v(n); + std::memcpy(v.data(), it->second.data(), n * sizeof(double)); + return v; + } else { + std::vector tmp(n); + std::memcpy(tmp.data(), it->second.data(), n * sizeof(double)); + return std::vector(tmp.begin(), tmp.end()); + } + }; + + auto get_ints = [&](int32_t field_id) -> std::vector { + auto it = arrays.find(field_id); + if (it == arrays.end() || it->second.empty()) return {}; + if (it->second.size() % sizeof(int32_t) != 0) return {}; + size_t n = it->second.size() / sizeof(int32_t); + if constexpr (std::is_same_v) { + std::vector v(n); + std::memcpy(v.data(), it->second.data(), n * sizeof(int32_t)); + return v; + } else { + std::vector tmp(n); + std::memcpy(tmp.data(), it->second.data(), n * sizeof(int32_t)); + return std::vector(tmp.begin(), tmp.end()); + } + }; + + auto get_bytes = [&](int32_t field_id) -> std::string { + auto it = arrays.find(field_id); + if (it == arrays.end() || it->second.empty()) return {}; + return std::string(reinterpret_cast(it->second.data()), it->second.size()); + }; + + auto get_string_list = [&](int32_t field_id) -> std::vector { + auto it = arrays.find(field_id); + if (it == arrays.end() || it->second.empty()) return {}; + std::vector names; + const char* s = reinterpret_cast(it->second.data()); + const char* s_end = s + it->second.size(); + while (s < s_end) { + const char* nul = static_cast(std::memchr(s, '\0', s_end - s)); + if (!nul) nul = s_end; + names.emplace_back(s, nul); + if (nul == s_end) break; + s = nul + 1; + } + return names; + }; + + // CSR constraint matrix + auto a_values = get_doubles(cuopt::remote::FIELD_A_VALUES); + auto a_indices = get_ints(cuopt::remote::FIELD_A_INDICES); + auto a_offsets = get_ints(cuopt::remote::FIELD_A_OFFSETS); + if (!a_values.empty() && !a_indices.empty() && !a_offsets.empty()) { + cpu_problem.set_csr_constraint_matrix(a_values.data(), + static_cast(a_values.size()), + a_indices.data(), + static_cast(a_indices.size()), + a_offsets.data(), + static_cast(a_offsets.size())); + } + + // Objective coefficients + auto c_vec = get_doubles(cuopt::remote::FIELD_C); + if (!c_vec.empty()) { + cpu_problem.set_objective_coefficients(c_vec.data(), static_cast(c_vec.size())); + } + + // Variable bounds + auto var_lb = get_doubles(cuopt::remote::FIELD_VARIABLE_LOWER_BOUNDS); + auto var_ub = get_doubles(cuopt::remote::FIELD_VARIABLE_UPPER_BOUNDS); + if (!var_lb.empty()) { + cpu_problem.set_variable_lower_bounds(var_lb.data(), static_cast(var_lb.size())); + } + if (!var_ub.empty()) { + cpu_problem.set_variable_upper_bounds(var_ub.data(), static_cast(var_ub.size())); + } + + // Constraint bounds + auto con_lb = get_doubles(cuopt::remote::FIELD_CONSTRAINT_LOWER_BOUNDS); + auto con_ub = get_doubles(cuopt::remote::FIELD_CONSTRAINT_UPPER_BOUNDS); + if (!con_lb.empty()) { + cpu_problem.set_constraint_lower_bounds(con_lb.data(), static_cast(con_lb.size())); + } + if (!con_ub.empty()) { + cpu_problem.set_constraint_upper_bounds(con_ub.data(), static_cast(con_ub.size())); + } + + auto b_vec = get_doubles(cuopt::remote::FIELD_B); + if (!b_vec.empty()) { + cpu_problem.set_constraint_bounds(b_vec.data(), static_cast(b_vec.size())); + } + + // Row types + auto row_types_str = get_bytes(cuopt::remote::FIELD_ROW_TYPES); + if (!row_types_str.empty()) { + cpu_problem.set_row_types(row_types_str.data(), static_cast(row_types_str.size())); + } + + // Variable types + problem category + auto var_types_str = get_bytes(cuopt::remote::FIELD_VARIABLE_TYPES); + if (!var_types_str.empty()) { + std::vector vtypes; + vtypes.reserve(var_types_str.size()); + bool has_ints = false; + for (char c : var_types_str) { + switch (c) { + case 'C': vtypes.push_back(var_t::CONTINUOUS); break; + case 'I': + case 'B': + vtypes.push_back(var_t::INTEGER); + has_ints = true; + break; + default: + throw std::runtime_error(std::string("Unknown variable type character '") + c + + "' in chunked variable_types (expected 'C', 'I', or 'B')"); + } + } + cpu_problem.set_variable_types(vtypes.data(), static_cast(vtypes.size())); + cpu_problem.set_problem_category(has_ints ? problem_category_t::MIP : problem_category_t::LP); + } else { + cpu_problem.set_problem_category(problem_category_t::LP); + } + + // Quadratic objective + auto q_values = get_doubles(cuopt::remote::FIELD_Q_VALUES); + auto q_indices = get_ints(cuopt::remote::FIELD_Q_INDICES); + auto q_offsets = get_ints(cuopt::remote::FIELD_Q_OFFSETS); + if (!q_values.empty() && !q_indices.empty() && !q_offsets.empty()) { + cpu_problem.set_quadratic_objective_matrix(q_values.data(), + static_cast(q_values.size()), + q_indices.data(), + static_cast(q_indices.size()), + q_offsets.data(), + static_cast(q_offsets.size())); + } + + // String arrays (may also be in header; these override if present as chunked arrays) + auto var_names = get_string_list(cuopt::remote::FIELD_VARIABLE_NAMES); + if (!var_names.empty()) { cpu_problem.set_variable_names(var_names); } + auto row_names = get_string_list(cuopt::remote::FIELD_ROW_NAMES); + if (!row_names.empty()) { cpu_problem.set_row_names(row_names); } +} + +// ============================================================================= +// Chunked array request building (client-side) +// ============================================================================= + +namespace { + +template +void chunk_typed_array(std::vector& out, + cuopt::remote::ArrayFieldId field_id, + const std::vector& data, + const std::string& upload_id, + int64_t chunk_data_budget) +{ + if (data.empty()) return; + + const int64_t elem_size = static_cast(sizeof(T)); + const int64_t total_elements = static_cast(data.size()); + + int64_t elems_per_chunk = chunk_data_budget / elem_size; + if (elems_per_chunk <= 0) elems_per_chunk = 1; + + const auto* raw = reinterpret_cast(data.data()); + + for (int64_t offset = 0; offset < total_elements; offset += elems_per_chunk) { + int64_t count = std::min(elems_per_chunk, total_elements - offset); + int64_t byte_offset = offset * elem_size; + int64_t byte_count = count * elem_size; + + cuopt::remote::SendArrayChunkRequest req; + req.set_upload_id(upload_id); + auto* ac = req.mutable_chunk(); + ac->set_field_id(field_id); + ac->set_element_offset(offset); + ac->set_total_elements(total_elements); + ac->set_data(raw + byte_offset, byte_count); + out.push_back(std::move(req)); + } +} + +void chunk_byte_blob(std::vector& out, + cuopt::remote::ArrayFieldId field_id, + const std::vector& data, + const std::string& upload_id, + int64_t chunk_data_budget) +{ + chunk_typed_array(out, field_id, data, upload_id, chunk_data_budget); +} + +} // namespace + +template +std::vector build_array_chunk_requests( + const cpu_optimization_problem_t& problem, + const std::string& upload_id, + int64_t chunk_size_bytes) +{ + std::vector requests; + + auto values = problem.get_constraint_matrix_values_host(); + auto indices = problem.get_constraint_matrix_indices_host(); + auto offsets = problem.get_constraint_matrix_offsets_host(); + auto obj = problem.get_objective_coefficients_host(); + auto var_lb = problem.get_variable_lower_bounds_host(); + auto var_ub = problem.get_variable_upper_bounds_host(); + auto con_lb = problem.get_constraint_lower_bounds_host(); + auto con_ub = problem.get_constraint_upper_bounds_host(); + auto b = problem.get_constraint_bounds_host(); + + chunk_typed_array(requests, cuopt::remote::FIELD_A_VALUES, values, upload_id, chunk_size_bytes); + chunk_typed_array(requests, cuopt::remote::FIELD_A_INDICES, indices, upload_id, chunk_size_bytes); + chunk_typed_array(requests, cuopt::remote::FIELD_A_OFFSETS, offsets, upload_id, chunk_size_bytes); + chunk_typed_array(requests, cuopt::remote::FIELD_C, obj, upload_id, chunk_size_bytes); + chunk_typed_array( + requests, cuopt::remote::FIELD_VARIABLE_LOWER_BOUNDS, var_lb, upload_id, chunk_size_bytes); + chunk_typed_array( + requests, cuopt::remote::FIELD_VARIABLE_UPPER_BOUNDS, var_ub, upload_id, chunk_size_bytes); + chunk_typed_array( + requests, cuopt::remote::FIELD_CONSTRAINT_LOWER_BOUNDS, con_lb, upload_id, chunk_size_bytes); + chunk_typed_array( + requests, cuopt::remote::FIELD_CONSTRAINT_UPPER_BOUNDS, con_ub, upload_id, chunk_size_bytes); + chunk_typed_array(requests, cuopt::remote::FIELD_B, b, upload_id, chunk_size_bytes); + + auto row_types = problem.get_row_types_host(); + if (!row_types.empty()) { + std::vector rt_bytes(row_types.begin(), row_types.end()); + chunk_byte_blob( + requests, cuopt::remote::FIELD_ROW_TYPES, rt_bytes, upload_id, chunk_size_bytes); + } + + auto var_types = problem.get_variable_types_host(); + if (!var_types.empty()) { + std::vector vt_bytes; + vt_bytes.reserve(var_types.size()); + for (const auto& vt : var_types) { + switch (vt) { + case var_t::CONTINUOUS: vt_bytes.push_back('C'); break; + case var_t::INTEGER: vt_bytes.push_back('I'); break; + default: + throw std::runtime_error("chunk_problem_to_proto: unknown var_t value " + + std::to_string(static_cast(vt))); + } + } + chunk_byte_blob( + requests, cuopt::remote::FIELD_VARIABLE_TYPES, vt_bytes, upload_id, chunk_size_bytes); + } + + if (problem.has_quadratic_objective()) { + const auto& q_values = problem.get_quadratic_objective_values(); + const auto& q_indices = problem.get_quadratic_objective_indices(); + const auto& q_offsets = problem.get_quadratic_objective_offsets(); + chunk_typed_array( + requests, cuopt::remote::FIELD_Q_VALUES, q_values, upload_id, chunk_size_bytes); + chunk_typed_array( + requests, cuopt::remote::FIELD_Q_INDICES, q_indices, upload_id, chunk_size_bytes); + chunk_typed_array( + requests, cuopt::remote::FIELD_Q_OFFSETS, q_offsets, upload_id, chunk_size_bytes); + } + + auto names_to_blob = [](const std::vector& names) -> std::vector { + if (names.empty()) return {}; + size_t total = 0; + for (const auto& n : names) + total += n.size() + 1; + std::vector blob(total); + size_t pos = 0; + for (const auto& n : names) { + std::memcpy(blob.data() + pos, n.data(), n.size()); + pos += n.size(); + blob[pos++] = '\0'; + } + return blob; + }; + + auto var_names_blob = names_to_blob(problem.get_variable_names()); + auto row_names_blob = names_to_blob(problem.get_row_names()); + chunk_byte_blob( + requests, cuopt::remote::FIELD_VARIABLE_NAMES, var_names_blob, upload_id, chunk_size_bytes); + chunk_byte_blob( + requests, cuopt::remote::FIELD_ROW_NAMES, row_names_blob, upload_id, chunk_size_bytes); + + return requests; +} + +// Explicit template instantiations +#if CUOPT_INSTANTIATE_FLOAT +template void map_problem_to_proto(const cpu_optimization_problem_t& cpu_problem, + cuopt::remote::OptimizationProblem* pb_problem); +template void map_proto_to_problem(const cuopt::remote::OptimizationProblem& pb_problem, + cpu_optimization_problem_t& cpu_problem); +template size_t estimate_problem_proto_size( + const cpu_optimization_problem_t& cpu_problem); +template void populate_chunked_header_lp( + const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings, + cuopt::remote::ChunkedProblemHeader* header); +template void populate_chunked_header_mip( + const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents, + cuopt::remote::ChunkedProblemHeader* header); +template void map_chunked_header_to_problem( + const cuopt::remote::ChunkedProblemHeader& header, + cpu_optimization_problem_t& cpu_problem); +template void map_chunked_arrays_to_problem( + const cuopt::remote::ChunkedProblemHeader& header, + const std::map>& arrays, + cpu_optimization_problem_t& cpu_problem); +template std::vector build_array_chunk_requests( + const cpu_optimization_problem_t& problem, + const std::string& upload_id, + int64_t chunk_size_bytes); +#endif + +#if CUOPT_INSTANTIATE_DOUBLE +template void map_problem_to_proto(const cpu_optimization_problem_t& cpu_problem, + cuopt::remote::OptimizationProblem* pb_problem); +template void map_proto_to_problem(const cuopt::remote::OptimizationProblem& pb_problem, + cpu_optimization_problem_t& cpu_problem); +template size_t estimate_problem_proto_size( + const cpu_optimization_problem_t& cpu_problem); +template void populate_chunked_header_lp( + const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings, + cuopt::remote::ChunkedProblemHeader* header); +template void populate_chunked_header_mip( + const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents, + cuopt::remote::ChunkedProblemHeader* header); +template void map_chunked_header_to_problem( + const cuopt::remote::ChunkedProblemHeader& header, + cpu_optimization_problem_t& cpu_problem); +template void map_chunked_arrays_to_problem( + const cuopt::remote::ChunkedProblemHeader& header, + const std::map>& arrays, + cpu_optimization_problem_t& cpu_problem); +template std::vector build_array_chunk_requests( + const cpu_optimization_problem_t& problem, + const std::string& upload_id, + int64_t chunk_size_bytes); +#endif + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/grpc_problem_mapper.hpp b/cpp/src/grpc/grpc_problem_mapper.hpp new file mode 100644 index 0000000000..db113e2502 --- /dev/null +++ b/cpp/src/grpc/grpc_problem_mapper.hpp @@ -0,0 +1,131 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace cuopt::remote { +class ChunkedProblemHeader; +} + +namespace cuopt::linear_programming { + +// Forward declarations +template +class cpu_optimization_problem_t; + +template +struct pdlp_solver_settings_t; + +template +struct mip_solver_settings_t; + +/** + * @brief Map cpu_optimization_problem_t to protobuf OptimizationProblem message. + * + * Populates a protobuf message using the generated protobuf C++ API. + * Does not perform serialization — that is handled by the protobuf library. + */ +template +void map_problem_to_proto(const cpu_optimization_problem_t& cpu_problem, + cuopt::remote::OptimizationProblem* pb_problem); + +/** + * @brief Map protobuf OptimizationProblem message to cpu_optimization_problem_t. + * + * Reads from a protobuf message using the generated protobuf C++ API. + * Does not perform deserialization — that is handled by the protobuf library. + */ +template +void map_proto_to_problem(const cuopt::remote::OptimizationProblem& pb_problem, + cpu_optimization_problem_t& cpu_problem); + +/** + * @brief Estimate the serialized protobuf size of a SolveLPRequest/SolveMIPRequest. + * + * Computes an approximate upper bound on the serialized size without actually building + * the protobuf message. Used to decide whether to use chunked array transfer. + * + * @return Estimated size in bytes + */ +template +size_t estimate_problem_proto_size(const cpu_optimization_problem_t& cpu_problem); + +/** + * @brief Populate a ChunkedProblemHeader from a cpu_optimization_problem_t and LP settings. + * + * Fills the header with problem scalars, string arrays, and LP settings. + * Numeric arrays are NOT included (they are sent as ArrayChunk messages). + */ +template +void populate_chunked_header_lp(const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings, + cuopt::remote::ChunkedProblemHeader* header); + +/** + * @brief Populate a ChunkedProblemHeader from a cpu_optimization_problem_t and MIP settings. + * + * Fills the header with problem scalars, string arrays, and MIP settings. + * Numeric arrays are NOT included (they are sent as ArrayChunk messages). + */ +template +void populate_chunked_header_mip(const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents, + cuopt::remote::ChunkedProblemHeader* header); + +/** + * @brief Reconstruct a cpu_optimization_problem_t from a ChunkedProblemHeader. + * + * Populates problem scalars and string arrays from the header. Numeric arrays + * must be populated separately from ArrayChunk data. + */ +template +void map_chunked_header_to_problem(const cuopt::remote::ChunkedProblemHeader& header, + cpu_optimization_problem_t& cpu_problem); + +/** + * @brief Reconstruct a cpu_optimization_problem_t from a ChunkedProblemHeader and raw array data. + * + * This is the single entry point for reconstructing a problem from chunked transfer data. + * It calls map_chunked_header_to_problem() for scalars/strings, then populates all numeric + * arrays from the raw byte data keyed by ArrayFieldId. + * + * @param header The chunked problem header (scalars, settings metadata, string arrays) + * @param arrays Map of ArrayFieldId (as int32_t) to raw byte data for each array field + * @param cpu_problem The cpu_optimization_problem_t to populate (output parameter) + */ +template +void map_chunked_arrays_to_problem(const cuopt::remote::ChunkedProblemHeader& header, + const std::map>& arrays, + cpu_optimization_problem_t& cpu_problem); + +/** + * @brief Build SendArrayChunkRequest messages for chunked upload of problem arrays. + * + * Iterates the problem's host arrays directly and slices each array into + * chunk-sized SendArrayChunkRequest protobuf messages. The caller simply + * iterates the returned vector and sends each message via SendArrayChunk RPC. + * + * @param problem The problem whose arrays to chunk + * @param upload_id The upload session ID from StartChunkedUpload + * @param chunk_size_bytes Maximum raw data bytes per chunk message + * @return Vector of ready-to-send SendArrayChunkRequest protobuf messages + */ +template +std::vector build_array_chunk_requests( + const cpu_optimization_problem_t& problem, + const std::string& upload_id, + int64_t chunk_size_bytes); + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/grpc_service_mapper.cpp b/cpp/src/grpc/grpc_service_mapper.cpp new file mode 100644 index 0000000000..786d1b8f1c --- /dev/null +++ b/cpp/src/grpc/grpc_service_mapper.cpp @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#include "grpc_service_mapper.hpp" + +#include +#include +#include +#include +#include +#include "grpc_problem_mapper.hpp" +#include "grpc_settings_mapper.hpp" + +namespace cuopt::linear_programming { + +template +cuopt::remote::SubmitJobRequest build_lp_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings) +{ + cuopt::remote::SubmitJobRequest submit_request; + + // Get the lp_request from the oneof + auto* lp_request = submit_request.mutable_lp_request(); + + // Set header + auto* header = lp_request->mutable_header(); + header->set_version(1); + header->set_problem_type(cuopt::remote::LP); + + // Map problem data to protobuf + map_problem_to_proto(cpu_problem, lp_request->mutable_problem()); + + // Map settings to protobuf + map_pdlp_settings_to_proto(settings, lp_request->mutable_settings()); + + return submit_request; +} + +template +cuopt::remote::SubmitJobRequest build_mip_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents) +{ + cuopt::remote::SubmitJobRequest submit_request; + + // Get the mip_request from the oneof + auto* mip_request = submit_request.mutable_mip_request(); + + // Set header + auto* header = mip_request->mutable_header(); + header->set_version(1); + header->set_problem_type(cuopt::remote::MIP); + + // Map problem data to protobuf + map_problem_to_proto(cpu_problem, mip_request->mutable_problem()); + + // Map settings to protobuf + map_mip_settings_to_proto(settings, mip_request->mutable_settings()); + + // Set enable_incumbents flag + mip_request->set_enable_incumbents(enable_incumbents); + + return submit_request; +} + +// Explicit template instantiations +#if CUOPT_INSTANTIATE_FLOAT +template cuopt::remote::SubmitJobRequest build_lp_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings); +template cuopt::remote::SubmitJobRequest build_mip_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents); +#endif + +#if CUOPT_INSTANTIATE_DOUBLE +template cuopt::remote::SubmitJobRequest build_lp_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings); +template cuopt::remote::SubmitJobRequest build_mip_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents); +#endif + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/grpc_service_mapper.hpp b/cpp/src/grpc/grpc_service_mapper.hpp new file mode 100644 index 0000000000..ed438a0551 --- /dev/null +++ b/cpp/src/grpc/grpc_service_mapper.hpp @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +#include + +namespace cuopt::linear_programming { + +// Forward declarations +template +class cpu_optimization_problem_t; + +template +struct pdlp_solver_settings_t; + +template +struct mip_solver_settings_t; + +/** + * @brief Build a gRPC SubmitJobRequest for an LP problem. + * + * Creates a SubmitJobRequest containing the LP problem and settings using + * the problem and settings mappers. Serialization is handled by the protobuf library. + */ +template +cuopt::remote::SubmitJobRequest build_lp_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const pdlp_solver_settings_t& settings); + +/** + * @brief Build a gRPC SubmitJobRequest for a MIP problem. + * + * Creates a SubmitJobRequest containing the MIP problem and settings using + * the problem and settings mappers. Serialization is handled by the protobuf library. + */ +template +cuopt::remote::SubmitJobRequest build_mip_submit_request( + const cpu_optimization_problem_t& cpu_problem, + const mip_solver_settings_t& settings, + bool enable_incumbents = false); + +/** + * @brief Build a gRPC StatusRequest. + * + * Simple helper to create a status check request. + * + * @param job_id The job ID to check status for + * @return StatusRequest protobuf message + */ +inline cuopt::remote::StatusRequest build_status_request(const std::string& job_id) +{ + cuopt::remote::StatusRequest request; + request.set_job_id(job_id); + return request; +} + +/** + * @brief Build a gRPC GetResultRequest. + * + * Simple helper to create a result retrieval request. + * + * @param job_id The job ID to get results for + * @return GetResultRequest protobuf message + */ +inline cuopt::remote::GetResultRequest build_get_result_request(const std::string& job_id) +{ + cuopt::remote::GetResultRequest request; + request.set_job_id(job_id); + return request; +} + +/** + * @brief Build a gRPC CancelRequest. + * + * Simple helper to create a job cancellation request. + * + * @param job_id The job ID to cancel + * @return CancelRequest protobuf message + */ +inline cuopt::remote::CancelRequest build_cancel_request(const std::string& job_id) +{ + cuopt::remote::CancelRequest request; + request.set_job_id(job_id); + return request; +} + +/** + * @brief Build a gRPC DeleteRequest. + * + * Simple helper to create a result deletion request. + * + * @param job_id The job ID whose result should be deleted + * @return DeleteRequest protobuf message + */ +inline cuopt::remote::DeleteRequest build_delete_request(const std::string& job_id) +{ + cuopt::remote::DeleteRequest request; + request.set_job_id(job_id); + return request; +} + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/grpc_settings_mapper.cpp b/cpp/src/grpc/grpc_settings_mapper.cpp new file mode 100644 index 0000000000..fcede38cf8 --- /dev/null +++ b/cpp/src/grpc/grpc_settings_mapper.cpp @@ -0,0 +1,267 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#include "grpc_settings_mapper.hpp" + +#include +#include +#include +#include +#include + +#include + +namespace cuopt::linear_programming { + +namespace { + +// Convert cuOpt pdlp_solver_mode_t to protobuf enum +cuopt::remote::PDLPSolverMode to_proto_pdlp_mode(pdlp_solver_mode_t mode) +{ + switch (mode) { + case pdlp_solver_mode_t::Stable1: return cuopt::remote::Stable1; + case pdlp_solver_mode_t::Stable2: return cuopt::remote::Stable2; + case pdlp_solver_mode_t::Methodical1: return cuopt::remote::Methodical1; + case pdlp_solver_mode_t::Fast1: return cuopt::remote::Fast1; + case pdlp_solver_mode_t::Stable3: return cuopt::remote::Stable3; + default: return cuopt::remote::Stable3; + } +} + +// Convert protobuf enum to cuOpt pdlp_solver_mode_t +pdlp_solver_mode_t from_proto_pdlp_mode(cuopt::remote::PDLPSolverMode mode) +{ + switch (mode) { + case cuopt::remote::Stable1: return pdlp_solver_mode_t::Stable1; + case cuopt::remote::Stable2: return pdlp_solver_mode_t::Stable2; + case cuopt::remote::Methodical1: return pdlp_solver_mode_t::Methodical1; + case cuopt::remote::Fast1: return pdlp_solver_mode_t::Fast1; + case cuopt::remote::Stable3: return pdlp_solver_mode_t::Stable3; + default: return pdlp_solver_mode_t::Stable3; + } +} + +// Convert cuOpt method_t to protobuf enum +cuopt::remote::LPMethod to_proto_method(method_t method) +{ + switch (method) { + case method_t::Concurrent: return cuopt::remote::Concurrent; + case method_t::PDLP: return cuopt::remote::PDLP; + case method_t::DualSimplex: return cuopt::remote::DualSimplex; + case method_t::Barrier: return cuopt::remote::Barrier; + default: return cuopt::remote::Concurrent; + } +} + +// Convert protobuf enum to cuOpt method_t +method_t from_proto_method(cuopt::remote::LPMethod method) +{ + switch (method) { + case cuopt::remote::Concurrent: return method_t::Concurrent; + case cuopt::remote::PDLP: return method_t::PDLP; + case cuopt::remote::DualSimplex: return method_t::DualSimplex; + case cuopt::remote::Barrier: return method_t::Barrier; + default: return method_t::Concurrent; + } +} + +} // anonymous namespace + +template +void map_pdlp_settings_to_proto(const pdlp_solver_settings_t& settings, + cuopt::remote::PDLPSolverSettings* pb_settings) +{ + // Termination tolerances (all names match cuOpt API) + pb_settings->set_absolute_gap_tolerance(settings.tolerances.absolute_gap_tolerance); + pb_settings->set_relative_gap_tolerance(settings.tolerances.relative_gap_tolerance); + pb_settings->set_primal_infeasible_tolerance(settings.tolerances.primal_infeasible_tolerance); + pb_settings->set_dual_infeasible_tolerance(settings.tolerances.dual_infeasible_tolerance); + pb_settings->set_absolute_dual_tolerance(settings.tolerances.absolute_dual_tolerance); + pb_settings->set_relative_dual_tolerance(settings.tolerances.relative_dual_tolerance); + pb_settings->set_absolute_primal_tolerance(settings.tolerances.absolute_primal_tolerance); + pb_settings->set_relative_primal_tolerance(settings.tolerances.relative_primal_tolerance); + + // Limits + pb_settings->set_time_limit(settings.time_limit); + // Avoid emitting a huge number when the iteration limit is the library default. + // Use -1 sentinel for "unset/use server defaults". + if (settings.iteration_limit == std::numeric_limits::max()) { + pb_settings->set_iteration_limit(-1); + } else { + pb_settings->set_iteration_limit(static_cast(settings.iteration_limit)); + } + + // Solver configuration + pb_settings->set_log_to_console(settings.log_to_console); + pb_settings->set_detect_infeasibility(settings.detect_infeasibility); + pb_settings->set_strict_infeasibility(settings.strict_infeasibility); + pb_settings->set_pdlp_solver_mode(to_proto_pdlp_mode(settings.pdlp_solver_mode)); + pb_settings->set_method(to_proto_method(settings.method)); + pb_settings->set_presolver(static_cast(settings.presolver)); + pb_settings->set_dual_postsolve(settings.dual_postsolve); + pb_settings->set_crossover(settings.crossover); + pb_settings->set_num_gpus(settings.num_gpus); + + pb_settings->set_per_constraint_residual(settings.per_constraint_residual); + pb_settings->set_cudss_deterministic(settings.cudss_deterministic); + pb_settings->set_folding(settings.folding); + pb_settings->set_augmented(settings.augmented); + pb_settings->set_dualize(settings.dualize); + pb_settings->set_ordering(settings.ordering); + pb_settings->set_barrier_dual_initial_point(settings.barrier_dual_initial_point); + pb_settings->set_eliminate_dense_columns(settings.eliminate_dense_columns); + pb_settings->set_pdlp_precision(static_cast(settings.pdlp_precision)); + pb_settings->set_save_best_primal_so_far(settings.save_best_primal_so_far); + pb_settings->set_first_primal_feasible(settings.first_primal_feasible); + + // TODO: Add warmstart data support + // if (settings.warm_start_data.has_value()) { + // auto* pb_warmstart = pb_settings->mutable_warm_start_data(); + // // Map warmstart data fields... + // } +} + +template +void map_proto_to_pdlp_settings(const cuopt::remote::PDLPSolverSettings& pb_settings, + pdlp_solver_settings_t& settings) +{ + // Termination tolerances (all names match cuOpt API) + settings.tolerances.absolute_gap_tolerance = pb_settings.absolute_gap_tolerance(); + settings.tolerances.relative_gap_tolerance = pb_settings.relative_gap_tolerance(); + settings.tolerances.primal_infeasible_tolerance = pb_settings.primal_infeasible_tolerance(); + settings.tolerances.dual_infeasible_tolerance = pb_settings.dual_infeasible_tolerance(); + settings.tolerances.absolute_dual_tolerance = pb_settings.absolute_dual_tolerance(); + settings.tolerances.relative_dual_tolerance = pb_settings.relative_dual_tolerance(); + settings.tolerances.absolute_primal_tolerance = pb_settings.absolute_primal_tolerance(); + settings.tolerances.relative_primal_tolerance = pb_settings.relative_primal_tolerance(); + + // Limits + settings.time_limit = pb_settings.time_limit(); + // proto3 defaults numeric fields to 0; treat negative iteration_limit as "unset" + // so the server keeps the library default (typically max()). + if (pb_settings.iteration_limit() >= 0) { + const auto limit = pb_settings.iteration_limit(); + settings.iteration_limit = (limit > static_cast(std::numeric_limits::max())) + ? std::numeric_limits::max() + : static_cast(limit); + } + + // Solver configuration + settings.log_to_console = pb_settings.log_to_console(); + settings.detect_infeasibility = pb_settings.detect_infeasibility(); + settings.strict_infeasibility = pb_settings.strict_infeasibility(); + settings.pdlp_solver_mode = from_proto_pdlp_mode(pb_settings.pdlp_solver_mode()); + settings.method = from_proto_method(pb_settings.method()); + { + auto pv = pb_settings.presolver(); + settings.presolver = (pv >= CUOPT_PRESOLVE_DEFAULT && pv <= CUOPT_PRESOLVE_PSLP) + ? static_cast(pv) + : presolver_t::Default; + } + settings.dual_postsolve = pb_settings.dual_postsolve(); + settings.crossover = pb_settings.crossover(); + settings.num_gpus = pb_settings.num_gpus(); + + settings.per_constraint_residual = pb_settings.per_constraint_residual(); + settings.cudss_deterministic = pb_settings.cudss_deterministic(); + settings.folding = pb_settings.folding(); + settings.augmented = pb_settings.augmented(); + settings.dualize = pb_settings.dualize(); + settings.ordering = pb_settings.ordering(); + settings.barrier_dual_initial_point = pb_settings.barrier_dual_initial_point(); + settings.eliminate_dense_columns = pb_settings.eliminate_dense_columns(); + { + auto pv = pb_settings.pdlp_precision(); + settings.pdlp_precision = + (pv >= CUOPT_PDLP_DEFAULT_PRECISION && pv <= CUOPT_PDLP_MIXED_PRECISION) + ? static_cast(pv) + : pdlp_precision_t::DefaultPrecision; + } + settings.save_best_primal_so_far = pb_settings.save_best_primal_so_far(); + settings.first_primal_feasible = pb_settings.first_primal_feasible(); + + // TODO: Add warmstart data support + // if (pb_settings.has_warm_start_data()) { + // // Map warmstart data fields... + // } +} + +template +void map_mip_settings_to_proto(const mip_solver_settings_t& settings, + cuopt::remote::MIPSolverSettings* pb_settings) +{ + // Limits + pb_settings->set_time_limit(settings.time_limit); + + // Tolerances (all names match cuOpt API) + pb_settings->set_relative_mip_gap(settings.tolerances.relative_mip_gap); + pb_settings->set_absolute_mip_gap(settings.tolerances.absolute_mip_gap); + pb_settings->set_integrality_tolerance(settings.tolerances.integrality_tolerance); + pb_settings->set_absolute_tolerance(settings.tolerances.absolute_tolerance); + pb_settings->set_relative_tolerance(settings.tolerances.relative_tolerance); + pb_settings->set_presolve_absolute_tolerance(settings.tolerances.presolve_absolute_tolerance); + + // Solver configuration + pb_settings->set_log_to_console(settings.log_to_console); + pb_settings->set_heuristics_only(settings.heuristics_only); + pb_settings->set_num_cpu_threads(settings.num_cpu_threads); + pb_settings->set_num_gpus(settings.num_gpus); + pb_settings->set_presolver(static_cast(settings.presolver)); + pb_settings->set_mip_scaling(settings.mip_scaling); +} + +template +void map_proto_to_mip_settings(const cuopt::remote::MIPSolverSettings& pb_settings, + mip_solver_settings_t& settings) +{ + // Limits + settings.time_limit = pb_settings.time_limit(); + + // Tolerances (all names match cuOpt API) + settings.tolerances.relative_mip_gap = pb_settings.relative_mip_gap(); + settings.tolerances.absolute_mip_gap = pb_settings.absolute_mip_gap(); + settings.tolerances.integrality_tolerance = pb_settings.integrality_tolerance(); + settings.tolerances.absolute_tolerance = pb_settings.absolute_tolerance(); + settings.tolerances.relative_tolerance = pb_settings.relative_tolerance(); + settings.tolerances.presolve_absolute_tolerance = pb_settings.presolve_absolute_tolerance(); + + // Solver configuration + settings.log_to_console = pb_settings.log_to_console(); + settings.heuristics_only = pb_settings.heuristics_only(); + settings.num_cpu_threads = pb_settings.num_cpu_threads(); + settings.num_gpus = pb_settings.num_gpus(); + { + auto pv = pb_settings.presolver(); + settings.presolver = (pv >= CUOPT_PRESOLVE_DEFAULT && pv <= CUOPT_PRESOLVE_PSLP) + ? static_cast(pv) + : presolver_t::Default; + } + settings.mip_scaling = pb_settings.mip_scaling(); +} + +// Explicit template instantiations +#if CUOPT_INSTANTIATE_FLOAT +template void map_pdlp_settings_to_proto(const pdlp_solver_settings_t& settings, + cuopt::remote::PDLPSolverSettings* pb_settings); +template void map_proto_to_pdlp_settings(const cuopt::remote::PDLPSolverSettings& pb_settings, + pdlp_solver_settings_t& settings); +template void map_mip_settings_to_proto(const mip_solver_settings_t& settings, + cuopt::remote::MIPSolverSettings* pb_settings); +template void map_proto_to_mip_settings(const cuopt::remote::MIPSolverSettings& pb_settings, + mip_solver_settings_t& settings); +#endif + +#if CUOPT_INSTANTIATE_DOUBLE +template void map_pdlp_settings_to_proto(const pdlp_solver_settings_t& settings, + cuopt::remote::PDLPSolverSettings* pb_settings); +template void map_proto_to_pdlp_settings(const cuopt::remote::PDLPSolverSettings& pb_settings, + pdlp_solver_settings_t& settings); +template void map_mip_settings_to_proto(const mip_solver_settings_t& settings, + cuopt::remote::MIPSolverSettings* pb_settings); +template void map_proto_to_mip_settings(const cuopt::remote::MIPSolverSettings& pb_settings, + mip_solver_settings_t& settings); +#endif + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/grpc_settings_mapper.hpp b/cpp/src/grpc/grpc_settings_mapper.hpp new file mode 100644 index 0000000000..6daf0d052b --- /dev/null +++ b/cpp/src/grpc/grpc_settings_mapper.hpp @@ -0,0 +1,61 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include + +namespace cuopt::linear_programming { + +// Forward declarations +template +struct pdlp_solver_settings_t; + +template +struct mip_solver_settings_t; + +/** + * @brief Map pdlp_solver_settings_t to protobuf PDLPSolverSettings message. + * + * Populates a protobuf message using the generated protobuf C++ API. + * Does not perform serialization — that is handled by the protobuf library. + */ +template +void map_pdlp_settings_to_proto(const pdlp_solver_settings_t& settings, + cuopt::remote::PDLPSolverSettings* pb_settings); + +/** + * @brief Map protobuf PDLPSolverSettings message to pdlp_solver_settings_t. + * + * Reads from a protobuf message using the generated protobuf C++ API. + * Does not perform deserialization — that is handled by the protobuf library. + */ +template +void map_proto_to_pdlp_settings(const cuopt::remote::PDLPSolverSettings& pb_settings, + pdlp_solver_settings_t& settings); + +/** + * @brief Map mip_solver_settings_t to protobuf MIPSolverSettings message. + * + * Populates a protobuf message using the generated protobuf C++ API. + * Does not perform serialization — that is handled by the protobuf library. + */ +template +void map_mip_settings_to_proto(const mip_solver_settings_t& settings, + cuopt::remote::MIPSolverSettings* pb_settings); + +/** + * @brief Map protobuf MIPSolverSettings message to mip_solver_settings_t. + * + * Reads from a protobuf message using the generated protobuf C++ API. + * Does not perform deserialization — that is handled by the protobuf library. + */ +template +void map_proto_to_mip_settings(const cuopt::remote::MIPSolverSettings& pb_settings, + mip_solver_settings_t& settings); + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/grpc_solution_mapper.cpp b/cpp/src/grpc/grpc_solution_mapper.cpp new file mode 100644 index 0000000000..700fd12c98 --- /dev/null +++ b/cpp/src/grpc/grpc_solution_mapper.cpp @@ -0,0 +1,735 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#include "grpc_solution_mapper.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cuopt::linear_programming { + +// Convert cuOpt termination status to protobuf enum +cuopt::remote::PDLPTerminationStatus to_proto_pdlp_status(pdlp_termination_status_t status) +{ + switch (status) { + case pdlp_termination_status_t::NoTermination: return cuopt::remote::PDLP_NO_TERMINATION; + case pdlp_termination_status_t::NumericalError: return cuopt::remote::PDLP_NUMERICAL_ERROR; + case pdlp_termination_status_t::Optimal: return cuopt::remote::PDLP_OPTIMAL; + case pdlp_termination_status_t::PrimalInfeasible: return cuopt::remote::PDLP_PRIMAL_INFEASIBLE; + case pdlp_termination_status_t::DualInfeasible: return cuopt::remote::PDLP_DUAL_INFEASIBLE; + case pdlp_termination_status_t::IterationLimit: return cuopt::remote::PDLP_ITERATION_LIMIT; + case pdlp_termination_status_t::TimeLimit: return cuopt::remote::PDLP_TIME_LIMIT; + case pdlp_termination_status_t::ConcurrentLimit: return cuopt::remote::PDLP_CONCURRENT_LIMIT; + case pdlp_termination_status_t::PrimalFeasible: return cuopt::remote::PDLP_PRIMAL_FEASIBLE; + default: return cuopt::remote::PDLP_NO_TERMINATION; + } +} + +// Convert protobuf enum to cuOpt termination status +pdlp_termination_status_t from_proto_pdlp_status(cuopt::remote::PDLPTerminationStatus status) +{ + switch (status) { + case cuopt::remote::PDLP_NO_TERMINATION: return pdlp_termination_status_t::NoTermination; + case cuopt::remote::PDLP_NUMERICAL_ERROR: return pdlp_termination_status_t::NumericalError; + case cuopt::remote::PDLP_OPTIMAL: return pdlp_termination_status_t::Optimal; + case cuopt::remote::PDLP_PRIMAL_INFEASIBLE: return pdlp_termination_status_t::PrimalInfeasible; + case cuopt::remote::PDLP_DUAL_INFEASIBLE: return pdlp_termination_status_t::DualInfeasible; + case cuopt::remote::PDLP_ITERATION_LIMIT: return pdlp_termination_status_t::IterationLimit; + case cuopt::remote::PDLP_TIME_LIMIT: return pdlp_termination_status_t::TimeLimit; + case cuopt::remote::PDLP_CONCURRENT_LIMIT: return pdlp_termination_status_t::ConcurrentLimit; + case cuopt::remote::PDLP_PRIMAL_FEASIBLE: return pdlp_termination_status_t::PrimalFeasible; + default: return pdlp_termination_status_t::NoTermination; + } +} + +// Convert MIP termination status +cuopt::remote::MIPTerminationStatus to_proto_mip_status(mip_termination_status_t status) +{ + switch (status) { + case mip_termination_status_t::NoTermination: return cuopt::remote::MIP_NO_TERMINATION; + case mip_termination_status_t::Optimal: return cuopt::remote::MIP_OPTIMAL; + case mip_termination_status_t::FeasibleFound: return cuopt::remote::MIP_FEASIBLE_FOUND; + case mip_termination_status_t::Infeasible: return cuopt::remote::MIP_INFEASIBLE; + case mip_termination_status_t::Unbounded: return cuopt::remote::MIP_UNBOUNDED; + case mip_termination_status_t::TimeLimit: return cuopt::remote::MIP_TIME_LIMIT; + case mip_termination_status_t::WorkLimit: return cuopt::remote::MIP_WORK_LIMIT; + default: return cuopt::remote::MIP_NO_TERMINATION; + } +} + +mip_termination_status_t from_proto_mip_status(cuopt::remote::MIPTerminationStatus status) +{ + switch (status) { + case cuopt::remote::MIP_NO_TERMINATION: return mip_termination_status_t::NoTermination; + case cuopt::remote::MIP_OPTIMAL: return mip_termination_status_t::Optimal; + case cuopt::remote::MIP_FEASIBLE_FOUND: return mip_termination_status_t::FeasibleFound; + case cuopt::remote::MIP_INFEASIBLE: return mip_termination_status_t::Infeasible; + case cuopt::remote::MIP_UNBOUNDED: return mip_termination_status_t::Unbounded; + case cuopt::remote::MIP_TIME_LIMIT: return mip_termination_status_t::TimeLimit; + case cuopt::remote::MIP_WORK_LIMIT: return mip_termination_status_t::WorkLimit; + default: return mip_termination_status_t::NoTermination; + } +} + +template +void map_lp_solution_to_proto(const cpu_lp_solution_t& solution, + cuopt::remote::LPSolution* pb_solution) +{ + pb_solution->set_termination_status(to_proto_pdlp_status(solution.get_termination_status())); + pb_solution->set_error_message(solution.get_error_status().what()); + + // Solution vectors - CPU solution already has data in host memory + const auto& primal = solution.get_primal_solution_host(); + const auto& dual = solution.get_dual_solution_host(); + const auto& reduced_cost = solution.get_reduced_cost_host(); + + for (const auto& v : primal) { + pb_solution->add_primal_solution(static_cast(v)); + } + for (const auto& v : dual) { + pb_solution->add_dual_solution(static_cast(v)); + } + for (const auto& v : reduced_cost) { + pb_solution->add_reduced_cost(static_cast(v)); + } + + // Statistics + pb_solution->set_l2_primal_residual(solution.get_l2_primal_residual()); + pb_solution->set_l2_dual_residual(solution.get_l2_dual_residual()); + pb_solution->set_primal_objective(solution.get_objective_value()); + pb_solution->set_dual_objective(solution.get_dual_objective_value()); + pb_solution->set_gap(solution.get_gap()); + pb_solution->set_nb_iterations(solution.get_num_iterations()); + pb_solution->set_solve_time(solution.get_solve_time()); + pb_solution->set_solved_by_pdlp(solution.is_solved_by_pdlp()); + + if (solution.has_warm_start_data()) { + auto* pb_ws = pb_solution->mutable_warm_start_data(); + const auto& ws = solution.get_cpu_pdlp_warm_start_data(); + + for (const auto& v : ws.current_primal_solution_) + pb_ws->add_current_primal_solution(static_cast(v)); + for (const auto& v : ws.current_dual_solution_) + pb_ws->add_current_dual_solution(static_cast(v)); + for (const auto& v : ws.initial_primal_average_) + pb_ws->add_initial_primal_average(static_cast(v)); + for (const auto& v : ws.initial_dual_average_) + pb_ws->add_initial_dual_average(static_cast(v)); + for (const auto& v : ws.current_ATY_) + pb_ws->add_current_aty(static_cast(v)); + for (const auto& v : ws.sum_primal_solutions_) + pb_ws->add_sum_primal_solutions(static_cast(v)); + for (const auto& v : ws.sum_dual_solutions_) + pb_ws->add_sum_dual_solutions(static_cast(v)); + for (const auto& v : ws.last_restart_duality_gap_primal_solution_) + pb_ws->add_last_restart_duality_gap_primal_solution(static_cast(v)); + for (const auto& v : ws.last_restart_duality_gap_dual_solution_) + pb_ws->add_last_restart_duality_gap_dual_solution(static_cast(v)); + + pb_ws->set_initial_primal_weight(static_cast(ws.initial_primal_weight_)); + pb_ws->set_initial_step_size(static_cast(ws.initial_step_size_)); + pb_ws->set_total_pdlp_iterations(static_cast(ws.total_pdlp_iterations_)); + pb_ws->set_total_pdhg_iterations(static_cast(ws.total_pdhg_iterations_)); + pb_ws->set_last_candidate_kkt_score(static_cast(ws.last_candidate_kkt_score_)); + pb_ws->set_last_restart_kkt_score(static_cast(ws.last_restart_kkt_score_)); + pb_ws->set_sum_solution_weight(static_cast(ws.sum_solution_weight_)); + pb_ws->set_iterations_since_last_restart( + static_cast(ws.iterations_since_last_restart_)); + } +} + +template +cpu_lp_solution_t map_proto_to_lp_solution(const cuopt::remote::LPSolution& pb_solution) +{ + // Convert solution vectors + std::vector primal(pb_solution.primal_solution().begin(), + pb_solution.primal_solution().end()); + std::vector dual(pb_solution.dual_solution().begin(), pb_solution.dual_solution().end()); + std::vector reduced_cost(pb_solution.reduced_cost().begin(), + pb_solution.reduced_cost().end()); + + auto status = from_proto_pdlp_status(pb_solution.termination_status()); + auto obj = static_cast(pb_solution.primal_objective()); + auto dual_obj = static_cast(pb_solution.dual_objective()); + auto solve_t = pb_solution.solve_time(); + auto l2_pr = static_cast(pb_solution.l2_primal_residual()); + auto l2_dr = static_cast(pb_solution.l2_dual_residual()); + auto g = static_cast(pb_solution.gap()); + auto iters = static_cast(pb_solution.nb_iterations()); + auto by_pdlp = pb_solution.solved_by_pdlp(); + + if (pb_solution.has_warm_start_data()) { + const auto& pb_ws = pb_solution.warm_start_data(); + cpu_pdlp_warm_start_data_t ws; + + ws.current_primal_solution_.assign(pb_ws.current_primal_solution().begin(), + pb_ws.current_primal_solution().end()); + ws.current_dual_solution_.assign(pb_ws.current_dual_solution().begin(), + pb_ws.current_dual_solution().end()); + ws.initial_primal_average_.assign(pb_ws.initial_primal_average().begin(), + pb_ws.initial_primal_average().end()); + ws.initial_dual_average_.assign(pb_ws.initial_dual_average().begin(), + pb_ws.initial_dual_average().end()); + ws.current_ATY_.assign(pb_ws.current_aty().begin(), pb_ws.current_aty().end()); + ws.sum_primal_solutions_.assign(pb_ws.sum_primal_solutions().begin(), + pb_ws.sum_primal_solutions().end()); + ws.sum_dual_solutions_.assign(pb_ws.sum_dual_solutions().begin(), + pb_ws.sum_dual_solutions().end()); + ws.last_restart_duality_gap_primal_solution_.assign( + pb_ws.last_restart_duality_gap_primal_solution().begin(), + pb_ws.last_restart_duality_gap_primal_solution().end()); + ws.last_restart_duality_gap_dual_solution_.assign( + pb_ws.last_restart_duality_gap_dual_solution().begin(), + pb_ws.last_restart_duality_gap_dual_solution().end()); + + ws.initial_primal_weight_ = static_cast(pb_ws.initial_primal_weight()); + ws.initial_step_size_ = static_cast(pb_ws.initial_step_size()); + ws.total_pdlp_iterations_ = static_cast(pb_ws.total_pdlp_iterations()); + ws.total_pdhg_iterations_ = static_cast(pb_ws.total_pdhg_iterations()); + ws.last_candidate_kkt_score_ = static_cast(pb_ws.last_candidate_kkt_score()); + ws.last_restart_kkt_score_ = static_cast(pb_ws.last_restart_kkt_score()); + ws.sum_solution_weight_ = static_cast(pb_ws.sum_solution_weight()); + ws.iterations_since_last_restart_ = static_cast(pb_ws.iterations_since_last_restart()); + + return cpu_lp_solution_t(std::move(primal), + std::move(dual), + std::move(reduced_cost), + status, + obj, + dual_obj, + solve_t, + l2_pr, + l2_dr, + g, + iters, + by_pdlp, + std::move(ws)); + } + + return cpu_lp_solution_t(std::move(primal), + std::move(dual), + std::move(reduced_cost), + status, + obj, + dual_obj, + solve_t, + l2_pr, + l2_dr, + g, + iters, + by_pdlp); +} + +template +void map_mip_solution_to_proto(const cpu_mip_solution_t& solution, + cuopt::remote::MIPSolution* pb_solution) +{ + pb_solution->set_termination_status(to_proto_mip_status(solution.get_termination_status())); + pb_solution->set_error_message(solution.get_error_status().what()); + + // Solution vector - CPU solution already has data in host memory + const auto& sol_vec = solution.get_solution_host(); + for (const auto& v : sol_vec) { + pb_solution->add_solution(static_cast(v)); + } + + // Solution statistics + pb_solution->set_objective(solution.get_objective_value()); + pb_solution->set_mip_gap(solution.get_mip_gap()); + pb_solution->set_solution_bound(solution.get_solution_bound()); + pb_solution->set_total_solve_time(solution.get_solve_time()); + pb_solution->set_presolve_time(solution.get_presolve_time()); + pb_solution->set_max_constraint_violation(solution.get_max_constraint_violation()); + pb_solution->set_max_int_violation(solution.get_max_int_violation()); + pb_solution->set_max_variable_bound_violation(solution.get_max_variable_bound_violation()); + pb_solution->set_nodes(solution.get_num_nodes()); + pb_solution->set_simplex_iterations(solution.get_num_simplex_iterations()); +} + +template +cpu_mip_solution_t map_proto_to_mip_solution( + const cuopt::remote::MIPSolution& pb_solution) +{ + // Convert solution vector + std::vector solution_vec(pb_solution.solution().begin(), pb_solution.solution().end()); + + // Create CPU MIP solution with data + return cpu_mip_solution_t(std::move(solution_vec), + from_proto_mip_status(pb_solution.termination_status()), + static_cast(pb_solution.objective()), + static_cast(pb_solution.mip_gap()), + static_cast(pb_solution.solution_bound()), + pb_solution.total_solve_time(), + pb_solution.presolve_time(), + static_cast(pb_solution.max_constraint_violation()), + static_cast(pb_solution.max_int_violation()), + static_cast(pb_solution.max_variable_bound_violation()), + static_cast(pb_solution.nodes()), + static_cast(pb_solution.simplex_iterations())); +} + +// ============================================================================ +// Size estimation +// ============================================================================ + +template +size_t estimate_lp_solution_proto_size(const cpu_lp_solution_t& solution) +{ + size_t est = 0; + est += static_cast(solution.get_primal_solution_size()) * sizeof(double); + est += static_cast(solution.get_dual_solution_size()) * sizeof(double); + est += static_cast(solution.get_reduced_cost_size()) * sizeof(double); + if (solution.has_warm_start_data()) { + const auto& ws = solution.get_cpu_pdlp_warm_start_data(); + est += ws.current_primal_solution_.size() * sizeof(double); + est += ws.current_dual_solution_.size() * sizeof(double); + est += ws.initial_primal_average_.size() * sizeof(double); + est += ws.initial_dual_average_.size() * sizeof(double); + est += ws.current_ATY_.size() * sizeof(double); + est += ws.sum_primal_solutions_.size() * sizeof(double); + est += ws.sum_dual_solutions_.size() * sizeof(double); + est += ws.last_restart_duality_gap_primal_solution_.size() * sizeof(double); + est += ws.last_restart_duality_gap_dual_solution_.size() * sizeof(double); + } + est += 512; // scalars + tags overhead + return est; +} + +template +size_t estimate_mip_solution_proto_size(const cpu_mip_solution_t& solution) +{ + size_t est = 0; + est += static_cast(solution.get_solution_size()) * sizeof(double); + est += 256; // scalars + tags overhead + return est; +} + +// ============================================================================ +// Chunked result header population +// ============================================================================ + +namespace { +void add_result_array_descriptor(cuopt::remote::ChunkedResultHeader* header, + cuopt::remote::ResultFieldId fid, + int64_t count, + int64_t elem_size) +{ + if (count <= 0) return; + auto* desc = header->add_arrays(); + desc->set_field_id(fid); + desc->set_total_elements(count); + desc->set_element_size_bytes(elem_size); +} + +template +std::vector doubles_to_bytes(const std::vector& vec) +{ + std::vector tmp(vec.begin(), vec.end()); + std::vector bytes(tmp.size() * sizeof(double)); + std::memcpy(bytes.data(), tmp.data(), bytes.size()); + return bytes; +} +} // namespace + +template +void populate_chunked_result_header_lp(const cpu_lp_solution_t& solution, + cuopt::remote::ChunkedResultHeader* header) +{ + header->set_is_mip(false); + header->set_lp_termination_status(to_proto_pdlp_status(solution.get_termination_status())); + header->set_error_message(solution.get_error_status().what()); + header->set_l2_primal_residual(solution.get_l2_primal_residual()); + header->set_l2_dual_residual(solution.get_l2_dual_residual()); + header->set_primal_objective(solution.get_objective_value()); + header->set_dual_objective(solution.get_dual_objective_value()); + header->set_gap(solution.get_gap()); + header->set_nb_iterations(solution.get_num_iterations()); + header->set_solve_time(solution.get_solve_time()); + header->set_solved_by_pdlp(solution.is_solved_by_pdlp()); + + const auto& primal = solution.get_primal_solution_host(); + const auto& dual = solution.get_dual_solution_host(); + const auto& reduced_cost = solution.get_reduced_cost_host(); + + add_result_array_descriptor( + header, cuopt::remote::RESULT_PRIMAL_SOLUTION, primal.size(), sizeof(double)); + add_result_array_descriptor( + header, cuopt::remote::RESULT_DUAL_SOLUTION, dual.size(), sizeof(double)); + add_result_array_descriptor( + header, cuopt::remote::RESULT_REDUCED_COST, reduced_cost.size(), sizeof(double)); + + if (solution.has_warm_start_data()) { + const auto& ws = solution.get_cpu_pdlp_warm_start_data(); + header->set_ws_initial_primal_weight(static_cast(ws.initial_primal_weight_)); + header->set_ws_initial_step_size(static_cast(ws.initial_step_size_)); + header->set_ws_total_pdlp_iterations(static_cast(ws.total_pdlp_iterations_)); + header->set_ws_total_pdhg_iterations(static_cast(ws.total_pdhg_iterations_)); + header->set_ws_last_candidate_kkt_score(static_cast(ws.last_candidate_kkt_score_)); + header->set_ws_last_restart_kkt_score(static_cast(ws.last_restart_kkt_score_)); + header->set_ws_sum_solution_weight(static_cast(ws.sum_solution_weight_)); + header->set_ws_iterations_since_last_restart( + static_cast(ws.iterations_since_last_restart_)); + + add_result_array_descriptor(header, + cuopt::remote::RESULT_WS_CURRENT_PRIMAL, + ws.current_primal_solution_.size(), + sizeof(double)); + add_result_array_descriptor(header, + cuopt::remote::RESULT_WS_CURRENT_DUAL, + ws.current_dual_solution_.size(), + sizeof(double)); + add_result_array_descriptor(header, + cuopt::remote::RESULT_WS_INITIAL_PRIMAL_AVG, + ws.initial_primal_average_.size(), + sizeof(double)); + add_result_array_descriptor(header, + cuopt::remote::RESULT_WS_INITIAL_DUAL_AVG, + ws.initial_dual_average_.size(), + sizeof(double)); + add_result_array_descriptor( + header, cuopt::remote::RESULT_WS_CURRENT_ATY, ws.current_ATY_.size(), sizeof(double)); + add_result_array_descriptor( + header, cuopt::remote::RESULT_WS_SUM_PRIMAL, ws.sum_primal_solutions_.size(), sizeof(double)); + add_result_array_descriptor( + header, cuopt::remote::RESULT_WS_SUM_DUAL, ws.sum_dual_solutions_.size(), sizeof(double)); + add_result_array_descriptor(header, + cuopt::remote::RESULT_WS_LAST_RESTART_GAP_PRIMAL, + ws.last_restart_duality_gap_primal_solution_.size(), + sizeof(double)); + add_result_array_descriptor(header, + cuopt::remote::RESULT_WS_LAST_RESTART_GAP_DUAL, + ws.last_restart_duality_gap_dual_solution_.size(), + sizeof(double)); + } +} + +template +void populate_chunked_result_header_mip(const cpu_mip_solution_t& solution, + cuopt::remote::ChunkedResultHeader* header) +{ + header->set_is_mip(true); + header->set_mip_termination_status(to_proto_mip_status(solution.get_termination_status())); + header->set_mip_error_message(solution.get_error_status().what()); + header->set_mip_objective(solution.get_objective_value()); + header->set_mip_gap(solution.get_mip_gap()); + header->set_solution_bound(solution.get_solution_bound()); + header->set_total_solve_time(solution.get_solve_time()); + header->set_presolve_time(solution.get_presolve_time()); + header->set_max_constraint_violation(solution.get_max_constraint_violation()); + header->set_max_int_violation(solution.get_max_int_violation()); + header->set_max_variable_bound_violation(solution.get_max_variable_bound_violation()); + header->set_nodes(solution.get_num_nodes()); + header->set_simplex_iterations(solution.get_num_simplex_iterations()); + + add_result_array_descriptor(header, + cuopt::remote::RESULT_MIP_SOLUTION, + solution.get_solution_host().size(), + sizeof(double)); +} + +// ============================================================================ +// Collect solution arrays as raw bytes +// ============================================================================ + +template +std::map> collect_lp_solution_arrays( + const cpu_lp_solution_t& solution) +{ + std::map> arrays; + + const auto& primal = solution.get_primal_solution_host(); + const auto& dual = solution.get_dual_solution_host(); + const auto& reduced_cost = solution.get_reduced_cost_host(); + + if (!primal.empty()) { arrays[cuopt::remote::RESULT_PRIMAL_SOLUTION] = doubles_to_bytes(primal); } + if (!dual.empty()) { arrays[cuopt::remote::RESULT_DUAL_SOLUTION] = doubles_to_bytes(dual); } + if (!reduced_cost.empty()) { + arrays[cuopt::remote::RESULT_REDUCED_COST] = doubles_to_bytes(reduced_cost); + } + + if (solution.has_warm_start_data()) { + const auto& ws = solution.get_cpu_pdlp_warm_start_data(); + if (!ws.current_primal_solution_.empty()) { + arrays[cuopt::remote::RESULT_WS_CURRENT_PRIMAL] = + doubles_to_bytes(ws.current_primal_solution_); + } + if (!ws.current_dual_solution_.empty()) { + arrays[cuopt::remote::RESULT_WS_CURRENT_DUAL] = doubles_to_bytes(ws.current_dual_solution_); + } + if (!ws.initial_primal_average_.empty()) { + arrays[cuopt::remote::RESULT_WS_INITIAL_PRIMAL_AVG] = + doubles_to_bytes(ws.initial_primal_average_); + } + if (!ws.initial_dual_average_.empty()) { + arrays[cuopt::remote::RESULT_WS_INITIAL_DUAL_AVG] = + doubles_to_bytes(ws.initial_dual_average_); + } + if (!ws.current_ATY_.empty()) { + arrays[cuopt::remote::RESULT_WS_CURRENT_ATY] = doubles_to_bytes(ws.current_ATY_); + } + if (!ws.sum_primal_solutions_.empty()) { + arrays[cuopt::remote::RESULT_WS_SUM_PRIMAL] = doubles_to_bytes(ws.sum_primal_solutions_); + } + if (!ws.sum_dual_solutions_.empty()) { + arrays[cuopt::remote::RESULT_WS_SUM_DUAL] = doubles_to_bytes(ws.sum_dual_solutions_); + } + if (!ws.last_restart_duality_gap_primal_solution_.empty()) { + arrays[cuopt::remote::RESULT_WS_LAST_RESTART_GAP_PRIMAL] = + doubles_to_bytes(ws.last_restart_duality_gap_primal_solution_); + } + if (!ws.last_restart_duality_gap_dual_solution_.empty()) { + arrays[cuopt::remote::RESULT_WS_LAST_RESTART_GAP_DUAL] = + doubles_to_bytes(ws.last_restart_duality_gap_dual_solution_); + } + } + + return arrays; +} + +template +std::map> collect_mip_solution_arrays( + const cpu_mip_solution_t& solution) +{ + std::map> arrays; + const auto& sol_vec = solution.get_solution_host(); + if (!sol_vec.empty()) { arrays[cuopt::remote::RESULT_MIP_SOLUTION] = doubles_to_bytes(sol_vec); } + return arrays; +} + +// ============================================================================ +// Chunked result -> solution (client-side) +// ============================================================================ + +namespace { + +template +std::vector bytes_to_typed(const std::map>& arrays, + int32_t field_id) +{ + auto it = arrays.find(field_id); + if (it == arrays.end() || it->second.empty()) return {}; + + const auto& raw = it->second; + if constexpr (std::is_same_v) { + if (raw.size() % sizeof(double) != 0) return {}; + size_t n = raw.size() / sizeof(double); + std::vector tmp(n); + std::memcpy(tmp.data(), raw.data(), n * sizeof(double)); + return std::vector(tmp.begin(), tmp.end()); + } else if constexpr (std::is_same_v) { + if (raw.size() % sizeof(double) != 0) return {}; + size_t n = raw.size() / sizeof(double); + std::vector v(n); + std::memcpy(v.data(), raw.data(), n * sizeof(double)); + return v; + } else { + if (raw.size() % sizeof(T) != 0) return {}; + size_t n = raw.size() / sizeof(T); + std::vector v(n); + std::memcpy(v.data(), raw.data(), n * sizeof(T)); + return v; + } +} + +} // namespace + +template +cpu_lp_solution_t chunked_result_to_lp_solution( + const cuopt::remote::ChunkedResultHeader& h, + const std::map>& arrays) +{ + auto primal = bytes_to_typed(arrays, cuopt::remote::RESULT_PRIMAL_SOLUTION); + auto dual = bytes_to_typed(arrays, cuopt::remote::RESULT_DUAL_SOLUTION); + auto reduced_cost = bytes_to_typed(arrays, cuopt::remote::RESULT_REDUCED_COST); + + auto status = from_proto_pdlp_status(h.lp_termination_status()); + auto obj = static_cast(h.primal_objective()); + auto dual_obj = static_cast(h.dual_objective()); + auto solve_t = h.solve_time(); + auto l2_pr = static_cast(h.l2_primal_residual()); + auto l2_dr = static_cast(h.l2_dual_residual()); + auto g = static_cast(h.gap()); + auto iters = static_cast(h.nb_iterations()); + auto by_pdlp = h.solved_by_pdlp(); + + auto ws_primal = bytes_to_typed(arrays, cuopt::remote::RESULT_WS_CURRENT_PRIMAL); + if (!ws_primal.empty()) { + cpu_pdlp_warm_start_data_t ws; + ws.current_primal_solution_ = std::move(ws_primal); + ws.current_dual_solution_ = bytes_to_typed(arrays, cuopt::remote::RESULT_WS_CURRENT_DUAL); + ws.initial_primal_average_ = + bytes_to_typed(arrays, cuopt::remote::RESULT_WS_INITIAL_PRIMAL_AVG); + ws.initial_dual_average_ = + bytes_to_typed(arrays, cuopt::remote::RESULT_WS_INITIAL_DUAL_AVG); + ws.current_ATY_ = bytes_to_typed(arrays, cuopt::remote::RESULT_WS_CURRENT_ATY); + ws.sum_primal_solutions_ = bytes_to_typed(arrays, cuopt::remote::RESULT_WS_SUM_PRIMAL); + ws.sum_dual_solutions_ = bytes_to_typed(arrays, cuopt::remote::RESULT_WS_SUM_DUAL); + ws.last_restart_duality_gap_primal_solution_ = + bytes_to_typed(arrays, cuopt::remote::RESULT_WS_LAST_RESTART_GAP_PRIMAL); + ws.last_restart_duality_gap_dual_solution_ = + bytes_to_typed(arrays, cuopt::remote::RESULT_WS_LAST_RESTART_GAP_DUAL); + + ws.initial_primal_weight_ = static_cast(h.ws_initial_primal_weight()); + ws.initial_step_size_ = static_cast(h.ws_initial_step_size()); + ws.total_pdlp_iterations_ = static_cast(h.ws_total_pdlp_iterations()); + ws.total_pdhg_iterations_ = static_cast(h.ws_total_pdhg_iterations()); + ws.last_candidate_kkt_score_ = static_cast(h.ws_last_candidate_kkt_score()); + ws.last_restart_kkt_score_ = static_cast(h.ws_last_restart_kkt_score()); + ws.sum_solution_weight_ = static_cast(h.ws_sum_solution_weight()); + ws.iterations_since_last_restart_ = static_cast(h.ws_iterations_since_last_restart()); + + return cpu_lp_solution_t(std::move(primal), + std::move(dual), + std::move(reduced_cost), + status, + obj, + dual_obj, + solve_t, + l2_pr, + l2_dr, + g, + iters, + by_pdlp, + std::move(ws)); + } + + return cpu_lp_solution_t(std::move(primal), + std::move(dual), + std::move(reduced_cost), + status, + obj, + dual_obj, + solve_t, + l2_pr, + l2_dr, + g, + iters, + by_pdlp); +} + +template +cpu_mip_solution_t chunked_result_to_mip_solution( + const cuopt::remote::ChunkedResultHeader& h, + const std::map>& arrays) +{ + auto sol_vec = bytes_to_typed(arrays, cuopt::remote::RESULT_MIP_SOLUTION); + + return cpu_mip_solution_t(std::move(sol_vec), + from_proto_mip_status(h.mip_termination_status()), + static_cast(h.mip_objective()), + static_cast(h.mip_gap()), + static_cast(h.solution_bound()), + h.total_solve_time(), + h.presolve_time(), + static_cast(h.max_constraint_violation()), + static_cast(h.max_int_violation()), + static_cast(h.max_variable_bound_violation()), + static_cast(h.nodes()), + static_cast(h.simplex_iterations())); +} + +// ============================================================================ +// Build full protobuf from stored header + arrays (server-side GetResult RPC) +// ============================================================================ + +template +void build_lp_solution_proto(const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::LPSolution* proto) +{ + auto cpu_sol = chunked_result_to_lp_solution(header, arrays); + map_lp_solution_to_proto(cpu_sol, proto); +} + +template +void build_mip_solution_proto(const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::MIPSolution* proto) +{ + auto cpu_sol = chunked_result_to_mip_solution(header, arrays); + map_mip_solution_to_proto(cpu_sol, proto); +} + +// Explicit template instantiations +#if CUOPT_INSTANTIATE_FLOAT +template void map_lp_solution_to_proto(const cpu_lp_solution_t& solution, + cuopt::remote::LPSolution* pb_solution); +template cpu_lp_solution_t map_proto_to_lp_solution( + const cuopt::remote::LPSolution& pb_solution); +template void map_mip_solution_to_proto(const cpu_mip_solution_t& solution, + cuopt::remote::MIPSolution* pb_solution); +template cpu_mip_solution_t map_proto_to_mip_solution( + const cuopt::remote::MIPSolution& pb_solution); +template size_t estimate_lp_solution_proto_size(const cpu_lp_solution_t& solution); +template size_t estimate_mip_solution_proto_size( + const cpu_mip_solution_t& solution); +template void populate_chunked_result_header_lp(const cpu_lp_solution_t& solution, + cuopt::remote::ChunkedResultHeader* header); +template void populate_chunked_result_header_mip(const cpu_mip_solution_t& solution, + cuopt::remote::ChunkedResultHeader* header); +template std::map> collect_lp_solution_arrays( + const cpu_lp_solution_t& solution); +template std::map> collect_mip_solution_arrays( + const cpu_mip_solution_t& solution); +template cpu_lp_solution_t chunked_result_to_lp_solution( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays); +template cpu_mip_solution_t chunked_result_to_mip_solution( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays); +template void build_lp_solution_proto( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::LPSolution* proto); +template void build_mip_solution_proto( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::MIPSolution* proto); +#endif + +#if CUOPT_INSTANTIATE_DOUBLE +template void map_lp_solution_to_proto(const cpu_lp_solution_t& solution, + cuopt::remote::LPSolution* pb_solution); +template cpu_lp_solution_t map_proto_to_lp_solution( + const cuopt::remote::LPSolution& pb_solution); +template void map_mip_solution_to_proto(const cpu_mip_solution_t& solution, + cuopt::remote::MIPSolution* pb_solution); +template cpu_mip_solution_t map_proto_to_mip_solution( + const cuopt::remote::MIPSolution& pb_solution); +template size_t estimate_lp_solution_proto_size(const cpu_lp_solution_t& solution); +template size_t estimate_mip_solution_proto_size( + const cpu_mip_solution_t& solution); +template void populate_chunked_result_header_lp(const cpu_lp_solution_t& solution, + cuopt::remote::ChunkedResultHeader* header); +template void populate_chunked_result_header_mip( + const cpu_mip_solution_t& solution, cuopt::remote::ChunkedResultHeader* header); +template std::map> collect_lp_solution_arrays( + const cpu_lp_solution_t& solution); +template std::map> collect_mip_solution_arrays( + const cpu_mip_solution_t& solution); +template cpu_lp_solution_t chunked_result_to_lp_solution( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays); +template cpu_mip_solution_t chunked_result_to_mip_solution( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays); +template void build_lp_solution_proto( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::LPSolution* proto); +template void build_mip_solution_proto( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::MIPSolution* proto); +#endif + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/grpc_solution_mapper.hpp b/cpp/src/grpc/grpc_solution_mapper.hpp new file mode 100644 index 0000000000..127bdb2c96 --- /dev/null +++ b/cpp/src/grpc/grpc_solution_mapper.hpp @@ -0,0 +1,183 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace cuopt::linear_programming { + +/** + * @brief Map cpu_lp_solution_t to protobuf LPSolution message. + * + * Populates a protobuf message using the generated protobuf C++ API. + * Does not perform serialization — that is handled by the protobuf library. + */ +template +void map_lp_solution_to_proto(const cpu_lp_solution_t& solution, + cuopt::remote::LPSolution* pb_solution); + +/** + * @brief Map protobuf LPSolution message to cpu_lp_solution_t. + * + * Reads from a protobuf message using the generated protobuf C++ API. + * Does not perform deserialization — that is handled by the protobuf library. + */ +template +cpu_lp_solution_t map_proto_to_lp_solution(const cuopt::remote::LPSolution& pb_solution); + +/** + * @brief Map cpu_mip_solution_t to protobuf MIPSolution message. + * + * Populates a protobuf message using the generated protobuf C++ API. + * Does not perform serialization — that is handled by the protobuf library. + */ +template +void map_mip_solution_to_proto(const cpu_mip_solution_t& solution, + cuopt::remote::MIPSolution* pb_solution); + +/** + * @brief Map protobuf MIPSolution message to cpu_mip_solution_t. + * + * Reads from a protobuf message using the generated protobuf C++ API. + * Does not perform deserialization — that is handled by the protobuf library. + */ +template +cpu_mip_solution_t map_proto_to_mip_solution( + const cuopt::remote::MIPSolution& pb_solution); + +/** + * @brief Convert cuOpt termination status to protobuf enum. + * @param status cuOpt PDLP termination status + * @return Protobuf PDLPTerminationStatus enum + */ +cuopt::remote::PDLPTerminationStatus to_proto_pdlp_status(pdlp_termination_status_t status); + +/** + * @brief Convert protobuf enum to cuOpt termination status. + * @param status Protobuf PDLPTerminationStatus enum + * @return cuOpt PDLP termination status + */ +pdlp_termination_status_t from_proto_pdlp_status(cuopt::remote::PDLPTerminationStatus status); + +/** + * @brief Convert cuOpt MIP termination status to protobuf enum. + * @param status cuOpt MIP termination status + * @return Protobuf MIPTerminationStatus enum + */ +cuopt::remote::MIPTerminationStatus to_proto_mip_status(mip_termination_status_t status); + +/** + * @brief Convert protobuf enum to cuOpt MIP termination status. + * @param status Protobuf MIPTerminationStatus enum + * @return cuOpt MIP termination status + */ +mip_termination_status_t from_proto_mip_status(cuopt::remote::MIPTerminationStatus status); + +// ============================================================================ +// Chunked result support (for results exceeding gRPC max message size) +// ============================================================================ + +/** + * @brief Estimate serialized protobuf size of an LP solution. + */ +template +size_t estimate_lp_solution_proto_size(const cpu_lp_solution_t& solution); + +/** + * @brief Estimate serialized protobuf size of a MIP solution. + */ +template +size_t estimate_mip_solution_proto_size(const cpu_mip_solution_t& solution); + +/** + * @brief Populate a ChunkedResultHeader from an LP solution (scalar fields + array descriptors). + */ +template +void populate_chunked_result_header_lp(const cpu_lp_solution_t& solution, + cuopt::remote::ChunkedResultHeader* header); + +/** + * @brief Populate a ChunkedResultHeader from a MIP solution (scalar fields + array descriptors). + */ +template +void populate_chunked_result_header_mip(const cpu_mip_solution_t& solution, + cuopt::remote::ChunkedResultHeader* header); + +/** + * @brief Collect LP solution arrays as raw bytes keyed by ResultFieldId. + * + * Returns a map of ResultFieldId -> raw byte data (doubles packed as bytes). + * Used by the worker to send chunked result data. + */ +template +std::map> collect_lp_solution_arrays( + const cpu_lp_solution_t& solution); + +/** + * @brief Collect MIP solution arrays as raw bytes keyed by ResultFieldId. + */ +template +std::map> collect_mip_solution_arrays( + const cpu_mip_solution_t& solution); + +// ============================================================================ +// Chunked result -> solution (for gRPC client) +// ============================================================================ + +/** + * @brief Reconstruct a cpu_lp_solution_t from chunked result header and raw array data. + * + * This is the client-side counterpart to collect_lp_solution_arrays + + * populate_chunked_result_header_lp. It reads scalars from the header and typed arrays from the + * byte map. + */ +template +cpu_lp_solution_t chunked_result_to_lp_solution( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays); + +/** + * @brief Reconstruct a cpu_mip_solution_t from chunked result header and raw array data. + */ +template +cpu_mip_solution_t chunked_result_to_mip_solution( + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays); + +// ============================================================================ +// Build full protobuf solution from stored header + arrays (server-side GetResult RPC) +// ============================================================================ + +/** + * @brief Build a full LPSolution protobuf from a ChunkedResultHeader and raw arrays. + * + * Used by the server's GetResult RPC to serve unary responses. + * Composes chunked_result_to_lp_solution + map_lp_solution_to_proto. + */ +template +void build_lp_solution_proto(const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::LPSolution* proto); + +/** + * @brief Build a full MIPSolution protobuf from a ChunkedResultHeader and raw arrays. + */ +template +void build_mip_solution_proto(const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays, + cuopt::remote::MIPSolution* proto); + +} // namespace cuopt::linear_programming diff --git a/cpp/src/grpc/server/grpc_field_element_size.hpp b/cpp/src/grpc/server/grpc_field_element_size.hpp new file mode 100644 index 0000000000..defc6b9674 --- /dev/null +++ b/cpp/src/grpc/server/grpc_field_element_size.hpp @@ -0,0 +1,32 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +// Codegen target: this file maps ArrayFieldId enum values to their element byte sizes. +// A future version of cpp/codegen/generate_conversions.py can produce this from +// a problem_arrays section in field_registry.yaml. + +#pragma once + +#ifdef CUOPT_ENABLE_GRPC + +#include +#include "cuopt_remote.pb.h" + +inline int64_t array_field_element_size(cuopt::remote::ArrayFieldId field_id) +{ + switch (field_id) { + case cuopt::remote::FIELD_A_INDICES: + case cuopt::remote::FIELD_A_OFFSETS: + case cuopt::remote::FIELD_Q_INDICES: + case cuopt::remote::FIELD_Q_OFFSETS: return 4; + case cuopt::remote::FIELD_ROW_TYPES: + case cuopt::remote::FIELD_VARIABLE_TYPES: + case cuopt::remote::FIELD_VARIABLE_NAMES: + case cuopt::remote::FIELD_ROW_NAMES: return 1; + default: return 8; + } +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_incumbent_proto.hpp b/cpp/src/grpc/server/grpc_incumbent_proto.hpp new file mode 100644 index 0000000000..f5c6f2e79e --- /dev/null +++ b/cpp/src/grpc/server/grpc_incumbent_proto.hpp @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +// Codegen target: this file builds and parses cuopt::remote::Incumbent protobuf messages. +// A future version of cpp/codegen/generate_conversions.py can produce this from +// an incumbent section in field_registry.yaml. + +#pragma once + +#ifdef CUOPT_ENABLE_GRPC + +#include +#include +#include +#include +#include "cuopt_remote.pb.h" +#include "cuopt_remote_service.pb.h" + +inline std::vector build_incumbent_proto(const std::string& job_id, + double objective, + const std::vector& assignment) +{ + cuopt::remote::Incumbent msg; + msg.set_job_id(job_id); + msg.set_objective(objective); + for (double v : assignment) { + msg.add_assignment(v); + } + auto size = msg.ByteSizeLong(); + if (size > static_cast(std::numeric_limits::max())) { return {}; } + std::vector buffer(size); + if (!msg.SerializeToArray(buffer.data(), static_cast(buffer.size()))) { return {}; } + return buffer; +} + +inline bool parse_incumbent_proto(const uint8_t* data, + size_t size, + std::string& job_id, + double& objective, + std::vector& assignment) +{ + cuopt::remote::Incumbent incumbent_msg; + if (!incumbent_msg.ParseFromArray(data, static_cast(size))) { return false; } + + job_id = incumbent_msg.job_id(); + objective = incumbent_msg.objective(); + assignment.clear(); + assignment.reserve(incumbent_msg.assignment_size()); + for (int i = 0; i < incumbent_msg.assignment_size(); ++i) { + assignment.push_back(incumbent_msg.assignment(i)); + } + return true; +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_job_management.cpp b/cpp/src/grpc/server/grpc_job_management.cpp new file mode 100644 index 0000000000..e731c2c1f5 --- /dev/null +++ b/cpp/src/grpc/server/grpc_job_management.cpp @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef CUOPT_ENABLE_GRPC + +#include "grpc_pipe_serialization.hpp" +#include "grpc_server_types.hpp" + +// write_to_pipe / read_from_pipe are defined in grpc_pipe_io.cpp + +bool send_job_data_pipe(int worker_idx, const std::vector& data) +{ + int fd; + { + std::lock_guard lock(worker_pipes_mutex); + if (worker_idx < 0 || worker_idx >= static_cast(worker_pipes.size())) { return false; } + fd = worker_pipes[worker_idx].to_worker_fd; + } + if (fd < 0) return false; + + uint64_t size = data.size(); + if (!write_to_pipe(fd, &size, sizeof(size))) return false; + if (size > 0 && !write_to_pipe(fd, data.data(), data.size())) return false; + return true; +} + +bool recv_job_data_pipe(int fd, uint64_t expected_size, std::vector& data) +{ + uint64_t size; + if (!read_from_pipe(fd, &size, sizeof(size))) return false; + if (size != expected_size) { + std::cerr << "[Worker] Size mismatch: expected " << expected_size << ", got " << size << "\n"; + return false; + } + data.resize(size); + if (size > 0 && !read_from_pipe(fd, data.data(), size)) return false; + return true; +} + +bool send_incumbent_pipe(int fd, const std::vector& data) +{ + uint64_t size = data.size(); + if (!write_to_pipe(fd, &size, sizeof(size))) return false; + if (size > 0 && !write_to_pipe(fd, data.data(), data.size())) return false; + return true; +} + +bool recv_incumbent_pipe(int fd, std::vector& data) +{ + static constexpr uint64_t kMaxIncumbentBytes = 256ULL * 1024 * 1024; + uint64_t size; + if (!read_from_pipe(fd, &size, sizeof(size))) return false; + if (size > kMaxIncumbentBytes) return false; + data.resize(size); + if (size > 0 && !read_from_pipe(fd, data.data(), size)) return false; + return true; +} + +// ============================================================================= +// Job management +// ============================================================================= + +// Reserve a shared-memory job queue slot, store the serialized request data, +// and register the job in the tracker. Returns {true, job_id} on success. +// Uses CAS on `claimed` for lock-free slot reservation and release semantics +// on `ready` to publish all writes to the dispatch thread. +std::pair submit_job_async(std::vector&& request_data, bool is_mip) +{ + std::string job_id = generate_job_id(); + + // Atomically reserve a free slot. + int slot = -1; + for (size_t i = 0; i < MAX_JOBS; ++i) { + if (job_queue[i].ready.load()) continue; + bool expected = false; + if (job_queue[i].claimed.compare_exchange_strong(expected, true)) { + slot = static_cast(i); + break; + } + } + if (slot < 0) { return {false, "Job queue full"}; } + + // Populate the slot while we hold the `claimed` flag. + copy_cstr(job_queue[slot].job_id, job_id); + job_queue[slot].problem_type = is_mip ? 1 : 0; + job_queue[slot].data_size = request_data.size(); + job_queue[slot].cancelled.store(false); + job_queue[slot].worker_index.store(-1); + job_queue[slot].data_sent.store(false); + job_queue[slot].is_chunked = false; + job_queue[slot].worker_pid = 0; + + { + std::lock_guard lock(pending_data_mutex); + pending_job_data[job_id] = std::move(request_data); + } + + { + std::lock_guard lock(tracker_mutex); + job_tracker[job_id] = + JobInfo{job_id, JobStatus::QUEUED, std::chrono::steady_clock::now(), {}, is_mip, "", false}; + } + + // Publish: release makes all writes above visible to the dispatch thread. + job_queue[slot].ready.store(true, std::memory_order_release); + job_queue[slot].claimed.store(false, std::memory_order_release); + + if (config.verbose) { std::cout << "[Server] Job submitted (async): " << job_id << "\n"; } + + return {true, job_id}; +} + +// Same as submit_job_async but for the chunked upload path. Stores the +// header + chunks in pending_chunked_data and marks the slot as is_chunked +// so the dispatch thread calls write_chunked_request_to_pipe(). +std::pair submit_chunked_job_async(PendingChunkedUpload&& chunked_data, + bool is_mip) +{ + std::string job_id = generate_job_id(); + + int slot = -1; + for (size_t i = 0; i < MAX_JOBS; ++i) { + if (job_queue[i].ready.load()) continue; + bool expected = false; + if (job_queue[i].claimed.compare_exchange_strong(expected, true)) { + slot = static_cast(i); + break; + } + } + if (slot < 0) { return {false, "Job queue full"}; } + + copy_cstr(job_queue[slot].job_id, job_id); + job_queue[slot].problem_type = is_mip ? 1 : 0; + job_queue[slot].data_size = 0; + job_queue[slot].cancelled.store(false); + job_queue[slot].worker_index.store(-1); + job_queue[slot].data_sent.store(false); + job_queue[slot].is_chunked = true; + job_queue[slot].worker_pid = 0; + + { + std::lock_guard lock(pending_data_mutex); + pending_chunked_data[job_id] = std::move(chunked_data); + } + + { + std::lock_guard lock(tracker_mutex); + job_tracker[job_id] = + JobInfo{job_id, JobStatus::QUEUED, std::chrono::steady_clock::now(), {}, is_mip, "", false}; + } + + job_queue[slot].ready.store(true, std::memory_order_release); + job_queue[slot].claimed.store(false, std::memory_order_release); + + if (config.verbose) { std::cout << "[Server] Chunked job submitted (async): " << job_id << "\n"; } + + return {true, job_id}; +} + +JobStatus check_job_status(const std::string& job_id, std::string& message) +{ + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + + if (it == job_tracker.end()) { + message = "Job ID not found"; + return JobStatus::NOT_FOUND; + } + + if (it->second.status == JobStatus::QUEUED) { + for (size_t i = 0; i < MAX_JOBS; ++i) { + if (job_queue[i].ready && job_queue[i].claimed && + std::string(job_queue[i].job_id) == job_id) { + it->second.status = JobStatus::PROCESSING; + break; + } + } + } + + switch (it->second.status) { + case JobStatus::QUEUED: message = "Job is queued"; break; + case JobStatus::PROCESSING: message = "Job is being processed"; break; + case JobStatus::COMPLETED: message = "Job completed"; break; + case JobStatus::FAILED: message = "Job failed: " + it->second.error_message; break; + case JobStatus::CANCELLED: message = "Job was cancelled"; break; + default: message = "Unknown status"; + } + + return it->second.status; +} + +bool get_job_is_mip(const std::string& job_id) +{ + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it == job_tracker.end()) { return false; } + return it->second.is_mip; +} + +void ensure_log_dir_exists() +{ + struct stat st; + if (stat(LOG_DIR.c_str(), &st) != 0) { mkdir(LOG_DIR.c_str(), 0755); } +} + +void delete_log_file(const std::string& job_id) +{ + std::string log_file = get_log_file_path(job_id); + unlink(log_file.c_str()); +} + +int cancel_job(const std::string& job_id, JobStatus& job_status_out, std::string& message) +{ + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + + if (it == job_tracker.end()) { + message = "Job ID not found"; + job_status_out = JobStatus::NOT_FOUND; + return 1; + } + + JobStatus current_status = it->second.status; + + if (current_status == JobStatus::COMPLETED) { + message = "Cannot cancel completed job"; + job_status_out = JobStatus::COMPLETED; + return 2; + } + + if (current_status == JobStatus::CANCELLED) { + message = "Job already cancelled"; + job_status_out = JobStatus::CANCELLED; + return 3; + } + + if (current_status == JobStatus::FAILED) { + message = "Cannot cancel failed job"; + job_status_out = JobStatus::FAILED; + return 2; + } + + for (size_t i = 0; i < MAX_JOBS; ++i) { + if (!job_queue[i].ready.load(std::memory_order_acquire)) continue; + if (strcmp(job_queue[i].job_id, job_id.c_str()) != 0) continue; + + // Re-validate the slot: the job_id could have changed between the + // initial check and now if the slot was recycled. Load ready with + // acquire so we see all writes that published it. + if (!job_queue[i].ready.load(std::memory_order_acquire) || + strcmp(job_queue[i].job_id, job_id.c_str()) != 0) { + continue; + } + + pid_t worker_pid = job_queue[i].worker_pid.load(std::memory_order_relaxed); + + if (worker_pid > 0 && job_queue[i].claimed.load(std::memory_order_relaxed)) { + if (config.verbose) { + std::cout << "[Server] Cancelling running job " << job_id << " (killing worker " + << worker_pid << ")\n"; + } + job_queue[i].cancelled.store(true, std::memory_order_release); + kill(worker_pid, SIGKILL); + } else { + if (config.verbose) { std::cout << "[Server] Cancelling queued job " << job_id << "\n"; } + job_queue[i].cancelled.store(true, std::memory_order_release); + } + + it->second.status = JobStatus::CANCELLED; + it->second.error_message = "Job cancelled by user"; + job_status_out = JobStatus::CANCELLED; + message = "Job cancelled successfully"; + + delete_log_file(job_id); + + { + std::lock_guard wlock(waiters_mutex); + auto wit = waiting_threads.find(job_id); + if (wit != waiting_threads.end()) { + auto waiter = wit->second; + { + std::lock_guard waiter_lock(waiter->mutex); + waiter->error_message = "Job cancelled by user"; + waiter->success = false; + waiter->ready = true; + } + waiter->cv.notify_all(); + waiting_threads.erase(wit); + } + } + + return 0; + } + + if (it->second.status == JobStatus::COMPLETED) { + message = "Cannot cancel completed job"; + job_status_out = JobStatus::COMPLETED; + return 2; + } + + it->second.status = JobStatus::CANCELLED; + it->second.error_message = "Job cancelled by user"; + job_status_out = JobStatus::CANCELLED; + message = "Job cancelled"; + + { + std::lock_guard wlock(waiters_mutex); + auto wit = waiting_threads.find(job_id); + if (wit != waiting_threads.end()) { + auto waiter = wit->second; + { + std::lock_guard waiter_lock(waiter->mutex); + waiter->error_message = "Job cancelled by user"; + waiter->success = false; + waiter->ready = true; + } + waiter->cv.notify_all(); + waiting_threads.erase(wit); + } + } + + return 0; +} + +std::string generate_job_id() +{ + uuid_t uuid; + uuid_generate_random(uuid); + char buf[37]; + uuid_unparse_lower(uuid, buf); + return std::string(buf); +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_pipe_io.cpp b/cpp/src/grpc/server/grpc_pipe_io.cpp new file mode 100644 index 0000000000..c9e70eff1d --- /dev/null +++ b/cpp/src/grpc/server/grpc_pipe_io.cpp @@ -0,0 +1,77 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef CUOPT_ENABLE_GRPC + +#include +#include +#include +#include + +#include +#include + +bool write_to_pipe(int fd, const void* data, size_t size) +{ + const uint8_t* ptr = static_cast(data); + size_t remaining = size; + while (remaining > 0) { + ssize_t written = ::write(fd, ptr, remaining); + if (written <= 0) { + if (errno == EINTR) continue; + return false; + } + ptr += written; + remaining -= written; + } + return true; +} + +bool read_from_pipe(int fd, void* data, size_t size, int timeout_ms) +{ + uint8_t* ptr = static_cast(data); + size_t remaining = size; + + // Poll once to enforce timeout before the first read. After data starts + // flowing, blocking read() is sufficient — if the writer dies the pipe + // closes and read() returns 0 (EOF). Avoids ~10k extra poll() syscalls + // per bulk transfer. + struct pollfd pfd = {fd, POLLIN, 0}; + int pr; + do { + pr = poll(&pfd, 1, timeout_ms); + } while (pr < 0 && errno == EINTR); + if (pr < 0) { + std::cerr << "[Server] poll() failed on pipe: " << strerror(errno) << "\n"; + return false; + } + if (pr == 0) { + std::cerr << "[Server] Timeout waiting for pipe data (waited " << timeout_ms << "ms)\n"; + return false; + } + if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) { + std::cerr << "[Server] Pipe error/hangup detected\n"; + return false; + } + + while (remaining > 0) { + ssize_t nread = ::read(fd, ptr, remaining); + if (nread > 0) { + ptr += nread; + remaining -= nread; + continue; + } + if (nread == 0) { + std::cerr << "[Server] Pipe EOF (writer closed)\n"; + return false; + } + if (errno == EINTR) continue; + std::cerr << "[Server] Pipe read error: " << strerror(errno) << "\n"; + return false; + } + return true; +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_pipe_serialization.hpp b/cpp/src/grpc/server/grpc_pipe_serialization.hpp new file mode 100644 index 0000000000..7e8726f35c --- /dev/null +++ b/cpp/src/grpc/server/grpc_pipe_serialization.hpp @@ -0,0 +1,251 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifdef CUOPT_ENABLE_GRPC + +#include "cuopt_remote.pb.h" +#include "cuopt_remote_service.pb.h" +#include "grpc_field_element_size.hpp" + +#include +#include +#include +#include +#include + +// Requested pipe buffer size (1 MiB). The kernel default is 64 KiB, which +// forces excessive context-switching on large transfers. fcntl(F_SETPIPE_SZ) +// may silently cap this to /proc/sys/fs/pipe-max-size. +static constexpr int kPipeBufferSize = 1024 * 1024; + +static constexpr uint64_t kMaxPipeArrayBytes = 4ULL * 1024 * 1024 * 1024; +static constexpr uint32_t kMaxPipeArrayFields = 10000; +static constexpr uint32_t kMaxProtobufMessageBytes = 64 * 1024 * 1024; // 64 MiB + +// Pipe I/O primitives defined in grpc_job_management.cpp. +bool write_to_pipe(int fd, const void* data, size_t size); +bool read_from_pipe(int fd, void* data, size_t size, int timeout_ms = 120000); + +// ============================================================================= +// Low-level: write/read a single protobuf message with a uint32 length prefix. +// Uses standard protobuf SerializeToArray / ParseFromArray for the payload. +// ============================================================================= + +inline bool write_protobuf_to_pipe(int fd, const google::protobuf::MessageLite& msg) +{ + size_t byte_size = msg.ByteSizeLong(); + if (byte_size > kMaxProtobufMessageBytes) return false; + uint32_t size = static_cast(byte_size); + if (!write_to_pipe(fd, &size, sizeof(size))) return false; + if (size == 0) return true; + std::vector buf(size); + if (!msg.SerializeToArray(buf.data(), static_cast(size))) return false; + return write_to_pipe(fd, buf.data(), size); +} + +inline bool read_protobuf_from_pipe(int fd, google::protobuf::MessageLite& msg) +{ + uint32_t size; + if (!read_from_pipe(fd, &size, sizeof(size))) return false; + if (size > kMaxProtobufMessageBytes) return false; + if (size == 0) return msg.ParseFromArray(nullptr, 0); + std::vector buf(size); + if (!read_from_pipe(fd, buf.data(), size)) return false; + return msg.ParseFromArray(buf.data(), static_cast(size)); +} + +// ============================================================================= +// Chunked request: server → worker pipe (ChunkedProblemHeader + raw arrays) +// +// Wire format (protobuf header + raw byte arrays): +// [uint32 hdr_size][protobuf header bytes] +// [uint32 num_arrays] +// per array: [int32 field_id][uint64 total_bytes][raw bytes...] +// +// The protobuf ChunkedProblemHeader carries all metadata (settings, field +// types, element counts). Array data bypasses protobuf serialization and +// flows directly through the pipe as raw bytes. +// ============================================================================= + +inline bool write_chunked_request_to_pipe(int fd, + const cuopt::remote::ChunkedProblemHeader& header, + const std::vector& chunks) +{ + // Step 1: write the protobuf header (settings, scalars, string arrays). + if (!write_protobuf_to_pipe(fd, header)) return false; + + // Step 2: group incoming gRPC chunks by field_id. A single field may arrive + // as multiple chunks (the client splits large arrays at chunk_size_bytes). + struct FieldInfo { + std::vector chunks; + int64_t total_bytes = 0; + }; + std::map fields; + for (const auto& ac : chunks) { + int32_t fid = static_cast(ac.field_id()); + auto& fi = fields[fid]; + fi.chunks.push_back(&ac); + if (fi.total_bytes == 0 && ac.total_elements() > 0) { + auto elem_size = array_field_element_size(ac.field_id()); + if (elem_size > 0 && ac.total_elements() <= std::numeric_limits::max() / elem_size) { + fi.total_bytes = ac.total_elements() * elem_size; + } + } + } + + // Step 3: write per-field raw byte arrays. + uint32_t num_arrays = static_cast(fields.size()); + if (!write_to_pipe(fd, &num_arrays, sizeof(num_arrays))) return false; + + for (const auto& [fid, fi] : fields) { + int32_t field_id = fid; + uint64_t total_bytes = static_cast(fi.total_bytes); + if (!write_to_pipe(fd, &field_id, sizeof(field_id))) return false; + if (!write_to_pipe(fd, &total_bytes, sizeof(total_bytes))) return false; + if (total_bytes == 0) continue; + + // Fast path: field arrived in a single chunk that covers the whole array. + // Write directly from the protobuf bytes string, avoiding an assembly copy. + if (fi.chunks.size() == 1 && fi.chunks[0]->element_offset() == 0 && + static_cast(fi.chunks[0]->data().size()) == fi.total_bytes) { + if (!write_to_pipe(fd, fi.chunks[0]->data().data(), fi.chunks[0]->data().size())) + return false; + } else { + // Slow path: stitch multiple chunks into a contiguous buffer, placing + // each chunk at its element_offset * elem_size byte position. + int64_t total_elements = fi.chunks[0]->total_elements(); + if (total_elements <= 0 || fi.total_bytes % total_elements != 0) return false; + int64_t elem_size = fi.total_bytes / total_elements; + if (elem_size <= 0) return false; + + std::vector assembled(static_cast(fi.total_bytes), 0); + // Per-element bitmap detects both overlaps (element written twice) + // and gaps (element never written). + std::vector covered(static_cast(total_elements), false); + + for (const auto* ac : fi.chunks) { + int64_t element_offset = ac->element_offset(); + const auto& chunk_data = ac->data(); + if (chunk_data.size() % static_cast(elem_size) != 0) return false; + int64_t chunk_elements = static_cast(chunk_data.size()) / elem_size; + if (element_offset < 0 || chunk_elements < 0) return false; + if (element_offset > total_elements - chunk_elements) return false; + + int64_t byte_offset = element_offset * elem_size; + if (byte_offset + static_cast(chunk_data.size()) > fi.total_bytes) return false; + + for (int64_t e = 0; e < chunk_elements; ++e) { + size_t idx = static_cast(element_offset + e); + if (covered[idx]) return false; // overlap + covered[idx] = true; + } + std::memcpy(assembled.data() + byte_offset, chunk_data.data(), chunk_data.size()); + } + // Every element must be covered exactly once (no gaps). + for (size_t e = 0; e < static_cast(total_elements); ++e) { + if (!covered[e]) return false; + } + if (!write_to_pipe(fd, assembled.data(), assembled.size())) return false; + } + } + + return true; +} + +inline bool read_chunked_request_from_pipe(int fd, + cuopt::remote::ChunkedProblemHeader& header_out, + std::map>& arrays_out) +{ + if (!read_protobuf_from_pipe(fd, header_out)) return false; + + uint32_t num_arrays; + if (!read_from_pipe(fd, &num_arrays, sizeof(num_arrays))) return false; + if (num_arrays > kMaxPipeArrayFields) return false; + + // Read each field's raw bytes directly into the output map, keyed by field_id. + for (uint32_t i = 0; i < num_arrays; ++i) { + int32_t field_id; + uint64_t total_bytes; + if (!read_from_pipe(fd, &field_id, sizeof(field_id))) return false; + if (!read_from_pipe(fd, &total_bytes, sizeof(total_bytes))) return false; + if (total_bytes > kMaxPipeArrayBytes) return false; + auto& dest = arrays_out[field_id]; + dest.resize(static_cast(total_bytes)); + if (total_bytes > 0 && !read_from_pipe(fd, dest.data(), static_cast(total_bytes))) + return false; + } + + return true; +} + +// ============================================================================= +// Result: worker → server pipe (ChunkedResultHeader + raw arrays) +// +// Same wire format as the chunked request above. Unlike the request path, +// result arrays are already assembled into contiguous vectors by the worker, +// so no chunk grouping or assembly is needed. +// ============================================================================= + +inline bool write_result_to_pipe(int fd, + const cuopt::remote::ChunkedResultHeader& header, + const std::map>& arrays) +{ + if (!write_protobuf_to_pipe(fd, header)) return false; + + uint32_t num_arrays = static_cast(arrays.size()); + if (!write_to_pipe(fd, &num_arrays, sizeof(num_arrays))) return false; + + // Each array is already contiguous — write field_id, size, and raw bytes. + for (const auto& [fid, data] : arrays) { + int32_t field_id = fid; + uint64_t total_bytes = data.size(); + if (!write_to_pipe(fd, &field_id, sizeof(field_id))) return false; + if (!write_to_pipe(fd, &total_bytes, sizeof(total_bytes))) return false; + if (total_bytes > 0 && !write_to_pipe(fd, data.data(), data.size())) return false; + } + + return true; +} + +inline bool read_result_from_pipe(int fd, + cuopt::remote::ChunkedResultHeader& header_out, + std::map>& arrays_out) +{ + if (!read_protobuf_from_pipe(fd, header_out)) return false; + + uint32_t num_arrays; + if (!read_from_pipe(fd, &num_arrays, sizeof(num_arrays))) return false; + if (num_arrays > kMaxPipeArrayFields) return false; + + for (uint32_t i = 0; i < num_arrays; ++i) { + int32_t field_id; + uint64_t total_bytes; + if (!read_from_pipe(fd, &field_id, sizeof(field_id))) return false; + if (!read_from_pipe(fd, &total_bytes, sizeof(total_bytes))) return false; + if (total_bytes > kMaxPipeArrayBytes) return false; + auto& dest = arrays_out[field_id]; + dest.resize(static_cast(total_bytes)); + if (total_bytes > 0 && !read_from_pipe(fd, dest.data(), static_cast(total_bytes))) + return false; + } + + return true; +} + +// Serialize a SubmitJobRequest directly to a pipe blob using standard protobuf. +// Used for unary submits only (always well under 2 GiB). +inline std::vector serialize_submit_request_to_pipe( + const cuopt::remote::SubmitJobRequest& request) +{ + size_t byte_size = request.ByteSizeLong(); + if (byte_size == 0 || byte_size > static_cast(std::numeric_limits::max())) return {}; + std::vector blob(byte_size); + request.SerializeToArray(blob.data(), static_cast(byte_size)); + return blob; +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_server_main.cpp b/cpp/src/grpc/server/grpc_server_main.cpp new file mode 100644 index 0000000000..4a1a28d047 --- /dev/null +++ b/cpp/src/grpc/server/grpc_server_main.cpp @@ -0,0 +1,347 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +/** + * @file grpc_server_main.cpp + * @brief gRPC-based remote solve server entry point + * + * This server uses gRPC for client communication with fork-based worker + * process infrastructure: + * - Worker processes with shared memory job queues + * - Pipe-based IPC for problem/result data + * - Result tracking and retrieval threads + * - Log streaming + */ + +#ifdef CUOPT_ENABLE_GRPC + +#include "grpc_server_types.hpp" + +// Defined in grpc_service_impl.cpp +std::unique_ptr create_cuopt_grpc_service(); + +void print_usage(const char* prog) +{ + std::cout + << "Usage: " << prog << " [options]\n" + << "Options:\n" + << " -p, --port PORT Listen port (default: 8765)\n" + << " -w, --workers NUM Number of worker processes (default: 1)\n" + << " --max-message-mb N gRPC max send/recv message size in MiB (default: 256)\n" + << " --max-message-bytes N Set max message size in exact bytes (min 4096, for testing)\n" + << " --chunk-timeout N Per-chunk timeout in seconds for streaming (default: 60, " + "0=disabled)\n" + << " --enable-transfer-hash Log data hashes for streaming transfers (for testing)\n" + << " --tls Enable TLS (requires --tls-cert and --tls-key)\n" + << " --tls-cert PATH Path to PEM-encoded server certificate\n" + << " --tls-key PATH Path to PEM-encoded server private key\n" + << " --tls-root PATH Path to PEM root certs for client verification\n" + << " --require-client-cert Require and verify client certs (mTLS)\n" + << " --log-to-console Enable solver log output to console (default: off)\n" + << " -v, --verbose Increase verbosity (default: on)\n" + << " -q, --quiet Reduce verbosity\n" + << " -h, --help Show this help\n"; +} + +int main(int argc, char** argv) +{ + auto require_arg = [&](int i, const std::string& flag) -> bool { + if (i + 1 >= argc) { + std::cerr << "Error: " << flag << " requires a value\n"; + print_usage(argv[0]); + return false; + } + return true; + }; + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "-p" || arg == "--port") { + if (!require_arg(i, arg)) return 1; + config.port = std::stoi(argv[++i]); + } else if (arg == "-w" || arg == "--workers") { + if (!require_arg(i, arg)) return 1; + config.num_workers = std::stoi(argv[++i]); + } else if (arg == "--max-message-mb") { + if (!require_arg(i, arg)) return 1; + config.max_message_bytes = static_cast(std::stoi(argv[++i])) * kMiB; + } else if (arg == "--max-message-bytes") { + if (!require_arg(i, arg)) return 1; + config.max_message_bytes = std::max(4096LL, std::stoll(argv[++i])); + } else if (arg == "--chunk-timeout") { + if (!require_arg(i, arg)) return 1; + config.chunk_timeout_seconds = std::max(0, std::stoi(argv[++i])); + } else if (arg == "--enable-transfer-hash") { + config.enable_transfer_hash = true; + } else if (arg == "--tls") { + config.enable_tls = true; + } else if (arg == "--tls-cert") { + if (!require_arg(i, arg)) return 1; + config.tls_cert_path = argv[++i]; + } else if (arg == "--tls-key") { + if (!require_arg(i, arg)) return 1; + config.tls_key_path = argv[++i]; + } else if (arg == "--tls-root") { + if (!require_arg(i, arg)) return 1; + config.tls_root_path = argv[++i]; + } else if (arg == "--require-client-cert") { + config.require_client = true; + } else if (arg == "--log-to-console") { + config.log_to_console = true; + } else if (arg == "-v" || arg == "--verbose") { + config.verbose = true; + } else if (arg == "-q" || arg == "--quiet") { + config.verbose = false; + } else if (arg == "-h" || arg == "--help") { + print_usage(argv[0]); + return 0; + } else { + std::cerr << "Unknown option: " << arg << "\n"; + print_usage(argv[0]); + return 1; + } + } + + // Validate numeric ranges. + if (config.port < 1 || config.port > 65535) { + std::cerr << "Error: --port must be in range 1-65535\n"; + print_usage(argv[0]); + return 1; + } + if (config.num_workers < 1) { + std::cerr << "Error: --workers must be >= 1\n"; + print_usage(argv[0]); + return 1; + } + if (config.chunk_timeout_seconds < 0) { + std::cerr << "Error: --chunk-timeout must be >= 0\n"; + print_usage(argv[0]); + return 1; + } + + config.max_message_bytes = + std::clamp(config.max_message_bytes, kServerMinMessageBytes, kServerMaxMessageBytes); + + std::cout << "cuOpt gRPC Remote Solve Server\n" + << "==============================\n" + << "Port: " << config.port << "\n" + << "Workers: " << config.num_workers << "\n" + << std::endl; + std::cout.flush(); + + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + ensure_log_dir_exists(); + + shm_unlink(SHM_JOB_QUEUE); + shm_unlink(SHM_RESULT_QUEUE); + shm_unlink(SHM_CONTROL); + + int shm_fd = shm_open(SHM_JOB_QUEUE, O_CREAT | O_RDWR, 0600); + if (shm_fd < 0) { + std::cerr << "[Server] Failed to create shared memory for job queue: " << strerror(errno) + << "\n"; + return 1; + } + if (ftruncate(shm_fd, sizeof(JobQueueEntry) * MAX_JOBS) < 0) { + std::cerr << "[Server] Failed to ftruncate job queue: " << strerror(errno) << "\n"; + close(shm_fd); + return 1; + } + job_queue = static_cast( + mmap(nullptr, sizeof(JobQueueEntry) * MAX_JOBS, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0)); + close(shm_fd); + + if (job_queue == MAP_FAILED) { + std::cerr << "[Server] Failed to mmap job queue: " << strerror(errno) << "\n"; + return 1; + } + int result_shm_fd = shm_open(SHM_RESULT_QUEUE, O_CREAT | O_RDWR, 0600); + if (result_shm_fd < 0) { + std::cerr << "[Server] Failed to create result queue shm: " << strerror(errno) << "\n"; + return 1; + } + if (ftruncate(result_shm_fd, sizeof(ResultQueueEntry) * MAX_RESULTS) < 0) { + std::cerr << "[Server] Failed to ftruncate result queue: " << strerror(errno) << "\n"; + close(result_shm_fd); + return 1; + } + result_queue = static_cast(mmap(nullptr, + sizeof(ResultQueueEntry) * MAX_RESULTS, + PROT_READ | PROT_WRITE, + MAP_SHARED, + result_shm_fd, + 0)); + close(result_shm_fd); + if (result_queue == MAP_FAILED) { + std::cerr << "[Server] Failed to mmap result queue: " << strerror(errno) << "\n"; + return 1; + } + int ctrl_shm_fd = shm_open(SHM_CONTROL, O_CREAT | O_RDWR, 0600); + if (ctrl_shm_fd < 0) { + std::cerr << "[Server] Failed to create control shm: " << strerror(errno) << "\n"; + return 1; + } + if (ftruncate(ctrl_shm_fd, sizeof(SharedMemoryControl)) < 0) { + std::cerr << "[Server] Failed to ftruncate control: " << strerror(errno) << "\n"; + close(ctrl_shm_fd); + return 1; + } + shm_ctrl = static_cast( + mmap(nullptr, sizeof(SharedMemoryControl), PROT_READ | PROT_WRITE, MAP_SHARED, ctrl_shm_fd, 0)); + close(ctrl_shm_fd); + if (shm_ctrl == MAP_FAILED) { + std::cerr << "[Server] Failed to mmap control: " << strerror(errno) << "\n"; + return 1; + } + new (shm_ctrl) SharedMemoryControl{}; + + for (size_t i = 0; i < MAX_JOBS; ++i) { + new (&job_queue[i]) JobQueueEntry{}; + job_queue[i].ready.store(false); + job_queue[i].claimed.store(false); + job_queue[i].cancelled.store(false); + job_queue[i].worker_index.store(-1); + } + + for (size_t i = 0; i < MAX_RESULTS; ++i) { + new (&result_queue[i]) ResultQueueEntry{}; + result_queue[i].claimed.store(false); + result_queue[i].ready.store(false); + result_queue[i].retrieved.store(false); + } + + shm_ctrl->shutdown_requested.store(false); + shm_ctrl->active_workers.store(0); + + // Build credentials before spawning workers so TLS validation failures + // don't leak worker processes or background threads. + std::string server_address = "0.0.0.0:" + std::to_string(config.port); + std::shared_ptr creds; + if (config.enable_tls) { + if (config.tls_cert_path.empty() || config.tls_key_path.empty()) { + std::cerr << "[Server] TLS enabled but --tls-cert/--tls-key not provided\n"; + cleanup_shared_memory(); + return 1; + } + grpc::SslServerCredentialsOptions ssl_opts; + grpc::SslServerCredentialsOptions::PemKeyCertPair key_cert; + key_cert.cert_chain = read_file_to_string(config.tls_cert_path); + key_cert.private_key = read_file_to_string(config.tls_key_path); + if (key_cert.cert_chain.empty() || key_cert.private_key.empty()) { + std::cerr << "[Server] Failed to read TLS cert/key files\n"; + cleanup_shared_memory(); + return 1; + } + ssl_opts.pem_key_cert_pairs.push_back(key_cert); + + if (!config.tls_root_path.empty()) { + ssl_opts.pem_root_certs = read_file_to_string(config.tls_root_path); + if (ssl_opts.pem_root_certs.empty()) { + std::cerr << "[Server] Failed to read TLS root cert file\n"; + cleanup_shared_memory(); + return 1; + } + } + + if (config.require_client) { + if (ssl_opts.pem_root_certs.empty()) { + std::cerr << "[Server] --require-client-cert requires --tls-root\n"; + cleanup_shared_memory(); + return 1; + } + ssl_opts.client_certificate_request = + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; + } else if (!ssl_opts.pem_root_certs.empty()) { + ssl_opts.client_certificate_request = GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY; + } + + creds = grpc::SslServerCredentials(ssl_opts); + } else { + creds = grpc::InsecureServerCredentials(); + } + + signal(SIGPIPE, SIG_IGN); + spawn_workers(); + + if (worker_pids.empty()) { + std::cerr << "[Server] No workers started; exiting\n"; + cleanup_shared_memory(); + return 1; + } + + std::thread result_thread(result_retrieval_thread); + std::thread incumbent_thread(incumbent_retrieval_thread); + std::thread monitor_thread(worker_monitor_thread); + std::thread reaper_thread(session_reaper_thread); + + auto shutdown_all = [&]() { + keep_running = false; + shm_ctrl->shutdown_requested = true; + result_cv.notify_all(); + + if (result_thread.joinable()) result_thread.join(); + if (incumbent_thread.joinable()) incumbent_thread.join(); + if (monitor_thread.joinable()) monitor_thread.join(); + if (reaper_thread.joinable()) reaper_thread.join(); + + wait_for_workers(); + cleanup_shared_memory(); + }; + + auto service = create_cuopt_grpc_service(); + + ServerBuilder builder; + builder.AddListeningPort(server_address, creds); + builder.RegisterService(service.get()); + const int64_t max_bytes = server_max_message_bytes(); + const int channel_limit = + static_cast(std::min(max_bytes, std::numeric_limits::max())); + builder.SetMaxReceiveMessageSize(channel_limit); + builder.SetMaxSendMessageSize(channel_limit); + + std::unique_ptr server(builder.BuildAndStart()); + if (!server) { + std::cerr << "[Server] BuildAndStart() failed — could not bind to " << server_address << "\n"; + shutdown_all(); + return 1; + } + + std::cout << "[gRPC Server] Listening on " << server_address << std::endl; + std::cout << "[gRPC Server] Workers: " << config.num_workers << std::endl; + std::cout << "[gRPC Server] Max message size: " << server_max_message_bytes() << " bytes (" + << (server_max_message_bytes() / kMiB) << " MiB)" << std::endl; + std::cout << "[gRPC Server] Press Ctrl+C to shutdown" << std::endl; + + std::thread shutdown_thread([&server]() { + while (keep_running.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + if (server) { server->Shutdown(); } + }); + + server->Wait(); + if (shutdown_thread.joinable()) shutdown_thread.join(); + + std::cout << "\n[Server] Shutting down..." << std::endl; + shutdown_all(); + + std::cout << "[Server] Shutdown complete" << std::endl; + return 0; +} + +#else // !CUOPT_ENABLE_GRPC + +#include + +int main() +{ + std::cerr << "Error: cuopt_grpc_server requires gRPC support.\n" + << "Rebuild with gRPC enabled (CUOPT_ENABLE_GRPC=ON)" << std::endl; + return 1; +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_server_threads.cpp b/cpp/src/grpc/server/grpc_server_threads.cpp new file mode 100644 index 0000000000..a9beb1ca39 --- /dev/null +++ b/cpp/src/grpc/server/grpc_server_threads.cpp @@ -0,0 +1,423 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef CUOPT_ENABLE_GRPC + +#include "grpc_incumbent_proto.hpp" +#include "grpc_pipe_serialization.hpp" +#include "grpc_server_types.hpp" + +void worker_monitor_thread() +{ + std::cout << "[Server] Worker monitor thread started\n"; + std::cout.flush(); + + while (keep_running) { + for (size_t i = 0; i < worker_pids.size(); ++i) { + pid_t pid = worker_pids[i]; + if (pid <= 0) continue; + + int status; + pid_t result = waitpid(pid, &status, WNOHANG); + + if (result == pid) { + int exit_code = WIFEXITED(status) ? WEXITSTATUS(status) : -1; + bool signaled = WIFSIGNALED(status); + int signal_num = signaled ? WTERMSIG(status) : 0; + + if (signaled) { + std::cerr << "[Server] Worker " << pid << " killed by signal " << signal_num << "\n"; + std::cerr.flush(); + } else if (exit_code != 0) { + std::cerr << "[Server] Worker " << pid << " exited with code " << exit_code << "\n"; + std::cerr.flush(); + } else { + if (shm_ctrl && shm_ctrl->shutdown_requested) { + worker_pids[i] = 0; + continue; + } + std::cerr << "[Server] Worker " << pid << " exited unexpectedly\n"; + std::cerr.flush(); + } + + mark_worker_jobs_failed(pid); + + if (keep_running && shm_ctrl && !shm_ctrl->shutdown_requested) { + pid_t new_pid = spawn_single_worker(static_cast(i)); + if (new_pid > 0) { + worker_pids[i] = new_pid; + std::cout << "[Server] Restarted worker " << i << " with PID " << new_pid << "\n"; + std::cout.flush(); + } else { + worker_pids[i] = 0; + } + } else { + worker_pids[i] = 0; + } + } + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + std::cout << "[Server] Worker monitor thread stopped\n"; + std::cout.flush(); +} + +void result_retrieval_thread() +{ + std::cout << "[Server] Result retrieval thread started\n"; + std::cout.flush(); + + while (keep_running) { + bool found = false; + + for (size_t i = 0; i < MAX_JOBS; ++i) { + if (job_queue[i].ready && job_queue[i].claimed && !job_queue[i].data_sent && + !job_queue[i].cancelled) { + std::string job_id(job_queue[i].job_id); + int worker_idx = job_queue[i].worker_index; + + if (worker_idx >= 0) { + bool is_chunked = job_queue[i].is_chunked.load(); + bool send_ok = false; + bool has_data = false; + + if (is_chunked) { + PendingChunkedUpload chunked; + { + std::lock_guard lock(pending_data_mutex); + auto it = pending_chunked_data.find(job_id); + if (it != pending_chunked_data.end()) { + chunked = std::move(it->second); + has_data = true; + pending_chunked_data.erase(it); + } + } + if (has_data) { + int to_fd; + { + std::lock_guard wpl(worker_pipes_mutex); + to_fd = worker_pipes[worker_idx].to_worker_fd; + } + auto pipe_t0 = std::chrono::steady_clock::now(); + send_ok = write_chunked_request_to_pipe(to_fd, chunked.header, chunked.chunks); + if (send_ok && config.verbose) { + auto pipe_us = std::chrono::duration_cast( + std::chrono::steady_clock::now() - pipe_t0) + .count(); + std::cout << "[THROUGHPUT] phase=pipe_chunked_send chunks=" << chunked.chunks.size() + << " elapsed_ms=" << std::fixed << std::setprecision(1) + << (pipe_us / 1000.0) << "\n"; + std::cout << "[Server] Streamed " << chunked.chunks.size() << " chunks to worker " + << worker_idx << " for job " << job_id << "\n"; + } + } + } else { + std::vector job_data; + { + std::lock_guard lock(pending_data_mutex); + auto it = pending_job_data.find(job_id); + if (it != pending_job_data.end()) { + job_data = std::move(it->second); + has_data = true; + pending_job_data.erase(it); + } + } + if (has_data) { + auto pipe_t0 = std::chrono::steady_clock::now(); + send_ok = send_job_data_pipe(worker_idx, job_data); + if (send_ok && config.verbose) { + auto pipe_us = std::chrono::duration_cast( + std::chrono::steady_clock::now() - pipe_t0) + .count(); + double pipe_sec = pipe_us / 1e6; + double pipe_mb = static_cast(job_data.size()) / (1024.0 * 1024.0); + double pipe_mbs = (pipe_sec > 0.0) ? (pipe_mb / pipe_sec) : 0.0; + std::cout << "[THROUGHPUT] phase=pipe_job_send bytes=" << job_data.size() + << " elapsed_ms=" << std::fixed << std::setprecision(1) + << (pipe_us / 1000.0) << " throughput_mb_s=" << std::setprecision(1) + << pipe_mbs << "\n"; + std::cout << "[Server] Sent " << job_data.size() << " bytes to worker " + << worker_idx << " for job " << job_id << "\n"; + } + } + } + + if (has_data) { + if (send_ok) { + job_queue[i].data_sent = true; + } else { + std::cerr << "[Server] Failed to send job data to worker " << worker_idx << "\n"; + job_queue[i].cancelled = true; + } + found = true; + } + } + } + } + + for (size_t i = 0; i < MAX_RESULTS; ++i) { + if (result_queue[i].ready && !result_queue[i].retrieved) { + std::string job_id(result_queue[i].job_id); + ResultStatus result_status = result_queue[i].status; + bool success = (result_status == RESULT_SUCCESS); + bool cancelled = (result_status == RESULT_CANCELLED); + int worker_idx = result_queue[i].worker_index; + if (config.verbose) { + std::cout << "[Server] Detected ready result_slot=" << i << " for job " << job_id + << " status=" << result_status << " data_size=" << result_queue[i].data_size + << " worker_idx=" << worker_idx << "\n"; + std::cout.flush(); + } + + std::string error_message; + + cuopt::remote::ChunkedResultHeader hdr; + std::map> arrays; + + if (success && result_queue[i].data_size > 0) { + if (config.verbose) { + std::cout << "[Server] Reading streamed result from worker pipe for job " << job_id + << "\n"; + std::cout.flush(); + } + int from_fd; + { + std::lock_guard wpl(worker_pipes_mutex); + from_fd = worker_pipes[worker_idx].from_worker_fd; + } + auto pipe_recv_t0 = std::chrono::steady_clock::now(); + bool read_ok = read_result_from_pipe(from_fd, hdr, arrays); + if (!read_ok) { + error_message = "Failed to read result data from pipe"; + success = false; + } + if (success && config.verbose) { + auto pipe_us = std::chrono::duration_cast( + std::chrono::steady_clock::now() - pipe_recv_t0) + .count(); + int64_t total_bytes = 0; + for (const auto& [fid, data] : arrays) { + total_bytes += data.size(); + } + double pipe_sec = pipe_us / 1e6; + double pipe_mb = static_cast(total_bytes) / (1024.0 * 1024.0); + double pipe_mbs = (pipe_sec > 0.0) ? (pipe_mb / pipe_sec) : 0.0; + std::cout << "[THROUGHPUT] phase=pipe_result_recv bytes=" << total_bytes + << " elapsed_ms=" << std::fixed << std::setprecision(1) << (pipe_us / 1000.0) + << " throughput_mb_s=" << std::setprecision(1) << pipe_mbs << "\n"; + std::cout.flush(); + } + } else if (!success) { + error_message = result_queue[i].error_message; + } + + { + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it != job_tracker.end()) { + if (success) { + it->second.status = JobStatus::COMPLETED; + int64_t total_bytes = 0; + for (const auto& [fid, data] : arrays) { + total_bytes += data.size(); + } + it->second.result_header = std::move(hdr); + it->second.result_arrays = std::move(arrays); + it->second.result_size_bytes = total_bytes; + + if (config.verbose) { + std::cout << "[Server] Marked job COMPLETED in job_tracker: " << job_id + << " result_arrays=" << it->second.result_arrays.size() + << " result_size_bytes=" << it->second.result_size_bytes << "\n"; + std::cout.flush(); + } + } else if (cancelled) { + it->second.status = JobStatus::CANCELLED; + it->second.error_message = error_message; + if (config.verbose) { + std::cout << "[Server] Marked job CANCELLED in job_tracker: " << job_id + << " msg=" << error_message << "\n"; + std::cout.flush(); + } + } else { + it->second.status = JobStatus::FAILED; + it->second.error_message = error_message; + if (config.verbose) { + std::cout << "[Server] Marked job FAILED in job_tracker: " << job_id + << " msg=" << error_message << "\n"; + std::cout.flush(); + } + } + } + } + + { + std::lock_guard lock(waiters_mutex); + auto wit = waiting_threads.find(job_id); + if (wit != waiting_threads.end()) { + auto waiter = wit->second; + { + std::lock_guard waiter_lock(waiter->mutex); + waiter->error_message = error_message; + waiter->success = success; + waiter->ready = true; + } + waiter->cv.notify_all(); + waiting_threads.erase(wit); + } + } + + result_queue[i].retrieved = true; + result_queue[i].worker_index.store(-1, std::memory_order_relaxed); + // Clear claimed before ready: writers CAS claimed first, so clearing + // it here lets the slot be reused. ready=false must come last so + // writers don't see a half-recycled slot. + result_queue[i].claimed.store(false, std::memory_order_release); + result_queue[i].ready.store(false, std::memory_order_release); + found = true; + } + } + + if (!found) { usleep(10000); } + + result_cv.notify_all(); + } + + std::cout << "[Server] Result retrieval thread stopped\n"; + std::cout.flush(); +} + +void incumbent_retrieval_thread() +{ + std::cout << "[Server] Incumbent retrieval thread started\n"; + std::cout.flush(); + + while (keep_running) { + std::vector pfds; + { + std::lock_guard lock(worker_pipes_mutex); + pfds.reserve(worker_pipes.size()); + for (const auto& wp : worker_pipes) { + if (wp.incumbent_from_worker_fd >= 0) { + pollfd pfd; + pfd.fd = wp.incumbent_from_worker_fd; + pfd.events = POLLIN; + pfd.revents = 0; + pfds.push_back(pfd); + } + } + } + + if (pfds.empty()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + int poll_result = poll(pfds.data(), pfds.size(), 100); + if (poll_result < 0) { + if (errno == EINTR) continue; + std::cerr << "[Server] poll() failed in incumbent thread: " << strerror(errno) << "\n"; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + if (poll_result == 0) { continue; } + + for (const auto& pfd : pfds) { + if (!(pfd.revents & POLLIN)) { continue; } + std::vector data; + if (!recv_incumbent_pipe(pfd.fd, data)) { continue; } + if (data.empty()) { continue; } + + std::string job_id; + double objective = 0.0; + std::vector assignment; + if (!parse_incumbent_proto(data.data(), data.size(), job_id, objective, assignment)) { + std::cerr << "[Server] Failed to parse incumbent payload\n"; + continue; + } + + if (job_id.empty()) { continue; } + + IncumbentEntry entry; + entry.objective = objective; + size_t num_vars = assignment.size(); + entry.assignment = std::move(assignment); + + { + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it != job_tracker.end()) { + it->second.incumbents.push_back(std::move(entry)); + std::cout << "[Server] Stored incumbent job_id=" << job_id + << " idx=" << (it->second.incumbents.size() - 1) << " obj=" << objective + << " vars=" << num_vars << "\n"; + std::cout.flush(); + } + } + } + } + + std::cout << "[Server] Incumbent retrieval thread stopped\n"; + std::cout.flush(); +} + +void session_reaper_thread() +{ + if (config.verbose) { + std::cout << "[Server] Session reaper thread started (timeout=" << kSessionTimeoutSeconds + << "s)\n"; + std::cout.flush(); + } + + const auto timeout = std::chrono::seconds(kSessionTimeoutSeconds); + + while (keep_running) { + for (int i = 0; i < 60 && keep_running; ++i) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + if (!keep_running) break; + + auto now = std::chrono::steady_clock::now(); + + { + std::lock_guard lock(chunked_uploads_mutex); + for (auto it = chunked_uploads.begin(); it != chunked_uploads.end();) { + if (now - it->second.last_activity > timeout) { + if (config.verbose) { + std::cout << "[Server] Reaping stale upload session: " << it->first << "\n"; + std::cout.flush(); + } + it = chunked_uploads.erase(it); + } else { + ++it; + } + } + } + + { + std::lock_guard lock(chunked_downloads_mutex); + for (auto it = chunked_downloads.begin(); it != chunked_downloads.end();) { + if (now - it->second.created > timeout) { + if (config.verbose) { + std::cout << "[Server] Reaping stale download session: " << it->first << "\n"; + std::cout.flush(); + } + it = chunked_downloads.erase(it); + } else { + ++it; + } + } + } + } + + if (config.verbose) { + std::cout << "[Server] Session reaper thread stopped\n"; + std::cout.flush(); + } +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_server_types.hpp b/cpp/src/grpc/server/grpc_server_types.hpp new file mode 100644 index 0000000000..2408440fb0 --- /dev/null +++ b/cpp/src/grpc/server/grpc_server_types.hpp @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifdef CUOPT_ENABLE_GRPC + +#include +#include "cuopt_remote.pb.h" +#include "cuopt_remote_service.grpc.pb.h" + +#include +#include +#include +#include +#include "grpc_problem_mapper.hpp" +#include "grpc_settings_mapper.hpp" +#include "grpc_solution_mapper.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::ServerReaderWriter; +using grpc::ServerWriter; +using grpc::Status; +using grpc::StatusCode; + +using namespace cuopt::linear_programming; +// Note: NOT using "using namespace cuopt::remote" to avoid JobStatus enum conflict + +// ============================================================================= +// Shared Memory Structures (must match between main process and workers) +// ============================================================================= + +constexpr size_t MAX_JOBS = 100; +constexpr size_t MAX_RESULTS = 100; + +template +void copy_cstr(char (&dst)[N], const std::string& src) +{ + std::snprintf(dst, N, "%s", src.c_str()); +} + +template +void copy_cstr(char (&dst)[N], const char* src) +{ + std::snprintf(dst, N, "%s", src ? src : ""); +} + +struct JobQueueEntry { + char job_id[64]; + uint32_t problem_type; // 0 = LP, 1 = MIP + uint64_t data_size; // Size of problem data (uint64 for large problems) + std::atomic ready; // Job is ready to be processed + std::atomic claimed; // Worker has claimed this job + std::atomic worker_pid; // PID of worker that claimed this job (0 if none) + std::atomic cancelled; // Job has been cancelled (worker should skip) + std::atomic worker_index; // Index of worker that claimed this job (-1 if none) + std::atomic data_sent; // Server has sent data to worker's pipe + std::atomic is_chunked; // True when data is in pending_chunked_data (streamed to pipe) +}; + +enum ResultStatus : uint32_t { + RESULT_SUCCESS = 0, + RESULT_ERROR = 1, + RESULT_CANCELLED = 2, +}; + +struct ResultQueueEntry { + char job_id[64]; + ResultStatus status; + uint64_t data_size; // Size of result data (uint64 for large results) + char error_message[1024]; + std::atomic claimed; // CAS guard: prevents two forked workers from + // writing the same slot simultaneously. + std::atomic ready; // Result is ready for reading (published last). + std::atomic retrieved; // Result has been retrieved + std::atomic worker_index; // Index of worker that produced this result +}; + +struct SharedMemoryControl { + std::atomic shutdown_requested; + std::atomic active_workers; +}; + +// ============================================================================= +// Job status tracking (main process only) +// ============================================================================= + +enum class JobStatus { QUEUED, PROCESSING, COMPLETED, FAILED, NOT_FOUND, CANCELLED }; + +struct IncumbentEntry { + double objective = 0.0; + std::vector assignment; +}; + +struct JobInfo { + std::string job_id; + JobStatus status; + std::chrono::steady_clock::time_point submit_time; + std::vector incumbents; + bool is_mip; + std::string error_message; + bool is_blocking; + cuopt::remote::ChunkedResultHeader result_header; + std::map> result_arrays; + int64_t result_size_bytes = 0; +}; + +struct JobWaiter { + std::mutex mutex; + std::condition_variable cv; + std::vector result_data; + std::string error_message; + bool success; + bool ready; + std::atomic waiters{0}; + JobWaiter() : success(false), ready(false) {} +}; + +// ============================================================================= +// Server configuration +// ============================================================================= + +struct ServerConfig { + int port = 8765; + int num_workers = 1; + bool verbose = true; + bool log_to_console = false; + // Effective gRPC max send/recv message size in bytes. + // Set via --max-message-mb (MiB) or --max-message-bytes (exact, min 4096). + // Clamped at startup to [kServerMinMessageBytes, kServerMaxMessageBytes]. + int64_t max_message_bytes = 256LL * 1024 * 1024; // 256 MiB + int chunk_timeout_seconds = 60; // 0 = disabled + bool enable_transfer_hash = false; + bool enable_tls = false; + bool require_client = false; + std::string tls_cert_path; + std::string tls_key_path; + std::string tls_root_path; +}; + +struct WorkerPipes { + int to_worker_fd; + int from_worker_fd; + int worker_read_fd; + int worker_write_fd; + int incumbent_from_worker_fd; + int worker_incumbent_write_fd; +}; + +// Chunked download session state (raw arrays from worker) +struct ChunkedDownloadState { + bool is_mip = false; + std::chrono::steady_clock::time_point created; + cuopt::remote::ChunkedResultHeader result_header; + std::map> raw_arrays; // ResultFieldId -> raw bytes +}; + +// Per-array allocation cap for chunked uploads (4 GiB). +static constexpr int64_t kMaxChunkedArrayBytes = 4LL * 1024 * 1024 * 1024; + +// Maximum concurrent chunked upload + download sessions (global across all clients). +static constexpr size_t kMaxChunkedSessions = 16; + +// Stale session timeout: sessions with no activity for this long are reaped. +static constexpr int kSessionTimeoutSeconds = 300; + +struct ChunkedUploadState { + bool is_mip = false; + cuopt::remote::ChunkedProblemHeader header; + struct FieldMeta { + int64_t total_elements = 0; + int64_t element_size = 0; + int64_t received_bytes = 0; + }; + std::map field_meta; + std::vector chunks; + int64_t total_chunks = 0; + int64_t total_bytes = 0; + std::chrono::steady_clock::time_point last_activity; +}; + +// Holds header + chunks for a chunked upload, ready to stream to the worker pipe. +struct PendingChunkedUpload { + cuopt::remote::ChunkedProblemHeader header; + std::vector chunks; +}; + +// ============================================================================= +// Global state +// ============================================================================= + +inline std::atomic keep_running{true}; +inline std::map job_tracker; +inline std::mutex tracker_mutex; +inline std::condition_variable result_cv; + +inline std::map> waiting_threads; +inline std::mutex waiters_mutex; + +inline JobQueueEntry* job_queue = nullptr; +inline ResultQueueEntry* result_queue = nullptr; +inline SharedMemoryControl* shm_ctrl = nullptr; + +inline std::vector worker_pids; + +inline ServerConfig config; + +inline std::vector worker_pipes; +inline std::mutex worker_pipes_mutex; + +inline std::mutex pending_data_mutex; +inline std::map> pending_job_data; +inline std::map pending_chunked_data; + +inline std::mutex chunked_uploads_mutex; +inline std::map chunked_uploads; + +inline std::mutex chunked_downloads_mutex; +inline std::map chunked_downloads; + +inline const char* SHM_JOB_QUEUE = "/cuopt_job_queue"; +inline const char* SHM_RESULT_QUEUE = "/cuopt_result_queue"; +inline const char* SHM_CONTROL = "/cuopt_control"; + +inline const std::string LOG_DIR = "/tmp/cuopt_logs"; + +constexpr int64_t kMiB = 1024LL * 1024; +constexpr int64_t kGiB = 1024LL * 1024 * 1024; + +// Floor: 4 KiB is enough for basic gRPC control messages. Values below this +// would risk rejecting even metadata-only RPCs like CheckStatus. +constexpr int64_t kServerMinMessageBytes = 4LL * 1024; // 4 KiB +// Protobuf's hard serialization limit is 2 GiB (int32 sizes internally). +// Reserve 1 MiB headroom for gRPC framing and internal bookkeeping. +constexpr int64_t kServerMaxMessageBytes = 2LL * 1024 * 1024 * 1024 - 1LL * 1024 * 1024; // ~2 GiB + +// ============================================================================= +// Inline utility functions +// ============================================================================= + +inline std::string get_log_file_path(const std::string& job_id) +{ + return LOG_DIR + "/job_" + job_id + ".log"; +} + +inline int64_t server_max_message_bytes() { return config.max_message_bytes; } + +inline std::string read_file_to_string(const std::string& path) +{ + std::ifstream in(path, std::ios::in | std::ios::binary); + if (!in.is_open()) { return ""; } + std::ostringstream ss; + ss << in.rdbuf(); + return ss.str(); +} + +// ============================================================================= +// Signal handling +// ============================================================================= + +inline void signal_handler(int signal) +{ + if (signal == SIGINT || signal == SIGTERM) { + keep_running = false; + if (shm_ctrl) { shm_ctrl->shutdown_requested = true; } + } +} + +// ============================================================================= +// Forward declarations +// ============================================================================= + +std::string generate_job_id(); +void ensure_log_dir_exists(); +void delete_log_file(const std::string& job_id); +void cleanup_shared_memory(); +void spawn_workers(); +void wait_for_workers(); +void worker_monitor_thread(); +void result_retrieval_thread(); +void incumbent_retrieval_thread(); +void session_reaper_thread(); + +bool write_to_pipe(int fd, const void* data, size_t size); +bool read_from_pipe(int fd, void* data, size_t size, int timeout_ms); +bool send_job_data_pipe(int worker_idx, const std::vector& data); +bool recv_job_data_pipe(int fd, uint64_t expected_size, std::vector& data); +bool send_incumbent_pipe(int fd, const std::vector& data); +bool recv_incumbent_pipe(int fd, std::vector& data); + +void worker_process(int worker_id); +pid_t spawn_single_worker(int worker_id); +void mark_worker_jobs_failed(pid_t dead_worker_pid); + +std::pair submit_job_async(std::vector&& request_data, bool is_mip); +std::pair submit_chunked_job_async(PendingChunkedUpload&& chunked_data, + bool is_mip); +JobStatus check_job_status(const std::string& job_id, std::string& message); +bool get_job_is_mip(const std::string& job_id); +int cancel_job(const std::string& job_id, JobStatus& job_status_out, std::string& message); + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_service_impl.cpp b/cpp/src/grpc/server/grpc_service_impl.cpp new file mode 100644 index 0000000000..f63edb445e --- /dev/null +++ b/cpp/src/grpc/server/grpc_service_impl.cpp @@ -0,0 +1,879 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef CUOPT_ENABLE_GRPC + +#include "grpc_field_element_size.hpp" +#include "grpc_pipe_serialization.hpp" +#include "grpc_server_types.hpp" + +class CuOptRemoteServiceImpl final : public cuopt::remote::CuOptRemoteService::Service { + public: + // Unary submit: the entire problem fits in a single gRPC message. + // Serializes the request and delegates slot reservation + tracking to + // submit_job_async (shared with the chunked path's submit_chunked_job_async). + Status SubmitJob(ServerContext* context, + const cuopt::remote::SubmitJobRequest* request, + cuopt::remote::SubmitJobResponse* response) override + { + bool is_lp = request->has_lp_request(); + if (!is_lp && !request->has_mip_request()) { + return Status(StatusCode::INVALID_ARGUMENT, "No problem data provided"); + } + + if (config.verbose && is_lp) { + const auto& lp_req = request->lp_request(); + std::cerr << "[gRPC] SubmitJob LP fields: bytes=" << lp_req.ByteSizeLong() + << " objective_scaling_factor=" << lp_req.problem().objective_scaling_factor() + << " objective_offset=" << lp_req.problem().objective_offset() + << " iteration_limit=" << lp_req.settings().iteration_limit() + << " method=" << lp_req.settings().method() << std::endl; + } + + auto job_data = serialize_submit_request_to_pipe(*request); + if (config.verbose) { + std::cout << "[gRPC] SubmitJob: UNARY " << (is_lp ? "LP" : "MIP") + << ", pipe payload=" << job_data.size() << " bytes\n"; + std::cout.flush(); + } + + auto [ok, job_id] = submit_job_async(std::move(job_data), !is_lp); + if (!ok) { return Status(StatusCode::RESOURCE_EXHAUSTED, job_id); } + + response->set_job_id(job_id); + response->set_message("Job submitted successfully"); + + if (config.verbose) { + std::cout << "[gRPC] Job submitted: " << job_id << " (type=" << (is_lp ? "LP" : "MIP") << ")" + << std::endl; + } + + return Status::OK; + } + + // ========================================================================= + // Chunked Array Upload + // ========================================================================= + + Status StartChunkedUpload(ServerContext* context, + const cuopt::remote::StartChunkedUploadRequest* request, + cuopt::remote::StartChunkedUploadResponse* response) override + { + (void)context; + + std::string upload_id = generate_job_id(); + const auto& header = request->problem_header(); + bool is_mip = (header.header().problem_type() == cuopt::remote::MIP); + + if (config.verbose) { + std::cout << "[gRPC] StartChunkedUpload upload_id=" << upload_id << " is_mip=" << is_mip + << "\n"; + std::cout.flush(); + } + + { + std::lock_guard lock(chunked_uploads_mutex); + if (chunked_uploads.size() >= kMaxChunkedSessions) { + return Status(StatusCode::RESOURCE_EXHAUSTED, + "Too many concurrent chunked upload sessions (limit " + + std::to_string(kMaxChunkedSessions) + ")"); + } + auto& state = chunked_uploads[upload_id]; + state.is_mip = is_mip; + state.header = header; + state.total_chunks = 0; + state.last_activity = std::chrono::steady_clock::now(); + } + + response->set_upload_id(upload_id); + response->set_max_message_bytes(server_max_message_bytes()); + + return Status::OK; + } + + // Receive one chunk of array data for a chunked upload session. + // Chunks are accumulated in memory until FinishChunkedUpload, which hands + // them to the dispatch thread for pipe serialization to the worker. + Status SendArrayChunk(ServerContext* context, + const cuopt::remote::SendArrayChunkRequest* request, + cuopt::remote::SendArrayChunkResponse* response) override + { + (void)context; + + const std::string& upload_id = request->upload_id(); + const auto& ac = request->chunk(); + + std::lock_guard lock(chunked_uploads_mutex); + auto it = chunked_uploads.find(upload_id); + if (it == chunked_uploads.end()) { + return Status(StatusCode::NOT_FOUND, "Unknown upload_id: " + upload_id); + } + + auto& state = it->second; + state.last_activity = std::chrono::steady_clock::now(); + + int32_t field_id = static_cast(ac.field_id()); + int64_t elem_offset = ac.element_offset(); + int64_t total_elems = ac.total_elements(); + const auto& raw = ac.data(); + + if (!cuopt::remote::ArrayFieldId_IsValid(field_id)) { + return Status(StatusCode::INVALID_ARGUMENT, + "Unknown array field_id: " + std::to_string(field_id)); + } + if (elem_offset < 0) { + return Status(StatusCode::INVALID_ARGUMENT, "element_offset must be non-negative"); + } + if (total_elems < 0) { + return Status(StatusCode::INVALID_ARGUMENT, "total_elements must be non-negative"); + } + + // On the first chunk for a field, record its total size and element width. + // Subsequent chunks for the same field reuse these values. + auto& meta = state.field_meta[field_id]; + if (meta.total_elements == 0 && total_elems > 0) { + int64_t elem_size = array_field_element_size(ac.field_id()); + if (total_elems > kMaxChunkedArrayBytes / elem_size) { + return Status(StatusCode::RESOURCE_EXHAUSTED, + "Array too large (" + std::to_string(total_elems) + " x " + + std::to_string(elem_size) + " bytes exceeds " + + std::to_string(kMaxChunkedArrayBytes) + " byte limit)"); + } + meta.total_elements = total_elems; + meta.element_size = elem_size; + } + + // Validate that the chunk's byte range falls within the declared array bounds. + int64_t elem_size = meta.element_size > 0 ? meta.element_size : 1; + + if (elem_size > 1 && (raw.size() % static_cast(elem_size)) != 0) { + return Status(StatusCode::INVALID_ARGUMENT, + "Chunk data size (" + std::to_string(raw.size()) + + ") not aligned to element size (" + std::to_string(elem_size) + ")"); + } + + int64_t array_bytes = meta.total_elements * elem_size; + if (elem_offset > meta.total_elements) { + return Status(StatusCode::INVALID_ARGUMENT, "ArrayChunk offset exceeds array size"); + } + int64_t byte_offset = elem_offset * elem_size; + if (byte_offset + static_cast(raw.size()) > array_bytes) { + return Status(StatusCode::INVALID_ARGUMENT, "ArrayChunk out of bounds"); + } + + // Accumulate: the raw ArrayChunk protobuf is stored as-is and will be + // assembled into contiguous arrays during pipe serialization. + meta.received_bytes += static_cast(raw.size()); + state.total_bytes += static_cast(raw.size()); + state.chunks.push_back(ac); + ++state.total_chunks; + + response->set_upload_id(upload_id); + response->set_chunks_received(state.total_chunks); + return Status::OK; + } + + // Finalize a chunked upload: move the accumulated header + chunks into a + // PendingChunkedUpload and enqueue it as a job. The dispatch thread will + // call write_chunked_request_to_pipe() to send it to a worker. + Status FinishChunkedUpload(ServerContext* context, + const cuopt::remote::FinishChunkedUploadRequest* request, + cuopt::remote::SubmitJobResponse* response) override + { + (void)context; + + const std::string& upload_id = request->upload_id(); + + // Take ownership of the upload session and remove it from the active map. + ChunkedUploadState state; + { + std::lock_guard lock(chunked_uploads_mutex); + auto it = chunked_uploads.find(upload_id); + if (it == chunked_uploads.end()) { + return Status(StatusCode::NOT_FOUND, "Unknown upload_id: " + upload_id); + } + state = std::move(it->second); + chunked_uploads.erase(it); + } + + if (config.verbose) { + std::cout << "[gRPC] FinishChunkedUpload upload_id=" << upload_id + << " chunks=" << state.total_chunks << " fields=" << state.field_meta.size() + << "\n"; + std::cout.flush(); + } + + // Package the header and chunks for the dispatch thread. Field metadata + // was only needed for validation during SendArrayChunk and can be dropped. + PendingChunkedUpload pending; + pending.header = std::move(state.header); + pending.chunks = std::move(state.chunks); + state.field_meta.clear(); + + if (config.verbose) { + std::cout << "[gRPC] FinishChunkedUpload: CHUNKED path, " << state.total_chunks << " chunks, " + << state.total_bytes << " bytes, upload_id=" << upload_id << "\n"; + std::cout.flush(); + } + + auto [ok, job_id] = submit_chunked_job_async(std::move(pending), state.is_mip); + if (!ok) { return Status(StatusCode::RESOURCE_EXHAUSTED, job_id); } + + response->set_job_id(job_id); + response->set_message("Job submitted via chunked arrays"); + + if (config.verbose) { + std::cout << "[gRPC] FinishChunkedUpload enqueued job: " << job_id + << " (type=" << (state.is_mip ? "MIP" : "LP") << ")\n"; + std::cout.flush(); + } + + return Status::OK; + } + + // ========================================================================= + // Job Status and Result RPCs + // ========================================================================= + + Status CheckStatus(ServerContext* context, + const cuopt::remote::StatusRequest* request, + cuopt::remote::StatusResponse* response) override + { + (void)context; + std::string job_id = request->job_id(); + + std::string message; + JobStatus status = check_job_status(job_id, message); + + switch (status) { + case JobStatus::QUEUED: response->set_job_status(cuopt::remote::QUEUED); break; + case JobStatus::PROCESSING: response->set_job_status(cuopt::remote::PROCESSING); break; + case JobStatus::COMPLETED: response->set_job_status(cuopt::remote::COMPLETED); break; + case JobStatus::FAILED: response->set_job_status(cuopt::remote::FAILED); break; + case JobStatus::CANCELLED: response->set_job_status(cuopt::remote::CANCELLED); break; + default: response->set_job_status(cuopt::remote::NOT_FOUND); break; + } + response->set_message(message); + + response->set_max_message_bytes(server_max_message_bytes()); + + int64_t result_size_bytes = 0; + if (status == JobStatus::COMPLETED) { + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it != job_tracker.end()) { result_size_bytes = it->second.result_size_bytes; } + } + response->set_result_size_bytes(result_size_bytes); + + return Status::OK; + } + + // Return the full result in a single gRPC response (unary path). + // If the result exceeds the server's max message size, the client must + // fall back to the chunked download RPCs instead. + Status GetResult(ServerContext* context, + const cuopt::remote::GetResultRequest* request, + cuopt::remote::ResultResponse* response) override + { + (void)context; + std::string job_id = request->job_id(); + + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + + if (it == job_tracker.end()) { return Status(StatusCode::NOT_FOUND, "Job not found"); } + + if (it->second.status != JobStatus::COMPLETED && it->second.status != JobStatus::FAILED) { + return Status(StatusCode::UNAVAILABLE, "Result not ready"); + } + + if (it->second.status == JobStatus::FAILED) { + response->set_status(cuopt::remote::ERROR_SOLVE_FAILED); + response->set_error_message(it->second.error_message); + return Status::OK; + } + + // Guard against results that would exceed gRPC/protobuf message limits. + // The client detects RESOURCE_EXHAUSTED and switches to chunked download. + int64_t total_result_bytes = it->second.result_size_bytes; + const int64_t max_bytes = server_max_message_bytes(); + if (max_bytes > 0 && total_result_bytes > max_bytes) { + std::string msg = "Result size (~" + std::to_string(total_result_bytes) + + " bytes) exceeds max message size (" + std::to_string(max_bytes) + + " bytes). Use StartChunkedDownload/GetResultChunk RPCs instead."; + if (config.verbose) { + std::cout << "[gRPC] GetResult rejected for job " << job_id << ": " << msg << "\n"; + std::cout.flush(); + } + return Status(StatusCode::RESOURCE_EXHAUSTED, msg); + } + + // Build the full protobuf solution from the raw arrays that were read + // back from the worker pipe by the result retrieval thread. + if (it->second.is_mip) { + cuopt::remote::MIPSolution mip_solution; + build_mip_solution_proto( + it->second.result_header, it->second.result_arrays, &mip_solution); + response->mutable_mip_solution()->Swap(&mip_solution); + } else { + cuopt::remote::LPSolution lp_solution; + build_lp_solution_proto( + it->second.result_header, it->second.result_arrays, &lp_solution); + response->mutable_lp_solution()->Swap(&lp_solution); + } + + response->set_status(cuopt::remote::SUCCESS); + if (config.verbose) { + std::cout << "[gRPC] GetResult: UNARY response for job " << job_id << " (" + << total_result_bytes << " bytes, " << it->second.result_arrays.size() + << " arrays)\n"; + std::cout.flush(); + } + + return Status::OK; + } + + // ========================================================================= + // Chunked Result Download RPCs + // ========================================================================= + + // Begin a chunked result download: snapshot the result arrays into a + // download session. The client calls GetResultChunk to fetch slices and + // FinishChunkedDownload when done (which frees the session). + Status StartChunkedDownload(ServerContext* context, + const cuopt::remote::StartChunkedDownloadRequest* request, + cuopt::remote::StartChunkedDownloadResponse* response) override + { + std::string job_id = request->job_id(); + + // Copy the result data into a download session. This snapshot lets the + // client fetch chunks at its own pace without holding the tracker lock. + bool is_mip = false; + ChunkedDownloadState state; + { + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it == job_tracker.end()) { + return Status(StatusCode::NOT_FOUND, "Job not found: " + job_id); + } + if (it->second.status != JobStatus::COMPLETED) { + return Status(StatusCode::FAILED_PRECONDITION, "Result not ready for job: " + job_id); + } + is_mip = it->second.is_mip; + state.is_mip = is_mip; + state.created = std::chrono::steady_clock::now(); + state.result_header = it->second.result_header; + state.raw_arrays = it->second.result_arrays; + } + + response->mutable_header()->CopyFrom(state.result_header); + + std::string download_id = generate_job_id(); + response->set_download_id(download_id); + response->set_max_message_bytes(server_max_message_bytes()); + + { + std::lock_guard lock(chunked_downloads_mutex); + if (chunked_downloads.size() >= kMaxChunkedSessions) { + return Status(StatusCode::RESOURCE_EXHAUSTED, + "Too many concurrent chunked download sessions (limit " + + std::to_string(kMaxChunkedSessions) + ")"); + } + chunked_downloads[download_id] = std::move(state); + } + + if (config.verbose) { + std::cout << "[gRPC] StartChunkedDownload: CHUNKED response for job " << job_id + << ", download_id=" << download_id + << ", arrays=" << response->header().arrays_size() << ", is_mip=" << is_mip << "\n"; + std::cout.flush(); + } + + return Status::OK; + } + + Status GetResultChunk(ServerContext* context, + const cuopt::remote::GetResultChunkRequest* request, + cuopt::remote::GetResultChunkResponse* response) override + { + std::string download_id = request->download_id(); + auto field_id = request->field_id(); + int64_t elem_offset = request->element_offset(); + int64_t max_elements = request->max_elements(); + + std::lock_guard lock(chunked_downloads_mutex); + auto it = chunked_downloads.find(download_id); + if (it == chunked_downloads.end()) { + return Status(StatusCode::NOT_FOUND, "Unknown download_id: " + download_id); + } + + const auto& state = it->second; + + const uint8_t* raw_bytes = nullptr; + int64_t total_bytes = 0; + auto array_it = state.raw_arrays.find(static_cast(field_id)); + if (array_it != state.raw_arrays.end() && !array_it->second.empty()) { + raw_bytes = array_it->second.data(); + total_bytes = static_cast(array_it->second.size()); + } + + if (raw_bytes == nullptr || total_bytes == 0) { + return Status(StatusCode::INVALID_ARGUMENT, + "Unknown or empty result field: " + std::to_string(field_id)); + } + + const int64_t elem_size = sizeof(double); + const int64_t array_size = total_bytes / elem_size; + + if (elem_offset < 0 || elem_offset >= array_size) { + return Status(StatusCode::OUT_OF_RANGE, + "element_offset " + std::to_string(elem_offset) + " out of range [0, " + + std::to_string(array_size) + ")"); + } + + int64_t elems_available = array_size - elem_offset; + int64_t elems_to_send = + (max_elements > 0) ? std::min(max_elements, elems_available) : elems_available; + + response->set_download_id(download_id); + response->set_field_id(field_id); + response->set_element_offset(elem_offset); + response->set_elements_in_chunk(elems_to_send); + response->set_data(reinterpret_cast(raw_bytes + elem_offset * elem_size), + static_cast(elems_to_send) * elem_size); + + return Status::OK; + } + + Status FinishChunkedDownload(ServerContext* context, + const cuopt::remote::FinishChunkedDownloadRequest* request, + cuopt::remote::FinishChunkedDownloadResponse* response) override + { + std::string download_id = request->download_id(); + response->set_download_id(download_id); + + std::lock_guard lock(chunked_downloads_mutex); + auto it = chunked_downloads.find(download_id); + if (it == chunked_downloads.end()) { + return Status(StatusCode::NOT_FOUND, "Unknown download_id: " + download_id); + } + + if (config.verbose) { + auto elapsed_ms = std::chrono::duration_cast( + std::chrono::steady_clock::now() - it->second.created) + .count(); + std::cout << "[gRPC] FinishChunkedDownload: download_id=" << download_id + << " elapsed_ms=" << elapsed_ms << "\n"; + std::cout.flush(); + } + + chunked_downloads.erase(it); + return Status::OK; + } + + // ========================================================================= + // Delete, Cancel, Wait, StreamLogs, GetIncumbents + // ========================================================================= + + Status DeleteResult(ServerContext* context, + const cuopt::remote::DeleteRequest* request, + cuopt::remote::DeleteResponse* response) override + { + std::string job_id = request->job_id(); + + size_t erased = 0; + { + std::lock_guard lock(tracker_mutex); + erased = job_tracker.erase(job_id); + } + + if (erased == 0) { + response->set_status(cuopt::remote::ERROR_NOT_FOUND); + response->set_message("Job not found: " + job_id); + if (config.verbose) { + std::cout << "[gRPC] DeleteResult job not found: " << job_id << std::endl; + } + return Status::OK; + } + + delete_log_file(job_id); + + response->set_status(cuopt::remote::SUCCESS); + response->set_message("Result deleted"); + + if (config.verbose) { std::cout << "[gRPC] Result deleted for job: " << job_id << std::endl; } + + return Status::OK; + } + + Status CancelJob(ServerContext* context, + const cuopt::remote::CancelRequest* request, + cuopt::remote::CancelResponse* response) override + { + (void)context; + std::string job_id = request->job_id(); + + JobStatus internal_status = JobStatus::NOT_FOUND; + std::string message; + int rc = cancel_job(job_id, internal_status, message); + + cuopt::remote::JobStatus pb_status = cuopt::remote::NOT_FOUND; + switch (internal_status) { + case JobStatus::QUEUED: pb_status = cuopt::remote::QUEUED; break; + case JobStatus::PROCESSING: pb_status = cuopt::remote::PROCESSING; break; + case JobStatus::COMPLETED: pb_status = cuopt::remote::COMPLETED; break; + case JobStatus::FAILED: pb_status = cuopt::remote::FAILED; break; + case JobStatus::CANCELLED: pb_status = cuopt::remote::CANCELLED; break; + case JobStatus::NOT_FOUND: pb_status = cuopt::remote::NOT_FOUND; break; + } + + response->set_job_status(pb_status); + response->set_message(message); + + if (rc == 0 || rc == 3) { + response->set_status(cuopt::remote::SUCCESS); + } else if (rc == 1) { + response->set_status(cuopt::remote::ERROR_NOT_FOUND); + } else { + response->set_status(cuopt::remote::ERROR_INVALID_REQUEST); + } + + if (config.verbose) { + std::cout << "[gRPC] CancelJob job_id=" << job_id << " rc=" << rc + << " status=" << static_cast(pb_status) << " msg=" << message << "\n"; + std::cout.flush(); + } + + return Status::OK; + } + + // Block until a job reaches a terminal state (COMPLETED / FAILED / CANCELLED). + // Uses a shared JobWaiter with a condition variable that the result retrieval + // thread signals when it processes the job's result. Falls back to polling + // every 200ms in case the signal is missed (e.g., worker crash recovery). + Status WaitForCompletion(ServerContext* context, + const cuopt::remote::WaitRequest* request, + cuopt::remote::WaitResponse* response) override + { + const std::string job_id = request->job_id(); + + // Fast path: if the job is already in a terminal state, return immediately. + { + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it == job_tracker.end()) { + response->set_job_status(cuopt::remote::NOT_FOUND); + response->set_message("Job not found"); + response->set_result_size_bytes(0); + return Status::OK; + } + if (it->second.status == JobStatus::COMPLETED) { + response->set_job_status(cuopt::remote::COMPLETED); + response->set_message(""); + response->set_result_size_bytes(it->second.result_size_bytes); + return Status::OK; + } + if (it->second.status == JobStatus::FAILED) { + response->set_job_status(cuopt::remote::FAILED); + response->set_message(it->second.error_message); + response->set_result_size_bytes(0); + return Status::OK; + } + if (it->second.status == JobStatus::CANCELLED) { + response->set_job_status(cuopt::remote::CANCELLED); + response->set_message("Job was cancelled"); + response->set_result_size_bytes(0); + return Status::OK; + } + } + + // Slow path: register a waiter. Multiple concurrent WaitForCompletion + // RPCs for the same job share a single JobWaiter instance. + std::shared_ptr waiter; + { + std::lock_guard lock(waiters_mutex); + auto it = waiting_threads.find(job_id); + if (it != waiting_threads.end()) { + waiter = it->second; + } else { + waiter = std::make_shared(); + waiting_threads[job_id] = waiter; + } + } + waiter->waiters.fetch_add(1, std::memory_order_relaxed); + + // Wait loop: cv is signaled by the result retrieval thread; we also + // poll check_job_status as a safety net and check for client disconnect. + // All exit paths (ready, terminal status, cancellation) break out of the + // loop so that cleanup (waiters decrement) happens in one place below. + bool client_cancelled = false; + { + std::unique_lock lock(waiter->mutex); + while (!waiter->ready) { + if (context->IsCancelled()) { + client_cancelled = true; + break; + } + lock.unlock(); + std::string msg; + JobStatus current = check_job_status(job_id, msg); + lock.lock(); + if (current == JobStatus::COMPLETED || current == JobStatus::FAILED || + current == JobStatus::CANCELLED) { + break; + } + waiter->cv.wait_for(lock, std::chrono::milliseconds(200)); + } + } + + waiter->waiters.fetch_sub(1, std::memory_order_relaxed); + + if (client_cancelled) { + if (config.verbose) { + std::cout << "[gRPC] WaitForCompletion cancelled by client, job_id=" << job_id << "\n"; + std::cout.flush(); + } + return Status(StatusCode::CANCELLED, "Client cancelled WaitForCompletion"); + } + + // Build the response from the final job state. + // The waiter's `success` flag is set by the result retrieval thread when it + // processes a successful result. It is true only for normal completion. + if (waiter->success) { + response->set_job_status(cuopt::remote::COMPLETED); + response->set_message(""); + { + std::lock_guard lock(tracker_mutex); + auto job_it = job_tracker.find(job_id); + response->set_result_size_bytes( + (job_it != job_tracker.end()) ? job_it->second.result_size_bytes : 0); + } + } else { + // The waiter was not signaled with success. This happens when: + // - The job failed (solver error, worker crash) + // - The job was cancelled by the user + // - The wait loop exited via the polling safety net (check_job_status + // detected a terminal state before the cv was signaled) + // Re-check the authoritative job status to determine what happened. + std::string msg; + JobStatus status = check_job_status(job_id, msg); + switch (status) { + case JobStatus::COMPLETED: { + response->set_job_status(cuopt::remote::COMPLETED); + response->set_message(""); + std::lock_guard lock(tracker_mutex); + auto job_it = job_tracker.find(job_id); + response->set_result_size_bytes( + (job_it != job_tracker.end()) ? job_it->second.result_size_bytes : 0); + break; + } + case JobStatus::FAILED: response->set_job_status(cuopt::remote::FAILED); break; + case JobStatus::CANCELLED: response->set_job_status(cuopt::remote::CANCELLED); break; + case JobStatus::NOT_FOUND: response->set_job_status(cuopt::remote::NOT_FOUND); break; + default: response->set_job_status(cuopt::remote::FAILED); break; + } + if (status != JobStatus::COMPLETED) { + response->set_message(msg); + response->set_result_size_bytes(0); + } + } + + if (config.verbose) { + std::cout << "[gRPC] WaitForCompletion finished job_id=" << job_id << "\n"; + std::cout.flush(); + } + + return Status::OK; + } + + // Server-streaming RPC: tails the solver log file for a job, sending one + // LogMessage per line as new output appears (like `tail -f` over gRPC). + // The client supplies a byte offset so it can resume after reconnection. + // The stream ends with a sentinel message (job_complete=true) once the + // job reaches a terminal state and all remaining log content is flushed. + Status StreamLogs(ServerContext* context, + const cuopt::remote::StreamLogsRequest* request, + ServerWriter* writer) override + { + const std::string job_id = request->job_id(); + int64_t from_byte = request->from_byte(); + const std::string log_path = get_log_file_path(job_id); + + // Phase 1: Wait for the log file to appear on disk. + // The worker may not have created it yet, so poll with a short sleep. + // Every 2 s, verify the job still exists to avoid waiting forever on + // a deleted/unknown job. + int waited_ms = 0; + while (!context->IsCancelled()) { + struct stat st; + if (stat(log_path.c_str(), &st) == 0) { break; } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + waited_ms += 50; + if (waited_ms >= 2000) { + std::string msg; + JobStatus s = check_job_status(job_id, msg); + if (s == JobStatus::NOT_FOUND) { + if (config.verbose) { + std::cout << "[gRPC] StreamLogs job not found: " << job_id << std::endl; + } + return Status(grpc::StatusCode::NOT_FOUND, "Job not found: " + job_id); + } + if (s == JobStatus::COMPLETED || s == JobStatus::FAILED || s == JobStatus::CANCELLED) { + cuopt::remote::LogMessage done; + done.set_line(""); + done.set_byte_offset(from_byte); + done.set_job_complete(true); + writer->Write(done); + return Status::OK; + } + waited_ms = 0; + } + } + + // Phase 2: Open the file and seek to the caller's resume point. + std::ifstream in(log_path, std::ios::in | std::ios::binary); + if (!in.is_open()) { + cuopt::remote::LogMessage m; + m.set_line("Failed to open log file"); + m.set_byte_offset(from_byte); + m.set_job_complete(true); + writer->Write(m); + return Status::OK; + } + + if (from_byte > 0) { in.seekg(from_byte, std::ios::beg); } + + int64_t current_offset = from_byte; + std::string line; + + // Phase 3: Tail loop — read available lines, stream each one, then + // poll for more. Each LogMessage carries the byte offset of the *next* + // unread byte so the client can resume from that point. + while (!context->IsCancelled()) { + std::streampos before = in.tellg(); + if (before >= 0) { current_offset = static_cast(before); } + + if (std::getline(in, line)) { + std::streampos after = in.tellg(); + int64_t next_offset = current_offset; + if (after >= 0) { + next_offset = static_cast(after); + } else { + // tellg() can return -1 after the last line when there is no + // trailing newline; fall back to estimating from line length. + next_offset = current_offset + static_cast(line.size()); + } + + cuopt::remote::LogMessage m; + m.set_line(line); + m.set_byte_offset(next_offset); + m.set_job_complete(false); + if (!writer->Write(m)) { break; } + continue; + } + + // Caught up to the current end of file — clear the EOF/fail bit + // so the next getline attempt can see newly appended data. + if (in.eof()) { + in.clear(); + } else if (in.fail()) { + in.clear(); + } + + // Check whether the job has finished. If so, drain any final + // bytes the solver may have flushed after our last read, then + // send the job_complete sentinel and close the stream. + std::string msg; + JobStatus s = check_job_status(job_id, msg); + if (s == JobStatus::COMPLETED || s == JobStatus::FAILED || s == JobStatus::CANCELLED) { + std::streampos before2 = in.tellg(); + if (before2 >= 0) { current_offset = static_cast(before2); } + if (std::getline(in, line)) { + std::streampos after2 = in.tellg(); + int64_t next_offset2 = current_offset + static_cast(line.size()); + if (after2 >= 0) { next_offset2 = static_cast(after2); } + cuopt::remote::LogMessage m; + m.set_line(line); + m.set_byte_offset(next_offset2); + m.set_job_complete(false); + writer->Write(m); + } + + cuopt::remote::LogMessage done; + done.set_line(""); + done.set_byte_offset(current_offset); + done.set_job_complete(true); + writer->Write(done); + return Status::OK; + } + + // Job still running but no new data yet — back off briefly before + // retrying so we don't spin-wait on the file. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + return Status::OK; + } + + Status GetIncumbents(ServerContext* context, + const cuopt::remote::IncumbentRequest* request, + cuopt::remote::IncumbentResponse* response) override + { + (void)context; + const std::string job_id = request->job_id(); + int64_t from_index = request->from_index(); + int32_t max_count = request->max_count(); + + if (from_index < 0) { from_index = 0; } + + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it == job_tracker.end()) { return Status(StatusCode::NOT_FOUND, "Job not found"); } + + const auto& incumbents = it->second.incumbents; + int64_t available = static_cast(incumbents.size()); + if (from_index > available) { from_index = available; } + + int64_t count = available - from_index; + if (max_count > 0 && count > max_count) { count = max_count; } + + for (int64_t i = 0; i < count; ++i) { + const auto& inc = incumbents[static_cast(from_index + i)]; + auto* out = response->add_incumbents(); + out->set_index(from_index + i); + out->set_objective(inc.objective); + for (double v : inc.assignment) { + out->add_assignment(v); + } + out->set_job_id(job_id); + } + + // next_index is the resume cursor: the client passes it back as from_index + // on the next call. Must be from_index + count (the first unsent entry), + // NOT available (total size), or the client skips entries when max_count + // limits the batch. + response->set_next_index(from_index + count); + bool done = + (it->second.status == JobStatus::COMPLETED || it->second.status == JobStatus::FAILED || + it->second.status == JobStatus::CANCELLED); + response->set_job_complete(done); + if (config.verbose) { + std::cout << "[gRPC] GetIncumbents job_id=" << job_id << " from=" << from_index + << " returned=" << response->incumbents_size() << " next=" << (from_index + count) + << " done=" << (done ? 1 : 0) << "\n"; + std::cout.flush(); + } + return Status::OK; + } +}; + +// Provide access to the service implementation type from grpc_server_main.cpp. +// This avoids exposing the class definition in a header (it's only needed once in main). +std::unique_ptr create_cuopt_grpc_service() +{ + return std::make_unique(); +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_worker.cpp b/cpp/src/grpc/server/grpc_worker.cpp new file mode 100644 index 0000000000..fb758f0b3e --- /dev/null +++ b/cpp/src/grpc/server/grpc_worker.cpp @@ -0,0 +1,542 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef CUOPT_ENABLE_GRPC + +#include "grpc_incumbent_proto.hpp" +#include "grpc_pipe_serialization.hpp" +#include "grpc_server_types.hpp" + +// --------------------------------------------------------------------------- +// Data-transfer structs used to pass results between decomposed functions. +// --------------------------------------------------------------------------- + +struct DeserializedJob { + cpu_optimization_problem_t problem; + pdlp_solver_settings_t lp_settings; + mip_solver_settings_t mip_settings; + bool enable_incumbents = true; + bool success = false; +}; + +struct SolveResult { + cuopt::remote::ChunkedResultHeader header; + std::map> arrays; + std::string error_message; + bool success = false; +}; + +// --------------------------------------------------------------------------- +// Solver callback that forwards each new MIP incumbent to the server thread +// via a pipe. A fresh instance is created per solve (as a unique_ptr scoped +// to run_mip_solve) and registered with mip_settings.set_mip_callback(). +// The solver calls get_solution() every time it finds a better integer-feasible +// solution; we serialize the objective + variable assignment into a protobuf +// and push it down the incumbent pipe FD. The server thread reads the other +// end to serve GetIncumbents RPCs. +// --------------------------------------------------------------------------- + +class IncumbentPipeCallback : public cuopt::internals::get_solution_callback_t { + public: + IncumbentPipeCallback(std::string job_id, int fd, size_t num_vars, bool is_float) + : job_id_(std::move(job_id)), fd_(fd) + { + n_variables = num_vars; + isFloat = is_float; + } + + // Called by the MIP solver each time a new incumbent is found. + // data/objective_value arrive as raw void* whose actual type depends on + // isFloat; we normalize everything to double before serializing. + void get_solution(void* data, + void* objective_value, + void* /*solution_bound*/, + void* /*user_data*/) override + { + if (fd_ < 0 || n_variables == 0) { return; } + + double objective = 0.0; + std::vector assignment; + assignment.resize(n_variables); + + if (isFloat) { + const float* float_data = static_cast(data); + for (size_t i = 0; i < n_variables; ++i) { + assignment[i] = static_cast(float_data[i]); + } + objective = static_cast(*static_cast(objective_value)); + } else { + const double* double_data = static_cast(data); + std::copy(double_data, double_data + n_variables, assignment.begin()); + objective = *static_cast(objective_value); + } + + auto buffer = build_incumbent_proto(job_id_, objective, assignment); + if (!send_incumbent_pipe(fd_, buffer)) { + std::cerr << "[Worker] Incumbent pipe write failed for job " << job_id_ + << ", disabling further sends\n"; + fd_ = -1; + return; + } + } + + private: + std::string job_id_; + int fd_; +}; + +// --------------------------------------------------------------------------- +// Small utility helpers +// --------------------------------------------------------------------------- + +// Reset every field in a job slot so it can be reused by the next submission. +static void reset_job_slot(JobQueueEntry& job) +{ + job.worker_pid = 0; + job.worker_index = -1; + job.data_sent = false; + job.is_chunked = false; + job.ready = false; + job.claimed = false; + job.cancelled = false; +} + +// Log pipe throughput when config.verbose is enabled. +static void log_pipe_throughput(const char* phase, + int64_t total_bytes, + std::chrono::steady_clock::time_point t0) +{ + auto pipe_us = + std::chrono::duration_cast(std::chrono::steady_clock::now() - t0) + .count(); + double pipe_sec = pipe_us / 1e6; + double pipe_mb = static_cast(total_bytes) / (1024.0 * 1024.0); + double pipe_mbs = (pipe_sec > 0.0) ? (pipe_mb / pipe_sec) : 0.0; + std::cout << "[THROUGHPUT] phase=" << phase << " bytes=" << total_bytes + << " elapsed_ms=" << std::fixed << std::setprecision(1) << (pipe_us / 1000.0) + << " throughput_mb_s=" << std::setprecision(1) << pipe_mbs << "\n"; + std::cout.flush(); +} + +// Copy a device vector of T to a newly allocated host std::vector. +template +static std::vector device_to_host(const auto& device_vec) +{ + std::vector host(device_vec.size()); + cudaError_t err = cudaMemcpy( + host.data(), device_vec.data(), device_vec.size() * sizeof(T), cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + throw std::runtime_error(std::string("cudaMemcpy device-to-host failed: ") + + cudaGetErrorString(err)); + } + return host; +} + +// Write a result entry with no payload (error, cancellation, etc.) into the +// first free slot in the shared-memory result_queue. +// +// Lock-free protocol for cross-process writes (workers are forked): +// 1. Skip slots where ready==true (still being consumed by the reader). +// 2. CAS claimed false→true to get exclusive write access. Another +// writer (different worker process) that races on the same slot will +// see the CAS fail and move to the next slot. +// 3. Re-check ready after claiming, in case the reader set ready=true +// between step 1 and step 2. +// 4. Write all non-atomic fields, then publish with ready=true (release) +// so the reader sees a consistent entry. +// 5. Clear claimed so the slot can be recycled after the reader is done. +// +// The same protocol is used by publish_result() and the crash-recovery +// path in grpc_worker_infra.cpp. +static void store_simple_result(const std::string& job_id, + int worker_id, + ResultStatus status, + const char* error_message) +{ + for (size_t i = 0; i < MAX_RESULTS; ++i) { + if (result_queue[i].ready.load(std::memory_order_acquire)) continue; + bool expected = false; + if (!result_queue[i].claimed.compare_exchange_strong( + expected, true, std::memory_order_acq_rel)) { + continue; + } + if (result_queue[i].ready.load(std::memory_order_acquire)) { + result_queue[i].claimed.store(false, std::memory_order_release); + continue; + } + copy_cstr(result_queue[i].job_id, job_id); + result_queue[i].status = status; + result_queue[i].data_size = 0; + result_queue[i].worker_index.store(worker_id, std::memory_order_relaxed); + copy_cstr(result_queue[i].error_message, error_message); + result_queue[i].error_message[sizeof(result_queue[i].error_message) - 1] = '\0'; + result_queue[i].retrieved.store(false, std::memory_order_relaxed); + result_queue[i].ready.store(true, std::memory_order_release); + result_queue[i].claimed.store(false, std::memory_order_release); + break; + } +} + +// --------------------------------------------------------------------------- +// Stage functions called from the worker_process main loop +// --------------------------------------------------------------------------- + +// Atomically claim the first ready-but-unclaimed job slot, stamping it with +// this worker's PID and index. Returns the slot index, or -1 if none found. +static int claim_job_slot(int worker_id) +{ + for (size_t i = 0; i < MAX_JOBS; ++i) { + if (job_queue[i].ready && !job_queue[i].claimed) { + bool expected = false; + if (job_queue[i].claimed.compare_exchange_strong(expected, true)) { + job_queue[i].worker_pid = getpid(); + job_queue[i].worker_index = worker_id; + return static_cast(i); + } + } + } + return -1; +} + +// Deserialize the problem from the worker's pipe. Handles both chunked and +// unary IPC formats. Returns a DeserializedJob with success=false on error. +static DeserializedJob read_problem_from_pipe(int worker_id, const JobQueueEntry& job) +{ + DeserializedJob dj; + + int read_fd = worker_pipes[worker_id].worker_read_fd; + bool is_chunked_job = job.is_chunked.load(); + + auto pipe_recv_t0 = std::chrono::steady_clock::now(); + + if (is_chunked_job) { + // Chunked path: the server wrote a ChunkedProblemHeader followed by + // a set of raw typed arrays (constraint matrix, bounds, etc.). + // This avoids a single giant protobuf allocation for large problems. + cuopt::remote::ChunkedProblemHeader chunked_header; + std::map> arrays; + if (!read_chunked_request_from_pipe(read_fd, chunked_header, arrays)) { return dj; } + + if (config.verbose) { + int64_t total_bytes = 0; + for (const auto& [fid, data] : arrays) { + total_bytes += data.size(); + } + log_pipe_throughput("pipe_job_recv", total_bytes, pipe_recv_t0); + std::cout << "[Worker] IPC path: CHUNKED (" << arrays.size() << " arrays, " << total_bytes + << " bytes)\n"; + std::cout.flush(); + } + if (chunked_header.has_lp_settings()) { + map_proto_to_pdlp_settings(chunked_header.lp_settings(), dj.lp_settings); + } + if (chunked_header.has_mip_settings()) { + map_proto_to_mip_settings(chunked_header.mip_settings(), dj.mip_settings); + } + dj.enable_incumbents = chunked_header.enable_incumbents(); + map_chunked_arrays_to_problem(chunked_header, arrays, dj.problem); + } else { + // Unary path: the entire SubmitJobRequest was serialized as a single + // protobuf blob. Simpler but copies more memory for large problems. + std::vector request_data; + if (!recv_job_data_pipe(read_fd, job.data_size, request_data)) { return dj; } + + if (config.verbose) { + log_pipe_throughput("pipe_job_recv", static_cast(request_data.size()), pipe_recv_t0); + } + cuopt::remote::SubmitJobRequest submit_request; + if (!submit_request.ParseFromArray(request_data.data(), + static_cast(request_data.size())) || + (!submit_request.has_lp_request() && !submit_request.has_mip_request())) { + return dj; + } + if (submit_request.has_lp_request()) { + const auto& req = submit_request.lp_request(); + std::cout << "[Worker] IPC path: UNARY LP (" << request_data.size() << " bytes)\n" + << std::flush; + map_proto_to_problem(req.problem(), dj.problem); + map_proto_to_pdlp_settings(req.settings(), dj.lp_settings); + } else { + const auto& req = submit_request.mip_request(); + std::cout << "[Worker] IPC path: UNARY MIP (" << request_data.size() << " bytes)\n" + << std::flush; + map_proto_to_problem(req.problem(), dj.problem); + map_proto_to_mip_settings(req.settings(), dj.mip_settings); + dj.enable_incumbents = req.has_enable_incumbents() ? req.enable_incumbents() : true; + } + } + + dj.success = true; + return dj; +} + +// Run the MIP solver on the GPU and serialize the solution into chunked format. +// The incumbent callback is created and scoped here so it lives exactly as +// long as the solve. Exceptions are caught and returned as error messages. +static SolveResult run_mip_solve(DeserializedJob& dj, + raft::handle_t& handle, + const std::string& log_file, + const std::string& job_id, + int worker_id) +{ + SolveResult sr; + try { + dj.mip_settings.log_file = log_file; + dj.mip_settings.log_to_console = config.log_to_console; + + // Create a per-solve incumbent callback wired to this worker's + // incumbent pipe. Destroyed automatically when sr is returned. + std::unique_ptr incumbent_cb; + if (dj.enable_incumbents) { + incumbent_cb = + std::make_unique(job_id, + worker_pipes[worker_id].worker_incumbent_write_fd, + dj.problem.get_n_variables(), + false); + dj.mip_settings.set_mip_callback(incumbent_cb.get()); + std::cout << "[Worker] Registered incumbent callback for job_id=" << job_id + << " n_vars=" << dj.problem.get_n_variables() << "\n"; + std::cout.flush(); + } + + std::cout << "[Worker] Converting CPU problem to GPU problem...\n" << std::flush; + auto gpu_problem = dj.problem.to_optimization_problem(&handle); + + std::cout << "[Worker] Calling solve_mip...\n" << std::flush; + auto gpu_solution = solve_mip(*gpu_problem, dj.mip_settings); + std::cout << "[Worker] solve_mip done\n" << std::flush; + + std::cout << "[Worker] Converting solution to CPU format...\n" << std::flush; + + auto host_solution = device_to_host(gpu_solution.get_solution()); + + cpu_mip_solution_t cpu_solution(std::move(host_solution), + gpu_solution.get_termination_status(), + gpu_solution.get_objective_value(), + gpu_solution.get_mip_gap(), + gpu_solution.get_solution_bound(), + gpu_solution.get_total_solve_time(), + gpu_solution.get_presolve_time(), + gpu_solution.get_max_constraint_violation(), + gpu_solution.get_max_int_violation(), + gpu_solution.get_max_variable_bound_violation(), + gpu_solution.get_num_nodes(), + gpu_solution.get_num_simplex_iterations()); + + populate_chunked_result_header_mip(cpu_solution, &sr.header); + sr.arrays = collect_mip_solution_arrays(cpu_solution); + std::cout << "[Worker] Result path: MIP solution -> " << sr.arrays.size() << " array(s)\n" + << std::flush; + sr.success = true; + } catch (const std::exception& e) { + sr.error_message = std::string("Exception: ") + e.what(); + } + return sr; +} + +// Run the LP solver on the GPU and serialize the solution into chunked format. +// No incumbent callback (LP solvers don't produce intermediate solutions). +// Exceptions are caught and returned as error messages. +static SolveResult run_lp_solve(DeserializedJob& dj, + raft::handle_t& handle, + const std::string& log_file) +{ + SolveResult sr; + try { + dj.lp_settings.log_file = log_file; + dj.lp_settings.log_to_console = config.log_to_console; + + std::cout << "[Worker] Converting CPU problem to GPU problem...\n" << std::flush; + auto gpu_problem = dj.problem.to_optimization_problem(&handle); + + std::cout << "[Worker] Calling solve_lp...\n" << std::flush; + auto gpu_solution = solve_lp(*gpu_problem, dj.lp_settings); + std::cout << "[Worker] solve_lp done\n" << std::flush; + + std::cout << "[Worker] Converting solution to CPU format...\n" << std::flush; + + auto host_primal = device_to_host(gpu_solution.get_primal_solution()); + auto host_dual = device_to_host(gpu_solution.get_dual_solution()); + auto host_reduced_cost = device_to_host(gpu_solution.get_reduced_cost()); + + auto term_info = gpu_solution.get_additional_termination_information(); + + // Warm-start data lets clients resume an interrupted LP solve from + // where it left off without starting over. + auto cpu_ws = + convert_to_cpu_warmstart(gpu_solution.get_pdlp_warm_start_data(), handle.get_stream()); + + cpu_lp_solution_t cpu_solution(std::move(host_primal), + std::move(host_dual), + std::move(host_reduced_cost), + gpu_solution.get_termination_status(), + gpu_solution.get_objective_value(), + gpu_solution.get_dual_objective_value(), + term_info.solve_time, + term_info.l2_primal_residual, + term_info.l2_dual_residual, + term_info.gap, + term_info.number_of_steps_taken, + term_info.solved_by_pdlp, + std::move(cpu_ws)); + + populate_chunked_result_header_lp(cpu_solution, &sr.header); + sr.arrays = collect_lp_solution_arrays(cpu_solution); + std::cout << "[Worker] Result path: LP solution -> " << sr.arrays.size() << " array(s)\n" + << std::flush; + sr.success = true; + } catch (const std::exception& e) { + sr.error_message = std::string("Exception: ") + e.what(); + } + return sr; +} + +// Publish a solve result: claim a slot in the shared-memory result_queue +// (metadata) and, for successful solves, stream the full solution payload +// through the worker's result pipe for the server thread to read. +static void publish_result(const SolveResult& sr, const std::string& job_id, int worker_id) +{ + int64_t result_total_bytes = 0; + if (sr.success) { + for (const auto& [fid, data] : sr.arrays) { + result_total_bytes += data.size(); + } + } + + // Same CAS protocol as store_simple_result (see comment there). + int result_slot = -1; + for (size_t i = 0; i < MAX_RESULTS; ++i) { + if (result_queue[i].ready.load(std::memory_order_acquire)) continue; + bool expected = false; + if (!result_queue[i].claimed.compare_exchange_strong( + expected, true, std::memory_order_acq_rel)) { + continue; + } + if (result_queue[i].ready.load(std::memory_order_acquire)) { + result_queue[i].claimed.store(false, std::memory_order_release); + continue; + } + result_slot = static_cast(i); + ResultQueueEntry& result = result_queue[i]; + copy_cstr(result.job_id, job_id); + result.status = sr.success ? RESULT_SUCCESS : RESULT_ERROR; + result.data_size = sr.success ? std::max(result_total_bytes, 1) : 0; + result.worker_index.store(worker_id, std::memory_order_relaxed); + if (!sr.success) { copy_cstr(result.error_message, sr.error_message); } + result.retrieved.store(false, std::memory_order_relaxed); + result.ready.store(true, std::memory_order_release); + result.claimed.store(false, std::memory_order_release); + if (config.verbose) { + std::cout << "[Worker " << worker_id << "] Enqueued result metadata for job " << job_id + << " in result_slot=" << result_slot << " status=" << result.status + << " data_size=" << result.data_size << "\n"; + std::cout.flush(); + } + break; + } + + // Stream the full solution payload through the worker's result pipe. + // The server thread reads the other end when the client calls + // GetResult / DownloadChunk. + if (sr.success && result_slot >= 0) { + int write_fd = worker_pipes[worker_id].worker_write_fd; + if (config.verbose) { + std::cout << "[Worker " << worker_id << "] Streaming result (" << sr.arrays.size() + << " arrays, " << result_total_bytes << " bytes) to pipe for job " << job_id + << "\n"; + std::cout.flush(); + } + auto pipe_result_t0 = std::chrono::steady_clock::now(); + bool write_success = write_result_to_pipe(write_fd, sr.header, sr.arrays); + if (write_success && config.verbose) { + log_pipe_throughput("pipe_result_send", result_total_bytes, pipe_result_t0); + } + if (!write_success) { + std::cerr << "[Worker " << worker_id << "] Failed to write result to pipe\n"; + std::cerr.flush(); + result_queue[result_slot].status = RESULT_ERROR; + copy_cstr(result_queue[result_slot].error_message, "Failed to write result to pipe"); + } else if (config.verbose) { + std::cout << "[Worker " << worker_id << "] Finished writing result payload for job " << job_id + << "\n"; + std::cout.flush(); + } + } else if (config.verbose) { + std::cout << "[Worker " << worker_id << "] No result payload write needed for job " << job_id + << " (success=" << sr.success << ", result_slot=" << result_slot + << ", payload_bytes=" << result_total_bytes << ")\n"; + std::cout.flush(); + } +} + +// --------------------------------------------------------------------------- +// Main worker loop — pure policy. All implementation detail is in the +// stage functions above. +// --------------------------------------------------------------------------- + +void worker_process(int worker_id) +{ + std::cout << "[Worker " << worker_id << "] Started (PID: " << getpid() << ")\n"; + + shm_ctrl->active_workers++; + + while (!shm_ctrl->shutdown_requested) { + int job_slot = claim_job_slot(worker_id); + if (job_slot < 0) { + usleep(10000); + continue; + } + + JobQueueEntry& job = job_queue[job_slot]; + std::string job_id(job.job_id); + bool is_mip = (job.problem_type == 1); + + if (job.cancelled) { + std::cout << "[Worker " << worker_id << "] Job cancelled before processing: " << job_id + << "\n" + << std::flush; + store_simple_result(job_id, worker_id, RESULT_CANCELLED, "Job was cancelled"); + reset_job_slot(job); + continue; + } + + std::cout << "[Worker " << worker_id << "] Processing job: " << job_id + << " (type: " << (is_mip ? "MIP" : "LP") << ")\n" + << std::flush; + + auto deserialized = read_problem_from_pipe(worker_id, job); + if (!deserialized.success) { + std::cerr << "[Worker " << worker_id << "] Failed to read job data from pipe\n"; + store_simple_result(job_id, worker_id, RESULT_ERROR, "Failed to read job data"); + reset_job_slot(job); + continue; + } + + std::cout << "[Worker] Problem reconstructed: " << deserialized.problem.get_n_constraints() + << " constraints, " << deserialized.problem.get_n_variables() << " variables, " + << deserialized.problem.get_nnz() << " nonzeros\n" + << std::flush; + + std::string log_file = get_log_file_path(job_id); + raft::handle_t handle; + + SolveResult result = is_mip ? run_mip_solve(deserialized, handle, log_file, job_id, worker_id) + : run_lp_solve(deserialized, handle, log_file); + + publish_result(result, job_id, worker_id); + reset_job_slot(job); + + std::cout << "[Worker " << worker_id << "] Completed job: " << job_id + << " (success: " << result.success << ")\n"; + } + + shm_ctrl->active_workers--; + std::cout << "[Worker " << worker_id << "] Stopped\n"; + // _exit() instead of exit() to avoid running atexit handlers or flushing + // parent-inherited stdio buffers a second time in the forked child. + _exit(0); +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/grpc/server/grpc_worker_infra.cpp b/cpp/src/grpc/server/grpc_worker_infra.cpp new file mode 100644 index 0000000000..e019ca4aa9 --- /dev/null +++ b/cpp/src/grpc/server/grpc_worker_infra.cpp @@ -0,0 +1,239 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef CUOPT_ENABLE_GRPC + +#include "grpc_pipe_serialization.hpp" +#include "grpc_server_types.hpp" + +void cleanup_shared_memory() +{ + if (job_queue) { + munmap(job_queue, sizeof(JobQueueEntry) * MAX_JOBS); + shm_unlink(SHM_JOB_QUEUE); + } + if (result_queue) { + munmap(result_queue, sizeof(ResultQueueEntry) * MAX_RESULTS); + shm_unlink(SHM_RESULT_QUEUE); + } + if (shm_ctrl) { + munmap(shm_ctrl, sizeof(SharedMemoryControl)); + shm_unlink(SHM_CONTROL); + } +} + +static void close_and_reset(int& fd) +{ + if (fd >= 0) { + close(fd); + fd = -1; + } +} + +static void close_all_worker_pipes(WorkerPipes& wp) +{ + close_and_reset(wp.worker_read_fd); + close_and_reset(wp.to_worker_fd); + close_and_reset(wp.from_worker_fd); + close_and_reset(wp.worker_write_fd); + close_and_reset(wp.incumbent_from_worker_fd); + close_and_reset(wp.worker_incumbent_write_fd); +} + +bool create_worker_pipes(int worker_id) +{ + while (static_cast(worker_pipes.size()) <= worker_id) { + worker_pipes.push_back({-1, -1, -1, -1, -1, -1}); + } + + WorkerPipes& wp = worker_pipes[worker_id]; + + int fds[2]; + + if (pipe(fds) < 0) { + std::cerr << "[Server] Failed to create input pipe for worker " << worker_id << "\n"; + return false; + } + wp.worker_read_fd = fds[0]; + wp.to_worker_fd = fds[1]; + fcntl(wp.to_worker_fd, F_SETPIPE_SZ, kPipeBufferSize); + + if (pipe(fds) < 0) { + std::cerr << "[Server] Failed to create output pipe for worker " << worker_id << "\n"; + close_all_worker_pipes(wp); + return false; + } + wp.from_worker_fd = fds[0]; + wp.worker_write_fd = fds[1]; + fcntl(wp.worker_write_fd, F_SETPIPE_SZ, kPipeBufferSize); + + if (pipe(fds) < 0) { + std::cerr << "[Server] Failed to create incumbent pipe for worker " << worker_id << "\n"; + close_all_worker_pipes(wp); + return false; + } + wp.incumbent_from_worker_fd = fds[0]; + wp.worker_incumbent_write_fd = fds[1]; + + return true; +} + +void close_worker_pipes_server(int worker_id) +{ + if (worker_id < 0 || worker_id >= static_cast(worker_pipes.size())) return; + + WorkerPipes& wp = worker_pipes[worker_id]; + close_and_reset(wp.to_worker_fd); + close_and_reset(wp.from_worker_fd); + close_and_reset(wp.incumbent_from_worker_fd); +} + +void close_worker_pipes_child_ends(int worker_id) +{ + if (worker_id < 0 || worker_id >= static_cast(worker_pipes.size())) return; + + WorkerPipes& wp = worker_pipes[worker_id]; + close_and_reset(wp.worker_read_fd); + close_and_reset(wp.worker_write_fd); + close_and_reset(wp.worker_incumbent_write_fd); +} + +pid_t spawn_worker(int worker_id, bool is_replacement) +{ + std::lock_guard lock(worker_pipes_mutex); + + if (is_replacement) { close_worker_pipes_server(worker_id); } + + if (!create_worker_pipes(worker_id)) { + std::cerr << "[Server] Failed to create pipes for " + << (is_replacement ? "replacement worker " : "worker ") << worker_id << "\n"; + return -1; + } + + pid_t pid = fork(); + if (pid < 0) { + std::cerr << "[Server] Failed to fork " << (is_replacement ? "replacement worker " : "worker ") + << worker_id << "\n"; + close_all_worker_pipes(worker_pipes[worker_id]); + return -1; + } else if (pid == 0) { + // Child: close all fds belonging to other workers. + for (int j = 0; j < static_cast(worker_pipes.size()); ++j) { + if (j != worker_id) { close_all_worker_pipes(worker_pipes[j]); } + } + // Close the server-side ends of this worker's pipes (child uses the other ends). + close_and_reset(worker_pipes[worker_id].to_worker_fd); + close_and_reset(worker_pipes[worker_id].from_worker_fd); + close_and_reset(worker_pipes[worker_id].incumbent_from_worker_fd); + worker_process(worker_id); + _exit(0); + } + + close_worker_pipes_child_ends(worker_id); + return pid; +} + +void spawn_workers() +{ + for (int i = 0; i < config.num_workers; ++i) { + pid_t pid = spawn_worker(i, false); + if (pid < 0) { continue; } + worker_pids.push_back(pid); + } +} + +void wait_for_workers() +{ + for (pid_t pid : worker_pids) { + if (pid <= 0) continue; + int status; + while (waitpid(pid, &status, 0) < 0 && errno == EINTR) {} + } + worker_pids.clear(); +} + +pid_t spawn_single_worker(int worker_id) { return spawn_worker(worker_id, true); } + +// Called by the worker-monitor thread when waitpid() detects a dead worker. +// Scans the shared-memory job queue for any job that was assigned to the dead +// worker and transitions it to FAILED (or CANCELLED if it was a user-initiated +// cancel that killed the worker). Three data structures must be updated: +// 1. pending_job_data — discard the serialized request bytes +// 2. result_queue — post a synthetic error result so the client unblocks +// 3. job_queue + job_tracker — mark the slot free and record final status +void mark_worker_jobs_failed(pid_t dead_worker_pid) +{ + for (size_t i = 0; i < MAX_JOBS; ++i) { + if (job_queue[i].ready && job_queue[i].claimed && job_queue[i].worker_pid == dead_worker_pid) { + std::string job_id(job_queue[i].job_id); + bool was_cancelled = job_queue[i].cancelled; + + if (was_cancelled) { + std::cerr << "[Server] Worker " << dead_worker_pid + << " killed for cancelled job: " << job_id << "\n"; + } else { + std::cerr << "[Server] Worker " << dead_worker_pid + << " died while processing job: " << job_id << "\n"; + } + + // 1. Drop the buffered request data (no longer needed). + { + std::lock_guard lock(pending_data_mutex); + pending_job_data.erase(job_id); + } + + // 2. Post a synthetic error result into the first free result_queue slot + // so that any client polling for results gets a clear failure message. + // Uses the same CAS protocol as store_simple_result (see comment there). + for (size_t j = 0; j < MAX_RESULTS; ++j) { + if (result_queue[j].ready.load(std::memory_order_acquire)) continue; + bool exp = false; + if (!result_queue[j].claimed.compare_exchange_strong( + exp, true, std::memory_order_acq_rel)) { + continue; + } + if (result_queue[j].ready.load(std::memory_order_acquire)) { + result_queue[j].claimed.store(false, std::memory_order_release); + continue; + } + copy_cstr(result_queue[j].job_id, job_id); + result_queue[j].status = was_cancelled ? RESULT_CANCELLED : RESULT_ERROR; + result_queue[j].data_size = 0; + result_queue[j].worker_index.store(-1, std::memory_order_relaxed); + copy_cstr(result_queue[j].error_message, + was_cancelled ? "Job was cancelled" : "Worker process died unexpectedly"); + result_queue[j].retrieved.store(false, std::memory_order_relaxed); + result_queue[j].ready.store(true, std::memory_order_release); + result_queue[j].claimed.store(false, std::memory_order_release); + break; + } + + // 3. Release the job queue slot and update the in-process job tracker. + job_queue[i].worker_pid = 0; + job_queue[i].worker_index = -1; + job_queue[i].data_sent = false; + job_queue[i].is_chunked = false; + job_queue[i].ready = false; + job_queue[i].claimed = false; + job_queue[i].cancelled = false; + + { + std::lock_guard lock(tracker_mutex); + auto it = job_tracker.find(job_id); + if (it != job_tracker.end()) { + if (was_cancelled) { + it->second.status = JobStatus::CANCELLED; + it->second.error_message = "Job was cancelled"; + } else { + it->second.status = JobStatus::FAILED; + it->second.error_message = "Worker process died unexpectedly"; + } + } + } + } + } +} + +#endif // CUOPT_ENABLE_GRPC diff --git a/cpp/src/pdlp/CMakeLists.txt b/cpp/src/pdlp/CMakeLists.txt index 30fc3cd3ff..f5f26837b6 100644 --- a/cpp/src/pdlp/CMakeLists.txt +++ b/cpp/src/pdlp/CMakeLists.txt @@ -11,7 +11,6 @@ set(LP_CORE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/backend_selection.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utilities/problem_checking.cu ${CMAKE_CURRENT_SOURCE_DIR}/solve.cu - ${CMAKE_CURRENT_SOURCE_DIR}/solve_remote.cu ${CMAKE_CURRENT_SOURCE_DIR}/pdlp.cu ${CMAKE_CURRENT_SOURCE_DIR}/pdhg.cu ${CMAKE_CURRENT_SOURCE_DIR}/solver_solution.cu diff --git a/cpp/src/pdlp/solve_remote.cu b/cpp/src/pdlp/solve_remote.cu deleted file mode 100644 index a9bf7e3989..0000000000 --- a/cpp/src/pdlp/solve_remote.cu +++ /dev/null @@ -1,120 +0,0 @@ -/* clang-format off */ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -/* clang-format on */ - -#include -#include -#include -#include -#include - -namespace cuopt::linear_programming { - -// ============================================================================ -// Remote execution stubs (placeholder implementations) -// ============================================================================ - -template -std::unique_ptr> solve_lp_remote( - cpu_optimization_problem_t const& cpu_problem, - pdlp_solver_settings_t const& settings, - bool problem_checking, - bool use_pdlp_solver_mode) -{ - init_logger_t log(settings.log_file, settings.log_to_console); - CUOPT_LOG_INFO( - "solve_lp_remote (CPU problem) stub called - returning dummy solution for testing"); - - // TODO: Implement actual remote LP solving via gRPC - // For now, return a dummy solution with fake data (allows testing the full flow) - i_t n_vars = cpu_problem.get_n_variables(); - i_t n_constraints = cpu_problem.get_n_constraints(); - - std::vector primal_solution(n_vars, 0.0); - std::vector dual_solution(n_constraints, 0.0); - std::vector reduced_cost(n_vars, 0.0); - - // Create fake warm start data struct with recognizable non-zero values for testing - cpu_pdlp_warm_start_data_t warmstart; - warmstart.current_primal_solution_ = std::vector(n_vars, 1.1); - warmstart.current_dual_solution_ = std::vector(n_constraints, 2.2); - warmstart.initial_primal_average_ = std::vector(n_vars, 3.3); - warmstart.initial_dual_average_ = std::vector(n_constraints, 4.4); - warmstart.current_ATY_ = std::vector(n_vars, 5.5); - warmstart.sum_primal_solutions_ = std::vector(n_vars, 6.6); - warmstart.sum_dual_solutions_ = std::vector(n_constraints, 7.7); - warmstart.last_restart_duality_gap_primal_solution_ = std::vector(n_vars, 8.8); - warmstart.last_restart_duality_gap_dual_solution_ = std::vector(n_constraints, 9.9); - warmstart.initial_primal_weight_ = 99.1; - warmstart.initial_step_size_ = 99.2; - warmstart.total_pdlp_iterations_ = 100; - warmstart.total_pdhg_iterations_ = 200; - warmstart.last_candidate_kkt_score_ = 99.3; - warmstart.last_restart_kkt_score_ = 99.4; - warmstart.sum_solution_weight_ = 99.5; - warmstart.iterations_since_last_restart_ = 10; - - auto solution = std::make_unique>( - std::move(primal_solution), - std::move(dual_solution), - std::move(reduced_cost), - pdlp_termination_status_t::Optimal, // Fake optimal status - 0.0, // Primal objective (zero solution) - 0.0, // Dual objective (zero solution) - 0.01, // Dummy solve time - 0.001, // l2_primal_residual - 0.002, // l2_dual_residual - 0.003, // gap - 42, // num_iterations - true, // solved_by_pdlp - std::move(warmstart) // warmstart data - ); - - return solution; -} - -template -std::unique_ptr> solve_mip_remote( - cpu_optimization_problem_t const& cpu_problem, - mip_solver_settings_t const& settings) -{ - init_logger_t log(settings.log_file, settings.log_to_console); - CUOPT_LOG_INFO( - "solve_mip_remote (CPU problem) stub called - returning dummy solution for testing"); - - // TODO: Implement actual remote MIP solving via gRPC - // For now, return a dummy solution with fake data (allows testing the full flow) - i_t n_vars = cpu_problem.get_n_variables(); - - std::vector solution(n_vars, 0.0); - auto mip_solution = std::make_unique>( - std::move(solution), - mip_termination_status_t::Optimal, // Fake optimal status - 0.0, // Objective value (zero solution) - 0.0, // MIP gap - 0.0, // Solution bound - 0.01, // Total solve time - 0.0, // Presolve time - 0.0, // Max constraint violation - 0.0, // Max int violation - 0.0, // Max variable bound violation - 0, // Number of nodes - 0); // Number of simplex iterations - - return mip_solution; -} - -// Explicit template instantiations for remote execution stubs -template std::unique_ptr> solve_lp_remote( - cpu_optimization_problem_t const&, - pdlp_solver_settings_t const&, - bool, - bool); - -template std::unique_ptr> solve_mip_remote( - cpu_optimization_problem_t const&, mip_solver_settings_t const&); - -} // namespace cuopt::linear_programming diff --git a/cpp/tests/linear_programming/CMakeLists.txt b/cpp/tests/linear_programming/CMakeLists.txt index 434af9ed39..677ae8cb70 100644 --- a/cpp/tests/linear_programming/CMakeLists.txt +++ b/cpp/tests/linear_programming/CMakeLists.txt @@ -72,3 +72,9 @@ if (NOT SKIP_C_PYTHON_ADAPTERS) EXCLUDE_FROM_ALL ) endif() + +# ################################################################################################## +# - gRPC Tests ------------------------------------------------------------------------------------- +if(CUOPT_ENABLE_GRPC) + add_subdirectory(grpc) +endif() diff --git a/cpp/tests/linear_programming/c_api_tests/c_api_tests.cpp b/cpp/tests/linear_programming/c_api_tests/c_api_tests.cpp index d39a970763..0f7a3200a0 100644 --- a/cpp/tests/linear_programming/c_api_tests/c_api_tests.cpp +++ b/cpp/tests/linear_programming/c_api_tests/c_api_tests.cpp @@ -323,77 +323,233 @@ TEST(c_api, mip_solution_lp_methods) { EXPECT_EQ(test_mip_solution_lp_methods(), // ============================================================================= // CPU-Only Execution Tests // These tests verify that cuOpt can run on a CPU-only host with remote execution -// enabled. The remote solve stubs return dummy results. +// enabled, forwarding solves to a real cuopt_grpc_server over gRPC. +// +// A single shared server is started once for all tests in this fixture +// (SetUpTestSuite / TearDownTestSuite) to avoid per-test startup overhead. // ============================================================================= -// Helper to set environment variables for CPU-only mode -class CPUOnlyTestEnvironment { - public: - CPUOnlyTestEnvironment() +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +std::string find_in_path(const std::string& name) +{ + const char* path_env = std::getenv("PATH"); + if (!path_env) return ""; + + std::string path_str(path_env); + std::string::size_type start = 0; + std::string::size_type end; + + while ((end = path_str.find(':', start)) != std::string::npos || start < path_str.size()) { + std::string dir; + if (end != std::string::npos) { + dir = path_str.substr(start, end - start); + start = end + 1; + } else { + dir = path_str.substr(start); + start = path_str.size(); + } + if (dir.empty()) continue; + std::string full_path = dir + "/" + name; + if (access(full_path.c_str(), X_OK) == 0) { return full_path; } + } + return ""; +} + +std::string find_server_binary() +{ + const char* env_path = std::getenv("CUOPT_GRPC_SERVER_PATH"); + if (env_path && access(env_path, X_OK) == 0) { return env_path; } + + std::string path_result = find_in_path("cuopt_grpc_server"); + if (!path_result.empty()) { return path_result; } + + std::vector paths = { + "./cuopt_grpc_server", + "../cuopt_grpc_server", + "../../cuopt_grpc_server", + "./build/cuopt_grpc_server", + "../build/cuopt_grpc_server", + }; + for (const auto& path : paths) { + if (access(path.c_str(), X_OK) == 0) { return path; } + } + return ""; +} + +bool tcp_connect_check(int port, int timeout_ms) +{ + auto start = std::chrono::steady_clock::now(); + while (true) { + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) return false; + + struct sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + + if (connect(sock, reinterpret_cast(&addr), sizeof(addr)) == 0) { + close(sock); + return true; + } + close(sock); + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (elapsed.count() >= timeout_ms) return false; + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + } +} + +} // namespace + +class CpuOnlyWithServerTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { - // Save original values - const char* cuda_visible = getenv("CUDA_VISIBLE_DEVICES"); - const char* remote_host = getenv("CUOPT_REMOTE_HOST"); - const char* remote_port = getenv("CUOPT_REMOTE_PORT"); - - orig_cuda_visible_ = cuda_visible ? cuda_visible : ""; - orig_remote_host_ = remote_host ? remote_host : ""; - orig_remote_port_ = remote_port ? remote_port : ""; - cuda_was_set_ = (cuda_visible != nullptr); - host_was_set_ = (remote_host != nullptr); - port_was_set_ = (remote_port != nullptr); - - // Set CPU-only environment + server_path_ = find_server_binary(); + if (server_path_.empty()) { + skip_reason_ = "cuopt_grpc_server binary not found"; + return; + } + + port_ = 18500; + const char* env_base = std::getenv("CUOPT_TEST_PORT_BASE"); + if (env_base) { port_ = std::atoi(env_base) + 500; } + + server_pid_ = fork(); + if (server_pid_ < 0) { + skip_reason_ = "fork() failed"; + return; + } + + if (server_pid_ == 0) { + std::string port_str = std::to_string(port_); + std::string log_file = "/tmp/cuopt_c_api_test_server_" + port_str + ".log"; + int fd = open(log_file.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (fd >= 0) { + dup2(fd, STDOUT_FILENO); + dup2(fd, STDERR_FILENO); + close(fd); + } + execl(server_path_.c_str(), + server_path_.c_str(), + "--port", + port_str.c_str(), + "--workers", + "1", + nullptr); + _exit(127); + } + + if (!tcp_connect_check(port_, 15000)) { + skip_reason_ = "cuopt_grpc_server failed to start within 15 seconds"; + kill(server_pid_, SIGKILL); + waitpid(server_pid_, nullptr, 0); + server_pid_ = -1; + return; + } + + const char* cv = getenv("CUDA_VISIBLE_DEVICES"); + const char* rh = getenv("CUOPT_REMOTE_HOST"); + const char* rp = getenv("CUOPT_REMOTE_PORT"); + orig_cuda_visible_ = cv ? cv : ""; + orig_remote_host_ = rh ? rh : ""; + orig_remote_port_ = rp ? rp : ""; + cuda_was_set_ = (cv != nullptr); + host_was_set_ = (rh != nullptr); + port_was_set_ = (rp != nullptr); + setenv("CUDA_VISIBLE_DEVICES", "", 1); setenv("CUOPT_REMOTE_HOST", "localhost", 1); - setenv("CUOPT_REMOTE_PORT", "12345", 1); + setenv("CUOPT_REMOTE_PORT", std::to_string(port_).c_str(), 1); } - ~CPUOnlyTestEnvironment() + static void TearDownTestSuite() { - // Restore original values if (cuda_was_set_) { setenv("CUDA_VISIBLE_DEVICES", orig_cuda_visible_.c_str(), 1); } else { unsetenv("CUDA_VISIBLE_DEVICES"); } - if (host_was_set_) { setenv("CUOPT_REMOTE_HOST", orig_remote_host_.c_str(), 1); } else { unsetenv("CUOPT_REMOTE_HOST"); } - if (port_was_set_) { setenv("CUOPT_REMOTE_PORT", orig_remote_port_.c_str(), 1); } else { unsetenv("CUOPT_REMOTE_PORT"); } + + if (server_pid_ > 0) { + kill(server_pid_, SIGTERM); + int status; + int wait_ms = 0; + while (wait_ms < 5000) { + if (waitpid(server_pid_, &status, WNOHANG) != 0) break; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + wait_ms += 100; + } + if (waitpid(server_pid_, &status, WNOHANG) == 0) { + kill(server_pid_, SIGKILL); + waitpid(server_pid_, &status, 0); + } + server_pid_ = -1; + } + } + + void SetUp() override + { + if (!skip_reason_.empty()) { GTEST_SKIP() << skip_reason_; } } - private: - std::string orig_cuda_visible_; - std::string orig_remote_host_; - std::string orig_remote_port_; - bool cuda_was_set_; - bool host_was_set_; - bool port_was_set_; + static std::string server_path_; + static std::string skip_reason_; + static pid_t server_pid_; + static int port_; + + static std::string orig_cuda_visible_; + static std::string orig_remote_host_; + static std::string orig_remote_port_; + static bool cuda_was_set_; + static bool host_was_set_; + static bool port_was_set_; }; -// TODO: Add numerical assertions once gRPC remote solver replaces the stub implementation. -// Currently validates that the CPU-only C API path completes without errors. -TEST(c_api_cpu_only, lp_solve) +std::string CpuOnlyWithServerTest::server_path_; +std::string CpuOnlyWithServerTest::skip_reason_; +pid_t CpuOnlyWithServerTest::server_pid_ = -1; +int CpuOnlyWithServerTest::port_ = 0; +std::string CpuOnlyWithServerTest::orig_cuda_visible_; +std::string CpuOnlyWithServerTest::orig_remote_host_; +std::string CpuOnlyWithServerTest::orig_remote_port_; +bool CpuOnlyWithServerTest::cuda_was_set_ = false; +bool CpuOnlyWithServerTest::host_was_set_ = false; +bool CpuOnlyWithServerTest::port_was_set_ = false; + +TEST_F(CpuOnlyWithServerTest, lp_solve) { - CPUOnlyTestEnvironment env; const std::string& rapidsDatasetRootDir = cuopt::test::get_rapids_dataset_root_dir(); std::string lp_file = rapidsDatasetRootDir + "/linear_programming/afiro_original.mps"; EXPECT_EQ(test_cpu_only_execution(lp_file.c_str()), CUOPT_SUCCESS); } -// TODO: Add numerical assertions once gRPC remote solver replaces the stub implementation. -TEST(c_api_cpu_only, mip_solve) +TEST_F(CpuOnlyWithServerTest, mip_solve) { - CPUOnlyTestEnvironment env; const std::string& rapidsDatasetRootDir = cuopt::test::get_rapids_dataset_root_dir(); std::string mip_file = rapidsDatasetRootDir + "/mip/bb_optimality.mps"; EXPECT_EQ(test_cpu_only_mip_execution(mip_file.c_str()), CUOPT_SUCCESS); diff --git a/cpp/tests/linear_programming/grpc/CMakeLists.txt b/cpp/tests/linear_programming/grpc/CMakeLists.txt new file mode 100644 index 0000000000..3a84ddd0bd --- /dev/null +++ b/cpp/tests/linear_programming/grpc/CMakeLists.txt @@ -0,0 +1,130 @@ +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on + +# ################################################################################################## +# - gRPC Client Unit Tests ------------------------------------------------------------------------- +# These tests use mock stubs and don't require a running server + +add_executable(GRPC_CLIENT_TEST + ${CMAKE_CURRENT_SOURCE_DIR}/grpc_client_test.cpp +) + +target_include_directories(GRPC_CLIENT_TEST + PRIVATE + "${CUOPT_SOURCE_DIR}/include" + "${CUOPT_SOURCE_DIR}/src/grpc" + "${CUOPT_SOURCE_DIR}/src/grpc/client" + "${CUOPT_TEST_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}" # For grpc_client_test_helper.hpp + "${CMAKE_BINARY_DIR}" # For generated protobuf headers +) + +target_link_libraries(GRPC_CLIENT_TEST + PRIVATE + cuopt + GTest::gmock + GTest::gmock_main + gRPC::grpc++ + protobuf::libprotobuf +) + +if(NOT DEFINED INSTALL_TARGET OR "${INSTALL_TARGET}" STREQUAL "") + target_link_options(GRPC_CLIENT_TEST PRIVATE -Wl,--enable-new-dtags) +endif() + +add_test(NAME GRPC_CLIENT_TEST COMMAND GRPC_CLIENT_TEST) + +install( + TARGETS GRPC_CLIENT_TEST + COMPONENT testing + DESTINATION bin/gtests/libcuopt + EXCLUDE_FROM_ALL +) + +# ################################################################################################## +# - gRPC Pipe Serialization Unit Tests ------------------------------------------------------------- +# These tests verify pipe serialization round-trips using real pipe(2) FDs. +# No running server is required; pipe I/O functions are self-contained in the test. + +add_executable(GRPC_PIPE_SERIALIZATION_TEST + ${CMAKE_CURRENT_SOURCE_DIR}/grpc_pipe_serialization_test.cpp + "${CUOPT_SOURCE_DIR}/src/grpc/server/grpc_pipe_io.cpp" +) + +target_include_directories(GRPC_PIPE_SERIALIZATION_TEST + PRIVATE + "${CUOPT_SOURCE_DIR}/include" + "${CUOPT_SOURCE_DIR}/src/grpc" + "${CUOPT_SOURCE_DIR}/src/grpc/server" + "${CMAKE_BINARY_DIR}" # For generated protobuf headers +) + +target_link_libraries(GRPC_PIPE_SERIALIZATION_TEST + PRIVATE + cuopt + GTest::gtest + GTest::gtest_main + protobuf::libprotobuf +) + +if(NOT DEFINED INSTALL_TARGET OR "${INSTALL_TARGET}" STREQUAL "") + target_link_options(GRPC_PIPE_SERIALIZATION_TEST PRIVATE -Wl,--enable-new-dtags) +endif() + +add_test(NAME GRPC_PIPE_SERIALIZATION_TEST COMMAND GRPC_PIPE_SERIALIZATION_TEST) + +install( + TARGETS GRPC_PIPE_SERIALIZATION_TEST + COMPONENT testing + DESTINATION bin/gtests/libcuopt + EXCLUDE_FROM_ALL +) + +# ################################################################################################## +# - gRPC Integration Tests ------------------------------------------------------------------------- +# These tests start a real server process and run end-to-end tests + +add_executable(GRPC_INTEGRATION_TEST + ${CMAKE_CURRENT_SOURCE_DIR}/grpc_integration_test.cpp +) + +target_include_directories(GRPC_INTEGRATION_TEST + PRIVATE + "${CUOPT_SOURCE_DIR}/include" + "${CUOPT_SOURCE_DIR}/src/grpc" + "${CUOPT_SOURCE_DIR}/src/grpc/client" + "${CUOPT_TEST_DIR}" + "${CMAKE_BINARY_DIR}" # For generated protobuf headers +) + +target_link_libraries(GRPC_INTEGRATION_TEST + PRIVATE + cuopt + GTest::gmock + GTest::gmock_main + gRPC::grpc++ + protobuf::libprotobuf +) + +if(NOT DEFINED INSTALL_TARGET OR "${INSTALL_TARGET}" STREQUAL "") + target_link_options(GRPC_INTEGRATION_TEST PRIVATE -Wl,--enable-new-dtags) +endif() + +# Integration tests need the server binary to be built first +add_dependencies(GRPC_INTEGRATION_TEST cuopt_grpc_server) + +add_test( + NAME GRPC_INTEGRATION_TEST + COMMAND ${CMAKE_COMMAND} -E env + "CUOPT_GRPC_SERVER_PATH=$" + $ +) + +install( + TARGETS GRPC_INTEGRATION_TEST + COMPONENT testing + DESTINATION bin/gtests/libcuopt + EXCLUDE_FROM_ALL +) diff --git a/cpp/tests/linear_programming/grpc/grpc_client_test.cpp b/cpp/tests/linear_programming/grpc/grpc_client_test.cpp new file mode 100644 index 0000000000..46a18dc026 --- /dev/null +++ b/cpp/tests/linear_programming/grpc/grpc_client_test.cpp @@ -0,0 +1,1634 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +/** + * @file grpc_client_test.cpp + * @brief Unit tests for grpc_client_t using mock stubs + * + * These tests verify client-side error handling without requiring a real server. + * For integration tests with a real server, see grpc_integration_test.cpp. + */ + +#include +#include + +#include "grpc_client_test_helper.hpp" + +#include +#include +#include +#include +#include +#include "grpc_client.hpp" +#include "grpc_service_mapper.hpp" + +#include +#include +#include +#include + +#include + +using namespace cuopt::linear_programming; +using namespace ::testing; + +/** + * @brief Mock stub for CuOptRemoteService + * + * This mock allows us to control exactly what the "server" returns + * without running an actual server. + */ +class MockCuOptStub : public cuopt::remote::CuOptRemoteService::StubInterface { + public: + // Unary RPCs + MOCK_METHOD(grpc::Status, + SubmitJob, + (grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + CheckStatus, + (grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + GetResult, + (grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + cuopt::remote::ResultResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + DeleteResult, + (grpc::ClientContext*, + const cuopt::remote::DeleteRequest&, + cuopt::remote::DeleteResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + CancelJob, + (grpc::ClientContext*, + const cuopt::remote::CancelRequest&, + cuopt::remote::CancelResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + WaitForCompletion, + (grpc::ClientContext*, + const cuopt::remote::WaitRequest&, + cuopt::remote::WaitResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + GetIncumbents, + (grpc::ClientContext*, + const cuopt::remote::IncumbentRequest&, + cuopt::remote::IncumbentResponse*), + (override)); + + // Streaming RPCs - these need special handling + // Chunked result download RPCs + MOCK_METHOD(grpc::Status, + StartChunkedDownload, + (grpc::ClientContext*, + const cuopt::remote::StartChunkedDownloadRequest&, + cuopt::remote::StartChunkedDownloadResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + GetResultChunk, + (grpc::ClientContext*, + const cuopt::remote::GetResultChunkRequest&, + cuopt::remote::GetResultChunkResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + FinishChunkedDownload, + (grpc::ClientContext*, + const cuopt::remote::FinishChunkedDownloadRequest&, + cuopt::remote::FinishChunkedDownloadResponse*), + (override)); + + MOCK_METHOD(grpc::ClientReaderInterface*, + StreamLogsRaw, + (grpc::ClientContext*, const cuopt::remote::StreamLogsRequest&), + (override)); + + // Chunked upload RPCs + MOCK_METHOD(grpc::Status, + StartChunkedUpload, + (grpc::ClientContext*, + const cuopt::remote::StartChunkedUploadRequest&, + cuopt::remote::StartChunkedUploadResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + SendArrayChunk, + (grpc::ClientContext*, + const cuopt::remote::SendArrayChunkRequest&, + cuopt::remote::SendArrayChunkResponse*), + (override)); + + MOCK_METHOD(grpc::Status, + FinishChunkedUpload, + (grpc::ClientContext*, + const cuopt::remote::FinishChunkedUploadRequest&, + cuopt::remote::SubmitJobResponse*), + (override)); + + // Required by interface - async versions (not used in our client but required for interface) + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncSubmitJobRaw, + (grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncSubmitJobRaw, + (grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncCheckStatusRaw, + (grpc::ClientContext*, const cuopt::remote::StatusRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncCheckStatusRaw, + (grpc::ClientContext*, const cuopt::remote::StatusRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncGetResultRaw, + (grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncGetResultRaw, + (grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncDeleteResultRaw, + (grpc::ClientContext*, const cuopt::remote::DeleteRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncDeleteResultRaw, + (grpc::ClientContext*, const cuopt::remote::DeleteRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncCancelJobRaw, + (grpc::ClientContext*, const cuopt::remote::CancelRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncCancelJobRaw, + (grpc::ClientContext*, const cuopt::remote::CancelRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncWaitForCompletionRaw, + (grpc::ClientContext*, const cuopt::remote::WaitRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncWaitForCompletionRaw, + (grpc::ClientContext*, const cuopt::remote::WaitRequest&, grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncGetIncumbentsRaw, + (grpc::ClientContext*, + const cuopt::remote::IncumbentRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncGetIncumbentsRaw, + (grpc::ClientContext*, + const cuopt::remote::IncumbentRequest&, + grpc::CompletionQueue*), + (override)); + + // Async chunked result download RPCs + MOCK_METHOD( + grpc::ClientAsyncResponseReaderInterface*, + AsyncStartChunkedDownloadRaw, + (grpc::ClientContext*, + const cuopt::remote::StartChunkedDownloadRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD( + grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncStartChunkedDownloadRaw, + (grpc::ClientContext*, + const cuopt::remote::StartChunkedDownloadRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncGetResultChunkRaw, + (grpc::ClientContext*, + const cuopt::remote::GetResultChunkRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncGetResultChunkRaw, + (grpc::ClientContext*, + const cuopt::remote::GetResultChunkRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD( + grpc::ClientAsyncResponseReaderInterface*, + AsyncFinishChunkedDownloadRaw, + (grpc::ClientContext*, + const cuopt::remote::FinishChunkedDownloadRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD( + grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncFinishChunkedDownloadRaw, + (grpc::ClientContext*, + const cuopt::remote::FinishChunkedDownloadRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD( + grpc::ClientAsyncReaderInterface*, + AsyncStreamLogsRaw, + (grpc::ClientContext*, const cuopt::remote::StreamLogsRequest&, grpc::CompletionQueue*, void*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncReaderInterface*, + PrepareAsyncStreamLogsRaw, + (grpc::ClientContext*, + const cuopt::remote::StreamLogsRequest&, + grpc::CompletionQueue*), + (override)); + + // Async chunked upload RPCs + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncStartChunkedUploadRaw, + (grpc::ClientContext*, + const cuopt::remote::StartChunkedUploadRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncStartChunkedUploadRaw, + (grpc::ClientContext*, + const cuopt::remote::StartChunkedUploadRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncSendArrayChunkRaw, + (grpc::ClientContext*, + const cuopt::remote::SendArrayChunkRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncSendArrayChunkRaw, + (grpc::ClientContext*, + const cuopt::remote::SendArrayChunkRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + AsyncFinishChunkedUploadRaw, + (grpc::ClientContext*, + const cuopt::remote::FinishChunkedUploadRequest&, + grpc::CompletionQueue*), + (override)); + + MOCK_METHOD(grpc::ClientAsyncResponseReaderInterface*, + PrepareAsyncFinishChunkedUploadRaw, + (grpc::ClientContext*, + const cuopt::remote::FinishChunkedUploadRequest&, + grpc::CompletionQueue*), + (override)); +}; + +/** + * @brief Test fixture for grpc_client_t tests with mock stub injection + */ +class GrpcClientTest : public ::testing::Test { + protected: + std::shared_ptr> mock_stub_; + std::unique_ptr client_; + + void SetUp() override + { + mock_stub_ = std::make_shared>(); + + // Create a client and inject the mock stub + grpc_client_config_t config; + config.server_address = "mock://test"; + client_ = std::make_unique(config); + + // Inject the mock stub using typed helper + grpc_test_inject_mock_stub_typed(*client_, mock_stub_); + } + + void TearDown() override + { + client_.reset(); + mock_stub_.reset(); + } +}; + +// ============================================================================= +// CheckStatus Tests +// ============================================================================= + +TEST_F(GrpcClientTest, CheckStatus_Success_Completed) +{ + // Setup mock to return COMPLETED status + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest& req, + cuopt::remote::StatusResponse* resp) { + EXPECT_EQ(req.job_id(), "test-job-123"); + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_message("Job completed successfully"); + resp->set_result_size_bytes(1024); + return grpc::Status::OK; + }); + + auto result = client_->check_status("test-job-123"); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.status, job_status_t::COMPLETED); + EXPECT_EQ(result.message, "Job completed successfully"); + EXPECT_EQ(result.result_size_bytes, 1024); +} + +TEST_F(GrpcClientTest, CheckStatus_Success_Processing) +{ + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::PROCESSING); + resp->set_message("Solving..."); + return grpc::Status::OK; + }); + + auto result = client_->check_status("test-job-456"); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.status, job_status_t::PROCESSING); +} + +TEST_F(GrpcClientTest, CheckStatus_JobNotFound) +{ + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::NOT_FOUND); + resp->set_message("Job not found"); + return grpc::Status::OK; + }); + + auto result = client_->check_status("nonexistent-job"); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.status, job_status_t::NOT_FOUND); +} + +TEST_F(GrpcClientTest, CheckStatus_RpcFailure_Unavailable) +{ + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse*) { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Server unavailable"); + }); + + auto result = client_->check_status("test-job"); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Server unavailable") != std::string::npos); +} + +TEST_F(GrpcClientTest, CheckStatus_RpcFailure_Internal) +{ + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse*) { + return grpc::Status(grpc::StatusCode::INTERNAL, "Internal server error"); + }); + + auto result = client_->check_status("test-job"); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Internal server error") != std::string::npos); +} + +// ============================================================================= +// CancelJob Tests +// ============================================================================= + +TEST_F(GrpcClientTest, CancelJob_Success) +{ + EXPECT_CALL(*mock_stub_, CancelJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::CancelRequest& req, + cuopt::remote::CancelResponse* resp) { + EXPECT_EQ(req.job_id(), "job-to-cancel"); + resp->set_job_status(cuopt::remote::CANCELLED); + resp->set_message("Job cancelled"); + return grpc::Status::OK; + }); + + auto result = client_->cancel_job("job-to-cancel"); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.job_status, job_status_t::CANCELLED); +} + +TEST_F(GrpcClientTest, CancelJob_AlreadyCompleted) +{ + EXPECT_CALL(*mock_stub_, CancelJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::CancelRequest&, + cuopt::remote::CancelResponse* resp) { + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_message("Job already completed"); + return grpc::Status::OK; + }); + + auto result = client_->cancel_job("completed-job"); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.job_status, job_status_t::COMPLETED); +} + +TEST_F(GrpcClientTest, CancelJob_RpcFailure) +{ + EXPECT_CALL(*mock_stub_, CancelJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::CancelRequest&, + cuopt::remote::CancelResponse*) { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Server down"); + }); + + auto result = client_->cancel_job("job-id"); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Server down") != std::string::npos); +} + +// ============================================================================= +// DeleteJob Tests +// ============================================================================= + +TEST_F(GrpcClientTest, DeleteJob_Success) +{ + EXPECT_CALL(*mock_stub_, DeleteResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::DeleteRequest& req, + cuopt::remote::DeleteResponse* resp) { + EXPECT_EQ(req.job_id(), "job-to-delete"); + resp->set_status(cuopt::remote::SUCCESS); + return grpc::Status::OK; + }); + + bool result = client_->delete_job("job-to-delete"); + + EXPECT_TRUE(result); +} + +TEST_F(GrpcClientTest, DeleteJob_NotFound) +{ + EXPECT_CALL(*mock_stub_, DeleteResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::DeleteRequest&, + cuopt::remote::DeleteResponse* resp) { + resp->set_status(cuopt::remote::ERROR_NOT_FOUND); + return grpc::Status::OK; + }); + + bool result = client_->delete_job("nonexistent-job"); + + // Job not found should return false to prevent silent failures + EXPECT_FALSE(result); +} + +TEST_F(GrpcClientTest, DeleteJob_RpcFailure) +{ + EXPECT_CALL(*mock_stub_, DeleteResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::DeleteRequest&, + cuopt::remote::DeleteResponse*) { + return grpc::Status(grpc::StatusCode::INTERNAL, "Delete failed"); + }); + + bool result = client_->delete_job("job-id"); + + EXPECT_FALSE(result); +} + +// ============================================================================= +// WaitForCompletion Tests +// ============================================================================= + +TEST_F(GrpcClientTest, WaitForCompletion_Success) +{ + EXPECT_CALL(*mock_stub_, WaitForCompletion(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::WaitRequest& req, + cuopt::remote::WaitResponse* resp) { + EXPECT_EQ(req.job_id(), "wait-job"); + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_message("Done"); + resp->set_result_size_bytes(2048); + return grpc::Status::OK; + }); + + auto result = client_->wait_for_completion("wait-job"); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.status, job_status_t::COMPLETED); + EXPECT_EQ(result.result_size_bytes, 2048); +} + +TEST_F(GrpcClientTest, WaitForCompletion_Failed) +{ + EXPECT_CALL(*mock_stub_, WaitForCompletion(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::WaitRequest&, + cuopt::remote::WaitResponse* resp) { + resp->set_job_status(cuopt::remote::FAILED); + resp->set_message("Solve failed: out of memory"); + return grpc::Status::OK; + }); + + auto result = client_->wait_for_completion("failed-job"); + + EXPECT_TRUE(result.success); // RPC succeeded, job failed + EXPECT_EQ(result.status, job_status_t::FAILED); + EXPECT_TRUE(result.message.find("out of memory") != std::string::npos); +} + +TEST_F(GrpcClientTest, WaitForCompletion_RpcTimeout) +{ + EXPECT_CALL(*mock_stub_, WaitForCompletion(_, _, _)) + .WillOnce( + [](grpc::ClientContext*, const cuopt::remote::WaitRequest&, cuopt::remote::WaitResponse*) { + return grpc::Status(grpc::StatusCode::DEADLINE_EXCEEDED, "Deadline exceeded"); + }); + + auto result = client_->wait_for_completion("timeout-job"); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Deadline exceeded") != std::string::npos); +} + +// ============================================================================= +// GetIncumbents Tests +// ============================================================================= + +TEST_F(GrpcClientTest, GetIncumbents_Success) +{ + EXPECT_CALL(*mock_stub_, GetIncumbents(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::IncumbentRequest& req, + cuopt::remote::IncumbentResponse* resp) { + EXPECT_EQ(req.job_id(), "mip-job"); + EXPECT_EQ(req.from_index(), 0); + + auto* inc1 = resp->add_incumbents(); + inc1->set_index(0); + inc1->set_objective(100.5); + inc1->add_assignment(1.0); + inc1->add_assignment(0.0); + + auto* inc2 = resp->add_incumbents(); + inc2->set_index(1); + inc2->set_objective(95.3); + inc2->add_assignment(1.0); + inc2->add_assignment(1.0); + + resp->set_next_index(2); + resp->set_job_complete(false); + return grpc::Status::OK; + }); + + auto result = client_->get_incumbents("mip-job", 0, 10); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.incumbents.size(), 2); + EXPECT_EQ(result.incumbents[0].index, 0); + EXPECT_DOUBLE_EQ(result.incumbents[0].objective, 100.5); + EXPECT_EQ(result.incumbents[1].index, 1); + EXPECT_DOUBLE_EQ(result.incumbents[1].objective, 95.3); + EXPECT_EQ(result.next_index, 2); + EXPECT_FALSE(result.job_complete); +} + +TEST_F(GrpcClientTest, GetIncumbents_NoNewIncumbents) +{ + EXPECT_CALL(*mock_stub_, GetIncumbents(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::IncumbentRequest& req, + cuopt::remote::IncumbentResponse* resp) { + resp->set_next_index(req.from_index()); // No new incumbents + resp->set_job_complete(false); + return grpc::Status::OK; + }); + + auto result = client_->get_incumbents("mip-job", 5, 10); + + EXPECT_TRUE(result.success); + EXPECT_TRUE(result.incumbents.empty()); + EXPECT_EQ(result.next_index, 5); +} + +// ============================================================================= +// Connection Test (without mock - tests real connection failure) +// ============================================================================= + +TEST(GrpcClientConnectionTest, Connect_ServerUnavailable) +{ + grpc_client_config_t config; + config.server_address = "localhost:1"; // Invalid port + config.timeout_seconds = 1; + + grpc_client_t client(config); + EXPECT_FALSE(client.connect()); + EXPECT_FALSE(client.get_last_error().empty()); +} + +TEST(GrpcClientConnectionTest, IsConnected_BeforeConnect) +{ + grpc_client_config_t config; + config.server_address = "localhost:9999"; + + grpc_client_t client(config); + EXPECT_FALSE(client.is_connected()); +} + +// ============================================================================= +// Transient Failure / Retry Behavior Tests +// ============================================================================= + +TEST_F(GrpcClientTest, CheckStatus_TransientFailureThenSuccess) +{ + // First call fails with UNAVAILABLE (transient), second succeeds + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse*) { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Temporary failure"); + }) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::COMPLETED); + return grpc::Status::OK; + }); + + // First call should fail + auto result1 = client_->check_status("retry-job"); + EXPECT_FALSE(result1.success); + + // Second call should succeed (simulates retry at higher level) + auto result2 = client_->check_status("retry-job"); + EXPECT_TRUE(result2.success); + EXPECT_EQ(result2.status, job_status_t::COMPLETED); +} + +TEST_F(GrpcClientTest, GetResult_InternalError) +{ + // Server reports internal error + EXPECT_CALL(*mock_stub_, GetResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + cuopt::remote::ResultResponse*) { + return grpc::Status(grpc::StatusCode::INTERNAL, "Internal server error"); + }); + + auto result = client_->get_lp_result("error-job"); + EXPECT_FALSE(result.success); + EXPECT_FALSE(result.error_message.empty()); +} + +TEST_F(GrpcClientTest, CancelJob_DeadlineExceeded) +{ + EXPECT_CALL(*mock_stub_, CancelJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::CancelRequest&, + cuopt::remote::CancelResponse*) { + return grpc::Status(grpc::StatusCode::DEADLINE_EXCEEDED, "Request timeout"); + }); + + auto result = client_->cancel_job("timeout-job"); + EXPECT_FALSE(result.success); +} + +// ============================================================================= +// Malformed Response Tests +// ============================================================================= + +TEST_F(GrpcClientTest, CheckStatus_MalformedResponse_InvalidStatus) +{ + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + // Set an invalid/unexpected status value + resp->set_job_status(static_cast(999)); + return grpc::Status::OK; + }); + + auto result = client_->check_status("malformed-job"); + + // Should handle gracefully - either map to unknown or report error + EXPECT_TRUE(result.success); // RPC succeeded +} + +TEST_F(GrpcClientTest, GetIncumbents_MalformedResponse_NegativeIndex) +{ + EXPECT_CALL(*mock_stub_, GetIncumbents(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::IncumbentRequest&, + cuopt::remote::IncumbentResponse* resp) { + auto* inc = resp->add_incumbents(); + inc->set_index(-1); // Invalid negative index + inc->set_objective(100.0); + resp->set_next_index(-5); // Invalid + return grpc::Status::OK; + }); + + auto result = client_->get_incumbents("malformed-job", 0, 10); + + // Should handle gracefully + EXPECT_TRUE(result.success); +} + +TEST_F(GrpcClientTest, WaitForCompletion_EmptyMessage) +{ + EXPECT_CALL(*mock_stub_, WaitForCompletion(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::WaitRequest&, + cuopt::remote::WaitResponse* resp) { + // Don't set any fields - empty response + return grpc::Status::OK; + }); + + auto result = client_->wait_for_completion("empty-response-job"); + + // Should handle gracefully with default values + EXPECT_TRUE(result.success); +} + +// ============================================================================= +// Chunked Download Tests (Mock) +// ============================================================================= + +TEST_F(GrpcClientTest, ChunkedDownload_FallbackOnResourceExhausted) +{ + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_result_size_bytes(500); + resp->set_max_message_bytes(256 * 1024 * 1024); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock_stub_, GetResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + cuopt::remote::ResultResponse*) { + return grpc::Status(grpc::StatusCode::RESOURCE_EXHAUSTED, "Too large"); + }); + + EXPECT_CALL(*mock_stub_, StartChunkedDownload(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StartChunkedDownloadRequest& req, + cuopt::remote::StartChunkedDownloadResponse* resp) { + resp->set_download_id("dl-001"); + auto* h = resp->mutable_header(); + h->set_is_mip(false); + h->set_lp_termination_status(cuopt::remote::PDLP_OPTIMAL); + h->set_primal_objective(-464.753); + auto* arr = h->add_arrays(); + arr->set_field_id(cuopt::remote::RESULT_PRIMAL_SOLUTION); + arr->set_total_elements(2); + arr->set_element_size_bytes(8); + resp->set_max_message_bytes(4 * 1024 * 1024); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock_stub_, GetResultChunk(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultChunkRequest& req, + cuopt::remote::GetResultChunkResponse* resp) { + EXPECT_EQ(req.download_id(), "dl-001"); + EXPECT_EQ(req.field_id(), cuopt::remote::RESULT_PRIMAL_SOLUTION); + resp->set_download_id("dl-001"); + resp->set_field_id(req.field_id()); + resp->set_element_offset(0); + resp->set_elements_in_chunk(2); + double vals[2] = {1.5, 2.5}; + resp->set_data(reinterpret_cast(vals), sizeof(vals)); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock_stub_, FinishChunkedDownload(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::FinishChunkedDownloadRequest& req, + cuopt::remote::FinishChunkedDownloadResponse* resp) { + resp->set_download_id(req.download_id()); + return grpc::Status::OK; + }); + + auto lp_result = client_->get_lp_result("test-job"); + + EXPECT_TRUE(lp_result.success) << lp_result.error_message; + ASSERT_NE(lp_result.solution, nullptr); + EXPECT_NEAR(lp_result.solution->get_objective_value(), -464.753, 0.01); +} + +TEST_F(GrpcClientTest, ChunkedDownload_StartFails) +{ + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_result_size_bytes(1000000); + resp->set_max_message_bytes(100); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock_stub_, StartChunkedDownload(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StartChunkedDownloadRequest&, + cuopt::remote::StartChunkedDownloadResponse*) { + return grpc::Status(grpc::StatusCode::NOT_FOUND, "Job not found"); + }); + + auto lp_result = client_->get_lp_result("test-job"); + + EXPECT_FALSE(lp_result.success); + EXPECT_TRUE(lp_result.error_message.find("StartChunkedDownload") != std::string::npos); +} + +// ============================================================================= +// Helper: Build minimal test problems +// ============================================================================= + +namespace { + +cpu_optimization_problem_t create_test_lp_problem() +{ + cpu_optimization_problem_t problem; + + // minimize x subject to x >= 1 + std::vector obj = {1.0}; + std::vector var_lb = {0.0}; + std::vector var_ub = {10.0}; + std::vector con_lb = {1.0}; + std::vector con_ub = {1e20}; + std::vector A_vals = {1.0}; + std::vector A_idx = {0}; + std::vector A_off = {0, 1}; + + problem.set_objective_coefficients(obj.data(), 1); + problem.set_maximize(false); + problem.set_variable_lower_bounds(var_lb.data(), 1); + problem.set_variable_upper_bounds(var_ub.data(), 1); + problem.set_csr_constraint_matrix(A_vals.data(), 1, A_idx.data(), 1, A_off.data(), 2); + problem.set_constraint_lower_bounds(con_lb.data(), 1); + problem.set_constraint_upper_bounds(con_ub.data(), 1); + + return problem; +} + +cpu_optimization_problem_t create_test_mip_problem() +{ + cpu_optimization_problem_t problem; + + // minimize x subject to x >= 1, x integer + std::vector obj = {1.0}; + std::vector var_lb = {0.0}; + std::vector var_ub = {10.0}; + std::vector var_ty = {var_t::INTEGER}; + std::vector con_lb = {1.0}; + std::vector con_ub = {1e20}; + std::vector A_vals = {1.0}; + std::vector A_idx = {0}; + std::vector A_off = {0, 1}; + + problem.set_objective_coefficients(obj.data(), 1); + problem.set_maximize(false); + problem.set_variable_lower_bounds(var_lb.data(), 1); + problem.set_variable_upper_bounds(var_ub.data(), 1); + problem.set_variable_types(var_ty.data(), 1); + problem.set_csr_constraint_matrix(A_vals.data(), 1, A_idx.data(), 1, A_off.data(), 2); + problem.set_constraint_lower_bounds(con_lb.data(), 1); + problem.set_constraint_upper_bounds(con_ub.data(), 1); + + return problem; +} + +} // namespace + +// ============================================================================= +// SubmitLP / SubmitMIP Tests +// ============================================================================= + +TEST_F(GrpcClientTest, SubmitLP_Success) +{ + EXPECT_CALL(*mock_stub_, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest& req, + cuopt::remote::SubmitJobResponse* resp) { + EXPECT_TRUE(req.has_lp_request()); + resp->set_job_id("lp-job-001"); + resp->set_message("Job submitted"); + return grpc::Status::OK; + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client_->submit_lp(problem, settings); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.job_id, "lp-job-001"); + EXPECT_TRUE(result.error_message.empty()); +} + +TEST_F(GrpcClientTest, SubmitLP_NotConnected) +{ + // Create a fresh client that is NOT marked as connected + grpc_client_config_t config; + config.server_address = "mock://disconnected"; + grpc_client_t disconnected_client(config); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + + auto result = disconnected_client.submit_lp(problem, settings); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Not connected") != std::string::npos); +} + +TEST_F(GrpcClientTest, SubmitLP_RpcFailure) +{ + EXPECT_CALL(*mock_stub_, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse*) { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Server unreachable"); + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + + auto result = client_->submit_lp(problem, settings); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Server unreachable") != std::string::npos); +} + +TEST_F(GrpcClientTest, SubmitLP_EmptyJobId) +{ + EXPECT_CALL(*mock_stub_, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse* resp) { + resp->set_job_id(""); + return grpc::Status::OK; + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + + auto result = client_->submit_lp(problem, settings); + + EXPECT_FALSE(result.success); +} + +TEST_F(GrpcClientTest, SubmitMIP_Success) +{ + EXPECT_CALL(*mock_stub_, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest& req, + cuopt::remote::SubmitJobResponse* resp) { + EXPECT_TRUE(req.has_mip_request()); + resp->set_job_id("mip-job-001"); + resp->set_message("MIP job submitted"); + return grpc::Status::OK; + }); + + auto problem = create_test_mip_problem(); + mip_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client_->submit_mip(problem, settings); + + EXPECT_TRUE(result.success); + EXPECT_EQ(result.job_id, "mip-job-001"); +} + +TEST_F(GrpcClientTest, SubmitMIP_RpcFailure) +{ + EXPECT_CALL(*mock_stub_, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse*) { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Server unreachable"); + }); + + auto problem = create_test_mip_problem(); + mip_solver_settings_t settings; + + auto result = client_->submit_mip(problem, settings); + + EXPECT_FALSE(result.success); +} + +// ============================================================================= +// SolveLP / SolveMIP Tests (end-to-end mock flow) +// ============================================================================= + +TEST_F(GrpcClientTest, SolveLP_SuccessWithPolling) +{ + // 1. SubmitJob succeeds + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + grpc_client_config_t cfg; + cfg.server_address = "mock://test"; + cfg.poll_interval_ms = 10; + cfg.timeout_seconds = 5; + + auto client = std::make_unique(cfg); + auto mock = std::make_shared>(); + grpc_test_inject_mock_stub_typed(*client, mock); + + EXPECT_CALL(*mock, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest& req, + cuopt::remote::SubmitJobResponse* resp) { + EXPECT_TRUE(req.has_lp_request()); + resp->set_job_id("solve-lp-001"); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, CheckStatus(_, _, _)) + .WillRepeatedly([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_result_size_bytes(64); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, GetResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultRequest& req, + cuopt::remote::ResultResponse* resp) { + EXPECT_EQ(req.job_id(), "solve-lp-001"); + cuopt::remote::LPSolution solution; + solution.add_primal_solution(1.0); + solution.set_primal_objective(1.0); + solution.set_termination_status(cuopt::remote::PDLP_OPTIMAL); + resp->mutable_lp_solution()->CopyFrom(solution); + resp->set_status(cuopt::remote::SUCCESS); + return grpc::Status::OK; + }); + + auto result = client->solve_lp(problem, settings); + + EXPECT_TRUE(result.success) << "Error: " << result.error_message; + EXPECT_NE(result.solution, nullptr); + if (result.solution) { EXPECT_DOUBLE_EQ(result.solution->get_objective_value(), 1.0); } +} + +TEST_F(GrpcClientTest, SolveLP_SuccessWithWait) +{ + grpc_client_config_t cfg; + cfg.server_address = "mock://test"; + cfg.poll_interval_ms = 10; + cfg.timeout_seconds = 5; + + auto client = std::make_unique(cfg); + auto mock = std::make_shared>(); + grpc_test_inject_mock_stub_typed(*client, mock); + + EXPECT_CALL(*mock, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse* resp) { + resp->set_job_id("wait-lp-001"); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, CheckStatus(_, _, _)) + .WillRepeatedly([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_result_size_bytes(64); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, GetResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + cuopt::remote::ResultResponse* resp) { + cuopt::remote::LPSolution solution; + solution.add_primal_solution(1.0); + solution.set_primal_objective(1.0); + solution.set_termination_status(cuopt::remote::PDLP_OPTIMAL); + resp->mutable_lp_solution()->CopyFrom(solution); + resp->set_status(cuopt::remote::SUCCESS); + return grpc::Status::OK; + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + + EXPECT_TRUE(result.success) << "Error: " << result.error_message; + EXPECT_NE(result.solution, nullptr); +} + +TEST_F(GrpcClientTest, SolveLP_JobFails) +{ + grpc_client_config_t cfg; + cfg.server_address = "mock://test"; + cfg.poll_interval_ms = 10; + cfg.timeout_seconds = 5; + + auto client = std::make_unique(cfg); + auto mock = std::make_shared>(); + grpc_test_inject_mock_stub_typed(*client, mock); + + EXPECT_CALL(*mock, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse* resp) { + resp->set_job_id("fail-lp-001"); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::FAILED); + resp->set_message("Out of GPU memory"); + return grpc::Status::OK; + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Out of GPU memory") != std::string::npos) + << "Error: " << result.error_message; +} + +TEST_F(GrpcClientTest, SolveLP_SubmitFails) +{ + grpc_client_config_t cfg; + cfg.server_address = "mock://test"; + cfg.timeout_seconds = 5; + + auto client = std::make_unique(cfg); + auto mock = std::make_shared>(); + grpc_test_inject_mock_stub_typed(*client, mock); + + EXPECT_CALL(*mock, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse*) { + return grpc::Status(grpc::StatusCode::INTERNAL, "Server crashed"); + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + + auto result = client->solve_lp(problem, settings); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Server crashed") != std::string::npos) + << "Error: " << result.error_message; +} + +TEST_F(GrpcClientTest, SolveLP_NotConnected) +{ + grpc_client_config_t cfg; + cfg.server_address = "mock://disconnected"; + + grpc_client_t client(cfg); + // Don't inject mock or mark as connected + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + + auto result = client.solve_lp(problem, settings); + + EXPECT_FALSE(result.success); + EXPECT_TRUE(result.error_message.find("Not connected") != std::string::npos); +} + +TEST_F(GrpcClientTest, SolveMIP_Success) +{ + grpc_client_config_t cfg; + cfg.server_address = "mock://test"; + cfg.poll_interval_ms = 10; + cfg.timeout_seconds = 5; + + auto client = std::make_unique(cfg); + auto mock = std::make_shared>(); + grpc_test_inject_mock_stub_typed(*client, mock); + + EXPECT_CALL(*mock, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest& req, + cuopt::remote::SubmitJobResponse* resp) { + EXPECT_TRUE(req.has_mip_request()); + resp->set_job_id("mip-solve-001"); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, CheckStatus(_, _, _)) + .WillRepeatedly([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::COMPLETED); + resp->set_result_size_bytes(64); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, GetResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + cuopt::remote::ResultResponse* resp) { + cuopt::remote::MIPSolution solution; + solution.add_solution(1.0); + solution.set_objective(1.0); + solution.set_termination_status(cuopt::remote::MIP_OPTIMAL); + resp->mutable_mip_solution()->CopyFrom(solution); + resp->set_status(cuopt::remote::SUCCESS); + return grpc::Status::OK; + }); + + auto problem = create_test_mip_problem(); + mip_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_mip(problem, settings); + + EXPECT_TRUE(result.success) << "Error: " << result.error_message; + EXPECT_NE(result.solution, nullptr); + if (result.solution) { EXPECT_DOUBLE_EQ(result.solution->get_objective_value(), 1.0); } +} + +// ============================================================================= +// GetResult on PROCESSING job +// ============================================================================= + +TEST_F(GrpcClientTest, GetResult_ProcessingJobReturnsError) +{ + // When a job is still PROCESSING, GetResult returns UNAVAILABLE. + // The client's get_result_or_stream first calls CheckStatus; if the job + // is not complete, it should not attempt GetResult at all. + // Here we test the lower-level get_lp_result path with a CheckStatus + // returning PROCESSING (small result size so no streaming fallback). + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::PROCESSING); + resp->set_result_size_bytes(0); + return grpc::Status::OK; + }); + + // GetResult should be called because CheckStatus doesn't show large result + EXPECT_CALL(*mock_stub_, GetResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + cuopt::remote::ResultResponse*) { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Result not ready"); + }); + + auto result = client_->get_lp_result("processing-job"); + EXPECT_FALSE(result.success); +} + +// ============================================================================= +// DeleteJob then verify subsequent operations fail +// ============================================================================= + +TEST_F(GrpcClientTest, DeleteJob_ThenCheckStatusNotFound) +{ + // Delete succeeds + EXPECT_CALL(*mock_stub_, DeleteResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::DeleteRequest& req, + cuopt::remote::DeleteResponse* resp) { + EXPECT_EQ(req.job_id(), "delete-then-check"); + resp->set_status(cuopt::remote::SUCCESS); + return grpc::Status::OK; + }); + + bool deleted = client_->delete_job("delete-then-check"); + EXPECT_TRUE(deleted); + + // Subsequent CheckStatus returns NOT_FOUND + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest& req, + cuopt::remote::StatusResponse* resp) { + EXPECT_EQ(req.job_id(), "delete-then-check"); + resp->set_job_status(cuopt::remote::NOT_FOUND); + resp->set_message("Job not found"); + return grpc::Status::OK; + }); + + auto status = client_->check_status("delete-then-check"); + EXPECT_TRUE(status.success); + EXPECT_EQ(status.status, job_status_t::NOT_FOUND); +} + +TEST_F(GrpcClientTest, DeleteJob_ThenGetResultFails) +{ + EXPECT_CALL(*mock_stub_, DeleteResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::DeleteRequest&, + cuopt::remote::DeleteResponse* resp) { + resp->set_status(cuopt::remote::SUCCESS); + return grpc::Status::OK; + }); + + client_->delete_job("deleted-job"); + + // GetResult after deletion + EXPECT_CALL(*mock_stub_, CheckStatus(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StatusRequest&, + cuopt::remote::StatusResponse* resp) { + resp->set_job_status(cuopt::remote::NOT_FOUND); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock_stub_, GetResult(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::GetResultRequest&, + cuopt::remote::ResultResponse*) { + return grpc::Status(grpc::StatusCode::NOT_FOUND, "Job not found"); + }); + + auto result = client_->get_lp_result("deleted-job"); + EXPECT_FALSE(result.success); +} + +// ============================================================================= +// WaitForCompletion with cancelled job +// ============================================================================= + +TEST_F(GrpcClientTest, WaitForCompletion_Cancelled) +{ + EXPECT_CALL(*mock_stub_, WaitForCompletion(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::WaitRequest& req, + cuopt::remote::WaitResponse* resp) { + EXPECT_EQ(req.job_id(), "cancelled-job"); + resp->set_job_status(cuopt::remote::CANCELLED); + resp->set_message("Job was cancelled"); + return grpc::Status::OK; + }); + + auto result = client_->wait_for_completion("cancelled-job"); + + EXPECT_TRUE(result.success); // RPC succeeded + EXPECT_EQ(result.status, job_status_t::CANCELLED); + EXPECT_TRUE(result.message.find("cancelled") != std::string::npos); +} + +// ============================================================================= +// StreamLogs Tests (Mock) +// ============================================================================= + +class MockLogStream : public grpc::ClientReaderInterface { + public: + explicit MockLogStream(std::vector msgs) + : messages_(std::move(msgs)), idx_(0) + { + } + + bool Read(cuopt::remote::LogMessage* msg) override + { + if (idx_ >= messages_.size()) return false; + *msg = messages_[idx_++]; + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + bool NextMessageSize(uint32_t* sz) override + { + if (idx_ >= messages_.size()) return false; + *sz = messages_[idx_].ByteSizeLong(); + return true; + } + void WaitForInitialMetadata() override {} + + private: + std::vector messages_; + size_t idx_; +}; + +TEST_F(GrpcClientTest, StreamLogs_ReceivesLogLines) +{ + std::vector msgs; + + cuopt::remote::LogMessage msg1; + msg1.set_line("Iteration 1: obj=100.0"); + msg1.set_job_complete(false); + msgs.push_back(msg1); + + cuopt::remote::LogMessage msg2; + msg2.set_line("Iteration 2: obj=50.0"); + msg2.set_job_complete(false); + msgs.push_back(msg2); + + cuopt::remote::LogMessage msg3; + msg3.set_line("Solve complete"); + msg3.set_job_complete(true); + msgs.push_back(msg3); + + auto* mock_reader = new MockLogStream(msgs); + EXPECT_CALL(*mock_stub_, StreamLogsRaw(_, _)).WillOnce(Return(mock_reader)); + + std::vector received_lines; + bool result = client_->stream_logs("log-job", 0, [&](const std::string& line, bool complete) { + received_lines.push_back(line); + return true; // keep streaming + }); + + EXPECT_TRUE(result); + EXPECT_EQ(received_lines.size(), 3); + EXPECT_EQ(received_lines[0], "Iteration 1: obj=100.0"); + EXPECT_EQ(received_lines[2], "Solve complete"); +} + +TEST_F(GrpcClientTest, StreamLogs_CallbackStopsEarly) +{ + std::vector msgs; + + cuopt::remote::LogMessage msg1; + msg1.set_line("Line 1"); + msg1.set_job_complete(false); + msgs.push_back(msg1); + + cuopt::remote::LogMessage msg2; + msg2.set_line("Line 2"); + msg2.set_job_complete(false); + msgs.push_back(msg2); + + auto* mock_reader = new MockLogStream(msgs); + EXPECT_CALL(*mock_stub_, StreamLogsRaw(_, _)).WillOnce(Return(mock_reader)); + + int count = 0; + client_->stream_logs("log-job", 0, [&](const std::string&, bool) { + count++; + return false; // stop after first line + }); + + EXPECT_EQ(count, 1); +} + +TEST_F(GrpcClientTest, StreamLogs_EmptyStream) +{ + std::vector msgs; // empty + + auto* mock_reader = new MockLogStream(msgs); + EXPECT_CALL(*mock_stub_, StreamLogsRaw(_, _)).WillOnce(Return(mock_reader)); + + int count = 0; + bool result = client_->stream_logs("log-job", 0, [&](const std::string&, bool) { + count++; + return true; + }); + + EXPECT_TRUE(result); + EXPECT_EQ(count, 0); +} + +// ============================================================================= +// Chunked Upload Tests (Mock) +// ============================================================================= + +TEST_F(GrpcClientTest, SubmitLP_ChunkedUploadForLargePayload) +{ + grpc_client_config_t cfg; + cfg.server_address = "mock://test"; + cfg.chunked_array_threshold_bytes = 0; // Force chunked upload for all sizes + cfg.chunk_size_bytes = 4 * 1024; + + auto client = std::make_unique(cfg); + auto mock = std::make_shared>(); + grpc_test_inject_mock_stub_typed(*client, mock); + + EXPECT_CALL(*mock, StartChunkedUpload(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::StartChunkedUploadRequest& req, + cuopt::remote::StartChunkedUploadResponse* resp) { + EXPECT_TRUE(req.has_problem_header()); + resp->set_upload_id("chunked-upload-001"); + resp->set_max_message_bytes(4 * 1024 * 1024); + return grpc::Status::OK; + }); + + int chunk_count = 0; + EXPECT_CALL(*mock, SendArrayChunk(_, _, _)) + .WillRepeatedly([&chunk_count](grpc::ClientContext*, + const cuopt::remote::SendArrayChunkRequest& req, + cuopt::remote::SendArrayChunkResponse* resp) { + EXPECT_EQ(req.upload_id(), "chunked-upload-001"); + EXPECT_TRUE(req.has_chunk()); + chunk_count++; + resp->set_upload_id("chunked-upload-001"); + resp->set_chunks_received(chunk_count); + return grpc::Status::OK; + }); + + EXPECT_CALL(*mock, FinishChunkedUpload(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::FinishChunkedUploadRequest& req, + cuopt::remote::SubmitJobResponse* resp) { + EXPECT_EQ(req.upload_id(), "chunked-upload-001"); + resp->set_job_id("chunked-job-001"); + return grpc::Status::OK; + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->submit_lp(problem, settings); + + EXPECT_TRUE(result.success) << "Error: " << result.error_message; + EXPECT_EQ(result.job_id, "chunked-job-001"); + EXPECT_GT(chunk_count, 0) << "Should have sent at least one array chunk"; +} + +TEST_F(GrpcClientTest, SubmitLP_UnaryForSmallPayload) +{ + EXPECT_CALL(*mock_stub_, SubmitJob(_, _, _)) + .WillOnce([](grpc::ClientContext*, + const cuopt::remote::SubmitJobRequest&, + cuopt::remote::SubmitJobResponse* resp) { + resp->set_job_id("unary-lp-001"); + return grpc::Status::OK; + }); + + auto problem = create_test_lp_problem(); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client_->submit_lp(problem, settings); + + EXPECT_TRUE(result.success) << "Error: " << result.error_message; + EXPECT_EQ(result.job_id, "unary-lp-001"); +} diff --git a/cpp/tests/linear_programming/grpc/grpc_client_test_helper.hpp b/cpp/tests/linear_programming/grpc/grpc_client_test_helper.hpp new file mode 100644 index 0000000000..de6d391dce --- /dev/null +++ b/cpp/tests/linear_programming/grpc/grpc_client_test_helper.hpp @@ -0,0 +1,65 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +/** + * @file grpc_client_test_helper.hpp + * @brief Test helper for injecting mock stubs into grpc_client_t + * + * This header is for unit testing only - it exposes internal gRPC types + * that are normally hidden by the PIMPL pattern. Include this only in test code. + */ + +#include + +#include +#include + +#include "grpc_client.hpp" + +namespace cuopt::linear_programming { + +/** + * @brief Inject a mock stub into a grpc_client_t instance for testing + * + * This allows unit tests to provide mock stubs that simulate various + * server responses and error conditions without needing a real server. + * + * Usage: + * @code + * grpc_client_config_t config; + * config.server_address = "mock://test"; + * grpc_client_t client(config); + * + * auto mock_stub = std::make_shared(); + * grpc_test_inject_mock_stub(client, mock_stub); + * + * // Now client.check_status() etc. will use the mock stub + * @endcode + * + * @param client The client to inject the stub into + * @param stub The mock stub to inject (takes ownership via shared_ptr) + */ +void grpc_test_inject_mock_stub(grpc_client_t& client, std::shared_ptr stub); + +/** + * @brief Mark a client as "connected" without actually connecting + * + * For mock testing, we don't have a real channel but need the client + * to think it's connected so it will use the stub. + */ +void grpc_test_mark_as_connected(grpc_client_t& client); + +/** + * @brief Helper template to cast mock stub to void pointer for injection + */ +template +inline void grpc_test_inject_mock_stub_typed(grpc_client_t& client, std::shared_ptr stub) +{ + grpc_test_inject_mock_stub(client, std::static_pointer_cast(stub)); +} + +} // namespace cuopt::linear_programming diff --git a/cpp/tests/linear_programming/grpc/grpc_integration_test.cpp b/cpp/tests/linear_programming/grpc/grpc_integration_test.cpp new file mode 100644 index 0000000000..0625e471c9 --- /dev/null +++ b/cpp/tests/linear_programming/grpc/grpc_integration_test.cpp @@ -0,0 +1,1939 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +/** + * @file grpc_integration_test.cpp + * @brief Integration tests for gRPC client-server communication + * + * Tests are organized into shared-server fixtures to minimize server startup overhead. + * Total target runtime: ~3 minutes. + * + * Fixture layout: + * NoServerTests - Tests that don't need a server + * DefaultServerTests - Shared server with default config (~21 tests) + * ChunkedUploadTests - Shared server with --max-message-mb 256 (4 tests) + * PathSelectionTests - Shared server with --max-message-bytes 4096 --verbose (4 tests) + * ErrorRecoveryTests - Per-test server lifecycle (4 tests) + * TlsServerTests - Shared TLS server (2 tests) + * MtlsServerTests - Shared mTLS server (2 tests) + * + * Environment variables: + * CUOPT_GRPC_SERVER_PATH - Path to cuopt_grpc_server binary + * CUOPT_TEST_PORT_BASE - Base port for test servers (default: 19000) + * RAPIDS_DATASET_ROOT_DIR - Path to test datasets + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include "grpc_client.hpp" + +#include "grpc_test_log_capture.hpp" + +#include +#include + +#include "grpc_service_mapper.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace cuopt::linear_programming; +using cuopt::linear_programming::testing::GrpcTestLogCapture; + +namespace { + +// ============================================================================= +// Server Process Manager +// ============================================================================= + +class ServerProcess { + public: + ServerProcess() : pid_(-1), port_(0) {} + ~ServerProcess() { stop(); } + + void set_tls_config(const std::string& root_certs, + const std::string& client_cert = "", + const std::string& client_key = "") + { + tls_root_certs_ = root_certs; + tls_client_cert_ = client_cert; + tls_client_key_ = client_key; + } + + bool start(int port, const std::vector& extra_args = {}) + { + port_ = port; + + std::string server_path = find_server_binary(); + if (server_path.empty()) { + std::cerr << "Could not find cuopt_grpc_server binary\n"; + return false; + } + + pid_ = fork(); + if (pid_ < 0) { + std::cerr << "fork() failed\n"; + return false; + } + + if (pid_ == 0) { + std::vector args; + args.push_back(server_path.c_str()); + args.push_back("--port"); + std::string port_str = std::to_string(port); + args.push_back(port_str.c_str()); + args.push_back("--workers"); + args.push_back("1"); + + for (const auto& arg : extra_args) { + args.push_back(arg.c_str()); + } + args.push_back(nullptr); + + std::string log_file = "/tmp/cuopt_test_server_" + std::to_string(port) + ".log"; + int fd = open(log_file.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (fd >= 0) { + dup2(fd, STDOUT_FILENO); + dup2(fd, STDERR_FILENO); + close(fd); + } + + execv(server_path.c_str(), const_cast(args.data())); + _exit(127); + } + + return wait_for_ready(15000); + } + + void stop() + { + if (pid_ > 0) { + kill(pid_, SIGTERM); + + int status; + int wait_ms = 0; + while (wait_ms < 5000) { + int ret = waitpid(pid_, &status, WNOHANG); + if (ret != 0) break; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + wait_ms += 100; + } + + if (waitpid(pid_, &status, WNOHANG) == 0) { + kill(pid_, SIGKILL); + waitpid(pid_, &status, 0); + } + + pid_ = -1; + } + } + + int port() const { return port_; } + + bool is_running() const + { + if (pid_ <= 0) return false; + return kill(pid_, 0) == 0; + } + + std::string log_path() const + { + if (port_ <= 0) return ""; + return "/tmp/cuopt_test_server_" + std::to_string(port_) + ".log"; + } + + private: + std::string find_in_path(const std::string& name) + { + const char* path_env = std::getenv("PATH"); + if (!path_env) return ""; + + std::string path_str(path_env); + std::string::size_type start = 0; + std::string::size_type end; + + while ((end = path_str.find(':', start)) != std::string::npos || start < path_str.size()) { + std::string dir; + if (end != std::string::npos) { + dir = path_str.substr(start, end - start); + start = end + 1; + } else { + dir = path_str.substr(start); + start = path_str.size(); + } + + if (dir.empty()) continue; + + std::string full_path = dir + "/" + name; + if (access(full_path.c_str(), X_OK) == 0) { return full_path; } + } + + return ""; + } + + std::string find_server_binary() + { + const char* env_path = std::getenv("CUOPT_GRPC_SERVER_PATH"); + if (env_path && access(env_path, X_OK) == 0) { return env_path; } + + std::string path_result = find_in_path("cuopt_grpc_server"); + if (!path_result.empty()) { return path_result; } + + std::vector paths = { + "./cuopt_grpc_server", + "../cuopt_grpc_server", + "../../cuopt_grpc_server", + "./build/cuopt_grpc_server", + "../build/cuopt_grpc_server", + }; + + for (const auto& path : paths) { + if (access(path.c_str(), X_OK) == 0) { return path; } + } + + return ""; + } + + bool wait_for_ready(int timeout_ms) + { + auto start = std::chrono::steady_clock::now(); + + while (true) { + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + + if (elapsed.count() >= timeout_ms) { return false; } + + grpc_client_config_t config; + config.server_address = "localhost:" + std::to_string(port_); + + if (!tls_root_certs_.empty()) { + config.enable_tls = true; + config.tls_root_certs = tls_root_certs_; + config.tls_client_cert = tls_client_cert_; + config.tls_client_key = tls_client_key_; + } + + grpc_client_t client(config); + if (client.connect()) { return true; } + + int status; + if (waitpid(pid_, &status, WNOHANG) != 0) { + std::cerr << "Server process died during startup\n"; + return false; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + } + } + + pid_t pid_; + int port_; + std::string tls_root_certs_; + std::string tls_client_cert_; + std::string tls_client_key_; +}; + +int get_test_port() +{ + static std::atomic port_counter{0}; + + int base_port = 19000; + const char* env_base = std::getenv("CUOPT_TEST_PORT_BASE"); + if (env_base) { base_port = std::atoi(env_base); } + + return base_port + port_counter.fetch_add(1); +} + +// ============================================================================= +// TLS Certificate Generation (shared across TLS fixtures) +// ============================================================================= + +std::string g_tls_certs_dir; +bool g_tls_certs_ready = false; + +bool ensure_test_certs() +{ + if (g_tls_certs_ready) return true; + + // Check for CI-provided certs + const char* cert_folder = std::getenv("CERT_FOLDER"); + if (cert_folder) { + g_tls_certs_dir = cert_folder; + g_tls_certs_ready = true; + return true; + } + + const char* ssl_certfile = std::getenv("CUOPT_SSL_CERTFILE"); + if (ssl_certfile) { + g_tls_certs_dir = std::filesystem::path(ssl_certfile).parent_path().string(); + g_tls_certs_ready = true; + return true; + } + + g_tls_certs_dir = "/tmp/cuopt_test_certs_" + std::to_string(getpid()); + std::filesystem::create_directories(g_tls_certs_dir); + + auto run = [](const std::string& cmd) { return std::system(cmd.c_str()) == 0; }; + + std::string ca_key = g_tls_certs_dir + "/ca.key"; + std::string ca_crt = g_tls_certs_dir + "/ca.crt"; + if (!run("openssl req -x509 -newkey rsa:2048 -keyout " + ca_key + " -out " + ca_crt + + " -days 1 -nodes -subj '/CN=TestCA' 2>/dev/null")) + return false; + + std::string server_key = g_tls_certs_dir + "/server.key"; + std::string server_csr = g_tls_certs_dir + "/server.csr"; + std::string server_crt = g_tls_certs_dir + "/server.crt"; + if (!run("openssl req -newkey rsa:2048 -keyout " + server_key + " -out " + server_csr + + " -nodes -subj '/CN=localhost' 2>/dev/null")) + return false; + if (!run("openssl x509 -req -in " + server_csr + " -CA " + ca_crt + " -CAkey " + ca_key + + " -CAcreateserial -out " + server_crt + " -days 1 2>/dev/null")) + return false; + + std::string client_key = g_tls_certs_dir + "/client.key"; + std::string client_csr = g_tls_certs_dir + "/client.csr"; + std::string client_crt = g_tls_certs_dir + "/client.crt"; + if (!run("openssl req -newkey rsa:2048 -keyout " + client_key + " -out " + client_csr + + " -nodes -subj '/CN=TestClient' 2>/dev/null")) + return false; + if (!run("openssl x509 -req -in " + client_csr + " -CA " + ca_crt + " -CAkey " + ca_key + + " -CAcreateserial -out " + client_crt + " -days 1 2>/dev/null")) + return false; + + g_tls_certs_ready = true; + return true; +} + +std::string read_file_contents(const std::string& path) +{ + std::ifstream file(path); + if (!file) return ""; + std::stringstream buffer; + buffer << file.rdbuf(); + return buffer.str(); +} + +// ============================================================================= +// Base Test Class +// ============================================================================= + +class GrpcIntegrationTestBase : public ::testing::Test { + protected: + std::unique_ptr create_client(grpc_client_config_t config = {}) + { + config.server_address = "localhost:" + std::to_string(port_); + config.poll_interval_ms = 100; + + if (config.timeout_seconds == 3600) { config.timeout_seconds = 60; } + + config.enable_transfer_hash = true; + + auto client = std::make_unique(config); + if (!client->connect()) { return nullptr; } + return client; + } + + std::string get_test_data_path(const std::string& subdir, const std::string& filename) + { + const char* env_var = std::getenv("RAPIDS_DATASET_ROOT_DIR"); + std::string dataset_root = env_var ? env_var : "./datasets"; + return dataset_root + "/" + subdir + "/" + filename; + } + + std::string get_test_lp_path(const std::string& filename) + { + return get_test_data_path("linear_programming", filename); + } + + std::string get_test_mip_path(const std::string& filename) + { + return get_test_data_path("mip", filename); + } + + cpu_optimization_problem_t load_problem_from_mps(const std::string& mps_path) + { + auto mps_data = cuopt::mps_parser::parse_mps(mps_path); + cpu_optimization_problem_t problem; + populate_from_mps_data_model(&problem, mps_data); + return problem; + } + + cpu_optimization_problem_t create_simple_mip() + { + cpu_optimization_problem_t problem; + + std::vector c = {1.0, 2.0}; + problem.set_objective_coefficients(c.data(), 2); + problem.set_maximize(false); + + std::vector A_values = {1.0, 1.0}; + std::vector A_indices = {0, 1}; + std::vector A_offsets = {0, 2}; + problem.set_csr_constraint_matrix(A_values.data(), 2, A_indices.data(), 2, A_offsets.data(), 2); + + std::vector var_lb = {0.0, 0.0}; + std::vector var_ub = {1.0, 1.0}; + problem.set_variable_lower_bounds(var_lb.data(), 2); + problem.set_variable_upper_bounds(var_ub.data(), 2); + + std::vector var_types = {var_t::INTEGER, var_t::INTEGER}; + problem.set_variable_types(var_types.data(), 2); + + std::vector con_lb = {1.0}; + std::vector con_ub = {1e20}; + problem.set_constraint_lower_bounds(con_lb.data(), 1); + problem.set_constraint_upper_bounds(con_ub.data(), 1); + + return problem; + } + + void wait_for_job_done(grpc_client_t* client, const std::string& job_id, int max_seconds = 30) + { + for (int i = 0; i < max_seconds * 2; ++i) { + auto status = client->check_status(job_id); + if (status.status == job_status_t::COMPLETED || status.status == job_status_t::FAILED || + status.status == job_status_t::CANCELLED) { + return; + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + } + + int port_ = 0; +}; + +// ============================================================================= +// No-Server Tests +// ============================================================================= + +class NoServerTests : public GrpcIntegrationTestBase { + protected: + void SetUp() override { port_ = get_test_port(); } +}; + +TEST_F(NoServerTests, ConnectToNonexistentServer) +{ + grpc_client_config_t config; + config.server_address = "localhost:" + std::to_string(port_); + + GrpcTestLogCapture log_capture; + config.debug_log_callback = log_capture.client_callback(); + + grpc_client_t client(config); + EXPECT_FALSE(client.connect()); + EXPECT_FALSE(client.get_last_error().empty()); + + EXPECT_TRUE(log_capture.client_log_contains("Connection failed")) + << "Expected failure log. Captured logs:\n" + << log_capture.get_client_logs(); +} + +TEST_F(NoServerTests, LogCaptureInfrastructure) +{ + GrpcTestLogCapture log_capture; + + // -- client_log_count -- + log_capture.add_client_log("Test message 1"); + log_capture.add_client_log("Test message 2"); + log_capture.add_client_log("Another test"); + log_capture.add_client_log("Test message 3"); + + EXPECT_EQ(log_capture.client_log_count("Test message"), 3); + EXPECT_EQ(log_capture.client_log_count("Another"), 1); + EXPECT_EQ(log_capture.client_log_count("Not found"), 0); + + // -- mark_test_start isolation with server log file -- + std::string tmp_log = "/tmp/cuopt_test_log_infra_" + std::to_string(getpid()) + ".log"; + { + std::ofstream f(tmp_log); + f << "[Phase1] Before mark\n"; + f.flush(); + } + + log_capture.set_server_log_path(tmp_log); + log_capture.mark_test_start(); + + // Logs written before the mark should be invisible + EXPECT_FALSE(log_capture.server_log_contains("[Phase1]")) + << "Should NOT see logs from before mark_test_start()"; + + // Append new content after the mark + { + std::ofstream f(tmp_log, std::ios::app); + f << "[Phase2] After mark\n"; + f.flush(); + } + + EXPECT_TRUE(log_capture.server_log_contains("[Phase2]")) + << "Should see logs from after mark_test_start()"; + + // get_all_server_logs bypasses the mark + std::string all = log_capture.get_all_server_logs(); + EXPECT_TRUE(all.find("[Phase1]") != std::string::npos); + EXPECT_TRUE(all.find("[Phase2]") != std::string::npos); + + // -- wait_for_server_log (content already present -> immediate return) -- + EXPECT_TRUE(log_capture.wait_for_server_log("[Phase2]", 500)); + EXPECT_FALSE(log_capture.wait_for_server_log("never_appears", 200)); + + // Clean up + std::filesystem::remove(tmp_log); +} + +// ============================================================================= +// Default Server Tests (shared server, default config) +// ============================================================================= + +class DefaultServerTests : public GrpcIntegrationTestBase { + protected: + static void SetUpTestSuite() + { + s_port_ = get_test_port(); + s_server_ = std::make_unique(); + ASSERT_TRUE(s_server_->start(s_port_, {"--enable-transfer-hash"})) + << "Failed to start shared default server on port " << s_port_; + } + + static void TearDownTestSuite() + { + if (s_server_) s_server_->stop(); + s_server_.reset(); + } + + void SetUp() override + { + ASSERT_NE(s_server_, nullptr) << "Shared server not running"; + port_ = s_port_; + } + + std::string server_log_path() const { return s_server_->log_path(); } + + static std::unique_ptr s_server_; + static int s_port_; +}; + +std::unique_ptr DefaultServerTests::s_server_; +int DefaultServerTests::s_port_ = 0; + +// -- Connectivity -- + +TEST_F(DefaultServerTests, ServerAcceptsConnections) +{ + ASSERT_TRUE(s_server_->is_running()); + + GrpcTestLogCapture log_capture; + grpc_client_config_t config; + config.debug_log_callback = log_capture.client_callback(); + + auto client = create_client(config); + ASSERT_NE(client, nullptr) << "Failed to connect to server"; + EXPECT_TRUE(client->is_connected()); + + EXPECT_TRUE(log_capture.client_log_contains("Connecting to")) + << "Expected connection log. Logs:\n" + << log_capture.get_client_logs(); + EXPECT_TRUE(log_capture.client_log_contains("Connected successfully")) + << "Expected success log. Logs:\n" + << log_capture.get_client_logs(); +} + +// -- Status / Cancel / Delete on nonexistent jobs -- + +TEST_F(DefaultServerTests, CheckStatusNotFound) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + auto status = client->check_status("nonexistent-job-id"); + EXPECT_TRUE(status.success); + EXPECT_EQ(status.status, job_status_t::NOT_FOUND); +} + +TEST_F(DefaultServerTests, CancelNonexistentJob) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + auto result = client->cancel_job("nonexistent-job-id"); + EXPECT_EQ(result.job_status, job_status_t::NOT_FOUND); +} + +TEST_F(DefaultServerTests, DeleteNonexistentJob) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + bool deleted = client->delete_job("nonexistent-job-id"); + EXPECT_FALSE(deleted); + EXPECT_FALSE(client->get_last_error().empty()); +} + +TEST_F(DefaultServerTests, StreamLogsNotFound) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + bool callback_called = false; + bool result = + client->stream_logs("nonexistent-job-id", 0, [&callback_called](const std::string&, bool) { + callback_called = true; + return true; + }); + + EXPECT_FALSE(callback_called); + EXPECT_FALSE(result); +} + +TEST_F(DefaultServerTests, GetResultNonexistentJob) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + auto result = client->get_lp_result("nonexistent-job-12345"); + EXPECT_FALSE(result.success); +} + +// -- LP Solves -- + +TEST_F(DefaultServerTests, SolveLPPolling) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto submit_result = client->submit_lp(problem, settings); + ASSERT_TRUE(submit_result.success) << submit_result.error_message; + EXPECT_FALSE(submit_result.job_id.empty()); + + job_status_t final_status = job_status_t::QUEUED; + for (int i = 0; i < 60; ++i) { + auto status = client->check_status(submit_result.job_id); + ASSERT_TRUE(status.success) << status.error_message; + final_status = status.status; + if (final_status == job_status_t::COMPLETED || final_status == job_status_t::FAILED) break; + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + + EXPECT_EQ(final_status, job_status_t::COMPLETED) + << "Status: " << job_status_to_string(final_status); + + auto result = client->get_lp_result(submit_result.job_id); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); +} + +TEST_F(DefaultServerTests, SolveLPWaitRPC) +{ + grpc_client_config_t config; + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); +} + +TEST_F(DefaultServerTests, SolveInfeasibleLP) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + cpu_optimization_problem_t problem; + std::vector var_lb = {1.0}; + std::vector var_ub = {0.0}; + std::vector obj = {1.0}; + std::vector offsets = {0}; + + problem.set_variable_lower_bounds(var_lb.data(), 1); + problem.set_variable_upper_bounds(var_ub.data(), 1); + problem.set_objective_coefficients(obj.data(), 1); + problem.set_maximize(false); + problem.set_csr_constraint_matrix(nullptr, 0, nullptr, 0, offsets.data(), 1); + problem.set_constraint_lower_bounds(nullptr, 0); + problem.set_constraint_upper_bounds(nullptr, 0); + + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + ASSERT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + auto status = result.solution->get_termination_status(); + EXPECT_NE(status, pdlp_termination_status_t::Optimal) + << "Expected non-optimal termination for infeasible problem"; +} + +// -- MIP Solve -- + +TEST_F(DefaultServerTests, SolveMIPBlocking) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + auto problem = create_simple_mip(); + + mip_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_mip(problem, settings, false); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_EQ(result.solution->get_termination_status(), mip_termination_status_t::Optimal); + EXPECT_NEAR(result.solution->get_objective_value(), 1.0, 0.01); +} + +// -- Explicit Async LP Flow (submit/poll/get/delete) -- + +TEST_F(DefaultServerTests, ExplicitAsyncLPFlow) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto submit_result = client->submit_lp(problem, settings); + ASSERT_TRUE(submit_result.success) << submit_result.error_message; + ASSERT_FALSE(submit_result.job_id.empty()); + std::string job_id = submit_result.job_id; + + wait_for_job_done(client.get(), job_id, 30); + + auto result = client->get_lp_result(job_id); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); + + bool deleted = client->delete_job(job_id); + EXPECT_TRUE(deleted); +} + +// -- Log Verification -- + +TEST_F(DefaultServerTests, ServerLogsJobProcessing) +{ + GrpcTestLogCapture log_capture; + log_capture.set_server_log_path(server_log_path()); + log_capture.mark_test_start(); + + auto client = create_client(); + ASSERT_NE(client, nullptr); + + auto problem = create_simple_mip(); + mip_solver_settings_t settings; + settings.time_limit = 10.0; + auto result = client->solve_mip(problem, settings, false); + EXPECT_TRUE(result.success) << result.error_message; + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + EXPECT_TRUE(log_capture.server_log_contains("[Worker")) + << "Expected worker logs. Server log: " << server_log_path(); +} + +TEST_F(DefaultServerTests, ClientDebugLogsSubmission) +{ + GrpcTestLogCapture log_capture; + grpc_client_config_t config; + config.debug_log_callback = log_capture.client_callback(); + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + + EXPECT_TRUE(log_capture.client_log_contains("submit")) << "Expected submit log. Logs:\n" + << log_capture.get_client_logs(); + EXPECT_TRUE(log_capture.client_log_contains_pattern("job_id=[a-f0-9-]+")) + << "Expected job_id pattern. Logs:\n" + << log_capture.get_client_logs(); +} + +// -- Multiple & Concurrent Solves -- + +TEST_F(DefaultServerTests, MultipleSequentialSolves) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + for (int i = 0; i < 3; ++i) { + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << "Solve #" << i << " failed: " << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); + } +} + +TEST_F(DefaultServerTests, ConcurrentJobSubmission) +{ + auto client1 = create_client(); + auto client2 = create_client(); + ASSERT_NE(client1, nullptr); + ASSERT_NE(client2, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + std::vector> jobs; + + auto s1 = client1->submit_lp(problem, settings); + ASSERT_TRUE(s1.success); + jobs.push_back({client1.get(), s1.job_id}); + + auto s2 = client2->submit_lp(problem, settings); + ASSERT_TRUE(s2.success); + jobs.push_back({client2.get(), s2.job_id}); + + auto s3 = client1->submit_lp(problem, settings); + ASSERT_TRUE(s3.success); + jobs.push_back({client1.get(), s3.job_id}); + + std::vector completed(3, false); + int completed_count = 0; + + for (int poll = 0; poll < 120 && completed_count < 3; ++poll) { + for (size_t i = 0; i < jobs.size(); ++i) { + if (completed[i]) continue; + auto status = jobs[i].first->check_status(jobs[i].second); + ASSERT_TRUE(status.success); + if (status.status == job_status_t::COMPLETED) { + completed[i] = true; + completed_count++; + } else if (status.status == job_status_t::FAILED) { + FAIL() << "Job " << i << " failed: " << status.message; + } + } + if (completed_count < 3) { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } + } + + ASSERT_EQ(completed_count, 3) << "Not all jobs completed in time"; + + for (size_t i = 0; i < jobs.size(); ++i) { + auto result = jobs[i].first->get_lp_result(jobs[i].second); + EXPECT_TRUE(result.success); + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); + jobs[i].first->delete_job(jobs[i].second); + } +} + +// -- Unary Path Verification -- + +TEST_F(DefaultServerTests, VerifyUnaryUploadSmallProblem) +{ + GrpcTestLogCapture log_capture; + grpc_client_config_t config; + config.debug_log_callback = log_capture.client_callback(); + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + + EXPECT_TRUE(log_capture.client_log_contains("Unary submit succeeded")) + << "Logs:\n" + << log_capture.get_client_logs(); + EXPECT_FALSE(log_capture.client_log_contains("Starting streaming upload")) + << "Logs:\n" + << log_capture.get_client_logs(); +} + +TEST_F(DefaultServerTests, VerifyUnaryDownloadSmallResult) +{ + GrpcTestLogCapture log_capture; + grpc_client_config_t config; + config.debug_log_callback = log_capture.client_callback(); + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + + EXPECT_TRUE(log_capture.client_log_contains("Attempting unary GetResult")) + << "Logs:\n" + << log_capture.get_client_logs(); + EXPECT_TRUE(log_capture.client_log_contains("Unary GetResult succeeded")) + << "Logs:\n" + << log_capture.get_client_logs(); +} + +TEST_F(DefaultServerTests, SolveLPReturnsWarmStartData) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + + EXPECT_TRUE(result.solution->has_warm_start_data()) + << "LP solution should contain PDLP warm start data"; + + const auto& ws = result.solution->get_cpu_pdlp_warm_start_data(); + + EXPECT_FALSE(ws.current_primal_solution_.empty()) + << "current_primal_solution should be populated"; + EXPECT_FALSE(ws.current_dual_solution_.empty()) << "current_dual_solution should be populated"; + EXPECT_FALSE(ws.initial_primal_average_.empty()) << "initial_primal_average should be populated"; + EXPECT_FALSE(ws.initial_dual_average_.empty()) << "initial_dual_average should be populated"; + EXPECT_FALSE(ws.current_ATY_.empty()) << "current_ATY should be populated"; + EXPECT_FALSE(ws.sum_primal_solutions_.empty()) << "sum_primal_solutions should be populated"; + EXPECT_FALSE(ws.sum_dual_solutions_.empty()) << "sum_dual_solutions should be populated"; + EXPECT_FALSE(ws.last_restart_duality_gap_primal_solution_.empty()) + << "last_restart_duality_gap_primal_solution should be populated"; + EXPECT_FALSE(ws.last_restart_duality_gap_dual_solution_.empty()) + << "last_restart_duality_gap_dual_solution should be populated"; + + EXPECT_GT(ws.initial_primal_weight_, 0.0) << "initial_primal_weight should be positive"; + EXPECT_GT(ws.initial_step_size_, 0.0) << "initial_step_size should be positive"; + EXPECT_GE(ws.total_pdlp_iterations_, 0) << "total_pdlp_iterations should be non-negative"; + EXPECT_GE(ws.total_pdhg_iterations_, 0) << "total_pdhg_iterations should be non-negative"; +} + +// -- MIP Log Callback -- + +TEST_F(DefaultServerTests, SolveMIPWithLogCallback) +{ + std::vector received_logs; + std::mutex log_mutex; + + grpc_client_config_t config; + config.timeout_seconds = 30; + config.stream_logs = true; + config.log_callback = [&](const std::string& line) { + std::lock_guard lock(log_mutex); + received_logs.push_back(line); + }; + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("bb_optimality.mps"); + auto problem = load_problem_from_mps(mps_path); + + mip_solver_settings_t settings; + settings.time_limit = 10.0; + settings.log_to_console = true; + + auto result = client->solve_mip(problem, settings, false); + EXPECT_TRUE(result.success) << result.error_message; +} + +// -- Incumbent Callbacks -- + +TEST_F(DefaultServerTests, IncumbentCallbacksMIP) +{ + std::vector incumbent_objectives; + std::mutex incumbent_mutex; + + grpc_client_config_t config; + config.timeout_seconds = 30; + config.incumbent_callback = [&](int64_t, double objective, const std::vector&) { + std::lock_guard lock(incumbent_mutex); + incumbent_objectives.push_back(objective); + return true; + }; + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("neos5-free-bound.mps"); + auto problem = load_problem_from_mps(mps_path); + + mip_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_mip(problem, settings, true); + EXPECT_TRUE(result.success) << result.error_message; + + if (incumbent_objectives.size() > 1) { + for (size_t i = 1; i < incumbent_objectives.size(); ++i) { + EXPECT_LE(incumbent_objectives[i], incumbent_objectives[i - 1] + 1e-6); + } + } +} + +TEST_F(DefaultServerTests, IncumbentCallbackCancelsSolve) +{ + int callback_count = 0; + + grpc_client_config_t config; + config.timeout_seconds = 30; + config.incumbent_callback = [&](int64_t, double, const std::vector&) { + return ++callback_count < 2; + }; + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("neos5-free-bound.mps"); + auto problem = load_problem_from_mps(mps_path); + + mip_solver_settings_t settings; + settings.time_limit = 30.0; + + auto start = std::chrono::steady_clock::now(); + auto result = client->solve_mip(problem, settings, true); + auto elapsed = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + + EXPECT_LT(elapsed.count(), 25) << "Solve should have cancelled early"; +} + +// -- Cancel Running Job -- + +TEST_F(DefaultServerTests, CancelRunningJob) +{ + auto client = create_client(); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("neos5-free-bound.mps"); + auto problem = load_problem_from_mps(mps_path); + + mip_solver_settings_t settings; + settings.time_limit = 120.0; + + auto submit_result = client->submit_mip(problem, settings); + ASSERT_TRUE(submit_result.success); + std::string job_id = submit_result.job_id; + + std::this_thread::sleep_for(std::chrono::seconds(2)); + + auto cancel_result = client->cancel_job(job_id); + EXPECT_TRUE(cancel_result.job_status == job_status_t::CANCELLED || + cancel_result.job_status == job_status_t::COMPLETED || + cancel_result.job_status == job_status_t::PROCESSING || + cancel_result.job_status == job_status_t::FAILED) + << "Unexpected job_status=" << static_cast(cancel_result.job_status) + << " message=" << cancel_result.message; + + // Wait for worker to free up before next test + wait_for_job_done(client.get(), job_id, 15); + client->delete_job(job_id); +} + +// ============================================================================= +// Chunked Upload Tests (--max-message-mb 256) +// ============================================================================= + +class ChunkedUploadTests : public GrpcIntegrationTestBase { + protected: + static void SetUpTestSuite() + { + s_port_ = get_test_port(); + s_server_ = std::make_unique(); + ASSERT_TRUE(s_server_->start(s_port_, {"--max-message-mb", "256"})) + << "Failed to start chunked upload server"; + } + + static void TearDownTestSuite() + { + if (s_server_) s_server_->stop(); + s_server_.reset(); + } + + void SetUp() override + { + ASSERT_NE(s_server_, nullptr); + port_ = s_port_; + } + + void TearDown() override + { + if (HasFailure() && s_server_) { + std::string log_file = s_server_->log_path(); + if (!log_file.empty()) { + std::ifstream f(log_file); + if (f) { + std::cerr << "\n=== Server log (" << log_file << ") ===\n" + << f.rdbuf() << "\n=== End server log ===\n"; + } + } + } + } + + static std::unique_ptr s_server_; + static int s_port_; +}; + +std::unique_ptr ChunkedUploadTests::s_server_; +int ChunkedUploadTests::s_port_ = 0; + +TEST_F(ChunkedUploadTests, ChunkedUploadLP) +{ + grpc_client_config_t config; + config.timeout_seconds = 60; + config.chunk_size_bytes = 8 * 1024; + config.chunked_array_threshold_bytes = 0; // Force chunked upload + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); +} + +TEST_F(ChunkedUploadTests, ChunkedUploadMIP) +{ + grpc_client_config_t config; + config.timeout_seconds = 60; + config.chunk_size_bytes = 4 * 1024; + config.chunked_array_threshold_bytes = 0; // Force chunked upload + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("sudoku.mps"); + auto problem = load_problem_from_mps(mps_path); + + mip_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_mip(problem, settings, false); + EXPECT_TRUE(result.success) << result.error_message; +} + +TEST_F(ChunkedUploadTests, ConcurrentChunkedUploads) +{ + const int num_clients = 3; + std::vector> clients; + + for (int i = 0; i < num_clients; ++i) { + grpc_client_config_t config; + config.timeout_seconds = 60; + config.chunk_size_bytes = 4 * 1024; + config.chunked_array_threshold_bytes = 0; + auto client = create_client(config); + ASSERT_NE(client, nullptr); + clients.push_back(std::move(client)); + } + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + std::atomic success_count{0}; + + auto solve_task = [&](int idx) -> bool { + auto result = clients[idx]->solve_lp(problem, settings); + if (result.success && result.solution && + std::abs(result.solution->get_objective_value() - (-464.753)) < 1.0) { + success_count++; + return true; + } + return false; + }; + + std::vector> futures; + for (int i = 0; i < num_clients; ++i) { + futures.push_back(std::async(std::launch::async, solve_task, i)); + } + + for (int i = 0; i < num_clients; ++i) { + EXPECT_TRUE(futures[i].get()) << "Client " << i << " failed"; + } + + EXPECT_EQ(success_count.load(), num_clients); +} + +TEST_F(ChunkedUploadTests, UnaryFallbackSmallProblem) +{ + grpc_client_config_t config; + config.timeout_seconds = 60; + config.chunked_array_threshold_bytes = 100 * 1024 * 1024; // 100 MiB, well above afiro size + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); +} + +// ============================================================================= +// Path Selection Tests (unary vs chunked IPC and result retrieval) +// +// Uses --max-message-bytes to set a very low logical threshold so that even +// small test problems exercise both unary and chunked code paths. +// Uses --verbose so the server emits IPC path tags we can verify in logs. +// ============================================================================= + +class PathSelectionTests : public GrpcIntegrationTestBase { + protected: + static void SetUpTestSuite() + { + s_port_ = get_test_port(); + s_server_ = std::make_unique(); + // Small threshold (clamped to 4 KiB) forces chunked result downloads for + // anything larger than ~4 KB, exercising the chunked download path. + ASSERT_TRUE(s_server_->start(s_port_, {"--max-message-bytes", "4096", "--verbose"})) + << "Failed to start path-selection server"; + } + + static void TearDownTestSuite() + { + if (s_server_) s_server_->stop(); + s_server_.reset(); + } + + void SetUp() override + { + ASSERT_NE(s_server_, nullptr); + port_ = s_port_; + } + + std::string server_log_path() const { return s_server_->log_path(); } + + static std::unique_ptr s_server_; + static int s_port_; +}; + +std::unique_ptr PathSelectionTests::s_server_; +int PathSelectionTests::s_port_ = 0; + +// Unary upload for a small LP (afiro). The result is small enough that +// the server returns it via unary GetResult. We verify the upload and +// result paths in the server logs but don't assert the download method +// since the result size may or may not exceed the 4 KiB threshold. +TEST_F(PathSelectionTests, UnaryUploadLPWithPathLogging) +{ + GrpcTestLogCapture log_capture; + log_capture.set_server_log_path(server_log_path()); + log_capture.mark_test_start(); + + grpc_client_config_t config; + config.timeout_seconds = 60; + config.chunked_array_threshold_bytes = 100 * 1024 * 1024; // high threshold => unary upload + config.debug_log_callback = log_capture.client_callback(); + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + // Worker should have received via the UNARY path + EXPECT_TRUE(log_capture.wait_for_server_log("[Worker] IPC path: UNARY LP", 5000)) + << "Expected UNARY LP path in server log.\nServer log:\n" + << log_capture.get_server_logs(); + + // Worker should have serialized the result + EXPECT_TRUE(log_capture.server_log_contains("[Worker] Result path: LP solution")) + << "Expected LP result path in server log.\nServer log:\n" + << log_capture.get_server_logs(); +} + +// Chunked upload, verify server receives via CHUNKED path +TEST_F(PathSelectionTests, ChunkedUploadLPWithPathLogging) +{ + GrpcTestLogCapture log_capture; + log_capture.set_server_log_path(server_log_path()); + log_capture.mark_test_start(); + + grpc_client_config_t config; + config.timeout_seconds = 60; + config.chunk_size_bytes = 4 * 1024; + config.chunked_array_threshold_bytes = 0; // force chunked upload + config.debug_log_callback = log_capture.client_callback(); + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + // Worker should have received via the CHUNKED path + EXPECT_TRUE(log_capture.wait_for_server_log("[Worker] IPC path: CHUNKED", 5000)) + << "Expected CHUNKED path in server log.\nServer log:\n" + << log_capture.get_server_logs(); + + // Server main process should have logged FinishChunkedUpload + EXPECT_TRUE(log_capture.server_log_contains("FinishChunkedUpload: CHUNKED path")) + << "Expected FinishChunkedUpload log.\nServer log:\n" + << log_capture.get_server_logs(); +} + +// Chunked upload + chunked result download for MIP. +// sudoku.mps produces ~5.8 KB result which exceeds the 4 KB threshold, +// so the client should use chunked download. +TEST_F(PathSelectionTests, ChunkedUploadAndChunkedDownloadMIP) +{ + GrpcTestLogCapture log_capture; + log_capture.set_server_log_path(server_log_path()); + log_capture.mark_test_start(); + + grpc_client_config_t config; + config.timeout_seconds = 60; + config.chunk_size_bytes = 4 * 1024; + config.chunked_array_threshold_bytes = 0; // force chunked upload + config.debug_log_callback = log_capture.client_callback(); + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("sudoku.mps"); + auto problem = load_problem_from_mps(mps_path); + mip_solver_settings_t settings; + settings.time_limit = 30.0; + + auto result = client->solve_mip(problem, settings, false); + EXPECT_TRUE(result.success) << result.error_message; + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + // Upload should have gone through the CHUNKED path + EXPECT_TRUE(log_capture.wait_for_server_log("[Worker] IPC path: CHUNKED", 5000)) + << "Expected CHUNKED upload path in server log.\nServer log:\n" + << log_capture.get_server_logs(); + + // Client should have used chunked download (result > 4096 bytes) + EXPECT_TRUE(log_capture.client_log_contains("chunked download") || + log_capture.client_log_contains("ChunkedDownload")) + << "Expected chunked download in client log.\nClient log:\n" + << log_capture.get_client_logs(); + + // Server should log CHUNKED response + EXPECT_TRUE(log_capture.wait_for_server_log("StartChunkedDownload: CHUNKED response", 5000)) + << "Expected chunked download path in server log.\nServer log:\n" + << log_capture.get_server_logs(); +} + +// MIP path: unary upload, verify UNARY MIP tag +TEST_F(PathSelectionTests, UnaryUploadMIPWithPathLogging) +{ + GrpcTestLogCapture log_capture; + log_capture.set_server_log_path(server_log_path()); + log_capture.mark_test_start(); + + grpc_client_config_t config; + config.timeout_seconds = 60; + config.debug_log_callback = log_capture.client_callback(); + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("bb_optimality.mps"); + auto problem = load_problem_from_mps(mps_path); + mip_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_mip(problem, settings, false); + EXPECT_TRUE(result.success) << result.error_message; + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + EXPECT_TRUE(log_capture.wait_for_server_log("[Worker] IPC path: UNARY MIP", 5000)) + << "Expected UNARY MIP path in server log.\nServer log:\n" + << log_capture.get_server_logs(); + + EXPECT_TRUE(log_capture.server_log_contains("[Worker] Result path: MIP solution")) + << "Expected MIP result path in server log.\nServer log:\n" + << log_capture.get_server_logs(); +} + +// ============================================================================= +// Error Recovery Tests (per-test server lifecycle) +// ============================================================================= + +class ErrorRecoveryTests : public GrpcIntegrationTestBase { + protected: + void SetUp() override { port_ = get_test_port(); } + void TearDown() override { server_.stop(); } + + bool start_server(const std::vector& extra_args = {}) + { + return server_.start(port_, extra_args); + } + + ServerProcess server_; +}; + +TEST_F(ErrorRecoveryTests, ClientReconnectsAfterServerRestart) +{ + ASSERT_TRUE(start_server()); + auto client = create_client(); + ASSERT_NE(client, nullptr); + + auto status_before = client->check_status("test-job"); + EXPECT_TRUE(status_before.success); + + server_.stop(); + EXPECT_FALSE(server_.is_running()); + + auto status_down = client->check_status("test-job"); + EXPECT_FALSE(status_down.success); + + ASSERT_TRUE(start_server()); + + auto status_after = client->check_status("test-job"); + EXPECT_TRUE(status_after.success) << "Should auto-reconnect: " << status_after.error_message; +} + +TEST_F(ErrorRecoveryTests, ClientHandlesServerCrashDuringSolve) +{ + ASSERT_TRUE(start_server()); + auto client = create_client(); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("neos5-free-bound.mps"); + auto problem = load_problem_from_mps(mps_path); + + mip_solver_settings_t settings; + settings.time_limit = 120.0; + + auto submit_result = client->submit_mip(problem, settings); + ASSERT_TRUE(submit_result.success); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + server_.stop(); + + auto status_result = client->check_status(submit_result.job_id); + EXPECT_FALSE(status_result.success); + EXPECT_FALSE(status_result.error_message.empty()); +} + +TEST_F(ErrorRecoveryTests, ClientTimeoutConfiguration) +{ + ASSERT_TRUE(start_server()); + + grpc_client_config_t config; + config.timeout_seconds = 1; + config.poll_interval_ms = 100; + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("neos5-free-bound.mps"); + auto problem = load_problem_from_mps(mps_path); + + mip_solver_settings_t settings; + settings.time_limit = 60.0; + + auto submit_result = client->submit_mip(problem, settings); + ASSERT_TRUE(submit_result.success); + + auto start = std::chrono::steady_clock::now(); + bool completed = false; + while (!completed) { + auto elapsed = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + if (elapsed.count() >= config.timeout_seconds) break; + auto status = client->check_status(submit_result.job_id); + if (status.status == job_status_t::COMPLETED || status.status == job_status_t::FAILED) { + completed = true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + EXPECT_FALSE(completed) << "Complex MIP should not complete in 1 second"; + client->cancel_job(submit_result.job_id); +} + +TEST_F(ErrorRecoveryTests, ChunkedUploadAfterServerRestart) +{ + ASSERT_TRUE(start_server({"--max-message-mb", "256"})); + + grpc_client_config_t config; + config.timeout_seconds = 30; + config.chunk_size_bytes = 4 * 1024; + config.chunked_array_threshold_bytes = 0; + + auto client = create_client(config); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_mip_path("sudoku.mps"); + auto problem = load_problem_from_mps(mps_path); + mip_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result1 = client->solve_mip(problem, settings, false); + EXPECT_TRUE(result1.success) << result1.error_message; + + server_.stop(); + ASSERT_TRUE(start_server({"--max-message-mb", "256"})); + + auto client2 = create_client(config); + ASSERT_NE(client2, nullptr); + + auto result2 = client2->solve_mip(problem, settings, false); + EXPECT_TRUE(result2.success) << result2.error_message; +} + +// ============================================================================= +// TLS Tests +// ============================================================================= + +class TlsServerTests : public GrpcIntegrationTestBase { + protected: + static void SetUpTestSuite() + { + if (!ensure_test_certs()) { + s_certs_available_ = false; + return; + } + + s_certs_available_ = std::filesystem::exists(g_tls_certs_dir + "/server.crt") && + std::filesystem::exists(g_tls_certs_dir + "/server.key") && + std::filesystem::exists(g_tls_certs_dir + "/ca.crt"); + + if (!s_certs_available_) return; + + s_port_ = get_test_port(); + s_server_ = std::make_unique(); + + std::string root_certs = read_file_contents(g_tls_certs_dir + "/ca.crt"); + s_server_->set_tls_config(root_certs); + + std::vector args = {"--tls", + "--tls-cert", + g_tls_certs_dir + "/server.crt", + "--tls-key", + g_tls_certs_dir + "/server.key", + "--tls-root", + g_tls_certs_dir + "/ca.crt", + "--enable-transfer-hash"}; + + if (!s_server_->start(s_port_, args)) { + s_server_.reset(); + s_certs_available_ = false; + } + } + + static void TearDownTestSuite() + { + if (s_server_) s_server_->stop(); + s_server_.reset(); + } + + void SetUp() override + { + if (!s_certs_available_) { GTEST_SKIP() << "TLS certificates not available"; } + ASSERT_NE(s_server_, nullptr) << "TLS server not running"; + port_ = s_port_; + } + + std::unique_ptr create_tls_client() + { + grpc_client_config_t config; + config.server_address = "localhost:" + std::to_string(port_); + config.timeout_seconds = 30; + config.enable_tls = true; + config.tls_root_certs = read_file_contents(g_tls_certs_dir + "/ca.crt"); + + auto client = std::make_unique(config); + if (!client->connect()) return nullptr; + return client; + } + + static std::unique_ptr s_server_; + static int s_port_; + static bool s_certs_available_; +}; + +std::unique_ptr TlsServerTests::s_server_; +int TlsServerTests::s_port_ = 0; +bool TlsServerTests::s_certs_available_ = false; + +TEST_F(TlsServerTests, BasicConnection) +{ + auto client = create_tls_client(); + ASSERT_NE(client, nullptr) << "Failed to connect with TLS"; + EXPECT_TRUE(client->is_connected()); +} + +TEST_F(TlsServerTests, SolveLP) +{ + auto client = create_tls_client(); + ASSERT_NE(client, nullptr); + + std::string mps_path = get_test_lp_path("afiro_original.mps"); + auto problem = load_problem_from_mps(mps_path); + pdlp_solver_settings_t settings; + settings.time_limit = 10.0; + + auto result = client->solve_lp(problem, settings); + EXPECT_TRUE(result.success) << result.error_message; + ASSERT_NE(result.solution, nullptr); + EXPECT_NEAR(result.solution->get_objective_value(), -464.753, 1.0); +} + +// ============================================================================= +// mTLS Tests +// ============================================================================= + +class MtlsServerTests : public GrpcIntegrationTestBase { + protected: + static void SetUpTestSuite() + { + if (!ensure_test_certs()) { + s_certs_available_ = false; + return; + } + + s_certs_available_ = std::filesystem::exists(g_tls_certs_dir + "/client.crt") && + std::filesystem::exists(g_tls_certs_dir + "/client.key") && + std::filesystem::exists(g_tls_certs_dir + "/server.crt") && + std::filesystem::exists(g_tls_certs_dir + "/ca.crt"); + + if (!s_certs_available_) return; + + s_port_ = get_test_port(); + s_server_ = std::make_unique(); + + std::string root_certs = read_file_contents(g_tls_certs_dir + "/ca.crt"); + std::string client_cert = read_file_contents(g_tls_certs_dir + "/client.crt"); + std::string client_key = read_file_contents(g_tls_certs_dir + "/client.key"); + s_server_->set_tls_config(root_certs, client_cert, client_key); + + std::vector args = {"--tls", + "--tls-cert", + g_tls_certs_dir + "/server.crt", + "--tls-key", + g_tls_certs_dir + "/server.key", + "--tls-root", + g_tls_certs_dir + "/ca.crt", + "--require-client-cert", + "--enable-transfer-hash"}; + + if (!s_server_->start(s_port_, args)) { + s_server_.reset(); + s_certs_available_ = false; + } + } + + static void TearDownTestSuite() + { + if (s_server_) s_server_->stop(); + s_server_.reset(); + } + + void SetUp() override + { + if (!s_certs_available_) { GTEST_SKIP() << "mTLS certificates not available"; } + ASSERT_NE(s_server_, nullptr) << "mTLS server not running"; + port_ = s_port_; + } + + std::unique_ptr create_mtls_client(bool with_client_cert = true) + { + grpc_client_config_t config; + config.server_address = "localhost:" + std::to_string(port_); + config.timeout_seconds = 30; + config.enable_tls = true; + config.tls_root_certs = read_file_contents(g_tls_certs_dir + "/ca.crt"); + + if (with_client_cert) { + config.tls_client_cert = read_file_contents(g_tls_certs_dir + "/client.crt"); + config.tls_client_key = read_file_contents(g_tls_certs_dir + "/client.key"); + } + + auto client = std::make_unique(config); + if (!client->connect()) return nullptr; + return client; + } + + static std::unique_ptr s_server_; + static int s_port_; + static bool s_certs_available_; +}; + +std::unique_ptr MtlsServerTests::s_server_; +int MtlsServerTests::s_port_ = 0; +bool MtlsServerTests::s_certs_available_ = false; + +TEST_F(MtlsServerTests, ConnectionWithClientCert) +{ + auto client = create_mtls_client(true); + ASSERT_NE(client, nullptr) << "Failed to connect with mTLS"; + EXPECT_TRUE(client->is_connected()); +} + +TEST_F(MtlsServerTests, RejectsClientWithoutCert) +{ + auto client = create_mtls_client(false); + EXPECT_EQ(client, nullptr) << "Server should reject client without certificate"; +} + +// ============================================================================= +// Chunk Validation Tests +// +// Uses a raw gRPC stub to send malformed chunk requests and verify the server +// rejects them with appropriate error codes. Exercises items 1-8 from the +// chunked transfer hardening work. +// ============================================================================= + +class ChunkValidationTests : public GrpcIntegrationTestBase { + protected: + static void SetUpTestSuite() + { + s_port_ = get_test_port(); + s_server_ = std::make_unique(); + ASSERT_TRUE(s_server_->start(s_port_, {"--verbose"})) + << "Failed to start chunk validation server"; + } + + static void TearDownTestSuite() + { + if (s_server_) s_server_->stop(); + s_server_.reset(); + } + + void SetUp() override + { + ASSERT_NE(s_server_, nullptr); + port_ = s_port_; + + auto channel = + grpc::CreateChannel("localhost:" + std::to_string(port_), grpc::InsecureChannelCredentials()); + stub_ = cuopt::remote::CuOptRemoteService::NewStub(channel); + } + + std::string start_upload() + { + grpc::ClientContext ctx; + cuopt::remote::StartChunkedUploadRequest req; + auto* hdr = req.mutable_problem_header()->mutable_header(); + hdr->set_version(1); + hdr->set_problem_type(cuopt::remote::LP); + cuopt::remote::StartChunkedUploadResponse resp; + auto status = stub_->StartChunkedUpload(&ctx, req, &resp); + EXPECT_TRUE(status.ok()) << status.error_message(); + return resp.upload_id(); + } + + grpc::Status send_chunk(const std::string& upload_id, + cuopt::remote::ArrayFieldId field_id, + int64_t element_offset, + int64_t total_elements, + const std::string& data) + { + grpc::ClientContext ctx; + cuopt::remote::SendArrayChunkRequest req; + req.set_upload_id(upload_id); + auto* ac = req.mutable_chunk(); + ac->set_field_id(field_id); + ac->set_element_offset(element_offset); + ac->set_total_elements(total_elements); + ac->set_data(data); + cuopt::remote::SendArrayChunkResponse resp; + return stub_->SendArrayChunk(&ctx, req, &resp); + } + + std::unique_ptr stub_; + static std::unique_ptr s_server_; + static int s_port_; +}; + +std::unique_ptr ChunkValidationTests::s_server_; +int ChunkValidationTests::s_port_ = 0; + +TEST_F(ChunkValidationTests, RejectsNegativeElementOffset) +{ + auto uid = start_upload(); + std::string data(8, '\0'); // 1 double + auto status = send_chunk(uid, cuopt::remote::FIELD_C, -1, 10, data); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("non-negative")); +} + +TEST_F(ChunkValidationTests, RejectsNegativeTotalElements) +{ + auto uid = start_upload(); + std::string data(8, '\0'); + auto status = send_chunk(uid, cuopt::remote::FIELD_C, 0, -5, data); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("non-negative")); +} + +TEST_F(ChunkValidationTests, RejectsHugeTotalElements) +{ + auto uid = start_upload(); + std::string data(8, '\0'); + auto status = send_chunk(uid, cuopt::remote::FIELD_C, 0, int64_t(1) << 60, data); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::RESOURCE_EXHAUSTED); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("too large")); +} + +TEST_F(ChunkValidationTests, RejectsInvalidFieldId) +{ + auto uid = start_upload(); + std::string data(8, '\0'); + auto status = send_chunk(uid, static_cast(999), 0, 10, data); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("field_id")); +} + +TEST_F(ChunkValidationTests, RejectsUnalignedChunkData) +{ + auto uid = start_upload(); + // First chunk to allocate the array (doubles, elem_size=8) + std::string good_data(80, '\0'); // 10 doubles + auto s1 = send_chunk(uid, cuopt::remote::FIELD_C, 0, 10, good_data); + EXPECT_TRUE(s1.ok()) << s1.error_message(); + + // Send a misaligned chunk (7 bytes, not a multiple of 8) + std::string bad_data(7, '\0'); + auto s2 = send_chunk(uid, cuopt::remote::FIELD_C, 0, 10, bad_data); + EXPECT_FALSE(s2.ok()); + EXPECT_EQ(s2.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + EXPECT_THAT(s2.error_message(), ::testing::HasSubstr("aligned")); +} + +TEST_F(ChunkValidationTests, RejectsOffsetBeyondArraySize) +{ + auto uid = start_upload(); + // Allocate array of 10 doubles + std::string data(80, '\0'); + auto s1 = send_chunk(uid, cuopt::remote::FIELD_C, 0, 10, data); + EXPECT_TRUE(s1.ok()) << s1.error_message(); + + // Offset 100 is way past the 10-element array + std::string small_data(8, '\0'); + auto s2 = send_chunk(uid, cuopt::remote::FIELD_C, 100, 10, small_data); + EXPECT_FALSE(s2.ok()); + EXPECT_EQ(s2.error_code(), grpc::StatusCode::INVALID_ARGUMENT); +} + +TEST_F(ChunkValidationTests, RejectsChunkOverflow) +{ + auto uid = start_upload(); + // Allocate array of 4 doubles (32 bytes) + std::string init_data(32, '\0'); + auto s1 = send_chunk(uid, cuopt::remote::FIELD_C, 0, 4, init_data); + EXPECT_TRUE(s1.ok()) << s1.error_message(); + + // Offset 3 + 2 doubles = writes past end + std::string over_data(16, '\0'); // 2 doubles + auto s2 = send_chunk(uid, cuopt::remote::FIELD_C, 3, 4, over_data); + EXPECT_FALSE(s2.ok()); + EXPECT_EQ(s2.error_code(), grpc::StatusCode::INVALID_ARGUMENT); +} + +TEST_F(ChunkValidationTests, RejectsUnknownUploadId) +{ + std::string data(8, '\0'); + auto status = send_chunk("nonexistent-upload-id", cuopt::remote::FIELD_C, 0, 10, data); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::NOT_FOUND); +} + +TEST_F(ChunkValidationTests, AcceptsValidChunk) +{ + auto uid = start_upload(); + // 10 doubles = 80 bytes + std::string data(80, '\x42'); + auto status = send_chunk(uid, cuopt::remote::FIELD_C, 0, 10, data); + EXPECT_TRUE(status.ok()) << status.error_message(); +} + +} // anonymous namespace + +// ============================================================================= +// Main +// ============================================================================= + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/cpp/tests/linear_programming/grpc/grpc_pipe_serialization_test.cpp b/cpp/tests/linear_programming/grpc/grpc_pipe_serialization_test.cpp new file mode 100644 index 0000000000..5d6b480d9b --- /dev/null +++ b/cpp/tests/linear_programming/grpc/grpc_pipe_serialization_test.cpp @@ -0,0 +1,471 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +/** + * @file grpc_pipe_serialization_test.cpp + * @brief Round-trip unit tests for the hybrid pipe serialization format. + * + * Tests write data through a real pipe(2) and read it back, verifying that + * protobuf headers and raw array bytes survive the round trip intact. + * A writer thread is used because pipe buffers are finite; blocking writes + * would deadlock if the reader isn't draining concurrently. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +// write_to_pipe / read_from_pipe are the real implementations from +// grpc_pipe_io.cpp, compiled directly into this test target. +#include "grpc_pipe_serialization.hpp" + +using namespace cuopt::remote; + +// --------------------------------------------------------------------------- +// RAII wrapper for a pipe(2) pair. +// --------------------------------------------------------------------------- +class PipePair { + public: + PipePair() + { + if (::pipe(fds_) != 0) { throw std::runtime_error("pipe() failed"); } + } + ~PipePair() + { + if (fds_[0] >= 0) ::close(fds_[0]); + if (fds_[1] >= 0) ::close(fds_[1]); + } + int read_fd() const { return fds_[0]; } + int write_fd() const { return fds_[1]; } + + private: + int fds_[2]{-1, -1}; +}; + +// --------------------------------------------------------------------------- +// Helpers to build test data. +// --------------------------------------------------------------------------- +namespace { + +std::vector make_pattern(size_t num_bytes, uint8_t seed = 0) +{ + std::vector v(num_bytes); + for (size_t i = 0; i < num_bytes; ++i) { + v[i] = static_cast((i + seed) & 0xFF); + } + return v; +} + +ArrayChunk make_whole_chunk(ArrayFieldId field_id, + int64_t total_elements, + const std::vector& data) +{ + ArrayChunk ac; + ac.set_field_id(field_id); + ac.set_element_offset(0); + ac.set_total_elements(total_elements); + ac.set_data(std::string(reinterpret_cast(data.data()), data.size())); + return ac; +} + +ArrayChunk make_partial_chunk(ArrayFieldId field_id, + int64_t element_offset, + int64_t total_elements, + const uint8_t* data, + size_t data_size) +{ + ArrayChunk ac; + ac.set_field_id(field_id); + ac.set_element_offset(element_offset); + ac.set_total_elements(total_elements); + ac.set_data(std::string(reinterpret_cast(data), data_size)); + return ac; +} + +} // namespace + +// ============================================================================= +// Chunked request round-trip tests +// ============================================================================= + +TEST(PipeSerialization, ChunkedRequest_SingleChunkPerField) +{ + PipePair pp; + + ChunkedProblemHeader header; + header.set_maximize(true); + header.set_objective_scaling_factor(2.5); + header.set_problem_name("test_lp"); + + // Two fields: FIELD_C (8-byte doubles, 100 elements) and FIELD_A_INDICES (4-byte ints, 50 + // elements) + auto c_data = make_pattern(100 * 8, 0xAA); + auto i_data = make_pattern(50 * 4, 0xBB); + + std::vector chunks; + chunks.push_back(make_whole_chunk(FIELD_C, 100, c_data)); + chunks.push_back(make_whole_chunk(FIELD_A_INDICES, 50, i_data)); + + // Write in a thread (pipe buffer is finite). + bool write_ok = false; + std::thread writer( + [&] { write_ok = write_chunked_request_to_pipe(pp.write_fd(), header, chunks); }); + + ChunkedProblemHeader header_out; + std::map> arrays_out; + bool read_ok = read_chunked_request_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + + EXPECT_TRUE(header_out.maximize()); + EXPECT_DOUBLE_EQ(header_out.objective_scaling_factor(), 2.5); + EXPECT_EQ(header_out.problem_name(), "test_lp"); + + ASSERT_EQ(arrays_out.size(), 2u); + EXPECT_EQ(arrays_out[FIELD_C], c_data); + EXPECT_EQ(arrays_out[FIELD_A_INDICES], i_data); +} + +TEST(PipeSerialization, ChunkedRequest_MultiChunkAssembly) +{ + PipePair pp; + + ChunkedProblemHeader header; + header.set_maximize(false); + + // Split a 200-element double array (FIELD_C, 8 bytes each = 1600 bytes) into two chunks. + constexpr int64_t total_elements = 200; + constexpr int64_t elem_size = 8; + auto full_data = make_pattern(total_elements * elem_size, 0x42); + + int64_t split = 120; + std::vector chunks; + chunks.push_back(make_partial_chunk( + FIELD_C, 0, total_elements, full_data.data(), static_cast(split * elem_size))); + chunks.push_back(make_partial_chunk(FIELD_C, + split, + total_elements, + full_data.data() + split * elem_size, + static_cast((total_elements - split) * elem_size))); + + bool write_ok = false; + std::thread writer( + [&] { write_ok = write_chunked_request_to_pipe(pp.write_fd(), header, chunks); }); + + ChunkedProblemHeader header_out; + std::map> arrays_out; + bool read_ok = read_chunked_request_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + ASSERT_EQ(arrays_out.size(), 1u); + EXPECT_EQ(arrays_out[FIELD_C], full_data); +} + +TEST(PipeSerialization, ChunkedRequest_EmptyArrays) +{ + PipePair pp; + + ChunkedProblemHeader header; + header.set_problem_name("empty"); + + // A field with total_elements=0 should produce a zero-length array entry. + ArrayChunk empty_chunk; + empty_chunk.set_field_id(FIELD_C); + empty_chunk.set_element_offset(0); + empty_chunk.set_total_elements(0); + empty_chunk.set_data(""); + + std::vector chunks = {empty_chunk}; + + bool write_ok = false; + std::thread writer( + [&] { write_ok = write_chunked_request_to_pipe(pp.write_fd(), header, chunks); }); + + ChunkedProblemHeader header_out; + std::map> arrays_out; + bool read_ok = read_chunked_request_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + EXPECT_EQ(header_out.problem_name(), "empty"); + ASSERT_EQ(arrays_out.size(), 1u); + EXPECT_TRUE(arrays_out[FIELD_C].empty()); +} + +TEST(PipeSerialization, ChunkedRequest_NoChunks) +{ + PipePair pp; + + ChunkedProblemHeader header; + header.set_problem_name("header_only"); + + std::vector chunks; // no chunks at all + + bool write_ok = false; + std::thread writer( + [&] { write_ok = write_chunked_request_to_pipe(pp.write_fd(), header, chunks); }); + + ChunkedProblemHeader header_out; + std::map> arrays_out; + bool read_ok = read_chunked_request_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + EXPECT_EQ(header_out.problem_name(), "header_only"); + EXPECT_TRUE(arrays_out.empty()); +} + +TEST(PipeSerialization, ChunkedRequest_ManyFields) +{ + PipePair pp; + + ChunkedProblemHeader header; + header.set_maximize(true); + + // Build one whole chunk per field for several different field types. + struct TestField { + ArrayFieldId id; + int64_t elements; + }; + std::vector test_fields = { + {FIELD_A_VALUES, 500}, + {FIELD_A_INDICES, 500}, + {FIELD_A_OFFSETS, 101}, + {FIELD_C, 100}, + {FIELD_VARIABLE_LOWER_BOUNDS, 100}, + {FIELD_VARIABLE_UPPER_BOUNDS, 100}, + {FIELD_CONSTRAINT_LOWER_BOUNDS, 100}, + {FIELD_CONSTRAINT_UPPER_BOUNDS, 100}, + }; + + std::map> expected; + std::vector chunks; + for (size_t i = 0; i < test_fields.size(); ++i) { + auto& tf = test_fields[i]; + int64_t es = array_field_element_size(tf.id); + auto data = make_pattern(static_cast(tf.elements * es), static_cast(i)); + expected[static_cast(tf.id)] = data; + chunks.push_back(make_whole_chunk(tf.id, tf.elements, data)); + } + + bool write_ok = false; + std::thread writer( + [&] { write_ok = write_chunked_request_to_pipe(pp.write_fd(), header, chunks); }); + + ChunkedProblemHeader header_out; + std::map> arrays_out; + bool read_ok = read_chunked_request_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + ASSERT_EQ(arrays_out.size(), expected.size()); + for (const auto& [fid, data] : expected) { + ASSERT_TRUE(arrays_out.count(fid)) << "Missing field_id " << fid; + EXPECT_EQ(arrays_out[fid], data) << "Mismatch for field_id " << fid; + } +} + +// ============================================================================= +// Result round-trip tests +// ============================================================================= + +TEST(PipeSerialization, Result_RoundTrip) +{ + PipePair pp; + + ChunkedResultHeader header; + header.set_is_mip(false); + header.set_lp_termination_status(PDLP_OPTIMAL); + header.set_primal_objective(42.5); + header.set_solve_time(1.23); + + // Two result arrays: primal solution and dual solution. + auto primal = make_pattern(1000 * 8, 0x11); + auto dual = make_pattern(500 * 8, 0x22); + + std::map> arrays; + arrays[RESULT_PRIMAL_SOLUTION] = primal; + arrays[RESULT_DUAL_SOLUTION] = dual; + + bool write_ok = false; + std::thread writer([&] { write_ok = write_result_to_pipe(pp.write_fd(), header, arrays); }); + + ChunkedResultHeader header_out; + std::map> arrays_out; + bool read_ok = read_result_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + + EXPECT_FALSE(header_out.is_mip()); + EXPECT_EQ(header_out.lp_termination_status(), PDLP_OPTIMAL); + EXPECT_DOUBLE_EQ(header_out.primal_objective(), 42.5); + EXPECT_DOUBLE_EQ(header_out.solve_time(), 1.23); + + ASSERT_EQ(arrays_out.size(), 2u); + EXPECT_EQ(arrays_out[RESULT_PRIMAL_SOLUTION], primal); + EXPECT_EQ(arrays_out[RESULT_DUAL_SOLUTION], dual); +} + +TEST(PipeSerialization, Result_MIPFields) +{ + PipePair pp; + + ChunkedResultHeader header; + header.set_is_mip(true); + header.set_mip_termination_status(MIP_OPTIMAL); + header.set_mip_objective(99.0); + header.set_mip_gap(0.001); + header.set_error_message(""); + + auto solution = make_pattern(2000 * 8, 0x33); + std::map> arrays; + arrays[RESULT_MIP_SOLUTION] = solution; + + bool write_ok = false; + std::thread writer([&] { write_ok = write_result_to_pipe(pp.write_fd(), header, arrays); }); + + ChunkedResultHeader header_out; + std::map> arrays_out; + bool read_ok = read_result_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + + EXPECT_TRUE(header_out.is_mip()); + EXPECT_EQ(header_out.mip_termination_status(), MIP_OPTIMAL); + EXPECT_DOUBLE_EQ(header_out.mip_objective(), 99.0); + + ASSERT_EQ(arrays_out.size(), 1u); + EXPECT_EQ(arrays_out[RESULT_MIP_SOLUTION], solution); +} + +TEST(PipeSerialization, Result_EmptyArrays) +{ + PipePair pp; + + ChunkedResultHeader header; + header.set_is_mip(false); + header.set_error_message("solver failed"); + + std::map> arrays; // no arrays (error case) + + bool write_ok = false; + std::thread writer([&] { write_ok = write_result_to_pipe(pp.write_fd(), header, arrays); }); + + ChunkedResultHeader header_out; + std::map> arrays_out; + bool read_ok = read_result_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + EXPECT_EQ(header_out.error_message(), "solver failed"); + EXPECT_TRUE(arrays_out.empty()); +} + +// ============================================================================= +// Protobuf-only round-trip (write_protobuf_to_pipe / read_protobuf_from_pipe) +// ============================================================================= + +TEST(PipeSerialization, ProtobufRoundTrip) +{ + PipePair pp; + + ChunkedResultHeader msg; + msg.set_is_mip(true); + msg.set_primal_objective(3.14); + msg.set_error_message("hello"); + + bool write_ok = false; + std::thread writer([&] { write_ok = write_protobuf_to_pipe(pp.write_fd(), msg); }); + + ChunkedResultHeader msg_out; + bool read_ok = read_protobuf_from_pipe(pp.read_fd(), msg_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + EXPECT_TRUE(msg_out.is_mip()); + EXPECT_DOUBLE_EQ(msg_out.primal_objective(), 3.14); + EXPECT_EQ(msg_out.error_message(), "hello"); +} + +// ============================================================================= +// Larger transfer to exercise multi-iteration pipe I/O +// ============================================================================= + +TEST(PipeSerialization, Result_LargeArray) +{ + PipePair pp; + + ChunkedResultHeader header; + header.set_is_mip(false); + header.set_primal_objective(0.0); + + // ~4 MiB array — large enough to require many kernel-level pipe iterations. + constexpr size_t large_size = 4 * 1024 * 1024; + auto large_data = make_pattern(large_size, 0x77); + + std::map> arrays; + arrays[RESULT_PRIMAL_SOLUTION] = large_data; + + bool write_ok = false; + std::thread writer([&] { write_ok = write_result_to_pipe(pp.write_fd(), header, arrays); }); + + ChunkedResultHeader header_out; + std::map> arrays_out; + bool read_ok = read_result_from_pipe(pp.read_fd(), header_out, arrays_out); + + writer.join(); + + ASSERT_TRUE(write_ok); + ASSERT_TRUE(read_ok); + ASSERT_EQ(arrays_out.size(), 1u); + EXPECT_EQ(arrays_out[RESULT_PRIMAL_SOLUTION], large_data); +} + +// ============================================================================= +// serialize_submit_request_to_pipe (pure function, no pipe needed) +// ============================================================================= + +TEST(PipeSerialization, SerializeSubmitRequest) +{ + SubmitJobRequest request; + auto* lp = request.mutable_lp_request(); + lp->mutable_header()->set_problem_type(LP); + + auto blob = serialize_submit_request_to_pipe(request); + ASSERT_FALSE(blob.empty()); + + SubmitJobRequest parsed; + ASSERT_TRUE(parsed.ParseFromArray(blob.data(), static_cast(blob.size()))); + EXPECT_TRUE(parsed.has_lp_request()); + EXPECT_EQ(parsed.lp_request().header().problem_type(), LP); +} diff --git a/cpp/tests/linear_programming/grpc/grpc_test_log_capture.hpp b/cpp/tests/linear_programming/grpc/grpc_test_log_capture.hpp new file mode 100644 index 0000000000..37b1f698cf --- /dev/null +++ b/cpp/tests/linear_programming/grpc/grpc_test_log_capture.hpp @@ -0,0 +1,381 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +/** + * @file grpc_test_log_capture.hpp + * @brief Test utility for capturing and verifying logs from gRPC client and server + * + * This utility provides a unified way to capture logs from both client and server + * during integration tests, and provides assertion methods to verify expected log entries. + * + * Usage: + * @code + * GrpcTestLogCapture log_capture; + * + * // Configure client to capture debug logs + * grpc_client_config_t config; + * config.debug_log_callback = log_capture.client_callback(); + * + * // Set server log file path + * log_capture.set_server_log_path("/tmp/cuopt_test_server_19000.log"); + * + * // ... run test ... + * + * // Verify client logs + * EXPECT_TRUE(log_capture.client_log_contains("Connected to server")); + * EXPECT_TRUE(log_capture.client_log_contains_pattern("job_id=.*-.*-.*")); + * + * // Verify server logs + * EXPECT_TRUE(log_capture.server_log_contains("[Worker 0] Processing job")); + * EXPECT_TRUE(log_capture.server_log_contains_pattern("solve_.*p done")); + * @endcode + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuopt::linear_programming::testing { + +/** + * @brief Log entry with metadata + */ +struct LogEntry { + std::string message; + std::chrono::steady_clock::time_point timestamp; + std::string source; // "client" or "server" +}; + +/** + * @brief Log capture and verification utility for gRPC integration tests + * + * This class tracks log positions to ensure tests only see logs from the current test, + * not from previous tests. Call mark_test_start() at the beginning of each test. + */ +class GrpcTestLogCapture { + public: + GrpcTestLogCapture() = default; + + /** + * @brief Clear all captured client logs and reset server log position + * + * Call this at the start of each test to ensure you only see logs from the current test. + */ + void clear() + { + std::lock_guard lock(mutex_); + client_logs_.clear(); + + if (!server_log_path_.empty()) { + std::ifstream file(server_log_path_, std::ios::ate); + if (file.is_open()) { server_log_start_pos_ = file.tellg(); } + } + test_start_marked_ = true; + } + + /** + * @brief Mark the start of a test - records current server log file position + * + * After calling this, get_server_logs() will only return logs written after this point. + * This ensures tests don't see log entries from previous tests. + * + * Call this AFTER setting the server log path and AFTER the server has started. + */ + void mark_test_start() + { + std::lock_guard lock(mutex_); + client_logs_.clear(); + + // Record current end position of server log file + if (!server_log_path_.empty()) { + std::ifstream file(server_log_path_, std::ios::ate); + if (file.is_open()) { + server_log_start_pos_ = file.tellg(); + } else { + server_log_start_pos_ = 0; + } + } else { + server_log_start_pos_ = 0; + } + test_start_marked_ = true; + } + + // ========================================================================= + // Client Log Capture + // ========================================================================= + + /** + * @brief Get a callback function to capture client debug logs + * + * Use this with grpc_client_config_t::debug_log_callback: + * @code + * config.debug_log_callback = log_capture.client_callback(); + * @endcode + */ + std::function client_callback() + { + return [this](const std::string& msg) { add_client_log(msg); }; + } + + /** + * @brief Manually add a client log entry + */ + void add_client_log(const std::string& message) + { + std::lock_guard lock(mutex_); + LogEntry entry; + entry.message = message; + entry.timestamp = std::chrono::steady_clock::now(); + entry.source = "client"; + client_logs_.push_back(entry); + } + + /** + * @brief Get all captured client logs as a single string + */ + std::string get_client_logs() const + { + std::lock_guard lock(mutex_); + std::ostringstream oss; + for (const auto& entry : client_logs_) { + oss << entry.message << "\n"; + } + return oss.str(); + } + + /** + * @brief Get client log entries + */ + std::vector get_client_log_entries() const + { + std::lock_guard lock(mutex_); + return client_logs_; + } + + /** + * @brief Check if client logs contain a substring + */ + bool client_log_contains(const std::string& substring) const + { + std::lock_guard lock(mutex_); + for (const auto& entry : client_logs_) { + if (entry.message.find(substring) != std::string::npos) { return true; } + } + return false; + } + + /** + * @brief Check if client logs contain a pattern (regex) + */ + bool client_log_contains_pattern(const std::string& pattern) const + { + std::lock_guard lock(mutex_); + std::regex re(pattern); + for (const auto& entry : client_logs_) { + if (std::regex_search(entry.message, re)) { return true; } + } + return false; + } + + /** + * @brief Count occurrences of a substring in client logs + */ + int client_log_count(const std::string& substring) const + { + std::lock_guard lock(mutex_); + int count = 0; + for (const auto& entry : client_logs_) { + if (entry.message.find(substring) != std::string::npos) { ++count; } + } + return count; + } + + // ========================================================================= + // Server Log Capture + // ========================================================================= + + /** + * @brief Set the path to the server log file + * + * The server process redirects stdout/stderr to a log file. This method + * sets the path so that server logs can be read for verification. + * + * Note: Call mark_test_start() after this to record the starting position. + */ + void set_server_log_path(const std::string& path) + { + std::lock_guard lock(mutex_); + server_log_path_ = path; + server_log_start_pos_ = 0; + test_start_marked_ = false; + } + + /** + * @brief Read server logs from the configured file path + * + * If mark_test_start() was called, this only returns logs written after that point. + * Otherwise, returns all logs in the file. + * + * @param since_test_start If true (default), only return logs since mark_test_start(). + * If false, return all logs in the file. + */ + std::string get_server_logs(bool since_test_start = true) const + { + std::string path; + std::streampos start_pos; + bool marked; + { + std::lock_guard lock(mutex_); + path = server_log_path_; + start_pos = server_log_start_pos_; + marked = test_start_marked_; + } + if (path.empty()) { return ""; } + + std::ifstream file(path); + if (!file.is_open()) { return ""; } + + if (since_test_start && marked && start_pos > 0) { file.seekg(start_pos); } + + std::ostringstream oss; + oss << file.rdbuf(); + return oss.str(); + } + + /** + * @brief Read all server logs (ignoring test start marker) + * + * Useful for debugging when you need to see the full log history. + */ + std::string get_all_server_logs() const { return get_server_logs(false); } + + /** + * @brief Check if server logs contain a substring + */ + bool server_log_contains(const std::string& substring) const + { + std::string logs = get_server_logs(); + return logs.find(substring) != std::string::npos; + } + + /** + * @brief Check if server logs contain a pattern (regex) + */ + bool server_log_contains_pattern(const std::string& pattern) const + { + std::string logs = get_server_logs(); + std::regex re(pattern); + return std::regex_search(logs, re); + } + + /** + * @brief Count occurrences of a substring in server logs + */ + int server_log_count(const std::string& substring) const + { + if (substring.empty()) { return 0; } + std::string logs = get_server_logs(); + int count = 0; + size_t pos = 0; + while ((pos = logs.find(substring, pos)) != std::string::npos) { + ++count; + pos += substring.length(); + } + return count; + } + + /** + * @brief Wait for a specific string to appear in server logs + * + * Polls the server log file until the string appears or timeout. + * Only searches logs written after mark_test_start() was called. + * + * @param substring The string to wait for + * @param timeout_ms Maximum time to wait in milliseconds + * @param poll_interval_ms How often to check (default 100ms) + * @return true if the string was found, false if timeout + */ + bool wait_for_server_log(const std::string& substring, + int timeout_ms, + int poll_interval_ms = 100) const + { + auto start = std::chrono::steady_clock::now(); + while (true) { + // server_log_contains() respects the test start marker + if (server_log_contains(substring)) { return true; } + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (elapsed.count() >= timeout_ms) { return false; } + + std::this_thread::sleep_for(std::chrono::milliseconds(poll_interval_ms)); + } + } + + // ========================================================================= + // Combined Log Verification + // ========================================================================= + + /** + * @brief Check if either client or server logs contain a substring + */ + bool any_log_contains(const std::string& substring) const + { + return client_log_contains(substring) || server_log_contains(substring); + } + + /** + * @brief Print all captured logs for debugging + * + * @param include_all_server_logs If true, print all server logs (not just since test start) + */ + void dump_logs(std::ostream& os = std::cout, bool include_all_server_logs = false) const + { + os << "=== Client Logs ===\n"; + os << get_client_logs(); + os << "\n=== Server Logs"; + if (test_start_marked_ && !include_all_server_logs) { + os << " (since test start)"; + } else { + os << " (all)"; + } + os << " ===\n"; + os << get_server_logs(!include_all_server_logs); + os << "\n==================\n"; + } + + /** + * @brief Check if mark_test_start() has been called + */ + bool is_test_start_marked() const { return test_start_marked_; } + + /** + * @brief Get the server log file path + */ + std::string server_log_path() const + { + std::lock_guard lock(mutex_); + return server_log_path_; + } + + private: + mutable std::mutex mutex_; + std::vector client_logs_; + std::string server_log_path_; + std::streampos server_log_start_pos_ = 0; // Position in server log file when test started + std::atomic test_start_marked_{false}; +}; + +} // namespace cuopt::linear_programming::testing diff --git a/dependencies.yaml b/dependencies.yaml index 011dfbcee6..54ddd36856 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -302,6 +302,8 @@ dependencies: - tbb-devel - zlib - bzip2 + - openssl + - c-ares test_cpp: common: - output_types: [conda] diff --git a/python/cuopt/cuopt/tests/linear_programming/test_cpu_only_execution.py b/python/cuopt/cuopt/tests/linear_programming/test_cpu_only_execution.py index 792942aae9..652c426c44 100644 --- a/python/cuopt/cuopt/tests/linear_programming/test_cpu_only_execution.py +++ b/python/cuopt/cuopt/tests/linear_programming/test_cpu_only_execution.py @@ -4,26 +4,35 @@ """ Tests for CPU-only execution mode and solution interface polymorphism. -TestCPUOnlyExecution / TestCuoptCliCPUOnly: - Run in subprocesses with CUDA_VISIBLE_DEVICES="" so the CUDA driver - never initializes. Subprocess isolation is required because the - driver reads that variable once at init time. +These tests verify that cuOpt can run on a CPU host without GPU access, +forwarding solves to a real cuopt_grpc_server over gRPC. A single shared +server is started once per test class to avoid per-test startup overhead. TestSolutionInterfacePolymorphism: Run in-process on real GPU hardware and assert correctness of solution values against known optima. """ +import logging import os +import re +import shutil +import signal +import socket import subprocess import sys +import time import cuopt_mps_parser import pytest from cuopt import linear_programming from cuopt.linear_programming.solver.solver_parameters import CUOPT_TIME_LIMIT -RAPIDS_DATASET_ROOT_DIR = os.environ.get("RAPIDS_DATASET_ROOT_DIR", "./") +logger = logging.getLogger(__name__) + +RAPIDS_DATASET_ROOT_DIR = os.environ.get( + "RAPIDS_DATASET_ROOT_DIR", "./datasets" +) # --------------------------------------------------------------------------- @@ -31,12 +40,235 @@ # --------------------------------------------------------------------------- -def _cpu_only_env(): +def _find_grpc_server(): + """Locate cuopt_grpc_server binary.""" + env_path = os.environ.get("CUOPT_GRPC_SERVER_PATH") + if env_path and os.path.isfile(env_path) and os.access(env_path, os.X_OK): + return env_path + + found = shutil.which("cuopt_grpc_server") + if found: + return found + + for candidate in [ + "./cuopt_grpc_server", + "../cpp/build/cuopt_grpc_server", + "../../cpp/build/cuopt_grpc_server", + ]: + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return os.path.abspath(candidate) + + conda_prefix = os.environ.get("CONDA_PREFIX", "") + if conda_prefix: + p = os.path.join(conda_prefix, "bin", "cuopt_grpc_server") + if os.path.isfile(p) and os.access(p, os.X_OK): + return p + return None + + +def _wait_for_port(port, timeout=15): + """Block until TCP port accepts connections or timeout expires.""" + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + with socket.create_connection(("127.0.0.1", port), timeout=1): + return True + except OSError: + time.sleep(0.2) + return False + + +def _cpu_only_env(port): """Return an env dict that hides all GPUs and enables remote mode.""" env = os.environ.copy() + for key in [k for k in env if k.startswith("CUOPT_TLS_")]: + env.pop(key) env["CUDA_VISIBLE_DEVICES"] = "" env["CUOPT_REMOTE_HOST"] = "localhost" - env["CUOPT_REMOTE_PORT"] = "12345" + env["CUOPT_REMOTE_PORT"] = str(port) + return env + + +def _parse_cli_output(output): + """Extract solver status and objective value from cuopt_cli output. + + Handles both the LP summary format + (``Status: Optimal Objective: -464.753 ... Time: 0.1s``) + and the MIP format + (``Optimal solution found.`` + ``Solution objective: 2.000000 ...``). + """ + result = {"status": "Unknown", "objective_value": float("nan")} + + for line in output.split("\n"): + stripped = line.strip() + + # LP summary: "Status: Optimal Objective: -464.753 ... Time: 0.1s" + if stripped.startswith("Status:") and "Time:" in stripped: + m = re.match(r"Status:\s*(\S+)", stripped) + if m: + result["status"] = m.group(1) + m = re.search( + r"Objective:\s*([+-]?\d+\.?\d*(?:[eE][+-]?\d+)?)", + stripped, + ) + if m: + result["objective_value"] = float(m.group(1)) + continue + + # MIP termination + if stripped == "Optimal solution found.": + result["status"] = "Optimal" + continue + + # MIP solution: "Solution objective: 2.000000 , ..." + m = re.match( + r"Solution objective:\s*([+-]?\d+\.?\d*(?:[eE][+-]?\d+)?)", + stripped, + ) + if m: + result["objective_value"] = float(m.group(1)) + continue + + return result + + +def _generate_test_certs(cert_dir): + """Generate a CA, server cert, and client cert for TLS/mTLS tests. + + Returns True on success, False if openssl is missing or a command fails. + """ + if not shutil.which("openssl"): + return False + + def _run(cmd): + result = subprocess.run(cmd, capture_output=True, timeout=30) + if result.returncode != 0: + logger.warning( + "cert command failed: %s (rc=%d)\nstdout: %s\nstderr: %s", + cmd, + result.returncode, + result.stdout.decode(errors="replace"), + result.stderr.decode(errors="replace"), + ) + return False + return True + + ca_key = os.path.join(cert_dir, "ca.key") + ca_crt = os.path.join(cert_dir, "ca.crt") + if not _run( + [ + "openssl", + "req", + "-x509", + "-newkey", + "rsa:2048", + "-keyout", + ca_key, + "-out", + ca_crt, + "-days", + "1", + "-nodes", + "-subj", + "/CN=TestCA", + ] + ): + return False + + server_key = os.path.join(cert_dir, "server.key") + server_csr = os.path.join(cert_dir, "server.csr") + server_crt = os.path.join(cert_dir, "server.crt") + server_ext = os.path.join(cert_dir, "server.ext") + if not _run( + [ + "openssl", + "req", + "-newkey", + "rsa:2048", + "-keyout", + server_key, + "-out", + server_csr, + "-nodes", + "-subj", + "/CN=localhost", + ] + ): + return False + with open(server_ext, "w") as f: + f.write("subjectAltName=DNS:localhost,IP:127.0.0.1\n") + if not _run( + [ + "openssl", + "x509", + "-req", + "-in", + server_csr, + "-CA", + ca_crt, + "-CAkey", + ca_key, + "-CAcreateserial", + "-out", + server_crt, + "-days", + "1", + "-extfile", + server_ext, + ] + ): + return False + + client_key = os.path.join(cert_dir, "client.key") + client_csr = os.path.join(cert_dir, "client.csr") + client_crt = os.path.join(cert_dir, "client.crt") + if not _run( + [ + "openssl", + "req", + "-newkey", + "rsa:2048", + "-keyout", + client_key, + "-out", + client_csr, + "-nodes", + "-subj", + "/CN=TestClient", + ] + ): + return False + if not _run( + [ + "openssl", + "x509", + "-req", + "-in", + client_csr, + "-CA", + ca_crt, + "-CAkey", + ca_key, + "-CAcreateserial", + "-out", + client_crt, + "-days", + "1", + ] + ): + return False + + return True + + +def _tls_env(port, cert_dir, mtls=False): + """Return an env dict for remote execution over TLS (or mTLS).""" + env = _cpu_only_env(port) + env["CUOPT_TLS_ENABLED"] = "1" + env["CUOPT_TLS_ROOT_CERT"] = os.path.join(cert_dir, "ca.crt") + if mtls: + env["CUOPT_TLS_CLIENT_CERT"] = os.path.join(cert_dir, "client.crt") + env["CUOPT_TLS_CLIENT_KEY"] = os.path.join(cert_dir, "client.key") return env @@ -88,6 +320,13 @@ def _impl_lp_solve_cpu_only(): obj = solution.get_primal_objective() assert obj is not None, "objective is None" + _AFIRO_OBJ = -464.7531428571 + rel_err = abs(obj - _AFIRO_OBJ) / max(abs(_AFIRO_OBJ), 1e-12) + assert rel_err < 0.01, ( + f"objective {obj} differs from expected {_AFIRO_OBJ} " + f"(rel error {rel_err:.4e})" + ) + def _impl_lp_dual_solution_cpu_only(): """Dual solution and reduced costs are correctly sized.""" @@ -110,6 +349,14 @@ def _impl_lp_dual_solution_cpu_only(): rc = solution.get_reduced_cost() assert len(rc) == n_vars, f"reduced_cost size {len(rc)} != n_vars {n_vars}" + obj = solution.get_primal_objective() + _AFIRO_OBJ = -464.7531428571 + rel_err = abs(obj - _AFIRO_OBJ) / max(abs(_AFIRO_OBJ), 1e-12) + assert rel_err < 0.01, ( + f"dual test: objective {obj} differs from expected {_AFIRO_OBJ} " + f"(rel error {rel_err:.4e})" + ) + def _impl_mip_solve_cpu_only(): """MIP solve returns correctly-sized solution vector.""" @@ -131,6 +378,18 @@ def _impl_mip_solve_cpu_only(): vals = solution.get_primal_solution() assert len(vals) == n_vars, f"solution size {len(vals)} != n_vars {n_vars}" + obj_coeffs = dm.get_objective_coefficients() + computed_obj = sum(c * v for c, v in zip(obj_coeffs, vals)) + reported_obj = solution.get_primal_objective() + if abs(reported_obj) > 1e-12: + rel_err = abs(computed_obj - reported_obj) / abs(reported_obj) + else: + rel_err = abs(computed_obj - reported_obj) + assert rel_err < 0.01, ( + f"MIP objective mismatch: computed {computed_obj} vs reported " + f"{reported_obj} (rel error {rel_err:.4e})" + ) + def _impl_warmstart_cpu_only(): """Warmstart round-trip works without touching CUDA.""" @@ -138,6 +397,7 @@ def _impl_warmstart_cpu_only(): from cuopt.linear_programming.solver.solver_parameters import ( CUOPT_METHOD, CUOPT_ITERATION_LIMIT, + CUOPT_PRESOLVE, ) from cuopt.linear_programming.solver_settings import SolverMethod import cuopt_mps_parser @@ -148,6 +408,7 @@ def _impl_warmstart_cpu_only(): settings = linear_programming.SolverSettings() settings.set_parameter(CUOPT_METHOD, SolverMethod.PDLP) + settings.set_parameter(CUOPT_PRESOLVE, 0) settings.set_parameter(CUOPT_ITERATION_LIMIT, 100) sol1 = linear_programming.Solve(dm, settings) @@ -160,36 +421,84 @@ def _impl_warmstart_cpu_only(): assert sol2.get_primal_solution() is not None +# --------------------------------------------------------------------------- +# Shared fixture helpers (used by TestCPUOnlyExecution and TestCuoptCliCPUOnly) +# --------------------------------------------------------------------------- + + +def _start_grpc_server_fixture(port_offset): + """Locate the server, start it on BASE + port_offset, return (proc, env).""" + server_bin = _find_grpc_server() + if server_bin is None: + pytest.skip("cuopt_grpc_server not found") + + port = int(os.environ.get("CUOPT_TEST_PORT_BASE", "18000")) + port_offset + proc = subprocess.Popen( + [server_bin, "--port", str(port), "--workers", "1"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + if not _wait_for_port(port, timeout=15): + proc.kill() + proc.wait() + pytest.fail("cuopt_grpc_server failed to start within 15s") + + return proc, _cpu_only_env(port) + + +def _stop_grpc_server(proc): + """Gracefully shut down a server process.""" + proc.send_signal(signal.SIGTERM) + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + + # --------------------------------------------------------------------------- # CPU-only Python tests (subprocess required) # --------------------------------------------------------------------------- class TestCPUOnlyExecution: - """Tests that run with CUDA_VISIBLE_DEVICES='' to simulate CPU-only hosts.""" + """Tests that run with CUDA_VISIBLE_DEVICES='' to simulate CPU-only hosts. + + A shared cuopt_grpc_server is started once for the whole class. + """ - @pytest.fixture - def env(self): - return _cpu_only_env() + @pytest.fixture(scope="class") + def cpu_only_env_with_server(self): + proc, env = _start_grpc_server_fixture(port_offset=600) + yield env + _stop_grpc_server(proc) - def test_lp_solve_cpu_only(self, env): + def test_lp_solve_cpu_only(self, cpu_only_env_with_server): """LP solve returns correctly-sized solution vectors.""" - result = _run_in_subprocess(_impl_lp_solve_cpu_only, env=env) + result = _run_in_subprocess( + _impl_lp_solve_cpu_only, env=cpu_only_env_with_server + ) assert result.returncode == 0, f"Test failed:\n{result.stderr}" - def test_lp_dual_solution_cpu_only(self, env): + def test_lp_dual_solution_cpu_only(self, cpu_only_env_with_server): """Dual solution and reduced costs are correctly sized.""" - result = _run_in_subprocess(_impl_lp_dual_solution_cpu_only, env=env) + result = _run_in_subprocess( + _impl_lp_dual_solution_cpu_only, env=cpu_only_env_with_server + ) assert result.returncode == 0, f"Test failed:\n{result.stderr}" - def test_mip_solve_cpu_only(self, env): + def test_mip_solve_cpu_only(self, cpu_only_env_with_server): """MIP solve returns correctly-sized solution vector.""" - result = _run_in_subprocess(_impl_mip_solve_cpu_only, env=env) + result = _run_in_subprocess( + _impl_mip_solve_cpu_only, env=cpu_only_env_with_server + ) assert result.returncode == 0, f"Test failed:\n{result.stderr}" - def test_warmstart_cpu_only(self, env): + def test_warmstart_cpu_only(self, cpu_only_env_with_server): """Warmstart round-trip works without touching CUDA.""" - result = _run_in_subprocess(_impl_warmstart_cpu_only, env=env) + result = _run_in_subprocess( + _impl_warmstart_cpu_only, env=cpu_only_env_with_server + ) assert result.returncode == 0, f"Test failed:\n{result.stderr}" @@ -199,16 +508,19 @@ def test_warmstart_cpu_only(self, env): class TestCuoptCliCPUOnly: - """Test that cuopt_cli runs without CUDA in remote-execution mode.""" + """Test that cuopt_cli runs without CUDA in remote-execution mode. - @pytest.fixture - def env(self): - return _cpu_only_env() + A shared cuopt_grpc_server is started once for the whole class. + """ + + @pytest.fixture(scope="class") + def cpu_only_env_with_server(self): + proc, env = _start_grpc_server_fixture(port_offset=700) + yield env + _stop_grpc_server(proc) @staticmethod def _find_cuopt_cli(): - import shutil - for loc in [ shutil.which("cuopt_cli"), "./cuopt_cli", @@ -221,7 +533,7 @@ def _find_cuopt_cli(): conda_prefix = os.environ.get("CONDA_PREFIX", "") if conda_prefix: p = os.path.join(conda_prefix, "bin", "cuopt_cli") - if os.path.isfile(p): + if os.path.isfile(p) and os.access(p, os.X_OK): return p return None @@ -233,15 +545,24 @@ def _find_cuopt_cli(): "CUDA initialization failed", ] - def _run_cli(self, mps_file, env): + def _run_cli(self, mps_file, env, extra_args=None): + """Run cuopt_cli on *mps_file* in remote-execution mode. + + Returns the combined stdout+stderr so callers can parse it. + Asserts no CUDA errors and zero exit code. + """ cli = self._find_cuopt_cli() if cli is None: pytest.skip("cuopt_cli not found") if not os.path.exists(mps_file): pytest.skip(f"Test file not found: {mps_file}") + cmd = [cli, mps_file, "--time-limit", "60"] + if extra_args: + cmd.extend(extra_args) + result = subprocess.run( - [cli, mps_file, "--time-limit", "60"], + cmd, env=env, capture_output=True, text=True, @@ -259,15 +580,62 @@ def _run_cli(self, mps_file, env): assert result.returncode == 0, ( f"cuopt_cli exited with {result.returncode}" ) + return combined + + _REMOTE_INDICATORS = [ + "connecting to gRPC server", + "solve completed successfully", + ] + + def _assert_remote_execution(self, output): + """Check that log output contains evidence of remote gRPC execution.""" + for indicator in self._REMOTE_INDICATORS: + assert indicator in output, ( + f"Remote execution indicator '{indicator}' not found " + "in CLI output -- solve may not have been forwarded" + ) - def test_cuopt_cli_lp_cpu_only(self, env): - self._run_cli( + def test_cli_lp_remote(self, cpu_only_env_with_server): + """LP solve via cuopt_cli runs remotely with correct objective.""" + output = self._run_cli( f"{RAPIDS_DATASET_ROOT_DIR}/linear_programming/afiro_original.mps", - env, + cpu_only_env_with_server, ) + self._assert_remote_execution(output) - def test_cuopt_cli_mip_cpu_only(self, env): - self._run_cli(f"{RAPIDS_DATASET_ROOT_DIR}/mip/bb_optimality.mps", env) + parsed = _parse_cli_output(output) + assert parsed["status"] == "Optimal", ( + f"Expected Optimal, got {parsed['status']}" + ) + expected_obj = -464.7531428571 + rel_err = abs(parsed["objective_value"] - expected_obj) / abs( + expected_obj + ) + assert rel_err < 0.01, ( + f"Objective {parsed['objective_value']} differs from expected " + f"{expected_obj} (rel error {rel_err:.4e})" + ) + + def test_cli_mip_remote(self, cpu_only_env_with_server): + """MIP solve via cuopt_cli runs remotely with correct objective.""" + output = self._run_cli( + f"{RAPIDS_DATASET_ROOT_DIR}/mip/bb_optimality.mps", + cpu_only_env_with_server, + ) + self._assert_remote_execution(output) + + parsed = _parse_cli_output(output) + assert parsed["status"] == "Optimal", ( + f"Expected Optimal, got {parsed['status']}" + ) + expected_obj = 2.0 + rel_err = abs(parsed["objective_value"] - expected_obj) / max( + abs(expected_obj), 1e-12 + ) + assert rel_err < 0.01, ( + f"Objective {parsed['objective_value']} differs from expected " + f"{expected_obj} (rel error {rel_err:.4e})" + ) # --------------------------------------------------------------------------- @@ -331,9 +699,164 @@ def test_mip_solution_values(self): assert stats["mip_gap"] >= 0, f"Negative MIP gap: {stats['mip_gap']}" +# --------------------------------------------------------------------------- +# TLS tests (subprocess required, server with --tls) +# --------------------------------------------------------------------------- + + +class TestTLSExecution: + """Test remote execution over a TLS-encrypted channel. + + A shared cuopt_grpc_server is started with --tls and self-signed certs. + The client connects using CUOPT_TLS_* env vars. + """ + + @pytest.fixture(scope="class") + def tls_env_with_server(self, tmp_path_factory): + cert_dir = str(tmp_path_factory.mktemp("tls_certs")) + if not _generate_test_certs(cert_dir): + pytest.skip("openssl not available or cert generation failed") + + server_bin = _find_grpc_server() + if server_bin is None: + pytest.skip("cuopt_grpc_server not found") + + port = int(os.environ.get("CUOPT_TEST_PORT_BASE", "18000")) + 800 + proc = subprocess.Popen( + [ + server_bin, + "--port", + str(port), + "--workers", + "1", + "--tls", + "--tls-cert", + os.path.join(cert_dir, "server.crt"), + "--tls-key", + os.path.join(cert_dir, "server.key"), + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + if not _wait_for_port(port, timeout=15): + proc.kill() + proc.wait() + pytest.fail("TLS cuopt_grpc_server failed to start within 15s") + + env = _tls_env(port, cert_dir, mtls=False) + yield env + + _stop_grpc_server(proc) + + def test_lp_solve_tls(self, tls_env_with_server): + """LP solve succeeds over a TLS channel.""" + result = _run_in_subprocess( + _impl_lp_solve_cpu_only, env=tls_env_with_server + ) + assert result.returncode == 0, f"TLS LP solve failed:\n{result.stderr}" + + +# --------------------------------------------------------------------------- +# mTLS tests (subprocess required, server with --tls + --require-client-cert) +# --------------------------------------------------------------------------- + + +class TestMTLSExecution: + """Test remote execution over an mTLS-encrypted channel. + + A shared cuopt_grpc_server is started with --tls, --tls-root, and + --require-client-cert. The client must present a valid certificate + signed by the test CA. + """ + + @pytest.fixture(scope="class") + def mtls_server_info(self, tmp_path_factory): + cert_dir = str(tmp_path_factory.mktemp("mtls_certs")) + if not _generate_test_certs(cert_dir): + pytest.skip("openssl not available or cert generation failed") + + server_bin = _find_grpc_server() + if server_bin is None: + pytest.skip("cuopt_grpc_server not found") + + port = int(os.environ.get("CUOPT_TEST_PORT_BASE", "18000")) + 900 + proc = subprocess.Popen( + [ + server_bin, + "--port", + str(port), + "--workers", + "1", + "--tls", + "--tls-cert", + os.path.join(cert_dir, "server.crt"), + "--tls-key", + os.path.join(cert_dir, "server.key"), + "--tls-root", + os.path.join(cert_dir, "ca.crt"), + "--require-client-cert", + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + if not _wait_for_port(port, timeout=15): + proc.kill() + proc.wait() + pytest.fail("mTLS cuopt_grpc_server failed to start within 15s") + + yield {"port": port, "cert_dir": cert_dir} + + proc.send_signal(signal.SIGTERM) + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + + def test_lp_solve_mtls(self, mtls_server_info): + """LP solve succeeds over an mTLS channel with valid client cert.""" + env = _tls_env( + mtls_server_info["port"], + mtls_server_info["cert_dir"], + mtls=True, + ) + result = _run_in_subprocess(_impl_lp_solve_cpu_only, env=env) + assert result.returncode == 0, ( + f"mTLS LP solve failed:\n{result.stderr}" + ) + + def test_mtls_rejects_no_client_cert(self, mtls_server_info): + """Server rejects a client that does not present a certificate.""" + env = _tls_env( + mtls_server_info["port"], + mtls_server_info["cert_dir"], + mtls=False, + ) + result = _run_in_subprocess( + _impl_lp_solve_cpu_only, env=env, timeout=30 + ) + assert result.returncode != 0, ( + "Expected failure when connecting without client cert" + ) + + # --------------------------------------------------------------------------- # Subprocess entry point # --------------------------------------------------------------------------- if __name__ == "__main__": - globals()[sys.argv[1]]() + _ALLOWED_ENTRIES = { + "_impl_lp_solve_cpu_only": _impl_lp_solve_cpu_only, + "_impl_lp_dual_solution_cpu_only": _impl_lp_dual_solution_cpu_only, + "_impl_mip_solve_cpu_only": _impl_mip_solve_cpu_only, + "_impl_warmstart_cpu_only": _impl_warmstart_cpu_only, + } + name = sys.argv[1] if len(sys.argv) > 1 else "" + if name not in _ALLOWED_ENTRIES: + print(f"Unknown entry point: {name!r}", file=sys.stderr) + print( + f"Available: {', '.join(sorted(_ALLOWED_ENTRIES))}", + file=sys.stderr, + ) + sys.exit(1) + _ALLOWED_ENTRIES[name]() diff --git a/python/libcuopt/CMakeLists.txt b/python/libcuopt/CMakeLists.txt index 7868d66567..b524d5f6e3 100644 --- a/python/libcuopt/CMakeLists.txt +++ b/python/libcuopt/CMakeLists.txt @@ -39,6 +39,9 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(argparse) +# gRPC must be available as an installed CMake package (gRPCConfig.cmake). +# On RockyLinux 8 wheel builds we install it in CI via ci/utils/install_protobuf_grpc.sh. +find_package(gRPC CONFIG REQUIRED) find_package(Boost 1.65 REQUIRED) if(Boost_FOUND) @@ -93,3 +96,4 @@ message(STATUS "libcuopt: Final RPATH = ${rpaths}") set_property(TARGET cuopt PROPERTY INSTALL_RPATH ${rpaths} APPEND) set_property(TARGET cuopt_cli PROPERTY INSTALL_RPATH ${rpaths} APPEND) +set_property(TARGET cuopt_grpc_server PROPERTY INSTALL_RPATH ${rpaths} APPEND)