The archive#
When you fit a brush estimator, two new attributes are created: best_estimator_
and archive_
.
If you set use_arch
to True
when instantiating the estimator, then it will store the pareto front as a list in archive_
. This pareto front is always created with individuals from the final population that are not dominated in objectives scorer and complexity. Setting scorer
as an objective means optimizing the metric set as scorer: str
.
In case you need more flexibility, the archive will contain the entire final population if use_arch
is False
, and you can iterate through this list to select individuals with different criteria. It is also good to remind that Brush supports different optimization objectives using the argument objectives
.
Each element from the archive is a serialized individual (JSON object).
import pandas as pd
from pybrush import BrushClassifier
# load data
df = pd.read_csv('../examples/datasets/d_analcatdata_aids.csv')
X = df.drop(columns='target')
y = df['target']
est = BrushClassifier(
# functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
use_arch=True,
objectives=["scorer", "linear_complexity"],
scorer='balanced_accuracy', # brush implements several metrics for clf and reg!
max_gens=100,
pop_size=100,
max_depth=10,
max_size=100,
verbosity=1
)
est.fit(X,y)
print("Best model:", est.best_estimator_.get_model())
print('score:', est.score(X,y))
Completed 100% [====================]
Best model: Logistic(Sum(-0.24,If(AIDS>15890.50,1.00,If(Total>1572255.50,0.00,Asin(If(AIDS>123.00,1.00,1.00*Atan(Tan(AIDS))))))))
score: 0.9
You can see individuals from archive using the index:
print(len(est.archive_[0]))
est.archive_[0]
7
{'fitness': {'complexity': 80,
'crowding_dist': 0.0,
'dcounter': 0,
'depth': 3,
'dominated': [],
'linear_complexity': 14,
'loss': 0.5400000214576721,
'loss_v': 0.5400000214576721,
'rank': 1,
'size': 5,
'values': [0.5400000214576721, 14.0],
'weights': [1.0, -1.0],
'wvalues': [0.5400000214576721, -14.0]},
'id': 264,
'is_fitted_': False,
'objectives': ['balanced_accuracy', 'linear_complexity'],
'parent_id': [207],
'program': {'Tree': [{'W': 1.0,
'arg_types': ['ArrayF'],
'center_op': False,
'feature': '',
'fixed': True,
'is_weighted': False,
'name': 'Logistic',
'node_type': 'Logistic',
'prob_change': 0.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 10617925524997611780,
'sig_hash': 13326223354425868050},
{'W': 8.940696716308594e-08,
'arg_types': ['ArrayF'],
'center_op': False,
'feature': '',
'fixed': True,
'is_weighted': True,
'name': 'OffsetSum',
'node_type': 'OffsetSum',
'prob_change': 0.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 10617925524997611780,
'sig_hash': 13326223354425868050},
{'W': 0.0,
'arg_types': [],
'center_op': True,
'feature': 'MeanLabel',
'fixed': False,
'is_weighted': True,
'name': 'MeanLabel',
'node_type': 'MeanLabel',
'prob_change': 0.19857122004032135,
'ret_type': 'ArrayF',
'sig_dual_hash': 509529941281334733,
'sig_hash': 17717457037689164349}],
'is_fitted_': True},
'variation': 'subtree'}
And you can call predict
(or predict_proba
, if your est
is an instance of BrushClassifier
) with the entire archive:
est.predict_archive(X)
[{'id': 264,
'y_pred': array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, False, False])},
{'id': 284,
'y_pred': array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, False, True, True,
True, True, False, True, True, True, False, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, False, False, False, False,
False, False, False, False, False])},
{'id': 260,
'y_pred': array([False, True, True, True, False, False, True, True, True,
False, False, True, True, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False])},
{'id': 286,
'y_pred': array([False, True, True, True, True, False, True, True, True,
True, False, True, True, True, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, True, True, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False])},
{'id': 276,
'y_pred': array([ True, True, True, True, True, True, True, True, True,
False, False, True, True, True, False, True, False, False,
True, True, True, True, False, False, True, False, True,
False, True, False, True, False, False, True, False, False,
True, True, False, False, False, True, False, False, False,
False, False, False, False, True])},
{'id': 251,
'y_pred': array([ True, True, True, True, True, True, True, True, True,
False, False, True, True, True, False, False, False, False,
True, True, True, True, False, False, True, False, False,
False, False, False, True, False, False, True, False, False,
True, True, False, False, False, True, False, False, False,
False, False, False, False, True])},
{'id': 227,
'y_pred': array([ True, True, True, True, True, True, True, True, True,
False, False, True, True, True, False, True, False, False,
True, True, True, True, False, True, True, False, False,
True, False, False, True, False, False, True, False, False,
True, True, False, False, False, False, False, False, False,
False, False, False, False, True])},
{'id': 261,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, True, True, True, True, False, False, True, True,
True, True, False, True, True, True, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, False, False, False, True, False, False,
False, False, False, False, False])},
{'id': 247,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, True, True, True, True, False, False, True, True,
True, True, False, True, True, True, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, False, False, False, False, False, False,
False, False, False, False, False])},
{'id': 228,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, True, True, True, True, False, True, True, True,
True, True, False, True, True, True, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, False, False, False, False, False, False,
False, False, False, False, False])},
{'id': 230,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, True, True, True, True, False, True, True, True,
True, True, True, True, True, True, True, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, False, False, False, False, False, False,
False, False, False, False, False])}]
est.predict_proba_archive(X)
[{'id': 264,
'y_pred': array([0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
0.50000006, 0.50000006, 0.50000006, 0.5 , 0.5 ],
dtype=float32)},
{'id': 284,
'y_pred': array([1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 0.0000000e+00,
1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
0.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9985957e-01,
7.8951918e-37, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
0.0000000e+00, 1.2521880e-21, 1.2202326e-11, 1.0412302e-27,
8.6581284e-34, 0.0000000e+00, 8.4371766e-24, 3.1038638e-24,
0.0000000e+00, 0.0000000e+00], dtype=float32)},
{'id': 260,
'y_pred': array([1.0081212e-35, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
1.2496071e-02, 5.2254774e-36, 1.0000000e+00, 1.0000000e+00,
1.0000000e+00, 7.0163466e-14, 0.0000000e+00, 9.9999726e-01,
1.0000000e+00, 1.3096678e-11, 5.5269798e-28, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00, 1.2658817e-27, 7.8919543e-26,
1.5775963e-37, 0.0000000e+00, 0.0000000e+00, 1.5053514e-09,
3.0856791e-03, 1.2355835e-24, 8.1867808e-36, 0.0000000e+00,
4.9887171e-28, 1.6857511e-26, 1.5835223e-38, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00], dtype=float32)},
{'id': 286,
'y_pred': array([0.40833515, 0.9820269 , 0.9980566 , 0.9287328 , 0.67898846,
0.40701428, 0.90484595, 0.9760761 , 0.8018543 , 0.5646876 ,
0.38469985, 0.7565818 , 0.841811 , 0.58659005, 0.45134798,
0.35905573, 0.3723781 , 0.38076416, 0.3690718 , 0.36326265,
0.35882685, 0.36579695, 0.3663163 , 0.36100355, 0.3592465 ,
0.36772186, 0.4536888 , 0.4658134 , 0.40029088, 0.3931104 ,
0.38719437, 0.60710794, 0.6725419 , 0.47429875, 0.40791473,
0.3677797 , 0.45106068, 0.4611968 , 0.39615962, 0.375852 ,
0.35844555, 0.3599145 , 0.3603538 , 0.3596472 , 0.35938004,
0.3583884 , 0.35981902, 0.35979992, 0.3589222 , 0.35859805],
dtype=float32)},
{'id': 276,
'y_pred': array([7.8108364e-01, 7.4068010e-01, 7.4068010e-01, 7.4068010e-01,
7.4068010e-01, 7.6991338e-01, 7.4068010e-01, 7.4068010e-01,
7.4068010e-01, 4.2578113e-01, 3.9292258e-01, 7.4068010e-01,
7.4068010e-01, 8.1273812e-01, 3.1526712e-01, 5.8899599e-01,
1.8576130e-01, 4.1078791e-01, 7.4497253e-01, 6.1298168e-01,
7.7351385e-01, 6.2793267e-01, 2.5962147e-01, 7.3486399e-03,
7.7723646e-01, 4.9020657e-01, 5.6600595e-01, 2.5054170e-03,
5.1463062e-01, 3.0869538e-01, 7.8104508e-01, 3.8264209e-01,
5.0942674e-02, 9.2560732e-01, 4.2246539e-02, 4.5395932e-01,
8.4166944e-01, 7.4169952e-01, 3.3088744e-01, 1.4896183e-01,
4.3991685e-01, 9.8073673e-01, 2.6703531e-01, 4.8349547e-01,
2.5045840e-20, 4.7675678e-01, 3.6594018e-01, 1.0172682e-14,
4.0319046e-01, 9.9932098e-01], dtype=float32)},
{'id': 251,
'y_pred': array([0.6278971 , 0.60289556, 0.60289556, 0.60289556, 0.60289556,
0.62100005, 0.60289556, 0.60289556, 0.60289556, nan,
nan, 0.60289556, 0.60289556, 0.6476004 , nan,
0.4936501 , nan, nan, 0.6055672 , 0.51452976,
0.62322253, 0.5264537 , nan, nan, 0.6255207 ,
nan, 0.47062764, nan, 0.38050255, nan,
0.62787324, nan, nan, 0.7290232 , nan,
nan, 0.66614634, 0.60353065, nan, nan,
nan, 0.80018765, nan, nan, nan,
nan, nan, nan, nan, 0.891787 ],
dtype=float32)},
{'id': 227,
'y_pred': array([0.6689781 , 0.6821383 , 0.6821383 , 0.6821383 , 0.6821383 ,
0.66368186, 0.6821383 , 0.6821383 , 0.6821383 , 0.35940248,
0.3313259 , 0.6821383 , 0.6821383 , 0.6796199 , 0.27467927,
0.5172356 , 0.22523105, 0.34633443, 0.6495391 , 0.54054886,
0.6654655 , 0.55476266, 0.24482079, 0.67573524, 0.667234 ,
0.41946912, 0.49450168, 0.50200063, 0.4434057 , 0.2706287 ,
0.66896105, 0.32298976, 0.40061048, 0.59498376, 0.44833356,
0.38498515, 0.68186295, 0.64747214, 0.28481302, 0.2293548 ,
0.37208086, 0.287013 , 0.24817689, 0.41297618, 0.23708849,
0.40650123, 0.30995077, 0.27102837, 0.33987376, 0.64190036],
dtype=float32)},
{'id': 261,
'y_pred': array([0.3447548 , 0.58851355, 0.58851355, 0.58851355, 0.58851355,
0.3447548 , 0.58851355, 0.58851355, 0.58851355, 0.3447548 ,
0.9996426 , 0.58851355, 0.58851355, 1. , 0.3447548 ,
0.40058893, 0.9812676 , 0.9987348 , 0.946967 , 0.7274303 ,
0.38259757, 0.8596523 , 0.8789301 , 0.5594563 , 0.41579157,
0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 ,
0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 ,
0.3447548 , 0.3447548 , 0.3447548 , 0.9999908 , 0.3447548 ,
0.35333464, 0.47008803, 0.50624126, 0.4482104 , 0.42653137,
0.3490325 , 0.46225694, 0.46069285, 0.390058 , 0.3649223 ],
dtype=float32)},
{'id': 247,
'y_pred': array([0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.6820835 ,
0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.44111538,
0.7915279 , 0.6820835 , 0.6820835 , 0.7915279 , 0.44111538,
0.44111538, 0.7915279 , 0.7915279 , 0.7915279 , 0.7915279 ,
0.44111538, 0.7915279 , 0.7915279 , 0.7915279 , 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.7915279 , 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538],
dtype=float32)},
{'id': 228,
'y_pred': array([0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.6820835 ,
0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.44111538,
0.7915279 , 0.6820835 , 0.6820835 , 0.7915279 , 0.44111538,
0.5197103 , 0.7915279 , 0.7915279 , 0.7915279 , 0.7915279 ,
nan, 0.7915279 , 0.7915279 , 0.7915279 , nan,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.7915279 , 0.44111538,
0.37007034, nan, nan, 0.41278994, nan,
0.4062064 , 0.29336157, nan, 0.33311623, nan],
dtype=float32)},
{'id': 230,
'y_pred': array([0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.6820835 ,
0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.44111538,
0.7915279 , 0.6820835 , 0.6820835 , 0.7915279 , 0.44111538,
0.51724 , 0.7915279 , 0.7915279 , 0.7915279 , 0.7915279 ,
0.6927554 , 0.7915279 , 0.7915279 , 0.7915279 , 0.69658685,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
0.44111538, 0.44111538, 0.44111538, 0.7915279 , 0.44111538,
0.3719757 , nan, 0.23363459, 0.41291445, nan,
0.40643886, 0.30856097, nan, 0.3394832 , nan],
dtype=float32)}]
Loading a specific model from archive#
We have a static class method called from_json
which let’s you easily the string representation of the json from the archive to load an individual.
from pybrush import individual
loaded_from_arch = individual.ClassifierIndividual.from_json(est.archive_[-1])
print(loaded_from_arch.get_model())
print(loaded_from_arch.fitness)
Logistic(Sum(-0.24,If(AIDS>15890.50,1.00,If(Total>1572255.50,0.00,Asin(If(AIDS>123.00,1.00,1.00*Atan(Tan(AIDS))))))))
Fitness(0.900000 67.000000 )
To use this loaded model to do predictions, you need to wrap the data into a Dataset:
from pybrush import Dataset
loaded_from_arch.predict(Dataset(X=X, ref_dataset=est.data_,
feature_names=est.feature_names_))
array([False, True, True, True, True, False, True, True, True,
False, True, True, True, True, False, True, True, True,
True, True, True, True, True, True, True, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, False, False, False, False, False, False,
False, False, False, False, False])
Visualizing the Pareto front of the archive#
import matplotlib.pyplot as plt
xs, ys = [], []
for ind in est.archive_:
# We should look at the same objectives to get a valid pareto front
xs.append(ind['fitness']['loss'])
ys.append(ind['fitness']['linear_complexity'])
print(len(xs))
plt.scatter(xs, ys, alpha=0.25, c='b', linewidth=1.0)
plt.yscale('log')
plt.xlabel("Loss on validation partition (greater is better)")
plt.ylabel("Complexity (smaller is better)")
11
Text(0, 0.5, 'Complexity (smaller is better)')

