From 5528b39d64f590c96be9eefcb693adebe10b0256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Feb 2026 20:04:58 -0800 Subject: [PATCH] Extract config structs and unified create_dsp() construction path Add typed config structs (LinearConfig, LSTMConfig, ConvNetConfig, WaveNetConfig) and parse_config_json() functions per architecture. Introduce ModelConfig variant, ModelMetadata, and create_dsp() for unified model construction independent of JSON parsing. Refactor get_dsp() to use the new unified path. Register Linear factory (was previously missing). Co-Authored-By: Claude Opus 4.6 --- NAM/convnet.cpp | 29 +++++----- NAM/convnet.h | 17 ++++++ NAM/dsp.cpp | 26 ++++++--- NAM/dsp.h | 15 ++++++ NAM/get_dsp.cpp | 128 ++++++++++++++++++++++++++++++--------------- NAM/lstm.cpp | 23 +++++--- NAM/lstm.h | 15 ++++++ NAM/model_config.h | 51 ++++++++++++++++++ NAM/wavenet.cpp | 77 +++++++++++++-------------- NAM/wavenet.h | 23 ++++++++ 10 files changed, 293 insertions(+), 111 deletions(-) create mode 100644 NAM/model_config.h diff --git a/NAM/convnet.cpp b/NAM/convnet.cpp index fc7c151..3be4ca6 100644 --- a/NAM/convnet.cpp +++ b/NAM/convnet.cpp @@ -322,22 +322,27 @@ void nam::convnet::ConvNet::_rewind_buffers_() this->Buffer::_rewind_buffers_(); } +// Config parser +nam::convnet::ConvNetConfig nam::convnet::parse_config_json(const nlohmann::json& config) +{ + ConvNetConfig c; + c.channels = config["channels"]; + c.dilations = config["dilations"].get>(); + c.batchnorm = config["batchnorm"]; + c.activation = activations::ActivationConfig::from_json(config["activation"]); + c.groups = config.value("groups", 1); + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + // Factory std::unique_ptr nam::convnet::Factory(const nlohmann::json& config, std::vector& weights, const double expectedSampleRate) { - const int channels = config["channels"]; - const std::vector dilations = config["dilations"]; - const bool batchnorm = config["batchnorm"]; - // Parse JSON into typed ActivationConfig at model loading boundary - const activations::ActivationConfig activation_config = - activations::ActivationConfig::from_json(config["activation"]); - const int groups = config.value("groups", 1); // defaults to 1 - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique( - in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups); + auto c = parse_config_json(config); + return std::make_unique(c.in_channels, c.out_channels, c.channels, c.dilations, c.batchnorm, + c.activation, weights, expectedSampleRate, c.groups); } namespace diff --git a/NAM/convnet.h b/NAM/convnet.h index 0d963df..c1a961d 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -165,6 +165,23 @@ class ConvNet : public Buffer int PrewarmSamples() override { return mPrewarmSamples; }; }; +/// \brief Configuration for a ConvNet model +struct ConvNetConfig +{ + int channels; + std::vector dilations; + bool batchnorm; + activations::ActivationConfig activation; + int groups; + int in_channels; + int out_channels; +}; + +/// \brief Parse ConvNet configuration from JSON +/// \param config JSON configuration object +/// \return ConvNetConfig +ConvNetConfig parse_config_json(const nlohmann::json& config); + /// \brief Factory function to instantiate ConvNet from JSON /// \param config JSON configuration object /// \param weights Model weights vector diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 05dab09..8bc4c3e 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -300,16 +300,30 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num nam::Buffer::_advance_input_buffer_(num_frames); } +// Config parser +nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config) +{ + LinearConfig c; + c.receptive_field = config["receptive_field"]; + c.bias = config["bias"]; + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + // Factory std::unique_ptr nam::linear::Factory(const nlohmann::json& config, std::vector& weights, const double expectedSampleRate) { - const int receptive_field = config["receptive_field"]; - const bool bias = config["bias"]; - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique(in_channels, out_channels, receptive_field, bias, weights, expectedSampleRate); + auto c = parse_config_json(config); + return std::make_unique(c.in_channels, c.out_channels, c.receptive_field, c.bias, weights, + expectedSampleRate); +} + +// Register the factory +namespace +{ +static nam::factory::Helper _register_Linear("Linear", nam::linear::Factory); } // NN modules ================================================================= diff --git a/NAM/dsp.h b/NAM/dsp.h index 1313ad9..7c47d38 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -258,6 +258,21 @@ class Linear : public Buffer namespace linear { + +/// \brief Configuration for a Linear model +struct LinearConfig +{ + int receptive_field; + bool bias; + int in_channels; + int out_channels; +}; + +/// \brief Parse Linear configuration from JSON +/// \param config JSON configuration object +/// \return LinearConfig +LinearConfig parse_config_json(const nlohmann::json& config); + /// \brief Factory function to instantiate Linear model from JSON /// \param config JSON configuration object /// \param weights Model weights vector diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index 57d0fbd..efd31d0 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "dsp.h" #include "registry.h" @@ -11,6 +12,7 @@ #include "convnet.h" #include "wavenet.h" #include "get_dsp.h" +#include "model_config.h" namespace nam { @@ -146,62 +148,102 @@ std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConf return get_dsp(conf); } -struct OptionalValue +// ============================================================================= +// Unified construction path +// ============================================================================= + +ModelConfig parse_model_config_json(const std::string& architecture, const nlohmann::json& config, double sample_rate) { - bool have = false; - double value = 0.0; -}; + if (architecture == "Linear") + return linear::parse_config_json(config); + else if (architecture == "LSTM") + return lstm::parse_config_json(config); + else if (architecture == "ConvNet") + return convnet::parse_config_json(config); + else if (architecture == "WaveNet") + return wavenet::parse_config_json(config, sample_rate); + else + throw std::runtime_error("Unknown architecture: " + architecture); +} -std::unique_ptr get_dsp(dspData& conf) +namespace { - verify_config_version(conf.version); - auto& architecture = conf.architecture; - nlohmann::json& config = conf.config; - std::vector& weights = conf.weights; - OptionalValue loudness, inputLevel, outputLevel; +void apply_metadata(DSP& dsp, const ModelMetadata& metadata) +{ + if (metadata.loudness.has_value()) + dsp.SetLoudness(metadata.loudness.value()); + if (metadata.input_level.has_value()) + dsp.SetInputLevel(metadata.input_level.value()); + if (metadata.output_level.has_value()) + dsp.SetOutputLevel(metadata.output_level.value()); +} + +} // anonymous namespace + +std::unique_ptr create_dsp(ModelConfig config, std::vector weights, const ModelMetadata& metadata) +{ + const double sample_rate = metadata.sample_rate; - auto AssignOptional = [&conf](const std::string key, OptionalValue& v) { - if (conf.metadata.find(key) != conf.metadata.end()) - { - if (!conf.metadata[key].is_null()) + std::unique_ptr out = std::visit( + [&](auto&& cfg) -> std::unique_ptr { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + return std::make_unique(cfg.in_channels, cfg.out_channels, cfg.receptive_field, cfg.bias, weights, + sample_rate); + } + else if constexpr (std::is_same_v) + { + return std::make_unique(cfg.in_channels, cfg.out_channels, cfg.num_layers, cfg.input_size, + cfg.hidden_size, weights, sample_rate); + } + else if constexpr (std::is_same_v) { - v.value = conf.metadata[key]; - v.have = true; + return std::make_unique(cfg.in_channels, cfg.out_channels, cfg.channels, cfg.dilations, + cfg.batchnorm, cfg.activation, weights, sample_rate, cfg.groups); } - } - }; + else if constexpr (std::is_same_v) + { + return std::make_unique(cfg.in_channels, cfg.layer_array_params, cfg.head_scale, + cfg.with_head, std::move(weights), std::move(cfg.condition_dsp), + sample_rate); + } + }, + std::move(config)); - if (!conf.metadata.is_null()) - { - AssignOptional("loudness", loudness); - AssignOptional("input_level_dbu", inputLevel); - AssignOptional("output_level_dbu", outputLevel); - } - const double expectedSampleRate = conf.expected_sample_rate; + apply_metadata(*out, metadata); + out->prewarm(); + return out; +} - // Initialize using registry-based factory - std::unique_ptr out = - nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate); +// ============================================================================= +// get_dsp(dspData&) — now uses unified path +// ============================================================================= - if (loudness.have) - { - out->SetLoudness(loudness.value); - } - if (inputLevel.have) - { - out->SetInputLevel(inputLevel.value); - } - if (outputLevel.have) +std::unique_ptr get_dsp(dspData& conf) +{ + verify_config_version(conf.version); + + // Extract metadata from JSON + ModelMetadata metadata; + metadata.version = conf.version; + metadata.sample_rate = conf.expected_sample_rate; + + if (!conf.metadata.is_null()) { - out->SetOutputLevel(outputLevel.value); + auto extract = [&conf](const std::string& key) -> std::optional { + if (conf.metadata.find(key) != conf.metadata.end() && !conf.metadata[key].is_null()) + return conf.metadata[key].get(); + return std::nullopt; + }; + metadata.loudness = extract("loudness"); + metadata.input_level = extract("input_level_dbu"); + metadata.output_level = extract("output_level_dbu"); } - // "pre-warm" the model to settle initial conditions - // Can this be removed now that it's part of Reset()? - out->prewarm(); - - return out; + ModelConfig model_config = parse_model_config_json(conf.architecture, conf.config, conf.expected_sample_rate); + return create_dsp(std::move(model_config), std::move(conf.weights), metadata); } double get_sample_rate_from_nam_file(const nlohmann::json& j) diff --git a/NAM/lstm.cpp b/NAM/lstm.cpp index d162d55..93e4d33 100644 --- a/NAM/lstm.cpp +++ b/NAM/lstm.cpp @@ -163,18 +163,25 @@ void nam::lstm::LSTM::_process_sample() this->_output.noalias() += this->_head_bias; } +// Config parser +nam::lstm::LSTMConfig nam::lstm::parse_config_json(const nlohmann::json& config) +{ + LSTMConfig c; + c.num_layers = config["num_layers"]; + c.input_size = config["input_size"]; + c.hidden_size = config["hidden_size"]; + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + // Factory to instantiate from nlohmann json std::unique_ptr nam::lstm::Factory(const nlohmann::json& config, std::vector& weights, const double expectedSampleRate) { - const int num_layers = config["num_layers"]; - const int input_size = config["input_size"]; - const int hidden_size = config["hidden_size"]; - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique( - in_channels, out_channels, num_layers, input_size, hidden_size, weights, expectedSampleRate); + auto c = parse_config_json(config); + return std::make_unique(c.in_channels, c.out_channels, c.num_layers, c.input_size, c.hidden_size, + weights, expectedSampleRate); } // Register the factory diff --git a/NAM/lstm.h b/NAM/lstm.h index d97de20..fa00d4d 100644 --- a/NAM/lstm.h +++ b/NAM/lstm.h @@ -95,6 +95,21 @@ class LSTM : public DSP Eigen::VectorXf _output; }; +/// \brief Configuration for an LSTM model +struct LSTMConfig +{ + int num_layers; + int input_size; + int hidden_size; + int in_channels; + int out_channels; +}; + +/// \brief Parse LSTM configuration from JSON +/// \param config JSON configuration object +/// \return LSTMConfig +LSTMConfig parse_config_json(const nlohmann::json& config); + /// \brief Factory function to instantiate LSTM from JSON /// \param config JSON configuration object /// \param weights Model weights vector diff --git a/NAM/model_config.h b/NAM/model_config.h new file mode 100644 index 0000000..cfc5b86 --- /dev/null +++ b/NAM/model_config.h @@ -0,0 +1,51 @@ +#pragma once +// Unified model configuration types for both JSON and binary loaders. +// No circular dependencies: architecture headers define config structs, +// this header combines them into a variant. + +#include +#include +#include +#include +#include + +#include "convnet.h" +#include "dsp.h" +#include "lstm.h" +#include "wavenet.h" + +namespace nam +{ + +/// \brief Metadata common to all model formats +struct ModelMetadata +{ + std::string version; + double sample_rate = -1.0; + std::optional loudness; + std::optional input_level; + std::optional output_level; +}; + +/// \brief Variant of all architecture configs +using ModelConfig = std::variant; + +/// \brief Construct a DSP object from a typed config, weights, and metadata +/// +/// This is the single construction path used by both JSON and binary loaders. +/// Handles construction, metadata application, and prewarm. +/// \param config Architecture-specific configuration (variant) +/// \param weights Model weights (taken by value to allow move for WaveNet) +/// \param metadata Model metadata (version, sample rate, loudness, levels) +/// \return Unique pointer to a DSP object +std::unique_ptr create_dsp(ModelConfig config, std::vector weights, const ModelMetadata& metadata); + +/// \brief Parse a ModelConfig from a JSON architecture name and config block +/// \param architecture Architecture name string (e.g., "WaveNet", "LSTM") +/// \param config JSON config block for this architecture +/// \param sample_rate Expected sample rate from metadata +/// \return ModelConfig variant +ModelConfig parse_model_config_json(const std::string& architecture, const nlohmann::json& config, + double sample_rate); + +} // namespace nam diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 6eb74a3..dd571e2 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -568,41 +568,43 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, con } } -// Factory to instantiate from nlohmann json -std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate) +// Config parser - extracts all configuration from JSON without constructing the DSP +nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json& config, + const double expectedSampleRate) { - std::unique_ptr condition_dsp = nullptr; - if (config.find("condition_dsp") != config.end() && !config["condition_dsp"].is_null()) + WaveNetConfig wc; + + // Condition DSP (eagerly built via get_dsp) + if ((config.find("condition_dsp") != config.end()) && !config["condition_dsp"].is_null()) { const nlohmann::json& condition_dsp_json = config["condition_dsp"]; - condition_dsp = nam::get_dsp(condition_dsp_json); - if (condition_dsp->GetExpectedSampleRate() != expectedSampleRate) + wc.condition_dsp = nam::get_dsp(condition_dsp_json); + if (wc.condition_dsp->GetExpectedSampleRate() != expectedSampleRate) { std::stringstream ss; - ss << "Condition DSP expected sample rate (" << condition_dsp->GetExpectedSampleRate() + ss << "Condition DSP expected sample rate (" << wc.condition_dsp->GetExpectedSampleRate() << ") doesn't match WaveNet expected sample rate (" << expectedSampleRate << "!\n"; throw std::runtime_error(ss.str().c_str()); } } - std::vector layer_array_params; + for (size_t i = 0; i < config["layers"].size(); i++) { nlohmann::json layer_config = config["layers"][i]; - const int groups = layer_config.value("groups_input", 1); // defaults to 1 - const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1 + const int groups = layer_config.value("groups_input", 1); + const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); const int channels = layer_config["channels"]; - const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present + const int bottleneck = layer_config.value("bottleneck", channels); // Parse layer1x1 parameters - bool layer1x1_active = true; // default to active if not present + bool layer1x1_active = true; int layer1x1_groups = 1; if (layer_config.find("layer1x1") != layer_config.end()) { const auto& layer1x1_config = layer_config["layer1x1"]; - layer1x1_active = layer1x1_config["active"]; // default to active + layer1x1_active = layer1x1_config["active"]; layer1x1_groups = layer1x1_config["groups"]; } nam::wavenet::Layer1x1Params layer1x1_params(layer1x1_active, layer1x1_groups); @@ -618,7 +620,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st std::vector activation_configs; if (layer_config["activation"].is_array()) { - // Array of activation configs for (const auto& activation_json : layer_config["activation"]) { activation_configs.push_back(activations::ActivationConfig::from_json(activation_json)); @@ -632,12 +633,12 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single activation config - duplicate it for all layers const activations::ActivationConfig activation_config = activations::ActivationConfig::from_json(layer_config["activation"]); activation_configs.resize(num_layers, activation_config); } - // Parse gating mode(s) - support both single value and array, and old "gated" boolean + + // Parse gating mode(s) std::vector gating_modes; std::vector secondary_activation_configs; @@ -656,21 +657,18 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st { if (layer_config["gating_mode"].is_array()) { - // Array of gating modes for (const auto& gating_mode_json : layer_config["gating_mode"]) { std::string gating_mode_str = gating_mode_json.get(); GatingMode mode = parse_gating_mode_str(gating_mode_str); gating_modes.push_back(mode); - // Parse corresponding secondary activation if gating is enabled if (mode != GatingMode::NONE) { if (layer_config.find("secondary_activation") != layer_config.end()) { if (layer_config["secondary_activation"].is_array()) { - // Array of secondary activations - use corresponding index if (gating_modes.size() > layer_config["secondary_activation"].size()) { throw std::runtime_error("Layer array " + std::to_string(i) @@ -682,21 +680,18 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single secondary activation - use for all gated layers secondary_activation_configs.push_back( activations::ActivationConfig::from_json(layer_config["secondary_activation"])); } } else { - // Default to Sigmoid for backward compatibility secondary_activation_configs.push_back( activations::ActivationConfig::simple(activations::ActivationType::Sigmoid)); } } else { - // NONE mode - use empty config secondary_activation_configs.push_back(activations::ActivationConfig{}); } } @@ -706,7 +701,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st + std::to_string(gating_modes.size()) + ") must match dilations size (" + std::to_string(num_layers) + ")"); } - // Validate secondary_activation array size if it's an array if (layer_config.find("secondary_activation") != layer_config.end() && layer_config["secondary_activation"].is_array()) { @@ -720,12 +714,10 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single gating mode - duplicate for all layers std::string gating_mode_str = layer_config["gating_mode"].get(); GatingMode gating_mode = parse_gating_mode_str(gating_mode_str); gating_modes.resize(num_layers, gating_mode); - // Parse secondary activation activations::ActivationConfig secondary_activation_config; if (gating_mode != GatingMode::NONE) { @@ -736,7 +728,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Default to Sigmoid for backward compatibility secondary_activation_config = activations::ActivationConfig::simple(activations::ActivationType::Sigmoid); } } @@ -745,7 +736,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else if (layer_config.find("gated") != layer_config.end()) { - // Backward compatibility: convert old "gated" boolean to new enum bool gated = layer_config["gated"]; GatingMode gating_mode = gated ? GatingMode::GATED : GatingMode::NONE; gating_modes.resize(num_layers, gating_mode); @@ -763,7 +753,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Default to NONE for all layers gating_modes.resize(num_layers, GatingMode::NONE); secondary_activation_configs.resize(num_layers, activations::ActivationConfig{}); } @@ -792,11 +781,10 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st const nlohmann::json& film_config = layer_config[key]; bool active = film_config.value("active", true); bool shift = film_config.value("shift", true); - int groups = film_config.value("groups", 1); - return nam::wavenet::_FiLMParams(active, shift, groups); + int film_groups = film_config.value("groups", 1); + return nam::wavenet::_FiLMParams(active, shift, film_groups); }; - // Parse FiLM parameters nam::wavenet::_FiLMParams conv_pre_film_params = parse_film_params("conv_pre_film"); nam::wavenet::_FiLMParams conv_post_film_params = parse_film_params("conv_post_film"); nam::wavenet::_FiLMParams input_mixin_pre_film_params = parse_film_params("input_mixin_pre_film"); @@ -806,32 +794,37 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st nam::wavenet::_FiLMParams _layer1x1_post_film_params = parse_film_params("layer1x1_post_film"); nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film"); - // Validation: if layer1x1_post_film is active, layer1x1 must also be active if (_layer1x1_post_film_params.active && !layer1x1_active) { throw std::runtime_error("Layer array " + std::to_string(i) + ": layer1x1_post_film cannot be active when layer1x1.active is false"); } - layer_array_params.push_back(nam::wavenet::LayerArrayParams( + wc.layer_array_params.push_back(nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, layer1x1_params, head1x1_params, std::move(secondary_activation_configs), conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params, activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params)); } - const bool with_head = config.find("head") != config.end() && !config["head"].is_null(); - const float head_scale = config["head_scale"]; - if (layer_array_params.empty()) + wc.with_head = config.find("head") != config.end() && !config["head"].is_null(); + wc.head_scale = config["head_scale"]; + wc.in_channels = config.value("in_channels", 1); + + if (wc.layer_array_params.empty()) throw std::runtime_error("WaveNet config requires at least one layer array"); - // Backward compatibility: assume 1 input channel - const int in_channels = config.value("in_channels", 1); + return wc; +} - // out_channels is determined from last layer array's head_size - return std::make_unique( - in_channels, layer_array_params, head_scale, with_head, weights, std::move(condition_dsp), expectedSampleRate); +// Factory to instantiate from nlohmann json +std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + auto wc = parse_config_json(config, expectedSampleRate); + return std::make_unique(wc.in_channels, wc.layer_array_params, wc.head_scale, wc.with_head, + weights, std::move(wc.condition_dsp), expectedSampleRate); } // Register the factory diff --git a/NAM/wavenet.h b/NAM/wavenet.h index 63e1378..b75d15f 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -713,6 +713,29 @@ class WaveNet : public DSP int PrewarmSamples() override { return mPrewarmSamples; }; }; +/// \brief Configuration for a WaveNet model +struct WaveNetConfig +{ + int in_channels; + std::vector layer_array_params; + float head_scale; + bool with_head; + std::unique_ptr condition_dsp; + + // Move-only due to unique_ptr + WaveNetConfig() = default; + WaveNetConfig(WaveNetConfig&&) = default; + WaveNetConfig& operator=(WaveNetConfig&&) = default; + WaveNetConfig(const WaveNetConfig&) = delete; + WaveNetConfig& operator=(const WaveNetConfig&) = delete; +}; + +/// \brief Parse WaveNet configuration from JSON +/// \param config JSON configuration object +/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown) +/// \return WaveNetConfig +WaveNetConfig parse_config_json(const nlohmann::json& config, const double expectedSampleRate); + /// \brief Factory function to instantiate WaveNet from JSON configuration /// \param config JSON configuration object /// \param weights Model weights vector