import functools
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import config
from jax.ops import index
from jax.ops import index_update
from skillmodels.constraints import add_bounds
from skillmodels.constraints import get_constraints
from skillmodels.kalman_filters import calculate_sigma_scaling_factor_and_weights
from skillmodels.kalman_filters import kalman_predict
from skillmodels.kalman_filters import kalman_update
from skillmodels.params_index import get_params_index
from skillmodels.parse_params import create_parsing_info
from skillmodels.parse_params import parse_params
from skillmodels.process_data import process_data_for_estimation
from skillmodels.process_debug_data import process_debug_data
from skillmodels.process_model import get_period_measurements
from skillmodels.process_model import process_model
config.update("jax_enable_x64", True)
def _log_likelihood_jax(
params,
parsing_info,
update_info,
measurements,
controls,
transition_functions,
sigma_scaling_factor,
sigma_weights,
dimensions,
labels,
estimation_options,
not_missing,
debug,
):
"""Log likelihood of a skill formation model.
This function is jax-differentiable and jax-jittable as long as all but the first
argument are marked as static.
The function returns both a tuple (float, dict). The first entry is the aggregated
log likelihood value. The second additional information like the log likelihood
contribution of each individual. Note that the dict also contains the aggregated
value. Returning that value separately is only needed to calculate a gradient
with Jax.
Args:
params (jax.numpy.array): 1d array with model parameters.
parsing_info (dict): Contains information how to parse parameter vector.
update_info (pandas.DataFrame): Contains information about number of updates in
each period and purpose of each update.
measurements (jax.numpy.array): Array of shape (n_updates, n_obs) with data on
observed measurements. NaN if the measurement was not observed.
controls (jax.numpy.array): Array of shape (n_periods, n_obs, n_controls)
with observed control variables for the measurement equations.
transition_functions (tuple): tuple of tuples where the first element is the
name of the transition function and the second the actual transition
function. Order is important and corresponds to the latent
factors in alphabetical order.
sigma_scaling_factor (float): A scaling factor that controls the spread of the
sigma points. Bigger means that sigma points are further apart. Depends on
the sigma_point algorithm chosen.
sigma_weights (jax.numpy.array): 1d array of length n_sigma with non-negative
sigma weights.
dimensions (dict): Dimensional information like n_states, n_periods, n_controls,
n_mixtures. See :ref:`dimensions`.
labels (dict): Dict of lists with labels for the model quantities like
factors, periods, controls, stagemap and stages. See :ref:`labels`
not_missing (jax.numpy.array): Array with same shape as measurements that is
True where measurements are not missing.
debug (bool): Boolean flag. If True, more intermediate results are returned
Returns:
jnp.array: 1d array of length 1, the aggregated log likelihood.
dict: Additional data, containing log likelihood contribution of each Kalman
update potentially if ``debug`` is ``True`` additional information like
the filtered states.
"""
n_obs = measurements.shape[1]
states, upper_chols, log_mixture_weights, pardict = parse_params(
params, parsing_info, dimensions, labels, n_obs
)
n_updates = len(update_info)
loglikes = jnp.zeros((n_updates, n_obs))
debug_infos = []
states_history = []
k = 0
for t in labels["periods"]:
nmeas = len(get_period_measurements(update_info, t))
for _j in range(nmeas):
purpose = update_info.iloc[k]["purpose"]
new_states, new_upper_chols, new_weights, loglikes_k, info = kalman_update(
states,
upper_chols,
pardict["loadings"][k],
pardict["controls"][k],
pardict["meas_sds"][k],
measurements[k],
controls[t],
log_mixture_weights,
not_missing[k],
debug,
)
if debug:
states_history.append(new_states)
loglikes = index_update(loglikes, index[k], loglikes_k)
log_mixture_weights = new_weights
if purpose == "measurement":
states, upper_chols = new_states, new_upper_chols
debug_infos.append(info)
k += 1
if t != labels["periods"][-1]:
states, upper_chols = kalman_predict(
states,
upper_chols,
sigma_scaling_factor,
sigma_weights,
transition_functions,
pardict["transition"][t],
pardict["shock_sds"][t],
pardict["anchoring_scaling_factors"][t : t + 2],
pardict["anchoring_constants"][t : t + 2],
)
clipped = soft_clipping(
arr=loglikes,
lower=estimation_options["clipping_lower_bound"],
upper=estimation_options["clipping_upper_bound"],
lower_hardness=estimation_options["clipping_lower_hardness"],
upper_hardness=estimation_options["clipping_upper_hardness"],
)
value = clipped.sum()
additional_data = {
# used for scalar optimization, thus has to be clipped
"value": value,
# can be used for sum-structure optimizers, thus has to be clipped
"contributions": clipped.sum(axis=0),
}
if debug:
additional_data["all_contributions"] = loglikes
additional_data["residuals"] = [info["residuals"] for info in debug_infos]
additional_data["residual_sds"] = [info["residual_sds"] for info in debug_infos]
initial_states, *_ = parse_params(
params, parsing_info, dimensions, labels, n_obs
)
additional_data["initial_states"] = initial_states
additional_data["filtered_states"] = states_history
return value, additional_data
[docs]def soft_clipping(arr, lower=None, upper=None, lower_hardness=1, upper_hardness=1):
"""Clip values in an array elementwise using a soft maximum to avoid kinks.
Clipping from below is taking a maximum between two values. Clipping
from above is taking a minimum, but it can be rewritten as taking a maximum after
switching the signs.
To smooth out the kinks introduced by normal clipping, we first rewrite all clipping
operations to taking maxima. Then we replace the normal maximum by the soft maximum.
For background on the soft maximum check out this
`article by John Cook: <https://www.johndcook.com/soft_maximum.pdf>`_
Note that contrary to the name, the soft maximum can be calculated using
``scipy.special.logsumexp``. ``scipy.special.softmax`` is the gradient of
``scipy.special.logsumexp``.
Args:
arr (jax.numpy.array): Array that is clipped elementwise.
lower (float): The value at which the array is clipped from below.
upper (float): The value at which the array is clipped from above.
lower_hardness (float): Scaling factor that is applied inside the soft maximum.
High values imply a closer approximation of the real maximum.
upper_hardness (float): Scaling factor that is applied inside the soft maximum.
High values imply a closer approximation of the real maximum.
"""
shape = arr.shape
flat = arr.flatten()
dim = len(flat)
if lower is not None:
helper = jnp.column_stack([flat, jnp.full(dim, lower)])
flat = (
jax.scipy.special.logsumexp(lower_hardness * helper, axis=1)
/ lower_hardness
)
if upper is not None:
helper = jnp.column_stack([-flat, jnp.full(dim, -upper)])
flat = (
-jax.scipy.special.logsumexp(upper_hardness * helper, axis=1)
/ upper_hardness
)
return flat.reshape(shape)
def _to_numpy(obj):
if isinstance(obj, dict):
res = {}
for key, value in obj.items():
if np.isscalar(value):
res[key] = value
else:
res[key] = np.array(value)
elif np.isscalar(obj):
res = obj
else:
res = np.array(obj)
return res