diff --git a/particle.py b/particle.py index a878201..4afb72e 100644 --- a/particle.py +++ b/particle.py @@ -16,11 +16,10 @@ class Particles: self.attr = {} sortidx = pp[col['id'],:,0].argsort() idsorted = pp[col['id'],sortidx,0] - #assert numpy.isclose(idsorted[0],1), "Particle IDs do not start at 1." - #assert numpy.isclose(idsorted[-1],self.num), "Particle IDs do not end at Np." - #assert numpy.all(numpy.diff(idsorted)>0), "Particle IDs are not unique." for key in col: - if (key!='id') and (select_col is None or key in select_col or key in ('x','y','z')): + if key=='id': + self.add_attribute(key,pp[col[key],sortidx,0].astype('int')) + elif (select_col is None or key in select_col or key in ('x','y','z')): self.add_attribute(key,pp[col[key],sortidx,0]) self.period = period self.time = time @@ -139,6 +138,7 @@ class Trajectories: # Verify part assert part.time>self.part[-1].time, "Time steps must be added in monotonically increasing order." assert part.num==self.numpart, "Number of particles is different from previous time steps." + assert all(part['id']==self.part[0].attr['id']), "Particle IDs differ or a not in the same order." assert all([part.period[ii]==self.period[ii] for ii in range(0,3)]), "Period differs!" for key in self.attr: assert key in part.attr, "Particles to be added are missing attribute '{}'".format(key) @@ -181,26 +181,15 @@ class Trajectories: for ii,part in enumerate(self.part): time[ii] = part.time return time - def get_trajectory(self,id=None,slice_part=slice(None),slice_time=slice(None)): - # WARNING: slice uses array indexing instead of ID + def get_trajectory(self,slice_part=slice(None),slice_time=slice(None)): import numpy assert isinstance(slice_part,slice), "'slice_part' must be a slice." assert isinstance(slice_time,slice), "'slice_time' must be a slice." self._make_data_array() - if id is None: # npart X ndim X ntime - return numpy.stack([ - self.attr['x'][slice_part,slice_time], - self.attr['y'][slice_part,slice_time], - self.attr['z'][slice_part,slice_time]],axis=0) - else: - if id<0: - id = self.numpart+id+1 - assert id>=1 and id<=self.numpart, "Particle ID out-of-bounds: {:d}".format(id) - id = slice(id-1,id) - return numpy.stack([ - self.attr['x'][id,slice_time], - self.attr['y'][id,slice_time], - self.attr['z'][id,slice_time]],axis=0) + return numpy.stack([ + self.attr['x'][slice_part,slice_time], + self.attr['y'][slice_part,slice_time], + self.attr['z'][slice_part,slice_time]],axis=0) def get_segments(self,id=None): #self.unravel() #if id is None: # npart X ndim X ntime @@ -209,13 +198,15 @@ class Trajectories: import numpy if self.unraveled: return raise NotImplementedError('Implemented but untested!') + self._make_data_array() for axis in range(0,3): if self.period[axis] is not None: key = ('x','y','z')[axis] - posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze() - coeff = -(numpy.abs(posdiff)>0.5*self.period[axis]).*numpy.sign(posdiff) - coeff = numpy.cumsum(coeff) - self.part[ii].attr[key][1:] += coeff*self.period[axis] + #posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze() + posdiff = self.attr[key][:,1:]-self.attr[key][:,0:-1] + coeff = -numpy.sign(posdiff)*(numpy.abs(posdiff)>0.5*self.period[axis]) + coeff = numpy.cumsum(coeff,axis=1) + self.attr[key][:,1:] += coeff*self.period[axis] self.unraveled = True return def ravel(self): @@ -228,6 +219,8 @@ class Trajectories: self.unraveled = False return def _make_data_array(self): + '''Transforms self.attr[key] to a numpy array of dimension npart X ntime + and updates self.part[itime].attr[key] to views on self.attr[key][:,itime].''' import numpy #print('DEBUG: _make_data_array') if not self._is_data_list: @@ -241,6 +234,8 @@ class Trajectories: self._is_data_list = False return def _make_data_list(self): + '''Transforms self.attr[key] to a list of length ntime of numpy arrays of + dimension npart and updates self.part[itime].attr[key] to views on self.attr[key][itime].''' import numpy #print('DEBUG: _make_data_list') if self._is_data_list: @@ -255,7 +250,6 @@ class Trajectories: return - def sort(pp,col): ncol,npart,ntime = pp.shape assert('id' in col)