1- __author__ = 'Thomas Rueckstiess, ruecksti@in.tum.de'
1+ # -*- coding: utf-8 -*-
22
33from pybrain .rl .learners .directsearch .directsearch import DirectSearchLearner
44from pybrain .rl .learners .learner import DataSetLearner , ExploringLearner
99from pybrain .structure .networks import FeedForwardNetwork
1010from pybrain .structure .connections import IdentityConnection
1111
12+ __author__ = 'Thomas Rueckstiess, ruecksti@in.tum.de'
13+
1214
1315class LoglhDataSet (DataSet ):
1416 def __init__ (self , dim ):
@@ -39,7 +41,6 @@ def __init__(self):
3941 # network to tie module and explorer together
4042 self .network = None
4143
42-
4344 def _setLearningRate (self , alpha ):
4445 """ pass the alpha value through to the gradient descent object """
4546 self .gd .alpha = alpha
@@ -49,13 +50,13 @@ def _getLearningRate(self):
4950
5051 learningRate = property (_getLearningRate , _setLearningRate )
5152
52- def _setModule (self , module ):
53+ def _setModule (self , moduleParam ):
5354 """ initialize gradient descender with module parameters and
5455 the loglh dataset with the outdim of the module. """
55- self ._module = module
56+ self ._module = moduleParam
5657
5758 # initialize explorer
58- self ._explorer = NormalExplorer (module .outdim )
59+ self ._explorer = NormalExplorer (moduleParam .outdim )
5960
6061 # build network
6162 self ._initializeNetwork ()
@@ -81,7 +82,6 @@ def _getExplorer(self):
8182
8283 explorer = property (_getExplorer , _setExplorer )
8384
84-
8585 def _initializeNetwork (self ):
8686 """ build the combined network consisting of the module and
8787 the explorer and initializing the log likelihoods dataset.
@@ -98,12 +98,14 @@ def _initializeNetwork(self):
9898 # initialize loglh dataset
9999 self .loglh = LoglhDataSet (self .network .paramdim )
100100
101-
102101 def learn (self ):
103102 """ calls the gradient calculation function and executes a step in direction
104103 of the gradient, scaled with a small learning rate alpha. """
105- assert self .dataset != None
106- assert self .module != None
104+ if self .dataset is None :
105+ raise Exception ("Dataset must be not None!" )
106+
107+ if self .module is None :
108+ raise Exception ("Module must be not None!" )
107109
108110 # calculate the gradient with the specific function from subclass
109111 gradient = self .calculateGradient ()
0 commit comments