Saving and loading populations
Another feature Brush implements is the ability to save and load entire populations. We use JSON notation to store the population into a file that is human readable. The same way, we can feed an estimator a previous population file to serve as starting point for the evolution.
In this notebook, we will walk through how to use the save_population
and load_population
parameters.
We start by getting a sample dataset and splitting it into X
and y
:
import pandas as pd
from pybrush import BrushRegressor
# load data
df = pd.read_csv('../examples/datasets/d_enc.csv')
X = df.drop(columns='label')
y = df['label']
To save the population after finishing the evolution, you nee to set save_population
parameter to a value different than an empty string. Then, the final population is going to be stored in that specific file.
In this example, we create a temporary file.
import pickle
import os, tempfile
pop_file = os.path.join(tempfile.mkdtemp(), 'population.json')
# set verbosity==2 to see the full report
est = BrushRegressor(
functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
max_gens=10,
objectives=["scorer", "complexity"],
scorer='mse',
save_population=pop_file,
verbosity=2
)
est.fit(X,y)
y_pred = est.predict(X)
print('score:', est.score(X,y))
Generation 1/10 [////// ]
Best model on Val:3.56*Add(If(x0>=0.76,1.18*x4,1.01*x4),4.46*x6)
Train Loss (Med): 13.87972 (61.04575)
Val Loss (Med): 12.52568 (73.18680)
Median Size (Max): 11 (76)
Median complexity (Max): 648 (1485456072)
Time (s): 0.30340
Generation 2/10 [/////////// ]
Best model on Val:2.26*Add(1.25*Cos(0.99*x2),1.18*Add(0.35*Mul(0.36*Add(495.40,-2.07*x3),0.33*Cos(1.12*Cos(0.99*Add(6.11*Cos(-1.55*Logabs(1.15*Logabs(1.02*x4))),0.99*Add(1.36*x4,0.99*Mul(0.99*x0,0.99*x3)))))),1.71*Add(If(x0>=0.76,-1.88*x0,2.32*x6),0.66*x4)))
Train Loss (Med): 7.19685 (18.29853)
Val Loss (Med): 9.99819 (19.73273)
Median Size (Max): 12 (90)
Median complexity (Max): 408 (1089403208)
Time (s): 0.46099
Generation 3/10 [//////////////// ]
Best model on Val:2.26*Add(1.25*Cos(0.99*x2),1.18*Add(0.35*Mul(0.36*Add(495.40,-2.07*x3),0.33*Cos(1.12*Cos(0.99*Add(6.11*Cos(-1.55*Logabs(1.15*Logabs(1.02*x4))),0.99*Add(1.36*x4,0.99*Mul(0.99*x0,0.99*x3)))))),1.71*Add(If(x0>=0.76,-1.88*x0,2.32*x6),0.66*x4)))
Train Loss (Med): 7.19685 (17.96509)
Val Loss (Med): 9.99819 (19.34284)
Median Size (Max): 9 (76)
Median complexity (Max): 232 (1179617608)
Time (s): 0.61931
Generation 4/10 [///////////////////// ]
Best model on Val:0.12*Add(0.51*Add(If(x0>=0.76,If(x0>=0.82,0.49*x1,3.17*x3),23.94*x4),264.10*x6),93.22*x0)
Train Loss (Med): 6.09925 (14.96953)
Val Loss (Med): 7.37980 (18.32659)
Median Size (Max): 18 (82)
Median complexity (Max): 24 (1714917288)
Time (s): 0.81694
Generation 5/10 [////////////////////////// ]
Best model on Val:0.22*Add(0.13*Add(If(x0>=0.76,If(x0>=0.82,x1,6.53*x3),49.27*x4),543.56*x6),49.88*x0)
Train Loss (Med): 6.09876 (14.38835)
Val Loss (Med): 7.37979 (16.98444)
Median Size (Max): 19 (82)
Median complexity (Max): 20 (1179617608)
Time (s): 1.10349
Generation 6/10 [/////////////////////////////// ]
Best model on Val:0.49*Add(0.32*Add(If(x0>=0.76,If(x0>=0.82,0.19*x1,1.21*x3),9.15*x4),100.96*x6),22.36*x0)
Train Loss (Med): 5.84923 (10.96850)
Val Loss (Med): 7.37968 (13.74235)
Median Size (Max): 21 (86)
Median complexity (Max): 20 (2011174224)
Time (s): 1.46493
Generation 7/10 [//////////////////////////////////// ]
Best model on Val:2.12*Add(1.90*Cos(0.99*x2),1.64*Add(0.28*Mul(0.28*Add(506.27,-1.88*x3),0.26*Cos(1.44*Cos(1.00*Add(4.00*Cos(-1.65*Logabs(1.21*Logabs(1.14*x4))),0.99*Add(1.29*x4,0.99*Mul(0.99*x0,0.99*x3)))))),0.71*Add(0.98*Mul(1.95*x0,4.53*x6),1.04*x4)))
Train Loss (Med): 3.15627 (10.75490)
Val Loss (Med): 3.90641 (12.99150)
Median Size (Max): 21 (81)
Median complexity (Max): 20 (2011174224)
Time (s): 1.88341
Generation 8/10 [///////////////////////////////////////// ]
Best model on Val:2.12*Add(1.90*Cos(0.99*x2),1.64*Add(0.28*Mul(0.28*Add(506.28,-1.88*x3),0.26*Cos(1.44*Cos(1.00*Add(4.00*Cos(-1.65*Logabs(1.21*Logabs(1.14*x4))),0.99*Add(1.29*x4,0.99*Mul(0.99*x0,0.99*x3)))))),0.71*Add(0.98*Mul(1.95*x0,4.53*x6),1.04*x4)))
Train Loss (Med): 3.15626 (7.32961)
Val Loss (Med): 3.90626 (9.54462)
Median Size (Max): 70 (83)
Median complexity (Max): 15 (2011174224)
Time (s): 2.35167
Generation 9/10 [////////////////////////////////////////////// ]
Best model on Val:2.28*Add(0.37*Cos(0.98*x2),1.79*Add(0.19*Mul(0.19*Add(511.93,-1.67*x3),0.18*Cos(2.98*Cos(1.00*Add(3.54*Cos(-1.62*Logabs(1.20*Logabs(1.12*x4))),1.00*Add(1.30*x4,1.22*Cos(0.99*Mul(0.99*x0,0.99*x3))))))),0.84*Add(0.88*Mul(1.75*x0,4.07*x6),1.31*x4)))
Train Loss (Med): 2.97018 (6.54713)
Val Loss (Med): 3.82776 (8.72991)
Median Size (Max): 72 (86)
Median complexity (Max): 15 (2011174224)
Time (s): 2.92181
Generation 10/10 [//////////////////////////////////////////////////]
Best model on Val:2.23*Add(0.18*Cos(1.01*x2),1.75*Add(0.16*Mul(0.16*Add(69.70*x4,-0.34*x3),0.15*Cos(2.60*Cos(1.01*Add(3.25*Cos(-1.63*Logabs(1.20*Logabs(1.11*x4))),1.00*Add(1.30*x4,1.32*Cos(0.99*Mul(0.99*x0,0.99*x3))))))),0.83*Add(0.86*Mul(1.70*x0,3.94*x6),1.30*x4)))
Train Loss (Med): 2.93713 (6.07609)
Val Loss (Med): 3.55941 (8.50959)
Median Size (Max): 72 (83)
Median complexity (Max): 17 (2011174224)
Time (s): 3.43889
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmpk_nrpah6/population.json
score: 0.9624956115764942
Loading a previous population is done providing load_population
a string value corresponding to a JSON file generated by Brush. In our case, we will use the same file from the previous code block.
After loading the population, we run the evolution for 10 more generations, and we can see that the first generation started from the previous population. This means that the population was successfully saved and loaded.
est = BrushRegressor(
functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
load_population=pop_file,
max_gens=10,
verbosity=1
)
est.fit(X,y)
y_pred = est.predict(X)
print('score:', est.score(X,y))
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmpk_nrpah6/population.json of size = 200
Completed 100% [====================]
score: 0.9649080920204649
There is a convenient way of accessing individuals on the population: just use the index of the individual on the est.population_
list.
# it will contain all individuals, differently than the archive
print("population size:", len(est.population_))
print("archive size :", len(est.archive_))
print( est.population_[0].get_model() )
population size: 100
archive size : 21
24.17
you can convert the json representation back to an fully functional individual by wrapping it in the individual class. It is important that the type of individual (i.e. classification, regression) is the same.
Differently from the archive (which is sorted by complexity), the individuals in the population have no specific order. So individual 5 may or may not be more complex than individual 10, for example.
ind2 = est.population_[2]
print(ind2.get_model("tree"))
If(x0>=0.76)
|- If(x0>=0.82)
| |- 33.39*x0
| |- 37.75
|- Add
| |- 1.46*Mul
| | |- 1.36*x4
| | |- 1.59*x6
| |- 3.86*x4
Saving just the archive
In case you want to use another expression rather than the final best_estimator_
, brush provides the archive option.
The archive is just the pareto front from the population. You can use predict
(and predict_proba
if using a BrushClassifier
) to call the prediction methods for the entire archive, instead of the selected best individual.
est = BrushRegressor(
functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
load_population=pop_file,
max_gens=10,
verbosity=1
)
est.fit(X,y)
# accessing first expression from the archive. It is serialized as a dict
print(est.archive_[0].fitness)
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmpk_nrpah6/population.json of size = 200
Completed 100% [====================]
Fitness(97.702599 2.000000 )
You can open the serialized file and change individuals’ programs manually.
This also allow us to have checkpoints in the execution.
Using population files with classification
To give another example, we do a two-step fit in the cells below.
First, we run the evolution and save the population to a file; then, we load it and keep evolving the individuals.
What is different though is that the first run is optimizing scorer
and complexity
, and the second run is optimizing average_precision_score
and linear_complexity
.
from pybrush import BrushClassifier
# load data
df = pd.read_csv('../examples/datasets/d_analcatdata_aids.csv')
X = df.drop(columns='target')
y = df['target']
pop_file = os.path.join(tempfile.mkdtemp(), 'population.json')
est = BrushClassifier(
functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
max_gens=10,
max_size=30,
objectives=["scorer", "complexity"],
scorer="log",
save_population=pop_file,
pop_size=200,
verbosity=2
)
est.fit(X,y)
print("Best model:", est.best_estimator_.get_model())
print('score:', est.score(X,y))
Generation 1/10 [////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.51436 (0.59792)
Val Loss (Med): 0.60301 (0.69320)
Median Size (Max): 7 (35)
Median complexity (Max): 992 (384239648)
Time (s): 0.16835
Generation 2/10 [/////////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.51436 (0.59792)
Val Loss (Med): 0.60301 (0.69320)
Median Size (Max): 7 (44)
Median complexity (Max): 992 (28766624)
Time (s): 0.30593
Generation 3/10 [//////////////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.51436 (0.57623)
Val Loss (Med): 0.60301 (0.70170)
Median Size (Max): 7 (44)
Median complexity (Max): 992 (28766624)
Time (s): 0.44703
Generation 4/10 [///////////////////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.51436 (0.57595)
Val Loss (Med): 0.60301 (0.70730)
Median Size (Max): 7 (44)
Median complexity (Max): 992 (108051872)
Time (s): 0.57940
Generation 5/10 [////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.50264 (0.57441)
Val Loss (Med): 0.60301 (0.70798)
Median Size (Max): 11 (44)
Median complexity (Max): 992 (180010400)
Time (s): 0.77661
Generation 6/10 [/////////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.50264 (0.56985)
Val Loss (Med): 0.60301 (0.70771)
Median Size (Max): 15 (44)
Median complexity (Max): 17312 (243874592)
Time (s): 0.94376
Generation 7/10 [//////////////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.50264 (0.56147)
Val Loss (Med): 0.60301 (0.70732)
Median Size (Max): 20 (45)
Median complexity (Max): 562592 (1440005536)
Time (s): 1.12788
Generation 8/10 [///////////////////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,1.00*Sin(1.00*Add(If(AIDS>=16068.00,1.00*Sin(AIDS),AIDS),If(AIDS>=16068.00,1.00,AIDS)))))
Train Loss (Med): 0.50264 (0.55554)
Val Loss (Med): 0.60301 (0.70742)
Median Size (Max): 19 (47)
Median complexity (Max): 562592 (1440005536)
Time (s): 1.31811
Generation 9/10 [////////////////////////////////////////////// ]
Best model on Val:Logistic(Sum(-0.61,0.40*Exp(-2.02*Sin(2.11*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(1.00*AIDS),-8.35*Sin(1.00*AIDS))),2.11)))))
Train Loss (Med): 0.50264 (0.55520)
Val Loss (Med): 0.57283 (0.70683)
Median Size (Max): 20 (45)
Median complexity (Max): 572192 (1440005536)
Time (s): 1.50146
Generation 10/10 [//////////////////////////////////////////////////]
Best model on Val:Logistic(Sum(-0.61,0.40*Exp(-2.02*Sin(2.11*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(1.00*AIDS),-8.35*Sin(1.00*AIDS))),2.11)))))
Train Loss (Med): 0.50264 (0.54342)
Val Loss (Med): 0.57283 (0.70766)
Median Size (Max): 19 (45)
Median complexity (Max): 562592 (612956576)
Time (s): 1.69596
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5cb1rhhe/population.json
Best model: Logistic(Sum(-0.61,0.40*Exp(-2.02*Sin(2.11*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(1.00*AIDS),-8.35*Sin(1.00*AIDS))),2.11)))))
score: 0.72
from sklearn.metrics import accuracy_score
accuracy_score(y, est.predict(X))
0.72
est = BrushClassifier(
functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
load_population=pop_file,
objectives=["scorer", "linear_complexity"],
scorer="average_precision_score",
max_gens=10,
pop_size=200, # make sure this is the same as loaded pop
verbosity=2
)
est.fit(X,y)
print("Best model:", est.best_estimator_.get_model())
print('score:', est.score(X,y))
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5cb1rhhe/population.json of size = 400
Generation 1/10 [////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.87814 (0.79892)
Val Loss (Med): 0.91111 (0.61905)
Median Size (Max): 14 (58)
Median complexity (Max): 38816 (1297560992)
Time (s): 0.23162
Generation 2/10 [/////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88508 (0.79892)
Val Loss (Med): 0.91111 (0.61905)
Median Size (Max): 9 (58)
Median complexity (Max): 1184 (109139360)
Time (s): 0.40799
Generation 3/10 [//////////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88640 (0.79892)
Val Loss (Med): 0.91111 (0.56444)
Median Size (Max): 7 (64)
Median complexity (Max): 752 (5625248)
Time (s): 0.55956
Generation 4/10 [///////////////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88640 (0.79892)
Val Loss (Med): 0.91111 (0.71429)
Median Size (Max): 7 (56)
Median complexity (Max): 992 (2154560)
Time (s): 0.72180
Generation 5/10 [////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88944 (0.32949)
Val Loss (Med): 0.91111 (0.46929)
Median Size (Max): 7 (62)
Median complexity (Max): 176 (1389827488)
Time (s): 0.86879
Generation 6/10 [/////////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88944 (0.32949)
Val Loss (Med): 0.91111 (0.46929)
Median Size (Max): 7 (62)
Median complexity (Max): 176 (1389827488)
Time (s): 1.04083
Generation 7/10 [//////////////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88944 (0.32949)
Val Loss (Med): 0.91111 (0.46929)
Median Size (Max): 7 (62)
Median complexity (Max): 176 (1389827488)
Time (s): 1.24817
Generation 8/10 [///////////////////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88944 (0.56420)
Val Loss (Med): 0.91111 (0.46929)
Median Size (Max): 7 (62)
Median complexity (Max): 176 (1389827488)
Time (s): 1.42332
Generation 9/10 [////////////////////////////////////////////// ]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88944 (0.79892)
Val Loss (Med): 0.91111 (0.60067)
Median Size (Max): 7 (62)
Median complexity (Max): 992 (1389827488)
Time (s): 1.61006
Generation 10/10 [//////////////////////////////////////////////////]
Best model on Val:Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Train Loss (Med): 0.88944 (0.79892)
Val Loss (Med): 0.91111 (0.59262)
Median Size (Max): 7 (62)
Median complexity (Max): 992 (1389827488)
Time (s): 1.80558
Best model: Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
score: 0.68
We can see the fitness object, and that the scorer now matches the average precision score metric:
Some differences may be due to part of the data being used as an inner validation partition
# Fitness is (scorer, linear complexity)
print(est.best_estimator_.fitness)
Fitness(0.911111 105.000000 )
from sklearn.metrics import average_precision_score
# takes y_true as first argument, and y_pred as second argument.
print("AUPRC:", average_precision_score(y, est.predict_proba(X)[:, 1], average='weighted'))
print("Model size:", est.best_estimator_.fitness.size)
print("Model:", est.best_estimator_.program.get_model())
AUPRC: 0.8318032741101248
Model size: 29
Model: Logistic(Sum(0.00,-0.79*Sin(2.12*Mul(0.83*Sin(-8.35*Mul(-8.39*Logabs(0.99*AIDS),-8.35*Sin(1.00*AIDS))),2.11))))
Serialization with pickle
You can save the entire model (best individual, parameters, and archive) with pickle.
At the current stage, Brush does not serialize the search space and dataset references, but only the necessary information to be able to load a previously trained model and do predictions with it.
est
BrushClassifier(algorithm='nsga2', bandit='dynamic_thompson', batch_size=1.0, class_weights='support', constants_simplification=True, cx_prob=0.14285714285714285, final_model_selection='', functions=['SplitBest', 'Add', 'Mul', 'Sin', 'Cos', 'Exp', 'Logabs'], inexact_simplification=True, initialization='uniform', load_population='/var/folders/mh/ggnb_jv93cl_gbqbg2j... 'insert': 0.16666666666666666, 'point': 0.16666666666666666, 'subtree': 0.16666666666666666, 'toggle_weight_off': 0.16666666666666666, 'toggle_weight_on': 0.16666666666666666}, n_jobs=1, num_islands=5, objectives=['scorer', 'linear_complexity'], pop_size=200, random_state=None, save_population='', scorer='average_precision_score', sel='lexicase', shuffle_split=False, surv='nsga2', ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
pop_size | 200 | |
max_gens | 10 | |
max_stall | 0 | |
max_time | -1 | |
verbosity | 2 | |
algorithm | 'nsga2' | |
mode | 'classification' | |
max_depth | 10 | |
max_size | 100 | |
num_islands | 5 | |
mig_prob | 0.05 | |
n_jobs | 1 | |
cx_prob | 0.14285714285714285 | |
bandit | 'dynamic_thompson' | |
logfile | '' | |
final_model_selection | '' | |
save_population | '' | |
load_population | '/var/folders/mh/ggnb_j...cb1rhhe/population.json' | |
mutation_probs | {'delete': 0.16666666666666666, 'insert': 0.16666666666666666, 'point': 0.16666666666666666, 'subtree': 0.16666666666666666, ...} | |
functions | ['SplitBest', 'Add', ...] | |
objectives | ['scorer', 'linear_complexity'] | |
constants_simplification | True | |
inexact_simplification | True | |
scorer | 'average_precision_score' | |
shuffle_split | False | |
initialization | 'uniform' | |
random_state | None | |
batch_size | 1.0 | |
sel | 'lexicase' | |
surv | 'nsga2' | |
weights_init | True | |
validation_size | 0.2 | |
class_weights | 'support' |
import pickle
est_file = os.path.join(tempfile.mkdtemp(), 'est.pkl')
with open(est_file, 'wb') as f:
pickle.dump(est, f)
loaded_est = pickle.load(open(est_file, 'rb'))
print(est.predict(X))
print(loaded_est.predict(X))
[ True False False True False True False True True True True True
True True True True True False True True True True False True
True True True False True False True False False True False False
False True True False False False False True False True False False
True False]
[ True False False True False True False True True True True True
True True True True True False True True True True False True
True True True False True False True False False True False False
False True True False False False False True False True False False
True False]
print(est.predict(X))
print(loaded_est.predict(X))
[ True False False True False True False True True True True True
True True True True True False True True True True False True
True True True False True False True False False True False False
False True True False False False False True False True False False
True False]
[ True False False True False True False True True True True True
True True True True True False True True True True False True
True True True False True False True False False True False False
False True True False False False False True False True False False
True False]
Stop/resume the fitting of an estimator
In the code below I try to mimic how pytorch models are trained: we can stop the training at any time, and we can resume it later.
The idea is to demonstrate how to use population files to store checkpoints, and continuing from the last saved checkpoint.
def train(est, X, y):
checkpoint = os.path.join(tempfile.mkdtemp(), 'brush_pop_checkpoint.json')
step = 5
max_gens = est.max_gens
est.max_gens = step
est.save_population = checkpoint
est.load_population = ""
# You can set validation_size to a value greater than zero
# and shuffle_split to true to have random bathes of data
est.shuffle_split = True
est.validation_size = 0.2
for g in range(max_gens // step):
print(f"Progress {g + 1}/{max_gens // step}")
est.fit(X, y) # Notice that this will reset the MAB everytime!
# Enable loading the checkpoint after a first run
est.load_population = checkpoint
print("Best model:", est.best_estimator_.get_model())
print('score :', est.score(X, y))
# Restoring initial state
est.max_gens = max_gens
est = BrushClassifier(
objectives=["scorer", "linear_complexity"],
scorer="balanced_accuracy",
max_gens=50,
validation_size=0.2,
pop_size=100,
max_depth=20,
max_size=50,
verbosity=1
)
train(est, X, y)
Progress 1/10
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.30,0.00*AIDS))
score : 0.68
Progress 2/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.24,0.00*AIDS))
score : 0.68
Progress 3/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.65,0.65))
score : 0.5
Progress 4/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.26,0.00*AIDS))
score : 0.68
Progress 5/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-28.50,If(AIDS>=16068.00,AIDS,1.00)))
score : 0.68
Progress 6/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.25,0.00*AIDS))
score : 0.68
Progress 7/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.32,0.00*AIDS))
score : 0.68
Progress 8/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-28.50,If(AIDS>=16068.00,AIDS,1.00)))
score : 0.68
Progress 9/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.19,0.00*AIDS))
score : 0.66
Progress 10/10
Loaded population from /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json of size = 200
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Best model: Logistic(Sum(-0.20,0.00*AIDS))
score : 0.66
By default, sklearn estimators will reset when calling fit twice. To continue from last fit, you can call partial_fit
and brush will resume the training.
If you want, you can change parameters from the est
object before calling partial fit to update the execution settings.
It is important that the data has the same features (same name and dtype) as the data used in the previous fit
/partial_fit
.
print(est.best_estimator_.get_model())
print(est.best_estimator_.fitness)
est.partial_fit(X, y)
print(est.best_estimator_.get_model())
print(est.best_estimator_.fitness)
Logistic(Sum(-0.20,0.00*AIDS))
Fitness(1.000000 22.000000 )
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Logistic(Sum(-0.27,0.00*AIDS))
Fitness(0.700000 22.000000 )
The partial_fit
also allows you to fix an initial portion of the tree before doing the new fit.
You can also choose to leave leaves out of this locking mechanism, this way the terminals close to the root are unlocked and can change.
If you set a big depth and also force leaves to be locked, there may be some (smaller) programs in the population that will not change at all during the run.
print(est.best_estimator_.get_model())
print(est.best_estimator_.fitness)
est.partial_fit(X, y, lock_nodes_depth=2, keep_leaves_unlocked=True)
print(est.best_estimator_.get_model())
print(est.best_estimator_.fitness)
Logistic(Sum(-0.27,0.00*AIDS))
Fitness(0.700000 22.000000 )
Completed 100% [====================]
Saved population to file /var/folders/mh/ggnb_jv93cl_gbqbg2jxq0yr0000gn/T/tmp5fzgqu_7/brush_pop_checkpoint.json
Logistic(Sum(-0.24,0.00*AIDS))
Fitness(0.700000 22.000000 )