added deep copy flag. The default is now shallow copy!
This commit is contained in:
parent
5b40dea605
commit
9527941936
36
field.py
36
field.py
|
|
@ -1,9 +1,13 @@
|
|||
import numpy as np
|
||||
class Field3d:
|
||||
def __init__(self,data,origin,spacing):
|
||||
def __init__(self,data,origin,spacing,deep=False):
|
||||
assert len(origin)==3, "'origin' must be of length 3"
|
||||
assert len(spacing)==3, "'spacing' must be of length 3"
|
||||
self.data = np.array(data)
|
||||
assert isinstance(data,np.ndarray), "'data' must be numpy.ndarray."
|
||||
if deep:
|
||||
self.data = data.copy()
|
||||
else:
|
||||
self.data = data
|
||||
self.origin = tuple([float(x) for x in origin])
|
||||
self.spacing = tuple([float(x) for x in spacing])
|
||||
self.eps_collapse = 1e-8
|
||||
|
|
@ -187,16 +191,28 @@ class Field3d:
|
|||
self.data[ib:ib+nx,jb:jb+ny,kb:kb+nz] = subfield.data[:,:,:]
|
||||
return
|
||||
|
||||
def extract_subfield(self,idx_origin,dim,stride=(1,1,1)):
|
||||
assert all(idx_origin[ii]>=0 and idx_origin[ii]<self.dim(axis=ii) for ii in range(3)),\
|
||||
"'origin' is out-of-bounds."
|
||||
assert all(idx_origin[ii]+stride[ii]*(dim[ii]-1)<self.dim(axis=ii) for ii in range(3)),\
|
||||
"endpoint is out-of-bounds."
|
||||
sl = tuple(slice(idx_origin[ii],idx_origin[ii]+stride[ii]*dim[ii],stride[ii]) for ii in range(3))
|
||||
def extract_subfield(self,idx_origin,dim,stride=(1,1,1),deep=False,strict_bounds=True):
|
||||
if strict_bounds:
|
||||
assert all(idx_origin[ii]>=0 and idx_origin[ii]<self.dim(axis=ii) for ii in range(3)),\
|
||||
"'origin' is out-of-bounds."
|
||||
assert all(idx_origin[ii]+stride[ii]*(dim[ii]-1)<self.dim(axis=ii) for ii in range(3)),\
|
||||
"endpoint is out-of-bounds."
|
||||
else:
|
||||
for ii in range(3):
|
||||
if idx_origin[ii]<0:
|
||||
dim[ii] += idx_origin[ii]
|
||||
idx_origin[ii] = 0
|
||||
if idx_origin[ii]+stride[ii]*(dim[ii]-1)>=self.dim(axis=ii):
|
||||
dim[ii] = (self.dim(axis=ii)-idx_origin[ii])//stride[ii]
|
||||
if dim[ii]<=0:
|
||||
return None
|
||||
sl = tuple(slice(idx_origin[ii],
|
||||
idx_origin[ii]+stride[ii]*dim[ii],
|
||||
stride[ii]) for ii in range(3))
|
||||
origin = self.coordinate(idx_origin)
|
||||
spacing = tuple(self.spacing[ii]*stride[ii] for ii in range(3))
|
||||
data = self.data[sl].copy()
|
||||
return Field3d(data,origin,spacing)
|
||||
data = self.data[sl]
|
||||
return Field3d(data,origin,spacing,deep=deep)
|
||||
|
||||
def coordinate(self,idx,axis=None):
|
||||
if axis is None:
|
||||
|
|
|
|||
|
|
@ -728,7 +728,8 @@ class PPP:
|
|||
return self._subfield(key,stride,(0,0,0))
|
||||
|
||||
def _subfield(self,key,stride,num_ghost,no_lower_ghost=(False,False,False),
|
||||
no_upper_ghost=(False,False,False)):
|
||||
no_upper_ghost=(False,False,False),
|
||||
deep=True):
|
||||
'''Returns the field with a stride applied.'''
|
||||
stride = np.array(stride,dtype=int)
|
||||
num_ghost = np.array(num_ghost,dtype=int)
|
||||
|
|
@ -765,7 +766,7 @@ class PPP:
|
|||
assert all(idx_origin>=0)
|
||||
assert all(idx_endpoint<np.array(self.field[key].dim()))
|
||||
# Return the subfield
|
||||
return self.field[key].extract_subfield(idx_origin,num_points,stride=stride)
|
||||
return self.field[key].extract_subfield(idx_origin,num_points,stride=stride,deep=deep)
|
||||
|
||||
def _merge_at_root(self,key,stride=(1,1,1)):
|
||||
'''Returns the entire field gathered from all processors with a
|
||||
|
|
|
|||
Loading…
Reference in New Issue