{ "cells": [ { "cell_type": "markdown", "id": "83778a8a", "metadata": {}, "source": [ "# Upstream Usage Example" ] }, { "cell_type": "markdown", "id": "0e52dafd", "metadata": {}, "source": [ "In this example we cover advanced usage of `BoolXAI.RuleClassifier` via upstream classification - i.e., using an ensemble of rule classifiers. " ] }, { "cell_type": "markdown", "id": "c0e96408", "metadata": {}, "source": [ "## Input data\n", "\n", "We'll start with the same binarized data we used in the Basic Usage Example. In order to speed up the execution, we'll only use a subset of the data:" ] }, { "cell_type": "code", "execution_count": 1, "id": "36f1a160", "metadata": { "execution": { "iopub.execute_input": "2023-05-08T23:04:28.645497Z", "iopub.status.busy": "2023-05-08T23:04:28.645186Z", "iopub.status.idle": "2023-05-08T23:04:29.809460Z", "shell.execute_reply": "2023-05-08T23:04:29.808578Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from sklearn import set_config\n", "from sklearn import datasets\n", "from sklearn.metrics import balanced_accuracy_score\n", "\n", "from boolxai import BoolXAI\n", "from util import BoolXAIKBinsDiscretizer\n", "\n", "set_config(transform_output=\"pandas\")\n", "\n", "X, y = datasets.load_breast_cancer(return_X_y=True, as_frame=True)\n", "\n", "# Use a subset of the data to speed up execution.\n", "# For higher quality results, comment these lines out.\n", "X = X.iloc[:100, :100]\n", "y = y.iloc[:100]\n", "\n", "# Binarize the data\n", "binarizer = BoolXAIKBinsDiscretizer(\n", " n_bins=10, strategy=\"quantile\", encode=\"onehot-dense\"\n", ")\n", "X_binarized = binarizer.fit_transform(X)\n", "X_binarized.head();" ] }, { "cell_type": "markdown", "id": "f66af7de", "metadata": {}, "source": [ "## Boosting\n", "\n", "Boosting defines a meta-classifier in which copies of a classifier are trained iteratively such that they focus on the most difficult samples to predict, given the previously trained classifiers. The final result is obtained by combining the weighted predictions from these sub-classifiers. Boosting can be used to greatly improve the results provided by a weak learner, such as our highly regularized rule classifiers.\n", "\n", "As a baseline result, we'll train and evaluate a rule classifier without boosting:" ] }, { "cell_type": "code", "execution_count": 2, "id": "6b9b09a5", "metadata": { "execution": { "iopub.execute_input": "2023-05-08T23:04:29.813813Z", "iopub.status.busy": "2023-05-08T23:04:29.813418Z", "iopub.status.idle": "2023-05-08T23:04:48.767991Z", "shell.execute_reply": "2023-05-08T23:04:48.767143Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Without boosting: score=0.91\n" ] } ], "source": [ "seed = 43\n", "rule_classifier = BoolXAI.RuleClassifier(random_state=seed)\n", "rule_classifier.fit(X_binarized, y)\n", "y_predict = rule_classifier.predict(X_binarized)\n", "score = balanced_accuracy_score(y, y_predict)\n", "print(f\"Without boosting: {score=:.2f}\")" ] }, { "cell_type": "markdown", "id": "4b29646b", "metadata": {}, "source": [ "Now we'll use `sklearn`'s `AdaBoostClassifier` to create a boosted classifier with 5 underlying sub-classifiers:" ] }, { "cell_type": "code", "execution_count": 3, "id": "8f225237", "metadata": { "execution": { "iopub.execute_input": "2023-05-08T23:04:48.771978Z", "iopub.status.busy": "2023-05-08T23:04:48.771686Z", "iopub.status.idle": "2023-05-08T23:05:37.996554Z", "shell.execute_reply": "2023-05-08T23:05:37.995175Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "With boosting: score=1.00\n" ] } ], "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "\n", "boosted_rule_classifier = AdaBoostClassifier(\n", " BoolXAI.RuleClassifier(random_state=seed),\n", " n_estimators=5,\n", " algorithm=\"SAMME\",\n", " random_state=seed,\n", ")\n", "boosted_rule_classifier.fit(X_binarized, y)\n", "y_predict = boosted_rule_classifier.predict(X_binarized)\n", "score = balanced_accuracy_score(y, y_predict)\n", "print(f\"With boosting: {score=:.2f}\")" ] }, { "cell_type": "markdown", "id": "50a62922", "metadata": {}, "source": [ "This score is clearly much higher than the one obtained without boosting. Note that this comes at the price of additional training (and inference) time, but also a higher complexity (lower interpretability). Also, boosted classifiers are more likely to overfit the data. This could be evaluated by comparing the performance on out of sample data (e.g., via cross-validation). \n", "\n", "We can print the best rule for each of the sub-classifiers:" ] }, { "cell_type": "code", "execution_count": 4, "id": "ed962602", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AtLeast1([worst perimeter<77.186], [compactness error<0.0105], [0.0546<=mean compactness<0.0719], [0.0074<=concave points error<0.0087], [area error<16.119])\n", "Or([0.0049<=smoothness error<0.0055], [565.14<=worst area<648.84], [365.18<=mean area<448.1], [0.0719<=mean compactness<0.0871], [worst radius<12.199])\n", "AtLeast1([87.668<=worst perimeter<96.489], Choose1([perimeter error<1.5335], [15.614<=worst radius<16.5]), [0.2212<=worst concavity<0.2656])\n", "And(~[0.1908<=worst compactness<0.2131], ~[0.0189<=compactness error<0.0245], ~[0.156<=worst concave points<0.1726], ~[worst area>=1713.1], ~[0.034<=concavity error<0.0388])\n", "Or([mean area<365.18], [21.328<=worst texture<23.204], [worst texture<19.063], [0.0899<=worst concavity<0.1394])\n" ] } ], "source": [ "# Print the best rule for each of the sub-classifiers inside boosted_rule_classifier\n", "for subclassifier in boosted_rule_classifier.estimators_:\n", " print(subclassifier.best_rule_.to_str(X_binarized.columns))" ] }, { "cell_type": "markdown", "id": "d07a0911", "metadata": {}, "source": [ "They are clearly very different from one another. We can also score each of the sub-classifiers separately:" ] }, { "cell_type": "code", "execution_count": 5, "id": "29de7eb3", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "score=0.88\n", "score=0.87\n", "score=0.69\n", "score=0.74\n", "score=0.80\n" ] } ], "source": [ "for subclassifier in boosted_rule_classifier.estimators_:\n", " y_predict = subclassifier.predict(X_binarized)\n", " score = balanced_accuracy_score(y, y_predict)\n", " print(f\"{score=:.2f}\")" ] }, { "cell_type": "markdown", "id": "c054f3b2", "metadata": {}, "source": [ "and inspect the weight given to each of the sub-classifiers:" ] }, { "cell_type": "code", "execution_count": 6, "id": "7204d7d7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2.19722458, 2.47293046, 1.64179982, 1.61359288, 2.40632095])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "boosted_rule_classifier.estimator_weights_" ] }, { "cell_type": "markdown", "id": "9c8bc506", "metadata": {}, "source": [ "## Multiclass classification\n", "\n", "In this case instead of having two classes (say 0 and 1), we have more than two classes. `sklearn` provides several ways of converting a binary classifier to a multiclass classifier. We choose to use `OneVsRestClassifier` rather than `OneVsOneClassifier` or `OutputCodeClassifier` since it is far more interpretable. `OneVsOneClassifier` trains a subclassifier for each class on a binary classification task defined by labels consisting of the chosen class, and a label including all other classes.\n", "\n", "First, we load a multiclass classification dataset:" ] }, { "cell_type": "code", "execution_count": 7, "id": "28cabda7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unique labels:\n" ] }, { "data": { "text/plain": [ "array([0, 1, 2])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn import datasets\n", "\n", "X, y = datasets.load_iris(return_X_y=True, as_frame=True)\n", "\n", "print(\"Unique labels:\")\n", "np.unique(y)" ] }, { "cell_type": "code", "execution_count": 8, "id": "1403a732", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(150, 4)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)
05.13.51.40.2
14.93.01.40.2
24.73.21.30.2
34.63.11.50.2
45.03.61.40.2
\n", "
" ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n", "0 5.1 3.5 1.4 0.2\n", "1 4.9 3.0 1.4 0.2\n", "2 4.7 3.2 1.3 0.2\n", "3 4.6 3.1 1.5 0.2\n", "4 5.0 3.6 1.4 0.2" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Inspect the data\n", "print(X.shape)\n", "X.head()" ] }, { "cell_type": "markdown", "id": "cfd5dc2e", "metadata": {}, "source": [ "We binarize the data as in the Basic Usage Example, but this time we use a smaller number of bins:" ] }, { "cell_type": "code", "execution_count": 9, "id": "1364f4c9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(150, 16)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
[sepal length (cm)<5.1][5.1<=sepal length (cm)<5.8][5.8<=sepal length (cm)<6.4][sepal length (cm)>=6.4][sepal width (cm)<2.8][2.8<=sepal width (cm)<3.0][3.0<=sepal width (cm)<3.3][sepal width (cm)>=3.3][petal length (cm)<1.6][1.6<=petal length (cm)<4.35][4.35<=petal length (cm)<5.1][petal length (cm)>=5.1][petal width (cm)<0.3][0.3<=petal width (cm)<1.3][1.3<=petal width (cm)<1.8][petal width (cm)>=1.8]
00.01.00.00.00.00.00.01.01.00.00.00.01.00.00.00.0
11.00.00.00.00.00.01.00.01.00.00.00.01.00.00.00.0
21.00.00.00.00.00.01.00.01.00.00.00.01.00.00.00.0
31.00.00.00.00.00.01.00.01.00.00.00.01.00.00.00.0
41.00.00.00.00.00.00.01.01.00.00.00.01.00.00.00.0
\n", "
" ], "text/plain": [ " [sepal length (cm)<5.1] [5.1<=sepal length (cm)<5.8] \n", "0 0.0 1.0 \\\n", "1 1.0 0.0 \n", "2 1.0 0.0 \n", "3 1.0 0.0 \n", "4 1.0 0.0 \n", "\n", " [5.8<=sepal length (cm)<6.4] [sepal length (cm)>=6.4] \n", "0 0.0 0.0 \\\n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 \n", "\n", " [sepal width (cm)<2.8] [2.8<=sepal width (cm)<3.0] \n", "0 0.0 0.0 \\\n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 \n", "\n", " [3.0<=sepal width (cm)<3.3] [sepal width (cm)>=3.3] \n", "0 0.0 1.0 \\\n", "1 1.0 0.0 \n", "2 1.0 0.0 \n", "3 1.0 0.0 \n", "4 0.0 1.0 \n", "\n", " [petal length (cm)<1.6] [1.6<=petal length (cm)<4.35] \n", "0 1.0 0.0 \\\n", "1 1.0 0.0 \n", "2 1.0 0.0 \n", "3 1.0 0.0 \n", "4 1.0 0.0 \n", "\n", " [4.35<=petal length (cm)<5.1] [petal length (cm)>=5.1] \n", "0 0.0 0.0 \\\n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 \n", "\n", " [petal width (cm)<0.3] [0.3<=petal width (cm)<1.3] \n", "0 1.0 0.0 \\\n", "1 1.0 0.0 \n", "2 1.0 0.0 \n", "3 1.0 0.0 \n", "4 1.0 0.0 \n", "\n", " [1.3<=petal width (cm)<1.8] [petal width (cm)>=1.8] \n", "0 0.0 0.0 \n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Binarize the data\n", "binarizer = BoolXAIKBinsDiscretizer(\n", " n_bins=4, strategy=\"quantile\", encode=\"onehot-dense\"\n", ")\n", "X_binarized = binarizer.fit_transform(X)\n", "X_binarized.head()\n", "print(X_binarized.shape)\n", "X_binarized.head()" ] }, { "cell_type": "markdown", "id": "e827abeb", "metadata": {}, "source": [ "With the data in hand, we can now use `OneVsRestClassifier` to implicitly train multiple rule classifiers, one for each class, combined into a single classifier:" ] }, { "cell_type": "code", "execution_count": 10, "id": "ec7e1236", "metadata": {}, "outputs": [], "source": [ "from sklearn.multiclass import OneVsRestClassifier\n", "\n", "from boolxai import BoolXAI\n", "\n", "# Instantiate a multiclass rule classifier and fit it\n", "multiclass_rule_classifier = OneVsRestClassifier(\n", " BoolXAI.RuleClassifier(random_state=43)\n", ")\n", "multiclass_rule_classifier.fit(X_binarized, y);" ] }, { "cell_type": "markdown", "id": "4df66a0e", "metadata": {}, "source": [ "We can make predictions and calculate scores as usual with the combined classifier:" ] }, { "cell_type": "code", "execution_count": 11, "id": "963723ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "score=0.96\n" ] } ], "source": [ "# Apply Rules\n", "y_predict = multiclass_rule_classifier.predict(X_binarized)\n", "score = balanced_accuracy_score(y, y_predict)\n", "print(f\"{score=:.2f}\")" ] }, { "cell_type": "markdown", "id": "3c57b6ba", "metadata": {}, "source": [ "We can also print out the best rule for used internally by `OneVsRestClassifier` for each of the classes:" ] }, { "cell_type": "code", "execution_count": 12, "id": "ea4e9017", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Or(And(~[petal width (cm)>=1.8], [sepal width (cm)>=3.3]), [petal length (cm)<1.6], [petal width (cm)<0.3])\n", "And(~[petal width (cm)>=1.8], ~[petal length (cm)<1.6], ~[petal width (cm)<0.3], ~[sepal width (cm)>=3.3], ~[petal length (cm)>=5.1])\n", "Or([petal width (cm)>=1.8], [petal length (cm)>=5.1])\n" ] } ], "source": [ "# Print the best rule for each of the sub-classifiers inside multiclass_rule_classifier\n", "for subclassifier in multiclass_rule_classifier.estimators_:\n", " print(subclassifier.best_rule_.to_str(X_binarized.columns))" ] }, { "cell_type": "markdown", "id": "c56dbe2f", "metadata": {}, "source": [ "## Multilabel classification\n", "\n", "In the multilabel (multioutput) case, we have labels that consist of multiple binary numbers. We can fit a classifier to each output using `MultiOutputClassifier`. To try this out, we first need some data. We can generate a small synthetic dataset with `make_multilabel_classification()`:" ] }, { "cell_type": "code", "execution_count": 13, "id": "ca80c411", "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_multilabel_classification\n", "\n", "X, y = make_multilabel_classification(\n", " n_classes=3, n_samples=200, n_features=4, random_state=seed\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "d567f8af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unique labels:\n" ] }, { "data": { "text/plain": [ "array([[0, 0, 0],\n", " [0, 0, 1],\n", " [0, 1, 0],\n", " [0, 1, 1],\n", " [1, 0, 0],\n", " [1, 1, 0],\n", " [1, 1, 1]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Unique labels:\")\n", "np.unique(y, axis=0)" ] }, { "cell_type": "markdown", "id": "18607031", "metadata": {}, "source": [ "Let's inspect the data:" ] }, { "cell_type": "code", "execution_count": 15, "id": "3f2fdab2", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(200, 4)\n", "[[10. 23. 13. 2.]\n", " [13. 10. 16. 11.]\n", " [12. 16. 13. 9.]\n", " [16. 18. 20. 7.]\n", " [11. 12. 12. 5.]\n", " [16. 22. 15. 7.]\n", " [10. 14. 24. 17.]\n", " [ 8. 7. 16. 9.]\n", " [21. 13. 15. 10.]\n", " [20. 19. 15. 9.]\n", " [ 6. 11. 14. 15.]\n", " [11. 12. 16. 8.]\n", " [ 7. 10. 13. 6.]\n", " [11. 9. 11. 11.]\n", " [18. 1. 11. 16.]\n", " [15. 14. 17. 15.]\n", " [ 6. 19. 13. 3.]\n", " [13. 11. 12. 15.]\n", " [10. 10. 10. 11.]\n", " [12. 16. 16. 2.]\n", " [ 9. 16. 21. 3.]\n", " [14. 10. 13. 11.]\n", " [13. 16. 14. 4.]\n", " [ 5. 15. 11. 21.]\n", " [ 8. 13. 10. 14.]\n", " [11. 12. 9. 12.]\n", " [14. 28. 18. 2.]\n", " [12. 19. 17. 4.]\n", " [10. 23. 19. 6.]\n", " [18. 10. 16. 10.]\n", " [11. 19. 26. 1.]\n", " [22. 14. 7. 17.]\n", " [10. 26. 22. 5.]\n", " [ 1. 22. 21. 9.]\n", " [ 8. 12. 16. 4.]\n", " [13. 11. 15. 10.]\n", " [ 4. 9. 23. 3.]\n", " [11. 16. 19. 15.]\n", " [16. 12. 18. 12.]\n", " [ 8. 13. 10. 16.]\n", " [11. 24. 18. 2.]\n", " [16. 15. 17. 8.]\n", " [ 7. 9. 16. 13.]\n", " [13. 14. 12. 11.]\n", " [ 8. 19. 12. 1.]\n", " [14. 20. 10. 1.]\n", " [11. 18. 20. 4.]\n", " [ 5. 15. 16. 10.]\n", " [18. 20. 17. 0.]\n", " [12. 14. 11. 13.]\n", " [10. 12. 18. 8.]\n", " [ 9. 20. 21. 4.]\n", " [11. 8. 15. 7.]\n", " [22. 7. 18. 11.]\n", " [14. 23. 12. 1.]\n", " [13. 26. 17. 3.]\n", " [14. 20. 12. 10.]\n", " [13. 13. 18. 11.]\n", " [20. 3. 11. 14.]\n", " [ 9. 15. 13. 2.]\n", " [ 7. 16. 20. 4.]\n", " [16. 20. 24. 2.]\n", " [ 6. 23. 15. 4.]\n", " [ 8. 24. 18. 3.]\n", " [15. 10. 18. 14.]\n", " [16. 16. 11. 2.]\n", " [18. 8. 19. 12.]\n", " [ 8. 10. 8. 7.]\n", " [19. 8. 20. 11.]\n", " [ 7. 15. 8. 5.]\n", " [12. 8. 17. 14.]\n", " [11. 15. 17. 2.]\n", " [18. 18. 17. 0.]\n", " [18. 14. 10. 7.]\n", " [16. 7. 14. 7.]\n", " [15. 18. 8. 2.]\n", " [ 8. 17. 23. 3.]\n", " [12. 12. 13. 6.]\n", " [18. 0. 15. 16.]\n", " [10. 12. 26. 11.]\n", " [11. 15. 13. 3.]\n", " [11. 21. 9. 1.]\n", " [10. 18. 14. 6.]\n", " [15. 9. 11. 14.]\n", " [ 2. 10. 21. 1.]\n", " [17. 13. 12. 7.]\n", " [14. 7. 14. 14.]\n", " [17. 13. 10. 5.]\n", " [13. 20. 9. 4.]\n", " [10. 25. 19. 1.]\n", " [10. 10. 9. 9.]\n", " [12. 19. 20. 3.]\n", " [ 3. 22. 19. 6.]\n", " [17. 8. 14. 15.]\n", " [ 5. 7. 17. 11.]\n", " [10. 15. 18. 6.]\n", " [11. 7. 18. 9.]\n", " [14. 15. 10. 11.]\n", " [17. 20. 10. 21.]\n", " [11. 14. 19. 14.]\n", " [14. 15. 18. 9.]\n", " [ 9. 23. 20. 8.]\n", " [ 9. 14. 20. 2.]\n", " [16. 12. 14. 7.]\n", " [13. 18. 12. 16.]\n", " [16. 16. 16. 1.]\n", " [ 9. 17. 16. 2.]\n", " [15. 10. 21. 10.]\n", " [13. 11. 18. 15.]\n", " [ 4. 18. 24. 5.]\n", " [15. 20. 11. 2.]\n", " [ 9. 17. 21. 7.]\n", " [12. 23. 11. 5.]\n", " [11. 17. 15. 3.]\n", " [21. 1. 8. 16.]\n", " [16. 0. 14. 16.]\n", " [12. 9. 17. 15.]\n", " [ 7. 23. 20. 4.]\n", " [19. 9. 20. 10.]\n", " [ 7. 13. 11. 6.]\n", " [ 9. 21. 10. 1.]\n", " [ 8. 10. 23. 4.]\n", " [ 9. 21. 14. 2.]\n", " [11. 10. 31. 6.]\n", " [ 3. 19. 25. 2.]\n", " [11. 11. 16. 9.]\n", " [ 8. 22. 13. 5.]\n", " [12. 18. 14. 11.]\n", " [14. 21. 20. 5.]\n", " [17. 12. 23. 10.]\n", " [13. 14. 15. 12.]\n", " [10. 20. 16. 6.]\n", " [11. 16. 17. 6.]\n", " [13. 9. 17. 6.]\n", " [10. 23. 19. 7.]\n", " [14. 11. 14. 13.]\n", " [ 6. 20. 17. 5.]\n", " [12. 20. 20. 2.]\n", " [12. 25. 18. 14.]\n", " [13. 15. 13. 9.]\n", " [12. 21. 12. 2.]\n", " [13. 8. 14. 19.]\n", " [ 7. 9. 18. 10.]\n", " [14. 0. 17. 19.]\n", " [14. 0. 19. 14.]\n", " [ 4. 10. 15. 9.]\n", " [16. 25. 11. 2.]\n", " [ 9. 15. 19. 0.]\n", " [12. 9. 15. 3.]\n", " [ 9. 14. 13. 6.]\n", " [10. 14. 15. 6.]\n", " [15. 15. 11. 1.]\n", " [15. 21. 20. 3.]\n", " [12. 12. 16. 10.]\n", " [11. 13. 16. 8.]\n", " [18. 11. 15. 6.]\n", " [ 9. 22. 13. 5.]\n", " [ 9. 12. 10. 18.]\n", " [14. 12. 17. 12.]\n", " [13. 9. 8. 14.]\n", " [12. 10. 9. 15.]\n", " [19. 20. 15. 13.]\n", " [11. 26. 17. 1.]\n", " [ 9. 17. 5. 2.]\n", " [ 9. 8. 19. 12.]\n", " [10. 25. 11. 0.]\n", " [17. 7. 17. 6.]\n", " [ 9. 22. 19. 0.]\n", " [13. 12. 19. 9.]\n", " [15. 10. 11. 14.]\n", " [10. 13. 8. 7.]\n", " [ 5. 23. 16. 1.]\n", " [19. 18. 22. 10.]\n", " [ 5. 11. 16. 12.]\n", " [12. 13. 14. 7.]\n", " [11. 15. 19. 13.]\n", " [11. 12. 16. 12.]\n", " [14. 10. 16. 8.]\n", " [14. 10. 24. 11.]\n", " [ 5. 10. 14. 12.]\n", " [ 9. 4. 12. 13.]\n", " [19. 0. 17. 14.]\n", " [19. 17. 10. 7.]\n", " [11. 11. 21. 7.]\n", " [12. 9. 18. 8.]\n", " [ 8. 14. 15. 9.]\n", " [17. 2. 14. 21.]\n", " [16. 13. 19. 11.]\n", " [11. 9. 18. 7.]\n", " [12. 18. 13. 1.]\n", " [ 1. 11. 22. 5.]\n", " [11. 20. 18. 3.]\n", " [12. 20. 25. 4.]\n", " [ 8. 23. 14. 4.]\n", " [ 6. 9. 10. 5.]\n", " [12. 27. 13. 3.]\n", " [14. 14. 13. 6.]\n", " [20. 9. 16. 8.]\n", " [10. 19. 14. 3.]\n", " [16. 20. 18. 4.]]\n" ] } ], "source": [ "print(X.shape)\n", "print(X)" ] }, { "cell_type": "markdown", "id": "9a35e106", "metadata": {}, "source": [ "As before, we binarize the data. We'll also give the features names, so that we get more intuitive binarized feature names below:" ] }, { "cell_type": "code", "execution_count": 16, "id": "e9d900a7", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(200, 8)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
[a<12.0][a>=12.0][b<14.0][b>=14.0][c<16.0][c>=16.0][d<7.0][d>=7.0]
01.00.00.01.01.00.01.00.0
10.01.01.00.00.01.00.01.0
20.01.00.01.01.00.00.01.0
30.01.00.01.00.01.00.01.0
41.00.01.00.01.00.01.00.0
\n", "
" ], "text/plain": [ " [a<12.0] [a>=12.0] [b<14.0] [b>=14.0] [c<16.0] [c>=16.0] [d<7.0] \n", "0 1.0 0.0 0.0 1.0 1.0 0.0 1.0 \\\n", "1 0.0 1.0 1.0 0.0 0.0 1.0 0.0 \n", "2 0.0 1.0 0.0 1.0 1.0 0.0 0.0 \n", "3 0.0 1.0 0.0 1.0 0.0 1.0 0.0 \n", "4 1.0 0.0 1.0 0.0 1.0 0.0 1.0 \n", "\n", " [d>=7.0] \n", "0 0.0 \n", "1 1.0 \n", "2 1.0 \n", "3 1.0 \n", "4 0.0 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "X = pd.DataFrame(X)\n", "X.columns = [\"a\", \"b\", \"c\", \"d\"]\n", "\n", "# Binarize the data\n", "binarizer = BoolXAIKBinsDiscretizer(\n", " n_bins=2, strategy=\"quantile\", encode=\"onehot-dense\"\n", ")\n", "X_binarized = binarizer.fit_transform(pd.DataFrame(X))\n", "X_binarized.head()\n", "print(X_binarized.shape)\n", "X_binarized.head()" ] }, { "cell_type": "markdown", "id": "d5b00c6d", "metadata": {}, "source": [ "With the data in hand, we can now use `OneVsRestClassifier` to implicitly train multiple rule classifiers, one for each label, combined into a single classifier:" ] }, { "cell_type": "code", "execution_count": 17, "id": "2716c6e8", "metadata": {}, "outputs": [], "source": [ "from sklearn.multioutput import MultiOutputClassifier\n", "\n", "from boolxai import BoolXAI\n", "\n", "# Instantiate a multilabel rule classifier and fit it\n", "multilabel_rule_classifier = MultiOutputClassifier(\n", " BoolXAI.RuleClassifier(random_state=43)\n", ")\n", "multilabel_rule_classifier.fit(X_binarized, y);" ] }, { "cell_type": "markdown", "id": "428afda8", "metadata": {}, "source": [ "We can make predictions and calculate scores as usual with the combined classifier. Note, however, that `balanced_accuracy_score` does not support multilabel classification, so we switch to `accuracy_score`:" ] }, { "cell_type": "code", "execution_count": 18, "id": "fe301cec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "score=0.38\n" ] } ], "source": [ "from sklearn.metrics import accuracy_score\n", "\n", "# Apply Rules\n", "y_predict = multilabel_rule_classifier.predict(X_binarized)\n", "score = accuracy_score(y, y_predict)\n", "print(f\"{score=:.2f}\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "c736c88f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AtMost1(~[a<12.0], [c<16.0], [a>=12.0], ~[d<7.0])\n", "AtLeast1(Choose1(~[d<7.0], [c<16.0], ~[a<12.0]), ~[d>=7.0])\n", "AtMost1(AtMost1(~[b>=14.0], ~[a<12.0]), [d<7.0])\n" ] } ], "source": [ "# Print the best rule for each of the sub-classifiers inside multilabel_rule_classifier\n", "for subclassifier in multilabel_rule_classifier.estimators_:\n", " print(subclassifier.best_rule_.to_str(X_binarized.columns))" ] } ], "metadata": { "kernelspec": { "display_name": "boolxai_test", "language": "python", "name": "boolxai_test" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" } }, "nbformat": 4, "nbformat_minor": 5 }