Skip to content

Commit 17bae33

Browse files
committed
make create_learner.Rmd precommit compatible (#2824)
1 parent 99fdfcb commit 17bae33

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

vignettes/tutorial/create_learner.Rmd

+16-13
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ In addition to the data of the task, we also need the formula that describes wha
170170
We use the function `getTaskFormula()` to extract this from the task.
171171

172172
```{r, code=showFunctionDef(mlr:::trainLearner.classif.lda), eval=FALSE, tidy=TRUE}
173-
trainLearner.classif.lda = function(.learner, .task, .subset, .weights = NULL, ...) {
173+
trainLearner.classif.lda = function(.learner, .task, .subset, .weights = NULL, ...) {
174174
f = getTaskFormula(.task)
175175
MASS::lda(f, data = getTaskData(.task, .subset), ...)
176176
}
@@ -198,10 +198,11 @@ It is pretty much just a straight pass-through of the arguments to the `base::pr
198198
```{r, code=showFunctionDef(mlr:::predictLearner.classif.lda), eval=FALSE, tidy=TRUE, tidy.opts=list(indent=2, width.cutoff=100)}
199199
predictLearner.classif.lda = function(.learner, .model, .newdata, predict.method = "plug-in", ...) {
200200
p = predict(.model$learner.model, newdata = .newdata, method = predict.method, ...)
201-
if(.learner$predict.type == "response")
201+
if (.learner$predict.type == "response") {
202202
return(p$class)
203-
else
203+
} else {
204204
return(p$posterior)
205+
}
205206
}
206207
```
207208

@@ -243,7 +244,7 @@ makeRLearner.regr.earth = function() {
243244
```
244245

245246
```{r, code=showFunctionDef(mlr:::trainLearner.regr.earth), eval=FALSE, tidy=TRUE}
246-
trainLearner.regr.earth = function(.learner, .task, .subset, .weights = NULL, ...) {
247+
trainLearner.regr.earth = function(.learner, .task, .subset, .weights = NULL, ...) {
247248
f = getTaskFormula(.task)
248249
earth::earth(f, data = getTaskData(.task, .subset), ...)
249250
}
@@ -297,29 +298,30 @@ makeRLearner.surv.coxph = function() {
297298
```
298299

299300
```{r, code=showFunctionDef(mlr:::trainLearner.surv.coxph), eval=FALSE, tidy=TRUE, tidy.opts=list(indent=2, width.cutoff=100)}
300-
trainLearner.surv.coxph = function(.learner, .task, .subset, .weights = NULL, ...) {
301+
trainLearner.surv.coxph = function(.learner, .task, .subset, .weights = NULL, ...) {
301302
f = getTaskFormula(.task)
302303
data = getTaskData(.task, subset = .subset)
303304
if (is.null(.weights)) {
304305
mod = survival::coxph(formula = f, data = data, ...)
305-
} else {
306+
} else {
306307
mod = survival::coxph(formula = f, data = data, weights = .weights, ...)
307308
}
308-
if (.learner$predict.type == "prob")
309+
if (.learner$predict.type == "prob") {
309310
mod = attachTrainingInfo(mod, list(surv.range = range(getTaskTargets(.task)[, 1L])))
311+
}
310312
mod
311313
}
312314
```
313315

314316
```{r, code=showFunctionDef(mlr:::predictLearner.surv.coxph), eval=FALSE, tidy=TRUE, tidy.opts=list(indent=2, width.cutoff=100)}
315317
predictLearner.surv.coxph = function(.learner, .model, .newdata, ...) {
316-
if(.learner$predict.type == "response") {
318+
if (.learner$predict.type == "response") {
317319
predict(.model$learner.model, newdata = .newdata, type = "lp", ...)
318320
} else if (.learner$predict.type == "prob") {
319321
surv.range = getTrainingInfo(.model$learner.model)$surv.range
320322
times = seq(from = surv.range[1L], to = surv.range[2L], length.out = 1000)
321323
t(summary(survival::survfit(.model$learner.model, newdata = .newdata,
322-
se.fit = FALSE, conf.int = FALSE), times = times)$surv)
324+
se.fit = FALSE, conf.int = FALSE), times = times)$surv)
323325
} else {
324326
stop("Unknown predict type")
325327
}
@@ -355,7 +357,7 @@ makeRLearner.cluster.FarthestFirst = function() {
355357
```
356358

357359
```{r, code=showFunctionDef(mlr:::trainLearner.cluster.FarthestFirst), eval=FALSE, tidy=TRUE, tidy.opts=list(indent=2, width.cutoff=100)}
358-
trainLearner.cluster.FarthestFirst = function(.learner, .task, .subset, .weights = NULL, ...) {
360+
trainLearner.cluster.FarthestFirst = function(.learner, .task, .subset, .weights = NULL, ...) {
359361
ctrl = RWeka::Weka_control(...)
360362
RWeka::FarthestFirst(getTaskData(.task, .subset), control = ctrl)
361363
}
@@ -459,14 +461,15 @@ This method takes the Learner (`makeLearner()`) `.learner` and the WrappedModel
459461
It must return the predictions in the same format as the `predictLearner()` function.
460462

461463
```{r, code=showFunctionDef(mlr:::getOOBPredsLearner.classif.randomForest), eval=FALSE, tidy=TRUE, tidy.opts=list(indent=2, width.cutoff=100)}
464+
462465
```
463466

464467
# Registering your learner
465468

466469
If your interface code to a new learning algorithm exists only locally, i.e., it is not (yet) merged into `mlr` or does not live in an extra package with a proper namespace you might want to register the new S3 methods to make sure that these are found by, e.g., `listLearners()`.
467470
You can do this as follows:
468471

469-
```{r, eval=FALSE}
472+
```r
470473
registerS3method("makeRLearner", "<awesome_new_learner_class>",
471474
makeRLearner.<awesome_new_learner_class>)
472475
registerS3method("trainLearner", "<awesome_new_learner_class>",
@@ -477,14 +480,14 @@ registerS3method("predictLearner", "<awesome_new_learner_class>",
477480

478481
If you have written more methods, for example in order to extract feature importance values or out-of-bag predictions these also need to be registered in the same manner, for example:
479482

480-
```{r, eval=FALSE}
483+
```r
481484
registerS3method("getFeatureImportanceLearner", "<awesome_new_learner_class>",
482485
getFeatureImportanceLearner.<awesome_new_learner_class>)
483486
```
484487

485488
For the new learner to work with parallelization, you may have to export the new methods explicitly:
486489

487-
```{r, eval=FALSE}
490+
```r
488491
parallelExport("trainLearner.<awesome_new_learner_class>",
489492
"predictLearner.<awesome_new_learner_class>")
490493
```

0 commit comments

Comments
 (0)