From 6c5fe80a2668a9d5a8274616709b4a4ec60f4970 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 13 May 2026 15:43:42 +0200 Subject: [PATCH 1/3] add rate as transition --- docs/source/python/m-modelgenerator.rst | 16 +++++- .../memilio/modelgenerator/schema.py | 11 +++- .../modelgenerator/templates/model_h.jinja2 | 20 ++++++- .../memilio/modelgenerator/validator.py | 5 +- .../tests/test_modelgenerator.py | 55 +++++++++++++++++++ 5 files changed, 101 insertions(+), 6 deletions(-) diff --git a/docs/source/python/m-modelgenerator.rst b/docs/source/python/m-modelgenerator.rst index 81f05bfc8c..d363b22c22 100644 --- a/docs/source/python/m-modelgenerator.rst +++ b/docs/source/python/m-modelgenerator.rst @@ -197,9 +197,9 @@ In order to allow the on-the-fly computation of newly infected (or hospitalized - Target compartment (must be in ``infection_states``, must differ from ``from``) * - ``type`` - yes - - ``infection``, ``linear``, or ``custom`` + - ``infection``, ``linear``, ``rate``, or ``custom`` * - ``parameter`` - - for ``infection`` and ``linear`` + - for ``infection``, ``linear``, and ``rate`` - Name of the driving parameter (must be in ``parameters``) * - ``infectious_state`` - for ``infection`` @@ -240,6 +240,17 @@ In order to allow the on-the-fly computation of newly infected (or hospitalized where :math:`T_i` is the time parameter for age group *i*. +``rate`` + The `rate` flow is a simple outflow proportional to the compartment size + using the parameter directly as a rate: + + .. math:: + + {X}'_i \leftarrow -r_i \cdot X_i + + where :math:`r_i` is the rate parameter for age group *i*. Unlike + ``linear``, setting :math:`r_i = 0` disables the transition. + ``custom`` For `custom`, a placeholder is inserted into ``get_flows()`` with a ``TODO`` comment. If ``custom_formula`` is provided, it is shown as a hint next to the placeholder. @@ -423,6 +434,7 @@ Common validation errors: * Missing or empty ``model``, ``infection_states``, ``parameters``, or ``transitions`` section * Fewer than two infection states, or duplicate state names * Parameter ``type`` is not one of ``probability``, ``time``, ``custom`` +* Transition ``type`` is not one of ``infection``, ``linear``, ``rate``, ``custom`` * ``parameter`` or ``infectious_state`` / ``infectious_states`` in a transition references an unknown name * A transition has the same ``from`` and ``to`` state (self-loop) diff --git a/pycode/memilio-generation/memilio/modelgenerator/schema.py b/pycode/memilio-generation/memilio/modelgenerator/schema.py index 6525357340..3163e33760 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/schema.py +++ b/pycode/memilio-generation/memilio/modelgenerator/schema.py @@ -40,10 +40,12 @@ class TransitionType: """Force-of-infection flow using contact matrix and S*I/N.""" LINEAR = "linear" """Simple outflow: (1 / parameter) * source_compartment.""" + RATE = "rate" + """Simple outflow: parameter * source_compartment.""" CUSTOM = "custom" """Placeholder. User must supply the expression manually.""" - ALL = (INFECTION, LINEAR, CUSTOM) + ALL = (INFECTION, LINEAR, RATE, CUSTOM) class ParameterType: @@ -161,6 +163,13 @@ def all_parameters(self) -> list[ParameterConfig]: """ return self.parameters + def parameter_by_name(self, name: str) -> ParameterConfig: + """Return the parameter config with the given name.""" + for parameter in self.parameters: + if parameter.name == name: + return parameter + raise KeyError(name) + def parameters_for_constraint_check(self) -> list[ParameterConfig]: """Return parameters that have explicit bound constraints.""" return [ diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 index 4cbedf0989..aeafc14d93 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 @@ -33,6 +33,11 @@ GCC_CLANG_DIAGNOSTIC(ignored "-Wshadow") #include GCC_CLANG_DIAGNOSTIC(pop) +{% macro param_value(parameter_name, index_name='i') -%} +{%- set param = cfg.parameter_by_name(parameter_name) -%} +params.template get<{{ parameter_name }}>(){{ "[" ~ index_name ~ "]" if param.per_age_group else "" }} +{%- endmacro %} + namespace mio { namespace {{ cfg.meta.namespace }} @@ -111,7 +116,7 @@ public: params.template get>() .get_cont_freq_mat() .get_matrix_at(SimulationTime(t))(i.get(), j.get()) * - params.template get<{{ t.parameter }}>()[i] * divNj; + {{ param_value(t.parameter) }} * divNj; flows[Base::template get_flat_flow_index< InfectionState::{{ t.from_state }}, @@ -133,7 +138,18 @@ public: flows[Base::template get_flat_flow_index< InfectionState::{{ t.from_state }}, InfectionState::{{ t.to_state }}>(i)] = - (FP(1.0) / params.template get<{{ t.parameter }}>()[i]) * + (FP(1.0) / {{ param_value(t.parameter) }}) * + y[idx_{{ t.from_state }}_i]; +{% endfor %} + + // ---------------------------------------------------------------- + // Rate outflow transitions + // ---------------------------------------------------------------- +{% for t in cfg.transitions if t.type == 'rate' %} + flows[Base::template get_flat_flow_index< + InfectionState::{{ t.from_state }}, + InfectionState::{{ t.to_state }}>(i)] = + {{ param_value(t.parameter) }} * y[idx_{{ t.from_state }}_i]; {% endfor %} diff --git a/pycode/memilio-generation/memilio/modelgenerator/validator.py b/pycode/memilio-generation/memilio/modelgenerator/validator.py index def8fd347b..dcc8ef4dfd 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/validator.py +++ b/pycode/memilio-generation/memilio/modelgenerator/validator.py @@ -178,7 +178,10 @@ def validate(data: Any) -> None: ) continue - if ttype in (TransitionType.INFECTION, TransitionType.LINEAR): + if ttype in ( + TransitionType.INFECTION, + TransitionType.LINEAR, + TransitionType.RATE): param = t.get("parameter") if param not in param_name_set: errors.append( diff --git a/pycode/memilio-generation/tests/test_modelgenerator.py b/pycode/memilio-generation/tests/test_modelgenerator.py index 879bc7e499..8f36591932 100644 --- a/pycode/memilio-generation/tests/test_modelgenerator.py +++ b/pycode/memilio-generation/tests/test_modelgenerator.py @@ -68,6 +68,21 @@ def test_seir_transitions(self): self.assertIn("infection", types) self.assertIn("linear", types) + def test_rate_transition_parses(self): + d = { + "model": {"name": "EI", "namespace": "oei", "prefix": "ode_ei"}, + "infection_states": ["E", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "custom", "default": 0.2} + ], + "transitions": [ + {"from": "E", "to": "I", "type": "rate", "parameter": "Rate"} + ], + } + gen = Generator.from_dict(d) + self.assertEqual(gen._config.transitions[0].type, "rate") + def test_seird_has_custom_transition(self): gen = Generator.from_yaml(SEIRD_YAML) custom = [t for t in gen._config.transitions if t.type == "custom"] @@ -339,6 +354,39 @@ def test_linear_flows(self): self.assertIn("TimeExposed>()[i]", self.content) self.assertIn("TimeInfected>()[i]", self.content) + def test_rate_flows(self): + d = { + "model": {"name": "EI", "namespace": "oei", "prefix": "ode_ei"}, + "infection_states": ["E", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "custom", "default": 0.2} + ], + "transitions": [ + {"from": "E", "to": "I", "type": "rate", "parameter": "Rate"} + ], + } + content = Generator.from_dict(d).render()["cpp/models/ode_ei/model.h"] + self.assertIn("Rate outflow transitions", content) + self.assertIn("Rate>()[i] *", content) + self.assertIn("y[idx_E_i]", content) + + def test_rate_flow_supports_scalar_parameter(self): + d = { + "model": {"name": "EI", "namespace": "oei", "prefix": "ode_ei"}, + "infection_states": ["E", "I"], + "parameters": [ + {"name": "Rate", "description": "d", "type": "custom", + "default": 0.2, "per_age_group": False} + ], + "transitions": [ + {"from": "E", "to": "I", "type": "rate", "parameter": "Rate"} + ], + } + content = Generator.from_dict(d).render()["cpp/models/ode_ei/model.h"] + self.assertIn("params.template get>() *", content) + self.assertNotIn("Rate>()[i] *", content) + def test_serialize_deserialize(self): self.assertIn("void serialize(", self.content) self.assertIn("static IOResult deserialize(", self.content) @@ -601,6 +649,13 @@ def test_unknown_parameter_in_transition(self): with self.assertRaises(ValidationError): Generator.from_dict(d) + def test_unknown_parameter_in_rate_transition(self): + d = self._base() + d["transitions"][0]["type"] = "rate" + d["transitions"][0]["parameter"] = "NoSuchParam" + with self.assertRaises(ValidationError): + Generator.from_dict(d) + def test_invalid_transition_type(self): d = self._base() d["transitions"][0]["type"] = "magic" From 9a5962f455f9782871d5f1e0dc1577c5df344674 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 13 May 2026 16:19:47 +0200 Subject: [PATCH 2/3] "costum rates" and derived_quantities --- docs/source/python/m-modelgenerator.rst | 21 +++++- .../memilio/modelgenerator/generator.py | 8 +++ .../memilio/modelgenerator/schema.py | 39 +++++++++++ .../modelgenerator/templates/model_h.jinja2 | 15 +++++ .../templates/parameters_h.jinja2 | 2 +- .../memilio/modelgenerator/validator.py | 66 +++++++++++++++++-- .../tests/test_modelgenerator.py | 42 ++++++++++++ 7 files changed, 187 insertions(+), 6 deletions(-) diff --git a/docs/source/python/m-modelgenerator.rst b/docs/source/python/m-modelgenerator.rst index d363b22c22..67d76e24d2 100644 --- a/docs/source/python/m-modelgenerator.rst +++ b/docs/source/python/m-modelgenerator.rst @@ -199,8 +199,11 @@ In order to allow the on-the-fly computation of newly infected (or hospitalized - yes - ``infection``, ``linear``, ``rate``, or ``custom`` * - ``parameter`` - - for ``infection``, ``linear``, and ``rate`` + - for ``infection``, ``linear``, and simple ``rate`` - Name of the driving parameter (must be in ``parameters``) + * - ``rate`` + - alternatively for ``rate`` + - Formula for the source-proportional rate. Names from ``infection_states``, ``parameters``, and ``derived_quantities`` are replaced in generated C++. * - ``infectious_state`` - for ``infection`` - Compartment whose population drives the force of infection (e.g. ``Infected``). You can pass a single state or a list of states (e.g. ``[InfectedNoSymptoms, InfectedSymptoms]``), in which case their populations are summed in the force of infection. @@ -251,6 +254,22 @@ In order to allow the on-the-fly computation of newly infected (or hospitalized where :math:`r_i` is the rate parameter for age group *i*. Unlike ``linear``, setting :math:`r_i = 0` disables the transition. + Instead of ``parameter``, ``rate`` may specify a formula, e.g. + ``rate: "a1 * lambda_human"``. Formula names may refer to states, + parameters, or previously defined ``derived_quantities``. + +``derived_quantities`` + Optional local quantities can be defined before ``transitions`` and reused + in ``rate`` formulas: + + .. code-block:: yaml + + derived_quantities: + - name: N + formula: "Susceptible + Infected" + - name: lambda_human + formula: "safe_div(Infected, N)" + ``custom`` For `custom`, a placeholder is inserted into ``get_flows()`` with a ``TODO`` comment. If ``custom_formula`` is provided, it is shown as a hint next to the placeholder. diff --git a/pycode/memilio-generation/memilio/modelgenerator/generator.py b/pycode/memilio-generation/memilio/modelgenerator/generator.py index 643090fef3..62dab6c00e 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/generator.py +++ b/pycode/memilio-generation/memilio/modelgenerator/generator.py @@ -42,6 +42,7 @@ from jinja2 import Environment, PackageLoader, StrictUndefined from .schema import ( + DerivedQuantityConfig, ModelConfig, ModelMeta, ParameterConfig, @@ -338,6 +339,11 @@ def _parse(raw: dict) -> ModelConfig: ) ) + derived_quantities = [ + DerivedQuantityConfig(name=d["name"], formula=d["formula"]) + for d in raw.get("derived_quantities", []) + ] + transitions = [] for t in raw["transitions"]: raw_infectious_states = t.get("infectious_states") @@ -359,6 +365,7 @@ def _parse(raw: dict) -> ModelConfig: to_state=t["to"], type=t["type"], parameter=t.get("parameter"), + rate=t.get("rate"), infectious_state=infectious_states[0] if infectious_states else None, infectious_states=infectious_states, @@ -368,5 +375,6 @@ def _parse(raw: dict) -> ModelConfig: meta=meta, infection_states=states, parameters=parameters, + derived_quantities=derived_quantities, transitions=transitions, ) diff --git a/pycode/memilio-generation/memilio/modelgenerator/schema.py b/pycode/memilio-generation/memilio/modelgenerator/schema.py index 3163e33760..4dcd1c0e9b 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/schema.py +++ b/pycode/memilio-generation/memilio/modelgenerator/schema.py @@ -28,6 +28,7 @@ from __future__ import annotations +import re from dataclasses import dataclass, field from typing import List, Optional, Tuple @@ -102,6 +103,17 @@ class ParameterConfig: """(lower, upper) bounds used in the constraint checks. ``None`` means unchecked.""" +@dataclass +class DerivedQuantityConfig: + """Configuration for a local formula-derived quantity.""" + + name: str + """C++ local variable name.""" + + formula: str + """Formula using state, parameter, and earlier derived quantity names.""" + + @dataclass class TransitionConfig: """Configuration for a single compartment flow.""" @@ -118,6 +130,9 @@ class TransitionConfig: parameter: str | None = None """Name of the `ParameterConfig` that drives this flow.""" + rate: str | None = None + """For ``type == "rate"``: formula for the source-proportional rate.""" + infectious_state: str | None = None """ For ``type == "infection"``: the compartment whose population drives @@ -145,6 +160,7 @@ class ModelConfig: meta: ModelMeta infection_states: list[str] parameters: list[ParameterConfig] + derived_quantities: list[DerivedQuantityConfig] transitions: list[TransitionConfig] @property @@ -170,6 +186,29 @@ def parameter_by_name(self, name: str) -> ParameterConfig: return parameter raise KeyError(name) + @property + def uses_safe_div(self) -> bool: + """``True`` if generated formulas need the safe_div helper.""" + return any("safe_div" in d.formula for d in self.derived_quantities) or any( + t.rate is not None and "safe_div" in t.rate for t in self.transitions) + + def formula_to_cpp(self, formula: str, index_name: str = "i") -> str: + """Translate a formula by replacing known names with C++ expressions.""" + states = set(self.infection_states) + parameters = {p.name: p for p in self.parameters} + derived = {d.name for d in self.derived_quantities} + + def replace(match): + name = match.group(0) + if name in states: + return f"y[idx_{name}_{index_name}]" + if name in parameters: + suffix = f"[{index_name}]" if parameters[name].per_age_group else "" + return f"params.template get<{name}>(){suffix}" + return name if name in derived or name in {"t", "safe_div"} else name + + return re.sub(r"\b[A-Za-z_][A-Za-z0-9_]*\b", replace, formula) + def parameters_for_constraint_check(self) -> list[ParameterConfig]: """Return parameters that have explicit bound constraints.""" return [ diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 index aeafc14d93..dc4b07433c 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 @@ -91,6 +91,17 @@ public: this->populations.get_flat_index({i, InfectionState::{{ state }}}); {% endfor %} +{% if cfg.derived_quantities %} +{% if cfg.uses_safe_div %} + const auto safe_div = [](FP numerator, FP denominator) { + return denominator < Limits::zero_tolerance() ? FP(0.0) : numerator / denominator; + }; +{% endif %} +{% for q in cfg.derived_quantities %} + const FP {{ q.name }} = {{ cfg.formula_to_cpp(q.formula) }}; +{% endfor %} + +{% endif %} {% if cfg.has_infection_transition %} // ---------------------------------------------------------------- // Infection transitions – double loop over contact age groups @@ -149,7 +160,11 @@ public: flows[Base::template get_flat_flow_index< InfectionState::{{ t.from_state }}, InfectionState::{{ t.to_state }}>(i)] = +{% if t.rate %} + {{ cfg.formula_to_cpp(t.rate) }} * +{% else %} {{ param_value(t.parameter) }} * +{% endif %} y[idx_{{ t.from_state }}_i]; {% endfor %} diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/parameters_h.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/parameters_h.jinja2 index d19b839123..6e4916b45f 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/templates/parameters_h.jinja2 +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/parameters_h.jinja2 @@ -217,7 +217,7 @@ private: {% else %} {% set per_age_group_params = cfg.parameters | selectattr('per_age_group') | list %} {% if per_age_group_params %} - , m_num_groups(AgeGroup(this->template get<{{ per_age_group_params[0].name }}>().size())) + , m_num_groups(AgeGroup(this->template get<{{ per_age_group_params[0].name }}>().template size().get())) {% else %} , m_num_groups(AgeGroup(1)) {% endif %} diff --git a/pycode/memilio-generation/memilio/modelgenerator/validator.py b/pycode/memilio-generation/memilio/modelgenerator/validator.py index dcc8ef4dfd..9893ad9815 100644 --- a/pycode/memilio-generation/memilio/modelgenerator/validator.py +++ b/pycode/memilio-generation/memilio/modelgenerator/validator.py @@ -27,6 +27,7 @@ from __future__ import annotations +import re from typing import Any, Dict, List from .schema import ParameterType, TransitionType @@ -143,6 +144,43 @@ def validate(data: Any) -> None: param_name_set = set(param_names) + # derived_quantities + derived_quantities = data.get("derived_quantities", []) + if not isinstance(derived_quantities, list): + errors.append("'derived_quantities' must be a list if provided.") + derived_quantities = [] + + formula_names = set(states) | param_name_set | {"t", "safe_div"} + derived_names: list[str] = [] + for i, d in enumerate(derived_quantities): + loc = f"derived_quantities[{i}]" + if not isinstance(d, dict): + errors.append(f"'{loc}' must be a mapping.") + continue + name = d.get("name") + formula = d.get("formula") + if not isinstance(name, str) or not name.strip(): + errors.append(f"'{loc}.name' must be a non-empty string.") + elif name in state_set or name in param_name_set: + errors.append( + f"'{loc}.name' must not duplicate a state or parameter name." + ) + elif name in derived_names: + errors.append( + f"'derived_quantities' contains duplicate 'name' entry {name!r}." + ) + if not isinstance(formula, str) or not formula.strip(): + errors.append(f"'{loc}.formula' must be a non-empty string.") + else: + for ref in re.findall(r"\b[A-Za-z_][A-Za-z0-9_]*\b", formula): + if ref not in formula_names: + errors.append( + f"'{loc}.formula' references unknown or later-defined symbol {ref!r}." + ) + if isinstance(name, str) and name.strip(): + derived_names.append(name) + formula_names.add(name) + # transitions transitions = data.get("transitions") if not isinstance(transitions, list) or len(transitions) == 0: @@ -178,16 +216,36 @@ def validate(data: Any) -> None: ) continue - if ttype in ( - TransitionType.INFECTION, - TransitionType.LINEAR, - TransitionType.RATE): + if ttype in (TransitionType.INFECTION, TransitionType.LINEAR): param = t.get("parameter") if param not in param_name_set: errors.append( f"'{loc}.parameter' references unknown parameter {param!r}." ) + if ttype == TransitionType.RATE: + if "rate" in t: + if "parameter" in t: + errors.append( + f"'{loc}' must define only one of 'parameter' or 'rate'." + ) + rate = t.get("rate") + if not isinstance(rate, str) or not rate.strip(): + errors.append( + f"'{loc}.rate' must be a non-empty string.") + else: + for ref in re.findall(r"\b[A-Za-z_][A-Za-z0-9_]*\b", rate): + if ref not in formula_names: + errors.append( + f"'{loc}.rate' references unknown symbol {ref!r}." + ) + else: + param = t.get("parameter") + if param not in param_name_set: + errors.append( + f"'{loc}.parameter' references unknown parameter {param!r}." + ) + if ttype == TransitionType.INFECTION: has_singular = "infectious_state" in t has_plural = "infectious_states" in t diff --git a/pycode/memilio-generation/tests/test_modelgenerator.py b/pycode/memilio-generation/tests/test_modelgenerator.py index 8f36591932..626182fa74 100644 --- a/pycode/memilio-generation/tests/test_modelgenerator.py +++ b/pycode/memilio-generation/tests/test_modelgenerator.py @@ -83,6 +83,26 @@ def test_rate_transition_parses(self): gen = Generator.from_dict(d) self.assertEqual(gen._config.transitions[0].type, "rate") + def test_derived_quantity_parses(self): + d = { + "model": {"name": "EI", "namespace": "oei", "prefix": "ode_ei"}, + "infection_states": ["E", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "custom", "default": 0.2} + ], + "derived_quantities": [ + {"name": "effective_rate", "formula": "2 * Rate"} + ], + "transitions": [ + {"from": "E", "to": "I", "type": "rate", + "rate": "effective_rate"} + ], + } + gen = Generator.from_dict(d) + self.assertEqual(gen._config.derived_quantities[0].name, + "effective_rate") + def test_seird_has_custom_transition(self): gen = Generator.from_yaml(SEIRD_YAML) custom = [t for t in gen._config.transitions if t.type == "custom"] @@ -387,6 +407,28 @@ def test_rate_flow_supports_scalar_parameter(self): self.assertIn("params.template get>() *", content) self.assertNotIn("Rate>()[i] *", content) + def test_rate_flow_supports_formula(self): + d = { + "model": {"name": "EI", "namespace": "oei", "prefix": "ode_ei"}, + "infection_states": ["E", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "custom", "default": 0.2} + ], + "derived_quantities": [ + {"name": "N", "formula": "E + I"}, + {"name": "share", "formula": "safe_div(I, N)"} + ], + "transitions": [ + {"from": "E", "to": "I", "type": "rate", + "rate": "Rate * share"} + ], + } + content = Generator.from_dict(d).render()["cpp/models/ode_ei/model.h"] + self.assertIn("const FP N = y[idx_E_i] + y[idx_I_i];", content) + self.assertIn("const FP share = safe_div(y[idx_I_i], N);", content) + self.assertIn("params.template get>()[i] * share", content) + def test_serialize_deserialize(self): self.assertIn("void serialize(", self.content) self.assertIn("static IOResult deserialize(", self.content) From 26e08a91aaca87463429a1230a70d19555a2aa8e Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 13 May 2026 16:36:21 +0200 Subject: [PATCH 3/3] split up tests --- pycode/memilio-generation/tests/test_modelgenerator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pycode/memilio-generation/tests/test_modelgenerator.py b/pycode/memilio-generation/tests/test_modelgenerator.py index 626182fa74..0b39109a94 100644 --- a/pycode/memilio-generation/tests/test_modelgenerator.py +++ b/pycode/memilio-generation/tests/test_modelgenerator.py @@ -425,8 +425,10 @@ def test_rate_flow_supports_formula(self): ], } content = Generator.from_dict(d).render()["cpp/models/ode_ei/model.h"] - self.assertIn("const FP N = y[idx_E_i] + y[idx_I_i];", content) - self.assertIn("const FP share = safe_div(y[idx_I_i], N);", content) + self.assertIn("const FP N", content) + self.assertIn("= y[idx_E_i] + y[idx_I_i];", content) + self.assertIn("const FP share", content) + self.assertIn("= safe_div(y[idx_I_i], N);", content) self.assertIn("params.template get>()[i] * share", content) def test_serialize_deserialize(self):