Commit faec00de authored by Tom's avatar Tom
Browse files

(Hopefully) fixed data type problems with find_objects/label on 32 bit machines.

parent ebcd7d9f
......@@ -50,7 +50,7 @@ def pyargsort(seq,cmp=cmp,key=lambda x:x):
return sorted(range(len(seq)),key=lambda x:key(seq.__getitem__(x)),cmp=cmp)
def renumber_by_xcenter(seg):
objects = [(slice(0,0),slice(0,0))]+measurements.find_objects(seg)
objects = [(slice(0,0),slice(0,0))]+find_objects(seg)
def xc(o): return mean((o[1].start,o[1].stop))
xs = array([xc(o) for o in objects])
order = argsort(xs)
......@@ -58,20 +58,32 @@ def renumber_by_xcenter(seg):
for i,j in enumerate(order): segmap[j] = i
return segmap[seg]
def flexible_find_objects(image):
"""Like measurements.find_objects, but tries to
be a bit more flexible about the datatypes it accepts."""
# first try the default type
try: return measurements.find_objects(image)
def label(image,**kw):
"""measurements.label fails to document what types it accepts,
and it fails randomly with different types on different
platforms. This tries to work around that."""
try: return measurements.label(image,**kw)
except: pass
types = ["int32","int64","int16"]
types = ["int32","uint32","int64","unit64","int16","uint16"]
for t in types:
# try with type conversions
try: return measurements.find_objects(array(image,dtype=t))
try: return measurements.label(array(image,dtype=t),**kw)
except: pass
# let it raise the same exception as before
return measurements.find_objects(image)
return measurements.label(image,**kw)
def find_objects(image,**kw):
"""measurements.find_objects fails to document what types it accepts,
and it fails randomly with different types on different
platforms. This tries to work around that."""
try: return measurements.find_objects(image,**kw)
except: pass
types = ["int32","uint32","int64","unit64","int16","uint16"]
for t in types:
try: return measurements.find_objects(array(image,dtype=t),**kw)
except: pass
# let it raise the same exception as before
return measurements.find_objects(image,**kw)
def rgb2int(image):
"""Converts a rank 3 array with RGB values stored in the
last axis into a rank 2 array containing 32 bit RGB values."""
......@@ -109,7 +121,7 @@ class RegionExtractor:
labels,correspondence = renumber_labels_ordered(labels,correspondence=1)
self.labels = labels
self.correspondence = correspondence
self.objects = [None]+flexible_find_objects(labels)
self.objects = [None]+find_objects(labels)
def setPageColumns(self,image):
"""Set the image to be iterated over. This should be an RGB image,
ndim==3, dtype=='B'. This iterates over the columns."""
......@@ -811,7 +823,7 @@ def estimate_xheight(line,scale=1.0,debug=0):
return bottom-top,bottom
def keep_marked(image,markers):
labels,_ = measurements.label(image)
labels,_ = label(image)
imshow(sin(17.1*labels),cmap=cm.jet)
marked = unique(labels*(markers!=0))
print marked
......@@ -840,7 +852,7 @@ def latin_filter(line,scale=1.0,r=1.2,debug=0):
def remove_noise(line,minsize=8):
bin = (line>0.5*amax(line))
labels,n = measurements.label(bin)
labels,n = label(bin)
sums = measurements.sum(bin,labels,range(n+1))
sums = sums[labels]
good = minimum(bin,1-(sums>0)*(sums<minsize))
......
......@@ -6,8 +6,8 @@ import scipy
from scipy import stats
from scipy.ndimage import measurements
from pylab import *
from common import *
import common
def avg(*args):
return mean(args)
......@@ -19,7 +19,7 @@ def seg_boxes(seg,math=0):
coordinates are used (however, the order of the values in the
tuple doesn't change)."""
seg = array(seg,'uint32')
slices = measurements.find_objects(seg)
slices = common.find_objects(seg)
h = seg.shape[0]
result = []
for i in range(len(slices)):
......@@ -117,7 +117,7 @@ def bbox(image):
"""Compute the bounding box for the pixels in the image."""
assert len(image.shape)==2,"wrong shape: "+str(image.shape)
image = array(image!=0,'uint32')
cs = scipy.ndimage.measurements.find_objects(image)
cs = common.find_objects(image)
if len(cs)<1: return None
c = cs[0]
return (c[0].start,c[1].start,c[0].stop,c[1].stop)
......
......@@ -69,7 +69,7 @@ class Grouper(PyComponent):
# print sorted(correspondences)
self.pre2seg = correspondences
# compute the bounding boxes in order
boxes = [None]+measurements.find_objects(segmentation)
boxes = [None]+common.find_objects(segmentation)
n = len(boxes)
# now consider groups of boxes
groups = []
......@@ -107,7 +107,7 @@ class Grouper(PyComponent):
the groups corresponding to each labeled object. Objects should be labeled
consecutively."""
# compute the bounding boxes in order
boxes = [None] + measurements.find_objects(segmentation)
boxes = [None] + common.find_objects(segmentation)
n = len(boxes)
# now consider groups of boxes
groups = []
......
from pylab import *
from scipy.ndimage import filters,morphology,measurements
import psegutils
import common
......@@ -85,7 +86,7 @@ def ccslineseg(image):
center = filters.maximum_filter(center,(3,3))
center = psegutils.keep_marked(image>0.5,center)
center = filters.maximum_filter(center,(2,2))
center,_ = measurements.label(center)
center,_ = common.label(center)
center = psegutils.spread_labels(center)
center *= image
return center
......@@ -136,7 +137,7 @@ class DPSegmentLine(SimpleParams):
tracks = dplineseg2(line,imweight=self.imweight,bweight=self.bweight,
diagweight=self.diagweight,debug=self.debug,r=self.r)
tracks = array(tracks<0.5*amax(tracks),'i')
tracks,_ = measurements.label(tracks)
tracks,_ = common.label(tracks)
self.tracks = tracks
rsegs = psegutils.spread_labels(tracks)
rsegs = rsegs*(line>0.5*amax(line))
......
......@@ -4,6 +4,7 @@ import argparse,glob,os,os.path
from scipy.ndimage import filters,interpolation,morphology,measurements
from scipy import stats
from scipy.misc import imsave
import common
class record:
def __init__(self,**kw): self.__dict__.update(kw)
......@@ -80,7 +81,7 @@ def spread_labels(labels,maxdist=9999999):
return spread
def keep_marked(image,markers):
labels,_ = measurements.label(image)
labels,_ = common.label(image)
marked = unique(labels*(markers!=0))
kept = in1d(labels.ravel(),marked)
return (image!=0)*kept.reshape(*labels.shape)
......@@ -100,7 +101,7 @@ def correspondences(labels1,labels2):
def propagate_labels_simple(regions,labels):
"""Spread the labels to the corresponding regions."""
rlabels,_ = measurements.label(regions)
rlabels,_ = common.label(regions)
cors = correspondences(rlabels,labels)
outputs = zeros(amax(rlabels)+1,'i')
for o,i in cors.T: outputs[o] = i
......@@ -109,7 +110,7 @@ def propagate_labels_simple(regions,labels):
def propagate_labels(regions,labels,conflict=0):
"""Spread the labels to the corresponding regions."""
rlabels,_ = measurements.label(regions)
rlabels,_ = common.label(regions)
cors = correspondences(rlabels,labels)
outputs = zeros(amax(rlabels)+1,'i')
oops = -(1<<30)
......@@ -126,8 +127,8 @@ def A(s): return W(s)*H(s)
def M(s): return mean([s[0].start,s[0].stop]),mean([s[1].start,s[1].stop])
def binary_objects(binary):
labels,n = measurements.label(binary)
objects = measurements.find_objects(labels)
labels,n = common.label(binary)
objects = common.find_objects(labels)
return objects
def estimate_scale(binary):
......@@ -153,7 +154,7 @@ def compute_boxmap(binary,scale,threshold=(.5,4),dtype='i'):
def compute_lines(segmentation,scale):
"""Given a line segmentation map, computes a list
of tuples consisting of 2D slices and masked images."""
lobjects = measurements.find_objects(segmentation)
lobjects = common.find_objects(segmentation)
lines = []
for i,o in enumerate(lobjects):
if o is None: continue
......@@ -291,8 +292,8 @@ def rgbshow(r,g,b=None,gn=1,cn=0,ab=0,**kw):
imshow(clip(combo,0,1),**kw)
def select_regions(binary,f,min=0,nbest=100000):
labels,n = measurements.label(binary)
objects = measurements.find_objects(labels)
labels,n = common.label(binary)
objects = common.find_objects(labels)
scores = [f(o) for o in objects]
best = argsort(scores)
keep = zeros(len(objects)+1,'B')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment