Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit of simple rewriting of certain components of AssayTool… #145

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
122 changes: 122 additions & 0 deletions AssayTools/numpyrobindingmodels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python

"""
Various ligand binding models for use in assays.

"""

# =============================================================================================
# IMPORTS
# =============================================================================================

import numpy as np
import jax.numpy as jnp
import copy

from math import sqrt, exp, log

# =============================================================================================
# Physical constants
# =============================================================================================

Na = 6.02214179e23 # Avogadro's number (number/mol)
kB = Na * 1.3806504e-23 / 4184.0 # Boltzmann constant (kcal/mol/K)
C0 = 1.0 # standard concentration (M)

# =============================================================================================
# Binding models
# =============================================================================================


class BindingModel(object):
"""
Abstract base class for reaction models.

"""

def __init__(self):
pass


# =============================================================================================
# Two-component binding model
# =============================================================================================


class TwoComponentBindingModel(BindingModel):
"""
Simple two-component association.

"""

@classmethod
def equilibrium_concentrations(cls, DeltaG, Ptot, Ltot):
"""
Compute equilibrium concentrations for simple two-component association.

Parameters
----------
DeltaG : float
Reduced free energy of binding (in units of kT)
Ptot : float or numpy array
Total protein concentration summed over bound and unbound species, molarity.
Ltot : float or numpy array
Total ligand concentration summed over bound and unbound speciesl, molarity.

Returns
-------
P : float or numpy array with same dimensions as Ptot
Free protein concentration, molarity.
L : float or numpy array with same dimensions as Ptot
Free ligand concentration, molarity.
PL : float or numpy array with same dimensions as Ptot
Bound complex concentration, molarity.

"""
# Handle only strictly positive elements---all others are set to zero as constants
try:
nonzero_indices = jnp.where(Ltot > 0)[0]
zero_indices = jnp.where(Ltot <= 0)[0]
except:
nonzero_indices = jnp.array(range(Ltot.shape[0]))
zero_indices = jnp.array([])
nnonzero = len(nonzero_indices)
nzeros = len(zero_indices)

# Numerically stable variant
dtype = jnp.float32
Ptot = Ptot.astype(dtype) # promote to dtype
Ltot = Ltot.astype(dtype) # promote to dtype
PL = jnp.zeros(Ptot.shape, dtype)
logP = jnp.log(jnp.take(Ptot, nonzero_indices))
logL = jnp.log(jnp.take(Ltot, nonzero_indices))
logPLK = jnp.logaddexp(jnp.logaddexp(logP, logL), DeltaG)
PLK = jnp.exp(logPLK)
sqrt_arg = 1.0 - jnp.exp(jnp.log(4.0) + logP + logL - 2.0 * logPLK)
sqrt_arg = jnp.where(sqrt_arg >= 0.0, sqrt_arg, 0) # ensure always positive
PL = PL.at[nonzero_indices].set(
0.5 * PLK * (1.0 - jnp.sqrt(sqrt_arg))
) # complex concentration (M)

# Compute remaining concentrations.
P = Ptot - PL
# free protein concentration in sample cell after n injections (M)
L = Ltot - PL
# free ligand concentration in sample cell after n injections (M)

# Ensure all concentrations are within limits, correcting cases where numerical issues cause problems.
PL = jnp.where(PL >= 0.0, PL, 0.0) # complex cannot have negative concentration
P = jnp.where(P >= 0.0, P, 0.0)
L = jnp.where(L >= 0.0, L, 0.0)

"""
# Check all concentrations are nonnegative
# this check doesn't work with jax as it requires concrete values (no tracer)
# but P, L, PL all have tracers

assert jnp.all(P >= 0)
assert jnp.all(L >= 0)
assert jnp.all(PL >= 0)
"""

return [P, L, PL]
144 changes: 144 additions & 0 deletions AssayTools/numpyromodels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""

numpyro models for analysis of fluorescence assay data

