Source code for skillmodels.process_data

import warnings

import jax.numpy as jnp
import numpy as np
import pandas as pd

from skillmodels.process_model import get_period_measurements


[docs]def process_data_for_estimation(df, labels, update_info, anchoring_info): """Process the data for estimation. Args: df (DataFrame): panel dataset in long format. It has a MultiIndex where the first level indicates the period and the second the individual. labels (dict): Dict of lists with labels for the model quantities like factors, periods, controls, stagemap and stages. See :ref:`labels` update_info (pandas.DataFrame): DataFrame with one row per Kalman update needed in the likelihood function. See :ref:`update_info`. anchoring_info (dict): Information about anchoring. See :ref:`anchoring` Returns: meas_data (jax.numpy.array): Array of shape (n_updates, n_obs) with data on observed measurements. NaN if the measurement was not observed. control_data (jax.numpy.array): Array of shape (n_periods, n_obs, n_controls) with observed control variables for the measurement equations. """ df = _pre_process_data(df, labels["periods"]) df["constant"] = 1 df = _add_copies_of_anchoring_outcome(df, anchoring_info) _check_data(df, labels["controls"], update_info, labels) n_obs = int(len(df) / len(labels["periods"])) df = _handle_controls_with_missings(df, labels["controls"], update_info) meas_data = _generate_measurements_array(df, update_info, n_obs) control_data = _generate_controls_array(df, labels, n_obs) return meas_data, control_data
def _pre_process_data(df, periods): """Balance panel data in long format, drop unnecessary periods and set index. Args: df (DataFrame): panel dataset in long format. It has a MultiIndex where the first level indicates the period and the second the individual. Returns: balanced (DataFrame): balanced panel. It has a MultiIndex. The first enumerates individuals. The second level counts periods, starting at 0. """ df = df.copy(deep=True).sort_index() df["__old_id__"] = df.index.get_level_values(0) df["__old_period__"] = df.index.get_level_values(1) # replace existing codes for periods and df.index.names = ["id", "period"] for level in [0, 1]: df.index = df.index.set_levels(range(len(df.index.levels[level])), level) # create new index ids = sorted(df.index.get_level_values("id").unique()) new_index = pd.MultiIndex.from_product([ids, periods], names=["id", "period"]) # set new index df = df.reindex(new_index) return df def _add_copies_of_anchoring_outcome(df, anchoring_info): df = df.copy() for factor in anchoring_info["factors"]: outcome = anchoring_info["outcomes"][factor] df[f"{outcome}_{factor}"] = df[outcome] return df def _check_data(df, controls, update_info, labels): var_report = pd.DataFrame(index=update_info.index[:0], columns=["problem"]) for period in labels["periods"]: period_data = df.query(f"period == {period}") for cont in controls: if cont not in period_data.columns or period_data[cont].isnull().all(): var_report.loc[(period, cont), "problem"] = "Variable is missing" for meas in get_period_measurements(update_info, period): if meas not in period_data.columns: var_report.loc[(period, meas), "problem"] = "Variable is missing" elif len(period_data[meas].dropna().unique()) == 1: var_report.loc[(period, meas), "problem"] = "Variable has no variance" var_report = var_report.to_string() if len(var_report) > 0 else "" if var_report: raise ValueError(var_report) def _handle_controls_with_missings(df, controls, update_info): df = df.copy(deep=True) periods = update_info.index.get_level_values(0).unique().tolist() problematic_index = df.index[:0] for period in periods: period_data = df.query(f"period == {period}") control_data = period_data[controls] meas_data = period_data[get_period_measurements(update_info, period)] problem = control_data.isnull().any(axis=1) & meas_data.notnull().any(axis=1) problematic_index = problematic_index.union(period_data[problem].index) if len(problematic_index) > 0: old_names = df.loc[problematic_index][["__old_id__", "__old_period__"]] msg = "Set measurements to NaN because there are NaNs in the controls for:\n{}" msg = msg.format(list(map(tuple, old_names.to_numpy().tolist()))) warnings.warn(msg) df.loc[problematic_index] = np.nan return df def _generate_measurements_array(df, update_info, n_obs): arr = np.zeros((len(update_info), n_obs)) for k, (period, var) in enumerate(update_info.index): arr[k] = df.query(f"period == {period}")[var].to_numpy() return jnp.array(arr) def _generate_controls_array(df, labels, n_obs): arr = np.zeros((len(labels["periods"]), n_obs, len(labels["controls"]))) for period in labels["periods"]: arr[period] = df.query(f"period == {period}")[labels["controls"]].to_numpy() return jnp.array(arr)