fixed unravel, but still untested. some small improvements
This commit is contained in:
parent
69230e4edd
commit
0e7b21c239
36
particle.py
36
particle.py
|
|
@ -16,11 +16,10 @@ class Particles:
|
|||
self.attr = {}
|
||||
sortidx = pp[col['id'],:,0].argsort()
|
||||
idsorted = pp[col['id'],sortidx,0]
|
||||
#assert numpy.isclose(idsorted[0],1), "Particle IDs do not start at 1."
|
||||
#assert numpy.isclose(idsorted[-1],self.num), "Particle IDs do not end at Np."
|
||||
#assert numpy.all(numpy.diff(idsorted)>0), "Particle IDs are not unique."
|
||||
for key in col:
|
||||
if (key!='id') and (select_col is None or key in select_col or key in ('x','y','z')):
|
||||
if key=='id':
|
||||
self.add_attribute(key,pp[col[key],sortidx,0].astype('int'))
|
||||
elif (select_col is None or key in select_col or key in ('x','y','z')):
|
||||
self.add_attribute(key,pp[col[key],sortidx,0])
|
||||
self.period = period
|
||||
self.time = time
|
||||
|
|
@ -139,6 +138,7 @@ class Trajectories:
|
|||
# Verify part
|
||||
assert part.time>self.part[-1].time, "Time steps must be added in monotonically increasing order."
|
||||
assert part.num==self.numpart, "Number of particles is different from previous time steps."
|
||||
assert all(part['id']==self.part[0].attr['id']), "Particle IDs differ or a not in the same order."
|
||||
assert all([part.period[ii]==self.period[ii] for ii in range(0,3)]), "Period differs!"
|
||||
for key in self.attr:
|
||||
assert key in part.attr, "Particles to be added are missing attribute '{}'".format(key)
|
||||
|
|
@ -181,26 +181,15 @@ class Trajectories:
|
|||
for ii,part in enumerate(self.part):
|
||||
time[ii] = part.time
|
||||
return time
|
||||
def get_trajectory(self,id=None,slice_part=slice(None),slice_time=slice(None)):
|
||||
# WARNING: slice uses array indexing instead of ID
|
||||
def get_trajectory(self,slice_part=slice(None),slice_time=slice(None)):
|
||||
import numpy
|
||||
assert isinstance(slice_part,slice), "'slice_part' must be a slice."
|
||||
assert isinstance(slice_time,slice), "'slice_time' must be a slice."
|
||||
self._make_data_array()
|
||||
if id is None: # npart X ndim X ntime
|
||||
return numpy.stack([
|
||||
self.attr['x'][slice_part,slice_time],
|
||||
self.attr['y'][slice_part,slice_time],
|
||||
self.attr['z'][slice_part,slice_time]],axis=0)
|
||||
else:
|
||||
if id<0:
|
||||
id = self.numpart+id+1
|
||||
assert id>=1 and id<=self.numpart, "Particle ID out-of-bounds: {:d}".format(id)
|
||||
id = slice(id-1,id)
|
||||
return numpy.stack([
|
||||
self.attr['x'][id,slice_time],
|
||||
self.attr['y'][id,slice_time],
|
||||
self.attr['z'][id,slice_time]],axis=0)
|
||||
def get_segments(self,id=None):
|
||||
#self.unravel()
|
||||
#if id is None: # npart X ndim X ntime
|
||||
|
|
@ -209,13 +198,15 @@ class Trajectories:
|
|||
import numpy
|
||||
if self.unraveled: return
|
||||
raise NotImplementedError('Implemented but untested!')
|
||||
self._make_data_array()
|
||||
for axis in range(0,3):
|
||||
if self.period[axis] is not None:
|
||||
key = ('x','y','z')[axis]
|
||||
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
||||
coeff = -(numpy.abs(posdiff)>0.5*self.period[axis]).*numpy.sign(posdiff)
|
||||
coeff = numpy.cumsum(coeff)
|
||||
self.part[ii].attr[key][1:] += coeff*self.period[axis]
|
||||
#posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
||||
posdiff = self.attr[key][:,1:]-self.attr[key][:,0:-1]
|
||||
coeff = -numpy.sign(posdiff)*(numpy.abs(posdiff)>0.5*self.period[axis])
|
||||
coeff = numpy.cumsum(coeff,axis=1)
|
||||
self.attr[key][:,1:] += coeff*self.period[axis]
|
||||
self.unraveled = True
|
||||
return
|
||||
def ravel(self):
|
||||
|
|
@ -228,6 +219,8 @@ class Trajectories:
|
|||
self.unraveled = False
|
||||
return
|
||||
def _make_data_array(self):
|
||||
'''Transforms self.attr[key] to a numpy array of dimension npart X ntime
|
||||
and updates self.part[itime].attr[key] to views on self.attr[key][:,itime].'''
|
||||
import numpy
|
||||
#print('DEBUG: _make_data_array')
|
||||
if not self._is_data_list:
|
||||
|
|
@ -241,6 +234,8 @@ class Trajectories:
|
|||
self._is_data_list = False
|
||||
return
|
||||
def _make_data_list(self):
|
||||
'''Transforms self.attr[key] to a list of length ntime of numpy arrays of
|
||||
dimension npart and updates self.part[itime].attr[key] to views on self.attr[key][itime].'''
|
||||
import numpy
|
||||
#print('DEBUG: _make_data_list')
|
||||
if self._is_data_list:
|
||||
|
|
@ -255,7 +250,6 @@ class Trajectories:
|
|||
return
|
||||
|
||||
|
||||
|
||||
def sort(pp,col):
|
||||
ncol,npart,ntime = pp.shape
|
||||
assert('id' in col)
|
||||
|
|
|
|||
Loading…
Reference in New Issue