From 69230e4edd8cb7217fc91cd98ca8e000731761a7 Mon Sep 17 00:00:00 2001 From: Michael Krayer Date: Tue, 25 May 2021 10:32:26 +0200 Subject: [PATCH] added ravel/unravel but untested (it can be loaded from mat file) --- particle.py | 45 ++++++++++++++++++--------------------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/particle.py b/particle.py index 432816e..a878201 100644 --- a/particle.py +++ b/particle.py @@ -91,11 +91,6 @@ class Trajectories: for key in self.part[0].attr: self.attr[key] = [self.part[0].attr[key].view()] self.period = self.part[0].period - if any([period is not None for period in self.period]): - self._crossings = numpy.zeros((6,self.numpart,1),dtype=numpy.bool) - else: - self._crossings = None - return self.unraveled = False def __str__(self): str = 'Trajectory with\n' @@ -124,7 +119,7 @@ class Trajectories: raise TypeError("Trajectories can only be sliced by slice objects or integers.") return self.get_trajectory(slice_part=sl[0],slice_time=sl[1]) @classmethod - def from_mat(cls,file): + def from_mat(cls,file,unraveled=False): from .helper import load_mat pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo']) period = [None,None,None] @@ -137,6 +132,7 @@ class Trajectories: traj = cls(Particles(pp[:,:,0],col,time[0],period)) for ii in range(1,ntime): traj += Particles(pp[:,:,ii],col,time[ii],period) + traj.unraveled = unraveled return traj def add_particles(self,part): import numpy @@ -206,36 +202,31 @@ class Trajectories: self.attr['y'][id,slice_time], self.attr['z'][id,slice_time]],axis=0) def get_segments(self,id=None): + #self.unravel() + #if id is None: # npart X ndim X ntime return def unravel(self): + import numpy if self.unraveled: return - if self._crossings is None: return - self._compute_crossings() - # do sth + raise NotImplementedError('Implemented but untested!') + for axis in range(0,3): + if self.period[axis] is not None: + key = ('x','y','z')[axis] + posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze() + coeff = -(numpy.abs(posdiff)>0.5*self.period[axis]).*numpy.sign(posdiff) + coeff = numpy.cumsum(coeff) + self.part[ii].attr[key][1:] += coeff*self.period[axis] self.unraveled = True return def ravel(self): if not self.unraveled: return + raise NotImplementedError('Implemented but untested!') + for axis in range(0,3): + if self.period[axis] is not None: + key = ('x','y','z')[axis] + self.part[ii].attr[key] %= self.period[axis] self.unraveled = False return - def _compute_crossings(self): - import numpy - if self._crossings is not None: # None means no periodicity - num_computed = self._crossings.shape[1] - if num_computed==self.numtime: - return - crossings = numpy.zeros((6,self.numpart,self.numtime),dtype=numpy.bool) - crossings[:,:,0:num_computed] = self._crossings - for ii in range(num_computed,self.numtime): - for axis in range(0,3): - if self.period[axis] is not None: - key = ('x','y','z')[axis] - posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze() - doesCross = abs(posdiff)>0.5*self.period[axis] - upCross = numpy.logical_and(posdiff<0,doesCross) - crossings[2*axis,:,ii] = doesCross - crossings[2*axis+1,:,ii] = upCross - self._crossings = crossings def _make_data_array(self): import numpy #print('DEBUG: _make_data_array')