Dirichlet Process Gaussian mixture model via the stick-breaking construction in various PPLs

This page was last updated on 29 Mar, 2021.


In this post, I’ll explore implementing posterior inference for Dirichlet process Gaussian mixture models (GMMs) via the stick-breaking construction in various probabilistic programming languages. For an overview of the Dirichlet process (DP) and Chinese restaurant process (CRP), visit this post on Probabilistic Modeling using the Infinite Mixture Model by the Turing team. Basic familiarity with Gaussian mixture models and Bayesian methods are assumed in this post. This Coursera Course on Mixture Models offers a great intro on the subject.

Stick-breaking Model for Mixture Weights

As described in Probabilistic Modeling using the Infinite Mixture Model, an Bayesian infinite GMM can be expressed as follows

\[\begin{aligned} \alpha &\sim \t{Gamma(a, b)} & \t{(optional prior)} \\ \bm{z} \mid \alpha &\sim \t{CRP}_N(\alpha) \\ (\mu_k, \sigma_k) &\sim G_0, & \t{for } k=1, \dots, \infty \\ y_i \mid z_i, \bm{\mu}, \bm{\sigma} &\sim \t{Normal}(\mu_{z_i}, \sigma_{z_i}), & \t{for } i=1, \dots, N \\ \end{aligned}\]

Note that this is a collapsed version of a DP (Gaussian) mixture model. Line 4 is the sampling distribution or model, where each of $N$ real univariate observations ($y_i$) are assumed to be Normally distributed; line 3 are independent priors for the mixture locations and scales. For example, the distribution $G$ could be $\t{Normal}(\cdot \mid m, s) \times \t{Gamma}(\cdot \mid a, b)$; i.e. a bivariate distribution with independent components – one Normal and one Gamma. The CRP allows for the number of mixture components to be unbounded (though it is practically bounded by $N$), hence $k=1,\dots,\infty$. Line 2 is the prior for the cluster membership indicators. Through $z_i$, the data can be partitioned (i.e. grouped) into clusters; Line 1 is an optional prior for the concentration parameter $\alpha$ which influences the number of clusters; larger $\alpha$, greater number of clusters. Note that the data usually does not contain much information about $\alpha$, so priors for $\alpha$ need to be at least moderately informative.

This model is difficult to implement generically in PPLs because it line 1 allows the number of mixture components (i.e. the dimensions of $\mu$ and $\sigma$) to vary, and line 2 contains discrete parameters. It is true that for a particular class of DP mixture models (for carefully chosen priors), efficient posterior inference is just a matter of iterative and direct sampling from full conditionals. But this convenience is not available in general, and we still have the issue of varying number of mixture components. The PPL NIMBLE (in R) addresses this issue by allowing the user to specify a maximum number of components, it also cleverly exploits conjugate priors when possible. This feature is not typically seen in PPLs, including Turing (at the moment). NIMBLE also supports the sampling of discrete parameters, which is uncommon in PPLs which usually implement generic and efficient sampling algorithms like HMC and NUTS for continuous parameters. Turing supports the sampling of discrete parameters via sequential Monte Carlo (SMC) and particle Gibbs (PG) / conditional SMC.

Having said that, the finite stick-breaking construction of the DP bypasses the need for varying number of mixture components. In the context of mixture models, the infinite stick-breaking construction of the DP constructs mixture weights $(w_1, w_2, \dots)$ as follows:

\[\begin{aligned} v_\ell \mid \alpha &\sim \t{Beta}(1, \alpha), &\t{for } \ell\in\mathbb{N} \\ w_1 &= v_1 \\ w_k &= v_k \prod_{\ell=2}^{k - 1} 1 - v_\ell, &\t{for } k\in\bc{2,3,\dots} \\ \end{aligned}\]

This image shows a realization of $\bm{w}$ from the stick-breaking process with $\alpha=3$.

stickbreak image

Infinite probability vectors ($\bm{w}$) generated from this construction are equivalent to the weights implied by the CRP (with parameter $\alpha$). Hence, the probability vector $\bm{w}$ (aka simplex), where each element is non-negative and the elements sum to 1, will be sparse. That is, most of the weights will be extremely close to 0, and only a handful will be noticeably (to substantially) greater than 0. The sparsity will be influenced by $\alpha$. In practice, a truncated (or finite) version of the stick-breaking construction is used. (An infinite array cannot be created on a computer…) The finite stick-breaking model simply places an upper bound on the number of mixture components, which, if chosen reasonably, allows us to reap the benefits of the DP (a model which allows model complexity to grow with data size) at a cheaper computational cost. The finite version is as follows:

\[\begin{aligned} v_\ell \mid \alpha &\sim \t{Beta}(1, \alpha), &\t{for } \ell=1,\dots,K-1 \\ w_1 &= v_1 \\ w_k &= v_k \prod_{\ell=2}^{k - 1} 1 - v_\ell, &\t{for } k=2,\dots,K-1 \\ w_K &= \prod_{\ell=1}^{K - 1} 1 - v_\ell. \end{aligned}\]

Note that $w_K$ is defined such that $w_K = 1 - \sum_{k=1}^{K-1} w_k$. But line 4 above is typically implemented in software for numerical stability. Here on, for brevity, $\bm{w} = \t{stickbreak}(\bm{v})$ will denote the transformation lines 2-4 above. Notice that the probability vector $\bm{w}$ is $K$-dimensional, while $v$ is $(K-1)$-dimensional. (A simplex $\bm{w}$ of length $K$ only requires $K-1$ elements to be specified. The remaining element is constrained such that the probability vector sums to 1.)

A DP GMM under the stick-breaking construction can thus be specified as follows:

\[\begin{aligned} % Priors. \alpha &\sim \t{Gamma(a, b)} &\t{(optional prior)} \\ v_k \mid \alpha &\sim \t{Beta}(1, \alpha), &\t{for } k=1,\dots,K-1 \\ \bm{w} &= \t{stickbreak}(\bm{v}), \\ \bm{z} \mid \bm{w} &\sim \t{Categorical}_K(\bm{w}) \\ \mu_k &\sim G_\mu, &\t{for } k = 1,\dots,K \\ \sigma_k &\sim G_\sigma, &\t{for } k = 1,\dots,K \\ % Sampling Distribution. y_i \mid z_i, \bm{\mu}, \bm{\sigma} &\sim \t{Normal}(\mu_{z_i}, \sigma_{z_i}), & \t{for } i=1, \dots, N \\ \end{aligned}\]

where $G_\mu$ and $G_\sigma$ are appropriate priors for $\mu_k$ and $\sigma_k$, respectively. Marginalizing over the (discrete) cluster membership indicators $\bm{z}$ may be beneficial in practice if an efficient posterior inference algorithm (e.g. ADVI, HMC, NUTS) exists for learning the joint posterior of the remaining model parameters. If this is the case, one further reduction can be made to yield:

\[\begin{aligned} % Priors. \alpha &\sim \t{Gamma(a, b)} &\t{(optional prior)}. \\ v_k \mid \alpha &\sim \t{Beta}(1, \alpha), &\t{for } k=1,\dots,K-1 \\ \bm{w} &= \t{stickbreak}(\bm{v}) \\ % Sampling Distribution. \mu_k &\sim G_\mu, &\t{for } k = 1,\dots,K \\ \sigma_k &\sim G_\sigma, &\t{for } k = 1,\dots,K \\ y_i \mid \bm{\mu}, \bm{\sigma}, \bm{w} &\sim \sum_{k=1}^K w_k \cdot \t{Normal}(\mu_k, \sigma_k), & \t{for } i=1, \dots, N \\ \end{aligned}\]

The joint posterior of the parameters $\bm{\theta} = \p{\bm{\mu}, \bm{\sigma}, \bm{v}, \alpha}$ can be sampled from using NUTS or HMC, or approximated via ADVI. This can be done in the various PPLs as follows. (Note that these are excerpts from complete examples which are also linked.)

Full Turing Example (notebook)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# NOTE: Import libraries here ...

# DP GMM model under stick-breaking construction
@model dp_gmm_sb(y, K) = begin
    nobs = length(y)

    mu ~ filldist(Normal(0, 3), K)
    sig ~ filldist(Gamma(1, 1/10), K)  # mean = 0.1

    alpha ~ Gamma(1, 1/10)  # mean = 0.1
    crm = DirichletProcess(alpha)
    v ~ filldist(StickBreakingProcess(crm), K - 1)
    eta = stickbreak(v)

    y .~ UnivariateGMM(mu, sig, Categorical(eta))
end

# NOTE: Read data y here ...
# Here, y (a vector of length 500) is noisy univariate draws from a
# mixture distribution with 4 components.

# Fit DP-SB-GMM with ADVI
advi = ADVI(1, 2000)  # num_elbo_samples, max_iters
q = vi(dp_gmm_sb(y, 10), advi, optimizer=Flux.ADAM(1e-2));

# Fit DP-SB-GMM with HMC
hmc_chain = sample(dp_gmm_sb(y, 500),  # data, number of mixture components
                   HMC(0.01, 100),  # stepsize, number of leapfrog steps
                   1000)  # iterations

# Fit DP-SB-GMM with NUTS
@time nuts_chain = begin
    n_samples = 500  # number of MCMC samples
    nadapt = 500  # number of iterations to adapt tuning parameters in NUTS
    iterations = n_samples + nadapt
    target_accept_ratio = 0.8
    
    sample(dp_gmm_sb(y, 10),  # data, number of mixture components.
           NUTS(nadapt, target_accept_ratio, max_depth=10),
           iterations);
end

Full STAN Example (notebook)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# import libraries here ...

model = """
data {
  int<lower=0> K;  // Number of cluster
  int<lower=0> N;  // Number of observations
  real y[N];  // observations
  real<lower=0> alpha_shape;
  real<lower=0> alpha_rate;
  real<lower=0> sigma_shape;
  real<lower=0> sigma_rate;
}

parameters {
  real mu[K]; // cluster means
  // real <lower=0,upper=1> v[K - 1];  // stickbreak components
  vector<lower=0,upper=1>[K - 1] v;  // stickbreak components
  real<lower=0> sigma[K];  // error scale
  real<lower=0> alpha;  // hyper prior DP(alpha, base)
}

transformed parameters {
  simplex[K] eta;
  vector<lower=0,upper=1>[K - 1] cumprod_one_minus_v;

  cumprod_one_minus_v = exp(cumulative_sum(log1m(v)));
  eta[1] = v[1];
  eta[2:(K-1)] = v[2:(K-1)] .* cumprod_one_minus_v[1:(K-2)];
  eta[K] = cumprod_one_minus_v[K - 1];
}

model {
  real ps[K];
  // real alpha = 1;
  
  alpha ~ gamma(alpha_shape, alpha_rate);  // mean = a/b = shape/rate 
  sigma ~ gamma(sigma_shape, sigma_rate);
  mu ~ normal(0, 3);
  v ~ beta(1, alpha);

  for(i in 1:N){
    for(k in 1:K){
      ps[k] = log(eta[k]) + normal_lpdf(y[i] | mu[k], sigma[k]);
    }
    target += log_sum_exp(ps);
  }
}

generated quantities {
  real ll;
  real ps_[K];
  
  ll = 0;
  for(i in 1:N){
    for(k in 1:K){
      ps_[k] = log(eta[k]) + normal_lpdf(y[i] | mu[k], sigma[k]);
    }
    ll += log_sum_exp(ps_);
  }  
}
"""

# Compile the stan model.
sm = pystan.StanModel(model_code=model)

# NOTE: Read data y here ...
# Here, y (a vector of length 500) is noisy univariate draws from a
# mixture distribution with 4 components.

# Construct data dictionary.
data = dict(y=simdata['y'], K=10, N=len(simdata['y']),
            alpha_shape=1, alpha_rate=10, sigma_shape=1, sigma_rate=10)

# Approximate posterior via ADVI
# - ADVI is sensitive to starting values. Should run several times and pick run 
#   that has best fit (e.g. highest ELBO / logliklihood).
# - Variational inference works better with more data. Inference is less accurate
#   with small datasets, due to the variational approximation.
fit = sm.vb(data=data, iter=1000, seed=1, algorithm='meanfield',
            adapt_iter=1000, verbose=False, grad_samples=1, elbo_samples=100,
            adapt_engaged=True, output_samples=1000)

### Settings for MCMC ###
burn = 500  # Number of burn in iterations
nsamples = 500  # Number of sampels to keep
niters = burn + nsamples  # Number of MCMC (HMC / NUTS) iterations in total

# Sample from posterior via HMC
# NOTE: num_leapfrog = int_time / stepsize.
hmc_fit = sm.sampling(data=data, iter=niters, chains=1, warmup=burn,
                      thin=1, seed=1, algorithm='HMC',
                      control=dict(stepsize=0.01, int_time=1,
                                   adapt_engaged=False))

# Sample from posterior via NUTS
nuts_fit = sm.sampling(data=data, iter=niters, chains=1, warmup=burn, thin=1,
                       seed=1)

Full TFP Example (notebook)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# import libraries here ...

# Stickbreak function
def stickbreak(v):
    batch_ndims = len(v.shape) - 1
    cumprod_one_minus_v = tf.math.cumprod(1 - v, axis=-1)
    one_v = tf.pad(v, [[0, 0]] * batch_ndims + [[0, 1]], "CONSTANT",
                   constant_values=1)
    c_one = tf.pad(cumprod_one_minus_v, [[0, 0]] * batch_ndims + [[1, 0]],
                   "CONSTANT", constant_values=1)
    return one_v * c_one

