Commit 45800292 authored by Emile Anclin's avatar Emile Anclin Committed by Philipp Zumstein
Browse files

strongly improved network stripping

parent 40273c2d
......@@ -292,6 +292,14 @@ class Softmax(Network):
self.No = No
self.W2 = randu(No,Nh+1)*initial_range
self.DW2 = zeros((No,Nh+1))
def postLoad(self):
self.DW2 = zeros(self.W2.shape)
def preSave(self):
for var in ('state', 'DW2'):
if hasattr(self, var):
delattr(self, var)
def ninputs(self):
return self.Nh
def noutputs(self):
......@@ -596,6 +604,12 @@ class Stacked(Network):
yield self
for sub in self.nets:
for x in sub.walk(): yield x
def preSave(self):
self.dstats = defaultdict(list) # reset
for delta in ('deltas', 'ldeltas'):
if hasattr(self, delta):
delattr(self, delta)
def ninputs(self):
return self.nets[0].ninputs()
def noutputs(self):
......@@ -859,11 +873,16 @@ class SeqRecognizer:
def walk(self):
for x in self.lstm.walk(): yield x
def clear_log(self):
def clear_log(self, deallocate_tempvars=False):
self.command_log = []
self.error_log = []
self.cerror_log = []
self.key_log = []
if deallocate_tempvars:
for attrname in ('outputs', 'targets', 'aligned'):
if hasattr(self, attrname):
delattr(self, attrname)
def __setstate__(self,state):
......@@ -874,7 +893,7 @@ class SeqRecognizer:
if "cerror_log" not in dir(self): self.cerror_log = []
if "key_log" not in dir(self): self.key_log = []
def info(self):
def setLearningRate(self,r,momentum=0.9):
def predictSequence(self,xs):
......@@ -156,7 +156,8 @@ def save_lstm(fname,network):
if args.strip:
print yellow('saving stripped network (without temporary variables)...')
for x in network.walk(): x.preSave()
if args.strip:
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment