Skip to content

Commit 3938cee

Browse files
author
Timur Gilmullin
committed
#14: refactor gaussprocess.py
1 parent 39a12de commit 3938cee

1 file changed

Lines changed: 15 additions & 7 deletions

File tree

pybrain/auxiliary/gaussprocess.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
__author__ = 'Thomas Rueckstiess, ruecksti@in.tum.de; Christian Osendorfer, osendorf@in.tum.de'
2-
1+
# -*- coding: utf-8 -*-
32

43
from scipy import r_, exp, zeros, eye, array, asarray, random, ravel, diag, sqrt, sin, cos, sort, mgrid, dot, floor
54
from scipy import c_ #@UnusedImport
65
from scipy.linalg import solve, inv
76
from pybrain.datasets import SupervisedDataSet
87
from scipy.linalg import norm
98

9+
__author__ = 'Thomas Rueckstiess, ruecksti@in.tum.de; Christian Osendorfer, osendorf@in.tum.de'
10+
1011

1112
class GaussianProcess:
1213
""" This class represents a basic n-dimensional Gaussian Process. The implementation
@@ -51,12 +52,16 @@ def _kernel(self, a, b):
5152
def _buildGrid(self):
5253
(start, stop, step) = (self.start, self.stop, self.step)
5354
""" returns a mgrid type of array for 'dim' dimensions """
54-
if isinstance(start, (int, long, float, complex)):
55-
dimstr = 'start:stop:step, '*self.indim
55+
if isinstance(start, (int, float, complex)):
56+
dimstr = 'start:stop:step, ' * self.indim
57+
5658
else:
57-
assert len(start) == len(stop) == len(step)
59+
if not (len(start) == len(stop) == len(step)):
60+
raise Exception("not len(start) == len(stop) == len(step)")
61+
5862
dimstr = ["start[%i]:stop[%i]:step[%i], " % (i, i, i) for i in range(len(start))]
5963
dimstr = ''.join(dimstr)
64+
6065
return eval('c_[map(ravel, mgrid[' + dimstr + '])]').T
6166

6267
def _buildCov(self, a, b):
@@ -75,8 +80,11 @@ def reset(self):
7580

7681
def trainOnDataset(self, dataset):
7782
""" takes a SequentialDataSet with indim input dimension and scalar target """
78-
assert (dataset.getDimension('input') == self.indim)
79-
assert (dataset.getDimension('target') == 1)
83+
if dataset.getDimension('input') != self.indim:
84+
raise Exception("(dataset.getDimension('input') != self.indim)")
85+
86+
if dataset.getDimension('target') != 1:
87+
raise Exception("dataset.getDimension('target') != 1")
8088

8189
self.trainx = dataset.getField('input')
8290
self.trainy = ravel(dataset.getField('target'))

0 commit comments

Comments
 (0)