# See: https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MixtureSameFamily
# See: https://www.tensorflow.org/probability/examples/Bayesian_Gaussian_Mixture_Model
# Define model builder.
def create_dp_sb_gmm(nobs, K, dtype=np.float64):
    return tfd.JointDistributionNamed(dict(
        # Mixture means
        mu = tfd.Independent(
            tfd.Normal(np.zeros(K, dtype), 3),
            reinterpreted_batch_ndims=1
        ),
        # Mixture scales
        sigma = tfd.Independent(
            tfd.LogNormal(loc=np.full(K, - 2, dtype), scale=0.5),
            reinterpreted_batch_ndims=1
        ),
        # Mixture weights (stick-breaking construction)
        alpha = tfd.Gamma(concentration=np.float64(1.0), rate=10.0),
        v = lambda alpha: tfd.Independent(
            # tfd.Beta(np.ones(K - 1, dtype), alpha),
            # NOTE: Dave Moore suggests doing this instead, to ensure 
            # that a batch dimension in alpha doesn't conflict with 
            # the other parameters.
            tfd.Beta(np.ones(K - 1, dtype), alpha[..., tf.newaxis]),
            reinterpreted_batch_ndims=1
        ),

        # Observations (likelihood)
        obs = lambda mu, sigma, v: tfd.Sample(tfd.MixtureSameFamily(
            # This will be marginalized over.
            mixture_distribution=tfd.Categorical(probs=stickbreak(v)),
            # mixture_distribution=tfd.Categorical(probs=v),
            components_distribution=tfd.Normal(mu, sigma)),
            sample_shape=nobs)
    ))

# Number of mixture components.
ncomponents = 10

# Create model.
model = create_dp_sb_gmm(nobs=len(simdata['y']), K=ncomponents)

# Define log unnormalized joint posterior density.
def target_log_prob_fn(mu, sigma, alpha, v):
    return model.log_prob(obs=y, mu=mu, sigma=sigma, alpha=alpha, v=v)

# NOTE: Read data y here ...
# Here, y (a vector of length 500) is noisy univariate draws from a
# mixture distribution with 4 components.

### ADVI ###
# Prep work for ADVI. Credit: Thanks to Dave Moore at BayesFlow for helping
# with the implementation!

# ADVI is quite sensitive to initial distritbution.
tf.random.set_seed(7) # 7

# Create variational parameters.
qmu_loc = tf.Variable(tf.random.normal([ncomponents], dtype=np.float64) * 3,
                      name='qmu_loc')
qmu_rho = tf.Variable(tf.random.normal([ncomponents], dtype=np.float64) * 2,
                      name='qmu_rho')

qsigma_loc = tf.Variable(tf.random.normal([ncomponents], dtype=np.float64) - 2,
                         name='qsigma_loc')
qsigma_rho = tf.Variable(tf.random.normal([ncomponents], dtype=np.float64) - 2,
                         name='qsigma_rho')

qv_loc = tf.Variable(tf.random.normal([ncomponents - 1], dtype=np.float64) - 2,
                     name='qv_loc')
qv_rho = tf.Variable(tf.random.normal([ncomponents - 1], dtype=np.float64) - 1,
                     name='qv_rho')

qalpha_loc = tf.Variable(tf.random.normal([], dtype=np.float64),
                         name='qalpha_loc')
qalpha_rho = tf.Variable(tf.random.normal([], dtype=np.float64),
                         name='qalpha_rho')

# Create variational distribution.
surrogate_posterior = tfd.JointDistributionNamed(dict(
    # qmu
    mu=tfd.Independent(tfd.Normal(qmu_loc, tf.nn.softplus(qmu_rho)),
                       reinterpreted_batch_ndims=1),
    # qsigma
    sigma=tfd.Independent(tfd.LogNormal(qsigma_loc,
                                        tf.nn.softplus(qsigma_rho)),
                          reinterpreted_batch_ndims=1),
    # qv
    v=tfd.Independent(tfd.LogitNormal(qv_loc, tf.nn.softplus(qv_rho)),
                      reinterpreted_batch_ndims=1),
    # qalpha
    alpha=tfd.LogNormal(qalpha_loc, tf.nn.softplus(qalpha_rho))))

# Run ADVI.
losses = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn=target_log_prob_fn,
    surrogate_posterior=surrogate_posterior,
    optimizer=tf.optimizers.Adam(learning_rate=1e-2),
    sample_size=100, seed=1, num_steps=2000)  # 9 seconds


### MCMC (HMC/NUTS) ###

# Creates initial values for HMC, NUTS. 
def generate_initial_state(seed=None):
    tf.random.set_seed(seed)
    return [
        tf.zeros(ncomponents, dtype, name='mu'),
        tf.ones(ncomponents, dtype, name='sigma') * 0.1,
        tf.ones([], dtype, name='alpha') * 0.5,
        tf.fill(ncomponents - 1, value=np.float64(0.5), name='v')
    ]

# Create bijectors to transform unconstrained to and from constrained
# parameters-space.  For example, if X ~ Exponential(theta), then X is
# constrained to be positive. A transformation that puts X onto an
# unconstrained # space is Y = log(X). In that case, the bijector used should
# be the **inverse-transform**, which is exp(.) (i.e. so that X = exp(Y)).
#
# NOTE: Define the inverse-transforms for each parameter in sequence.
bijectors = [
    tfb.Identity(),  # mu
    tfb.Exp(),  # sigma
    tfb.Exp(),  # alpha
    tfb.Sigmoid()  # v
]

# Define HMC sampler.
@tf.function(autograph=False, experimental_compile=True)
def hmc_sample(num_results, num_burnin_steps, current_state, step_size=0.01,
               num_leapfrog_steps=100):
    return tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=num_burnin_steps,
        current_state=current_state,
        kernel = tfp.mcmc.SimpleStepSizeAdaptation(
            tfp.mcmc.TransformedTransitionKernel(
                inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
                    target_log_prob_fn=target_log_prob_fn,
                    step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, seed=1),
                bijector=bijectors),
            num_adaptation_steps=num_burnin_steps),
        trace_fn = lambda _, pkr: pkr.inner_results.inner_results.is_accepted)


# Define NUTS sampler.
@tf.function(autograph=False, experimental_compile=True)
def nuts_sample(num_results, num_burnin_steps, current_state, max_tree_depth=10):
    return tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=num_burnin_steps,
        current_state=current_state,
        kernel = tfp.mcmc.SimpleStepSizeAdaptation(
            tfp.mcmc.TransformedTransitionKernel(
                inner_kernel=tfp.mcmc.NoUTurnSampler(
                     target_log_prob_fn=target_log_prob_fn,
                     max_tree_depth=max_tree_depth, step_size=0.01, seed=1),
                bijector=bijectors),
            num_adaptation_steps=num_burnin_steps,  # should be smaller than burn-in.
            target_accept_prob=0.8),
        trace_fn = lambda _, pkr: pkr.inner_results.inner_results.is_accepted)

# Run HMC sampler.
current_state = generate_initial_state()
[mu, sigma, alpha, v], is_accepted = hmc_sample(500, 500, current_state)
hmc_output = dict(mu=mu, sigma=sigma, alpha=alpha, v=v,
                  acceptance_rate=is_accepted.numpy().mean())

# Run NUTS sampler.
current_state = generate_initial_state()
[mu, sigma, alpha, v], is_accepted = nuts_sample(500, 500, current_state)
nuts_output = dict(mu=mu, sigma=sigma, alpha=alpha, v=v,
                   acceptance_rate=is_accepted.numpy().mean())

Full Pyro Example (notebook)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# import libraries here ...

# Stick break function
# See: https://pyro.ai/examples/dirichlet_process_mixture.html
def stickbreak(v):
    cumprod_one_minus_v = torch.cumprod(1 - v, dim=-1)
    one_v = pad(v, (0, 1), value=1)
    c_one = pad(cumprod_one_minus_v, (1, 0), value=1)
    return one_v * c_one

# See: https://pyro.ai/examples/gmm.html#
# See: https://pyro.ai/examples/dirichlet_process_mixture.html
# See: https://forum.pyro.ai/t/fitting-models-with-nuts-is-slow/1900

# DP SB GMM model.
# NOTE: In pyro, priors are assigned to parameters in the following manner:
#
#   random_variable = pyro.sample('name_of_random_variable', some_distribution)
#
# Note that random variables appear on the left hand side of the `pyro.sample` statement.
# Data will appear *inside* the `pyro.sample` statement, via the obs argument.
# 
# In this example, labels are explicitly mentioned. But they are, in fact, marginalized
# out automatically by pyro. Hence, they do not appear in the posterior samples.
#
# Both marginalized and auxiliary variabled versions are equally slow.
def dp_sb_gmm(y, num_components):
    # Cosntants
    N = y.shape[0]
    K = num_components
    
    # Priors
    # NOTE: In pyro, the Gamma distribution is parameterized with shape and rate.
    # Hence, Gamma(shape, rate) => mean = shape/rate
    alpha = pyro.sample('alpha', dist.Gamma(1, 10))
    
    with pyro.plate('mixture_weights', K - 1):
        v = pyro.sample('v', dist.Beta(1, alpha, K - 1))
    
    eta = stickbreak(v)
    
    with pyro.plate('components', K):
        mu = pyro.sample('mu', dist.Normal(0., 3.))
        sigma = pyro.sample('sigma', dist.Gamma(1, 10))

    with pyro.plate('data', N):
        label = pyro.sample('label', dist.Categorical(eta), infer={"enumerate": "parallel"})
        pyro.sample('obs', dist.Normal(mu[label], sigma[label]), obs=y)


# NOTE: Read data y here ...
# Here, y (a vector of length 500) is noisy univariate draws from a
# mixture distribution with 4 components.

# Fit DP SB GMM via ADVI
# See: https://pyro.ai/examples/dirichlet_process_mixture.html

