Skip to content

Commit e2122b9

Browse files
authored
BUG: Allow multiple names for vector indicators (kernc#382) (kernc#980)
* BUG: Allow multiple names for vector indicators (kernc#382) Previously we only allowed one name per vector indicator: def _my_indicator(open, close): return tuple( _my_indicator_one(open, close), _my_indicator_two(open, close), ) self.I( _my_indicator, # One name is used to describe two values name="My Indicator", self.data.Open, self.data.Close ) Now, the user can supply two (or more) names to annotate each value individually. The names will be shown in the plot legend. The following is now valid: self.I( _my_indicator, # Two names can now be passed name=["My Indicator One", "My Indicator Two"], self.data.Open, self.data.Close )
1 parent 235e516 commit e2122b9

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

backtesting/_plotting.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,23 @@ def __eq__(self, other):
539539
colors = value._opts['color']
540540
colors = colors and cycle(_as_list(colors)) or (
541541
cycle([next(ohlc_colors)]) if is_overlay else colorgen())
542-
legend_label = LegendStr(value.name)
543-
for j, arr in enumerate(value, 1):
542+
543+
if isinstance(value.name, str):
544+
tooltip_label = value.name
545+
if len(value) == 1:
546+
legend_labels = [LegendStr(value.name)]
547+
else:
548+
legend_labels = [
549+
LegendStr(f"{value.name}[{i}]")
550+
for i in range(len(value))
551+
]
552+
else:
553+
tooltip_label = ", ".join(value.name)
554+
legend_labels = [LegendStr(item) for item in value.name]
555+
556+
for j, arr in enumerate(value):
544557
color = next(colors)
545-
source_name = f'{legend_label}_{i}_{j}'
558+
source_name = f'{legend_labels[j]}_{i}_{j}'
546559
if arr.dtype == bool:
547560
arr = arr.astype(int)
548561
source.add(arr, source_name)
@@ -552,24 +565,24 @@ def __eq__(self, other):
552565
if is_scatter:
553566
fig.circle(
554567
'index', source_name, source=source,
555-
legend_label=legend_label, color=color,
568+
legend_label=legend_labels[j], color=color,
556569
line_color='black', fill_alpha=.8,
557570
radius=BAR_WIDTH / 2 * .9)
558571
else:
559572
fig.line(
560573
'index', source_name, source=source,
561-
legend_label=legend_label, line_color=color,
574+
legend_label=legend_labels[j], line_color=color,
562575
line_width=1.3)
563576
else:
564577
if is_scatter:
565578
r = fig.circle(
566579
'index', source_name, source=source,
567-
legend_label=LegendStr(legend_label), color=color,
580+
legend_label=legend_labels[j], color=color,
568581
radius=BAR_WIDTH / 2 * .6)
569582
else:
570583
r = fig.line(
571584
'index', source_name, source=source,
572-
legend_label=LegendStr(legend_label), line_color=color,
585+
legend_label=legend_labels[j], line_color=color,
573586
line_width=1.3)
574587
# Add dashed centerline just because
575588
mean = float(pd.Series(arr).mean())
@@ -580,9 +593,9 @@ def __eq__(self, other):
580593
line_color='#666666', line_dash='dashed',
581594
line_width=.5))
582595
if is_overlay:
583-
ohlc_tooltips.append((legend_label, NBSP.join(tooltips)))
596+
ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips)))
584597
else:
585-
set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r])
598+
set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r])
586599
# If the sole indicator line on this figure,
587600
# have the legend only contain text without the glyph
588601
if len(value) == 1:

backtesting/backtesting.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def I(self, # noqa: E743
9393
same length as `backtesting.backtesting.Strategy.data`.
9494
9595
In the plot legend, the indicator is labeled with
96-
function name, unless `name` overrides it.
96+
function name, unless `name` overrides it. If `func` returns
97+
multiple arrays, `name` can be a sequence of strings, and
98+
its size must agree with the number of arrays returned.
9799
98100
If `plot` is `True`, the indicator is plotted on the resulting
99101
`backtesting.backtesting.Backtest.plot`.
@@ -118,13 +120,21 @@ def I(self, # noqa: E743
118120
def init():
119121
self.sma = self.I(ta.SMA, self.data.Close, self.n_sma)
120122
"""
123+
def _format_name(name: str) -> str:
124+
return name.format(*map(_as_str, args),
125+
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
126+
121127
if name is None:
122128
params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values()))))
123129
func_name = _as_str(func)
124130
name = (f'{func_name}({params})' if params else f'{func_name}')
131+
elif isinstance(name, str):
132+
name = _format_name(name)
133+
elif try_(lambda: all(isinstance(item, str) for item in name), False):
134+
name = [_format_name(item) for item in name]
125135
else:
126-
name = name.format(*map(_as_str, args),
127-
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
136+
raise TypeError(f'Unexpected `name=` type {type(name)}; expected `str` or '
137+
'`Sequence[str]`')
128138

129139
try:
130140
value = func(*args, **kwargs)
@@ -142,6 +152,11 @@ def init():
142152
if is_arraylike and np.argmax(value.shape) == 0:
143153
value = value.T
144154

155+
if isinstance(name, list) and (np.atleast_2d(value).shape[0] != len(name)):
156+
raise ValueError(
157+
f'Length of `name=` ({len(name)}) must agree with the number '
158+
f'of arrays the indicator returns ({value.shape[0]}).')
159+
145160
if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close):
146161
raise ValueError(
147162
'Indicators must return (optionally a tuple of) numpy.arrays of same '

backtesting/test/_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,37 @@ def test_resample(self):
771771
# Give browser time to open before tempfile is removed
772772
time.sleep(1)
773773

774+
def test_indicator_name(self):
775+
test_self = self
776+
777+
class S(Strategy):
778+
def init(self):
779+
def _SMA():
780+
return SMA(self.data.Close, 5), SMA(self.data.Close, 10)
781+
782+
test_self.assertRaises(TypeError, self.I, _SMA, name=42)
783+
test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", ))
784+
test_self.assertRaises(
785+
ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three"))
786+
787+
for overlay in (True, False):
788+
self.I(SMA, self.data.Close, 5, overlay=overlay)
789+
self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay)
790+
self.I(SMA, self.data.Close, 5, name=("My SMA", ), overlay=overlay)
791+
self.I(_SMA, overlay=overlay)
792+
self.I(_SMA, name="My SMA", overlay=overlay)
793+
self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay)
794+
795+
def next(self):
796+
pass
797+
798+
bt = Backtest(GOOG, S)
799+
bt.run()
800+
with _tempfile() as f:
801+
bt.plot(filename=f,
802+
plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False,
803+
open_browser=False)
804+
774805
def test_indicator_color(self):
775806
class S(Strategy):
776807
def init(self):

0 commit comments

Comments
 (0)