{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Using Longitudinal Data\n", "=======================\n", "\n", "This example demonstrates how to do cross validation with longitudinal\n", "data.\n", "By *longitudinal*, we mean predictors for specific samples that have more than one value. \n", "This could be time series data or any other sequential data we might want to model. \n", "\n", "\n", "Example Patient Data\n", "------------------------------\n", "\n", "First, we generate some example data and store it using this script: https://github.com/lacava/feat/blob/master/docs/examples/longitudinal/generate_example_longitudinal_data.py. \n", "\n", "Let’s imagine we have patient data from a hospital. \n", "This means we have measurements from different visits, with different numbers of measurements from different\n", "patients collected in non-uniform intervals.\n", "In this example, we make up a risk model in which risk increases for a\n", "patient with an increasing body mass index (BMI) and a high maximum\n", "glucose level in their blood panel." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from sklearn.model_selection import KFold\n", "\n", "random_state=42\n", "\n", "df = pd.read_csv('data/d_example_patients.csv')\n", "df.drop('id',axis=1,inplace=True)\n", "X = df.drop('target',axis=1)\n", "y = df['target']\n", "kf = KFold(n_splits=3,shuffle=True,random_state=random_state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# longitudinal format\n", "\n", "FEAT expects longitudinal data in the following format:\n", "\n", "```python\n", "\n", "Z = {\n", " 'variable': ([patient1_values, patient2_values], [patient1_times, patient2_timestamps])\n", "}\n", "```\n", "\n", "\n", "Longitudinal data is a dictionary in which the keys are the variable names and the values are tuples. \n", "The first element of the tuple contains observations, and the second element contains corresponding time stamps for those observations. \n", "The observations and timestamps are expected to be lists, with one element for each patient. \n", "Each patient element contains all of the observations or time stamps for that patient. \n", "\n", "\n", "On the c++ side, FEAT interprets this into the following format:\n", "\n", "```c++\n", "typedef std::map, vector>\n", " > LongData;\n", "```\n", "\n", "Although a little clunky, the goal is to store patient-specific values in arrays under the hood to allow for as much SIMD optimization as possible when evaluating operators." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the longitudinal data we generated:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnamedatevalue
00bmi251540.000000
10age251533.000000
20glucose25150.058084
30bmi297940.000000
40age297934.000000
...............
147250992age71187213.000000
147251992glucose711870.388632
147252992bmi7214987.000000
147253992age72149215.000000
147254992glucose721490.426824
\n", "

147255 rows × 4 columns

\n", "
" ], "text/plain": [ " id name date value\n", "0 0 bmi 2515 40.000000\n", "1 0 age 2515 33.000000\n", "2 0 glucose 2515 0.058084\n", "3 0 bmi 2979 40.000000\n", "4 0 age 2979 34.000000\n", "... ... ... ... ...\n", "147250 992 age 71187 213.000000\n", "147251 992 glucose 71187 0.388632\n", "147252 992 bmi 72149 87.000000\n", "147253 992 age 72149 215.000000\n", "147254 992 glucose 72149 0.426824\n", "\n", "[147255 rows x 4 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "zfile = 'data/d_example_patients_long.csv'\n", "zdf = pd.read_csv('data/d_example_patients_long.csv')\n", "zdf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see it's in a long tabular format. Below we convert it to the FEAT input format." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dates = zdf.set_index(['name','id'])['date']\n", "values = zdf.set_index(['name','id'])['value']\n", "# zdf.set_index(['name','date']).to_dict(orient='tight')\n", "values.to_dict()\n", "Z = {}\n", "for name, zg in zdf.groupby('name'):\n", " values = [\n", " zgid['value'].values for _,zgid in zg.groupby('id')\n", " ]\n", " timestamps = [\n", " zgid['date'].values for _,zgid in zg.groupby('id')\n", " ]\n", " Z[name] = (values, timestamps)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Next we set up the learner. We need to declare the longitudinal\n", "operators we want to search over. They are defined as a comma-delimited\n", "list of strings using the ``functions`` argument. In this case, the\n", "operators on the second row of the declaration below all operate on\n", "longitudinal data.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from feat import Feat\n", "\n", "clf = Feat(max_depth=5,\n", " max_dim=5,\n", " gens = 10,\n", " pop_size = 100,\n", " max_time = 30, # seconds\n", " verbosity=0,\n", " shuffle=True,\n", " normalize=False, # don't normalize input data\n", " functions=['and','or','not','split','split_c',\n", " 'mean','median','max','min','variance','skew','kurtosis','slope','count'\n", " ],\n", " backprop=True,\n", " batch_size=10,\n", " iters=10,\n", " random_state=random_state,\n", " n_jobs=1,\n", " simplify=0.01 # prune final representations\n", " )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cross validation\n", "----------------\n", "\n", "Cross validation works a little bit differently with longitudinal data. \n", "The block below shows how to train a model using Kfold cross validation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "scores: [0.0019200502112349603, 0.0013323824931506492, 0.020377830997444116]\n" ] } ], "source": [ "scores=[]\n", "\n", "for train_idx, test_idx in kf.split(X,y):\n", " # print('train_idx:',train_idx)\n", " # note that the train index is passed to FEAT's fit method\n", " Ztrain = {k:([v[0][i] for i in train_idx], [v[1][i] for i in train_idx]) for k,v in Z.items()}\n", " Ztest = {k:([v[0][i] for i in test_idx], [v[1][i] for i in test_idx]) for k,v in Z.items()}\n", " clf.fit(X.loc[train_idx],y.loc[train_idx],Ztrain) \n", " scores.append(clf.score(X.loc[test_idx],y.loc[test_idx],Ztest))\n", "\n", "print('scores:',scores)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model Interpretation\n", "--------------------\n", "\n", "Now let’s fit a model to all the data and try to interpret it.\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fitting longer to all data...\n" ] }, { "data": { "text/plain": [ "Feat(backprop=True, batch_size=10, feature_names='sex,race',\n", " functions='and,or,not,split,split_c,mean,median,max,min,variance,skew,kurtosis,slope,count',\n", " max_depth=5, max_dim=5, max_time=30, normalize=False, random_state=42,\n", " simplify=0.01, verbosity=2)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fit to all data\n", "\n", "print('fitting longer to all data...')\n", "clf.gens = 100\n", "clf.fit(X,y,Z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "To see the learned representation, we run ``clf.get_representation()``:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'[slope(z_bmi)][max(z_glucose)]'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.get_representation()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here our final representation is composed of ``slope(z_bmi)`` and\n", "``max(z_glucose)``, both of which we know to be correct features for\n", "this simulated dataset. The best training representation displays clear\n", "overfitting, highlighting the importance of using archive validation for\n", "model selection.\n", "We can also look at the representation with the model weights, sorted by\n", "magnitude, using ``clf.get_model()``:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Weight\tFeature\n", "8.23132\toffset\n", "0.80\tslope(z_bmi)\n", "0.73\tmax(z_glucose)\n", "\n" ] } ], "source": [ "print(clf.get_model())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "View runtime stats\n", "------------------\n", "\n", "FEAT stores statistics about the training procedure in a dictionary `clf.stats_`. \n", "An example of plotting from this dictionary is shown below." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['generation', 'med_complexity', 'med_dim', 'med_loss', 'med_loss_v', 'med_num_params', 'med_size', 'min_loss', 'min_loss_v', 'time'])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.stats_.keys()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(clf.stats_['time'], clf.stats_['min_loss'], 'b', label='training loss')\n", "plt.plot(clf.stats_['time'], clf.stats_['min_loss_v'], 'r', label='validation loss')\n", "plt.legend()\n", "plt.xlabel('Time (s)')\n", "plt.ylabel('MSE')\n", "plt.gca().set_yscale('log')\n", "plt.gca().set_xscale('log')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(clf.stats_['time'], clf.stats_['med_complexity'], 'b', label='median complexity')\n", "# plt.plot(clf.stats_['time'], clf.stats_['med_size'], 'r', label='median size')\n", "plt.legend()\n", "plt.xlabel('Time (s)')\n", "plt.ylabel('Median Complexity')\n", "# plt.gca().set_yscale('log')\n", "plt.gca().set_xscale('log')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Visualizing the representation\n", "------------------------------\n", "\n", "Here we take the two relevant features and plot the data with them.\n", "This shows us the risk surface as a function of these learned features." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "proj: (993, 2)\n", "rep: ['slope(z_bmi)', 'max(z_glucose)']\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Visualize the representation\n", "\n", "proj = clf.transform(X,Z)\n", "\n", "print('proj:',proj.shape)\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import matplotlib.patheffects as PathEffects\n", "from matplotlib import cm\n", "\n", "cm = plt.cm.get_cmap('RdBu')\n", "# We choose a color palette with seaborn.\n", "# palette = np.array(sns.color_palette(\"cividis\", np.unique(y)))\n", "\n", "# We create a scatter plot.\n", "f = plt.figure(figsize=(6, 6))\n", "ax = plt.subplot(aspect='equal')\n", "sc = ax.scatter(proj[:,0], proj[:,1], lw=0, s=20,\n", " c=y, cmap=cm)\n", "plt.colorbar(sc)\n", "# sc.colorbar()\n", "ax.axis('square')\n", "# ax.axis('off')\n", "ax.axis('tight')\n", "\n", "# add labels from representation\n", "rep = [r.split('[')[-1] for r in clf.get_representation().split(']') if r != '']\n", "print('rep:',rep)\n", "plt.xlabel(rep[0])\n", "plt.ylabel(rep[1])\n", "\n", "# plt.savefig('longitudinal_representation.svg', dpi=120)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.9 ('feat')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "vscode": { "interpreter": { "hash": "d66ca72d61f5dfef1f2206d4040f625e956f2c7717f82cf69012f4e80879aa3c" } } }, "nbformat": 4, "nbformat_minor": 4 }