fixed unravel, but still untested. some small improvements

This commit is contained in:
Michael Krayer 2021-05-25 11:01:06 +02:00
parent 69230e4edd
commit 0e7b21c239
1 changed files with 19 additions and 25 deletions

View File

@ -16,11 +16,10 @@ class Particles:
self.attr = {} self.attr = {}
sortidx = pp[col['id'],:,0].argsort() sortidx = pp[col['id'],:,0].argsort()
idsorted = pp[col['id'],sortidx,0] idsorted = pp[col['id'],sortidx,0]
#assert numpy.isclose(idsorted[0],1), "Particle IDs do not start at 1."
#assert numpy.isclose(idsorted[-1],self.num), "Particle IDs do not end at Np."
#assert numpy.all(numpy.diff(idsorted)>0), "Particle IDs are not unique."
for key in col: for key in col:
if (key!='id') and (select_col is None or key in select_col or key in ('x','y','z')): if key=='id':
self.add_attribute(key,pp[col[key],sortidx,0].astype('int'))
elif (select_col is None or key in select_col or key in ('x','y','z')):
self.add_attribute(key,pp[col[key],sortidx,0]) self.add_attribute(key,pp[col[key],sortidx,0])
self.period = period self.period = period
self.time = time self.time = time
@ -139,6 +138,7 @@ class Trajectories:
# Verify part # Verify part
assert part.time>self.part[-1].time, "Time steps must be added in monotonically increasing order." assert part.time>self.part[-1].time, "Time steps must be added in monotonically increasing order."
assert part.num==self.numpart, "Number of particles is different from previous time steps." assert part.num==self.numpart, "Number of particles is different from previous time steps."
assert all(part['id']==self.part[0].attr['id']), "Particle IDs differ or a not in the same order."
assert all([part.period[ii]==self.period[ii] for ii in range(0,3)]), "Period differs!" assert all([part.period[ii]==self.period[ii] for ii in range(0,3)]), "Period differs!"
for key in self.attr: for key in self.attr:
assert key in part.attr, "Particles to be added are missing attribute '{}'".format(key) assert key in part.attr, "Particles to be added are missing attribute '{}'".format(key)
@ -181,26 +181,15 @@ class Trajectories:
for ii,part in enumerate(self.part): for ii,part in enumerate(self.part):
time[ii] = part.time time[ii] = part.time
return time return time
def get_trajectory(self,id=None,slice_part=slice(None),slice_time=slice(None)): def get_trajectory(self,slice_part=slice(None),slice_time=slice(None)):
# WARNING: slice uses array indexing instead of ID
import numpy import numpy
assert isinstance(slice_part,slice), "'slice_part' must be a slice." assert isinstance(slice_part,slice), "'slice_part' must be a slice."
assert isinstance(slice_time,slice), "'slice_time' must be a slice." assert isinstance(slice_time,slice), "'slice_time' must be a slice."
self._make_data_array() self._make_data_array()
if id is None: # npart X ndim X ntime return numpy.stack([
return numpy.stack([ self.attr['x'][slice_part,slice_time],
self.attr['x'][slice_part,slice_time], self.attr['y'][slice_part,slice_time],
self.attr['y'][slice_part,slice_time], self.attr['z'][slice_part,slice_time]],axis=0)
self.attr['z'][slice_part,slice_time]],axis=0)
else:
if id<0:
id = self.numpart+id+1
assert id>=1 and id<=self.numpart, "Particle ID out-of-bounds: {:d}".format(id)
id = slice(id-1,id)
return numpy.stack([
self.attr['x'][id,slice_time],
self.attr['y'][id,slice_time],
self.attr['z'][id,slice_time]],axis=0)
def get_segments(self,id=None): def get_segments(self,id=None):
#self.unravel() #self.unravel()
#if id is None: # npart X ndim X ntime #if id is None: # npart X ndim X ntime
@ -209,13 +198,15 @@ class Trajectories:
import numpy import numpy
if self.unraveled: return if self.unraveled: return
raise NotImplementedError('Implemented but untested!') raise NotImplementedError('Implemented but untested!')
self._make_data_array()
for axis in range(0,3): for axis in range(0,3):
if self.period[axis] is not None: if self.period[axis] is not None:
key = ('x','y','z')[axis] key = ('x','y','z')[axis]
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze() #posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
coeff = -(numpy.abs(posdiff)>0.5*self.period[axis]).*numpy.sign(posdiff) posdiff = self.attr[key][:,1:]-self.attr[key][:,0:-1]
coeff = numpy.cumsum(coeff) coeff = -numpy.sign(posdiff)*(numpy.abs(posdiff)>0.5*self.period[axis])
self.part[ii].attr[key][1:] += coeff*self.period[axis] coeff = numpy.cumsum(coeff,axis=1)
self.attr[key][:,1:] += coeff*self.period[axis]
self.unraveled = True self.unraveled = True
return return
def ravel(self): def ravel(self):
@ -228,6 +219,8 @@ class Trajectories:
self.unraveled = False self.unraveled = False
return return
def _make_data_array(self): def _make_data_array(self):
'''Transforms self.attr[key] to a numpy array of dimension npart X ntime
and updates self.part[itime].attr[key] to views on self.attr[key][:,itime].'''
import numpy import numpy
#print('DEBUG: _make_data_array') #print('DEBUG: _make_data_array')
if not self._is_data_list: if not self._is_data_list:
@ -241,6 +234,8 @@ class Trajectories:
self._is_data_list = False self._is_data_list = False
return return
def _make_data_list(self): def _make_data_list(self):
'''Transforms self.attr[key] to a list of length ntime of numpy arrays of
dimension npart and updates self.part[itime].attr[key] to views on self.attr[key][itime].'''
import numpy import numpy
#print('DEBUG: _make_data_list') #print('DEBUG: _make_data_list')
if self._is_data_list: if self._is_data_list:
@ -255,7 +250,6 @@ class Trajectories:
return return
def sort(pp,col): def sort(pp,col):
ncol,npart,ntime = pp.shape ncol,npart,ntime = pp.shape
assert('id' in col) assert('id' in col)