suspendtools/field.py

1072 lines
48 KiB
Python

import numpy as np
class Field3d:
def __init__(self,data,origin,spacing,deep=False):
assert len(origin)==3, "'origin' must be of length 3"
assert len(spacing)==3, "'spacing' must be of length 3"
assert isinstance(data,np.ndarray), "'data' must be numpy.ndarray."
if deep:
self.data = data.copy()
else:
self.data = data
self.origin = tuple([float(x) for x in origin])
self.spacing = tuple([float(x) for x in spacing])
self.eps_collapse = 1e-8
self._dim = None
return
def __str__(self):
str = 'Field3d with\n'
str+= ' dimension: {}, {}, {}\n'.format(*self.dim())
str+= ' origin: {:G}, {:G}, {:G}\n'.format(*self.origin)
str+= ' spacing: {:G}, {:G}, {:G}\n'.format(*self.spacing)
str+= ' datatype: {}'.format(self.dtype())
return str
def __add__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
return Field3d(self.data+other.data,self.origin,self.spacing)
else:
return Field3d(self.data+other,self.origin,self.spacing)
def __sub__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
return Field3d(self.data-other.data,self.origin,self.spacing)
else:
return Field3d(self.data-other,self.origin,self.spacing)
def __mul__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
return Field3d(self.data*other.data,self.origin,self.spacing)
else:
return Field3d(self.data*other,self.origin,self.spacing)
def __truediv__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
return Field3d(self.data/other.data,self.origin,self.spacing)
else:
return Field3d(self.data/other,self.origin,self.spacing)
def __radd__(self,other):
return Field3d(other+self.data,self.origin,self.spacing)
def __rmul__(self,other):
return Field3d(other*self.data,self.origin,self.spacing)
def __pow__(self,other):
return Field3d(self.data**other,self.origin,self.spacing)
def __iadd__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
self.data += other.data
else:
self.data += other
return self
def __isub__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
self.data -= other.data
else:
self.data -= other
return self
def __imul__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
self.data *= other.data
else:
self.data *= other
return self
def __itruediv__(self,other):
if isinstance(other,Field3d):
assert self.has_same_grid(other), "Grid mismatch."
self.data /= other.data
else:
self.data /= other
return self
# TBD: this should return another Field3d object
# def __getitem__(self,val):
# assert isinstance(val,tuple) and len(val)==3, "Field3d must be indexed by [ii,jj,kk]."
# sl = []
# for x in val:
# if isinstance(x,int):
# lo,hi = x,x+1
# hi = hi if hi!=0 else None
# sl.append(slice(lo,hi))
# elif isinstance(x,slice):
# sl.append(x)
# else:
# raise TypeError("Trajectories can only be sliced by slice objects or integers.")
# return self.data[sl[0],sl[1],sl[2]]
@classmethod
def from_chunk(cls,chunk,gridg):
'''Initialize Field3d from chunk data and global grid.'''
xg,yg,zg = gridg
ib,jb,kb = chunk['ibeg']-1, chunk['jbeg']-1, chunk['kbeg']-1
dx,dy,dz = xg[1]-xg[0], yg[1]-yg[0], zg[1]-zg[0]
xo,yo,zo = xg[ib]-chunk['ighost']*dx, yg[jb]-chunk['ighost']*dy, zg[kb]-chunk['ighost']*dz
nx,ny,nz = chunk['data'].shape
assert (chunk['nxl']+2*chunk['ighost'])==nx, "Invalid chunk data: nxl != chunk['data'].shape[0]"
assert (chunk['nyl']+2*chunk['ighost'])==ny, "Invalid chunk data: nyl != chunk['data'].shape[1]"
assert (chunk['nzl']+2*chunk['ighost'])==nz, "Invalid chunk data: nzl != chunk['data'].shape[2]"
return cls(chunk['data'],origin=(xo,yo,zo),spacing=(dx,dy,dz))
@classmethod
def from_file(cls,file,name='Field3d'):
import h5py
is_open = isinstance(file,(h5py.File,h5py.Group))
f = file if is_open else h5py.File(file,'r')
g = f[name]
origin = tuple(g['origin'])
spacing = tuple(g['spacing'])
data = g['data'][:]
if not is_open: f.close()
return cls(data,origin,spacing)
@classmethod
def allocate(cls,dim,origin,spacing,fill=None,dtype=np.float64):
'''Allocates an empty field, or a field filled with 'fill'.'''
assert isinstance(dim,(tuple,list,np.ndarray)) and len(dim)==3,\
"'dim' must be a tuple/list of length 3."
assert isinstance(origin,(tuple,list,np.ndarray)) and len(origin)==3,\
"'origin' must be a tuple/list of length 3."
assert isinstance(spacing,(tuple,list,np.ndarray)) and len(spacing)==3,\
"'spacing' must be a tuple/list of length 3."
if fill is None:
data = np.empty(dim,dtype=dtype)
else:
data = np.full(dim,fill,dtype=dtype)
return cls(data,origin,spacing)
@classmethod
def pseudo_field(cls,dim,origin,spacing):
'''Creates a Field3d instance without allocating any memory.'''
assert isinstance(dim,(tuple,list,np.ndarray)) and len(dim)==3,\
"'dim' must be a tuple/list of length 3."
assert isinstance(origin,(tuple,list,np.ndarray)) and len(origin)==3,\
"'origin' must be a tuple/list of length 3."
assert isinstance(spacing,(tuple,list,np.ndarray)) and len(spacing)==3,\
"'spacing' must be a tuple/list of length 3."
data = np.empty((0,0,0))
r = cls(data,origin,spacing)
r._dim = dim
return r
def save(self,file,name='Field3d',truncate=False):
import h5py
is_open = isinstance(file,(h5py.File,h5py.Group))
flag = 'w' if truncate else 'a'
f = file if is_open else h5py.File(file,flag)
g = f.create_group(name)
g.create_dataset('origin',data=self.origin)
g.create_dataset('spacing',data=self.spacing)
g.create_dataset('data',data=self.data)
if not is_open: f.close()
return
def copy(self):
return Field3d(self.data.copy(),self.origin,self.spacing)
def pseudo_copy(self):
return self.pseudo_field(self.dim(),self.origin,self.spacing)
def insert_subfield(self,subfield):
assert all([abs(subfield.spacing[ii]-self.spacing[ii])<self.eps_collapse
for ii in range(3)]), "spacing differs. Got {}, have {}".format(subfield.spacing,self.spacing)
assert all([self.distance_to_nearest_gridpoint(subfield.origin[ii],axis=ii)<self.eps_collapse
for ii in range(3)]), "subfield has shifted origin."
assert all(self.is_within_bounds(subfield.origin,axis=None)), "subfield origin is out-of-bounds."
assert all(self.is_within_bounds(subfield.endpoint(),axis=None)), "subfield endpoint is out-of-bounds."
#ib,jb,kb = [int(round((subfield.origin[ii]-self.origin[ii])/self.spacing[ii])) for ii in range(3)]
ib,jb,kb = self.nearest_gridpoint(subfield.origin,axis=None)
nx,ny,nz = subfield.dim()
self.data[ib:ib+nx,jb:jb+ny,kb:kb+nz] = subfield.data[:,:,:]
return
def extract_subfield(self,idx_origin,dim,stride=(1,1,1),deep=False,strict_bounds=True):
if strict_bounds:
assert all(idx_origin[ii]>=0 and idx_origin[ii]<self.dim(axis=ii) for ii in range(3)),\
"'origin' is out-of-bounds."
assert all(idx_origin[ii]+stride[ii]*(dim[ii]-1)<self.dim(axis=ii) for ii in range(3)),\
"endpoint is out-of-bounds."
else:
for ii in range(3):
if idx_origin[ii]<0:
dim[ii] += idx_origin[ii]
idx_origin[ii] = 0
if idx_origin[ii]+stride[ii]*(dim[ii]-1)>=self.dim(axis=ii):
dim[ii] = (self.dim(axis=ii)-idx_origin[ii])//stride[ii]
if dim[ii]<=0:
return None
sl = tuple(slice(idx_origin[ii],
idx_origin[ii]+stride[ii]*dim[ii],
stride[ii]) for ii in range(3))
origin = self.coordinate(idx_origin)
spacing = tuple(self.spacing[ii]*stride[ii] for ii in range(3))
data = self.data[sl]
return Field3d(data,origin,spacing,deep=deep)
def coordinate(self,idx,axis=None):
if axis is None:
assert len(idx)==3, "If 'axis' is None, 'idx' must be a tuple/list of length 3."
return tuple(self.coordinate(idx[ii],axis=ii) for ii in range(3))
assert axis<3, "'axis' must be one of 0,1,2."
assert idx<self.dim(axis=axis), "'idx' is out-of-bounds."
return self.origin[axis]+idx*self.spacing[axis]
def grid(self,axis=None):
if axis is None:
return tuple(self.grid(axis=ii) for ii in range(3))
assert axis<3, "'axis' must be one of 0,1,2."
return self.origin[axis]+np.arange(0,self.dim(axis=axis))*self.spacing[axis]
def x(self): return self.grid(axis=0)
def y(self): return self.grid(axis=1)
def z(self): return self.grid(axis=2)
def nearest_gridpoint(self,coord,axis=None,lower=False):
if axis is None:
assert len(coord)==3, "If 'axis' is None, 'coord' must be a tuple/list of length 3."
return tuple(self.nearest_gridpoint(coord[ii],axis=ii,lower=lower) for ii in range(3))
assert axis<3, "'axis' must be one of 0,1,2."
if lower:
return np.floor((coord+self.eps_collapse-self.origin[axis])/self.spacing[axis]).astype('int')
else:
return np.round((coord-self.origin[axis])/self.spacing[axis]).astype('int')
def distance_to_nearest_gridpoint(self,coord,axis=None,lower=False):
if axis is None:
assert len(coord)==3, "If 'axis' is None, 'coord' must be a tuple/list of length 3."
return tuple(self.distance_to_nearest_gridpoint(coord[ii],axis=ii,lower=lower) for ii in range(3))
assert axis<3, "'axis' must be one of 0,1,2."
val = np.remainder(coord+self.eps_collapse-self.origin[axis],self.spacing[axis])-self.eps_collapse
if not lower and val>0.5*self.spacing[axis]:
val = self.spacing[axis]-val
return val
def is_within_bounds(self,coord,axis=None):
if axis is None:
assert len(coord)==3, "If 'axis' is None, 'coord' must be a tuple/list of length 3."
return tuple(self.is_within_bounds(coord[ii],axis=ii) for ii in range(3))
assert axis<3, "'axis' must be one of 0,1,2."
idx_nearest = self.nearest_gridpoint(coord,axis=axis)
if idx_nearest>0 and idx_nearest<self.dim(axis=axis)-1:
return True
dist_nearest = self.distance_to_nearest_gridpoint(coord,axis=axis)
if (idx_nearest==0 or idx_nearest==self.dim(axis=axis)-1) and abs(dist_nearest)<self.eps_collapse:
return True
else:
return False
def has_same_grid(self,other):
if not self.has_same_origin(other): return False
if not self.has_same_spacing(other): return False
if not self.has_same_origin(other): return False
return True
def has_same_origin(self,other,axis=None):
if axis is None:
return all([self.has_same_origin(other,axis=ii) for ii in range(3)])
origin = other.origin if isinstance(other,Field3d) else other
return abs(origin[axis]-self.origin[axis])<self.eps_collapse
def has_same_spacing(self,other,axis=None):
if axis is None:
return all([self.has_same_spacing(other,axis=ii) for ii in range(3)])
spacing = other.spacing if isinstance(other,Field3d) else other
return abs(spacing[axis]-self.spacing[axis])<self.eps_collapse
def has_same_dim(self,other,axis=None):
if axis is None:
return all([self.has_same_dim(other,axis=ii) for ii in range(3)])
dim = other.dim(axis=axis) if isinstance(other,Field3d) else other
return dim==self.dim(axis=axis)
def dim(self,axis=None):
if axis is None:
return tuple(self.dim(axis=ii) for ii in range(3))
assert axis<3, "'axis' must be one of 0,1,2."
if self._dim is not None:
return self._dim[axis]
else:
return self.data.shape[axis]
def endpoint(self,axis=None):
if axis is None:
return tuple(self.endpoint(axis=ii) for ii in range(3))
assert axis<3, "'axis' must be one of 0,1,2."
return self.origin[axis]+(self.dim(axis=axis)-1)*self.spacing[axis]
def dtype(self):
return self.data.dtype
def convert_dtype(self,dtype):
self.data = self.data.astype(dtype,copy=False)
return
def derivative(self,axis,only_keep_interior=False,shift_origin='before'):
'''Computes derivative wrt to direction 'axis' with 2nd order finite differences
centered between the origin grid points'''
from scipy import ndimage
assert axis<3, "'axis' must be one of 0,1,2."
origin = list(self.origin)
assert shift_origin in ('before','after'), "'shift_origin' must be one of {'before','after'}."
if shift_origin=='left':
corr_orig = 0
origin[axis] -= 0.5*self.spacing[axis]
else:
corr_orig = -1
origin[axis] += 0.5*self.spacing[axis]
data = ndimage.correlate1d(self.data,np.array([-1.,1.])/self.spacing[axis],axis=axis,
mode='constant',cval=np.nan,origin=corr_orig)
if only_keep_interior:
sl = 3*[slice(None)]
if shift_origin=='before':
sl[axis] = slice(-1,None)
origin[axis] += self.spacing[axis]
else:
sl[axis] = slice(0,-1)
data = data[tuple(sl)]
return Field3d(data,origin,self.spacing)
def gradient(self,axis,preserve_origin=False,only_keep_interior=False,add_border='before'):
return [self.derivative(axis,preserve_origin=preserve_origin,
only_keep_interior=only_keep_interior,add_border=add_border) for axis in range(0,3)]
def laplacian(self,only_keep_interior=False):
'''Computes the Laplacian of a field.'''
from scipy import ndimage
data = ndimage.correlate1d(self.data,np.array([1.,-2.,1.])/self.spacing[0]**2,axis=0,mode='constant',cval=np.nan,origin=0)
data += ndimage.correlate1d(self.data,np.array([1.,-2.,1.])/self.spacing[1]**2,axis=1,mode='constant',cval=np.nan,origin=0)
data += ndimage.correlate1d(self.data,np.array([1.,-2.,1.])/self.spacing[2]**2,axis=2,mode='constant',cval=np.nan,origin=0)
origin = list(self.origin)
if only_keep_interior:
data = data[1:-1,1:-1,1:-1]
for axis in range(3):
origin[axis] = origin[axis]+self.spacing[axis]
return Field3d(data,origin,self.spacing)
def integral(self,integrate_axis,average=False,ignore_nan=False,return_weights=False,ufunc=None):
'''Computes the integral or average along a given axis applying the
function 'ufunc' to each node.'''
assert isinstance(integrate_axis,(list,tuple,np.ndarray)) and len(integrate_axis)==3,\
"'integrate_axis' must be a tuple/list of length 3."
assert all([isinstance(integrate_axis[ii],(bool,int)) for ii in range(3)]),\
"'integrate_axis' requires bool values."
assert any(integrate_axis), "'integrate_axis' must contain at least one True."
axes = []
weight = 1.0
for axis in range(3):
if integrate_axis[axis]:
axes.append(axis)
if average:
weight *= self.dim(axis=axis)
else:
weight *= self.spacing[axis]
axes = tuple(axes)
if ignore_nan:
if average:
weight = np.sum(~np.isnan(self.data),axis=axes,keepdims=True)
func_sum = np.nansum
else:
func_sum = np.sum
if ufunc is None:
out = func_sum(self.data,axis=axes,keepdims=True)
else:
assert isinstance(ufunc,np.ufunc), "'ufunc' needs to be a numpy ufunc. "\
"Check out https://numpy.org/doc/stable/reference/ufuncs.html for reference."
assert ufunc.nin==1, "Only ufunc with single input argument are supported for now."
out = func_sum(ufunc(self.data),axis=axes,keepdims=True)
if return_weights:
return (out,weight)
else:
return out/weight
def gaussian_filter(self,sigma,truncate=4.0,only_keep_interior=False):
'''Applies a gaussian filter: sigma is standard deviation for Gaussian kernel for each axis.'''
from scipy import ndimage
assert isinstance(sigma,(tuple,list,np.ndarray)) and len(sigma)==3,\
"'sigma' must be a tuple/list of length 3"
# Convert sigma from simulation length scales to grid points as required by ndimage
sigma_img = tuple(sigma[ii]/self.spacing[ii] for ii in range(3))
data = ndimage.gaussian_filter(self.data,sigma_img,truncate=truncate,mode='constant',cval=np.nan)
origin = list(self.origin)
if only_keep_interior:
r = self.gaussian_filter_radius(sigma,truncate=truncate)
data = data[r[0]:-r[0],r[1]:-r[1],r[2]:-r[2]]
for axis in range(3):
origin[axis] = origin[axis]+r[axis]*self.spacing[axis]
return Field3d(data,origin,self.spacing)
def gaussian_filter_radius(self,sigma,truncate=4.0):
'''Radius of Gaussian filter. Stencil width is 2*radius+1.'''
assert isinstance(sigma,(tuple,list,np.ndarray)) and len(sigma)==3,\
"'sigma' must be a tuple/list of length 3"
# Convert sigma from simulation length scales to grid points as required by ndimage
sigma_img = tuple(sigma[ii]/self.spacing[ii] for ii in range(3))
radius = []
for ii in range(3):
radius.append(int(truncate*sigma_img[ii]+0.5))
return tuple(radius)
def shift_origin(self,rel_shift,only_keep_interior=False):
'''Shifts the origin of a field by multiple of spacing.'''
from scipy import ndimage
assert isinstance(rel_shift,(tuple,list,np.ndarray)) and len(rel_shift)==3,\
"'shift' must be tuple/list with length 3."
assert all([rel_shift[ii]>=-1.0-self.eps_collapse and rel_shift[ii]<=1.0+self.eps_collapse for ii in range(3)]),\
"'shift' must be in (-1.0,1.0). {}".format(rel_shift)
data = self.data.copy()
origin = list(self.origin)
sl = 3*[slice(None)]
for axis in range(3):
if abs(rel_shift[axis])<self.eps_collapse:
continue
elif rel_shift[axis]>0:
w = rel_shift[axis] if rel_shift[axis]<=1.0 else 1.0
weights = (1.0-w,w)
data = ndimage.correlate1d(data,weights,axis=axis,mode='constant',cval=np.nan,origin=-1)
origin[axis] += w*self.spacing[axis]
if only_keep_interior:
sl[axis] = slice(0,-1)
else:
w = rel_shift[axis] if rel_shift[axis]>=-1.0 else -1.0
weights = (-w,1.0+w)
data = ndimage.correlate1d(data,weights,axis=axis,mode='constant',cval=np.nan,origin=0)
origin[axis] += w*self.spacing[axis]
if only_keep_interior:
sl[axis] = slice(1,None)
origin[axis] += self.spacing[axis]
if only_keep_interior:
data = data[sl]
return Field3d(data,origin,self.spacing)
def relative_shift(self,field):
'''Compute the relative shift (in terms of spacing) to shift self onto field.'''
assert self.has_same_spacing(field), "spacing differs."
rel_shift = [0.0,0.0,0.0]
for axis in range(3):
dist = field.origin[axis]-self.origin[axis]
if abs(dist)>self.eps_collapse:
rel_shift[axis] = dist/self.spacing[axis]
return tuple(rel_shift)
def change_grid(self,origin,spacing,dim,padding=None,numpad=1):
assert all([origin[ii]>=self.origin[ii] for ii in range(0,3)]), "New origin is out of bounds."
endpoint = [origin[ii]+(dim[ii]-1)*spacing[ii] for ii in range(0,3)]
assert all([endpoint[ii]<=self.endpoint(ii) for ii in range(0,3)]), "New end point is out of bounds."
# Allocate (possibly padded array)
origin_pad,dim_pad,sl_out = padding(origin,spacing,dim,padding,numpad)
data = np.zeros(dim_pad,dtype=self.data.dtype)
# Trilinear interpolation
if np.allclose(spacing,self.spacing):
# spacing is the same: we can construct universal weights for the stencil
i0,j0,k0 = self.nearest_gridpoint(origin,axis=None,lower=True)
cx,cy,cz = [self.distance_to_nearest_gridpoint(origin[ii],axis=ii,lower=True)/self.spacing[ii]
for ii in range(3)]
c = self.weights_trilinear((cx,cy,cz))
for ii in range(0,2):
for jj in range(0,2):
for kk in range(0,2):
if c[ii,jj,kk]>self.eps_collapse:
data[sl_out] += c[ii,jj,kk]*self.data[
i0+ii:i0+ii+dim[0],
j0+jj:j0+jj+dim[1],
k0+kk:k0+kk+dim[2]]
else:
data_ref = data[sl_out]
for ii in range(0,dim[0]):
for jj in range(0,dim[1]):
for kk in range(0,dim[2]):
coord = (
origin[0]+ii*spacing[0],
origin[1]+jj*spacing[1],
origin[2]+kk*spacing[2])
data_ref[ii,jj,kk] = self.interpolate(coord)
return Field3d(data,origin_pad,spacing)
def interpolate(self,coord):
assert all([coord[ii]>=self.origin[ii] for ii in range(0,3)]), "'coord' is out of bounds."
assert all([coord[ii]<=self.endpoint(ii) for ii in range(0,3)]), "'coord' is out of bounds."
i0,j0,k0 = self.nearest_gridpoint(coord,axis=None,lower=True)
cx,cy,cz = [self.distance_to_nearest_gridpoint(coord[ii],axis=ii,lower=True)/self.spacing[ii]
for ii in range(3)]
c = self.weights_trilinear((cx,cy,cz))
val = 0.0
for ii in range(0,2):
for jj in range(0,2):
for kk in range(0,2):
if c[ii,jj,kk]>self.eps_collapse:
val += c[ii,jj,kk]*self.data[i0+ii,j0+jj,k0+kk]
return val
@staticmethod
def padding(origin,spacing,dim,padding,numpad):
if isinstance(numpad,int):
numpad = np.fill(3,numpad,dtype=int)
else:
numpad = np.array(numpad,dtype=int)
assert len(numpad)==3, "'numpad' must be either an integer or tuple/list of length 3."
origin_pad = np.array(origin)
dim_pad = np.array(dim)
sl_out = [slice(None),slice(None),slice(None)]
if padding is not None:
if padding=='before':
dim_pad += numpad
origin_pad -= numpad*spacing
for axis in range(3):
sl_out[axis] = slice(numpad[axis],None)
elif padding=='after':
dim_pad += numpad
for axis in range(3):
sl_out[axis] = slice(0,-numpad[axis])
elif padding=='both':
dim_pad += 2*numpad
origin_pad -= numpad*spacing
for axis in range(3):
sl_out[axis] = slice(numpad[axis],-numpad[axis])
else:
raise ValueError("'padding' must either be None or one of {'before','after','both'}.")
sl_out = tuple(sl_out)
origin_pad = tuple(origin_pad)
dim_pad = tuple(dim_pad)
return (origin_pad,dim_pad,sl_out)
def weights_trilinear(self,rel_dist):
assert len(rel_dist)==3, "len(rel_dist) must be 3."
cx,cy,cz = rel_dist
if cx<0.0 and cx>-self.eps_collapse: cx=0.0
if cy<0.0 and cy>-self.eps_collapse: cy=0.0
if cz<0.0 and cz>-self.eps_collapse: cz=0.0
if cx>1.0 and cx<1.0+self.eps_collapse: cx=1.0
if cy>1.0 and cy<1.0+self.eps_collapse: cy=1.0
if cz>1.0 and cz<1.0+self.eps_collapse: cz=1.0
assert cx>=0.0 and cy>=0.0 and cz>=0.0, "'rel_dist' must be >=0"
assert cx<=1.0 and cy<=1.0 and cz<=1.0, "'rel_dist' must be <=1"
c = np.zeros((2,2,2))
c[0,0,0] = 1.0-(cx+cy+cz)+(cx*cy+cx*cz+cy*cz)-(cx*cy*cz)
c[1,0,0] = cx-(cx*cy+cx*cz)+(cx*cy*cz)
c[0,1,0] = cy-(cx*cy+cy*cz)+(cx*cy*cz)
c[0,0,1] = cz-(cx*cz+cy*cz)+(cx*cy*cz)
c[1,1,0] = (cx*cy)-(cx*cy*cz)
c[1,0,1] = (cx*cz)-(cx*cy*cz)
c[0,1,1] = (cy*cz)-(cx*cy*cz)
c[1,1,1] = (cx*cy*cz)
return c
def set_writable(self,flag):
self.data.setflags(write=flag)
return
def to_vtk(self,deep=False):
import pyvista as pv
mesh = pv.UniformGrid()
mesh.dimensions = self.dim(axis=None)
mesh.origin = self.origin
mesh.spacing = self.spacing
# order needs to be F no matter how array is stored in memory
if deep:
mesh.point_arrays['data'] = self.data.flatten(order='F')
else:
mesh.point_arrays['data'] = self.data.ravel(order='F')
return mesh
def vtk_contour(self,val,deep=False,method='contour'):
if not isinstance(val,(tuple,list)):
val = [val]
return self.to_vtk(deep=deep).contour(val,method=method)
def vtk_slice(self,normal,origin,deep=False):
assert (normal in ('x','y','z') or (isinstance(normal,(tuple,list))
and len(normal)==3)), "'normal' must be 'x','y','z' or tuple of length 3."
assert isinstance(origin,(tuple,list)) and len(origin)==3,\
"'origin' must be tuple of length 3."
return self.to_vtk(deep=deep).slice(normal=normal,origin=origin)
def gaussian_filter_umean_channel(array,spacing,sigma,truncate=4.0):
'''Applies a Gaussian filter to a numpy array of dimension (1,ny,1) which
contains the mean streamwise velocity of the channel (or possibly sth else).
Since yu[0] = c and yu[-1] = d, we can use scipy's mirror settings and don't
need ghost cells.'''
from scipy import ndimage
assert array.ndim==3, "Expected an array with three dimensions/axes."
assert array.shape[0]==1 and array.shape[1]>1 and array.shape[2]==1,\
"Expected an array with shape (1,ny,1)."
sigma_img = sigma/spacing
array = ndimage.gaussian_filter1d(array,sigma_img,axis=1,truncate=truncate,mode='mirror')
return array
class BinaryFieldNd:
def __init__(self,input):
assert isinstance(input,np.ndarray) and input.dtype==np.dtype('bool'),\
"'input' must be a numpy array of dtype('bool')."
self.data = input
self._dim = input.shape
self._ndim = input.ndim
self.labels = None
self.nlabels = 0
self.wrap = tuple(self._ndim*[None])
self._feat_slice = None
self.set_structure(False)
self.set_periodicity(self._ndim*[False])
@classmethod
def from_threshold(cls,fld,threshold,invert=False):
if isinstance(fld,Field3d):
fld = fld.data
if invert:
return cls(fld<threshold)
else:
return cls(fld>=threshold)
def set_periodicity(self,periodicity):
assert all([isinstance(x,(bool,int)) for x in periodicity]),\
"'periodicity' requires bool values."
assert len(periodicity)==self._ndim,\
"Number of entries in 'periodicity' must match dimension of binary field."
self.periodicity = tuple(bool(x) for x in periodicity)
return
def set_structure(self,connect_diagonals):
from scipy import ndimage
if connect_diagonals:
self.structure = ndimage.generate_binary_structure(self._ndim,self._ndim)
else:
self.structure = ndimage.generate_binary_structure(self._ndim,1)
def enable_diagonal_connections(self): self.set_structure(True)
def disable_diagonal_connections(self): self.set_structure(False)
def label(self):
'''Labels connected regions in binary fields.'''
from scipy import ndimage
if any(self.periodicity):
self.labels,self.nlabels,self.wrap = self._labels_periodic()
else:
self.labels,self.nlabels = ndimage.label(self.data,structure=self.structure)
self._feat_slice = ndimage.find_objects(self.labels)
def _labels_periodic(self,map_to_zero=False):
'''Label features in an array while taking into account periodic wrapping.
If map_to_zero=True, every feature which overlaps or is attached to the
periodic boundary will be removed.'''
from scipy import ndimage
# Pad input data
if map_to_zero:
pw = tuple((1,1) if x else (0,0) for x in self.periodicity)
sl_pad = tuple(slice(1,-1) if x else slice(None) for x in self.periodicity)
else:
pw = tuple((0,1) if x else (0,0) for x in self.periodicity)
sl_pad = tuple(slice(0,-1) if x else slice(None) for x in self.periodicity)
data_ = np.pad(self.data,pw,mode='wrap')
# Compute labels on padded array
labels_,nlabels_ = ndimage.label(data_,structure=self.structure)
# Get a mapping of labels which differ at periodic overlap
map_ = np.array(range(nlabels_+1),dtype=labels_.dtype)
wrap_ = self._ndim*[None]
for axis in range(self._ndim):
if not self.periodicity[axis]: continue
if map_to_zero:
sl_lo = tuple(slice(0,2) if ii==axis else slice(None) for ii in range(self._ndim))
sl_hi = tuple(slice(-2,None) if ii==axis else slice(None) for ii in range(self._ndim))
lab_lo = labels_[sl_lo]
lab_hi = labels_[sl_hi]
li = (lab_lo!=lab_hi)
for source_ in np.unique(lab_lo[li]):
map_[source_] = 0
for source_ in np.unique(lab_hi[li]):
map_[source_] = 0
else:
sl_lo = tuple(slice(0,1) if ii==axis else slice(None) for ii in range(self._ndim))
sl_hi = tuple(slice(-1,None) if ii==axis else slice(None) for ii in range(self._ndim))
sl_pre = tuple(slice(-2,-1) if ii==axis else slice(None) for ii in range(self._ndim))
lab_lo = labels_[sl_lo]
lab_hi = labels_[sl_hi]
lab_pre = np.unique(labels_[sl_pre]) # all labels in last (unwrapped) slice
# Initialize array to keep track of wrapping
wrap_[axis] = np.zeros(nlabels_+1,dtype=bool)
# Determine new label and map
lab_new = np.minimum(lab_lo,lab_hi)
for lab_ in [lab_lo,lab_hi]:
li = (lab_!=lab_new)
lab_li = lab_[li]
lab_new_li = lab_new[li]
for idx_ in np.unique(lab_li,return_index=True)[1]:
source_ = lab_li[idx_] # the label to be changed
target_ = lab_new_li[idx_] # the label which will be newly assigned
while target_ != map_[target_]: # map it recursively
target_ = map_[target_]
map_[source_] = target_
if source_ in lab_pre: # check if source is not a ghost
wrap_[axis][target_] = True
# Remove gaps from target mapping
idx_,map_ = np.unique(map_,return_index=True,return_inverse=True)[1:3]
# Relabel and remove padding
labels_ = map_[labels_[sl_pad]]
nlabels_ = np.max(map_)
assert nlabels_==len(idx_)-1, "DEBUG assertion"
for axis in range(self._ndim):
if wrap_[axis] is not None:
wrap_[axis] = wrap_[axis][idx_]
return labels_,nlabels_,tuple(wrap_)
def fill_holes(self):
'''Fill the holes in binary objects while taking into account periodicity.
In the non-periodic sense, a hole is a region of zeros which does not connect
to a boundary. In the periodic sense, a hole is a region of zeros which is not
connected to itself accross the periodic boundaries.'''
from scipy import ndimage
# Reimplementation of "binary_fill_holes" from ndimage
mask = np.logical_not(self.data) # only modify locations which are "False" at the moment
tmp = np.zeros(mask.shape,bool) # create empty array to "grow from boundaries"
ndimage.binary_dilation(tmp,structure=None,iterations=-1,
mask=mask,output=self.data,border_value=1,
origin=0) # everything connected to the boundary is now True in self.data
# Remove holes which overlap the boundaries
if any(self.periodicity):
self.data = self._labels_periodic(map_to_zero=True)[0]>0
# Invert to get the final result
np.logical_not(self.data,self.data)
# If labels have been computed already, recompute them to stay consistent
if self.labels is not None:
self.label()
def probe(self,idx):
'''Returns whether or not a point at idx is True or False.'''
return self.data[tuple(idx)]
def volume(self):
'''Returns the sum of True values.'''
return np.sum(self.data)
def volume_feature(self,label=None):
'''Returns volume of features, i.e. connected regions which have been
labeled using the label() method. If 'label' is None all volumes
are returned including the volume of the background region. The array
is sorted by labels, i.e. vol[0] is the volume of the background region,
vol[1] the volume of label 1, etc. If 'label' is an integer value, only
the volume of the corresponding region is returned.
Note: it is more efficient to retrieve all volumes at once than querying
single labels.'''
if self.labels is None:
self.label()
if label is None:
return np.bincount(self.labels.ravel())
# TBD: compare performance with
# sizes = ndimage.sum(mask, label_im, range(nb_labels + 1))
# http://scipy-lectures.org/advanced/image_processing/auto_examples/plot_find_object.html
else:
return np.sum(self.labels==label)
def volume_domain(self):
'''Returns volume of entire domain. Should be equal to sum(volume_feature()).'''
return np.prod(self._dim)
def feature_labels_by_volume(self,descending=True):
'''Returns labels of connected regions sorted by volume.'''
labels = np.argsort(self.volume_feature()[1:])+1
if descending: labels = labels[::-1]
return labels
def discard_feature(self,selection):
if self.labels is None:
self.label()
selection = self._select_feature(selection)
# Map tagged regions to zero in order to discard them
map_ = np.array(range(self.nlabels+1),dtype=self.labels.dtype)
map_[selection] = 0
# Remove gaps from target mapping
idx_,map_ = np.unique(map_,return_index=True,return_inverse=True)[1:3]
# Discard regions
self.labels = map_[self.labels]
self.nlabels = np.max(map_)
for axis in range(self._ndim):
if self.wrap[axis] is not None:
self.wrap[axis] = self.wrap[axis][idx_]
self.data = self.labels>0
def isolate_feature(self,selection,array=None):
from scipy import ndimage
if self.labels is None:
self.label()
selection = self._select_feature(selection)
output1 = []
has_array = array is not None
if has_array:
assert np.all(array.shape==self._dim)
output2 = []
for lab_ in selection:
# Extract feature of interest
if lab_==0:
data_ = np.logical_not(self.data)
if has_array: data2_ = array
else:
data_ = (self.labels[self._feat_slice[lab_-1]]==lab_)
if has_array: data2_ = array[self._feat_slice[lab_-1]]
# If feature is wrapped periodically, duplicate it and extract
# largest one
iswrapped = False
rep_ = self._ndim*[1]
for axis in range(self._ndim):
if self.wrap[axis] is not None and self.wrap[axis][lab_]:
rep_[axis] = 2
iswrapped = True
if iswrapped:
data_ = np.tile(data_,rep_)
l_,nl_ = ndimage.label(data_,structure=self.structure)
vol_ = np.bincount(l_.ravel())
il_ = np.argmax(vol_[1:])+1
sl_ = ndimage.find_objects(l_==il_)[0]
data_ = data_[sl_]
if has_array:
data2_ = np.tile(data2_,rep_)[sl_]
# Add to output
output1.append(data_)
if has_array: output2.append(data2_)
if has_array:
return (output1,output2)
else:
return output1
def triangulate_feature(self,selection,origin=(0,0,0),spacing=(1,1,1),array=None):
assert self._ndim==3, "Triangulation requires 3D data."
from scipy import ndimage
if self.labels is None:
self.label()
selection = self._select_feature(selection)
output = []
pw = tuple((1,1) for ii in range(self._ndim))
has_array = array is not None
if has_array:
assert np.all(array.shape==self._dim)
output2 = []
for lab_ in selection:
if has_array:
data_,scal_ = self.isolate_feature(lab_,array=array)
data_,scal_ = data_[0],scal_[0]
# Fill interior in case we filled holes which is not in scal_
data_ = ndimage.binary_erosion(data_,structure=self.structure,iterations=2)
scal_[data_] = 1.0
scal_ = np.pad(scal_,pw,mode='reflect',reflect_type='odd')
pd = Field3d(scal_,origin,spacing).vtk_contour(0.0)
else:
data_ = self.isolate_feature(lab_)[0]
data_ = np.pad(data_.astype(float),pw,mode='constant',constant_values=-1.0)
pd = Field3d(data_,origin,spacing).vtk_contour(0.5).smooth(1000)
output.append(pd)
return output
def _select_feature(self,selection):
dtype = self.labels.dtype
if selection is None:
return np.array(range(1,self.nlabels+1))
elif np.issubdtype(type(selection),np.integer):
return np.array(selection,dtype=dtype,ndmin=1)
elif isinstance(selection,(list,tuple,np.ndarray)):
selection = np.array(selection)
if selection.dtype==np.dtype('bool'):
assert selection.ndim==1 and selection.shape[0]==self.nlabels+1,\
"Boolean indexing must provide count+1 values."
else:
selection = selection.astype(dtype)
assert np.max(selection)<=self.nlabels and np.min(selection)>=0,\
"Entry in selection is out-of-bounds."
return selection
else:
raise ValueError('Invalid input. Accepting int,list,tuple,ndarray.')
class ConnectedRegions:
def __init__(self,binarr,periodicity,connect_diagonals=False,fill_holes=False,bytes_label=32):
assert isinstance(binarr,np.ndarray) and binarr.dtype==np.dtype('bool'),\
"'binarr' must be a numpy array of dtype('bool')."
assert all([isinstance(x,(bool,int)) for x in periodicity]),\
"'periodicity' requires bool values."
assert bytes_label in (8,16,32,64),\
"'bytes_label' must be one of {8,16,32,64}."
self._dim = binarr.shape
self._ndim = binarr.ndim
assert self._ndim in (2,3),\
"'binarr' must be either two or three dimensional."
assert len(periodicity)==self._ndim,\
"Length of 'periodicity' must match number of dimensions of data."
from scipy import ndimage
# Construct connectivity stencil
if self._ndim==2:
connectivity = np.ones((3,3),dtype='bool')
if not connect_diagonals:
connectivity[0,0] = False
connectivity[0,2] = False
connectivity[2,0] = False
connectivity[0,2] = False
else:
connectivity = np.ones((3,3,3),dtype='bool')
if not connect_diagonals:
connectivity[0,0,0] = False
connectivity[2,0,0] = False
connectivity[0,2,0] = False
connectivity[0,0,2] = False
connectivity[2,2,0] = False
connectivity[0,2,2] = False
connectivity[2,0,2] = False
connectivity[2,2,2] = False
# Compute labels:
# this does not take into account periodic wrapping
dtype_label = np.dtype('uint'+str(bytes_label))
self.label = np.empty(self._dim,dtype=dtype_label)
ndimage.label(binarr,structure=connectivity,output=self.label)
self.count = np.max(self.label)
# Merge labels if there are periodic overlaps
map_tgt = np.array(range(0,self.count+1),dtype=dtype_label)
for axis in range(self._ndim):
if not periodicity[axis]:
continue
# Merge the first and last plane and compute connectivity
sl = self._ndim*[slice(None)]
sl[axis] = (-1,0)
binarr_ = binarr[tuple(sl)]
label_ = np.empty(binarr_.shape,dtype=dtype_label)
ndimage.label(binarr_,structure=connectivity,output=label_)
for val_ in np.unique(label_):
# Get all global labels which are associated to a region
# connected over the boundary
global_labels = list(np.unique(self.label[tuple(sl)][label_==val_]))
# If there is only one label, nothing needs to be done
if len(global_labels)==1:
continue
# Determine target label:
# this needs to be done recursively because the original
# target may already be reassigned
tgt = global_labels[0]
while tgt!=map_tgt[tgt]:
tgt=map_tgt[tgt]
map_tgt[global_labels[1:]] = tgt
# Remove gaps from target mapping
map_tgt = np.unique(map_tgt,return_inverse=True)[1]
# Remap labels
self.label = map_tgt[self.label]
self.count = np.max(map_tgt)
@classmethod
def from_field(cls,fld3d,threshold,periodicity,connect_diagonals=False,bytes_label=32,invert=False):
voxthr = VoxelThreshold.from_field(fld3d,threshold,invert=invert)
return cls.from_voxelthresh(voxthr,periodicity,
connect_diagonals=connect_diagonals,
bytes_label=bytes_label)
@classmethod
def from_voxelthresh(cls,voxthr,periodicity,connect_diagonals=False,bytes_label=32):
return cls(voxthr.data,periodicity,
connect_diagonals=connect_diagonals,
bytes_label=bytes_label)
def volume(self,label=None):
'''Returns volume of labeled regions. If 'label' is None all volumes
are returned including the volume of the background region. The array
is sorted by labels, i.e. vol[0] is the volume of the background region,
vol[1] the volume of label 1, etc. If 'label' is an integer value, only
the volume of the corresponding region is returned.
Note: it is more efficient to retrieve all volumes at once than querying
single labels.'''
if label is None:
return np.bincount(self.label.ravel())
else:
return np.sum(self.label==label)
def volume_domain(self):
'''Returns volume of entire domain. Should be equal to sum(volume()).'''
return np.prod(self._dim)
def labels_by_volume(self,descending=True):
'''Returns labels of connected regions sorted by volume.'''
labels = np.argsort(self.volume()[1:])+1
if descending:
labels = labels[::-1]
return labels
def discard_regions(self,selection):
selection = self._parse_selection(selection)
# Map tagged regions to zero in order to discard them
map_tgt = np.array(range(0,self.count+1),dtype=self.label.dtype)
map_tgt[selection] = 0
# Remove gaps from target mapping
map_tgt = np.unique(map_tgt,return_inverse=True)[1]
# Discard regions
self.label = map_tgt[self.label]
self.count = np.max(map_tgt)
def probe(self,idx):
'''Returns label for given index.'''
return self.label[tuple(idx)]
def vtk_contour(self,fld3,val,selection):
'''Computes contours of a Field3d only within selected structures.'''
assert isinstance(fld3,Field3d), "'fld3' must be a Field3d instance."
assert tuple(self._dim)==tuple(fld3.dim()), \
"'fld3' must be of dimension {}.".format(self._dim)
selection = self._parse_selection(selection)
from scipy import ndimage
# Create binary map of selection
map_tgt = np.zeros(self.count+1,dtype='bool')
map_tgt[selection] = True
binary_map = map_tgt[self.label]
# Add an extra cell to get the contour interpolation right
print(np.sum(binary_map))
binary_map = ndimage.binary_dilation(binary_map)
print(np.sum(binary_map))
# Extract the subfield
fld_con = fld3.copy()
fld_con.data[~binary_map] = np.nan
# Compute the contour
return fld_con.vtk_contour(val)
def _parse_selection(self,selection):
dtype = self.label.dtype
if np.issubdtype(type(selection),np.integer):
return np.array(selection,dtype=dtype)
elif isinstance(selection,(list,tuple,np.ndarray)):
selection = np.array(selection)
if selection.dtype==np.dtype('bool'):
assert selection.ndim==1 and selection.shape[0]==self.count+1,\
"Boolean indexing must provide count+1 values."
else:
selection = selection.astype(dtype)
assert np.max(selection)<=self.count and np.min(selection)>=0,\
"Entry in selection is out-of-bounds."
return selection
else:
raise ValueError('Invalid input. Accepting int,list,tuple,ndarray.')
class ChunkIterator:
'''Iterates through all chunks. 'snapshot' must be an instance
of a class which returns a Field3d from the method call
snapshot.field_chunk(rank,key,keep_ghost=keep_ghost).
One example implementation is UCFSnapshot from suspendtools.ucf.'''
def __init__(self,snapshot,key,keep_ghost=True):
self.snapshot = snapshot
self.key = key
self.keep_ghost = keep_ghost
self.iter_rank = 0
def __iter__(self):
self.iter_rank = 0
return self
def __next__(self):
if self.iter_rank<self.snapshot.nproc():
field = self.snapshot.field_chunk(
self.iter_rank,self.key,keep_ghost=self.keep_ghost)
self.iter_rank += 1
return field
else:
raise StopIteration