Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions docs/source/python/m-modelgenerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,13 @@ In order to allow the on-the-fly computation of newly infected (or hospitalized
- Target compartment (must be in ``infection_states``, must differ from ``from``)
* - ``type``
- yes
- ``infection``, ``linear``, or ``custom``
- ``infection``, ``linear``, ``rate``, or ``custom``
* - ``parameter``
- for ``infection`` and ``linear``
- for ``infection``, ``linear``, and simple ``rate``
- Name of the driving parameter (must be in ``parameters``)
* - ``rate``
- alternatively for ``rate``
- Formula for the source-proportional rate. Names from ``infection_states``, ``parameters``, and ``derived_quantities`` are replaced in generated C++.
* - ``infectious_state``
- for ``infection``
- Compartment whose population drives the force of infection (e.g. ``Infected``). You can pass a single state or a list of states (e.g. ``[InfectedNoSymptoms, InfectedSymptoms]``), in which case their populations are summed in the force of infection.
Expand Down Expand Up @@ -240,6 +243,33 @@ In order to allow the on-the-fly computation of newly infected (or hospitalized

where :math:`T_i` is the time parameter for age group *i*.

``rate``
The `rate` flow is a simple outflow proportional to the compartment size
using the parameter directly as a rate:

.. math::

{X}'_i \leftarrow -r_i \cdot X_i

where :math:`r_i` is the rate parameter for age group *i*. Unlike
``linear``, setting :math:`r_i = 0` disables the transition.

Instead of ``parameter``, ``rate`` may specify a formula, e.g.
``rate: "a1 * lambda_human"``. Formula names may refer to states,
parameters, or previously defined ``derived_quantities``.

``derived_quantities``
Optional local quantities can be defined before ``transitions`` and reused
in ``rate`` formulas:

.. code-block:: yaml

derived_quantities:
- name: N
formula: "Susceptible + Infected"
- name: lambda_human
formula: "safe_div(Infected, N)"

``custom``
For `custom`, a placeholder is inserted into ``get_flows()`` with a ``TODO`` comment.
If ``custom_formula`` is provided, it is shown as a hint next to the placeholder.
Expand Down Expand Up @@ -423,6 +453,7 @@ Common validation errors:
* Missing or empty ``model``, ``infection_states``, ``parameters``, or ``transitions`` section
* Fewer than two infection states, or duplicate state names
* Parameter ``type`` is not one of ``probability``, ``time``, ``custom``
* Transition ``type`` is not one of ``infection``, ``linear``, ``rate``, ``custom``
* ``parameter`` or ``infectious_state`` / ``infectious_states`` in a transition references an unknown name
* A transition has the same ``from`` and ``to`` state (self-loop)

Expand Down
8 changes: 8 additions & 0 deletions pycode/memilio-generation/memilio/modelgenerator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from jinja2 import Environment, PackageLoader, StrictUndefined

from .schema import (
DerivedQuantityConfig,
ModelConfig,
ModelMeta,
ParameterConfig,
Expand Down Expand Up @@ -338,6 +339,11 @@ def _parse(raw: dict) -> ModelConfig:
)
)

derived_quantities = [
DerivedQuantityConfig(name=d["name"], formula=d["formula"])
for d in raw.get("derived_quantities", [])
]

transitions = []
for t in raw["transitions"]:
raw_infectious_states = t.get("infectious_states")
Expand All @@ -359,6 +365,7 @@ def _parse(raw: dict) -> ModelConfig:
to_state=t["to"],
type=t["type"],
parameter=t.get("parameter"),
rate=t.get("rate"),
infectious_state=infectious_states[0]
if infectious_states else None,
infectious_states=infectious_states,
Expand All @@ -368,5 +375,6 @@ def _parse(raw: dict) -> ModelConfig:
meta=meta,
infection_states=states,
parameters=parameters,
derived_quantities=derived_quantities,
transitions=transitions,
)
50 changes: 49 additions & 1 deletion pycode/memilio-generation/memilio/modelgenerator/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from __future__ import annotations

import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

Expand All @@ -40,10 +41,12 @@ class TransitionType:
"""Force-of-infection flow using contact matrix and S*I/N."""
LINEAR = "linear"
"""Simple outflow: (1 / parameter) * source_compartment."""
RATE = "rate"
"""Simple outflow: parameter * source_compartment."""
CUSTOM = "custom"
"""Placeholder. User must supply the expression manually."""