# Automatically define variational distribution (a mean field guide).
pyro.clear_param_store()  # clear global parameter cache
guide = AutoDiagonalNormal(pyro.poutine.block(dp_sb_gmm, expose=['alpha', 'v', 'mu', 'sigma']))
svi = SVI(dp_sb_gmm, guide, Adam({'lr': 1e-2}), TraceEnum_ELBO())
pyro.set_rng_seed(7)  # set random seed
# Do gradient steps.
for step in range(2000):
    svi.step(y, 10)

# Fit DP SB GMM via HMC
pyro.clear_param_store()
pyro.set_rng_seed(1)
kernel = HMC(dp_sb_gmm, step_size=0.01, trajectory_length=1,
             target_accept_prob=0.8, adapt_step_size=False,
             adapt_mass_matrix=False)
hmc = MCMC(kernel, num_samples=500, warmup_steps=500)
hmc.run(y, 10)

# Fit DP SB GMM via NUTS
pyro.clear_param_store()
pyro.set_rng_seed(1)
kernel = NUTS(dp_sb_gmm, target_accept_prob=0.8)
nuts = MCMC(kernel, num_samples=500, warmup_steps=500)
nuts.run(y, 10)

Full Numpyro Example (notebook)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# import some libraries here...

# Stick break function
def stickbreak(v):
    batch_ndims = len(v.shape) - 1
    cumprod_one_minus_v = np.exp(np.log1p(-v).cumsum(-1))
    one_v = np.pad(v, [[0, 0]] * batch_ndims + [[0, 1]], constant_values=1)
    c_one = np.pad(cumprod_one_minus_v, [[0, 0]] * batch_ndims +[[1, 0]],
                   constant_values=1)
    return one_v * c_one

# Custom distribution: Mixture of normals.
# This implements the abstract class `dist.Distribution`.
class NormalMixture(dist.Distribution):
    support = constraints.real_vector

    def __init__(self, mu, sigma, w):
        super(NormalMixture, self).__init__(event_shape=(1, ))
        self.mu = mu
        self.sigma = sigma
        self.w = w

    def sample(self, key, sample_shape=()):
        # it is enough to return an arbitrary sample with correct shape
        return np.zeros(sample_shape + self.event_shape)

    def log_prob(self, y, axis=-1):
        lp = dist.Normal(self.mu, self.sigma).log_prob(y) + np.log(self.w)
        return logsumexp(lp, axis=axis)

# DP SB GMM model.
# NOTE: In numpyro, priors are assigned to parameters in the following manner:
#
#   random_variable = numpyro.sample('name_of_random_variable', some_distribution)
#
# Note that random variables appear on the left hand side of the
# `numpyro.sample` statement.  Data will appear *inside* the `numpyro.sample`
# statement, via the obs argument.
# 
# In this example, labels are explicitly mentioned. But they are, in fact,
# marginalized out automatically by numpyro. Hence, they do not appear in the
# posterior samples.
def dp_sb_gmm(y, num_components):
    # Cosntants
    N = y.shape[0]
    K = num_components
    
    # Priors
    # NOTE: In numpyro, the Gamma distribution is parameterized with shape and
    # rate.  Hence, Gamma(shape, rate) => mean = shape/rate
    alpha = numpyro.sample('alpha', dist.Gamma(1, 10))
    
    with numpyro.plate('mixture_weights', K - 1):
        v = numpyro.sample('v', dist.Beta(1, alpha, K - 1))
    
    eta = stickbreak(v)
    
    with numpyro.plate('components', K):
        mu = numpyro.sample('mu', dist.Normal(0., 3.))
        sigma = numpyro.sample('sigma', dist.Gamma(1, 10))

    with numpyro.plate('data', N):
        numpyro.sample('obs', NormalMixture(mu[None, :] , sigma[None, :],
                                            eta[None, :]), obs=y[:, None])

# NOTE: Read data y here ...
# Here, y (a vector of length 200) is noisy univariate draws from a
# mixture distribution with 4 components.

# Set random seed for reproducibility.
rng_key = random.PRNGKey(0)

# FIT DP SB GMM via HMC
# NOTE: num_leapfrog = trajectory_length / step_size
kernel = HMC(dp_sb_gmm, step_size=.01, trajectory_length=1)
hmc = MCMC(kernel, num_samples=500, num_warmup=500)
hmc.run(rng_key, y, 10)
hmc_samples = get_posterior_samples(hmc)

# FIT DP SB GMM via NUTS
kernel = NUTS(dp_sb_gmm, max_tree_depth=10, target_accept_prob=0.8)
nuts = MCMC(kernel, num_samples=500, num_warmup=500)
nuts.run(rng_key, y, 10)
nuts_samples = get_posterior_samples(nuts)

# FIT DP SB GMM via ADVI
sigmoid = lambda x: 1 / (1 + onp.exp(-x))

