YAML Metadata Warning: empty or missing yaml metadata in repo card

Check out the documentation for more information.

DIMA — Diffusion–Isocoder Manifold-Autoencoder

DIMA combines three components into a practical, scalable manifold autoencoder:

  • DMAP: Diffusion Maps encoder (builds a kNN graph in ambient space and embeds points into diffusion coordinates).
  • GPLM: Nyström / inducing-point kernel ridge decoder (maps latent diffusion coordinates back to ambient space).
  • DDPM: Latent diffusion model (optional) that learns a generative prior over normalized latents and can “refine” latents.

Philosophy: keep DMAP + GPLM fast on CPU/RAM (NumPy/SciPy, sparse ops), and optionally run DDPM on JAX (CPU or GPU).


Installation

CPU-only (recommended to start)

pip install dima

Optional extras

  • Faster kNN search (FAISS):
pip install dima[faiss]
  • Hugging Face upload/load:
pip install dima[hf]

GPU note (JAX): pip install dima installs CPU jaxlib by default. For GPU, install the correct JAX wheel for your CUDA/ROCm setup first (per JAX docs), then install dima.


Quickstart

import numpy as np
import jax
from dima import DIMA

R_iX = np.random.randn(5000, 16).astype(np.float32)

dima = DIMA(R_iX)                 # trains DMAP, GPLM, DDPM with defaults

Z = dima.encode(R_iX[:10])        # ambient -> normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)  # latent -> ambient (np), no DDPM
X_ref = dima.decode(Z, refine=True, t_start=10)  # DDPM refinement

X_gen = dima.sample(1000)         # unconditional samples -> ambient (np)

What DIMA trains

1) DMAP encoder (Diffusion Maps)

DMAP builds a kNN graph over ambient data $R_{iX}\in\mathbb{R}^{N\times D}$ and returns diffusion coordinates $R_{ix}\in\mathbb{R}^{N\times d}$.

Key idea (no theory): DMAP produces a geometry-aware latent space where nearby points on the manifold remain nearby in diffusion distance.

Main knobs (DMAP):

  • d (int): latent dimension.
  • k (int): neighbors in the kNN graph. If None, a heuristic is used.
  • beta / β (float): kernel sharpness for the RBF affinity $K_{ij}=\exp{-\beta |x_i-x_j|^2/\varepsilon}$.
  • eps / ε (float or None): kernel bandwidth. If None, estimated from kNN distances (median heuristic).
  • alpha / α (float): density normalization exponent. Common values: 0.0 (none) or 1.0 (often robust).
  • t (float): diffusion time exponent (scales eigenvalues as $\lambda^t$). Often 0.5 or 1.0.
  • drop_trivial (bool): drops the top eigenvector/eigenvalue (the constant mode).
  • sym (str): symmetrization mode for sparse kNN kernel graph: "max" or "mean".
  • ann_backend (str): "auto" | "faiss" | "pynndescent" | "sklearn" | "brute".

Typical DMAP presets

  • Fast-ish and stable: k=128..512, alpha=1.0, t=0.5, drop_trivial=True.
  • If your data is very noisy, try larger k and/or larger eps_mul.

2) GPLM decoder (Nyström kernel ridge / inducing GP)

GPLM learns a mapping from latents back to ambient:

  • Inputs: latent training points $R_{ix}\in\mathbb{R}^{N\times d}$
  • Targets: ambient training points $R_{iX}\in\mathbb{R}^{N\times D}$

Instead of a full $N\times N$ kernel solve, GPLM uses inducing points $Z_{mx}$ with $m\ll N$ and solves a reduced system:

  • Build affinities $C_{im}=\exp{-\beta|R_{ix}-Z_{mx}|^2/\varepsilon}$
  • Solve a stabilized kernel ridge / GP mean system to obtain weights $M_{mX}$
  • Predict: $\hat R_{aX}=C_{am}M_{mX}+\mu_X$

Main knobs (GPLM):

  • m (int): number of inducing points. Bigger → better accuracy, more compute.

  • inducing (str): inducing strategy:

    • "kmeans_medoids" (default): kmeans centers snapped to nearest training latent.
    • "fps": farthest-point sampling (space-filling).
    • "random_subset": fastest.
    • "given": use provided Z_mx.
  • sigma2 / σ2 (float): ridge regularization. Too small can overfit / cause instability; too large blurs reconstructions.

  • jitter (float): tiny diagonal stabilizer for Cholesky.

  • eps / ε (float or None): RBF bandwidth in latent space. If None, estimated from latent kNN distances.

  • k_eps / κ_eps (int): neighbors used for the $\varepsilon$ heuristic.

  • pred_k / pred_κ (int or None): prediction-time inducing neighbors:

    • None means use all inducing points (best accuracy).
    • a small number (e.g. 128 or 256) speeds inference (slightly lower accuracy).
  • whiten_latent (bool): optionally standardize latent dimensions before kernel computation.

  • center_X (bool): subtract and re-add ambient mean (usually helpful).

  • fit_block (int): block size for streaming $C^T C$ accumulation (memory/perf knob).

  • ann_backend (str): same options as DMAP.

