Skip to content
Snippets Groups Projects
Commit aa7dfa0e authored by Kaitlin Trenfield's avatar Kaitlin Trenfield
Browse files

changed random to not check types

parent 2122d13f
Branches main
No related merge requests found
......@@ -11,9 +11,10 @@ import warnings
import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.special import logsumexp, i0
from jax.scipy.special import logsumexp, i0, xlogy
from jax.nn import softmax
from jax.scipy.optimize import minimize as jminimize
#from scipy.optimize import minimize as spminimize
from . import random
......@@ -109,6 +110,7 @@ class VMFMixture:
else:
self.converged = False
self.log_likelihood = self.log_likelihood_trace[-1]
return self.converged
def predict_proba(self, X):
"""
......@@ -221,7 +223,7 @@ def von_mises_component_log_prob(x, mu, kappa, w):
return von_mises_log_pdf(x, mu, kappa) + jnp.log(w)
def fit_vmfmm(theta, n_components, atol=1e-3, maxiter=100,
mu=None, kappa=None, w=None, max_reseed=5, min_w=1e-2):
mu=None, kappa=None, w=None, max_reseed=0, min_w=1e-2):
"""
Parameters
----------
......@@ -261,18 +263,21 @@ def fit_vmfmm(theta, n_components, atol=1e-3, maxiter=100,
N = theta.shape[-1] # dimension of the observations
K = n_components
### initialize parameters with random guesses
if w is None:
def init_params(mu=None, kappa=None, w=None):
if w is None:
#w = jnp.array(np.random.uniform(size=n_components))
w = jnp.ones(K)
w /= w.sum()
if mu is None:
w = jnp.ones(K)
w /= w.sum()
if mu is None:
# initialize cluster centers at randomly chosen entries of theta
#_ix = np.random.randint(M, size=K)
#mu = theta[_ix,:]
mu = jnp.array(np.random.uniform(low=-np.pi, high=np.pi, size=(K, N)))
if kappa is None:
kappa = K*jnp.ones((K,N))
#mu = theta[_ix,:]
mu = jnp.array(np.random.uniform(low=-np.pi, high=np.pi, size=(K, N)))
if kappa is None:
kappa = K*jnp.ones((K,N))
return mu, kappa, w
mu, kappa, w = init_params(mu, kappa, w)
converged = False
n_iter = 0
# get log likelihood of initial parameter guess
......@@ -298,11 +303,20 @@ def fit_vmfmm(theta, n_components, atol=1e-3, maxiter=100,
r = _e_step(theta, mu, kappa, w)
n_reseed += 1
mu, kappa, w = _m_step(theta, mu, kappa, r)
mu_new, kappa_new, w_new = _m_step(theta, mu, kappa, r)
ll = mix_ll(theta, mu_new, kappa_new, w_new) #jnp.sum(mix_ll(theta, mu, kappa, w))
ll = mix_ll(theta, mu, kappa, w) #jnp.sum(mix_ll(theta, mu, kappa, w))
log_likelihoods.append(ll)
n_iter += 1
if not jnp.isnan(ll):
log_likelihoods.append(ll)
mu, kappa, w = mu_new, kappa_new, w_new
n_iter += 1
else:
# try the next step with regularized concentration parameters instead
mu, kappa, w = _m_step(theta, mu, kappa, w, reg_kappa=1.0/len(kappa))
ll = mix_ll(theta, mu_new, kappa_new, w_new)
if np.abs(log_likelihoods[-2] - log_likelihoods[-1]) < atol:
converged = True
......@@ -363,7 +377,7 @@ def _e_step(theta, mu, kappa, w):
return r
@jit
def _m_step(theta, mu, kappa, r, gtol=1e-5, maxiter=None):
def _m_step(theta, mu, kappa, r, gtol=1e-5, maxiter=None, reg_kappa=0.0):
"""
Parameters
----------
......@@ -393,7 +407,7 @@ def _m_step(theta, mu, kappa, r, gtol=1e-5, maxiter=None):
_p = p.reshape(params.shape)
mus = _p[:,:,0]
kappas = jnp.minimum(jnp.abs(_p[:,:,1]), 1000.)
return -expected_log_likelihood(theta, mus, kappas, r, w)/M
return -expected_log_likelihood(theta, mus, kappas, r, w)/M + reg_kappa * reg_kappa*jnp.sum(kappas)**2
x0 = params.flatten()
res = jminimize(negative_ell,
......@@ -411,7 +425,7 @@ def expected_log_likelihood(theta, mu, kappa, r, w):
lambda mu, kappa: von_mises_log_pdf(theta, mu, kappa),
in_axes=(0,0), out_axes=1
)
return jnp.sum(r * component_log_likelihoods(mu, kappa)) + w.dot(jnp.log(w + 1e-12))
return jnp.sum(r * component_log_likelihoods(mu, kappa)) + jnp.sum(xlogy(w, w))
def rand_vmf_mixture(N, mu, kappa, w):
"""
......
......@@ -131,14 +131,6 @@ def rand_von_mises(mu, kappa, shape=None):
N, M = shape
# mu should be a real scalar. It can wrap around the circle, so it can be negative, positive and also
# outside the range [0,2*pi].
if (type(mu) not in {float, int, np.int32, np.int64, np.float32, np.float64}):
raise TypeError("mu must be a real scalar number.")
# kappa should be positive real scalar
if (type(kappa) not in {float, int, np.int32, np.int64, np.float32, np.float64}):
raise TypeError("kappa must be a positive float.")
if kappa < 0:
raise Exception("kappa must be a positive float.")
# SPECIAL CASE
......@@ -277,4 +269,4 @@ def rand_von_mises_jittable(rng_key, mu, kappa, shape):
theta = theta - 2*jnp.pi*(theta > np.pi)
theta = theta.reshape(shape)
return theta
\ No newline at end of file
return theta
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment