Matlab Object Oriented Model Fitting Tutorial

I frequently fit computational models to behavioral and neuroimaging datasets using statistical, reinforcement learning, and utility frameworks.  However, I typically write separate scripts for each specific study.  To make this process easier I have developed a new matlab function that allows anyone to easily fit any type of model to a multi-subject dataset using very simple commands.  To use it, simply clone my github toolbox repository and add the matlab folder to your matlab path.

Model Fitting Overview

The general idea of model fitting is to posit a theory about how something might work and test it by examining how well the theory can account for a dataset given a set of free parameters.  In principle, I try to evaluate how well the model performs (i.e., the goodness of fit) relative to other competing theories.  This comparison assumes that the model with the fewest number of free parameters that has the best goodness of fit for the most subjects is the best model to explain the data.  To calculate the goodness of fit, it is often necessary to determine the optimal free parameters specified in the model, a process referred to as parameter estimation.   The most common type of estimation techniques are Least Squares Estimation, which attempts to minimize the squared difference between the model prediction and the observed data, and Maximum Likelihood Estimation, which calculates the likelihood of a model explaining a dataset given a set of parameters.  I recommend reading Myung (2003) for a tutorial on maximum likelihood.  I typically use a two-stage multi-level framework to estimate parameters for a group.  This is identical to neuroimaging analyses in which parameters are estimated for each individual subject and then averaged across participants.   For hypothesis testing, I often use non-parametric statistical tests (e.g., signed rank test) to perform model comparison on model fits penalized for their number of free parameters (e.g., AIC, BIC ).  I also occasionally use F and chi square tests to perform nested model comparisons and the vuong statistic for non-nested model comparisons.  More recently, I have been playing with the stan statistical language which has both R and Python wrappers to perform hierarchical bayesian parameter estimation.  For a more in depth overview of this process, I highly recommend an excellent chapter by Nathaniel Daw.

Computational Model Class ( comp_model() )

The computational model ( comp_model() ) class creates an object instance that is used to fit a computational model to a multi-subject dataset. The object uses the design_matrix() class for the data set, which has rudimentary functionality similar to an R data.frame().  There are additional fields for the model and parameters for the model fitting procedure such as the parameter constraints, number of iterations, and type of estimation (e.g., maximum likelihood or least squares).  I use Matlab’s fmincon optimization function to find the parameters that result in the best fit of the model.  This constrained optimization function is very efficient, but has a tendency to occasionally get stuck in a local minima.  Using more iterations will rerun the model multiple times with different initial starting values which decreases the likelihood of fmincon getting stuck in local minimas, but increases the amount of time it takes to fit the models.

Current Methods for comp_model (inherits from design_matrix class too)

• avg_aic – display average AIC value
• avg_bic – display average BIC value
• avg_params – display average parameter estimates
• comp_model – class constructor
• fit_model – estimate parameters using model
• get_aic – extract all subject’s AIC values
• get_bic – extract all subject’s BIC values
• get_params – extract all subject’s estimated parameters
• plot – plot average model predictions across subjects
• summary – display summary table for model
• save – save object as .mat file
• write_tables – write out parameter estimates and trial-to-trial predictions to csv data frame.

Example Usage

First we will load some example data using the importdata function and set some of the defaults settings for the optimization algorithm.  The comp_model() function assumes that data is in the long format and that the subject indicator is the first column of the data frame.

basedir = '~/Dropbox/TreatmentExpectations';
dat = importdata(fullfile(basedir, 'Data','Seattle_Sona','seattle_ses_by_session.csv'));

% Set optimization parameters for fmincon (OPTIONAL)
options = optimset(@fmincon);
options = optimset(options, 'TolX', 0.00001, 'TolFun', 0.00001, 'MaxFunEvals', 900000000, 'LargeScale','off');

Next, we will create comp_model() class instance for an example linear model.  This requires that we specify a data frame, cell array of column names, and a model name. The Model Name must refer to function containing the model, which is located on the matlab path (See example ‘linear_model’ function below).  We can also specify some additional parameters that are required for the model fitting procedure.

• nStart – the number of iterations to repeat model estimation – selects the iteration associated with the best model fit. Higher numbers decrease the likelihood of fmincon getting stuck in a local minima, but increase the computational time.
• param_min – vector of lower bound of parameters
• param_max – vector of upper bound of parameters
• esttype – type of parameter estimation (‘SSE’ – minimize sum of squared error; ‘LLE’ – maximize log likelihood; ‘LE’ – maximize likelihood)
lin = comp_model(dat.data,dat.textdata,'linear_model','nStart',10, 'param_min',[-5, -20], 'param_max', [60, 20], 'esttype','SSE');

