saving of reduced field for visualization at later point; save_state now based on hdf5
This commit is contained in:
parent
0e48c9b382
commit
fcd989b3f9
24
field.py
24
field.py
|
|
@ -100,6 +100,18 @@ class Field3d:
|
||||||
assert (chunk['nzl']+2*chunk['ighost'])==nz, "Invalid chunk data: nzl != chunk['data'].shape[2]"
|
assert (chunk['nzl']+2*chunk['ighost'])==nz, "Invalid chunk data: nzl != chunk['data'].shape[2]"
|
||||||
return cls(chunk['data'],origin=(xo,yo,zo),spacing=(dx,dy,dz))
|
return cls(chunk['data'],origin=(xo,yo,zo),spacing=(dx,dy,dz))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls,file,name='Field3d'):
|
||||||
|
import h5py
|
||||||
|
is_open = isinstance(file,(h5py.File,h5py.Group))
|
||||||
|
f = file if is_open else h5py.File(file,'r')
|
||||||
|
g = f[name]
|
||||||
|
origin = tuple(g['origin'])
|
||||||
|
spacing = tuple(g['spacing'])
|
||||||
|
data = g['data'][:]
|
||||||
|
if not is_open: f.close()
|
||||||
|
return cls(data,origin,spacing)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def allocate(cls,dim,origin,spacing,fill=None,dtype=numpy.float64,pseudo=False):
|
def allocate(cls,dim,origin,spacing,fill=None,dtype=numpy.float64,pseudo=False):
|
||||||
'''Allocates an empty field, or a field filled with 'fill'.'''
|
'''Allocates an empty field, or a field filled with 'fill'.'''
|
||||||
|
|
@ -121,6 +133,18 @@ class Field3d:
|
||||||
data = numpy.full(dim,fill,dtype=dtype)
|
data = numpy.full(dim,fill,dtype=dtype)
|
||||||
return cls(data,origin,spacing)
|
return cls(data,origin,spacing)
|
||||||
|
|
||||||
|
def save(self,file,name='Field3d',truncate=False):
|
||||||
|
import h5py
|
||||||
|
is_open = isinstance(file,h5py.File)
|
||||||
|
flag = 'w' if truncate else 'a'
|
||||||
|
f = file if is_open else h5py.File(file,flag)
|
||||||
|
g = f.create_group(name)
|
||||||
|
g.create_dataset('origin',data=self.origin)
|
||||||
|
g.create_dataset('spacing',data=self.spacing)
|
||||||
|
g.create_dataset('data',data=self.data)
|
||||||
|
if not is_open: f.close()
|
||||||
|
return
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return Field3d(self.data.copy(),self.origin,self.spacing)
|
return Field3d(self.data.copy(),self.origin,self.spacing)
|
||||||
|
|
||||||
|
|
|
||||||
297
parallel.py
297
parallel.py
|
|
@ -70,18 +70,55 @@ class PPP:
|
||||||
nghbr,field,symmetries)
|
nghbr,field,symmetries)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(cls,filename):
|
def from_state(cls,file,parallel=True,io_limit=None):
|
||||||
import pickle
|
import h5py
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
from .field import Field3d
|
||||||
comm = MPI.COMM_WORLD
|
comm = MPI.COMM_WORLD
|
||||||
rank = comm.Get_rank()
|
rank = comm.Get_rank()
|
||||||
fin = filename+'.{:05d}.pickle'.format(rank)
|
# Only use parallel IO if flag is set and h5py has MPIIO support
|
||||||
with open(fin,"rb") as f:
|
parallel = (h5py.h5fd.MPIO>=0) and parallel
|
||||||
payload = pickle.load(f)
|
if parallel:
|
||||||
(num_ghost,chunks_per_proc,origin,spacing,periodicity,bounds,
|
f = h5py.File(file,'r',driver='mpio',comm=comm)
|
||||||
proc_grid,nxp,nyp,nzp,nghbr,field,symmetries) = payload
|
else:
|
||||||
|
baton_wait(io_limit,comm=comm)
|
||||||
|
f = h5py.File(file,'r')
|
||||||
|
# Read attributes which are the same for all ranks
|
||||||
|
g = f['PPP']
|
||||||
|
field_names = tuple(x.decode('ascii') for x in g['field_names'])
|
||||||
|
field_loaded = tuple(x.decode('ascii') for x in g['field_loaded'])
|
||||||
|
origin = {}
|
||||||
|
proc_grid = {}
|
||||||
|
for key in field_names:
|
||||||
|
origin[key] = tuple(g[key]['origin'])
|
||||||
|
proc_grid[key] = []
|
||||||
|
proc_grid[key].append(tuple(g[key]['ibeg']))
|
||||||
|
proc_grid[key].append(tuple(g[key]['iend']))
|
||||||
|
proc_grid[key].append(tuple(g[key]['jbeg']))
|
||||||
|
proc_grid[key].append(tuple(g[key]['jend']))
|
||||||
|
proc_grid[key].append(tuple(g[key]['kbeg']))
|
||||||
|
proc_grid[key].append(tuple(g[key]['kend']))
|
||||||
|
num_ghost = tuple(g['num_ghost'])
|
||||||
|
chunks_per_proc = tuple(g['chunks_per_proc'])
|
||||||
|
spacing = tuple(g['spacing'])
|
||||||
|
periodicity = tuple(g['periodicity'])
|
||||||
|
bounds = tuple(g['bounds'])
|
||||||
|
nxp = int(g['nxp'][:])
|
||||||
|
nyp = int(g['nyp'][:])
|
||||||
|
nzp = int(g['nzp'][:])
|
||||||
|
# Independent read
|
||||||
|
grp_rank = '{:05d}'.format(rank)
|
||||||
|
g = f[grp_rank]
|
||||||
|
nghbr = g['nghbr'][:]
|
||||||
|
field = {}
|
||||||
|
symmetries = {}
|
||||||
|
for key in field_loaded:
|
||||||
|
field[key] = Field3d.from_file(g,name=key)
|
||||||
|
symmetries[key] = g[key]['symmetries'][:]
|
||||||
|
f.close()
|
||||||
assert nxp*nyp*nzp==comm.Get_size(), "The loaded state requires {} processors, but "\
|
assert nxp*nyp*nzp==comm.Get_size(), "The loaded state requires {} processors, but "\
|
||||||
"we are currently running with {}.".format(nxp*nyp*nzp,comm.Get_size())
|
"we are currently running with {}.".format(nxp*nyp*nzp,comm.Get_size())
|
||||||
|
if not parallel: baton_pass(io_limit,comm=comm)
|
||||||
func_load = None
|
func_load = None
|
||||||
proc_grid_ext = None
|
proc_grid_ext = None
|
||||||
nxp_ext,nyp_ext,nzp_ext = 3*[None]
|
nxp_ext,nyp_ext,nzp_ext = 3*[None]
|
||||||
|
|
@ -491,25 +528,134 @@ class PPP:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_state(self,filename):
|
def save_state(self,file,parallel=False):
|
||||||
import pickle
|
import h5py
|
||||||
fout = filename+'.{:05d}.pickle'.format(self.rank)
|
from mpi4py import MPI
|
||||||
payload = (
|
tbeg = MPI.Wtime()
|
||||||
self.num_ghost,
|
ascii_type = h5py.string_dtype('ascii',32)
|
||||||
self.chunks_per_proc,
|
# Only use parallel IO if flag is set and h5py has MPIIO support
|
||||||
self.origin,
|
parallel = (h5py.h5fd.MPIO>=0) and parallel
|
||||||
self.spacing,
|
if parallel:
|
||||||
self.periodicity,
|
f = h5py.File(file,'w',driver='mpio',comm=self.comm)
|
||||||
self.bounds,
|
else:
|
||||||
self.proc_grid,
|
self._baton_wait(1)
|
||||||
self.nxp,
|
f = h5py.File(file,'w') if self.rank==0 else h5py.File(file,'a')
|
||||||
self.nyp,
|
# Write attributes which are the same for all ranks
|
||||||
self.nzp,
|
if parallel or self.rank==0:
|
||||||
self.nghbr,
|
g = f.create_group('/PPP')
|
||||||
self.field,
|
# Create a variable which stores all the field names
|
||||||
self.symmetries)
|
d = g.create_dataset('field_names',len(self.origin),dtype=ascii_type)
|
||||||
with open(fout,"wb") as f:
|
for ii,key in enumerate(self.origin):
|
||||||
pickle.dump(payload,f)
|
d[ii] = key
|
||||||
|
d = g.create_dataset('field_loaded',len(self.field),dtype=ascii_type)
|
||||||
|
for ii,key in enumerate(self.field):
|
||||||
|
d[ii] = key
|
||||||
|
# Store all global field dependent data
|
||||||
|
for key in self.origin:
|
||||||
|
g2 = g.create_group(key)
|
||||||
|
d = g2.create_dataset('origin',3,dtype='f')
|
||||||
|
if self.rank==0: d[:] = self.origin[key]
|
||||||
|
d = g2.create_dataset('ibeg',self.nxp,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.proc_grid[key][0]
|
||||||
|
d = g2.create_dataset('iend',self.nxp,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.proc_grid[key][1]
|
||||||
|
d = g2.create_dataset('jbeg',self.nyp,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.proc_grid[key][2]
|
||||||
|
d = g2.create_dataset('jend',self.nyp,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.proc_grid[key][3]
|
||||||
|
d = g2.create_dataset('kbeg',self.nzp,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.proc_grid[key][4]
|
||||||
|
d = g2.create_dataset('kend',self.nzp,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.proc_grid[key][5]
|
||||||
|
# Store all global field independent data
|
||||||
|
d = g.create_dataset('num_ghost',3,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.num_ghost
|
||||||
|
d = g.create_dataset('chunks_per_proc',3,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.chunks_per_proc
|
||||||
|
d = g.create_dataset('spacing',3,dtype='f')
|
||||||
|
if self.rank==0: d[:] = self.spacing
|
||||||
|
d = g.create_dataset('periodicity',3,dtype='i')
|
||||||
|
if self.rank==0: d[:] = self.periodicity
|
||||||
|
d = g.create_dataset('bounds',6,dtype='f')
|
||||||
|
if self.rank==0: d[:] = self.bounds
|
||||||
|
d = g.create_dataset('nxp',1,'i')
|
||||||
|
if self.rank==0: d[:] = self.nxp
|
||||||
|
d = g.create_dataset('nyp',1,'i')
|
||||||
|
if self.rank==0: d[:] = self.nyp
|
||||||
|
d = g.create_dataset('nzp',1,'i')
|
||||||
|
if self.rank==0: d[:] = self.nzp
|
||||||
|
# Collectively create groups and datasets for rank-specific data
|
||||||
|
for ii in range(self.nproc):
|
||||||
|
if parallel or ii==self.rank:
|
||||||
|
g1 = f.create_group('{:05d}'.format(ii))
|
||||||
|
g1.create_dataset('nghbr',self.nghbr.shape,data=self.nghbr)
|
||||||
|
for key in self.field:
|
||||||
|
# To support parallel h5py, we need to create the groups here and cannot use
|
||||||
|
# the 'save' method of Field3d
|
||||||
|
g2 = g1.create_group('{}'.format(key))
|
||||||
|
g2.create_dataset('origin',3,dtype='f')
|
||||||
|
g2.create_dataset('spacing',3,dtype='f')
|
||||||
|
g2.create_dataset('data',
|
||||||
|
self.chunk_size(key,rank=ii,incl_ghost=True),
|
||||||
|
dtype=self.chunk_dtype(key))
|
||||||
|
g2.create_dataset('symmetries',(3,3,3),dtype='i')
|
||||||
|
# Independent write
|
||||||
|
grp_rank = '{:05d}'.format(self.rank)
|
||||||
|
f[grp_rank]['nghbr'][:] = self.nghbr
|
||||||
|
for key in self.field:
|
||||||
|
f[grp_rank][key]['origin'][:] = self.field[key].origin
|
||||||
|
f[grp_rank][key]['spacing'][:] = self.field[key].spacing
|
||||||
|
f[grp_rank][key]['data'][:] = self.field[key].data
|
||||||
|
f[grp_rank][key]['symmetries'][:] = self.symmetries[key]
|
||||||
|
f.close()
|
||||||
|
if not parallel: self._baton_pass(1)
|
||||||
|
self.comm.Barrier()
|
||||||
|
tend = MPI.Wtime()
|
||||||
|
if self.rank==0:
|
||||||
|
print("[save_state] Elapsed time: {:f}".format(tend-tbeg))
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_for_vtk(self,file,key,stride=(1,1,1),truncate=True,merge_at_root=True,on_pressure_grid=True):
|
||||||
|
'''Saves a field for visualization purposes. This means it will only have a single
|
||||||
|
lower ghost cell if there is an upper neighbor, and both a single and an upper
|
||||||
|
ghost cell if there is no upper neighbor (or merged).'''
|
||||||
|
import h5py
|
||||||
|
from mpi4py import MPI
|
||||||
|
# Recursive saving if key is a list. Take care of 'truncate'.
|
||||||
|
if isinstance(key,(tuple,list)):
|
||||||
|
for key_ in key:
|
||||||
|
self.save_for_vtk(file,key_,stride=stride,merge_at_root=merge_at_root,
|
||||||
|
truncate=truncate,on_pressure_grid=on_pressure_grid)
|
||||||
|
truncate = False
|
||||||
|
return
|
||||||
|
# Since the data is usually much smaller than a full 'save_state', I only
|
||||||
|
# implement sequential IO for now.
|
||||||
|
tbeg = MPI.Wtime()
|
||||||
|
# If flag is set, shift data onto pressure grid first. Use a temporary field for this.
|
||||||
|
name = key
|
||||||
|
if on_pressure_grid:
|
||||||
|
key_tmp = 'tmp'
|
||||||
|
self.shift_to_pressure_grid(key,key_out='tmp')
|
||||||
|
key = key_tmp
|
||||||
|
# Get the subfield and save them
|
||||||
|
if merge_at_root:
|
||||||
|
fld = self._merge_at_root(key,stride=stride)
|
||||||
|
if self.rank==0: fld.save(file,name=name,truncate=truncate)
|
||||||
|
else:
|
||||||
|
self._baton_wait(1)
|
||||||
|
if self.rank!=0: truncate=False
|
||||||
|
name += '/{:05d}'.format(self.rank)
|
||||||
|
self._subfield_for_vtk(key,stride=stride).save(file,name=name,truncate=truncate)
|
||||||
|
self._baton_pass(1)
|
||||||
|
# Free the temporary field (if it was created)
|
||||||
|
if on_pressure_grid: self.delete(key_tmp)
|
||||||
|
# Sync (important in case there is another write to the same file following!)
|
||||||
|
self.comm.Barrier()
|
||||||
|
# Print timing
|
||||||
|
tend = MPI.Wtime()
|
||||||
|
if self.rank==0:
|
||||||
|
print("[save_for_vtk] Elapsed time: {:f}".format(tend-tbeg))
|
||||||
|
return
|
||||||
|
|
||||||
def to_vtk(self,key,stride=(1,1,1),merge_at_root=False):
|
def to_vtk(self,key,stride=(1,1,1),merge_at_root=False):
|
||||||
'''Returns the field (only its interior + some ghost cells for plotting)
|
'''Returns the field (only its interior + some ghost cells for plotting)
|
||||||
|
|
@ -517,25 +663,7 @@ class PPP:
|
||||||
methods and apply .to_vtk() to the result.'''
|
methods and apply .to_vtk() to the result.'''
|
||||||
from .field import Field3d
|
from .field import Field3d
|
||||||
if merge_at_root:
|
if merge_at_root:
|
||||||
if self.rank==0:
|
return self._merge_at_root(key,stride=stride).to_vtk()
|
||||||
stride = np.array(stride)
|
|
||||||
# Allocate the full output field
|
|
||||||
origin = np.array(self.origin[key])
|
|
||||||
spacing = stride*np.array(self.spacing)
|
|
||||||
dim = (np.array(self.dim(key,axis=None))/stride+1).astype(int)
|
|
||||||
# the following seems necessary and seems to work, but i didn't think much about it
|
|
||||||
for axis in range(3):
|
|
||||||
if self.dim(key,axis=axis)%stride[axis]!=0:
|
|
||||||
dim[axis]+=1
|
|
||||||
output = Field3d.allocate(dim,origin,spacing,dtype=self.field[key].dtype())
|
|
||||||
# Recieve subfields and insert them
|
|
||||||
output.insert_subfield(self._subfield_for_vtk_merge(key,stride=stride))
|
|
||||||
for rank_src in range(1,self.nproc):
|
|
||||||
output.insert_subfield(self.comm.recv(source=rank_src))
|
|
||||||
return output.to_vtk()
|
|
||||||
else:
|
|
||||||
self.comm.send(self._subfield_for_vtk_merge(key,stride=stride),dest=0)
|
|
||||||
return None
|
|
||||||
else:
|
else:
|
||||||
return self._subfield_for_vtk(key,stride=stride).to_vtk()
|
return self._subfield_for_vtk(key,stride=stride).to_vtk()
|
||||||
|
|
||||||
|
|
@ -611,6 +739,30 @@ class PPP:
|
||||||
# Return the subfield
|
# Return the subfield
|
||||||
return self.field[key].extract_subfield(idx_origin,num_points,stride=stride)
|
return self.field[key].extract_subfield(idx_origin,num_points,stride=stride)
|
||||||
|
|
||||||
|
def _merge_at_root(self,key,stride=(1,1,1)):
|
||||||
|
'''Returns the entire field gathered from all processors with a
|
||||||
|
stride applied as a Field3d.'''
|
||||||
|
from .field import Field3d
|
||||||
|
if self.rank==0:
|
||||||
|
stride = np.array(stride)
|
||||||
|
# Allocate the full output field
|
||||||
|
origin = np.array(self.origin[key])
|
||||||
|
spacing = stride*np.array(self.spacing)
|
||||||
|
dim = (np.array(self.dim(key,axis=None))/stride+1).astype(int)
|
||||||
|
# the following seems necessary and seems to work, but i didn't think much about it
|
||||||
|
for axis in range(3):
|
||||||
|
if self.dim(key,axis=axis)%stride[axis]!=0:
|
||||||
|
dim[axis]+=1
|
||||||
|
output = Field3d.allocate(dim,origin,spacing,dtype=self.field[key].dtype())
|
||||||
|
# Recieve subfields and insert them
|
||||||
|
output.insert_subfield(self._subfield_for_vtk_merge(key,stride=stride))
|
||||||
|
for rank_src in range(1,self.nproc):
|
||||||
|
output.insert_subfield(self.comm.recv(source=rank_src))
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
self.comm.send(self._subfield_for_vtk_merge(key,stride=stride),dest=0)
|
||||||
|
return None
|
||||||
|
|
||||||
def rank_from_position(self,ip,jp,kp,external=False):
|
def rank_from_position(self,ip,jp,kp,external=False):
|
||||||
if external:
|
if external:
|
||||||
nyp,nzp = self.nyp_ext,self.nzp_ext
|
nyp,nzp = self.nyp_ext,self.nzp_ext
|
||||||
|
|
@ -644,12 +796,23 @@ class PPP:
|
||||||
assert axis<3, "'axis' must be one of 0,1,2."
|
assert axis<3, "'axis' must be one of 0,1,2."
|
||||||
return int(round((self.origin[key][axis]-self.bounds[2*axis])/(0.5*self.spacing[axis])))
|
return int(round((self.origin[key][axis]-self.bounds[2*axis])/(0.5*self.spacing[axis])))
|
||||||
|
|
||||||
def chunk_size(self,key,axis=None):
|
def chunk_size(self,key,axis=None,rank=None,incl_ghost=False):
|
||||||
'''Returns size of chunk without ghost cells.'''
|
'''Returns size of a chunk.'''
|
||||||
if axis is None:
|
if axis is None:
|
||||||
return tuple(self.chunk_size(key,axis=ii) for ii in range(3))
|
return tuple(self.chunk_size(key,
|
||||||
|
axis=ii,rank=rank,incl_ghost=incl_ghost) for ii in range(3))
|
||||||
assert axis<3, "'axis' must be one of 0,1,2."
|
assert axis<3, "'axis' must be one of 0,1,2."
|
||||||
return self.field[key].dim(axis=axis)-2*self.num_ghost[axis]
|
#return self.field[key].dim(axis=axis)-2*self.num_ghost[axis]
|
||||||
|
if rank is None:
|
||||||
|
rank=self.rank
|
||||||
|
pos = self.position_from_rank(rank,external=False)
|
||||||
|
r = self.proc_grid[key][2*axis+1][pos[axis]]-self.proc_grid[key][2*axis][pos[axis]]+1
|
||||||
|
if incl_ghost:
|
||||||
|
r+=2*self.num_ghost[axis]
|
||||||
|
return r
|
||||||
|
|
||||||
|
def chunk_dtype(self,key):
|
||||||
|
return self.field[key].dtype()
|
||||||
|
|
||||||
def dim(self,key,axis=None):
|
def dim(self,key,axis=None):
|
||||||
'''Returns the total number of gridpoints across all processors
|
'''Returns the total number of gridpoints across all processors
|
||||||
|
|
@ -923,21 +1086,12 @@ class PPP:
|
||||||
def _baton_wait(self,batch_size,tag=420):
|
def _baton_wait(self,batch_size,tag=420):
|
||||||
'''Block execution until an empty message from rank-batch_wait
|
'''Block execution until an empty message from rank-batch_wait
|
||||||
is received (issued by _baton_pass)'''
|
is received (issued by _baton_pass)'''
|
||||||
from mpi4py import MPI
|
baton_wait(batch_size,comm=self.comm,tag=tag)
|
||||||
if batch_size is not None:
|
|
||||||
if self.rank>=batch_size:
|
|
||||||
source = self.rank-batch_size
|
|
||||||
self.comm.recv(source=source,tag=tag)
|
|
||||||
|
|
||||||
def _baton_pass(self,batch_size,tag=420):
|
def _baton_pass(self,batch_size,tag=420):
|
||||||
'''Sends an empty message to rank+batch_wait to unblock its
|
'''Sends an empty message to rank+batch_wait to unblock its
|
||||||
execution (issued by _baton_wait)'''
|
execution (issued by _baton_wait)'''
|
||||||
from mpi4py import MPI
|
baton_pass(batch_size,comm=self.comm,tag=tag)
|
||||||
if batch_size is not None:
|
|
||||||
dest = self.rank+batch_size
|
|
||||||
if dest<self.comm.Get_size():
|
|
||||||
data = None
|
|
||||||
self.comm.send(data,dest=dest,tag=tag)
|
|
||||||
|
|
||||||
class GatherIterator:
|
class GatherIterator:
|
||||||
'''Sends 'data' sequentially to 'root' which can iterate over it
|
'''Sends 'data' sequentially to 'root' which can iterate over it
|
||||||
|
|
@ -992,3 +1146,26 @@ def gather(data,comm=None):
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
comm = MPI.COMM_WORLD if comm is None else comm
|
comm = MPI.COMM_WORLD if comm is None else comm
|
||||||
return comm.gather(data,root=0)
|
return comm.gather(data,root=0)
|
||||||
|
|
||||||
|
def baton_wait(batch_size,comm=None,tag=420):
|
||||||
|
'''Block execution until an empty message from rank-batch_wait
|
||||||
|
is received (issued by _baton_pass)'''
|
||||||
|
from mpi4py import MPI
|
||||||
|
comm = MPI.COMM_WORLD if comm is None else comm
|
||||||
|
rank = comm.Get_rank()
|
||||||
|
if batch_size is not None:
|
||||||
|
if rank>=batch_size:
|
||||||
|
source = rank-batch_size
|
||||||
|
comm.recv(source=source,tag=tag)
|
||||||
|
|
||||||
|
def baton_pass(batch_size,comm=None,tag=420):
|
||||||
|
'''Sends an empty message to rank+batch_wait to unblock its
|
||||||
|
execution (issued by _baton_wait)'''
|
||||||
|
from mpi4py import MPI
|
||||||
|
comm = MPI.COMM_WORLD if comm is None else comm
|
||||||
|
rank = comm.Get_rank()
|
||||||
|
if batch_size is not None:
|
||||||
|
dest = rank+batch_size
|
||||||
|
if dest<comm.Get_size():
|
||||||
|
data = None
|
||||||
|
comm.send(data,dest=dest,tag=tag)
|
||||||
Loading…
Reference in New Issue