# Setup ADVI.
guide = AutoDiagonalNormal(dp_sb_gmm)  # surrogate posterior
optimizer = numpyro.optim.Adam(step_size=0.01)  # adam optimizer
svi = SVI(guide.model, guide, optimizer, loss=ELBO())  # ELBO loss
init_state = svi.init(random.PRNGKey(2), y, 10)  # initial state

# Run optimizer
state, losses = lax.scan(lambda state, i: 
                         svi.update(state, y, 10), init_state, np.arange(2000))

# Extract surrogate posterior.
params = svi.get_params(state)
def sample_advi_posterior(guide, params, nsamples, seed=1):
    samples = guide.get_posterior(params).sample(random.PRNGKey(seed),
                                                 (nsamples, ))
    # NOTE: Samples are arranged in alphabetical order.
    #       Not in the order in which they appear in the
    #       model. This is different from pyro.
    return dict(alpha=onp.exp(samples[:, 0]),
                mu=onp.array(samples[:, 1:11]).T,
                sigma=onp.exp(samples[:, 11:21]).T,
                eta=onp.array(stickbreak(sigmoid(samples[:, 21:]))).T)  # v

advi_samples = sample_advi_posterior(guide, params, nsamples=500, seed=1)

ADVI, HMC, and NUTS are not supported in NIMBLE at the moment. Though, the model can be implemented and posterior inference can be made via alternative inference algorithms. See: this NIMBLE example.

The purpose of this post is to compare various PPLs for this particular model. That includes things like timings, inference quality, and syntax sugar. Comparing inferences is difficult without overwhelming the reader with many figures. Suffice it to say that , where possible, HMC and NUTS performed similarly across the various PPLs. And some discrepancies occurred for ADVI. (Full notebooks with visuals are provided here.) Nevertheless, I have included some information of the data used, and inferences from ADVI in Turing here. (The inferences from HMC and NUTS were not vastly different.)

Here is a histogram of the data used. $y_i$ for $i=1,\dots,200$ is univariate simulated from a mixture of four Normal distributions with varying locations and scales. The specific dataset can be found here, and the scripts for generating the data can be found here.

simdata-n200

Here we have the posterior distributions of $\eta$ (the mixture weights, aka $w$ above), $\mu$ (the mixture locations), $\sigma$ (the mixture scales), and $\alpha$ the concentration parameter.

advi-n200-posterior

Note that the whiskers in the box plots are the ends of the 95% credible intervals. The triangles and solid horizontal line in the boxplots are the posterior means, and medians, respectively. The dashed lines are the simulation truths, which match up closely to the posterior means, and fall within the 95% credible intervals. Note for components 5-10, $\eta$ is near 0, due to the sparse prior which is implied on it. In addition, $\mu$ and $\sigma$ for components 5-10 are simply sampling from the prior, as no information is available for those components (which are not used). The histogram for $\alpha$ is provided for reference. Since there is no simulation truth for $\alpha$, not much more can be said about it. It is reasonable that it is in the range (0, 1) due to the sparseness of the small number of mixture components and the prior information. We can see that this model effectively learned there are only 4 clusters in this data.

Comparing the PPLs

First off, these are the settings used for HMC, and NUTS:

  • ADVI:
    • The optimizers for ADVI were run for 2000 iterations.
    • Full ADVI was done (i.e. no sub-sampling of the data as in stochastic ADVI)
    • STAN defaults to using 100 samples for ELBO approximation and 10 samples for ELBO gradient approximation. When I set these to 1, like in Turing (and also recommended in the ADVI paper), the inferences were quite bad, so I left them at the default. Nevertheless, STAN still ran the fastest.
    • NOTE: ADVI is extremely sensitive to initial values. The best run from multiple runs with varying initial values was used.
    • NOTE: There are discrepancies in the optimizers used in each PPL. In Pyro and Turing, the Adam optimizer was used; in STAN, RmsProp was used.
  • HMC
    • Step-size: 0.01
    • Number of leapfrog steps: 100
    • Number of burn-in iterations: 500
    • Number of subsequent HMC iterations: 500
      • This is also the number of posterior samples
    • NOTE: Relatively robust to initial values.
  • NUTS
    • Target acceptance rate: 80%
    • Number of iterations for burn-in: 500
      • NOTE: In TFP, the number of iterations for adapting NUTS hyper-parameters was set to 400 (this needs to be less than the burn-in, as it is part of the burn-in period).
    • Maximum tree-depth: 10
    • Number of subsequent NUTS iterations: 500
      • This is also the number of posterior samples
    • NOTE: Relatively robust to initial values.
  • Note: Automatic differentiation (AD) libraries used vary between the PPLs. The default AD libraries in each PPL were used. This may have a small effect on the timings.

