diff --git a/particle.py b/particle.py index 4afb72e..835a86e 100644 --- a/particle.py +++ b/particle.py @@ -94,7 +94,8 @@ class Trajectories: def __str__(self): str = 'Trajectory with\n' str+= ' time steps: {:d}\n'.format(self.numtime) - str+= ' particles: {:d}'.format(self.numpart) + str+= ' particles: {:d}\n'.format(self.numpart) + str+= ' unraveled: {}'.format(self.unraveled) return str def __iadd__(self,other): if isinstance(other,Trajectories): @@ -116,7 +117,7 @@ class Trajectories: sl.append(x) else: raise TypeError("Trajectories can only be sliced by slice objects or integers.") - return self.get_trajectory(slice_part=sl[0],slice_time=sl[1]) + return self.get_trajectories(slice_part=sl[0],slice_time=sl[1]) @classmethod def from_mat(cls,file,unraveled=False): from .helper import load_mat @@ -181,7 +182,8 @@ class Trajectories: for ii,part in enumerate(self.part): time[ii] = part.time return time - def get_trajectory(self,slice_part=slice(None),slice_time=slice(None)): + def get_trajectories(self,slice_part=slice(None),slice_time=slice(None)): + '''Get (x,y,z) trajectories in numpy array of dimension (3,npart,ntime).''' import numpy assert isinstance(slice_part,slice), "'slice_part' must be a slice." assert isinstance(slice_time,slice), "'slice_time' must be a slice." @@ -190,10 +192,48 @@ class Trajectories: self.attr['x'][slice_part,slice_time], self.attr['y'][slice_part,slice_time], self.attr['z'][slice_part,slice_time]],axis=0) - def get_segments(self,id=None): - #self.unravel() - #if id is None: # npart X ndim X ntime - return + def get_trajectories_segmented(self,slice_part=slice(None),slice_time=slice(None),restore_ravel_state=True): + '''Get (x,y,z) segments of trajectories as a tuple (len: npart) of + lists (len: nsegments) of numpy arrays (shape: 3 X ntime of the segment)''' + import numpy + assert isinstance(slice_part,slice), "'slice_part' must be a slice." + assert isinstance(slice_time,slice), "'slice_time' must be a slice." + was_unraveled = self.unraveled + self.unravel() + xyzpath = self.get_trajectories(slice_part=slice_part,slice_time=slice_time) + npart = xyzpath.shape[1] + ntime = xyzpath.shape[2] + # Initialize output and helper arrays + out = tuple([] for ii in range(npart)) + lastJump = numpy.zeros(npart,dtype=numpy.uint) + lastPeriod = numpy.zeros((3,npart),dtype=numpy.int) + for axis in range(3): + if self.period[axis] is not None: + lastPeriod[axis,:] = numpy.floor_divide(xyzpath[axis,:,0],self.period[axis]) + # Construct output tuple + for itime in range(1,ntime+1): + thisPeriod = numpy.zeros((3,npart),dtype=int) + if itime==ntime: + hasJumped = numpy.ones(npart,dtype=bool) + else: + for axis in range(3): + if self.period[axis] is not None: + thisPeriod[axis,:] = numpy.floor_divide(xyzpath[axis,:,itime],self.period[axis]) + hasJumped = numpy.any(thisPeriod!=lastPeriod,axis=0) + for ipart in range(npart): + if hasJumped[ipart]: + sl = slice(lastJump[ipart],itime) + segment = xyzpath[:,ipart,sl].copy() + for axis in range(3): + if self.period[axis] is not None: + segment[axis,:] -= lastPeriod[axis,ipart]*self.period[axis] + out[ipart].append(segment) + lastJump[ipart] = itime + lastPeriod = thisPeriod + # Restore previous ravel state + if restore_ravel_state and not was_unraveled: + self.ravel() + return out def unravel(self): import numpy if self.unraveled: return diff --git a/visu.py b/visu.py index 152fa9d..a9a24fd 100644 --- a/visu.py +++ b/visu.py @@ -47,6 +47,23 @@ def add_particles(plotter,part, opacity=opacity) return +def add_trajectories(plotter,traj, + name=None, + color='black', + scalars=None,cmap=None,clim=None, + opacity=1.0): + import pyvista + segments = traj.get_trajectories_segmented() + lines = pyvista.PolyData() + for part in segments: + for seg in part: + lines += pyvista.helpers.lines_from_points(seg.transpose()) + plotter.add_mesh(lines,name=name, + color=color, + scalars=scalars,cmap=cmap,clim=clim, + opacity=opacity) + return + def chunk_to_pvmesh(chunk,gridg): import pyvista mesh = pyvista.UniformGrid()