Compare commits

...

2 Commits

1 changed files with 33 additions and 48 deletions

View File

@ -16,11 +16,10 @@ class Particles:
self.attr = {}
sortidx = pp[col['id'],:,0].argsort()
idsorted = pp[col['id'],sortidx,0]
#assert numpy.isclose(idsorted[0],1), "Particle IDs do not start at 1."
#assert numpy.isclose(idsorted[-1],self.num), "Particle IDs do not end at Np."
#assert numpy.all(numpy.diff(idsorted)>0), "Particle IDs are not unique."
for key in col:
if (key!='id') and (select_col is None or key in select_col or key in ('x','y','z')):
if key=='id':
self.add_attribute(key,pp[col[key],sortidx,0].astype('int'))
elif (select_col is None or key in select_col or key in ('x','y','z')):
self.add_attribute(key,pp[col[key],sortidx,0])
self.period = period
self.time = time
@ -91,11 +90,6 @@ class Trajectories:
for key in self.part[0].attr:
self.attr[key] = [self.part[0].attr[key].view()]
self.period = self.part[0].period
if any([period is not None for period in self.period]):
self._crossings = numpy.zeros((6,self.numpart,1),dtype=numpy.bool)
else:
self._crossings = None
return
self.unraveled = False
def __str__(self):
str = 'Trajectory with\n'
@ -124,7 +118,7 @@ class Trajectories:
raise TypeError("Trajectories can only be sliced by slice objects or integers.")
return self.get_trajectory(slice_part=sl[0],slice_time=sl[1])
@classmethod
def from_mat(cls,file):
def from_mat(cls,file,unraveled=False):
from .helper import load_mat
pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo'])
period = [None,None,None]
@ -137,12 +131,14 @@ class Trajectories:
traj = cls(Particles(pp[:,:,0],col,time[0],period))
for ii in range(1,ntime):
traj += Particles(pp[:,:,ii],col,time[ii],period)
traj.unraveled = unraveled
return traj
def add_particles(self,part):
import numpy
# Verify part
assert part.time>self.part[-1].time, "Time steps must be added in monotonically increasing order."
assert part.num==self.numpart, "Number of particles is different from previous time steps."
assert all(part['id']==self.part[0].attr['id']), "Particle IDs differ or a not in the same order."
assert all([part.period[ii]==self.period[ii] for ii in range(0,3)]), "Period differs!"
for key in self.attr:
assert key in part.attr, "Particles to be added are missing attribute '{}'".format(key)
@ -185,58 +181,46 @@ class Trajectories:
for ii,part in enumerate(self.part):
time[ii] = part.time
return time
def get_trajectory(self,id=None,slice_part=slice(None),slice_time=slice(None)):
# WARNING: slice uses array indexing instead of ID
def get_trajectory(self,slice_part=slice(None),slice_time=slice(None)):
import numpy
assert isinstance(slice_part,slice), "'slice_part' must be a slice."
assert isinstance(slice_time,slice), "'slice_time' must be a slice."
self._make_data_array()
if id is None: # npart X ndim X ntime
return numpy.stack([
self.attr['x'][slice_part,slice_time],
self.attr['y'][slice_part,slice_time],
self.attr['z'][slice_part,slice_time]],axis=0)
else:
if id<0:
id = self.numpart+id+1
assert id>=1 and id<=self.numpart, "Particle ID out-of-bounds: {:d}".format(id)
id = slice(id-1,id)
return numpy.stack([
self.attr['x'][id,slice_time],
self.attr['y'][id,slice_time],
self.attr['z'][id,slice_time]],axis=0)
return numpy.stack([
self.attr['x'][slice_part,slice_time],
self.attr['y'][slice_part,slice_time],
self.attr['z'][slice_part,slice_time]],axis=0)
def get_segments(self,id=None):
#self.unravel()
#if id is None: # npart X ndim X ntime
return
def unravel(self):
import numpy
if self.unraveled: return
if self._crossings is None: return
self._compute_crossings()
# do sth
raise NotImplementedError('Implemented but untested!')
self._make_data_array()
for axis in range(0,3):
if self.period[axis] is not None:
key = ('x','y','z')[axis]
#posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
posdiff = self.attr[key][:,1:]-self.attr[key][:,0:-1]
coeff = -numpy.sign(posdiff)*(numpy.abs(posdiff)>0.5*self.period[axis])
coeff = numpy.cumsum(coeff,axis=1)
self.attr[key][:,1:] += coeff*self.period[axis]
self.unraveled = True
return
def ravel(self):
if not self.unraveled: return
raise NotImplementedError('Implemented but untested!')
for axis in range(0,3):
if self.period[axis] is not None:
key = ('x','y','z')[axis]
self.part[ii].attr[key] %= self.period[axis]
self.unraveled = False
return
def _compute_crossings(self):
import numpy
if self._crossings is not None: # None means no periodicity
num_computed = self._crossings.shape[1]
if num_computed==self.numtime:
return
crossings = numpy.zeros((6,self.numpart,self.numtime),dtype=numpy.bool)
crossings[:,:,0:num_computed] = self._crossings
for ii in range(num_computed,self.numtime):
for axis in range(0,3):
if self.period[axis] is not None:
key = ('x','y','z')[axis]
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
doesCross = abs(posdiff)>0.5*self.period[axis]
upCross = numpy.logical_and(posdiff<0,doesCross)
crossings[2*axis,:,ii] = doesCross
crossings[2*axis+1,:,ii] = upCross
self._crossings = crossings
def _make_data_array(self):
'''Transforms self.attr[key] to a numpy array of dimension npart X ntime
and updates self.part[itime].attr[key] to views on self.attr[key][:,itime].'''
import numpy
#print('DEBUG: _make_data_array')
if not self._is_data_list:
@ -250,6 +234,8 @@ class Trajectories:
self._is_data_list = False
return
def _make_data_list(self):
'''Transforms self.attr[key] to a list of length ntime of numpy arrays of
dimension npart and updates self.part[itime].attr[key] to views on self.attr[key][itime].'''
import numpy
#print('DEBUG: _make_data_list')
if self._is_data_list:
@ -264,7 +250,6 @@ class Trajectories:
return
def sort(pp,col):
ncol,npart,ntime = pp.shape
assert('id' in col)