added deep copy flag. The default is now shallow copy!

This commit is contained in:
Michael Krayer 2021-08-05 00:20:58 +02:00
parent 5b40dea605
commit 9527941936
2 changed files with 29 additions and 12 deletions

View File

@ -1,9 +1,13 @@
import numpy as np
class Field3d:
def __init__(self,data,origin,spacing):
def __init__(self,data,origin,spacing,deep=False):
assert len(origin)==3, "'origin' must be of length 3"
assert len(spacing)==3, "'spacing' must be of length 3"
self.data = np.array(data)
assert isinstance(data,np.ndarray), "'data' must be numpy.ndarray."
if deep:
self.data = data.copy()
else:
self.data = data
self.origin = tuple([float(x) for x in origin])
self.spacing = tuple([float(x) for x in spacing])
self.eps_collapse = 1e-8
@ -187,16 +191,28 @@ class Field3d:
self.data[ib:ib+nx,jb:jb+ny,kb:kb+nz] = subfield.data[:,:,:]
return
def extract_subfield(self,idx_origin,dim,stride=(1,1,1)):
assert all(idx_origin[ii]>=0 and idx_origin[ii]<self.dim(axis=ii) for ii in range(3)),\
"'origin' is out-of-bounds."
assert all(idx_origin[ii]+stride[ii]*(dim[ii]-1)<self.dim(axis=ii) for ii in range(3)),\
"endpoint is out-of-bounds."
sl = tuple(slice(idx_origin[ii],idx_origin[ii]+stride[ii]*dim[ii],stride[ii]) for ii in range(3))
def extract_subfield(self,idx_origin,dim,stride=(1,1,1),deep=False,strict_bounds=True):
if strict_bounds:
assert all(idx_origin[ii]>=0 and idx_origin[ii]<self.dim(axis=ii) for ii in range(3)),\
"'origin' is out-of-bounds."
assert all(idx_origin[ii]+stride[ii]*(dim[ii]-1)<self.dim(axis=ii) for ii in range(3)),\
"endpoint is out-of-bounds."
else:
for ii in range(3):
if idx_origin[ii]<0:
dim[ii] += idx_origin[ii]
idx_origin[ii] = 0
if idx_origin[ii]+stride[ii]*(dim[ii]-1)>=self.dim(axis=ii):
dim[ii] = (self.dim(axis=ii)-idx_origin[ii])//stride[ii]
if dim[ii]<=0:
return None
sl = tuple(slice(idx_origin[ii],
idx_origin[ii]+stride[ii]*dim[ii],
stride[ii]) for ii in range(3))
origin = self.coordinate(idx_origin)
spacing = tuple(self.spacing[ii]*stride[ii] for ii in range(3))
data = self.data[sl].copy()
return Field3d(data,origin,spacing)
data = self.data[sl]
return Field3d(data,origin,spacing,deep=deep)
def coordinate(self,idx,axis=None):
if axis is None:

View File

@ -728,7 +728,8 @@ class PPP:
return self._subfield(key,stride,(0,0,0))
def _subfield(self,key,stride,num_ghost,no_lower_ghost=(False,False,False),
no_upper_ghost=(False,False,False)):
no_upper_ghost=(False,False,False),
deep=True):
'''Returns the field with a stride applied.'''
stride = np.array(stride,dtype=int)
num_ghost = np.array(num_ghost,dtype=int)
@ -765,7 +766,7 @@ class PPP:
assert all(idx_origin>=0)
assert all(idx_endpoint<np.array(self.field[key].dim()))
# Return the subfield
return self.field[key].extract_subfield(idx_origin,num_points,stride=stride)
return self.field[key].extract_subfield(idx_origin,num_points,stride=stride,deep=deep)
def _merge_at_root(self,key,stride=(1,1,1)):
'''Returns the entire field gathered from all processors with a