The archive#
When you fit a brush estimator, two new attributes are created: best_estimator_
and archive_
.
If you set use_arch
to True
when instantiating the estimator, then it will store the pareto front as a list in archive_
. This pareto front is always created with individuals from the final population that are not dominated in objectives error and complexity.
In case you need more flexibility, the archive will contain the entire final population if use_arch
is False
, and you can iterate through this list to select individuals with different criteria. It is also good to remind that Brush supports different optimization objectives using the argument objectives
.
Each element from the archive is a serialized individual (JSON object).
import pandas as pd
from pybrush import BrushClassifier
# load data
df = pd.read_csv('../examples/datasets/d_analcatdata_aids.csv')
X = df.drop(columns='target')
y = df['target']
est = BrushClassifier(
functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],
use_arch=True,
max_gens=100,
verbosity=1
)
est.fit(X,y)
y_pred = est.predict(X)
print('score:', est.score(X,y))
Completed 100% [====================]
score: 0.7
You can see individuals from archive using the index:
print(len(est.archive_[0]))
est.archive_[0]
5
{'fitness': {'complexity': 80,
'crowding_dist': 0.0,
'dcounter': 0,
'depth': 3,
'dominated': [],
'loss': 0.5091069936752319,
'loss_v': 0.5091069936752319,
'rank': 1,
'size': 12,
'values': [0.5091069936752319, 12.0],
'weights': [-1.0, -1.0],
'wvalues': [-0.5091069936752319, -12.0]},
'id': 10060,
'objectives': ['error', 'size'],
'parent_id': [9628],
'program': {'Tree': [{'W': 15890.5,
'arg_types': ['ArrayF', 'ArrayF'],
'center_op': True,
'feature': 'AIDS',
'fixed': False,
'is_weighted': False,
'name': 'SplitBest',
'node_type': 'SplitBest',
'prob_change': 1.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 9996486434638833164,
'sig_hash': 10001460114883919497},
{'W': 1.0,
'arg_types': ['ArrayF'],
'center_op': True,
'feature': '',
'fixed': False,
'is_weighted': False,
'name': 'Logabs',
'node_type': 'Logabs',
'prob_change': 1.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 10617925524997611780,
'sig_hash': 13326223354425868050},
{'W': 2.7182815074920654,
'arg_types': [],
'center_op': True,
'feature': 'Cf',
'fixed': False,
'is_weighted': False,
'name': 'Constant',
'node_type': 'Constant',
'prob_change': 1.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 509529941281334733,
'sig_hash': 17717457037689164349},
{'W': 1572255.5,
'arg_types': ['ArrayF', 'ArrayF'],
'center_op': True,
'feature': 'Total',
'fixed': False,
'is_weighted': False,
'name': 'SplitBest',
'node_type': 'SplitBest',
'prob_change': 1.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 9996486434638833164,
'sig_hash': 10001460114883919497},
{'W': 0.2222222238779068,
'arg_types': [],
'center_op': True,
'feature': 'MeanLabel',
'fixed': False,
'is_weighted': True,
'name': 'MeanLabel',
'node_type': 'MeanLabel',
'prob_change': 1.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 509529941281334733,
'sig_hash': 17717457037689164349},
{'W': 0.5217871069908142,
'arg_types': [],
'center_op': True,
'feature': 'Cf',
'fixed': False,
'is_weighted': False,
'name': 'Constant',
'node_type': 'Constant',
'prob_change': 1.0,
'ret_type': 'ArrayF',
'sig_dual_hash': 509529941281334733,
'sig_hash': 17717457037689164349}],
'is_fitted_': True}}
And you can call predict
(or predict_proba
, if your est
is an instance of BrushClassifier
) with the entire archive:
est.predict_archive(X)
[{'id': 10060,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, True, True, True, True, False, True, True, True,
True, True, True, True, True, True, True, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, False, True, True, True, True, True,
True, True, True, True, True])},
{'id': 9789,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, True, True, True, True, False, True, True, True,
True, True, True, True, True, True, True, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, False, True, True, True, True, True,
True, True, True, True, True])},
{'id': 10049,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, False, True, True, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False])},
{'id': 4384,
'y_pred': array([False, True, True, True, True, False, True, True, True,
False, False, True, True, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False])},
{'id': 9692,
'y_pred': array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True])},
{'id': 9552,
'y_pred': array([False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False])}]
est.predict_proba_archive(X)
[{'id': 10060,
'y_pred': array([0.22222222, 0.9999999 , 0.9999999 , 0.9999999 , 0.9999999 ,
0.22222222, 0.9999999 , 0.9999999 , 0.9999999 , 0.22222222,
0.5217871 , 0.9999999 , 0.9999999 , 0.5217871 , 0.22222222,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,
0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,
0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,
0.22222222, 0.22222222, 0.22222222, 0.5217871 , 0.22222222,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ],
dtype=float32)},
{'id': 9789,
'y_pred': array([0.22222222, 0.99994993, 0.99994993, 0.99994993, 0.99994993,
0.22222222, 0.99994993, 0.99994993, 0.99994993, 0.22222222,
0.5217871 , 0.99994993, 0.99994993, 0.5217871 , 0.22222222,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,
0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,
0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,
0.22222222, 0.22222222, 0.22222222, 0.5217871 , 0.22222222,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,
0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ],
dtype=float32)},
{'id': 10049,
'y_pred': array([0.39024392, 0.9999999 , 0.9999999 , 0.9999999 , 0.9999999 ,
0.39024392, 0.9999999 , 0.9999999 , 0.9999999 , 0.39024392,
0.39024392, 0.9999999 , 0.9999999 , 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392],
dtype=float32)},
{'id': 4384,
'y_pred': array([0.39024392, 0.9999522 , 0.9999522 , 0.9999522 , 0.9999522 ,
0.39024392, 0.9999522 , 0.9999522 , 0.9999522 , 0.39024392,
0.39024392, 0.9999522 , 0.9999522 , 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,
0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392],
dtype=float32)},
{'id': 9692,
'y_pred': array([0.5317098 , 0.93985564, 0.9835824 , 0.8686745 , 0.68970597,
0.53089285, 0.8455727 , 0.9291562 , 0.7663612 , 0.6237519 ,
0.5169323 , 0.7368382 , 0.794476 , 0.63628834, 0.5578266 ,
0.50047225, 0.50908357, 0.51443684, 0.506959 , 0.50320625,
0.5003231 , 0.50484663, 0.5051821 , 0.50173986, 0.5005965 ,
0.5060892 , 0.5592239 , 0.56642807, 0.5267187 , 0.5222307 ,
0.5185086 , 0.64804167, 0.68591666, 0.5714386 , 0.5314499 ,
0.50612646, 0.5576549 , 0.5636914 , 0.5241404 , 0.5113072 ,
0.50007457, 0.5010315 , 0.5013173 , 0.50085753, 0.50068355,
0.5000373 , 0.50096935, 0.50095695, 0.5003852 , 0.500174 ],
dtype=float32)},
{'id': 9552,
'y_pred': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
dtype=float32)}]