{ "cells": [ { "cell_type": "markdown", "id": "b1db8c9a", "metadata": {}, "source": [ "# Generating Models with PyXAI" ] }, { "cell_type": "markdown", "id": "514a5144", "metadata": {}, "source": [ "\n", "The ```Learning``` module of PyXAI provides methods to: \n", "* create a [Scikit-learn](https://scikit-learn.org/stable/) or a [XGBoost](https://xgboost.readthedocs.io/en/stable/) ML classifier;\n", "* create a [XGBoost](https://xgboost.readthedocs.io/en/stable/) or a [LightGBM](https://github.com/microsoft/LightGBM) ML regressor;\n", "* carry out an experimental protocol using a train-test split technique (i.e. a cross-validation method); \n", "* get one or several ML models based on Decision Trees, Random Forests or Boosted Trees;\n", "* get, save and load specific instances and models.\n", "\n", "In this page, we detail the first three points. For the last one, please see the [Saving/Loading Models](/documentation/saving) page. \n" ] }, { "cell_type": "markdown", "id": "bfb6ec7d", "metadata": {}, "source": [ "## Loading Data" ] }, { "cell_type": "markdown", "id": "ed95f878", "metadata": {}, "source": [ "The first step is to create a ```Learner``` object that contains all methods needed to generate models. To this aim, you can use one of these methods depending on the chosen library:\n", " - ```Learning.Scikitlearn```\n", " - ```Learning.Xgboost```\n", " - ```Learning.LightGBM```" ] }, { "cell_type": "code", "execution_count": 1, "id": "3a86dda4", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:28.599646Z", "iopub.status.busy": "2026-05-15T14:33:28.599555Z", "iopub.status.idle": "2026-05-15T14:33:30.601232Z", "shell.execute_reply": "2026-05-15T14:33:30.600725Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------------- Information ---------------\n", "Problem type: classification\n", "Instances type: tabular\n", "Labels type: classes\n", "\n", "Dataset path: ../dataset/iris.csv\n", "nFeatures (nAttributes, with the labels): 4\n", "nInstances (nObservations): 150\n", "nLabels: 3\n" ] } ], "source": [ "from pyxai import Learning\n", "learner = Learning.Xgboost(\"../dataset/iris.csv\", problem_type=Learning.CLASSIFICATION)" ] }, { "cell_type": "markdown", "id": "32adc48c", "metadata": {}, "source": [ "{: .attention }\n", "> You can launch your program in command line with the ```-dataset``` option to specify the dataset filename: \n", "> ```console\n", "> python3 example.py -dataset=\"../dataset/iris.csv\"\n", "> ```\n", "> To get the value of the ```-dataset``` option in your program, you need to import the ```Tools``` module:\n", "> ```python\n", "> from pyxai import Learning, Tools\n", "> learner = Learning.Xgboost(Tools.Options.dataset)\n", "> ```" ] }, { "cell_type": "markdown", "id": "901275d9", "metadata": {}, "source": [ "{: .warning }\n", "> The dataset must specify the labels in the first row and the classes/values in the last column. \n", "> If this is not the case, you must modify your data using the [pandas](https://pandas.pydata.org/docs/index.html) library and provide a ```pandas.DataFrame``` in the functions of the ```Learning``` module. In this example, we add the missing labels:\n", "> ```python\n", "import pandas\n", "> data = pandas.read_csv(\"../dataset/iris.data\", names=['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width', 'Class'])\n", "> learner = Learning.Xgboost(data)\n", "> ```\n", "> \n", "> You can also use the [Preprocessor]({{ site.baseurl }}/documentation/preprocessor/) object of PyXAI that helps you to clean the dataset." ] }, { "cell_type": "markdown", "id": "dbfc19ac", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "markdown", "id": "5e34e2d2", "metadata": {}, "source": [ "The ```Learner``` object allows to learn a classifier/regressor (with the ```evaluate``` method) in order to produce one or several models according to the cross-validation method and the ML model chosen." ] }, { "cell_type": "markdown", "id": "a35011c2", "metadata": {}, "source": [ "Information about cross-validators can be found in the [Scikit-learn](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.model_selection) page. " ] }, { "cell_type": "markdown", "id": "322e27da", "metadata": {}, "source": [ "In this example, we create 3 boosted trees (classifiers) thanks to the K-folds cross-validator. " ] }, { "cell_type": "code", "execution_count": 2, "id": "52a02932", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:30.602477Z", "iopub.status.busy": "2026-05-15T14:33:30.602370Z", "iopub.status.idle": "2026-05-15T14:33:30.810237Z", "shell.execute_reply": "2026-05-15T14:33:30.809760Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Model creation, fitting and evaluation ---------------\n", "Splitting method: k-folds\n", "Problem type: classification\n", "Models type: boosted-tree\n", "model_parameters: {}\n", "--------- Evaluation Information ---------\n", "For the evaluation number 0:\n", "Metrics:\n", " micro_averaging_accuracy: 97.33333333333334\n", " micro_averaging_precision: 96.0\n", " micro_averaging_recall: 96.0\n", " macro_averaging_accuracy: 97.33333333333333\n", " macro_averaging_precision: 96.02339181286548\n", " macro_averaging_recall: 96.02339181286548\n", " true_positives: {'Iris-setosa': 16, 'Iris-versicolor': 18, 'Iris-virginica': 14}\n", " true_negatives: {'Iris-setosa': 34, 'Iris-versicolor': 30, 'Iris-virginica': 34}\n", " false_positives: {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 1}\n", " false_negatives: {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 1}\n", " accuracy: 96.0\n", " sklearn_confusion_matrix: [[16, 0, 0], [0, 18, 1], [0, 1, 14]]\n", "Number of Training instances: 100\n", "Number of Testing instances: 50\n", "\n", "For the evaluation number 1:\n", "Metrics:\n", " micro_averaging_accuracy: 94.66666666666667\n", " micro_averaging_precision: 92.0\n", " micro_averaging_recall: 92.0\n", " macro_averaging_accuracy: 94.66666666666667\n", " macro_averaging_precision: 91.66666666666666\n", " macro_averaging_recall: 92.85714285714285\n", " true_positives: {'Iris-setosa': 15, 'Iris-versicolor': 13, 'Iris-virginica': 18}\n", " true_negatives: {'Iris-setosa': 34, 'Iris-versicolor': 33, 'Iris-virginica': 29}\n", " false_positives: {'Iris-setosa': 1, 'Iris-versicolor': 3, 'Iris-virginica': 0}\n", " false_negatives: {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 3}\n", " accuracy: 92.0\n", " sklearn_confusion_matrix: [[15, 0, 0], [1, 13, 0], [0, 3, 18]]\n", "Number of Training instances: 100\n", "Number of Testing instances: 50\n", "\n", "For the evaluation number 2:\n", "Metrics:\n", " micro_averaging_accuracy: 97.33333333333334\n", " micro_averaging_precision: 96.0\n", " micro_averaging_recall: 96.0\n", " macro_averaging_accuracy: 97.33333333333333\n", " macro_averaging_precision: 95.83333333333334\n", " macro_averaging_recall: 96.07843137254902\n", " true_positives: {'Iris-setosa': 19, 'Iris-versicolor': 15, 'Iris-virginica': 14}\n", " true_negatives: {'Iris-setosa': 31, 'Iris-versicolor': 33, 'Iris-virginica': 34}\n", " false_positives: {'Iris-setosa': 0, 'Iris-versicolor': 0, 'Iris-virginica': 2}\n", " false_negatives: {'Iris-setosa': 0, 'Iris-versicolor': 2, 'Iris-virginica': 0}\n", " accuracy: 96.0\n", " sklearn_confusion_matrix: [[19, 0, 0], [0, 15, 2], [0, 0, 14]]\n", "Number of Training instances: 100\n", "Number of Testing instances: 50\n", "\n", "--------------- Explainer ----------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "For the split number 0:\n", "**Boosted Tree model**\n", "NClasses: 3\n", "nTrees: 300\n", "nVariables: 23\n", "\n", "For the split number 1:\n", "**Boosted Tree model**\n", "NClasses: 3\n", "nTrees: 300\n", "nVariables: 25\n", "\n", "For the split number 2:\n", "**Boosted Tree model**\n", "NClasses: 3\n", "nTrees: 300\n", "nVariables: 20\n", "\n" ] } ], "source": [ "models = learner.evaluate(splitting_method=Learning.K_FOLDS, model_type=Learning.BT,splitting_parameters={'n_models':3,'random_state':0})" ] }, { "cell_type": "markdown", "id": "0804030f", "metadata": {}, "source": [ "Beyond carrying out the experimental protocol, this method allows one to return the models in a dedicated format for the calculation of explanations. " ] }, { "cell_type": "markdown", "id": "03806b56", "metadata": {}, "source": [ "However, this may not meet your requirements (you may need another ML classifier and/or another cross-validation method):\n", "* Other ML classifiers and cross-validation methods are under development (the objective is to offer all cross-validation methods of Scikit-learn); \n", "* You can code your own experimental protocol and then import your models (see the [Importing Models](/documentation/importing) page)." ] }, { "cell_type": "markdown", "id": "68ec59f2", "metadata": {}, "source": [ "## Selecting Instances" ] }, { "cell_type": "markdown", "id": "b37a5641", "metadata": {}, "source": [ "PyXAI can easily select specific instances thanks to the ```get_instances``` method." ] }, { "cell_type": "markdown", "id": "96947e99", "metadata": {}, "source": [ "More details on the indexes, save_directory, and instances_id parameters are given on the [Saving/Loading Models](/documentation/saving). Let us look now at some examples of use. \n", "\n", "First we select only one instance (we take the first model among the three models computed). We directly get the instance and the prediction." ] }, { "cell_type": "code", "execution_count": 3, "id": "6061cdfa", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:30.811604Z", "iopub.status.busy": "2026-05-15T14:33:30.811482Z", "iopub.status.idle": "2026-05-15T14:33:30.816778Z", "shell.execute_reply": "2026-05-15T14:33:30.816375Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n", "Number of instances selected: 1\n", "----------------------------------------------\n", "Sepal.Length 5.1\n", "Sepal.Width 3.5\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 0, dtype: float64 Iris-setosa\n" ] } ], "source": [ "instance, prediction = learner.get_instances(models[0],n=1)\n", "print(instance, prediction)" ] }, { "cell_type": "markdown", "id": "ba6142ad", "metadata": {}, "source": [ "Now, we take 3 instances. We obtain a list of instances. " ] }, { "cell_type": "code", "execution_count": 4, "id": "2cfea8dc", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:30.817926Z", "iopub.status.busy": "2026-05-15T14:33:30.817807Z", "iopub.status.idle": "2026-05-15T14:33:30.823030Z", "shell.execute_reply": "2026-05-15T14:33:30.822660Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n", "Number of instances selected: 3\n", "----------------------------------------------\n", "(Sepal.Length 5.1\n", "Sepal.Width 3.5\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 0, dtype: float64, 'Iris-setosa')\n", "(Sepal.Length 4.9\n", "Sepal.Width 3.0\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 1, dtype: float64, 'Iris-setosa')\n", "(Sepal.Length 4.7\n", "Sepal.Width 3.2\n", "Petal.Length 1.3\n", "Petal.Width 0.2\n", "Name: 2, dtype: float64, 'Iris-setosa')\n" ] } ], "source": [ "instances = learner.get_instances(models[0],n=3)\n", "for instance in instances:\n", " print(instance)" ] }, { "cell_type": "markdown", "id": "192c5161", "metadata": {}, "source": [ "The same invocation but without the model as a parameter leads to a different output: the prediction is not provided." ] }, { "cell_type": "code", "execution_count": 5, "id": "8dc00ef4", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:30.824114Z", "iopub.status.busy": "2026-05-15T14:33:30.824003Z", "iopub.status.idle": "2026-05-15T14:33:30.827311Z", "shell.execute_reply": "2026-05-15T14:33:30.826900Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n", "Number of instances selected: 3\n", "----------------------------------------------\n", "[(Sepal.Length 5.1\n", "Sepal.Width 3.5\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 0, dtype: float64, None), (Sepal.Length 4.9\n", "Sepal.Width 3.0\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 1, dtype: float64, None), (Sepal.Length 4.7\n", "Sepal.Width 3.2\n", "Petal.Length 1.3\n", "Petal.Width 0.2\n", "Name: 2, dtype: float64, None)]\n" ] } ], "source": [ "instances = learner.get_instances(n=3)\n", "print(instances)" ] }, { "cell_type": "markdown", "id": "dab333ec", "metadata": {}, "source": [ "Now, consider 3 instances for which the prediction given by the model is equal to Iris-setosa." ] }, { "cell_type": "code", "execution_count": 6, "id": "ae29f933", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:30.828308Z", "iopub.status.busy": "2026-05-15T14:33:30.828204Z", "iopub.status.idle": "2026-05-15T14:33:30.835525Z", "shell.execute_reply": "2026-05-15T14:33:30.835170Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n", "Number of instances selected: 3\n", "----------------------------------------------\n", "[(Sepal.Length 5.1\n", "Sepal.Width 3.5\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 0, dtype: float64, 'Iris-setosa'), (Sepal.Length 4.9\n", "Sepal.Width 3.0\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 1, dtype: float64, 'Iris-setosa'), (Sepal.Length 4.7\n", "Sepal.Width 3.2\n", "Petal.Length 1.3\n", "Petal.Width 0.2\n", "Name: 2, dtype: float64, 'Iris-setosa')]\n" ] } ], "source": [ "instances = learner.get_instances(models[0], n=3, subset_predicted_classes=[\"Iris-setosa\"])\n", "print(instances)" ] }, { "cell_type": "markdown", "id": "91a88f29", "metadata": {}, "source": [ "Next, we focus on 3 instances that have a prediction given by the model equal to Iris-virginica and for which the prediction is wrong (i.e. the prediction returned by the model differs from the label in the dataset). Note that only one instance meets these criteria." ] }, { "cell_type": "code", "execution_count": 7, "id": "48ae777f", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:30.836621Z", "iopub.status.busy": "2026-05-15T14:33:30.836513Z", "iopub.status.idle": "2026-05-15T14:33:31.018903Z", "shell.execute_reply": "2026-05-15T14:33:31.018476Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Number of instances selected: 1\n", "----------------------------------------------\n", "[(Sepal.Length 6.0\n", "Sepal.Width 2.7\n", "Petal.Length 5.1\n", "Petal.Width 1.6\n", "Name: 83, dtype: float64, 'Iris-virginica')]\n" ] } ], "source": [ "instances = learner.get_instances(models[0], n=3, subset_predicted_classes=[\"Iris-virginica\"], is_correct=False)\n", "print(instances)" ] }, { "cell_type": "markdown", "id": "bf0c55fe", "metadata": {}, "source": [ "Now, we want to get random instances (2 different calls provide different instances)." ] }, { "cell_type": "code", "execution_count": 8, "id": "58b9cd61", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:31.019956Z", "iopub.status.busy": "2026-05-15T14:33:31.019842Z", "iopub.status.idle": "2026-05-15T14:33:31.026340Z", "shell.execute_reply": "2026-05-15T14:33:31.026011Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Number of instances selected: 1\n", "----------------------------------------------\n", "Sepal.Length 6.3\n", "Sepal.Width 2.5\n", "Petal.Length 4.9\n", "Petal.Width 1.5\n", "Name: 72, dtype: float64 Iris-versicolor\n", "--------------- Instances ----------------\n", "Number of instances selected: 1\n", "----------------------------------------------\n", "Sepal.Length 5.8\n", "Sepal.Width 2.7\n", "Petal.Length 5.1\n", "Petal.Width 1.9\n", "Name: 142, dtype: float64 Iris-virginica\n" ] } ], "source": [ "instance, prediction = learner.get_instances(models[0], n=1, seed=None)\n", "print(instance, prediction)\n", "instance, prediction = learner.get_instances(models[0], n=1, seed=None)\n", "print(instance, prediction)" ] }, { "cell_type": "markdown", "id": "a9fc8445", "metadata": {}, "source": [ "Here we show how the ```details``` parameter works to obtain the predictions and labels:" ] }, { "cell_type": "code", "execution_count": 9, "id": "5dd67e34", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:31.027452Z", "iopub.status.busy": "2026-05-15T14:33:31.027352Z", "iopub.status.idle": "2026-05-15T14:33:31.031767Z", "shell.execute_reply": "2026-05-15T14:33:31.031418Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n", "Number of instances selected: 3\n", "----------------------------------------------\n", "[{'instance': Sepal.Length 5.1\n", "Sepal.Width 3.5\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 0, dtype: float64, 'prediction': 'Iris-setosa', 'label': 'Iris-setosa', 'index': 0}, {'instance': Sepal.Length 4.9\n", "Sepal.Width 3.0\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 1, dtype: float64, 'prediction': 'Iris-setosa', 'label': 'Iris-setosa', 'index': 1}, {'instance': Sepal.Length 4.7\n", "Sepal.Width 3.2\n", "Petal.Length 1.3\n", "Petal.Width 0.2\n", "Name: 2, dtype: float64, 'prediction': 'Iris-setosa', 'label': 'Iris-setosa', 'index': 2}]\n" ] } ], "source": [ "instances = learner.get_instances(models[2], n=3, details=True)\n", "print(instances)" ] }, { "cell_type": "markdown", "id": "9ff0756c", "metadata": {}, "source": [ "Finally, we want to select 3 instances from the test set:" ] }, { "cell_type": "code", "execution_count": 10, "id": "58e33698", "metadata": { "execution": { "iopub.execute_input": "2026-05-15T14:33:31.032670Z", "iopub.status.busy": "2026-05-15T14:33:31.032576Z", "iopub.status.idle": "2026-05-15T14:33:31.037170Z", "shell.execute_reply": "2026-05-15T14:33:31.036866Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- Instances ----------------\n", "Number of instances selected: 3\n", "----------------------------------------------\n", "[(Sepal.Length 5.1\n", "Sepal.Width 3.5\n", "Petal.Length 1.4\n", "Petal.Width 0.2\n", "Name: 0, dtype: float64, 'Iris-setosa'), (Sepal.Length 5.4\n", "Sepal.Width 3.9\n", "Petal.Length 1.7\n", "Petal.Width 0.4\n", "Name: 5, dtype: float64, 'Iris-setosa'), (Sepal.Length 4.9\n", "Sepal.Width 3.1\n", "Petal.Length 1.5\n", "Petal.Width 0.1\n", "Name: 9, dtype: float64, 'Iris-setosa')]\n" ] } ], "source": [ "instances = learner.get_instances(models[2], indexes=Learning.TEST, n=3)\n", "print(instances)" ] }, { "cell_type": "markdown", "id": "4812f67a", "metadata": {}, "source": [ "Saving or loading instances is presented in the [Saving/Loading Models](/documentation/saving) page. " ] }, { "cell_type": "markdown", "id": "f253addd", "metadata": {}, "source": [ "## A complete example" ] }, { "cell_type": "markdown", "id": "b502929b", "metadata": {}, "source": [ "As you can see, carrying out an empirical protocol requires the execution of very few instructions. The ```Learning``` module allows us to easily obtain the models and instances that we want to explain. " ] }, { "cell_type": "markdown", "id": "49fef8d0", "metadata": {}, "source": [ "```python\n", "from pyxai import Learning\n", "\n", "learner = Learning.Xgboost(\"../dataset/iris.csv\", problem_type=Learning.CLASSIFICATION)\n", "models = learner.evaluate(splitting_method=Learning.K_FOLDS, model_type=Learning.BT)\n", "for model in models:\n", " instances_with_prediction = learner.get_instances(model, n=10, indexes=Learning.TEST)\n", " for instance, prediction in instances_with_prediction:\n", " print(\"instance:\", instance)\n", " print(\"prediction:\", prediction)\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.7" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }