Skip to content

Hwhitetooth/jax_muzero

Repository files navigation

JAX MuZero

A JAX implementation of the MuZero agent.

Everything is implemented in JAX, including the MCTS. The entire search process can be jitted and can run on accelerators such as GPUs.

Requirements

Run the following command to create a new conda environment with all dependencies:

conda env create -f conda_env.yml

Then activate the conda environment by

conda activate muzero

Or if you prefer using your own Python environment, run the following command to install the dependencies:

pip install -r requirements.txt

Training

Run the following command for learning to play the Atari game Breakout:

python -m experiments.breakout

Atari 100K Benchmark Results

Median human-normalized score:

Raw game scores:

Repository Structure

.
├── algorithms              # Files for the MuZero algorithm.
│   ├── actors.py           # Agent-environment interaction.
│   ├── agents.py           # An RL agent that plans with a learned model by MCTS.
│   ├── haiku_nets.py       # Neural networks.
│   ├── muzero.py           # The training pipeline.
│   ├── replay_buffers.py   # Experience replay.
│   ├── types.py            # Customized data structures.
│   └── utils.py            # Helper functions.
├── environments            # The Atari environment interface and wrappers.
├── experiments             # Experiment configuration files.
├── vec_env                 # Vectorized environment interfaces.
├── conda_env.yml           # Conda environment specification.
├── requirements.txt        # Python dependencies.
├── LICENSE
└── README.md

Resources