Storing the population (unique individuals)#
If not using archive, then the unique individuals from the final population will be stored. Notice that, while the archive contains only the Pareto front (when use_arch=True
), this will contain all individuals, even dominated ones.
est = BrushClassifier(
# functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
use_arch=True,
objectives=["scorer", "linear_complexity"],
max_depth=10,
max_size=100,
max_gens=100,
pop_size=200,
verbosity=1
)
est.fit(X,y)
print("Best model:", est.best_estimator_.get_model())
print('score:', est.score(X,y))
Completed 100% [====================]
Best model: Logistic(Sum(-1.39,If(AIDS>15890.50,17.24,If(Total>1572255.50,If(AIDS>2320.50,11.61*Logistic(If(Age>0.00,13.75,-21.10)),1.00*Mul(If(AIDS>491.50,-39544.36,Total),Add(-10166.09,0.00))),If(AIDS>123.00,If(AIDS>1653.50,If(AIDS>6597.50,15.18,Total),Add(15.18,1.00)),If(AIDS>51.50,-14.40,If(AIDS>20.00,If(Total>168995.50,15.98,-21.11),-14.01)))))))
score: 0.84
xs, ys = [], []
for ind in est.archive_:
# use the same as the objectives
xs.append(ind['fitness']['loss'])
ys.append(ind['fitness']['linear_complexity'])
print(len(xs))
plt.scatter(xs, ys, alpha=0.25, c='b', linewidth=1.0)
plt.yscale('log')
plt.xlabel("Loss on validation partition (smaller is better)")
plt.ylabel("Complexity (smaller is better)")
42
Text(0, 0.5, 'Complexity (smaller is better)')

(advanced) Interactive plot of archive#
# TODO