{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%run notebook_setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# PyMC3 extras\n", "\n", "*exoplanet* comes bundled with a few utilities that can make it easier to use and debug PyMC3 models for fitting exoplanet data.\n", "This tutorial briefly describes these features and their use.\n", "\n", "## Custom tuning schedule\n", "\n", "The main extra is the :class:`exoplanet.PyMC3Sampler` class that wraps the PyMC3 sampling procedure to include support for learning off-diagonal elements of the mass matrix.\n", "This is *very* important for any problems where there are covariances between the parameters (this is true for pretty much all exoplanet models).\n", "A thorough discussion of this [can be found elsewhere online](https://dfm.io/posts/pymc3-mass-matrix/), but here is a simple demo where we sample a covariant Gaussian using :class:`exoplanet.PyMC3Sampler`.\n", "\n", "First, we generate a random positive definite covariance matrix for the Gaussian:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "ndim = 5\n", "np.random.seed(42)\n", "L = np.random.randn(ndim, ndim)\n", "L[np.diag_indices_from(L)] = 0.1*np.exp(L[np.diag_indices_from(L)])\n", "L[np.triu_indices_from(L, 1)] = 0.0\n", "cov = np.dot(L, L.T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then we can sample this using PyMC3 and :class:`exoplanet.PyMC3Sampler`:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains: 100%|██████████| 308/308 [00:04<00:00, 76.53draws/s]\n", "Sampling 4 chains: 100%|██████████| 108/108 [00:01<00:00, 63.32draws/s]\n", "Sampling 4 chains: 100%|██████████| 208/208 [00:00<00:00, 456.97draws/s]\n", "Sampling 4 chains: 100%|██████████| 408/408 [00:00<00:00, 1185.92draws/s]\n", "Sampling 4 chains: 100%|██████████| 808/808 [00:00<00:00, 1313.22draws/s]\n", "Sampling 4 chains: 100%|██████████| 1608/1608 [00:01<00:00, 1456.51draws/s]\n", "Sampling 4 chains: 100%|██████████| 4608/4608 [00:03<00:00, 1432.32draws/s]\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [x]\n", "Sampling 4 chains: 100%|██████████| 8200/8200 [00:04<00:00, 1773.95draws/s]\n" ] } ], "source": [ "import pymc3 as pm\n", "import exoplanet as xo\n", "\n", "sampler = xo.PyMC3Sampler()\n", "\n", "with pm.Model() as model:\n", " pm.MvNormal(\"x\", mu=np.zeros(ndim), chol=L, shape=(ndim,))\n", " \n", " # Run the burn-in and learn the mass matrix\n", " step_kwargs = dict(target_accept=0.9)\n", " sampler.tune(tune=2000, step_kwargs=step_kwargs)\n", " \n", " # Run the production chain\n", " trace = sampler.sample(draws=2000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a little more verbose than the standard use of PyMC3, but the performance is several orders of magnitude better than you would get without the mass matrix tuning.\n", "As you can see from the `pymc3.summary`, the autocorrelation time of this chain is about 1 as we would expect for a simple problem like this." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | mean | \n", "sd | \n", "mc_error | \n", "hpd_2.5 | \n", "hpd_97.5 | \n", "n_eff | \n", "Rhat | \n", "
---|---|---|---|---|---|---|---|
x__0 | \n", "0.000182 | \n", "0.161166 | \n", "0.001505 | \n", "-0.321874 | \n", "0.309854 | \n", "10391.955695 | \n", "1.000320 | \n", "
x__1 | \n", "-0.001958 | \n", "0.528701 | \n", "0.005229 | \n", "-0.979854 | \n", "1.084149 | \n", "11469.348323 | \n", "1.000107 | \n", "
x__2 | \n", "0.002317 | \n", "0.654096 | \n", "0.006766 | \n", "-1.305341 | \n", "1.250142 | \n", "10681.526348 | \n", "1.000276 | \n", "
x__3 | \n", "0.006045 | \n", "1.172721 | \n", "0.012248 | \n", "-2.207018 | \n", "2.352476 | \n", "10931.385998 | \n", "1.000197 | \n", "
x__4 | \n", "-0.002606 | \n", "2.020893 | \n", "0.017690 | \n", "-3.955703 | \n", "3.929923 | \n", "10968.369908 | \n", "0.999961 | \n", "