Specifying the DP GMM via stick-breaking is quite simple in each of the PPLs. Each PPL has also strengths and weaknesses for this particular model. Particularly, some PPLs had shorter compile times; others had short inference-times. It appears that the inferences from HMC and NUTS are similar across the PPLs. Pyro implements ADVI, but the inferences were quite poor compared to STAN, Turing, and TFP (several random initial values were used). Numpyro was fastest at HMC and NUTS. STAN was fastest at ADVI, though it had the longest compile time. (Note that the compile time for STAN models can be one-time, as you can cache the compiled model.)

STAN, being the oldest of the PPLs, has the most comprehensive documentation. So, implementing this particular model was quite easy.

TFP is quite verbose, but offers a lot of control for the user. I had some difficulty getting the dimensions (batch_dim, event_shape) correct initially. Dave Moore, who works at Google’s BayesFlow team, gracefully looked at my code and offered fixes!

I would like to point out (while acknowledging my current affiliation with the Turing team) the elegance of the syntax of Turing’s model specification. It is close to how one would naturally write down the model, and it is also the shortest. For those that are already familiar with Julia, custom functions (such as the stickbreak function used) can be implemented by the user and used in the model specification. Note that functions from other libraries can be used quite effortlessly. For example, in Turing, the logsumexp and normlogpdf methods are from the StatsFuns.jl library. They worked without any tweaking in Turing. Another strength of Turing sampling from the full conditionals of individual (or a group of) parameters using different inference algorithms is possible. For example,

hmc_chain = sample(dp_gmm_sb(y, 500),  # data, number of mixture components
                   Gibbs(HMC(0.01, 100, :mu, :sig, :v),
                         MH(:alpha)),
                   1000)  # iterations

enables the sampling of $(\mu, \sigma, v)$ jointly conditioned on the current $\alpha$, using HMC; and sampling $\alpha \mid \bm{\mu},\bm{\sigma},\bm{v}$ via a vanilla Metropolis-Hastings step. This is not possible in STAN, Pyro, Numpyro, and TFP. (This is also possible in NIMBLE; however, NIMBLE currently does not support inference algorithms based on AD.) Another possibility in Turing is:

hmc_chain = sample(dp_gmm_sb(y, 500),  # data, number of mixture components
                   Gibbs(HMC(0.01, 100, :mu, :sig),
                         HMC(0.01, 100, :v),
                         MH(:alpha)),
                   1000)  # iterations

where separate HMC samplers are used for $(\mu, \sigma)$ and $\bm{v}$.

For HMC, the inference timings of Turing are within an order of magnitude from the fastest PPL (Numpyro). I think time and inference comparisons via HMC may be fairer than the NUTS and ADVI comparisons as the implementations of HMC is comparatively more well-defined (standardized) and relatively simple; whereas the implementations of ADVI and NUTS are nuanced, and for NUTS quite complex. As already stated, the quality of AD libraries will affect timings, and possibly the quality of the inferences. Though the AD libraries used (Flux, autodiff, torch, tensorflow, JAX) all seem rather robust.

Timings

Here are the compile and inference times (in seconds) for each PPL for this model. Smaller is better. (By clicking the column headers, you can sort the rows by inference times.)

PPL ADVI (run) HMC (run) NUTS (run) ADVI (compile) HMC (compile) NUTS (compile)
stan 2.3 23.3 99.0 52.3 52.3 52.3
turing 7.0 116.0 676.0 7.9 18.5 10.3
tfp 3.0 14.4 51.4 0.0 8.8 16.3
pyro 8.0 178.0 1106.0 0.0 0.0 0.0
numpyro 3.4 9.0 30.0 6.6 7.2 2.6

For STAN, the compile times listed is the time required to compile the model; i.e. the time required to run this command:

sm = pystan.StanModel(model_code=model)

The model is compiled only once (and not three times), then the user is free to select the inference algorithm.

The timings provided are one-off, but they don’t vary much from run-to-run, and don’t affect the rankings of the PPLs in terms of speed for each inference algorithm.

Note that for the Turing model, noticeable (40-60%) speedups can be realized by replacing the UnivariateGMM model by a direct increment of the log unnormalized joint posterior pdf. See this implementation and the results timings.

The Manifest.toml and requirements.txt files, respectively, in the GitHub project page list the specific Julia and Python libraries used for these runs. All experiments for this project were done in an c5.xlarge AWS Spot Instance. As of this writing, here are the specs for this instance:

  • vCPU: 4 Intel(R) Xeon(R) Platinum 8124M CPU @ 3.00GHz
  • RAM: 8 GB
  • Storage: EBS only (32 GB)
  • Network Bandwidth: Up to 10 Gbps
  • EBS Bandwidth: Up to 4750 Mbps

Next

In my next post, I will do a similar comparison of the PPLs and inference algorithms mentioned here for a basic Gaussian process model.

Feel free to comment below.