YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)
license: apache-2.0 pipeline_tag: reinforcement-learning tags: - reinforcement-learning - meta-learning - pytorch
Pretrained weights for the Disco103 meta-network from Discovering State-of-the-art Reinforcement Learning Algorithms (Nature, 2025).
What is this?
A small LSTM neural network (754,778 parameters) that generates loss targets for RL agents. Instead of hand-crafted loss functions like PPO or GRPO, Disco103 observes an agent's rollout โ policy logits, rewards, advantages, auxiliary predictions โ and outputs target distributions the agent should match.
Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a PyTorch port.
Quick Start
from disco_torch import DiscoTrainer, collect_rollout
agent = YourAgent(obs_dim=64, num_actions=3).to(device)
trainer = DiscoTrainer(agent, device=device) # auto-downloads weights
env = YourEnv(num_envs=2)
obs = env.obs()
lstm_state = agent.init_lstm_state(env.num_envs, device)
def step_fn(actions):
rewards, dones = env.step(actions)
return env.obs(), rewards, dones
for step in range(1000):
rollout, obs, lstm_state = collect_rollout(
agent, step_fn, obs, lstm_state, rollout_len=29, device=device,
)
logs = trainer.step(rollout) # replay buffer, gradient loop, target updates โ all handled
DiscoTrainer encapsulates the full training loop: replay buffer, 32x inner gradient steps, per-element gradient
clipping, Polyak target network updates, and meta-state management. See
https://github.com/asystemoffields/disco-torch/blob/main/examples/catch_disco.py for a complete working example that
reaches 99% catch rate in 1000 steps.
Advanced: Low-level API
from disco_torch import DiscoUpdateRule, load_disco103_weights
rule = DiscoUpdateRule()
load_disco103_weights(rule) # auto-downloads from this repo
# Generate loss targets from a rollout
meta_out, new_state = rule.unroll_meta_net(
rollout, agent_params, meta_state, unroll_fn, hyper_params
)
loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)
File
disco_103.npz โ NumPy archive with 42 parameters (754,778 values total), converted from the original JAX checkpoint.
PyTorch Port
See https://github.com/asystemoffields/disco-torch for the full PyTorch implementation, examples, and experiment
results.
Citation
@article{oh2025disco,
title={Discovering State-of-the-art Reinforcement Learning Algorithms},
author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa
and Singh, Satinder and van Hasselt, Hado and Silver, David},
journal={Nature},
volume={648},
pages={312--319},
year={2025},
doi={10.1038/s41586-025-09761-x}
}
License
Apache 2.0 โ same as the original https://github.com/google-deepmind/disco_rl.
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support