Skip to content

Commit 492ee27

Browse files
author
Timur Gilmullin
committed
#14: refactor policygradient.py
1 parent 12678ae commit 492ee27

1 file changed

Lines changed: 11 additions & 9 deletions

File tree

pybrain/rl/learners/directsearch/policygradient.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__author__ = 'Thomas Rueckstiess, ruecksti@in.tum.de'
1+
# -*- coding: utf-8 -*-
22

33
from pybrain.rl.learners.directsearch.directsearch import DirectSearchLearner
44
from pybrain.rl.learners.learner import DataSetLearner, ExploringLearner
@@ -9,6 +9,8 @@
99
from pybrain.structure.networks import FeedForwardNetwork
1010
from pybrain.structure.connections import IdentityConnection
1111

12+
__author__ = 'Thomas Rueckstiess, ruecksti@in.tum.de'
13+
1214

1315
class LoglhDataSet(DataSet):
1416
def __init__(self, dim):
@@ -39,7 +41,6 @@ def __init__(self):
3941
# network to tie module and explorer together
4042
self.network = None
4143

42-
4344
def _setLearningRate(self, alpha):
4445
""" pass the alpha value through to the gradient descent object """
4546
self.gd.alpha = alpha
@@ -49,13 +50,13 @@ def _getLearningRate(self):
4950

5051
learningRate = property(_getLearningRate, _setLearningRate)
5152

52-
def _setModule(self, module):
53+
def _setModule(self, moduleParam):
5354
""" initialize gradient descender with module parameters and
5455
the loglh dataset with the outdim of the module. """
55-
self._module = module
56+
self._module = moduleParam
5657

5758
# initialize explorer
58-
self._explorer = NormalExplorer(module.outdim)
59+
self._explorer = NormalExplorer(moduleParam.outdim)
5960

6061
# build network
6162
self._initializeNetwork()
@@ -81,7 +82,6 @@ def _getExplorer(self):
8182

8283
explorer = property(_getExplorer, _setExplorer)
8384

84-
8585
def _initializeNetwork(self):
8686
""" build the combined network consisting of the module and
8787
the explorer and initializing the log likelihoods dataset.
@@ -98,12 +98,14 @@ def _initializeNetwork(self):
9898
# initialize loglh dataset
9999
self.loglh = LoglhDataSet(self.network.paramdim)
100100

101-
102101
def learn(self):
103102
""" calls the gradient calculation function and executes a step in direction
104103
of the gradient, scaled with a small learning rate alpha. """
105-
assert self.dataset != None
106-
assert self.module != None
104+
if self.dataset is None:
105+
raise Exception("Dataset must be not None!")
106+
107+
if self.module is None:
108+
raise Exception("Module must be not None!")
107109

108110
# calculate the gradient with the specific function from subclass
109111
gradient = self.calculateGradient()

0 commit comments

Comments
 (0)