Skip to content

Commit de294b6

Browse files
committed
Allow saving with pickle
1 parent 62b80a5 commit de294b6

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pyESN.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ def correct_dimensions(s, targetlength):
2424
return s
2525

2626

27+
def identity(x):
28+
return x
29+
30+
2731
class ESN():
2832

2933
def __init__(self, n_inputs, n_outputs, n_reservoir=200,
3034
spectral_radius=0.95, sparsity=0, noise=0.001, input_shift=None,
3135
input_scaling=None, teacher_forcing=True, feedback_scaling=None,
3236
teacher_scaling=None, teacher_shift=None,
33-
out_activation=lambda x: x, inverse_out_activation=lambda x: x,
37+
out_activation=identity, inverse_out_activation=identity,
3438
random_state=None, silent=True):
3539
"""
3640
Args:

testing.py

+11
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,17 @@ def test_nonsense(self):
5858
ESN(N_in, N_out, random_state=0.5)
5959
self.assertIn("Invalid seed", str(cm.exception))
6060

61+
def test_serialisation(self):
62+
import pickle
63+
import io
64+
esn = ESN(N_in, N_out, random_state=1)
65+
with io.BytesIO() as buf:
66+
pickle.dump(esn, buf)
67+
buf.flush()
68+
buf.seek(0)
69+
esn_unpickled = pickle.load(buf)
70+
self._compare(esn, esn_unpickled, should_be='same')
71+
6172

6273
class InitArguments(unittest.TestCase):
6374

0 commit comments

Comments
 (0)