911x8 comp_model array with properties:

model: 'linear_model'
param_min: [-5 -20]
param_max: [60 20]
nStart: 10
esttype: 'SSE'
params: []
trial: []
dat: [911x8 double]
varname: {'subj' 'group' 'sess' 'se_count' 'se_sum_intensity' 'any_action_taken' 'hamtot' 'bditot'}
fname: ''

Once the comp_model() object has been created with all of the necessary setup parameters, the model can be fit to the data with following command.

lin = lin.fit_model();

911x8 comp_model array with properties:

model: 'linear_model'
param_min: [-5 -20]
param_max: [60 20]
nStart: 10
esttype: 'SSE'
params: [77x6 double]
trial: [911x4 double]
dat: [911x8 double]
varname: {'subj' 'group' 'sess' 'se_count' 'se_sum_intensity' 'any_action_taken' 'hamtot' 'bditot'}
fname: ''

This adds data to two new fields, params and trial.

• params – is the parameters estimated for each subject. Rows are individual subjects. Columns are {‘Subject’, ‘Estimated Parameters (specific to each model)’, ‘Model Fit’, ‘AIC’, ‘BIC’}
• trial – is the trial by trial data and predicted values for all subjects stacked together.

The overall average results from the model can be quickly viewed using the summary() method.

summary(lin)

Summary of Model: linear_model
-----------------------------------------
Average Parameters: 18.1073
Average AIC: 35.8264
Average BIC: 36.3802
Average SSE: -1.86
Number of Subjects: 77
-----------------------------------------

The ‘params’ and ‘trial’ data frames can be written to separate .csv files.  This can be helpful to open these data in a separate program such as R or Python for plotting or futher analyses.

lin.write_tables(fullfile(basedir,'Analysis','Modeling'))

The overall object instance can be saved as .mat file. This is helpful as sometimes model estimation can take a long time especially if using a high number of iterations.

lin.save(fullfile(basedir,'Analysis','Modeling','Linear_ModelFit.mat')) The average predicted values from the model can be quickly plotted using the plot() function.  It is necessary to specifiy the particular columns from obj.trial to plot as these will be specific to the data set and model. This method has some rudimentary options to customize the plot.

plot(lin, [3,4], 'title', 'Linear Model', 'xlabel','session', 'ylabel', 'Average BDI', 'legend', {'Predicted','Observed'})

Example Model Function

Here is an example function of a very simple linear model. Functions can be completely flexible, but need to have the free parameter (xpar) and data as inputs and output the model fit (e.g., sse).  This is so fmincon can optimize the parameters for this function by minimizing the Sum of Squared Error (sse – for this example).  This function needs to be in a separate function file on the matlab path.

function sse = linear_model(xpar, data)
% Fit Linear Decay Treatment Model

global trialout %this allows trial to be saved to comp_model() object

% Model Parameters
beta0 = xpar(1); % Intercept
beta1 = xpar(2); % Slope

% Parse Data
obssx = data(:,8); %BDI symptom scores

%Model Initial Values
sse = 0; % initial value sum of squared error
time = 1:size(data,1);
time = time - mean(time); %center time variable

% This model is looping through every trial. Obviously this isn't
% necessary for this specific model, but it is for more dynamic models that
% change with respect to time such as RL models.
for t = 1:length(obssx)

%Calculate symptom decay using linear decay
predsx(t) = beta0 + beta1 * time(t); %linear trend of symptom change

% update sum of squared error (sse)
sse = sse + (predsx(t) - obssx(t))^2;

end

%Output trial by trial results - saved as obj.trial
trialout = [ones(t,1)*data(1,1) (1:t)', obssx, predsx(1:t)'];

end % model

Model Comparison

For a quick example of testing a hypothesis by comparing competing models we will use Matlab’s robust non-parametric signrank test to determine if the test model (i.e. linear_expect_model) fits the data significantly better than the null model (i.e. linear_model) using the get_aic() command.  We use the AIC metric to compare the models as the test model has 3 free parameters compared to the null model which only has 2.

[P,H,STATS] = signrank(get_aic(lin),get_aic(lin_expect))

P =
8.8387e-08

H =
1

STATS =
zval: 5.3491
signedrank: 2555

In this example, we see that the test Linear Expectation model fits the data significantly better than the null Linear model, supporting our hypothesis.

Let me know if you have any comments or suggestions on how to improve the software.