bugfix: threshold() Field3d. Removed ConnectedRegions class, it has been replaced entirely by BindaryFieldNd

This commit is contained in:
Michael Krayer 2021-08-06 22:55:02 +02:00
parent f5714e1987
commit 2b4be2a375
1 changed files with 8 additions and 178 deletions

186
field.py
View File

@ -251,9 +251,9 @@ class Field3d:
'''Returns a binary array indicating which grid points are above,
or in case of invert=True below, a given threshold.'''
if invert:
return self.data>=val
else:
return self.data<val
else:
return self.data>=val
def coordinate(self,idx,axis=None):
if axis is None:
@ -656,19 +656,10 @@ class BinaryFieldNd:
self.labels = None
self.nlabels = 0
self.wrap = tuple(self._ndim*[None])
self._feat_slice = None
self._featsl = None
self.set_structure(False)
self.set_periodicity(self._ndim*[False])
@classmethod
def from_threshold(cls,fld,threshold,invert=False):
if isinstance(fld,Field3d):
fld = fld.data
if invert:
return cls(fld<threshold)
else:
return cls(fld>=threshold)
def set_periodicity(self,periodicity):
assert all([isinstance(x,(bool,int)) for x in periodicity]),\
"'periodicity' requires bool values."
@ -695,7 +686,7 @@ class BinaryFieldNd:
self.labels,self.nlabels,self.wrap = self._labels_periodic()
else:
self.labels,self.nlabels = ndimage.label(self.data,structure=self.structure)
self._feat_slice = ndimage.find_objects(self.labels)
self._featsl = ndimage.find_objects(self.labels)
def _labels_periodic(self,map_to_zero=False):
'''Label features in an array while taking into account periodic wrapping.
@ -820,6 +811,7 @@ class BinaryFieldNd:
return labels
def discard_feature(self,selection):
'''Removes a feature from data.'''
if self.labels is None:
self.label()
selection = self._select_feature(selection)
@ -852,8 +844,8 @@ class BinaryFieldNd:
data_ = np.logical_not(self.data)
if has_array: data2_ = array
else:
data_ = (self.labels[self._feat_slice[lab_-1]]==lab_)
if has_array: data2_ = array[self._feat_slice[lab_-1]]
data_ = (self.labels[self._featsl[lab_-1]]==lab_)
if has_array: data2_ = array[self._featsl[lab_-1]]
# If feature is wrapped periodically, duplicate it and extract
# largest one
iswrapped = False
@ -868,6 +860,7 @@ class BinaryFieldNd:
vol_ = np.bincount(l_.ravel())
il_ = np.argmax(vol_[1:])+1
sl_ = ndimage.find_objects(l_==il_)[0]
print(sl_)
data_ = data_[sl_]
if has_array:
data2_ = np.tile(data2_,rep_)[sl_]
@ -926,169 +919,6 @@ class BinaryFieldNd:
else:
raise ValueError('Invalid input. Accepting int,list,tuple,ndarray.')
class ConnectedRegions:
def __init__(self,binarr,periodicity,connect_diagonals=False,fill_holes=False,bytes_label=32):
assert isinstance(binarr,np.ndarray) and binarr.dtype==np.dtype('bool'),\
"'binarr' must be a numpy array of dtype('bool')."
assert all([isinstance(x,(bool,int)) for x in periodicity]),\
"'periodicity' requires bool values."
assert bytes_label in (8,16,32,64),\
"'bytes_label' must be one of {8,16,32,64}."
self._dim = binarr.shape
self._ndim = binarr.ndim
assert self._ndim in (2,3),\
"'binarr' must be either two or three dimensional."
assert len(periodicity)==self._ndim,\
"Length of 'periodicity' must match number of dimensions of data."
from scipy import ndimage
# Construct connectivity stencil
if self._ndim==2:
connectivity = np.ones((3,3),dtype='bool')
if not connect_diagonals:
connectivity[0,0] = False
connectivity[0,2] = False
connectivity[2,0] = False
connectivity[0,2] = False
else:
connectivity = np.ones((3,3,3),dtype='bool')
if not connect_diagonals:
connectivity[0,0,0] = False
connectivity[2,0,0] = False
connectivity[0,2,0] = False
connectivity[0,0,2] = False
connectivity[2,2,0] = False
connectivity[0,2,2] = False
connectivity[2,0,2] = False
connectivity[2,2,2] = False
# Compute labels:
# this does not take into account periodic wrapping
dtype_label = np.dtype('uint'+str(bytes_label))
self.label = np.empty(self._dim,dtype=dtype_label)
ndimage.label(binarr,structure=connectivity,output=self.label)
self.count = np.max(self.label)
# Merge labels if there are periodic overlaps
map_tgt = np.array(range(0,self.count+1),dtype=dtype_label)
for axis in range(self._ndim):
if not periodicity[axis]:
continue
# Merge the first and last plane and compute connectivity
sl = self._ndim*[slice(None)]
sl[axis] = (-1,0)
binarr_ = binarr[tuple(sl)]
label_ = np.empty(binarr_.shape,dtype=dtype_label)
ndimage.label(binarr_,structure=connectivity,output=label_)
for val_ in np.unique(label_):
# Get all global labels which are associated to a region
# connected over the boundary
global_labels = list(np.unique(self.label[tuple(sl)][label_==val_]))
# If there is only one label, nothing needs to be done
if len(global_labels)==1:
continue
# Determine target label:
# this needs to be done recursively because the original
# target may already be reassigned
tgt = global_labels[0]
while tgt!=map_tgt[tgt]:
tgt=map_tgt[tgt]
map_tgt[global_labels[1:]] = tgt
# Remove gaps from target mapping
map_tgt = np.unique(map_tgt,return_inverse=True)[1]
# Remap labels
self.label = map_tgt[self.label]
self.count = np.max(map_tgt)
@classmethod
def from_field(cls,fld3d,threshold,periodicity,connect_diagonals=False,bytes_label=32,invert=False):
voxthr = VoxelThreshold.from_field(fld3d,threshold,invert=invert)
return cls.from_voxelthresh(voxthr,periodicity,
connect_diagonals=connect_diagonals,
bytes_label=bytes_label)
@classmethod
def from_voxelthresh(cls,voxthr,periodicity,connect_diagonals=False,bytes_label=32):
return cls(voxthr.data,periodicity,
connect_diagonals=connect_diagonals,
bytes_label=bytes_label)
def volume(self,label=None):
'''Returns volume of labeled regions. If 'label' is None all volumes
are returned including the volume of the background region. The array
is sorted by labels, i.e. vol[0] is the volume of the background region,
vol[1] the volume of label 1, etc. If 'label' is an integer value, only
the volume of the corresponding region is returned.
Note: it is more efficient to retrieve all volumes at once than querying
single labels.'''
if label is None:
return np.bincount(self.label.ravel())
else:
return np.sum(self.label==label)
def volume_domain(self):
'''Returns volume of entire domain. Should be equal to sum(volume()).'''
return np.prod(self._dim)
def labels_by_volume(self,descending=True):
'''Returns labels of connected regions sorted by volume.'''
labels = np.argsort(self.volume()[1:])+1
if descending:
labels = labels[::-1]
return labels
def discard_regions(self,selection):
selection = self._parse_selection(selection)
# Map tagged regions to zero in order to discard them
map_tgt = np.array(range(0,self.count+1),dtype=self.label.dtype)
map_tgt[selection] = 0
# Remove gaps from target mapping
map_tgt = np.unique(map_tgt,return_inverse=True)[1]
# Discard regions
self.label = map_tgt[self.label]
self.count = np.max(map_tgt)
def probe(self,idx):
'''Returns label for given index.'''
return self.label[tuple(idx)]
def vtk_contour(self,fld3,val,selection):
'''Computes contours of a Field3d only within selected structures.'''
assert isinstance(fld3,Field3d), "'fld3' must be a Field3d instance."
assert tuple(self._dim)==tuple(fld3.dim()), \
"'fld3' must be of dimension {}.".format(self._dim)
selection = self._parse_selection(selection)
from scipy import ndimage
# Create binary map of selection
map_tgt = np.zeros(self.count+1,dtype='bool')
map_tgt[selection] = True
binary_map = map_tgt[self.label]
# Add an extra cell to get the contour interpolation right
print(np.sum(binary_map))
binary_map = ndimage.binary_dilation(binary_map)
print(np.sum(binary_map))
# Extract the subfield
fld_con = fld3.copy()
fld_con.data[~binary_map] = np.nan
# Compute the contour
return fld_con.vtk_contour(val)
def _parse_selection(self,selection):
dtype = self.label.dtype
if np.issubdtype(type(selection),np.integer):
return np.array(selection,dtype=dtype)
elif isinstance(selection,(list,tuple,np.ndarray)):
selection = np.array(selection)
if selection.dtype==np.dtype('bool'):
assert selection.ndim==1 and selection.shape[0]==self.count+1,\
"Boolean indexing must provide count+1 values."
else:
selection = selection.astype(dtype)
assert np.max(selection)<=self.count and np.min(selection)>=0,\
"Entry in selection is out-of-bounds."
return selection
else:
raise ValueError('Invalid input. Accepting int,list,tuple,ndarray.')
class ChunkIterator:
'''Iterates through all chunks. 'snapshot' must be an instance
of a class which returns a Field3d from the method call