Typical GPLM presets

  • Accurate: m=1024..4096, pred_k=None, sigma2=1e-5 (tune).
  • Faster inference: set pred_k=128..512.

3) DDPM latent diffusion (optional generative prior)

DDPM learns a distribution over normalized latents: [ Z = \frac{R_x - \mu}{\sigma} ] and can:

  • sample new latents,
  • refine a given latent by projecting it onto the learned latent manifold/prior.

In DIMA, DDPM operates purely in latent space (dimension d), so it’s lightweight compared to image DDPMs.

Main knobs (DDPM):

  • T (int): number of diffusion steps. Common: 100..1000. DIMA default: 200.
  • hidden_dim (int): MLP width.
  • t_embed_dim (int): time embedding dimension.
  • n_iter (int): training iterations (more is better for sampling quality).
  • batch_size (int): training batch size.
  • learning_rate (float): Adam learning rate.
  • ema_decay (float): EMA smoothing for stable sampling (typical: 0.999).
  • beta_max (float): caps noise schedule (stability knob).
  • eps (float): numerical stabilizer.
  • verbose_every (int): progress prints.

Refinement knobs (during decoding):

  • refine (bool): enable/disable DDPM refinement.

  • t_start (int): how strongly to “project” using reverse diffusion.

    • small (1..10) = gentle projection
    • larger (20..100) = stronger projection (can oversmooth or drift if DDPM undertrained)
  • add_noise (bool): whether to forward-noise before reverse steps.


API

Core class

from dima import DIMA
dima = DIMA(R_iX)

Encode / decode

Z = dima.encode(X)                      # (B,d) normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)    # (B,D) reconstruction (np)
X_ref = dima.decode(Z, refine=True, t_start=10)  # refined decode

Polymorphic call

Z = dima(X)     # if X.shape[-1] == D
X = dima(Z)     # if Z.shape[-1] == d

Sampling

X_gen = dima.sample(1000)            # ambient samples (np)
Z_gen = dima.sample(1000, decode=False)  # latent samples (jnp)

Configuration patterns

1) Use defaults (cleanest)

dima = DIMA(R_iX)

2) Pass only a few knobs

dima = DIMA(
    R_iX,
    d=64,
    dmap_kwargs=dict(k=256, alpha=1.0, t=0.5),
    gplm_kwargs=dict(m=2048, sigma2=1e-5, pred_k=256),
    ddpm_kwargs=dict(n_iter=50_000, hidden_dim=256),
)

Saving / Loading

Local

dima.save_local("dima.msgpack", "config.json")
dima2 = DIMA.load_local("dima.msgpack", ddpm_device="auto")

Hugging Face (optional)

dima.upload_to_huggingface(repo_id="username/dima-model", hf_token="...")

dima3 = DIMA.load_from_huggingface("username/dima-model", ddpm_device="auto")

Performance notes

  • DMAP: main costs are kNN search + sparse eigensolve.

    • kNN is faster with FAISS.
    • Increasing k increases graph density and compute.
  • GPLM: training cost roughly scales with $N\cdot m$ (streamed by fit_block).

    • Inference cost is $B\cdot m$ if pred_k=None, or $B\cdot \text{pred_k}$ if using inducing kNN.
  • DDPM: scales with latent dimension d, steps T, and training iterations.

    • GPU helps but CPU works for smaller d and fewer iterations.

Troubleshooting

“Unknown backend: gpu”

Your JAX install only has CPU. Use:

dima = DIMA(R_iX, ddpm_device="cpu")

or install a GPU-enabled JAX build.

Reconstructions are blurry / low quality

  • Increase gplm_kwargs["m"]
  • Decrease gplm_kwargs["sigma2"] slightly (careful: too small can destabilize)
  • Set pred_k=None (use all inducing points)

DDPM refinement makes reconstructions worse

  • Reduce t_start (try 3..10)
  • Increase DDPM training (ddpm_kwargs["n_iter"])
  • Disable add_noise for gentler behavior

Citations

  • Diffusion Maps / DMAE inspiration: Diffusion Map AutoEncoder (DMAE).
  • Latent diffusion prior: Denoising Diffusion Probabilistic Models (DDPM).

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support