Skip to content

Commit

Permalink
Merge pull request #317 from kboone/emcee_deprecation
Browse files Browse the repository at this point in the history
Update emcee usage to avoid deprecation warnings
  • Loading branch information
kboone committed Sep 9, 2021
2 parents 98f16f8 + 855b4b5 commit ff508e3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 43 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ test_coverage = # Required for coverage tests.
all = # Used by some parts of sncosmo, but not required.
iminuit
matplotlib
emcee
emcee>=3
nestle
docs = # Required to build the docs.
cython
Expand All @@ -62,7 +62,7 @@ docs = # Required to build the docs.
sphinx_rtd_theme
pillow
iminuit
emcee
emcee>=3
numpydoc

[options.package_data]
Expand Down
60 changes: 19 additions & 41 deletions sncosmo/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,14 +1142,14 @@ def loglike(parameters):
def mcmc_lc(data, model, vparam_names, bounds=None, priors=None,
guess_amplitude=True, guess_t0=True, guess_z=True,
minsnr=5., modelcov=False, nwalkers=10, nburn=200,
nsamples=1000, sampler='ensemble', ntemps=4, thin=1,
a=2.0, warn=True):
nsamples=1000, sampler='ensemble', thin=1, a=2.0,
warn=True):
"""Run an MCMC chain to get model parameter samples.
This is a convenience function around `emcee.EnsembleSampler` andx
`emcee.PTSampler`. It defines the likelihood function and makes a
heuristic guess at a good set of starting points for the
walkers. It then runs the sampler, starting with a burn-in run.
This is a convenience function around `emcee.EnsembleSampler`.
It defines the likelihood function and makes a heuristic guess
at a good set of starting points for the walkers. It then runs
the sampler, starting with a burn-in run.
If you're not getting good results, you might want to try
increasing the burn-in, increasing the walkers, or specifying a
Expand Down Expand Up @@ -1202,12 +1202,8 @@ def mcmc_lc(data, model, vparam_names, bounds=None, priors=None,
nsamples : int, optional
Number of samples in production run.
sampler: str, optional
The kind of sampler to use. Currently 'ensemble' for
`emcee.EnsembleSampler` and 'pt' for `emcee.PTSampler` are
supported.
ntemps : int, optional
If `sampler == 'pt'` the number of temperatures to use for the
parallel tempered sampler.
The kind of sampler to use. Currently only 'ensemble' for
`emcee.EnsembleSampler` is supported.
thin : int, optional
Factor by which to thin samples in production run. Output samples
array will have (nsamples/thin) samples.
Expand Down Expand Up @@ -1349,31 +1345,13 @@ def lnprior(parameters):
def lnprob(parameters):
return lnprior(parameters) + lnlike(parameters)

# Heuristic determination of walker initial positions: distribute
# walkers uniformly over parameter space. If no bounds are
# supplied for a given parameter, use a heuristically determined
# scale.
# Moves to use
moves = emcee.moves.StretchMove(a=a)

if sampler == 'pt':
pos = np.empty((ndim, nwalkers, ntemps))
for i, name in enumerate(vparam_names):
if name in bounds:
pos[i] = np.random.uniform(low=bounds[name][0],
high=bounds[name][1],
size=(nwalkers, ntemps))
else:
ctr = model.get(name)
scale = np.abs(ctr)
pos[i] = np.random.uniform(low=ctr-scale, high=ctr+scale,
size=(nwalkers, ntemps))
pos = np.swapaxes(pos, 0, 2)
sampler = emcee.PTSampler(ntemps, nwalkers, ndim, lnlike, lnprob, a=a)

# Heuristic determination of walker initial positions: distribute
# walkers in a symmetric gaussian ball, with heuristically
# determined scale.

elif sampler == 'ensemble':
if sampler == 'ensemble':
# Heuristic determination of walker initial positions: distribute
# walkers in a symmetric gaussian ball, with heuristically
# determined scale.
ctr = model.parameters[modelidx]
scale = np.ones(ndim)
for i, name in enumerate(vparam_names):
Expand All @@ -1384,17 +1362,17 @@ def lnprob(parameters):
else:
scale[i] = 0.1
pos = ctr + scale * np.random.normal(size=(nwalkers, ndim))
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, a=a)
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, moves=moves)

else:
raise ValueError('Invalid sampler type. Currently "pt" '
'and "ensemble" are supported.')
raise ValueError('Invalid sampler type. Currently only '
'"ensemble" is supported.')

# Run the sampler.
pos, prob, state = sampler.run_mcmc(pos, nburn) # burn-in
sampler.reset()
sampler.run_mcmc(pos, nsamples, thin=thin) # production run
samples = sampler.flatchain.reshape(-1, ndim)
sampler.run_mcmc(pos, nsamples, thin_by=thin) # production run
samples = sampler.get_chain(flat=True).reshape(-1, ndim)

# Summary statistics.
vparameters = np.mean(samples, axis=0)
Expand Down
20 changes: 20 additions & 0 deletions sncosmo/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ def test_nest_lc(self):

assert_allclose(fitmodel.parameters, self.model.parameters, rtol=0.05)

@pytest.mark.skipif('not HAS_EMCEE')
def test_mcmc_lc(self):
"""Ensure that mcmc runs."""
self.model.set(**self.params)

res, fitmodel = sncosmo.mcmc_lc(self.data, self.model,
['amplitude', 'z', 't0'],
bounds={'z': (0., 1.0)})

assert_allclose(fitmodel.parameters, self.model.parameters, rtol=0.05)

@pytest.mark.skipif('not HAS_EMCEE')
def test_mcmc_invalid_sampler(self):
"""Ensure that mcmc runs."""
with pytest.raises(ValueError):
sncosmo.mcmc_lc(self.data, self.model,
['amplitude', 'z', 't0'],
bounds={'z': (0., 1.0)},
sampler='invalid_sampler')

@pytest.mark.skipif('not HAS_IMINUIT')
def test_flatten_result(self):
"""Ensure that the flatten_result function works correctly."""
Expand Down

0 comments on commit ff508e3

Please sign in to comment.