Skip to content

Commit 01b463d

Browse files
author
Timur Gilmullin
committed
#14: refactor gradientdescent module
1 parent f981e23 commit 01b463d

1 file changed

Lines changed: 33 additions & 9 deletions

File tree

pybrain/auxiliary/gradientdescent.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
__author__ = ('Thomas Rueckstiess, ruecksti@in.tum.de'
2-
'Justin Bayer, bayer.justin@googlemail.com')
3-
1+
# -*- coding: utf-8 -*-
42

53
from scipy import zeros, asarray, sign, array, cov, dot, clip, ndarray
64
from scipy.linalg import inv
75

6+
__author__ = ('Thomas Rueckstiess, ruecksti@in.tum.de'
7+
'Justin Bayer, bayer.justin@googlemail.com')
8+
89

910
class GradientDescent(object):
1011

1112
def __init__(self):
1213
""" initialize algorithms with standard parameters (typical values given in parentheses)"""
1314

15+
self.values = []
16+
1417
# --- BackProp parameters ---
1518
# learning rate (0.1-0.001, down to 1e-7 for RNNs)
1619
self.alpha = 0.1
@@ -33,15 +36,17 @@ def __init__(self):
3336
self.etaplus = 1.2
3437
self.etaminus = 0.5
3538
self.lastgradient = None
39+
self.rprop_theta = 0.
3640

3741
def init(self, values):
3842
""" call this to initialize data structures *after* algorithm to use
3943
has been selected
40-
4144
:arg values: the list (or array) of parameters to perform gradient descent on
4245
(will be copied, original not modified)
4346
"""
44-
assert isinstance(values, ndarray)
47+
if not isinstance(values, ndarray):
48+
raise Exception("{} is not instance of type {}".format(str(values), str(ndarray)))
49+
4550
self.values = values.copy()
4651
if self.rprop:
4752
self.lastgradient = zeros(len(values), dtype='float64')
@@ -54,7 +59,9 @@ def init(self, values):
5459
def __call__(self, gradient, error=None):
5560
""" calculates parameter change based on given gradient and returns updated parameters """
5661
# check if gradient has correct dimensionality, then make array """
57-
assert len(gradient) == len(self.values)
62+
if len(gradient) != len(self.values):
63+
raise Exception("{} is not equal to {}".format(str(gradient), str(self.values)))
64+
5865
gradient_arr = asarray(gradient)
5966

6067
if self.rprop:
@@ -66,7 +73,7 @@ def __call__(self, gradient, error=None):
6673
# update rprop meta parameters
6774
dirSwitch = self.lastgradient * gradient_arr
6875
rprop_theta[dirSwitch > 0] *= self.etaplus
69-
idx = dirSwitch < 0
76+
idx = dirSwitch < 0
7077
rprop_theta[idx] *= self.etaminus
7178
gradient_arr[idx] = 0
7279

@@ -97,6 +104,8 @@ def __call__(self, gradient, error=None):
97104
class NaturalGradient(object):
98105

99106
def __init__(self, samplesize):
107+
self.values = []
108+
100109
# Counter after how many samples a new gradient estimate will be
101110
# returned.
102111
self.samplesize = samplesize
@@ -110,14 +119,17 @@ def __call__(self, gradient, error=None):
110119
# Append a copy to make sure this one is not changed after by the
111120
# client.
112121
self.samples.append(array(gradient))
122+
113123
# Return None if no new estimate is being given.
114124
if len(self.samples) < self.samplesize:
115125
return None
126+
116127
# After all the samples have been put into a single array, we can
117128
# delete them.
118129
gradientarray = array(self.samples).T
119130
inv_covar = inv(cov(gradientarray))
120131
self.values += dot(inv_covar, gradientarray.sum(axis=1))
132+
121133
return self.values
122134

123135

@@ -126,9 +138,18 @@ class IRpropPlus(object):
126138
def __init__(self, upfactor=1.1, downfactor=0.9, bound=0.5):
127139
self.upfactor = upfactor
128140
self.downfactor = downfactor
141+
129142
if not bound > 0:
130143
raise ValueError("bound greater than 0 needed.")
131144

145+
self.bound = bound
146+
self.values = None
147+
self.prev_values = []
148+
self.more_prev_values = []
149+
self.previous_gradient = 0
150+
self.step = 0
151+
self.previous_error = 0.
152+
132153
def init(self, values):
133154
self.values = values.copy()
134155
self.prev_values = values.copy()
@@ -142,21 +163,24 @@ def __call__(self, gradient, error):
142163
signs = sign(gradient)
143164

144165
# For positive gradient parts.
145-
positive = (products > 0).astype('int8')
166+
positive = int(products > 0)
146167
pos_step = self.step * self.upfactor * positive
147168
clip(pos_step, -self.bound, self.bound)
148169
pos_update = self.values - signs * pos_step
149170

150171
# For negative gradient parts.
151-
negative = (products < 0).astype('int8')
172+
negative = int(products < 0)
152173
neg_step = self.step * self.downfactor * negative
153174
clip(neg_step, -self.bound, self.bound)
175+
154176
if error <= self.previous_error:
155177
# If the error has decreased, do nothing.
156178
neg_update = zeros(gradient.shape)
179+
157180
else:
158181
# If it has increased, move back 2 steps.
159182
neg_update = self.more_prev_values
183+
160184
# Set all negative gradients to zero for the next step.
161185
gradient *= positive
162186

0 commit comments

Comments
 (0)