"""

# =============================================================================================
# IMPORTS
# =============================================================================================

import abc
import numpy as np
import numpyro
import numpyro.distributions as dist
import arviz as az
import jax
import jax.numpy as jnp
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive
from .numpyrobindingmodels import TwoComponentBindingModel
import matplotlib.pyplot as plt

# =============================================================================================
# Physical constants
# =============================================================================================

Na = 6.02214179e23 # Avogadro's number (number/mol)
kB = Na * 1.3806504e-23 / 4184.0 # Boltzmann constant (kcal/mol/K)
C0 = 1.0 # standard concentration (M)

# =============================================================================================
# Parameters for MCMC sampling
# =============================================================================================

DG_min = np.log(
1e-15
) # kT, most favorable (negative) binding free energy possible; 1 fM
DG_max = +0 # kT, least favorable binding free energy possible
niter = 500000 # number of iterations
nburn = 50000 # number of burn-in iterations to discard
nthin = 500 # thinning interval

# =============================================================================================
# numpyro submodels
# i'm assuming that we have LogNormal priors for now
# =============================================================================================

# construct a lognorm dist with proper loc and scale
def construct_lognorm(loc, scale):
u = jnp.log(loc ** 2 / jnp.sqrt(loc ** 2 + scale ** 2))
sig = jnp.log(1 + scale ** 2 / loc ** 2)
return dist.LogNormal(loc=u, scale=sig)


# dispense an amount of liquid with expected concentration mu and variance var
def dispense(mu, var, name=""):
return numpyro.sample(f"dispense_{name}", construct_lognorm(loc=mu, scale=var))


# sample a unknown value e.g. extinction coefficient or quantum yield
def hidden(min=0, max=10e6, name=""):
return numpyro.sample(name, dist.Uniform(low=min, high=max))


# =============================================================================================
# numpyro base modules
# =============================================================================================


class MCMCModel:
def __init__(
self,
model,
mcmc_args={"num_warmup": nburn, "num_samples": niter, "thinning": nthin},
):
self.model = model
self.mcmc = None
self.mcmc_args = mcmc_args
self.params = None
rng_key = random.PRNGKey(0) # TODO: make option to choose random seed
self.rng_key_infer, self.rng_key_predict = random.split(rng_key)

def run_mcmc(self, *args, **kwargs):
nuts_kernel = NUTS(self.model)
self.mcmc = MCMC(nuts_kernel, **self.mcmc_args)
self.mcmc.run(self.rng_key_infer, *args, **kwargs)
self.sample_params()

def sample_params(self):
self.params = self.mcmc.get_samples()

def predict(self, *args, **kwargs):
predictor = Predictive(self.model, self.params)
return predictor(self.rng_key_predict, *args, **kwargs)

def plot_results(self):
self.mcmc.print_summary()
data = az.from_numpyro(self.mcmc)
az.plot_trace(data, compact=True)


# =============================================================================================
# numpyro models
# =============================================================================================

# toy model for fluorescence curve
# assume we have precise solutions of each concentration of protein/ligand available
# and only the complex fluoresces
def toy_model(Pstated, dPstated, Lstated, dLstated, fluorescence=None):
# binding free energy (kT), using a Uniform prior
dG = numpyro.sample("dG", dist.Uniform(DG_min, DG_max))

# we use our stated concentrations as priors
# we use plate as each dispense should be conditionally independent
with numpyro.plate("dispense_plate", len(Pstated)):
Ptrue = dispense(Pstated, dPstated, name="Ptrue")
Ltrue = dispense(Lstated, dLstated, name="Ltrue")

# compute equilibrium concentrations using model
[P_i, L_i, PL_i] = TwoComponentBindingModel.equilibrium_concentrations(
dG, Ptrue, Ltrue
)

# scale/noise factors
# all the bounds are sorta arbitrary
f_var = hidden(
max=100, name="f_measure_var"
) # variance for fluorescence measurement
f_gain = hidden(max=1e13, name="f_gain") # gain / quantum yield
f_background = hidden(max=100, name="f_background") # baseline signal

# assume that only the complex fluoresces
# e.g. FRET experiment at a particular wavelength
F_PL_i = f_background + f_gain * PL_i

# assume each measurement also has conditionally independent error
with numpyro.plate("measure_plate", len(F_PL_i)):
measurement = numpyro.sample(
"measure_fluorescence",
dist.Normal(loc=F_PL_i, scale=f_var),
obs=fluorescence,
)

return measurement