diff --git a/field.py b/field.py index a14c863..df37c9d 100644 --- a/field.py +++ b/field.py @@ -603,44 +603,200 @@ def gaussian_filter_umean_channel(array,spacing,sigma,truncate=4.0): array = ndimage.gaussian_filter1d(array,sigma_img,axis=1,truncate=truncate,mode='mirror') return array -class VoxelThreshold: - def __init__(self,data,threshold,invert=False): - assert isinstance(data,np.ndarray),\ - "'data' must be a numpy array." - self._dim = data.shape - self._ndim = data.ndim - if invert: - self.data = data=threshold + + +class BinaryFieldNd: + def __init__(self,input): + assert isinstance(input,np.ndarray) and input.dtype==np.dtype('bool'),\ + "'input' must be a numpy array of dtype('bool')." + self.data = input + self._dim = input.shape + self._ndim = input.ndim + self.labels = None + self.nlabels = 0 + self.wrap = tuple(self._ndim*[None]) + self.set_structure(False) + self.set_periodicity(self._ndim*[False]) @classmethod - def from_field(cls,fld3d,threshold,invert=False): - return cls(fld3d.data,threshold,invert=invert) + def from_threshold(cls,fld,threshold,invert=False): + if isinstance(fld,Field3d): + fld = fld.data + if invert: + return cls(fld=threshold) - def fill_holes(self,periodicity=(False,False,False)): - '''Fills topological holes in threshold regions.''' + def set_periodicity(self,periodicity): assert all([isinstance(x,(bool,int)) for x in periodicity]),\ "'periodicity' requires bool values." - from scipy import ndimage - binarr = ndimage.binary_fill_holes(self.data) - for axis in range(self._ndim): - if periodicity[axis]: - n = binarr.shape[axis] - binarr = np.roll(binarr,n//2,axis=axis) - binarr = ndimage.binary_fill_holes(binarr) - binarr = np.roll(binarr,-n//2,axis=axis) - self.data = binarr + assert len(periodicity)==self._ndim,\ + "Number of entries in 'periodicity' must match dimension of binary field." + self.periodicity = tuple(bool(x) for x in periodicity) return + def set_structure(self,connect_diagonals): + from scipy import ndimage + if connect_diagonals: + self.structure = ndimage.generate_binary_structure(self._ndim,self._ndim) + else: + self.structure = ndimage.generate_binary_structure(self._ndim,1) + + def enable_diagonal_connections(self): self.set_structure(True) + + def disable_diagonal_connections(self): self.set_structure(False) + + def label(self): + '''Labels connected regions in binary fields.''' + from scipy import ndimage + if any(self.periodicity): + self.labels,self.nlabels,self.wrap = self._labels_periodic() + else: + self.labels,self.nlabels = ndimage.label(self.data,structure=self.structure) + + def _labels_periodic(self,map_to_zero=False): + '''Label features in an array while taking into account periodic wrapping. + If map_to_zero=True, every feature which overlaps or is attached to the + periodic boundary will be removed.''' + from scipy import ndimage + # Pad input data + if map_to_zero: + pw = tuple((1,1) if x else (0,0) for x in self.periodicity) + sl_pad = tuple(slice(1,-1) if x else slice(None) for x in self.periodicity) + else: + pw = tuple((0,1) if x else (0,0) for x in self.periodicity) + sl_pad = tuple(slice(0,-1) if x else slice(None) for x in self.periodicity) + data_ = np.pad(self.data,pw,mode='wrap') + # Compute labels on padded array + labels_,nlabels_ = ndimage.label(data_,structure=self.structure) + # Get a mapping of labels which differ at periodic overlap + map_ = np.array(range(nlabels_+1),dtype=labels_.dtype) + wrap_ = self._ndim*[None] + for axis in range(self._ndim): + if not self.periodicity[axis]: continue + if map_to_zero: + sl_lo = tuple(slice(0,2) if ii==axis else slice(None) for ii in range(self._ndim)) + sl_hi = tuple(slice(-2,None) if ii==axis else slice(None) for ii in range(self._ndim)) + lab_lo = labels_[sl_lo] + lab_hi = labels_[sl_hi] + li = (lab_lo!=lab_hi) + for source_ in np.unique(lab_lo[li]): + map_[source_] = 0 + for source_ in np.unique(lab_hi[li]): + map_[source_] = 0 + else: + sl_lo = tuple(slice(0,1) if ii==axis else slice(None) for ii in range(self._ndim)) + sl_hi = tuple(slice(-1,None) if ii==axis else slice(None) for ii in range(self._ndim)) + lab_lo = labels_[sl_lo] + lab_hi = labels_[sl_hi] + # Initialize array to keep track of wrapping + wrap_[axis] = np.zeros(nlabels_+1,dtype=bool) + # Determine new label and map + lab_new = np.minimum(lab_lo,lab_hi) + for lab_ in [lab_lo,lab_hi]: + li = (lab_!=lab_new) + lab_li = lab_[li] + lab_new_li = lab_new[li] + for idx_ in np.unique(lab_li,return_index=True)[1]: + source_ = lab_li[idx_] # the label to be changed + target_ = lab_new_li[idx_] # the label which will be newly assigned + while target_ != map_[target_]: # map it recursively + target_ = map_[target_] + map_[source_] = target_ + wrap_[axis][target_] = True + # Remove gaps from target mapping + idx_,map_ = np.unique(map_,return_index=True,return_inverse=True)[1:3] + # Relabel and remove padding + labels_ = map_[labels_[sl_pad]] + nlabels_ = np.max(map_) + for axis in range(self._ndim): + if wrap_[axis] is not None: + wrap_[axis] = wrap_[axis][idx_] + return labels_,nlabels_,tuple(wrap_) + + def fill_holes(self): + '''Fill the holes in binary objects while taking into account periodicity. + In the non-periodic sense, a hole is a region of zeros which does not connect + to a boundary. In the periodic sense, a hole is a region of zeros which is not + connected to itself accross the periodic boundaries.''' + from scipy import ndimage + # Reimplementation of "binary_fill_holes" from ndimage + mask = np.logical_not(self.data) # only modify locations which are "False" at the moment + tmp = np.zeros(mask.shape,bool) # create empty array to "grow from boundaries" + ndimage.binary_dilation(tmp,structure=None,iterations=-1, + mask=mask,output=self.data,border_value=1, + origin=0) # everything connected to the boundary is now True in self.data + # Remove holes which overlap the boundaries + if any(self.periodicity): + self.data = self._labels_periodic(map_to_zero=True)[0]>0 + # Invert to get the final result + np.logical_not(self.data,self.data) + def probe(self,idx): - '''Returns whether or not point at index is inside threshold region or not.''' + '''Returns whether or not a point at idx is True or False.''' return self.data[tuple(idx)] def volume(self): - '''Returns volume of region above threshold.''' + '''Returns the sum of True values.''' return np.sum(self.data) + def volume_feature(self,label=None): + '''Returns volume of features, i.e. connected regions which have been + labeled using the label() method. If 'label' is None all volumes + are returned including the volume of the background region. The array + is sorted by labels, i.e. vol[0] is the volume of the background region, + vol[1] the volume of label 1, etc. If 'label' is an integer value, only + the volume of the corresponding region is returned. + Note: it is more efficient to retrieve all volumes at once than querying + single labels.''' + if self.labels is None: + self.label() + if label is None: + return np.bincount(self.labels.ravel()) + else: + return np.sum(self.labels==label) + + def volume_domain(self): + '''Returns volume of entire domain. Should be equal to sum(volume_feature()).''' + return np.prod(self._dim) + + def feature_labels_by_volume(self,descending=True): + '''Returns labels of connected regions sorted by volume.''' + labels = np.argsort(self.volume_feature()[1:])+1 + if descending: labels = labels[::-1] + return labels + + def discard_feature(self,selection): + if self.labels is None: + self.label() + selection = self._select_feature(selection) + # Map tagged regions to zero in order to discard them + map_ = np.array(range(0,self.count+1),dtype=self.label.dtype) + map_[selection] = 0 + # Remove gaps from target mapping + map_ = np.unique(map_,return_inverse=True)[1] + # Discard regions + self.labels = map_[self.labels] + self.nlabels = np.max(map_) + self.data = self.labels>0 + + def _select_feature(self,selection): + dtype = self.label.dtype + if np.issubdtype(type(selection),np.integer): + return np.array(selection,dtype=dtype) + elif isinstance(selection,(list,tuple,np.ndarray)): + selection = np.array(selection) + if selection.dtype==np.dtype('bool'): + assert selection.ndim==1 and selection.shape[0]==self.count+1,\ + "Boolean indexing must provide count+1 values." + else: + selection = selection.astype(dtype) + assert np.max(selection)<=self.count and np.min(selection)>=0,\ + "Entry in selection is out-of-bounds." + return selection + else: + raise ValueError('Invalid input. Accepting int,list,tuple,ndarray.') + class ConnectedRegions: def __init__(self,binarr,periodicity,connect_diagonals=False,fill_holes=False,bytes_label=32): assert isinstance(binarr,np.ndarray) and binarr.dtype==np.dtype('bool'),\