faster binary hole filling

This commit is contained in:
Michael Krayer 2021-08-16 21:46:40 +02:00
parent c5908e00f5
commit 2acef17323
1 changed files with 70 additions and 63 deletions

133
field.py
View File

@ -774,9 +774,6 @@ class Features3d:
self._offset = offset self._offset = offset
self._faces = contour.faces.reshape(contour.n_faces,4)[ind,:] self._faces = contour.faces.reshape(contour.n_faces,4)[ind,:]
self._points = contour.points self._points = contour.points
# self._offset = offset
# self._faces = contour.faces.reshape(contour.n_faces,4)
# self._points = contour.points
# Compute the volume and area per cell. For the volume computation, an arbitrary component # Compute the volume and area per cell. For the volume computation, an arbitrary component
# of the normal has to be chosen which defaults to the z-component and is set by # of the normal has to be chosen which defaults to the z-component and is set by
# 'cellvol_normal_component'. # 'cellvol_normal_component'.
@ -786,9 +783,10 @@ class Features3d:
C = self._points[self._faces[:,3],:] C = self._points[self._faces[:,3],:]
cn = np.cross(B-A,C-A) cn = np.cross(B-A,C-A)
# Check if cell normal points in direction of gradient. If not, switch vertex order. # Check if cell normal points in direction of gradient. If not, switch vertex order.
idx = (contour.point_arrays['Gradients'][self._faces[:,1],:]*cn).sum(axis=-1)<0 idx = (contour.point_arrays['Gradients'][self._faces[:,1],:]*cn).sum(axis=-1)>0
self._faces[np.ix_(idx,[2,3])] = self._faces[np.ix_(idx,[3,2])] # print(idx.shape,np.sum(idx),self._faces.shape,self._faces[idx,2:].shape,self._faces[idx,3:1:-1].shape)
cn[idx] = -cn[idx] # self._faces[np.ix_(idx,[2,3])] = self._faces[np.ix_(idx,[3,2])]
# cn[idx,:] = -cn[idx,:]
# Compute area and signed volume per cell # Compute area and signed volume per cell
cc = (A+B+C)/3 cc = (A+B+C)/3
self._cell_areas = 0.5*np.sqrt(np.square(cn).sum(axis=1)) self._cell_areas = 0.5*np.sqrt(np.square(cn).sum(axis=1))
@ -826,6 +824,7 @@ class Features3d:
features = self.list_of_features(features) features = self.list_of_features(features)
# Get index ranges which are to be deleted and also create an array # Get index ranges which are to be deleted and also create an array
# which determines the size of the block to be deleted # which determines the size of the block to be deleted
print('Discarding',features)
idx = [] idx = []
gapsize = np.zeros((self._nfeatures,),dtype=np.int) gapsize = np.zeros((self._nfeatures,),dtype=np.int)
for feature in features: for feature in features:
@ -1169,7 +1168,6 @@ class BinaryFieldNd:
"'periodicity' requires bool values." "'periodicity' requires bool values."
assert len(periodicity)==input.ndim,\ assert len(periodicity)==input.ndim,\
"Number of entries in 'periodicity' must match dimension of binary field." "Number of entries in 'periodicity' must match dimension of binary field."
from scipy import ndimage
if has_ghost and deep: if has_ghost and deep:
self._data = input.copy() self._data = input.copy()
elif has_ghost: elif has_ghost:
@ -1184,10 +1182,7 @@ class BinaryFieldNd:
self.nlabels = None self.nlabels = None
self.wrap = tuple(self._ndim*[None]) self.wrap = tuple(self._ndim*[None])
self.periodicity = tuple(bool(x) for x in periodicity) self.periodicity = tuple(bool(x) for x in periodicity)
if connect_diagonals: self.connect_diagonals = connect_diagonals
self.structure = ndimage.generate_binary_structure(self._ndim,self._ndim)
else:
self.structure = ndimage.generate_binary_structure(self._ndim,1)
@property @property
def data(self): def data(self):
@ -1199,34 +1194,37 @@ class BinaryFieldNd:
return None return None
return self._labels[self._sldata] return self._labels[self._sldata]
def label(self): def label(self,use_cc3d=False):
'''Labels connected regions in binary fields.''' '''Labels connected regions in binary fields.'''
from scipy import ndimage if use_cc3d:
if any(self.periodicity): import cc3d
self._labels,self.nlabels,self.wrap = self._labels_periodic() if self._ndim==2: connectivity = 8 if self.connect_diagonals else 4
elif self._ndim==3: connectivity = 18 if self.connect_diagonals else 6
else: raise RuntimeError("'use_cc3d' can only be used with 2D or 3D data.")
else: else:
self._labels,self.nlabels = ndimage.label(self._data,structure=self.structure) from scipy import ndimage
if self.connect_diagonals: structure = ndimage.generate_binary_structure(self._ndim,self._ndim)
def _labels_periodic(self,map_to_zero=False): else: structure = ndimage.generate_binary_structure(self._ndim,1)
'''Label features in an array while taking into account periodic wrapping. #
If map_to_zero=True, every feature which overlaps or is attached to the if use_cc3d:
periodic boundary will be removed.''' self._labels,self.nlabels = cc3d.connected_components(self._data,connectivity=connectivity,return_N=True)
from scipy import ndimage else:
# Compute labels on padded array self._labels,self.nlabels = ndimage.label(self._data,structure=structure)
labels_,nlabels_ = ndimage.label(self._data,structure=self.structure) if not any(self.periodicity):
return
# Get a mapping of labels which differ at periodic overlap # Get a mapping of labels which differ at periodic overlap
map_ = np.array(range(nlabels_+1),dtype=labels_.dtype) map_ = np.array(range(nlabels_+1),dtype=labels_.dtype)
wrap_ = self._ndim*[None] wrap_ = self._ndim*[None]
for axis in range(self._ndim): for axis in range(self._ndim):
if not self.periodicity[axis]: continue if not self.periodicity[axis]: continue
sl_lo = tuple(slice(0,1) if ii==axis else slice(None) for ii in range(self._ndim)) sl_lo = tuple(slice(0,1) if ii==axis else slice(None) for ii in range(self._ndim))
sl_hi = tuple(slice(-1,None) if ii==axis else slice(None) for ii in range(self._ndim)) sl_hi = tuple(slice(-1,None) if ii==axis else slice(None) for ii in range(self._ndim))
sl_pre = tuple(slice(-2,-1) if ii==axis else slice(None) for ii in range(self._ndim)) sl_pre = tuple(slice(-2,-1) if ii==axis else slice(None) for ii in range(self._ndim))
lab_lo = labels_[sl_lo] lab_lo = self._labels[sl_lo]
lab_hi = labels_[sl_hi] lab_hi = self._labels[sl_hi]
lab_pre = np.unique(labels_[sl_pre]) # all labels in last (unwrapped) slice lab_pre = np.unique(self._labels[sl_pre]) # all labels in last (unwrapped) slice
# Initialize array to keep track of wrapping # Initialize array to keep track of wrapping
wrap_[axis] = np.zeros(nlabels_+1,dtype=bool) wrap_[axis] = np.zeros(self.nlabels+1,dtype=bool)
# Determine new label and map # Determine new label and map
lab_new = np.minimum(lab_lo,lab_hi) lab_new = np.minimum(lab_lo,lab_hi)
for lab_ in [lab_lo,lab_hi]: for lab_ in [lab_lo,lab_hi]:
@ -1236,47 +1234,56 @@ class BinaryFieldNd:
for idx_ in np.unique(lab_li,return_index=True)[1]: for idx_ in np.unique(lab_li,return_index=True)[1]:
source_ = lab_li[idx_] # the label to be changed source_ = lab_li[idx_] # the label to be changed
target_ = lab_new_li[idx_] # the label which will be newly assigned target_ = lab_new_li[idx_] # the label which will be newly assigned
if map_to_zero and source_ in lab_pre: while target_ != map_[target_]: # map it recursively
map_[source_] = 0 target_ = map_[target_]
map_[target_] = 0 map_[source_] = target_
else: if source_ in lab_pre: # check if source is not a ghost
while target_ != map_[target_]: # map it recursively wrap_[axis][target_] = True
target_ = map_[target_]
map_[source_] = target_
if source_ in lab_pre: # check if source is not a ghost
wrap_[axis][target_] = True
# Remove gaps from target mapping # Remove gaps from target mapping
idx_,map_ = np.unique(map_,return_index=True,return_inverse=True)[1:3] idx_,map_ = np.unique(map_,return_index=True,return_inverse=True)[1:3]
# Relabel and remove padding # Relabel
labels_ = map_[labels_] self._labels = map_[self._labels]
nlabels_ = np.max(map_) self.nlabels = np.max(map_)
assert nlabels_==len(idx_)-1, "DEBUG assertion" self.wrap = tuple(None if x is None else x[idx_] for x in wrap_)
self.wrap = tuple(None if x is None else x[idx_] for x in self.wrap) return
# for axis in range(self._ndim):
# if wrap_[axis] is not None:
# wrap_[axis] = wrap_[axis][idx_]
return labels_,nlabels_,tuple(wrap_)
def fill_holes(self): def fill_holes(self,keep_wall_attached=True,use_cc3d=False,return_mask=False):
'''Fill the holes in binary objects while taking into account periodicity. '''Fill the holes in binary objects while taking into account periodicity.
In the non-periodic sense, a hole is a region of zeros which does not connect In the non-periodic sense, a hole is a region of zeros which does not connect
to a boundary. In the periodic sense, a hole is a region of zeros which is not to a boundary. When keep_wall_attached==False, only regions are kept which
connected to itself accross the periodic boundaries.''' fully connect from top to bottom wall.
from scipy import ndimage In the periodic sense, a hole is a region of zeros which is not
# Reimplementation of "binary_fill_holes" from ndimage connected to itself accross the opposite periodic boundary.'''
mask = np.logical_not(self._data) # only modify locations which are "False" at the moment if return_mask: mask = self._data.copy()
tmp = np.zeros(mask.shape,bool) # create empty array to "grow from boundaries"
ndimage.binary_dilation(tmp,structure=None,iterations=-1,
mask=mask,output=self._data,border_value=1,
origin=0) # everything connected to the boundary is now True in self._data
# Remove holes which overlap the boundaries
if any(self.periodicity):
self._data = self._labels_periodic(map_to_zero=True)[0]>0
# Invert to get the final result
np.logical_not(self._data,self._data) np.logical_not(self._data,self._data)
if use_cc3d:
import cc3d
if self._ndim==2: connectivity = 8 if self.connect_diagonals else 4
elif self._ndim==3: connectivity = 18 if self.connect_diagonals else 6
else: raise RuntimeError("'use_cc3d' can only be used with 2D or 3D data.")
labels_,nlabels_ = cc3d.connected_components(self._data,connectivity=connectivity,return_N=True)
else:
from scipy import ndimage
if self.connect_diagonals: structure = ndimage.generate_binary_structure(self._ndim,self._ndim)
else: structure = ndimage.generate_binary_structure(self._ndim,1)
labels_,nlabels_ = ndimage.label(self._data,structure=structure)
labels_keep = set()
for axis in range(3):
sl_lo = tuple(slice(None) if ii!=axis else 0 for ii in range(3))
sl_hi = tuple(slice(None) if ii!=axis else -1 for ii in range(3))
if self.periodicity[axis]:
labels_keep |= set(np.unique(labels_[sl_lo])) & set(np.unique(labels_[sl_hi]))
elif keep_wall_attached:
labels_keep |= set(np.unique(labels_[sl_lo])) | set(np.unique(labels_[sl_hi]))
labels_keep.discard(0)
self._data = np.isin(labels_,tuple(labels_keep),invert=True)
# If labels have been computed already, recompute them to stay consistent # If labels have been computed already, recompute them to stay consistent
if self._labels is not None: if self._labels is not None: self.label()
self.label() # Compute mask if it is to be returned, otherwise we are done
if return_mask:
np.logical_xor(mask,self._data,out=mask)
return mask
return
def probe(self,idx,probe_label=False): def probe(self,idx,probe_label=False):
'''Returns whether or not a point at idx is True or False.''' '''Returns whether or not a point at idx is True or False.'''