ALL = (INFECTION, LINEAR, CUSTOM)
ALL = (INFECTION, LINEAR, RATE, CUSTOM)


class ParameterType:
Expand Down Expand Up @@ -100,6 +103,17 @@ class ParameterConfig:
"""(lower, upper) bounds used in the constraint checks. ``None`` means unchecked."""


@dataclass
class DerivedQuantityConfig:
"""Configuration for a local formula-derived quantity."""

name: str
"""C++ local variable name."""

formula: str
"""Formula using state, parameter, and earlier derived quantity names."""


@dataclass
class TransitionConfig:
"""Configuration for a single compartment flow."""
Expand All @@ -116,6 +130,9 @@ class TransitionConfig:
parameter: str | None = None
"""Name of the `ParameterConfig` that drives this flow."""

rate: str | None = None
"""For ``type == "rate"``: formula for the source-proportional rate."""

infectious_state: str | None = None
"""
For ``type == "infection"``: the compartment whose population drives
Expand Down Expand Up @@ -143,6 +160,7 @@ class ModelConfig:
meta: ModelMeta
infection_states: list[str]
parameters: list[ParameterConfig]
derived_quantities: list[DerivedQuantityConfig]
transitions: list[TransitionConfig]

@property
Expand All @@ -161,6 +179,36 @@ def all_parameters(self) -> list[ParameterConfig]:
"""
return self.parameters

def parameter_by_name(self, name: str) -> ParameterConfig:
"""Return the parameter config with the given name."""
for parameter in self.parameters:
if parameter.name == name:
return parameter
raise KeyError(name)

@property
def uses_safe_div(self) -> bool:
"""``True`` if generated formulas need the safe_div helper."""
return any("safe_div" in d.formula for d in self.derived_quantities) or any(
t.rate is not None and "safe_div" in t.rate for t in self.transitions)

def formula_to_cpp(self, formula: str, index_name: str = "i") -> str:
"""Translate a formula by replacing known names with C++ expressions."""
states = set(self.infection_states)
parameters = {p.name: p for p in self.parameters}
derived = {d.name for d in self.derived_quantities}

def replace(match):
name = match.group(0)
if name in states:
return f"y[idx_{name}_{index_name}]"
if name in parameters:
suffix = f"[{index_name}]" if parameters[name].per_age_group else ""
return f"params.template get<{name}<FP>>(){suffix}"
return name if name in derived or name in {"t", "safe_div"} else name

return re.sub(r"\b[A-Za-z_][A-Za-z0-9_]*\b", replace, formula)

def parameters_for_constraint_check(self) -> list[ParameterConfig]:
"""Return parameters that have explicit bound constraints."""
return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ GCC_CLANG_DIAGNOSTIC(ignored "-Wshadow")
#include <Eigen/Dense>
GCC_CLANG_DIAGNOSTIC(pop)

{% macro param_value(parameter_name, index_name='i') -%}
{%- set param = cfg.parameter_by_name(parameter_name) -%}
params.template get<{{ parameter_name }}<FP>>(){{ "[" ~ index_name ~ "]" if param.per_age_group else "" }}
{%- endmacro %}

