The archive#

When you fit a brush estimator, two new attributes are created: best_estimator_ and archive_.

If you set use_arch to True when instantiating the estimator, then it will store the pareto front as a list in archive_. This pareto front is always created with individuals from the final population that are not dominated in objectives scorer and complexity. Setting scorer as an objective means optimizing the metric set as scorer: str.

In case you need more flexibility, the archive will contain the entire final population if use_arch is False, and you can iterate through this list to select individuals with different criteria. It is also good to remind that Brush supports different optimization objectives using the argument objectives.

Each element from the archive is a serialized individual (JSON object).

import pandas as pd
from pybrush import BrushClassifier

# load data
df = pd.read_csv('../examples/datasets/d_analcatdata_aids.csv')
X = df.drop(columns='target')
y = df['target']
est = BrushClassifier(
    # functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
    use_arch=True,
    objectives=["scorer", "linear_complexity"],
    scorer='balanced_accuracy', # brush implements several metrics for clf and reg!
    max_gens=100,
    pop_size=100,
    max_depth=10,
    max_size=100,
    verbosity=1
)

est.fit(X,y)

print("Best model:", est.best_estimator_.get_model())
print('score:', est.score(X,y))
Completed 100% [====================]
Best model: Logistic(Sum(-0.24,If(AIDS>15890.50,1.00,If(Total>1572255.50,0.00,Asin(If(AIDS>123.00,1.00,1.00*Atan(Tan(AIDS))))))))
score: 0.9

You can see individuals from archive using the index:

print(len(est.archive_[0]))

est.archive_[0]
7
{'fitness': {'complexity': 80,
  'crowding_dist': 0.0,
  'dcounter': 0,
  'depth': 3,
  'dominated': [],
  'linear_complexity': 14,
  'loss': 0.5400000214576721,
  'loss_v': 0.5400000214576721,
  'rank': 1,
  'size': 5,
  'values': [0.5400000214576721, 14.0],
  'weights': [1.0, -1.0],
  'wvalues': [0.5400000214576721, -14.0]},
 'id': 264,
 'is_fitted_': False,
 'objectives': ['balanced_accuracy', 'linear_complexity'],
 'parent_id': [207],
 'program': {'Tree': [{'W': 1.0,
    'arg_types': ['ArrayF'],
    'center_op': False,
    'feature': '',
    'fixed': True,
    'is_weighted': False,
    'name': 'Logistic',
    'node_type': 'Logistic',
    'prob_change': 0.0,
    'ret_type': 'ArrayF',
    'sig_dual_hash': 10617925524997611780,
    'sig_hash': 13326223354425868050},
   {'W': 8.940696716308594e-08,
    'arg_types': ['ArrayF'],
    'center_op': False,
    'feature': '',
    'fixed': True,
    'is_weighted': True,
    'name': 'OffsetSum',
    'node_type': 'OffsetSum',
    'prob_change': 0.0,
    'ret_type': 'ArrayF',
    'sig_dual_hash': 10617925524997611780,
    'sig_hash': 13326223354425868050},
   {'W': 0.0,
    'arg_types': [],
    'center_op': True,
    'feature': 'MeanLabel',
    'fixed': False,
    'is_weighted': True,
    'name': 'MeanLabel',
    'node_type': 'MeanLabel',
    'prob_change': 0.19857122004032135,
    'ret_type': 'ArrayF',
    'sig_dual_hash': 509529941281334733,
    'sig_hash': 17717457037689164349}],
  'is_fitted_': True},
 'variation': 'subtree'}

And you can call predict (or predict_proba, if your est is an instance of BrushClassifier) with the entire archive:

est.predict_archive(X)
[{'id': 264,
  'y_pred': array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True, False, False])},
 {'id': 284,
  'y_pred': array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True, False,  True,  True,
          True,  True, False,  True,  True,  True, False,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True, False, False, False, False, False,
         False, False, False, False, False])},
 {'id': 260,
  'y_pred': array([False,  True,  True,  True, False, False,  True,  True,  True,
         False, False,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False])},
 {'id': 286,
  'y_pred': array([False,  True,  True,  True,  True, False,  True,  True,  True,
          True, False,  True,  True,  True, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False,  True,  True, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False])},
 {'id': 276,
  'y_pred': array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False,  True,  True,  True, False,  True, False, False,
          True,  True,  True,  True, False, False,  True, False,  True,
         False,  True, False,  True, False, False,  True, False, False,
          True,  True, False, False, False,  True, False, False, False,
         False, False, False, False,  True])},
 {'id': 251,
  'y_pred': array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False,  True,  True,  True, False, False, False, False,
          True,  True,  True,  True, False, False,  True, False, False,
         False, False, False,  True, False, False,  True, False, False,
          True,  True, False, False, False,  True, False, False, False,
         False, False, False, False,  True])},
 {'id': 227,
  'y_pred': array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False,  True,  True,  True, False,  True, False, False,
          True,  True,  True,  True, False,  True,  True, False, False,
          True, False, False,  True, False, False,  True, False, False,
          True,  True, False, False, False, False, False, False, False,
         False, False, False, False,  True])},
 {'id': 261,
  'y_pred': array([False,  True,  True,  True,  True, False,  True,  True,  True,
         False,  True,  True,  True,  True, False, False,  True,  True,
          True,  True, False,  True,  True,  True, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False,  True, False, False, False,  True, False, False,
         False, False, False, False, False])},
 {'id': 247,
  'y_pred': array([False,  True,  True,  True,  True, False,  True,  True,  True,
         False,  True,  True,  True,  True, False, False,  True,  True,
          True,  True, False,  True,  True,  True, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False])},
 {'id': 228,
  'y_pred': array([False,  True,  True,  True,  True, False,  True,  True,  True,
         False,  True,  True,  True,  True, False,  True,  True,  True,
          True,  True, False,  True,  True,  True, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False])},
 {'id': 230,
  'y_pred': array([False,  True,  True,  True,  True, False,  True,  True,  True,
         False,  True,  True,  True,  True, False,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False])}]
est.predict_proba_archive(X)
[{'id': 264,
  'y_pred': array([0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.50000006, 0.50000006,
         0.50000006, 0.50000006, 0.50000006, 0.5       , 0.5       ],
        dtype=float32)},
 {'id': 284,
  'y_pred': array([1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 0.0000000e+00,
         1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         0.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9985957e-01,
         7.8951918e-37, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         0.0000000e+00, 1.2521880e-21, 1.2202326e-11, 1.0412302e-27,
         8.6581284e-34, 0.0000000e+00, 8.4371766e-24, 3.1038638e-24,
         0.0000000e+00, 0.0000000e+00], dtype=float32)},
 {'id': 260,
  'y_pred': array([1.0081212e-35, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         1.2496071e-02, 5.2254774e-36, 1.0000000e+00, 1.0000000e+00,
         1.0000000e+00, 7.0163466e-14, 0.0000000e+00, 9.9999726e-01,
         1.0000000e+00, 1.3096678e-11, 5.5269798e-28, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 1.2658817e-27, 7.8919543e-26,
         1.5775963e-37, 0.0000000e+00, 0.0000000e+00, 1.5053514e-09,
         3.0856791e-03, 1.2355835e-24, 8.1867808e-36, 0.0000000e+00,
         4.9887171e-28, 1.6857511e-26, 1.5835223e-38, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00], dtype=float32)},
 {'id': 286,
  'y_pred': array([0.40833515, 0.9820269 , 0.9980566 , 0.9287328 , 0.67898846,
         0.40701428, 0.90484595, 0.9760761 , 0.8018543 , 0.5646876 ,
         0.38469985, 0.7565818 , 0.841811  , 0.58659005, 0.45134798,
         0.35905573, 0.3723781 , 0.38076416, 0.3690718 , 0.36326265,
         0.35882685, 0.36579695, 0.3663163 , 0.36100355, 0.3592465 ,
         0.36772186, 0.4536888 , 0.4658134 , 0.40029088, 0.3931104 ,
         0.38719437, 0.60710794, 0.6725419 , 0.47429875, 0.40791473,
         0.3677797 , 0.45106068, 0.4611968 , 0.39615962, 0.375852  ,
         0.35844555, 0.3599145 , 0.3603538 , 0.3596472 , 0.35938004,
         0.3583884 , 0.35981902, 0.35979992, 0.3589222 , 0.35859805],
        dtype=float32)},
 {'id': 276,
  'y_pred': array([7.8108364e-01, 7.4068010e-01, 7.4068010e-01, 7.4068010e-01,
         7.4068010e-01, 7.6991338e-01, 7.4068010e-01, 7.4068010e-01,
         7.4068010e-01, 4.2578113e-01, 3.9292258e-01, 7.4068010e-01,
         7.4068010e-01, 8.1273812e-01, 3.1526712e-01, 5.8899599e-01,
         1.8576130e-01, 4.1078791e-01, 7.4497253e-01, 6.1298168e-01,
         7.7351385e-01, 6.2793267e-01, 2.5962147e-01, 7.3486399e-03,
         7.7723646e-01, 4.9020657e-01, 5.6600595e-01, 2.5054170e-03,
         5.1463062e-01, 3.0869538e-01, 7.8104508e-01, 3.8264209e-01,
         5.0942674e-02, 9.2560732e-01, 4.2246539e-02, 4.5395932e-01,
         8.4166944e-01, 7.4169952e-01, 3.3088744e-01, 1.4896183e-01,
         4.3991685e-01, 9.8073673e-01, 2.6703531e-01, 4.8349547e-01,
         2.5045840e-20, 4.7675678e-01, 3.6594018e-01, 1.0172682e-14,
         4.0319046e-01, 9.9932098e-01], dtype=float32)},
 {'id': 251,
  'y_pred': array([0.6278971 , 0.60289556, 0.60289556, 0.60289556, 0.60289556,
         0.62100005, 0.60289556, 0.60289556, 0.60289556,        nan,
                nan, 0.60289556, 0.60289556, 0.6476004 ,        nan,
         0.4936501 ,        nan,        nan, 0.6055672 , 0.51452976,
         0.62322253, 0.5264537 ,        nan,        nan, 0.6255207 ,
                nan, 0.47062764,        nan, 0.38050255,        nan,
         0.62787324,        nan,        nan, 0.7290232 ,        nan,
                nan, 0.66614634, 0.60353065,        nan,        nan,
                nan, 0.80018765,        nan,        nan,        nan,
                nan,        nan,        nan,        nan, 0.891787  ],
        dtype=float32)},
 {'id': 227,
  'y_pred': array([0.6689781 , 0.6821383 , 0.6821383 , 0.6821383 , 0.6821383 ,
         0.66368186, 0.6821383 , 0.6821383 , 0.6821383 , 0.35940248,
         0.3313259 , 0.6821383 , 0.6821383 , 0.6796199 , 0.27467927,
         0.5172356 , 0.22523105, 0.34633443, 0.6495391 , 0.54054886,
         0.6654655 , 0.55476266, 0.24482079, 0.67573524, 0.667234  ,
         0.41946912, 0.49450168, 0.50200063, 0.4434057 , 0.2706287 ,
         0.66896105, 0.32298976, 0.40061048, 0.59498376, 0.44833356,
         0.38498515, 0.68186295, 0.64747214, 0.28481302, 0.2293548 ,
         0.37208086, 0.287013  , 0.24817689, 0.41297618, 0.23708849,
         0.40650123, 0.30995077, 0.27102837, 0.33987376, 0.64190036],
        dtype=float32)},
 {'id': 261,
  'y_pred': array([0.3447548 , 0.58851355, 0.58851355, 0.58851355, 0.58851355,
         0.3447548 , 0.58851355, 0.58851355, 0.58851355, 0.3447548 ,
         0.9996426 , 0.58851355, 0.58851355, 1.        , 0.3447548 ,
         0.40058893, 0.9812676 , 0.9987348 , 0.946967  , 0.7274303 ,
         0.38259757, 0.8596523 , 0.8789301 , 0.5594563 , 0.41579157,
         0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 ,
         0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 , 0.3447548 ,
         0.3447548 , 0.3447548 , 0.3447548 , 0.9999908 , 0.3447548 ,
         0.35333464, 0.47008803, 0.50624126, 0.4482104 , 0.42653137,
         0.3490325 , 0.46225694, 0.46069285, 0.390058  , 0.3649223 ],
        dtype=float32)},
 {'id': 247,
  'y_pred': array([0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.6820835 ,
         0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.44111538,
         0.7915279 , 0.6820835 , 0.6820835 , 0.7915279 , 0.44111538,
         0.44111538, 0.7915279 , 0.7915279 , 0.7915279 , 0.7915279 ,
         0.44111538, 0.7915279 , 0.7915279 , 0.7915279 , 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.7915279 , 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538],
        dtype=float32)},
 {'id': 228,
  'y_pred': array([0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.6820835 ,
         0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.44111538,
         0.7915279 , 0.6820835 , 0.6820835 , 0.7915279 , 0.44111538,
         0.5197103 , 0.7915279 , 0.7915279 , 0.7915279 , 0.7915279 ,
                nan, 0.7915279 , 0.7915279 , 0.7915279 ,        nan,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.7915279 , 0.44111538,
         0.37007034,        nan,        nan, 0.41278994,        nan,
         0.4062064 , 0.29336157,        nan, 0.33311623,        nan],
        dtype=float32)},
 {'id': 230,
  'y_pred': array([0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.6820835 ,
         0.44111538, 0.6820835 , 0.6820835 , 0.6820835 , 0.44111538,
         0.7915279 , 0.6820835 , 0.6820835 , 0.7915279 , 0.44111538,
         0.51724   , 0.7915279 , 0.7915279 , 0.7915279 , 0.7915279 ,
         0.6927554 , 0.7915279 , 0.7915279 , 0.7915279 , 0.69658685,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.44111538, 0.44111538,
         0.44111538, 0.44111538, 0.44111538, 0.7915279 , 0.44111538,
         0.3719757 ,        nan, 0.23363459, 0.41291445,        nan,
         0.40643886, 0.30856097,        nan, 0.3394832 ,        nan],
        dtype=float32)}]

Loading a specific model from archive#

We have a static class method called from_json which let’s you easily the string representation of the json from the archive to load an individual.

from pybrush import individual

loaded_from_arch = individual.ClassifierIndividual.from_json(est.archive_[-1])
print(loaded_from_arch.get_model())
print(loaded_from_arch.fitness)
Logistic(Sum(-0.24,If(AIDS>15890.50,1.00,If(Total>1572255.50,0.00,Asin(If(AIDS>123.00,1.00,1.00*Atan(Tan(AIDS))))))))
Fitness(0.900000 67.000000 )

To use this loaded model to do predictions, you need to wrap the data into a Dataset:

from pybrush import Dataset

loaded_from_arch.predict(Dataset(X=X, ref_dataset=est.data_, 
                              feature_names=est.feature_names_))
array([False,  True,  True,  True,  True, False,  True,  True,  True,
       False,  True,  True,  True,  True, False,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False,  True, False, False, False, False, False, False,
       False, False, False, False, False])

Visualizing the Pareto front of the archive#

import matplotlib.pyplot as plt

xs, ys = [], []
for ind in est.archive_:
    # We should look at the same objectives to get a valid pareto front
    xs.append(ind['fitness']['loss'])
    ys.append(ind['fitness']['linear_complexity'])

print(len(xs))
plt.scatter(xs, ys, alpha=0.25, c='b', linewidth=1.0)
plt.yscale('log')
plt.xlabel("Loss on validation partition (greater is better)")
plt.ylabel("Complexity (smaller is better)")
11
Text(0, 0.5, 'Complexity (smaller is better)')
../_images/de05b39b75031b604799e7971d6e82693a7e0be28eef725ee58a3434571952ea.png

Storing the population (unique individuals)#

If not using archive, then the unique individuals from the final population will be stored. Notice that, while the archive contains only the Pareto front (when use_arch=True), this will contain all individuals, even dominated ones.

est = BrushClassifier(
    # functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
    use_arch=True,
    objectives=["scorer", "linear_complexity"],
    max_depth=10,
    max_size=100,
    max_gens=100,
    pop_size=200,
    verbosity=1
)

est.fit(X,y)

print("Best model:", est.best_estimator_.get_model())
print('score:', est.score(X,y))
Completed 100% [====================]
Best model: Logistic(Sum(-1.39,If(AIDS>15890.50,17.24,If(Total>1572255.50,If(AIDS>2320.50,11.61*Logistic(If(Age>0.00,13.75,-21.10)),1.00*Mul(If(AIDS>491.50,-39544.36,Total),Add(-10166.09,0.00))),If(AIDS>123.00,If(AIDS>1653.50,If(AIDS>6597.50,15.18,Total),Add(15.18,1.00)),If(AIDS>51.50,-14.40,If(AIDS>20.00,If(Total>168995.50,15.98,-21.11),-14.01)))))))
score: 0.84
xs, ys = [], []
for ind in est.archive_:
    # use the same as the objectives
    xs.append(ind['fitness']['loss'])
    ys.append(ind['fitness']['linear_complexity'])

print(len(xs))
plt.scatter(xs, ys, alpha=0.25, c='b', linewidth=1.0)
plt.yscale('log')
plt.xlabel("Loss on validation partition (smaller is better)")
plt.ylabel("Complexity (smaller is better)")
42
Text(0, 0.5, 'Complexity (smaller is better)')
../_images/8236ac7f4e466a77bcb441e3791747b0e11c47a1dfa6bdede008ba458df931ce.png

(advanced) Interactive plot of archive#

# TODO