Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: CI

on:
push:
branches: ["**"]
pull_request:

# Cancel superseded runs on the same ref to save CI minutes.
concurrency:
group: ci-${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
name: test (py${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.12"]

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip

- name: Upgrade pip tooling
run: python -m pip install --upgrade pip setuptools wheel

# Install CPU-only torch explicitly so we never pull CUDA wheels on CI.
- name: Install PyTorch (CPU)
run: pip install torch --index-url https://download.pytorch.org/whl/cpu

# Install the package plus dev + eval extras. torch is already satisfied
# by the CPU wheel above, so pip will not replace it.
- name: Install package (dev + eval extras)
run: pip install -e ".[dev,eval]"

- name: Show environment
run: |
python --version
pip show torch | grep -E "Name|Version" || true
python -c "import torch; print('torch', torch.__version__, 'cuda', torch.cuda.is_available())"

# Strict formatting gate for the files added by this infrastructure work.
# These must stay black-clean.
- name: Black (format check - new infra files)
run: black --check scripts/benchmark.py tests/test_smoke.py tests/test_imports.py

# Repo-wide format check is advisory for now: the existing tree still has
# pre-v2 formatting debt, so this reports diffs without failing the build.
# Flip continue-on-error to false once the tree is fully formatted.
- name: Black (format check - whole tree, advisory)
continue-on-error: true
run: black --check src tests scripts

- name: Run tests
run: pytest -q

# mypy is advisory: type issues are reported but do not fail the build.
- name: Mypy (non-strict, advisory)
run: mypy src/dimba || true
28 changes: 28 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Pre-commit hooks for DIMBA.
# Install with: pip install pre-commit && pre-commit install
# Run on all files: pre-commit run --all-files
#
# black and isort read their configuration from pyproject.toml
# (line-length = 100, isort profile = "black").

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
language_version: python3

- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]
208 changes: 110 additions & 98 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -1,124 +1,136 @@
# AGENTS.md

This file provides guidance to Any Agentic Coding Model when working with code in this repository.
Guidance for any agentic coding model working in this repo. Reflects the **v2 overhaul**
(branch `feature/dimba-v2-overhaul`, PR #18). For deeper detail see
`docs/OVERHAUL_STATUS.md`, `docs/IMPROVEMENT_PLAN.md`, and `docs/RESEARCH_DIRECTIONS.md`.

## Project overview

**DIMBA** is a non-autoregressive **latent-diffusion** language model: continuous Gaussian
diffusion runs in a learned latent space (VAE/projector over token embeddings; raw-embedding
diffusion is the degenerate `latent_diffusion=False` case), denoised by a **bidirectional
Mamba** backbone, generating whole sequences in parallel by iterative denoising.

- **v1 = `paper/main.pdf`** — an *architectural concept* (explicitly untested). Do not treat
it as ground truth: it contains a prompt-conditioning leak (`C = PromptEncoder(X₀)`) and an
MSE-only objective that the v2 code deliberately fixes.
- **v2 = this repo** — the implementation + the overhaul below. **This file describes v2.**

## ⚠️ Environment gotchas (read first)

- **`import torch` segfaults at interpreter *teardown*** on the original dev box (Windows),
and the bare `python` on PATH is a WindowsApps shim that hangs. **Use the project venv**:
`venv\Scripts\python.exe` (Windows) / `venv/bin/python` (mac). For scripts that import
torch, **end with `os._exit(0)`** after flushing to dodge the teardown crash.
- **Validate without running torch**: `python -m compileall src/dimba scripts tests` (syntax).
- **Runtime smoke**: `venv/bin/python .sisyphus/smoke_full.py` (end-to-end, uses `os._exit`).
- **CI** (`.github/workflows/ci.yml`) runs the real `pytest` suite on clean Linux runners
(py3.10/3.12) with working torch — that's the source of truth for runtime tests.
- **Apple Silicon (M1/M2/M3) training**: use the **PyTorch MPS** path (the `backends/mlx/`
port is a skeleton, not training-ready). Set `PYTORCH_ENABLE_MPS_FALLBACK=1`, use **fp32**,
and `latent_diffusion=True`. `SimpleMamba2` uses the vectorized scan on MPS; keep `seq_len`
≤ ~256. See the small-model recipe in `docs/OVERHAUL_STATUS.md`.

## Architecture (current data flow)

## Project Overview
```
input_ids ─► token_embed ─► encode_latent (×latent_scale → ~unit variance)
prompt (pooled, response-free) ─┐ add_noise (cosine, zero-terminal-SNR)
timestep_embed τ(t) ────────────┤ │ (only response positions if prompt_mask)
▼ ▼
Mamba denoiser (N bidirectional blocks, FiLM/additive cond)
│ [+ self-conditioning: prev x̂₀ fused in]
raw pred → x̂₀ latent (x0 or v param)
decode_latent (÷latent_scale → embedding) ─► output_head ─► logits
```

This repository is implementing **DIMBA** (Diffusion-based Mamba architecture), a non-autoregressive text generation model that combines:
- Cosine-scheduled diffusion process for parallel denoising
- Mamba-2 state-space model (SSM) as the denoiser backbone
- Conditioning via prompt embeddings and timestep embeddings
Key points that differ from v1 and **must not be reintroduced as bugs**:
- **No conditioning leak.** Conditioning is the *prompt only* — a pooled prompt summary, and
(when `prompt_mask` is given) the prompt tokens kept *clean in-sequence* while only the
response is noised, with loss on the response. Never condition on the clean target.
- **`forward()` always returns the 3-tuple** `(x_pred, noise, latent_info)`.
- **`encode_latent`/`decode_latent` carry `latent_scale`** and round-trip exactly. Anything
that diffuses or samples must go through them (signal must be ~unit variance for a
calibrated SNR). Call `model.calibrate_latent_scale(batch)` before training in latent mode.
- The model stores its full constructor config in `model.config` (used to build EMA/replicas).

The goal is to achieve faster parallel text generation compared to autoregressive transformers while maintaining output quality.
## Model API (`src/dimba/models/diffusion.py`)

## Architecture Summary
`DIMBA(...)` notable kwargs: `latent_diffusion`, `d_latent`, `use_vae_latent`,
`bidirectional=True`, `self_conditioning=False`, `prediction_type="x0"|"v"`,
`zero_terminal_snr=True`, `embed_init_std=0.02`, `latent_scale=None` (auto = `1/embed_init_std`
for embedding mode, `1.0` for latent mode → calibrate).

### Core Components (from Section 3.2 of the paper)
- `forward(input_ids, t, noise=None, prompt_mask=None, x_self_cond=None, drop_cond=False)`
- `predict_token_logits(input_ids, t)` → `[B,L,vocab]` (the **discrete/masked** track)
- `denoise_to_x0_latent(x_t, t, cond, x_self_cond=None)` / `denoise_step(...)` (inference)
- `conditioning_from_prompt(prompt_ids=None, batch_size, device, drop_cond=False)` → `[B,1,cond_dim]`
- `encode_latent` / `decode_latent` / `calibrate_latent_scale(batch, target_std=1.0)`

1. **Token Embeddings**: Input tokens mapped to continuous embedding space via learned embedding matrix E, producing X₀ ∈ ℝ^(L×d)
## Diffusion modes

2. **Prompt Encoder**: Lightweight MLP or frozen encoder that processes token embeddings to produce conditioning vector C ∈ ℝ^(L×d_c)
1. **Continuous latent (default)** — `GaussianEmbeddingCorruption`; the `forward()` path above.
2. **Discrete / masked (LLaDA/MDLM)** — `diffusion/corruption.py:AbsorbingMaskCorruption` +
`diffusion/masked_sampling.py:masked_diffusion_sample(predict_logits, ...)`; model side is
`predict_token_logits`. Needs a `[MASK]` token id (not in the tokenizer yet — pass explicitly).
3. **Hybrid (novel, experimental)** — `HybridCorruption` interpolates masked ↔ Gaussian per token.

3. **Cosine Noise Schedule**: Follows Nichol & Dhariwal (2021) with formula:
- ᾱ(t) = cos²((t/T + s)/(1 + s) · π/2), s = 0.008
- β_t = 1 - ᾱ(t)/ᾱ(t-1)
- X_T = √ᾱ(T)X₀ + √(1 - ᾱ(T))ε
## Training (`src/dimba/training/trainer.py`)

4. **Timestep Embedding**: Sinusoidal positional encoding processed through MLP yielding τ(t) ∈ ℝ^d
Use **`compute_dimba_losses(model, input_ids, t, *, ce_loss_weight=1.0, min_snr_gamma=5.0,
prompt_mask=None)`** → `(loss, parts)`. It combines:
- **min-SNR-γ-weighted** diffusion regression in latent space (x0 or v target),
- a **cross-entropy / rounding anchor** (trains the head/decoder, ties to real tokens),
- **latent autoencoder consistency** + optional **VAE KL** (latent mode).

5. **Mamba-2 Denoiser**: N Mamba-2 blocks taking (X_t, C, τ(t)) as input with either additive or Feature-wise Linear Modulation (FiLM) conditioning
`DIMBALightningModule` and `SimpleTrainer` both call it; both accept `ce_loss_weight` /
`min_snr_gamma`. The CDLM consistency loss (`compute_consistency_loss`) is de-leaked (null cond).
Schedule helpers: `CosineNoiseSchedule(num_steps, zero_terminal_snr=True)` with `.add_noise`,
`.velocity`, `.predict_x0_from_v`, `.snr`, plus `enforce_zero_terminal_snr`.

6. **Output Projection**: Linear layer mapping denoised embeddings to token logits (optionally weight-tied with embedding matrix)
## Inference (`src/dimba/diffusion/sampling.py`)

### Data Flow
- `sample_from_model(model, prompt_ids, seq_len, num_steps, temperature, top_k, top_p,
guidance_scale=1.0, eta=0.0, clamp_to_tokens=False)` — correct x0-DDIM, CFG, self-cond carry,
clean-prefix conditioning. `DDIMSampler` wraps it.
- `diffusion/rerank.py:best_of_k(generate_fn, score_fn, k)` + `diffusion_elbo_score(...)` — best-of-K.

```
Input Prompt
Token Embeddings (X₀)
↙ ↖
Prompt Encoder Noise Injection (Cosine Schedule)
↓ ↓
Conditioning (C) Noisy Embeddings (X_T)
↓ ↓
Mamba-2 Denoiser (with Timestep Embedding τ)
Denoised Embeddings (X₀_pred)
Output Projection to Logits
Output Tokens
```
## Post-training (`scripts/finetuning/`)

## Training Procedure
- `finetune_sft.py` — SFT (leak-free per-position prompt cond; response-only CE via labels).
- `finetune_dpo.py` + `training/preference.py` — **DPO/IPO/SimPO** with a diffusion-ELBO/VRPO
surrogate (diffusion log-likelihoods are intractable). Preferred for preference pairs.
- `finetune_grpo.py` + `training/rewards.py` — GRPO with a **pluggable `--reward`** (default
`numeric`; `token_overlap` kept but deprecated — it just teaches copying).

**Objective**: Learn to reverse the diffusion process at arbitrary timesteps.

```
For each training batch:
1. Sample random timestep: t ~ Uniform(1, T)
2. Create noisy embeddings: X_t = √ᾱ(t)X₀ + √(1 - ᾱ(t))ε
3. Encode prompt: C = PromptEncoder(X₀)
4. Create timestep embedding: τ = MLP(t)
5. Predict: X_pred = Denoiser(X_t, C, τ)
6. Loss: L = ||X_pred - X₀||²
7. Update parameters via backpropagation
```
## Performance (`src/dimba/models/parallel_scan.py`, `utils/compile.py`, `backends/`)

## Inference Procedure
- `selective_scan(dt, A, Bmat, C, x, *, stable=True, chunk_size=64)` — vectorized, numerically
stable (chunked); `selective_scan_sequential` is the parity reference; `bidirectional_*` too.
`SimpleMamba2` uses it and falls back to the sequential scan if the result is non-finite.
- `maybe_compile(module)` — `torch.compile` on CUDA only. `backends/mlx/` — MLX skeleton (WIP).

**Goal**: Generate text of length L_gen by iterative denoising from noise.
## Repo layout

```
1. Compute prompt conditioning: C = PromptEncoder(X_prompt)
2. Initialize with noise: X_T ~ N(0, I) ∈ ℝ^(L_gen×d)
3. Iterative denoising loop (t = T down to 1):
- τ = MLP(t)
- X_{t-1} = Denoiser(X_t, C, τ)
4. Final projection: X₀ → linear layer → softmax → output tokens
src/dimba/{models,diffusion,training,data,tokenizers,evaluation,utils,backends}/
scripts/{train*,generate,evaluate,benchmark}.py scripts/finetuning/finetune_{sft,dpo,grpo,interactive}.py
configs/ tests/ notebooks/ docs/ paper/
```

The number of diffusion steps T controls the speed-quality trade-off: lower T = faster but potentially lower quality.

## Key Implementation Notes

### Hyperparameters to Consider
- **T**: Number of diffusion steps (controls inference speed/quality trade-off)
- **d**: Embedding dimension
- **d_c**: Conditioning dimension (may equal d)
- **N**: Number of Mamba-2 blocks in denoiser
- **s**: Noise schedule constant (0.008 per paper)

### Conditioning Mechanisms
- **Additive**: Simple concatenation with noise
- **FiLM (Feature-wise Linear Modulation)**: γ(C) * X_t + β(C), where γ and β are learned from C

### Important Design Choices Requiring Implementation Decisions
1. **Prompt Encoder**: Frozen or trainable? Use existing pretrained encoder or train from scratch?
2. **Weight Tying**: Should output projection share weights with embedding matrix?
3. **Mamba-2 Architecture**: How many layers? What hidden dimension?
4. **Sampling During Inference**: Pure denoising or DDIM-style acceleration?

## Paper References

- **Main sections**: See `paper/main.txt` for full details
- **Figure 1**: Shows complete end-to-end architecture
- **Section 3.2**: Detailed component descriptions
- **Section 3.3-3.4**: Training and inference procedures
- **Section 4.1**: Hypothesized advantages (latency, coherence, controllability, reasoning, extensibility)
- **Section 4.2**: Known challenges (training cost, discrete-continuous gap, conditioning robustness, hyperparameter sensitivity)

## Dependencies & Libraries
## Conventions

When implementing, likely dependencies include:
- PyTorch (for tensor operations)
- Mamba-2 implementation (from `mamba-2` package or custom implementation)
- Hugging Face Transformers (for tokenizers, embeddings, reference models)
- Python ≥3.9, **black line-length 100**, type hints, Google-style docstrings.
- Don't reintroduce the conditioning leak, the 2-tuple `forward`, positive-`A` SSM, or
un-scaled latents. Run `compileall` + the smoke before claiming a change works.

## Testing Strategy
## Current status (2026-05-27)

Based on Section 4.2 challenges:
- Test discrete-continuous mapping accuracy for rare tokens
- Validate conditioning mechanisms (FiLM vs additive) across diverse prompts
- Benchmark inference latency with varying T values
- Compare generation quality against autoregressive baselines
- **PR #18** open into `main`: `c2352ba` (overhaul) + `60f30eb` (latent scale-factor). **Not
merged.** All known bugs fixed; `compileall` clean; runtime smoke 14/14 (venv python).
- **Open follow-ups**: first-class masked-mode training script + `[MASK]` token; an M1
quickstart config; train a real VAE to calibrate the latent against; cross-attention
conditioning (stronger than pooled-global); real speed/quality benchmarks once compute lands.
64 changes: 64 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Changelog

All notable changes to this project are documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

The v2 overhaul focuses on correctness, modern diffusion-LM capabilities, and
CPU/GPU performance. Items below are in progress on the `feature/dimba-v2-overhaul`
branch and describe the direction at a high level.

### Added

- **Self-conditioning** for the denoiser: the model can condition each denoising
step on its own previous clean-sample estimate, improving sample quality at a
small compute cost.
- **Classifier-free guidance (CFG)**: joint conditional/unconditional training via
prompt dropout, with a guidance scale applied at sampling time.
- **Discrete / masked diffusion mode**: an alternative to continuous Gaussian
diffusion over embeddings, operating directly on token states with a masked /
absorbing-state corruption process.
- **Preference optimization (DPO)**: direct preference optimization for aligning
generations to preferred outputs, building on the existing preference-data
pipeline.
- **Performance backends**: pluggable denoiser backends so the optimized
`mamba-ssm` kernels are used when available (GPU) while the pure-PyTorch
`SimpleMamba2` remains the default CPU-friendly fallback.
- **Infrastructure**: a CPU inference benchmark (`scripts/benchmark.py`),
smoke/import test suites, GitHub Actions CI (Python 3.10 and 3.12, CPU only),
pre-commit hooks (black, isort, trailing-whitespace, end-of-file-fixer), and
this changelog.

### Changed

- Sampling is being consolidated around correct, schedule-consistent update rules
for both ancestral and DDIM-style accelerated inference, with configurable
step counts and guidance.
- Conditioning, latent-projection, and timestep-embedding interfaces are being
unified so continuous, latent, and discrete modes share a single denoiser path.

### Fixed

- Correctness fixes to the noise schedule and the train/inference sampling math so
the reverse process is consistent with the forward (training) process, including
terminal-SNR handling and per-step variance computation.
- More robust logit post-processing during sampling (temperature, top-k / top-p)
to avoid NaNs from fully-masked distributions.

## [0.1.0] - 2025-01-24

### Added

- Initial DIMBA library: continuous Gaussian diffusion over token embeddings with
a Mamba-2 denoiser for non-autoregressive text generation.
- Pure-PyTorch `SimpleMamba2` denoiser for CPU usage without compiled kernels.
- Cosine noise schedule, `sample_from_model`, and a DDIM sampler.
- Character and BPE tokenizers, dataset utilities, evaluation metrics, and
PyTorch Lightning training utilities.
- LoRA / Q-LoRA adapters and a finetuning data pipeline (SFT and preference data).

[Unreleased]: https://github.com/devnull37/dimba-lib/compare/v0.1.0...HEAD
[0.1.0]: https://github.com/devnull37/dimba-lib/releases/tag/v0.1.0
Loading
Loading