Link Search Menu Expand Document
PyXAI
Papers Video GitHub EXPEKCTATION About
download notebook

Importing Models From Libraries

PyXAI can generate models for you. Indeed it provides some dedicated functions that simplify this task. However, if your model has already been learned, you may want to import it inside PyXAI in order to extract explanations afterwards. This page explains how to perform such a task.

Procedure

Consider the follownig source code to create a RandomForestClassifier using Scikit-learn:

from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier

model_rf = RandomForestClassifier(random_state=0)
data = datasets.load_breast_cancer(as_frame=True)
X = data.data.to_numpy()
Y = data.target.to_numpy()

feature_names = data.feature_names
model_rf.fit(X, Y);

You can import this ML model thanks to the Learning.import_models() method:

Learning.import_models(models, feature_names=[]):
Import the models. The method detects the type of models and applies the correct conversions in order to translate them into PyXAI data structures. Return a tuple (<Learner Object>, models) where the returned models depend on the conversions applied. More precisely, the returned models can be of the form DecisionTree|RandomForest|BoostedTrees|BoostedTreesRegression.
models List of RandomForestClassifier|DecisionTreeClassifier|XGBClassifier|XGBRegressor|LGBMRegressor: List of models to import.
feature_names List of String: The feature names. If the feature names are not specified, they can be replaced by strings starting with ‘f’ followed by a number (e.g., f1,f2,f3,…,f30) in the explanations provided by the to_features() method.

Here is a table summarizing the compatibility ensured with respect to 3 standard ML libraries:

Type Scikit-learn Xgboost LightGBM
Decision Tree DecisionTreeClassifier
Random Forest RandomForestClassifier
Boosted Tree XGBClassifier
XGBRegressor
LGBMRegressor


from pyxai import Tools, Learning, Explainer
learner, model = Learning.import_models(model_rf, feature_names)
---------------   Explainer   ----------------
For the evaluation number 0:
**Random Forest Model**
nClasses: 2
nTrees: 100
nVariables: 1755

Then, you can get explanations by executing:

instance, prediction = learner.get_instances(dataset=data.frame, model=model, n=1)
print("instance:", instance)
print("prediction:", prediction)
---------------   Instances   ----------------
data:
     mean radius  mean texture  mean perimeter  mean area  mean smoothness   
0          17.99         10.38          122.80     1001.0          0.11840  \
1          20.57         17.77          132.90     1326.0          0.08474   
2          19.69         21.25          130.00     1203.0          0.10960   
3          11.42         20.38           77.58      386.1          0.14250   
4          20.29         14.34          135.10     1297.0          0.10030   
..           ...           ...             ...        ...              ...   
564        21.56         22.39          142.00     1479.0          0.11100   
565        20.13         28.25          131.20     1261.0          0.09780   
566        16.60         28.08          108.30      858.1          0.08455   
567        20.60         29.33          140.10     1265.0          0.11780   
568         7.76         24.54           47.92      181.0          0.05263   

     mean compactness  mean concavity  mean concave points  mean symmetry   
0             0.27760         0.30010              0.14710         0.2419  \
1             0.07864         0.08690              0.07017         0.1812   
2             0.15990         0.19740              0.12790         0.2069   
3             0.28390         0.24140              0.10520         0.2597   
4             0.13280         0.19800              0.10430         0.1809   
..                ...             ...                  ...            ...   
564           0.11590         0.24390              0.13890         0.1726   
565           0.10340         0.14400              0.09791         0.1752   
566           0.10230         0.09251              0.05302         0.1590   
567           0.27700         0.35140              0.15200         0.2397   
568           0.04362         0.00000              0.00000         0.1587   

     mean fractal dimension  ...  worst texture  worst perimeter  worst area   
0                   0.07871  ...          17.33           184.60      2019.0  \
1                   0.05667  ...          23.41           158.80      1956.0   
2                   0.05999  ...          25.53           152.50      1709.0   
3                   0.09744  ...          26.50            98.87       567.7   
4                   0.05883  ...          16.67           152.20      1575.0   
..                      ...  ...            ...              ...         ...   
564                 0.05623  ...          26.40           166.10      2027.0   
565                 0.05533  ...          38.25           155.00      1731.0   
566                 0.05648  ...          34.12           126.70      1124.0   
567                 0.07016  ...          39.42           184.60      1821.0   
568                 0.05884  ...          30.37            59.16       268.6   

     worst smoothness  worst compactness  worst concavity   
0             0.16220            0.66560           0.7119  \
1             0.12380            0.18660           0.2416   
2             0.14440            0.42450           0.4504   
3             0.20980            0.86630           0.6869   
4             0.13740            0.20500           0.4000   
..                ...                ...              ...   
564           0.14100            0.21130           0.4107   
565           0.11660            0.19220           0.3215   
566           0.11390            0.30940           0.3403   
567           0.16500            0.86810           0.9387   
568           0.08996            0.06444           0.0000   

     worst concave points  worst symmetry  worst fractal dimension  target  
0                  0.2654          0.4601                  0.11890       0  
1                  0.1860          0.2750                  0.08902       0  
2                  0.2430          0.3613                  0.08758       0  
3                  0.2575          0.6638                  0.17300       0  
4                  0.1625          0.2364                  0.07678       0  
..                    ...             ...                      ...     ...  
564                0.2216          0.2060                  0.07115       0  
565                0.1628          0.2572                  0.06637       0  
566                0.1418          0.2218                  0.07820       0  
567                0.2650          0.4087                  0.12400       0  
568                0.0000          0.2871                  0.07039       1  

[569 rows x 31 columns]
--------------   Information   ---------------
Dataset name: pandas.core.frame.DataFrame
nFeatures (nAttributes, with the labels): 31
nInstances (nObservations): 569
nLabels: 2
number of instances selected: 1
----------------------------------------------
instance: [1.799e+01 1.038e+01 1.228e+02 1.001e+03 1.184e-01 2.776e-01 3.001e-01
 1.471e-01 2.419e-01 7.871e-02 1.095e+00 9.053e-01 8.589e+00 1.534e+02
 6.399e-03 4.904e-02 5.373e-02 1.587e-02 3.003e-02 6.193e-03 2.538e+01
 1.733e+01 1.846e+02 2.019e+03 1.622e-01 6.656e-01 7.119e-01 2.654e-01
 4.601e-01 1.189e-01]
prediction: 0
explainer = Explainer.initialize(model, instance=instance)

direct = explainer.direct_reason()
print("len direct reason:", len(direct))

sufficient = explainer.sufficient_reason()
print("len sufficient reason:", len(sufficient))

print("to_features:", explainer.to_features(sufficient))
len direct reason: 294
len sufficient reason: 159
to_features: ('mean radius > 15.045000076293945', 'mean texture <= 11.585000038146973', 'mean perimeter > 96.57999801635742', 'mean area > 694.5', 'mean smoothness > 0.09075499698519707', 'mean compactness > 0.09524999931454659', 'mean concavity > 0.17409999668598175', 'mean concave points > 0.07939000055193901', 'mean symmetry > 0.12639999762177467', 'radius error > 0.7730999886989594', 'texture error > 0.7377500236034393', 'perimeter error > 2.76200008392334', 'area error > 33.064998626708984', 'smoothness error in ]0.005567499902099371, 0.009928999934345484]', 'compactness error > 0.00834800023585558', 'concavity error in ]0.018459999933838844, 0.2157999947667122]', 'fractal dimension error in ]0.0030724999960511923, 0.012140000239014626]', 'worst radius > 17.72499942779541', 'worst texture in ]15.434999942779541, 18.289999961853027]', 'worst perimeter > 120.70000076293945', 'worst area > 953.7000122070312', 'worst smoothness > 0.1363999992609024', 'worst concavity > 0.4524500072002411', 'worst concave points > 0.16029999405145645', 'worst symmetry > 0.37139999866485596', 'worst fractal dimension > 0.10035499930381775')

Giving the feature_names in the Learning.import_models() parameters allows to get the right feature names with the to_features() method. If you do not give them, the feature names will be of the form f1, f2, f3 ,…, f30 where the numbers correspond to ranks in the dataset.

You can use learner.get_label_from_value(value) and learner.get_value_from_label(label) to get the right values comming from the encoding of labels. The python dictionary variable learner.dict_labels contains the encoding performed.

Load/Save From Libraries

The creation of ML models and the calculation of explanations are done by two different programs. You cab save them using the first one and load them using the second one.

Scikit-learn

We follow the documentation of Scikit-learn which advises the use of module pickle.

from sklearn import svm
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
import pickle

rf = RandomForestClassifier()
X, Y = datasets.load_breast_cancer(return_X_y=True)
rf.fit(X, Y)
file = open("example.model", 'wb')
pickle.dump(rf, file)
file.close()

You can load this model into another program thanks to these lines of code:

with open("example.model", 'rb') as file:
    learner = pickle.load(file)

And then you can import your model:

from pyxai import Tools, Learning, Explainer
learner, model = Learning.import_models(learner)
---------------   Explainer   ----------------
For the evaluation number 0:
**Random Forest Model**
nClasses: 2
nTrees: 100
nVariables: 1675

XGBoost

We follow the documentation of XGBoost.

from sklearn import svm
from sklearn import datasets
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier

X, Y = datasets.load_iris(return_X_y=True)
bt = XGBClassifier(eval_metric="mlogloss")
bt.fit(X, Y)
bt.save_model('my_model.json')

You can load this model into another program thanks to these lines of code:

bt_loaded = XGBClassifier(eval_metric='mlogloss')
bt_loaded.load_model('my_model.json')

And then you can import your model:

from pyxai import Tools, Learning, Explainer
learner, model = Learning.import_models(bt_loaded)
---------------   Explainer   ----------------
For the evaluation number 0:
**Boosted Tree model**
NClasses: 3
nTrees: 300
nVariables: 33

LightGBM

The documentation of LightGBM allows to save/load a Booster object. Thus, to save/load a LGBMRegressor, we use pickle.

from sklearn import datasets
import lightgbm

X, Y = datasets.load_iris(return_X_y=True)
learner = lightgbm.LGBMRegressor(n_estimators=5, random_state=0)
learner.fit(X, Y)
file = open("example.model", 'wb')
pickle.dump(learner, file)
file.close()

You can load this model into another program thanks to these lines of code:

with open("example.model", 'rb') as file:
    learner_loaded = pickle.load(file)

And then you can import your model:

from pyxai import Tools, Learning, Explainer
learner, model = Learning.import_models(learner_loaded)
---------------   Explainer   ----------------
For the evaluation number 0:
**Boosted Tree model**
NClasses: None
nTrees: 5
nVariables: 9

Example with cross-validation

This example shows how to import models and to compute explanations. We start by implementing a function to process the dataset:

import pandas
import numpy

def load_dataset(dataset):
    data = pandas.read_csv(dataset).copy()

    # extract labels
    labels = data[data.columns[-1]]
    labels = numpy.array(labels)

    # remove the label of each instance
    data = data.drop(columns=[data.columns[-1]])

    # extract the feature names
    feature_names = list(data.columns)

    return data.values, labels, feature_names

Then, we implement a function performing a cross validation. More precisely, we chose here to use the Leave One Group Out cross-validator of Scikit-learn and a lightgbm.LGBMRegressor of the LightGBM library:

import functools
import random 
import operator
import lightgbm
from sklearn.model_selection import LeaveOneGroupOut

def cross_validation(X, Y, n_trees=100, n_forests=2) :
    n_instance = len(Y)
    quotient = n_instance // n_forests
    remain = n_instance % n_forests

    # Groups creation
    groups = [quotient*[i] for i in range(1,n_forests+1)]
    groups = functools.reduce(operator.iconcat, groups, [])
    groups += [i for i in range(1,remain+1)]
    random.shuffle(groups)

    # Variable definition
    loo = LeaveOneGroupOut()
    forests = []
    i = 0
    for index_training, index_test in loo.split(X, Y, groups=groups):
        if i < n_forests:
            i += 1
        # Creation of instances (X) and labels (Y) according to the index of loo.split() 
        # for both training and test set
        x_train = [X[x] for x in index_training]
        y_train = [Y[x] for x in index_training]
        x_test = [X[x] for x in index_test]
        y_test = [Y[x] for x in index_test]

        # Training phase
        learner = lightgbm.LGBMRegressor(n_estimators=5, random_state=0)
        learner.fit(x_train, y_train)
        # Get the classifier prediction of the test set  
        y_predict = learner.predict(x_test)

        forests.append((learner, index_training, index_test))
    return forests

Finally, we use the two previous functions and import the models in PyXAI in order to compute explanations.

from pyxai import Tools, Learning, Explainer

data, labels, feature_names = load_dataset("../dataset/winequality-red.csv")
results = cross_validation(data, labels, n_trees=5)

models = [result[0] for result in results]
training_indexes = [result[1] for result in results]
test_indexes = [result[2] for result in results]

learner, models = Learning.import_models(models)

for i, model in enumerate(models):
    instances = learner.get_instances(dataset="../dataset/winequality-red.csv", model=model, n=2, indexes=Learning.TEST, test_indexes=test_indexes[i])

    for (instance, prediction_classifier) in instances:
        explainer = Explainer.initialize(model, instance=instance)
        prediction = model.predict_instance(instance)
        print("prediction:", prediction)
        direct = explainer.direct_reason()
        print("len direct reason:", len(direct))
        explainer.set_interval(prediction - 0.2, prediction + 0.2)
        ts = explainer.tree_specific_reason()
        print("len tree_specific_reason:", len(ts))
        print("---------------------------")
---------------   Explainer   ----------------
For the evaluation number 0:
**Boosted Tree model**
NClasses: None
nTrees: 5
nVariables: 73

For the evaluation number 1:
**Boosted Tree model**
NClasses: None
nTrees: 5
nVariables: 86

---------------   Instances   ----------------
number of instances selected: 2
----------------------------------------------
prediction: 5.344027682956049
len direct reason: 12
len tree_specific_reason: 5
---------------------------
prediction: 5.3536241433926985
len direct reason: 9
len tree_specific_reason: 5
---------------------------
---------------   Instances   ----------------
number of instances selected: 2
----------------------------------------------
prediction: 5.407780216085812
len direct reason: 28
len tree_specific_reason: 9
---------------------------
prediction: 5.54355845301202
len direct reason: 28
len tree_specific_reason: 9
---------------------------

With PyXAI, you can also generate your own models. We invite you to look at the Generating Models page for more information.