Source code for skillmodels.transition_functions

"""Contains transition functions and corresponding helper functions.

Below the signature and purpose of a transition function and its helper
functions is explained with a transition function called example_func:


**example_func(** *sigma_points, params**)**:

    The actual transition function.

    Args:
        * sigma_points: 4d numpy array of sigma_points or states being transformed.
            The shape is n_obs, n_mixtures, n_sigma, n_fac.
        * params: 1d numpy array with coefficients specific to this transition function

    Returns
        * np.ndarray: Shape is n_obs, n_mixtures, n_sigma


**index_tuples_example_func(** *factor, factors, period* **)**:

    Generate a list of index tuples for the params of the transition function.

    Each index tuple contains four entries

    - 'transition' (fix)
    - period
    - factor
    - 'some-name'


The transition functions have to be JAX jittable and differentiable. However, they
should not be jitted yet.

"""
from itertools import combinations

import jax
import jax.numpy as jnp


[docs]def linear(sigma_points, params): """Linear production function where the constant is the last parameter.""" constant = params[-1] betas = params[:-1] return jnp.dot(sigma_points, betas) + constant
[docs]def index_tuples_linear(factor, factors, period): """Index tuples for linear transition function.""" ind_tups = [("transition", period, factor, rhs_fac) for rhs_fac in factors] return ind_tups + [("transition", period, factor, "constant")]
[docs]def translog(sigma_points, params): """Translog transition function. The name is a convention in the skill formation literature even though the function is better described as a linear in parameters transition function with squares and interaction terms of the states. """ nfac = sigma_points.shape[-1] constant = params[-1] lin_beta = params[:nfac] square_beta = params[nfac : 2 * nfac] inter_beta = params[2 * nfac : -1] res = jnp.dot(sigma_points, lin_beta) res += jnp.dot(sigma_points ** 2, square_beta) for p, (a, b) in zip(inter_beta, combinations(range(nfac), 2)): res += p * sigma_points[..., a] * sigma_points[..., b] res += constant return res
[docs]def index_tuples_translog(factor, factors, period): """Index tuples for the translog production function.""" ind_tups = [("transition", period, factor, rhs_fac) for rhs_fac in factors] ind_tups += [ ("transition", period, factor, f"{rhs_fac} ** 2") for rhs_fac in factors ] ind_tups += [ ("transition", period, factor, f"{a} * {b}") for a, b in combinations(factors, 2) ] ind_tups += [("transition", period, factor, "constant")] return ind_tups
[docs]def log_ces(sigma_points, params): """Log CES production function (KLS version).""" phi = params[-1] gammas = params[:-1] scaling_factor = 1 / phi # note: once the b argument is supported in jax.scipy.special.logsumexp, we can set # b = gammas instead of adding the log of gammas to sigma_points * phi # the log step for gammas underflows for gamma = 0, but this is handled correctly # by logsumexp and does not raise a warning. unscaled = jax.scipy.special.logsumexp( jnp.log(gammas) + sigma_points * phi, axis=-1 ) result = unscaled * scaling_factor return result
[docs]def index_tuples_log_ces(factor, factors, period): """Index tuples for the log_ces production function.""" ind_tups = [("transition", period, factor, rhs_fac) for rhs_fac in factors] return ind_tups + [("transition", period, factor, "phi")]
def constraints_log_ces(factor, factors, period): ind_tups = index_tuples_log_ces(factor, factors, period) loc = ind_tups[:-1] return {"loc": loc, "type": "probability"}
[docs]def constant(sigma_points, params): """Constant production function should never be called.""" raise NotImplementedError
[docs]def index_tuples_constant(factor, factors, period): """Index tuples for the constant production function.""" return []