Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
V
VMFMixture
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Kaitlin Trenfield
VMFMixture
Commits
aa7dfa0e
Commit
aa7dfa0e
authored
2 years ago
by
Kaitlin Trenfield
Browse files
Options
Downloads
Patches
Plain Diff
changed random to not check types
parent
2122d13f
Branches
main
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
VMFMixture/core.py
+31
-17
31 additions, 17 deletions
VMFMixture/core.py
VMFMixture/random.py
+1
-9
1 addition, 9 deletions
VMFMixture/random.py
with
32 additions
and
26 deletions
VMFMixture/core.py
+
31
−
17
View file @
aa7dfa0e
...
...
@@ -11,9 +11,10 @@ import warnings
import
jax.numpy
as
jnp
from
jax
import
jit
,
vmap
from
jax.scipy.special
import
logsumexp
,
i0
from
jax.scipy.special
import
logsumexp
,
i0
,
xlogy
from
jax.nn
import
softmax
from
jax.scipy.optimize
import
minimize
as
jminimize
#from scipy.optimize import minimize as spminimize
from
.
import
random
...
...
@@ -109,6 +110,7 @@ class VMFMixture:
else
:
self
.
converged
=
False
self
.
log_likelihood
=
self
.
log_likelihood_trace
[
-
1
]
return
self
.
converged
def
predict_proba
(
self
,
X
):
"""
...
...
@@ -221,7 +223,7 @@ def von_mises_component_log_prob(x, mu, kappa, w):
return
von_mises_log_pdf
(
x
,
mu
,
kappa
)
+
jnp
.
log
(
w
)
def
fit_vmfmm
(
theta
,
n_components
,
atol
=
1e-3
,
maxiter
=
100
,
mu
=
None
,
kappa
=
None
,
w
=
None
,
max_reseed
=
5
,
min_w
=
1e-2
):
mu
=
None
,
kappa
=
None
,
w
=
None
,
max_reseed
=
0
,
min_w
=
1e-2
):
"""
Parameters
----------
...
...
@@ -261,18 +263,21 @@ def fit_vmfmm(theta, n_components, atol=1e-3, maxiter=100,
N
=
theta
.
shape
[
-
1
]
# dimension of the observations
K
=
n_components
### initialize parameters with random guesses
if
w
is
None
:
def
init_params
(
mu
=
None
,
kappa
=
None
,
w
=
None
):
if
w
is
None
:
#w = jnp.array(np.random.uniform(size=n_components))
w
=
jnp
.
ones
(
K
)
w
/=
w
.
sum
()
if
mu
is
None
:
w
=
jnp
.
ones
(
K
)
w
/=
w
.
sum
()
if
mu
is
None
:
# initialize cluster centers at randomly chosen entries of theta
#_ix = np.random.randint(M, size=K)
#mu = theta[_ix,:]
mu
=
jnp
.
array
(
np
.
random
.
uniform
(
low
=-
np
.
pi
,
high
=
np
.
pi
,
size
=
(
K
,
N
)))
if
kappa
is
None
:
kappa
=
K
*
jnp
.
ones
((
K
,
N
))
#mu = theta[_ix,:]
mu
=
jnp
.
array
(
np
.
random
.
uniform
(
low
=-
np
.
pi
,
high
=
np
.
pi
,
size
=
(
K
,
N
)))
if
kappa
is
None
:
kappa
=
K
*
jnp
.
ones
((
K
,
N
))
return
mu
,
kappa
,
w
mu
,
kappa
,
w
=
init_params
(
mu
,
kappa
,
w
)
converged
=
False
n_iter
=
0
# get log likelihood of initial parameter guess
...
...
@@ -298,11 +303,20 @@ def fit_vmfmm(theta, n_components, atol=1e-3, maxiter=100,
r
=
_e_step
(
theta
,
mu
,
kappa
,
w
)
n_reseed
+=
1
mu
,
kappa
,
w
=
_m_step
(
theta
,
mu
,
kappa
,
r
)
mu_new
,
kappa_new
,
w_new
=
_m_step
(
theta
,
mu
,
kappa
,
r
)
ll
=
mix_ll
(
theta
,
mu_new
,
kappa_new
,
w_new
)
#jnp.sum(mix_ll(theta, mu, kappa, w))
ll
=
mix_ll
(
theta
,
mu
,
kappa
,
w
)
#jnp.sum(mix_ll(theta, mu, kappa, w))
log_likelihoods
.
append
(
ll
)
n_iter
+=
1
if
not
jnp
.
isnan
(
ll
):
log_likelihoods
.
append
(
ll
)
mu
,
kappa
,
w
=
mu_new
,
kappa_new
,
w_new
n_iter
+=
1
else
:
# try the next step with regularized concentration parameters instead
mu
,
kappa
,
w
=
_m_step
(
theta
,
mu
,
kappa
,
w
,
reg_kappa
=
1.0
/
len
(
kappa
))
ll
=
mix_ll
(
theta
,
mu_new
,
kappa_new
,
w_new
)
if
np
.
abs
(
log_likelihoods
[
-
2
]
-
log_likelihoods
[
-
1
])
<
atol
:
converged
=
True
...
...
@@ -363,7 +377,7 @@ def _e_step(theta, mu, kappa, w):
return
r
@jit
def
_m_step
(
theta
,
mu
,
kappa
,
r
,
gtol
=
1e-5
,
maxiter
=
None
):
def
_m_step
(
theta
,
mu
,
kappa
,
r
,
gtol
=
1e-5
,
maxiter
=
None
,
reg_kappa
=
0.0
):
"""
Parameters
----------
...
...
@@ -393,7 +407,7 @@ def _m_step(theta, mu, kappa, r, gtol=1e-5, maxiter=None):
_p
=
p
.
reshape
(
params
.
shape
)
mus
=
_p
[:,:,
0
]
kappas
=
jnp
.
minimum
(
jnp
.
abs
(
_p
[:,:,
1
]),
1000.
)
return
-
expected_log_likelihood
(
theta
,
mus
,
kappas
,
r
,
w
)
/
M
return
-
expected_log_likelihood
(
theta
,
mus
,
kappas
,
r
,
w
)
/
M
+
reg_kappa
*
reg_kappa
*
jnp
.
sum
(
kappas
)
**
2
x0
=
params
.
flatten
()
res
=
jminimize
(
negative_ell
,
...
...
@@ -411,7 +425,7 @@ def expected_log_likelihood(theta, mu, kappa, r, w):
lambda
mu
,
kappa
:
von_mises_log_pdf
(
theta
,
mu
,
kappa
),
in_axes
=
(
0
,
0
),
out_axes
=
1
)
return
jnp
.
sum
(
r
*
component_log_likelihoods
(
mu
,
kappa
))
+
w
.
dot
(
jnp
.
log
(
w
+
1e-12
))
return
jnp
.
sum
(
r
*
component_log_likelihoods
(
mu
,
kappa
))
+
jnp
.
sum
(
x
log
y
(
w
,
w
))
def
rand_vmf_mixture
(
N
,
mu
,
kappa
,
w
):
"""
...
...
This diff is collapsed.
Click to expand it.
VMFMixture/random.py
+
1
−
9
View file @
aa7dfa0e
...
...
@@ -131,14 +131,6 @@ def rand_von_mises(mu, kappa, shape=None):
N
,
M
=
shape
# mu should be a real scalar. It can wrap around the circle, so it can be negative, positive and also
# outside the range [0,2*pi].
if
(
type
(
mu
)
not
in
{
float
,
int
,
np
.
int32
,
np
.
int64
,
np
.
float32
,
np
.
float64
}):
raise
TypeError
(
"
mu must be a real scalar number.
"
)
# kappa should be positive real scalar
if
(
type
(
kappa
)
not
in
{
float
,
int
,
np
.
int32
,
np
.
int64
,
np
.
float32
,
np
.
float64
}):
raise
TypeError
(
"
kappa must be a positive float.
"
)
if
kappa
<
0
:
raise
Exception
(
"
kappa must be a positive float.
"
)
# SPECIAL CASE
...
...
@@ -277,4 +269,4 @@ def rand_von_mises_jittable(rng_key, mu, kappa, shape):
theta
=
theta
-
2
*
jnp
.
pi
*
(
theta
>
np
.
pi
)
theta
=
theta
.
reshape
(
shape
)
return
theta
\ No newline at end of file
return
theta
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment