YAML Metadata Warning: empty or missing yaml metadata in repo card
Check out the documentation for more information.
DIMA — Diffusion–Isocoder Manifold-Autoencoder
DIMA combines three components into a practical, scalable manifold autoencoder:
- DMAP: Diffusion Maps encoder (builds a kNN graph in ambient space and embeds points into diffusion coordinates).
- GPLM: Nyström / inducing-point kernel ridge decoder (maps latent diffusion coordinates back to ambient space).
- DDPM: Latent diffusion model (optional) that learns a generative prior over normalized latents and can “refine” latents.
Philosophy: keep DMAP + GPLM fast on CPU/RAM (NumPy/SciPy, sparse ops), and optionally run DDPM on JAX (CPU or GPU).
Installation
CPU-only (recommended to start)
pip install dima
Optional extras
- Faster kNN search (FAISS):
pip install dima[faiss]
- Hugging Face upload/load:
pip install dima[hf]
GPU note (JAX):
pip install dimainstalls CPUjaxlibby default. For GPU, install the correct JAX wheel for your CUDA/ROCm setup first (per JAX docs), then installdima.
Quickstart
import numpy as np
import jax
from dima import DIMA
R_iX = np.random.randn(5000, 16).astype(np.float32)
dima = DIMA(R_iX) # trains DMAP, GPLM, DDPM with defaults
Z = dima.encode(R_iX[:10]) # ambient -> normalized latent (jnp)
X_hat = dima.decode(Z, refine=False) # latent -> ambient (np), no DDPM
X_ref = dima.decode(Z, refine=True, t_start=10) # DDPM refinement
X_gen = dima.sample(1000) # unconditional samples -> ambient (np)
What DIMA trains
1) DMAP encoder (Diffusion Maps)
DMAP builds a kNN graph over ambient data $R_{iX}\in\mathbb{R}^{N\times D}$ and returns diffusion coordinates $R_{ix}\in\mathbb{R}^{N\times d}$.
Key idea (no theory): DMAP produces a geometry-aware latent space where nearby points on the manifold remain nearby in diffusion distance.
Main knobs (DMAP):
d(int): latent dimension.k(int): neighbors in the kNN graph. IfNone, a heuristic is used.beta/β(float): kernel sharpness for the RBF affinity $K_{ij}=\exp{-\beta |x_i-x_j|^2/\varepsilon}$.eps/ε(float or None): kernel bandwidth. IfNone, estimated from kNN distances (median heuristic).alpha/α(float): density normalization exponent. Common values:0.0(none) or1.0(often robust).t(float): diffusion time exponent (scales eigenvalues as $\lambda^t$). Often0.5or1.0.drop_trivial(bool): drops the top eigenvector/eigenvalue (the constant mode).sym(str): symmetrization mode for sparse kNN kernel graph:"max"or"mean".ann_backend(str):"auto" | "faiss" | "pynndescent" | "sklearn" | "brute".
Typical DMAP presets
- Fast-ish and stable:
k=128..512,alpha=1.0,t=0.5,drop_trivial=True. - If your data is very noisy, try larger
kand/or largereps_mul.
2) GPLM decoder (Nyström kernel ridge / inducing GP)
GPLM learns a mapping from latents back to ambient:
- Inputs: latent training points $R_{ix}\in\mathbb{R}^{N\times d}$
- Targets: ambient training points $R_{iX}\in\mathbb{R}^{N\times D}$
Instead of a full $N\times N$ kernel solve, GPLM uses inducing points $Z_{mx}$ with $m\ll N$ and solves a reduced system:
- Build affinities $C_{im}=\exp{-\beta|R_{ix}-Z_{mx}|^2/\varepsilon}$
- Solve a stabilized kernel ridge / GP mean system to obtain weights $M_{mX}$
- Predict: $\hat R_{aX}=C_{am}M_{mX}+\mu_X$
Main knobs (GPLM):
m(int): number of inducing points. Bigger → better accuracy, more compute.inducing(str): inducing strategy:"kmeans_medoids"(default): kmeans centers snapped to nearest training latent."fps": farthest-point sampling (space-filling)."random_subset": fastest."given": use providedZ_mx.
sigma2/σ2(float): ridge regularization. Too small can overfit / cause instability; too large blurs reconstructions.jitter(float): tiny diagonal stabilizer for Cholesky.eps/ε(float or None): RBF bandwidth in latent space. IfNone, estimated from latent kNN distances.k_eps/κ_eps(int): neighbors used for the $\varepsilon$ heuristic.pred_k/pred_κ(int or None): prediction-time inducing neighbors:Nonemeans use all inducing points (best accuracy).- a small number (e.g.
128or256) speeds inference (slightly lower accuracy).
whiten_latent(bool): optionally standardize latent dimensions before kernel computation.center_X(bool): subtract and re-add ambient mean (usually helpful).fit_block(int): block size for streaming $C^T C$ accumulation (memory/perf knob).ann_backend(str): same options as DMAP.
Typical GPLM presets
- Accurate:
m=1024..4096,pred_k=None,sigma2=1e-5(tune). - Faster inference: set
pred_k=128..512.
3) DDPM latent diffusion (optional generative prior)
DDPM learns a distribution over normalized latents: [ Z = \frac{R_x - \mu}{\sigma} ] and can:
- sample new latents,
- refine a given latent by projecting it onto the learned latent manifold/prior.
In DIMA, DDPM operates purely in latent space (dimension d), so it’s lightweight compared to image DDPMs.
Main knobs (DDPM):
T(int): number of diffusion steps. Common:100..1000. DIMA default:200.hidden_dim(int): MLP width.t_embed_dim(int): time embedding dimension.n_iter(int): training iterations (more is better for sampling quality).batch_size(int): training batch size.learning_rate(float): Adam learning rate.ema_decay(float): EMA smoothing for stable sampling (typical:0.999).beta_max(float): caps noise schedule (stability knob).eps(float): numerical stabilizer.verbose_every(int): progress prints.
Refinement knobs (during decoding):
refine(bool): enable/disable DDPM refinement.t_start(int): how strongly to “project” using reverse diffusion.- small (
1..10) = gentle projection - larger (
20..100) = stronger projection (can oversmooth or drift if DDPM undertrained)
- small (
add_noise(bool): whether to forward-noise before reverse steps.
API
Core class
from dima import DIMA
dima = DIMA(R_iX)
Encode / decode
Z = dima.encode(X) # (B,d) normalized latent (jnp)
X_hat = dima.decode(Z, refine=False) # (B,D) reconstruction (np)
X_ref = dima.decode(Z, refine=True, t_start=10) # refined decode
Polymorphic call
Z = dima(X) # if X.shape[-1] == D
X = dima(Z) # if Z.shape[-1] == d
Sampling
X_gen = dima.sample(1000) # ambient samples (np)
Z_gen = dima.sample(1000, decode=False) # latent samples (jnp)
Configuration patterns
1) Use defaults (cleanest)
dima = DIMA(R_iX)
2) Pass only a few knobs
dima = DIMA(
R_iX,
d=64,
dmap_kwargs=dict(k=256, alpha=1.0, t=0.5),
gplm_kwargs=dict(m=2048, sigma2=1e-5, pred_k=256),
ddpm_kwargs=dict(n_iter=50_000, hidden_dim=256),
)
Saving / Loading
Local
dima.save_local("dima.msgpack", "config.json")
dima2 = DIMA.load_local("dima.msgpack", ddpm_device="auto")
Hugging Face (optional)
dima.upload_to_huggingface(repo_id="username/dima-model", hf_token="...")
dima3 = DIMA.load_from_huggingface("username/dima-model", ddpm_device="auto")
Performance notes
DMAP: main costs are kNN search + sparse eigensolve.
- kNN is faster with FAISS.
- Increasing
kincreases graph density and compute.
GPLM: training cost roughly scales with $N\cdot m$ (streamed by
fit_block).- Inference cost is $B\cdot m$ if
pred_k=None, or $B\cdot \text{pred_k}$ if using inducing kNN.
- Inference cost is $B\cdot m$ if
DDPM: scales with latent dimension
d, stepsT, and training iterations.- GPU helps but CPU works for smaller
dand fewer iterations.
- GPU helps but CPU works for smaller
Troubleshooting
“Unknown backend: gpu”
Your JAX install only has CPU. Use:
dima = DIMA(R_iX, ddpm_device="cpu")
or install a GPU-enabled JAX build.
Reconstructions are blurry / low quality
- Increase
gplm_kwargs["m"] - Decrease
gplm_kwargs["sigma2"]slightly (careful: too small can destabilize) - Set
pred_k=None(use all inducing points)
DDPM refinement makes reconstructions worse
- Reduce
t_start(try3..10) - Increase DDPM training (
ddpm_kwargs["n_iter"]) - Disable
add_noisefor gentler behavior
Citations
- Diffusion Maps / DMAE inspiration: Diffusion Map AutoEncoder (DMAE).
- Latent diffusion prior: Denoising Diffusion Probabilistic Models (DDPM).
License
MIT