From 6eeaa477b8c142de5a30292d6fbf8138fe977b86 Mon Sep 17 00:00:00 2001 From: Michael Krayer Date: Fri, 28 May 2021 00:59:54 +0200 Subject: [PATCH] added broadcast: useful for subtracting mean flow --- parallel.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/parallel.py b/parallel.py index bfe6f0a..0229dba 100644 --- a/parallel.py +++ b/parallel.py @@ -346,6 +346,49 @@ class PPP: # Iterate inplace from now on key = key_out + def broadcast(self,key,arg,operation): + '''Broadcasts an inplace operation involving a scalar or matrix on + the entire grid. If 'arg' is a matrix, it must be three-dimensional + and its axes must be singular or of length nx/ny/nz.''' + import numpy as np + import operator + if operation in ('add','+'): + op = operator.iadd + elif operation in ('subtract','sub','-'): + op = operator.isub + elif operation in ('divide','div','/'): + op = operator.itruediv + elif operation in ('multiply','mul','*'): + op = operator.imul + elif operation in ('power','pow','^','**'): + op = operator.ipow + else: + raise ValueError("Invalid operation: {}".format(operation)) + if isinstance(arg,np.ndarray): + sl_arg = 3*[slice(None)] + for axis in range(3): + if arg.shape[axis]==1: + continue + elif arg.shape[axis]==self.proc_grid[key][2*axis+1][-1]: + pos = self.position_from_rank(self.rank,external=False)[axis] + sl_arg[axis] = slice( + self.proc_grid[key][2*axis][pos]-1, + self.proc_grid[key][2*axis+1][pos]) + else: + raise ValueError("'arg' must either be singular or match global "\ + "grid dimension. (axis={}: got {:d}, expected {:d}".format( + axis,arg.shape[axis],self.proc_grid[key][2*axis+1][-1])) + # Only operate on interior and communcate ghosts later + sl_int = tuple(slice(self.num_ghost[ii],-self.num_ghost[ii]) for ii in range(3)) + sl_arg = tuple(sl_arg) + op(self.field[key].data[sl_int],arg[sl_arg]) + # Exchange ghost cells and set boundary conditions + self.exchange_ghost_cells(key) + self.impose_boundary_conditions(key) + elif isinstance(arg,(int,float)): + op(self.field[key].data,arg) + return + def vtk_contour(self,key,val): '''Compute isocontour for chunks.''' if any([self.num_ghost[ii]>1 for ii in range(3)]):