Skip to content

Commit fab61b8

Browse files
committed
[python] add weights to loss function: fix tests, update changelog
1 parent 9884571 commit fab61b8

File tree

6 files changed

+27
-20
lines changed

6 files changed

+27
-20
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
**token
12
.vscode/settings.json
23
**.DS_Store
34

python/dalex/NEWS.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
## Changelog
22

3+
### development
4+
5+
* added a way to pass `sample_weight` to loss functions in `model_parts()` (variable importance) using `weights` from `dx.Explainer` ([#563](https://github.com/ModelOriented/DALEX/issues/563))
36

47
### v1.7.0 (2024-02-28)
58

python/dalex/dalex/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .aspect import Aspect
1010

1111

12-
__version__ = '1.7.0'
12+
__version__ = '1.7.0.9000'
1313

1414
__all__ = [
1515
"Arena",

python/dalex/dalex/_global_checks.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import pkg_resources
22
from importlib import import_module
33
from re import search
4-
import numpy as np
5-
import pandas as pd
64

75
# WARNING: below code is parsed by setup.py
86
# WARNING: each dependency should be in new line

python/dalex/dalex/model_explanations/_variable_importance/utils.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,18 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari
5757
sampled_rows = rng.choice(np.arange(data.shape[0]), N, replace=False)
5858
sampled_data = data.iloc[sampled_rows, :]
5959
observed = y[sampled_rows]
60-
sample_weights = weights[sampled_rows] if weights is not None else None
60+
sample_weight = weights[sampled_rows] if weights is not None else None
6161
else:
6262
sampled_data = data
6363
observed = y
64-
sample_weights = weights
64+
sample_weight = weights
6565

6666
# loss on the full model or when outcomes are permuted
67-
loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weights)
67+
loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weight)
6868

6969
sampled_rows2 = rng.choice(range(observed.shape[0]), observed.shape[0], replace=False)
70-
sample_weights_rows2 = sample_weights[sampled_rows2] if sample_weights is not None else None
71-
loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weights_rows2)
70+
sample_weight_rows2 = sample_weight[sampled_rows2] if sample_weight is not None else None
71+
loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weight_rows2)
7272

7373
loss_features = {}
7474
for variables_set_key in variables:
@@ -79,24 +79,22 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari
7979

8080
predicted = predict(model, ndf)
8181

82-
loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weights)
82+
loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weight)
8383

8484
loss_features['_full_model_'] = loss_full
8585
loss_features['_baseline_'] = loss_baseline
8686

8787
return pd.DataFrame(loss_features, index=[0])
8888

8989

90-
def calculate_loss(loss_function, observed, predicted, sample_weights=None):
90+
def calculate_loss(loss_function, observed, predicted, sample_weight=None):
9191
# Determine if loss function accepts 'sample_weight'
9292
loss_args = inspect.signature(loss_function).parameters
9393
supports_weight = "sample_weight" in loss_args
9494

9595
if supports_weight:
96-
return loss_function(observed, predicted, sample_weight=sample_weights)
96+
return loss_function(observed, predicted, sample_weight=sample_weight)
9797
else:
98-
if sample_weights is not None:
99-
warnings.warn(
100-
f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss."
101-
)
98+
if sample_weight is not None:
99+
raise UserWarning(f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss.")
102100
return loss_function(observed, predicted)

python/dalex/test/test_variable_importance.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,43 @@ def test_loss_after_permutation(self):
5555
variables = {}
5656
for col in self.X.columns:
5757
variables[col] = col
58-
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, rmse,
58+
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, rmse,
5959
variables, 100, np.random)
6060
self.assertIsInstance(lap, pd.DataFrame)
6161
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
6262
lap.columns).all(), np.random)
63+
64+
with self.assertRaises(UserWarning):
65+
lap = utils.loss_after_permutation(self.X, self.y, self.y, self.exp.model, self.exp.predict_function, rmse,
66+
variables, 100, np.random)
67+
self.assertIsInstance(lap, pd.DataFrame)
68+
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
69+
lap.columns).all(), np.random)
6370

6471
variables = {'age': 'age', 'embarked': 'embarked'}
65-
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, mad,
72+
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, mad,
6673
variables, 10, np.random)
6774
self.assertIsInstance(lap, pd.DataFrame)
6875
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
6976
lap.columns).all())
7077

7178
variables = {'embarked': 'embarked'}
72-
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, mae,
79+
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, mae,
7380
variables, None, np.random)
7481
self.assertIsInstance(lap, pd.DataFrame)
7582
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
7683
lap.columns).all())
7784

7885
variables = {'age': 'age'}
79-
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, rmse,
86+
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, rmse,
8087
variables, None, np.random)
8188
self.assertIsInstance(lap, pd.DataFrame)
8289
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
8390
lap.columns).all())
8491

8592
variables = {'personal': ['gender', 'age', 'sibsp', 'parch'],
8693
'wealth': ['class', 'fare']}
87-
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, mae,
94+
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, mae,
8895
variables, None, np.random)
8996
self.assertIsInstance(lap, pd.DataFrame)
9097
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),

0 commit comments

Comments
 (0)