Skip to content

Commit 9884571

Browse files
Issue #563: Use weights in Permutation Feature Importance calculation (#564)
* Pass explainer weights to loss_after_permutation * Define weights for sampled data * Function to handle loss functions with or without sample_weight arg * Replace loss function calls with wrapper * Add imports * Avoid ambiguous truth values * More explicit warning if weights passed but not used in loss calc
1 parent db2ae5d commit 9884571

File tree

1 file changed

+26
-6
lines changed
  • python/dalex/dalex/model_explanations/_variable_importance

1 file changed

+26
-6
lines changed

python/dalex/dalex/model_explanations/_variable_importance/utils.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import inspect
12
import multiprocessing as mp
3+
import warnings
24
from numpy.random import SeedSequence, default_rng
35

46
import numpy as np
@@ -18,15 +20,15 @@ def calculate_variable_importance(explainer,
1820
if processes == 1:
1921
result = [None] * B
2022
for i in range(B):
21-
result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.model, explainer.predict_function,
23+
result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function,
2224
loss_function, variables, N, np.random)
2325
else:
2426
# Create number generator for each iteration
2527
ss = SeedSequence(random_state)
2628
generators = [default_rng(s) for s in ss.spawn(B)]
2729
pool = mp.get_context('spawn').Pool(processes)
2830
result = pool.starmap_async(loss_after_permutation, [
29-
(explainer.data, explainer.y, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for
31+
(explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for
3032
i in range(B)]).get()
3133
pool.close()
3234

@@ -49,21 +51,24 @@ def calculate_variable_importance(explainer,
4951
return result, raw_permutations
5052

5153

52-
def loss_after_permutation(data, y, model, predict, loss_function, variables, N, rng):
54+
def loss_after_permutation(data, y, weights, model, predict, loss_function, variables, N, rng):
5355
if isinstance(N, int):
5456
N = min(N, data.shape[0])
5557
sampled_rows = rng.choice(np.arange(data.shape[0]), N, replace=False)
5658
sampled_data = data.iloc[sampled_rows, :]
5759
observed = y[sampled_rows]
60+
sample_weights = weights[sampled_rows] if weights is not None else None
5861
else:
5962
sampled_data = data
6063
observed = y
64+
sample_weights = weights
6165

6266
# loss on the full model or when outcomes are permuted
63-
loss_full = loss_function(observed, predict(model, sampled_data))
67+
loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weights)
6468

6569
sampled_rows2 = rng.choice(range(observed.shape[0]), observed.shape[0], replace=False)
66-
loss_baseline = loss_function(observed[sampled_rows2], predict(model, sampled_data))
70+
sample_weights_rows2 = sample_weights[sampled_rows2] if sample_weights is not None else None
71+
loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weights_rows2)
6772

6873
loss_features = {}
6974
for variables_set_key in variables:
@@ -74,9 +79,24 @@ def loss_after_permutation(data, y, model, predict, loss_function, variables, N,
7479

7580
predicted = predict(model, ndf)
7681

77-
loss_features[variables_set_key] = loss_function(observed, predicted)
82+
loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weights)
7883

7984
loss_features['_full_model_'] = loss_full
8085
loss_features['_baseline_'] = loss_baseline
8186

8287
return pd.DataFrame(loss_features, index=[0])
88+
89+
90+
def calculate_loss(loss_function, observed, predicted, sample_weights=None):
91+
# Determine if loss function accepts 'sample_weight'
92+
loss_args = inspect.signature(loss_function).parameters
93+
supports_weight = "sample_weight" in loss_args
94+
95+
if supports_weight:
96+
return loss_function(observed, predicted, sample_weight=sample_weights)
97+
else:
98+
if sample_weights is not None:
99+
warnings.warn(
100+
f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss."
101+
)
102+
return loss_function(observed, predicted)

0 commit comments

Comments
 (0)