namespace mio
{
namespace {{ cfg.meta.namespace }}
Expand Down Expand Up @@ -86,6 +91,17 @@ public:
this->populations.get_flat_index({i, InfectionState::{{ state }}});
{% endfor %}

{% if cfg.derived_quantities %}
{% if cfg.uses_safe_div %}
const auto safe_div = [](FP numerator, FP denominator) {
return denominator < Limits<FP>::zero_tolerance() ? FP(0.0) : numerator / denominator;
};
{% endif %}
{% for q in cfg.derived_quantities %}
const FP {{ q.name }} = {{ cfg.formula_to_cpp(q.formula) }};
{% endfor %}

{% endif %}
{% if cfg.has_infection_transition %}
// ----------------------------------------------------------------
// Infection transitions – double loop over contact age groups
Expand All @@ -111,7 +127,7 @@ public:
params.template get<ContactPatterns<FP>>()
.get_cont_freq_mat()
.get_matrix_at(SimulationTime<FP>(t))(i.get(), j.get()) *
params.template get<{{ t.parameter }}<FP>>()[i] * divNj;
{{ param_value(t.parameter) }} * divNj;

flows[Base::template get_flat_flow_index<
InfectionState::{{ t.from_state }},
Expand All @@ -133,7 +149,22 @@ public:
flows[Base::template get_flat_flow_index<
InfectionState::{{ t.from_state }},
InfectionState::{{ t.to_state }}>(i)] =
(FP(1.0) / params.template get<{{ t.parameter }}<FP>>()[i]) *
(FP(1.0) / {{ param_value(t.parameter) }}) *
y[idx_{{ t.from_state }}_i];
{% endfor %}

// ----------------------------------------------------------------
// Rate outflow transitions
// ----------------------------------------------------------------
{% for t in cfg.transitions if t.type == 'rate' %}
flows[Base::template get_flat_flow_index<
InfectionState::{{ t.from_state }},
InfectionState::{{ t.to_state }}>(i)] =
{% if t.rate %}
{{ cfg.formula_to_cpp(t.rate) }} *
{% else %}
{{ param_value(t.parameter) }} *
{% endif %}
y[idx_{{ t.from_state }}_i];
{% endfor %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ private:
{% else %}
{% set per_age_group_params = cfg.parameters | selectattr('per_age_group') | list %}
{% if per_age_group_params %}
, m_num_groups(AgeGroup(this->template get<{{ per_age_group_params[0].name }}<FP>>().size()))
, m_num_groups(AgeGroup(this->template get<{{ per_age_group_params[0].name }}<FP>>().template size<AgeGroup>().get()))
{% else %}
, m_num_groups(AgeGroup(1))
{% endif %}
Expand Down
61 changes: 61 additions & 0 deletions pycode/memilio-generation/memilio/modelgenerator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from __future__ import annotations

import re
from typing import Any, Dict, List

from .schema import ParameterType, TransitionType
Expand Down Expand Up @@ -143,6 +144,43 @@ def validate(data: Any) -> None:

param_name_set = set(param_names)

# derived_quantities
derived_quantities = data.get("derived_quantities", [])
if not isinstance(derived_quantities, list):
errors.append("'derived_quantities' must be a list if provided.")
derived_quantities = []

formula_names = set(states) | param_name_set | {"t", "safe_div"}
derived_names: list[str] = []
for i, d in enumerate(derived_quantities):
loc = f"derived_quantities[{i}]"
if not isinstance(d, dict):
errors.append(f"'{loc}' must be a mapping.")
continue
name = d.get("name")
formula = d.get("formula")
if not isinstance(name, str) or not name.strip():
errors.append(f"'{loc}.name' must be a non-empty string.")
elif name in state_set or name in param_name_set:
errors.append(
f"'{loc}.name' must not duplicate a state or parameter name."
)
elif name in derived_names:
errors.append(
f"'derived_quantities' contains duplicate 'name' entry {name!r}."
)
if not isinstance(formula, str) or not formula.strip():
errors.append(f"'{loc}.formula' must be a non-empty string.")
else:
for ref in re.findall(r"\b[A-Za-z_][A-Za-z0-9_]*\b", formula):
if ref not in formula_names:
errors.append(
f"'{loc}.formula' references unknown or later-defined symbol {ref!r}."
)
if isinstance(name, str) and name.strip():
derived_names.append(name)
formula_names.add(name)

# transitions
transitions = data.get("transitions")
if not isinstance(transitions, list) or len(transitions) == 0:
Expand Down Expand Up @@ -185,6 +223,29 @@ def validate(data: Any) -> None:
f"'{loc}.parameter' references unknown parameter {param!r}."
)

if ttype == TransitionType.RATE:
if "rate" in t:
if "parameter" in t:
errors.append(
f"'{loc}' must define only one of 'parameter' or 'rate'."
)
rate = t.get("rate")
if not isinstance(rate, str) or not rate.strip():
errors.append(
f"'{loc}.rate' must be a non-empty string.")
else:
for ref in re.findall(r"\b[A-Za-z_][A-Za-z0-9_]*\b", rate):
if ref not in formula_names:
errors.append(
f"'{loc}.rate' references unknown symbol {ref!r}."
)
else:
param = t.get("parameter")
if param not in param_name_set:
errors.append(
f"'{loc}.parameter' references unknown parameter {param!r}."
)

if ttype == TransitionType.INFECTION:
has_singular = "infectious_state" in t
has_plural = "infectious_states" in t
Expand Down
Loading
Loading