Modules Related to Estimation¶
The Likelihood Function¶
- get_maximization_inputs(model_dict, data)[source]¶
Create inputs for estimagic’s maximize function.
- Parameters
model_dict (dict) – The model specification. See: Model specifications
data (DataFrame) – dataset in long format.
- Returns a dictionary with keys:
- loglike (function): A jax jitted function that takes an estimagic-style
params dataframe as only input and returns a dict with entries: - “value”: The scalar log likelihood - “contributions”: An array with the log likelihood per observation
- debug_loglike (function): Similar to loglike, with the following differences:
It is not jitted and thus faster on the first call and debuggable
It will add intermediate results as additional entries in the returned dictionary. Those can be used for debugging and plotting.
- gradient (function): The gradient of the scalar log likelihood
function with respect to the parameters.
- loglike_and_gradient (function): Combination of loglike and
loglike_gradient that is faster than calling the two functions separately.
- constraints (list): List of estimagic constraints that are implied by the
model specification.
- params_template (pd.DataFrame): Parameter DataFrame with correct index and
bounds but with empty value column.
- soft_clipping(arr, lower=None, upper=None, lower_hardness=1, upper_hardness=1)[source]¶
Clip values in an array elementwise using a soft maximum to avoid kinks.
Clipping from below is taking a maximum between two values. Clipping from above is taking a minimum, but it can be rewritten as taking a maximum after switching the signs.
To smooth out the kinks introduced by normal clipping, we first rewrite all clipping operations to taking maxima. Then we replace the normal maximum by the soft maximum.
For background on the soft maximum check out this article by John Cook:
Note that contrary to the name, the soft maximum can be calculated using
scipy.special.logsumexp
.scipy.special.softmax
is the gradient ofscipy.special.logsumexp
.- Parameters
arr (jax.numpy.array) – Array that is clipped elementwise.
lower (float) – The value at which the array is clipped from below.
upper (float) – The value at which the array is clipped from above.
lower_hardness (float) – Scaling factor that is applied inside the soft maximum. High values imply a closer approximation of the real maximum.
upper_hardness (float) – Scaling factor that is applied inside the soft maximum. High values imply a closer approximation of the real maximum.
The Kalman Filters¶
- kalman_update(states, upper_chols, loadings, control_params, meas_sd, measurements, controls, log_mixture_weights, not_missing, debug)[source]¶
Perform a Kalman update with likelihood evaluation.
- Parameters
states (jax.numpy.array) – Array of shape (n_obs, n_mixtures, n_states) with pre-update states estimates.
upper_chols (jax.numpy.array) – Array of shape (n_obs, n_mixtures, n_states, n_states) with the transpose of the lower triangular cholesky factor of the pre-update covariance matrix of the state estimates.
loadings (jax.numpy.array) – 1d array of length n_states with factor loadings.
control_params (jax.numpy.array) – 1d array of length n_controls.
meas_sd (float) – Standard deviation of the measurement error.
measurements (jax.numpy.array) – 1d array of length n_obs with measurements. May contain NaNs if no measurement was observed.
controls (jax.numpy.array) – Array of shape (n_obs, n_controls) with data on the control variables.
log_mixture_weights (jax.numpy.array) – Array of shape (n_obs, n_mixtures) with the natural logarithm of the weights of each element of the mixture of normals distribution.
not_missung (jax.numpy.array) – Boolean 1d array of length n_obs that indicates if a measurement not missing. This could be calculated on the fly but that generates a jax error on GPUs.
debug (bool) – If true, the debug_info contains the residuals of the update and their standard deviations. Otherwise, it is an empty dict.
- Returns
states (jax.numpy.array) – Same format as states. new_states (jax.numpy.array): Same format as states. new_upper_chols (jax.numpy.array): Same format as upper_chols new_log_mixture_weights: (jax.numpy.array): Same format as log_mixture_weights new_loglikes: (jax.numpy.array): 1d array of length n_obs debug_info (dict): Empty or containing residuals and residual_sds
- calculate_sigma_scaling_factor_and_weights(n_states, kappa=2)[source]¶
Calculate the scaling factor and weights for sigma points according to Julier.
There are other sigma point algorithms, but many of them possibly have negative weights which makes the unscented predict step more complicated.
- Parameters
n_states (int) – Number of states.
kappa (float) – Spreading factor of the sigma points.
- Returns
float – Scaling factor jax.numpy.array: Sigma weights of length 2 * n_states + 1
- kalman_predict(states, upper_chols, sigma_scaling_factor, sigma_weights, transition_functions, trans_coeffs, shock_sds, anchoring_scaling_factors, anchoring_constants)[source]¶
Make a unscented Kalman predict.
- Parameters
states (jax.numpy.array) – Array of shape (n_obs, n_mixtures, n_states) with pre-update states estimates.
upper_chols (jax.numpy.array) – Array of shape (n_obs, n_mixtures, n_states, n_states) with the transpose of the lower triangular cholesky factor of the pre-update covariance matrix of the state estimates.
sigma_scaling_factor (float) – A scaling factor that controls the spread of the sigma points. Bigger means that sigma points are further apart. Depends on the sigma_point algorithm chosen.
sigma_weights (jax.numpy.array) – 1d array of length n_sigma with non-negative sigma weights.
transition_functions (tuple) – tuple of tuples where the first element is the name of the transition function and the second the actual transition function. Order is important and corresponds to the latent factors in alphabetical order.
trans_coeffs (tuple) – Tuple of 1d jax.numpy.arrays with transition parameters.
anchoring_scaling_factors (jax.numpy.array) – Array of shape (2, n_fac) with the scaling factors for anchoring. The first row corresponds to the input period, the second to the output period (i.e. input period + 1).
anchoring_constants (jax.numpy.array) – Array of shape (2, n_states) with the constants for anchoring. The first row corresponds to the input period, the second to the output period (i.e. input period + 1).
- Returns
jax.numpy.array – Predicted states, same shape as states. jax.numpy.array: Predicted upper_chols, same shape as upper_chols.
The Index of the Parameter DataFrame¶
- get_params_index(update_info, labels, dimensions)[source]¶
Generate index for the params_df for estimagic.
The index has four levels. The first is the parameter category. The second is the period in which the parameters are used. The third and fourth are additional descriptors that depend on the category. If the fourth level is not really needed, it contains an empty string.
- Parameters
update_info (pandas.DataFrame) – DataFrame with one row per Kalman update needed in the likelihood function. See update_info.
labels (dict) – Dict of lists with labels for the model quantities like factors, periods, controls, stagemap and stages. See labels
options (dict) – Tuning parameters for the estimation. See estimation_options.
- Returns
params_index (pd.MultiIndex)
- get_control_params_index_tuples(controls, update_info)[source]¶
Index tuples for control coeffs.
- Parameters
controls (list) – List of lists. There is one sublist per period which contains the names of the control variables in that period. Constant not included.
update_info (pandas.DataFrame) – DataFrame with one row per Kalman update needed in the likelihood function. See update_info.
- get_loadings_index_tuples(factors, update_info)[source]¶
Index tuples for loading.
- Parameters
factors (list) – The latent factors of the model
update_info (pandas.DataFrame) – DataFrame with one row per Kalman update needed in the likelihood function. See update_info.
- Returns
ind_tups (list)
- get_meas_sds_index_tuples(update_info)[source]¶
Index tuples for meas_sd.
- Parameters
update_info (pandas.DataFrame) – DataFrame with one row per Kalman update needed in the likelihood function. See update_info.
- Returns
ind_tups (list)
- get_shock_sds_index_tuples(periods, factors)[source]¶
Index tuples for shock_sd.
- Parameters
periods (list) – The periods of the model.
factors (list) – The latent factors of the model.
- Returns
ind_tups (list)
- initial_mean_index_tuples(n_mixtures, factors)[source]¶
Index tuples for initial_mean.
- Parameters
n_mixtures (int) – Number of elements in the mixture distribution of the factors.
factors (list) – The latent factors of the model
- Returns
ind_tups (list)
- get_mixture_weights_index_tuples(n_mixtures)[source]¶
Index tuples for mixture_weight.
- Parameters
n_mixtures (int) – Number of elements in the mixture distribution of the factors.
- Returns
ind_tups (list)
- get_initial_cholcovs_index_tuples(n_mixtures, factors)[source]¶
Index tuples for initial_cov.
- Parameters
n_mixtures (int) – Number of elements in the mixture distribution of the factors.
factors (list) – The latent factors of the model
- Returns
ind_tups (list)
- get_transition_index_tuples(factors, periods, transition_names)[source]¶
Index tuples for transition equation coefficients.
- Parameters
factors (list) – The latent factors of the model
periods (list) – The periods of the model
transition_names (list) – name of the transition equation of each factor
included_factors (list) – the factors that appear on the right hand side of the transition equations of the latent factors.
- Returns
ind_tups (list)
Parsing the Parameter Vector¶
- create_parsing_info(params_index, update_info, labels, anchoring)[source]¶
Create a dictionary with information how the parameter vector has to be parsed.
- Parameters
params_index (pandas.MultiIndex) – It has the levels [“category”, “period”, “name1”, “name2”]
update_info (pandas.DataFrame) – DataFrame with one row per Kalman update needed in the likelihood function. See update_info.
labels (dict) – Dict of lists with labels for the model quantities like factors, periods, controls, stagemap and stages. See labels
- Returns
dict –
- dictionary that maps model quantities to positions or slices of the
parameter vector.
- parse_params(params, parsing_info, dimensions, labels, n_obs)[source]¶
Parse params into the quantities that depend on it.
- Parameters
params (jax.numpy.array) – 1d array with model parameters.
parsing_info (dict) – Dictionary with information on how the parameters have to be parsed.
dimensions (dict) – Dimensional information like n_states, n_periods, n_controls, n_mixtures. See dimensions.
n_obs (int) – Number of observations.
- Returns
jax.numpy.array –
- Array of shape (n_obs, n_mixtures, n_states) with initial
state estimates.
- jax.numpy.array: Array of shape (n_obs, n_mixtures, n_states, n_states) with the
transpose of the lower triangular cholesky factors of the initial covariance matrices.
- jax.numpy.array: Array of shape (n_obs, n_mixtures) with the log of the initial
weight for each element in the finite mixture of normals.
- dict: Dictionary with other parameters. It has the following key-value pairs:
”control_params”:
”loadings”:
”meas_sds”:
”shock_sds”:
”trans_params”:
”anchoring_scaling_factors”:
”anchoring_constants”: