more efficient particle masking

This commit is contained in:
Michael Krayer 2021-08-19 15:43:10 +02:00
parent 928cc9c9d2
commit 96c727b824
1 changed files with 46 additions and 14 deletions

View File

@ -166,20 +166,52 @@ class Particles:
def mask_field(self,fld,cval=np.nan,padding=0.0):
'''Fills grid points which lie inside of solid phase with values.'''
assert self.has_attribute('r'), "Attribute 'r' required."
for ipart in range(self.num):
# Slice a box from the field around the particle
rp = self.attr['r'][ipart]+padding
for pos in self.position_with_duplicates(ipart,padding=padding):
idxlo = np.array(fld.nearest_gridpoint(pos-rp,lower=True))
sfdim = np.ceil(2*rp/fld.spacing+2).astype('int')
sf = fld.extract_subfield(idxlo,sfdim,deep=False,strict_bounds=False)
if sf is None:
from numba import jit
# Define functions for efficient duplication and masking
def __duplicate(xyzr,period,axis):
li = (xyzr[:,axis]-xyzr[:,3])<0.0
dupl_lo = xyzr[li,:].copy()
dupl_lo[:,axis] += period
li = (xyzr[:,axis]+xyzr[:,3])>period
dupl_hi = xyzr[li,:].copy()
dupl_hi[:,axis] -= period
return np.concatenate((xyzr,dupl_lo,dupl_hi),axis=0)
@jit(nopython=True)
def __mask(origin,spacing,data,cval,xyzr):
npart = xyzr.shape[0]
nx,ny,nz = data.shape
for ipart in range(npart):
xp = xyzr[ipart,0]
yp = xyzr[ipart,1]
zp = xyzr[ipart,2]
rp = xyzr[ipart,3]
ib = int(np.ceil((xp-rp-origin[0])/spacing[0]))
jb = int(np.ceil((yp-rp-origin[1])/spacing[1]))
kb = int(np.ceil((zp-rp-origin[2])/spacing[2]))
ie = min(nx,ib+int(np.ceil(2*rp/spacing[0])))
je = min(ny,jb+int(np.ceil(2*rp/spacing[1])))
ke = min(nz,kb+int(np.ceil(2*rp/spacing[2])))
ib = max(ib,0)
jb = max(jb,0)
kb = max(kb,0)
for ii in range(ib,ie):
x_ = origin[0]+ii*spacing[0]
for jj in range(jb,je):
y_ = origin[1]+jj*spacing[1]
for kk in range(kb,ke):
z_ = origin[2]+kk*spacing[2]
dsq = (x_-xp)**2+(y_-yp)**2+(z_-zp)**2
if dsq<=rp**2:
data[ii,jj,kk] = cval
return
# Duplicate particles in all periodic directions
xyzr = np.stack((self.attr['x'],self.attr['y'],self.attr['z'],self.attr['r']+padding),axis=1)
for axis in range(3):
if self.period[axis] is None:
continue
xsf = sf.x().reshape((-1,1,1))
ysf = sf.y().reshape((1,-1,1))
zsf = sf.z().reshape((1,1,-1))
dist = (xsf-pos[0])**2 + (ysf-pos[1])**2 + (zsf-pos[2])**2
sf.data[dist<=rp*rp] = cval
xyzr = __duplicate(xyzr,self.period[axis],axis)
# Mask values
__mask(fld.origin,fld.spacing,fld.data,cval,xyzr)
return
def clip(self,position,axis,invert=False):