YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

license: apache-2.0 pipeline_tag: reinforcement-learning tags: - reinforcement-learning - meta-learning - pytorch

Pretrained weights for the Disco103 meta-network from Discovering State-of-the-art Reinforcement Learning Algorithms (Nature, 2025).

What is this?

A small LSTM neural network (754,778 parameters) that generates loss targets for RL agents. Instead of hand-crafted loss functions like PPO or GRPO, Disco103 observes an agent's rollout โ€” policy logits, rewards, advantages, auxiliary predictions โ€” and outputs target distributions the agent should match.

Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a PyTorch port.

Quick Start

from disco_torch import DiscoTrainer, collect_rollout

agent = YourAgent(obs_dim=64, num_actions=3).to(device)
trainer = DiscoTrainer(agent, device=device)  # auto-downloads weights

env = YourEnv(num_envs=2)
obs = env.obs()
lstm_state = agent.init_lstm_state(env.num_envs, device)

def step_fn(actions):
    rewards, dones = env.step(actions)
    return env.obs(), rewards, dones

for step in range(1000):
    rollout, obs, lstm_state = collect_rollout(
        agent, step_fn, obs, lstm_state, rollout_len=29, device=device,
    )
    logs = trainer.step(rollout)  # replay buffer, gradient loop, target updates โ€” all handled

DiscoTrainer encapsulates the full training loop: replay buffer, 32x inner gradient steps, per-element gradient
clipping, Polyak target network updates, and meta-state management. See
https://github.com/asystemoffields/disco-torch/blob/main/examples/catch_disco.py for a complete working example that
reaches 99% catch rate in 1000 steps.

Advanced: Low-level API

from disco_torch import DiscoUpdateRule, load_disco103_weights

rule = DiscoUpdateRule()
load_disco103_weights(rule)  # auto-downloads from this repo

# Generate loss targets from a rollout
meta_out, new_state = rule.unroll_meta_net(
    rollout, agent_params, meta_state, unroll_fn, hyper_params
)
loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)

File

disco_103.npz โ€” NumPy archive with 42 parameters (754,778 values total), converted from the original JAX checkpoint.

PyTorch Port

See https://github.com/asystemoffields/disco-torch for the full PyTorch implementation, examples, and experiment
results.

Citation

@article{oh2025disco,
  title={Discovering State-of-the-art Reinforcement Learning Algorithms},
  author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa
and Singh, Satinder and van Hasselt, Hado and Silver, David},
  journal={Nature},
  volume={648},
  pages={312--319},
  year={2025},
  doi={10.1038/s41586-025-09761-x}
}

License

Apache 2.0 โ€” same as the original https://github.com/google-deepmind/disco_rl.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support