clipping and thresholding for particles; improvements for field

This commit is contained in:
Michael Krayer 2021-08-06 22:02:35 +02:00
parent 4c84ea6adc
commit f5714e1987
3 changed files with 73 additions and 8 deletions

View File

@ -214,7 +214,11 @@ class Field3d:
data = self.data[sl] data = self.data[sl]
return Field3d(data,origin,spacing,deep=deep) return Field3d(data,origin,spacing,deep=deep)
def clip(self,position,axis,invert=False,deep=False): def clip(self,position,axis,invert=False,deep=False,is_relative=False):
'''Extracts a subfield by clipping with a plane with normal pointing
direction specified by 'axis'.'''
if is_relative:
coord = self.origin[axis] + coord*self.dim(axis=axis)*self.spacing[axis]
idx_clip = self.nearest_gridpoint(coord,axis=axis,lower=True) idx_clip = self.nearest_gridpoint(coord,axis=axis,lower=True)
sl = 3*[slice(None)] sl = 3*[slice(None)]
origin_ = self.origin origin_ = self.origin
@ -228,21 +232,29 @@ class Field3d:
return Field3d(data_,origin_,spacing_,deep=deep) return Field3d(data_,origin_,spacing_,deep=deep)
def clip_box(self,bounds,deep=False,is_relative=False): def clip_box(self,bounds,deep=False,is_relative=False):
'''Extracts a subfield by clipping with a box.'''
if is_relative: if is_relative:
bounds = tuple(self.origin[ii//2] + bounds[ii]*self.dim(ii//2)*self.spacing[ii//2] for ii in range(6)) bounds = tuple(self.origin[ii//2] + bounds[ii]*self.dim(ii//2)*self.spacing[ii//2] for ii in range(6))
print(bounds)
idx_lo = self.nearest_gridpoint((bounds[0],bounds[2],bounds[4]),lower=True) idx_lo = self.nearest_gridpoint((bounds[0],bounds[2],bounds[4]),lower=True)
idx_hi = self.nearest_gridpoint((bounds[1],bounds[3],bounds[5]),lower=True) idx_hi = self.nearest_gridpoint((bounds[1],bounds[3],bounds[5]),lower=True)
origin_ = self.origin
spacing_ = self.spacing
idx_lo = tuple(0 if idx_lo[axis]<0 else idx_lo[axis] for axis in range(3)) idx_lo = tuple(0 if idx_lo[axis]<0 else idx_lo[axis] for axis in range(3))
idx_hi = tuple(0 if idx_hi[axis]<0 else idx_hi[axis] for axis in range(3)) idx_hi = tuple(0 if idx_hi[axis]<0 else idx_hi[axis] for axis in range(3))
sl_ = tuple([slice(idx_lo[0],idx_hi[0]+1), sl_ = tuple([slice(idx_lo[0],idx_hi[0]+1),
slice(idx_lo[1],idx_hi[1]+1), slice(idx_lo[1],idx_hi[1]+1),
slice(idx_lo[2],idx_hi[2]+1)]) slice(idx_lo[2],idx_hi[2]+1)])
origin_ = tuple(self.origin[axis]+idx_lo[axis]*self.spacing[axis] for axis in range(3))
spacing_ = self.spacing
data_ = self.data[sl_] data_ = self.data[sl_]
return Field3d(data_,origin_,spacing_,deep=deep) return Field3d(data_,origin_,spacing_,deep=deep)
def threshold(self,val,invert=False):
'''Returns a binary array indicating which grid points are above,
or in case of invert=True below, a given threshold.'''
if invert:
return self.data>=val
else:
return self.data<val
def coordinate(self,idx,axis=None): def coordinate(self,idx,axis=None):
if axis is None: if axis is None:
assert len(idx)==3, "If 'axis' is None, 'idx' must be a tuple/list of length 3." assert len(idx)==3, "If 'axis' is None, 'idx' must be a tuple/list of length 3."

View File

@ -81,4 +81,11 @@ def h5read(file,name='data'):
f = file if is_open else h5py.File(file,'r') f = file if is_open else h5py.File(file,'r')
data = f[name][:] data = f[name][:]
if not is_open: f.close() if not is_open: f.close()
return data return data
def convert_coordinate_to_absolute(coord,bounds):
assert len(coord) in (3,6)
if len(coord)==3:
return tuple(bounds[2*ii]+(bounds[2*ii+1]-bounds[2*ii])*coord[ii] for ii in range(3))
else:
return tuple(bounds[2*(ii//2)]+(bounds[2*(ii//2)+1]-bounds[2*(ii//2)])*coord[ii] for ii in range(6))

View File

@ -29,6 +29,7 @@ class Particles:
self.period = tuple(period) self.period = tuple(period)
self.frame_velocity = np.zeros(3) self.frame_velocity = np.zeros(3)
return return
@classmethod @classmethod
def from_array(cls,pp,col,time,period,select_col=None): def from_array(cls,pp,col,time,period,select_col=None):
assert 'id' in col, "Need column 'id' to initialize Particles" assert 'id' in col, "Need column 'id' to initialize Particles"
@ -47,14 +48,14 @@ class Particles:
if (select_col is None or key in select_col or key in ('id','x','y','z')): if (select_col is None or key in select_col or key in ('id','x','y','z')):
attr[key] = pp[col[key],:,0].squeeze() attr[key] = pp[col[key],:,0].squeeze()
return cls(num,time,attr,period) return cls(num,time,attr,period)
def __getitem__(self,val): def __getitem__(self,val):
if isinstance(val,int): if isinstance(val,int):
lo,hi = val,val+1 lo,hi = val,val+1
hi = hi if hi!=0 else None hi = hi if hi!=0 else None
val = slice(lo,hi) val = slice(lo,hi)
elif not isinstance(val,slice):
raise TypeError("Particles can only be sliced by slice objects or integers.")
return self._slice(val) return self._slice(val)
def __str__(self): def __str__(self):
str = '{:d} particles with\n'.format(self.num) str = '{:d} particles with\n'.format(self.num)
str+= ' time: {}\n'.format(self.time) str+= ' time: {}\n'.format(self.time)
@ -65,17 +66,26 @@ class Particles:
str+= ' period: {}\n'.format(self.period) str+= ' period: {}\n'.format(self.period)
str+= ' frame velocity: {}'.format(self.frame_velocity) str+= ' frame velocity: {}'.format(self.frame_velocity)
return str return str
def _slice(self,slice_part=slice(None)):
def _slice(self,slice_part=None):
assert slice_part is None or \
isinstance(slice_part,slice) or \
(isinstance(slice_part,np.ndarray) and
slice_part.dtype==bool and
slice_part.shape[0]==self.num),\
"Slicing requires int, slice or logical array."
attr = {} attr = {}
for key in self.attr: for key in self.attr:
attr[key] = self.attr[key][slice_part] attr[key] = self.attr[key][slice_part]
num = attr['id'].shape[0] num = attr['id'].shape[0]
return Particles(num,self.time,attr,self.period) return Particles(num,self.time,attr,self.period)
def copy(self): def copy(self):
attr = {} attr = {}
for key in self.attr: for key in self.attr:
attr[key] = self.attr[key].copy() attr[key] = self.attr[key].copy()
return Particles(self.num,self.time,attr,self.period) return Particles(self.num,self.time,attr,self.period)
def add_attribute(self,key,val): def add_attribute(self,key,val):
if isinstance(val,(tuple,list,np.ndarray)): if isinstance(val,(tuple,list,np.ndarray)):
assert len(val)==self.num and val.ndim==1, "Invalid 'val'." assert len(val)==self.num and val.ndim==1, "Invalid 'val'."
@ -83,10 +93,13 @@ class Particles:
else: else:
self.attr[key] = np.full(self.num,val) self.attr[key] = np.full(self.num,val)
return return
def del_attribute(self,key): def del_attribute(self,key):
del self.attr[key] del self.attr[key]
def get_attribute(self,key): def get_attribute(self,key):
return self.attr[key] return self.attr[key]
def get_position(self,axis=None): def get_position(self,axis=None):
if axis is None: if axis is None:
return np.vstack([self.attr[key].copy() for key in ('x','y','z')]) return np.vstack([self.attr[key].copy() for key in ('x','y','z')])
@ -94,14 +107,17 @@ class Particles:
assert axis<3, "'axis' must be smaller than 3." assert axis<3, "'axis' must be smaller than 3."
key = ('x','y','z')[axis] key = ('x','y','z')[axis]
return self.attr[key].copy() return self.attr[key].copy()
def has_attribute(self,key): def has_attribute(self,key):
return key in self.attr return key in self.attr
def translate(self,translation,axis=0): def translate(self,translation,axis=0):
'''Translates particles. Periodicity must be enforced manually.''' '''Translates particles. Periodicity must be enforced manually.'''
assert axis<3, "'axis' must be smaller than 3." assert axis<3, "'axis' must be smaller than 3."
key = ('x','y','z')[axis] key = ('x','y','z')[axis]
self.attr[key] += translation self.attr[key] += translation
return return
def set_frame_velocity(self,val,axis=0): def set_frame_velocity(self,val,axis=0):
'''Adjust frame of reference by translating particles by '''Adjust frame of reference by translating particles by
time*frame_velocity and subtracting frame_velocity from time*frame_velocity and subtracting frame_velocity from
@ -114,6 +130,7 @@ class Particles:
self.attr[key] -= valdiff self.attr[key] -= valdiff
self.frame_velocity[axis] = val self.frame_velocity[axis] = val
return return
def enforce_periodicity(self,axis=None): def enforce_periodicity(self,axis=None):
if axis is None: if axis is None:
for ii in range(3): for ii in range(3):
@ -123,6 +140,7 @@ class Particles:
key = ('x','y','z')[axis] key = ('x','y','z')[axis]
self.attr[key] %= self.period[axis] self.attr[key] %= self.period[axis]
return return
def position_with_duplicates(self,ipart,padding=0.0): def position_with_duplicates(self,ipart,padding=0.0):
pos = np.array( pos = np.array(
(self.attr['x'][ipart], (self.attr['x'][ipart],
@ -144,6 +162,7 @@ class Particles:
tmp[axis] = tmp[axis]-self.period[axis] tmp[axis] = tmp[axis]-self.period[axis]
posd.append(tmp) posd.append(tmp)
return posd return posd
def mask_field(self,fld,cval=np.nan,padding=0.0): def mask_field(self,fld,cval=np.nan,padding=0.0):
'''Fills grid points which lie inside of solid phase with values.''' '''Fills grid points which lie inside of solid phase with values.'''
assert self.has_attribute('r'), "Attribute 'r' required." assert self.has_attribute('r'), "Attribute 'r' required."
@ -162,6 +181,32 @@ class Particles:
dist = (xsf-pos[0])**2 + (ysf-pos[1])**2 + (zsf-pos[2])**2 dist = (xsf-pos[0])**2 + (ysf-pos[1])**2 + (zsf-pos[2])**2
sf.data[dist<=rp*rp] = cval sf.data[dist<=rp*rp] = cval
return return
def clip(self,position,axis,invert=False):
'''Clips particles by a plane with normal pointing
direction specified by 'axis'.'''
return self.threshold(('x','y','z')[axis],position,invert=(not invert))
def clip_box(self,bounds,invert=False):
'''Extracts particls within a bounding box.'''
li = np.logical_and(np.logical_and(
np.logical_and(self.attr['x']>=bounds[0],self.attr['x']<=bounds[1]),
np.logical_and(self.attr['y']>=bounds[2],self.attr['y']<=bounds[3])),
np.logical_and(self.attr['z']>=bounds[4],self.attr['z']<=bounds[5]))
if invert:
np.logical_not(li,li)
return self._slice(li)
def threshold(self,key,val,invert=False):
'''Returns particles for which specified attribute is above,
or in case invert=True below, the specified threshold.'''
assert key in self.attr, "'key' not found in attr."
if invert:
li = self.attr[key]<val
else:
li = self.attr[key]>=val
return self._slice(li)
def to_vtk(self,deep=False): def to_vtk(self,deep=False):
import pyvista as pv import pyvista as pv
position = np.vstack([self.attr[key] for key in ('x','y','z')]).transpose() position = np.vstack([self.attr[key] for key in ('x','y','z')]).transpose()
@ -169,6 +214,7 @@ class Particles:
for key in self.attr: for key in self.attr:
mesh[key] = self.attr[key] mesh[key] = self.attr[key]
return mesh return mesh
def glyph(self,theta_resolution=30,phi_resolution=30,deep=False): def glyph(self,theta_resolution=30,phi_resolution=30,deep=False):
import pyvista as pv import pyvista as pv
assert self.has_attribute('r'), "Attribute 'r' required." assert self.has_attribute('r'), "Attribute 'r' required."