Source code for skillmodels.parse_params

import warnings

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax.ops import index
from jax.ops import index_update

[docs]def create_parsing_info(params_index, update_info, labels, anchoring): """Create a dictionary with information how the parameter vector has to be parsed. Args: params_index (pandas.MultiIndex): It has the levels ["category", "period", "name1", "name2"] update_info (pandas.DataFrame): DataFrame with one row per Kalman update needed in the likelihood function. See :ref:`update_info`. labels (dict): Dict of lists with labels for the model quantities like factors, periods, controls, stagemap and stages. See :ref:`labels` Returns: dict: dictionary that maps model quantities to positions or slices of the parameter vector. """ range_sr = pd.Series(data=np.arange(len(params_index)), index=params_index) parsing_info = {} simple_ones = [ "initial_states", "initial_cholcovs", "mixture_weights", "controls", "loadings", "meas_sds", "shock_sds", ] for quantity in simple_ones: parsing_info[quantity] = _get_positional_selector_from_loc(range_sr, quantity) # "trans_coeffs" parsing_info["transition"] = [] for period in labels["periods"]: period_dict = {} for factor in labels["factors"]: loc = ("transition", period, factor) period_dict[factor] = _get_positional_selector_from_loc(range_sr, loc) parsing_info["transition"].append(period_dict) # anchoring_scaling_factors is_free_loading = update_info[labels["factors"]].to_numpy() is_anchoring = (update_info["purpose"] == "anchoring").to_numpy().reshape(-1, 1) is_anchoring_loading = jnp.array(is_free_loading & is_anchoring) parsing_info["is_anchoring_loading"] = is_anchoring_loading parsing_info["is_anchored_factor"] = jnp.array( update_info.query("purpose == 'anchoring'")[labels["factors"]].any(axis=0) ) parsing_info["is_anchoring_update"] = is_anchoring.flatten() parsing_info["ignore_constant_when_anchoring"] = anchoring[ "ignore_constant_when_anchoring" ] return parsing_info
def _get_positional_selector_from_loc(range_sr, loc): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="indexing past lexsort depth may impact performance." ) try: ilocs = jnp.array(range_sr.loc[loc]) except KeyError: ilocs = slice(0, 0) except Exception: raise return ilocs
[docs]def parse_params(params, parsing_info, dimensions, labels, n_obs): """Parse params into the quantities that depend on it. Args: params (jax.numpy.array): 1d array with model parameters. parsing_info (dict): Dictionary with information on how the parameters have to be parsed. dimensions (dict): Dimensional information like n_states, n_periods, n_controls, n_mixtures. See :ref:`dimensions`. n_obs (int): Number of observations. Returns: jax.numpy.array: Array of shape (n_obs, n_mixtures, n_states) with initial state estimates. jax.numpy.array: Array of shape (n_obs, n_mixtures, n_states, n_states) with the transpose of the lower triangular cholesky factors of the initial covariance matrices. jax.numpy.array: Array of shape (n_obs, n_mixtures) with the log of the initial weight for each element in the finite mixture of normals. dict: Dictionary with other parameters. It has the following key-value pairs: - "control_params": - "loadings": - "meas_sds": - "shock_sds": - "trans_params": - "anchoring_scaling_factors": - "anchoring_constants": """ states = _get_initial_states(params, parsing_info, dimensions, n_obs) upper_chols = _get_initial_upper_chols(params, parsing_info, dimensions, n_obs) log_weights = _get_initial_log_mixture_weights(params, parsing_info, n_obs) pardict = { "controls": _get_control_params(params, parsing_info, dimensions), "loadings": _get_loadings(params, parsing_info, dimensions), "meas_sds": _get_meas_sds(params, parsing_info), "shock_sds": _get_shock_sds(params, parsing_info, dimensions), "transition": _get_transition_params(params, parsing_info, dimensions, labels), } pardict["anchoring_scaling_factors"] = _get_anchoring_scaling_factors( pardict["loadings"], parsing_info, dimensions ) pardict["anchoring_constants"] = _get_anchoring_constants( pardict["controls"], parsing_info, dimensions ) return states, upper_chols, log_weights, pardict
def _get_initial_states(params, info, dimensions, n_obs): """Create the array of initial states.""" state = params[info["initial_states"]].reshape( 1, dimensions["n_mixtures"], dimensions["n_states"] ) return jnp.repeat(state, n_obs, axis=0) def _get_initial_upper_chols(params, info, dimensions, n_obs): """Create the array with cholesky factors of the initial states covariance matrix. Note: The matrices contain the transpose of the lower triangular cholesky factors. """ n_states, n_mixtures = dimensions["n_states"], dimensions["n_mixtures"] chol_params = params[info["initial_cholcovs"]].reshape(n_mixtures, -1) upper_chols = jnp.zeros((n_obs, n_mixtures, n_states, n_states)) for i in range(n_mixtures): filler = jnp.zeros((n_states, n_states)) filler = index_update(filler, jnp.tril_indices(n_states), chol_params[i]) upper_chols = index_update(upper_chols, index[:, i], filler.T) return upper_chols def _get_initial_log_mixture_weights(params, info, n_obs): """Create the array with the log of initial mixture weights.""" log_weights = jnp.log(params[info["mixture_weights"]]).reshape(1, -1) return jnp.repeat(log_weights, n_obs, axis=0) def _get_control_params(params, info, dimensions): """Create the parameters for control variables in measurement equations.""" return params[info["controls"]].reshape(-1, dimensions["n_controls"]) def _get_loadings(params, info, dimensions): """Create the array of factor loadings.""" return params[info["loadings"]].reshape(-1, dimensions["n_states"]) def _get_meas_sds(params, info): """Create the array of standard deviations of the measurement errors.""" return params[info["meas_sds"]] def _get_shock_sds(params, info, dimensions): """Create the array of standard deviations of the shocks in transition functions.""" return params[info["shock_sds"]].reshape(-1, dimensions["n_states"]) def _get_transition_params(params, info, dims, labels): """Create a list of arrays with transition equation parameters.""" trans_params = [] for period in range(dims["n_periods"] - 1): trans_params_t = [] for factor in labels["factors"]: trans_params_t.append(params[info["transition"][period][factor]]) trans_params.append(tuple(trans_params_t)) return tuple(trans_params) def _get_anchoring_scaling_factors(loadings, info, dimensions): """Create an array of anchoring scaling factors. Note: Parameters are not taken from the parameter vector but from the loadings. """ scaling_factors = jnp.ones((dimensions["n_periods"], dimensions["n_states"])) free_anchoring_loadings = loadings[info["is_anchoring_loading"]].reshape( dimensions["n_periods"], -1 ) scaling_factors = index_update( scaling_factors, index[:, info["is_anchored_factor"]], free_anchoring_loadings ) return scaling_factors def _get_anchoring_constants(controls, info, dimensions): """Create an array of anchoring constants. Note: Parameters are not taken from the parameter vector but from the controls. """ constants = jnp.zeros((dimensions["n_periods"], dimensions["n_states"])) if not info["ignore_constant_when_anchoring"]: values = controls[:, 0][info["is_anchoring_update"]].reshape( dimensions["n_periods"], -1 ) constants = index_update( constants, index[:, info["is_anchored_factor"]], values ) return constants