Compare commits
2 Commits
d23824128c
...
0e7b21c239
| Author | SHA1 | Date |
|---|---|---|
|
|
0e7b21c239 | |
|
|
69230e4edd |
81
particle.py
81
particle.py
|
|
@ -16,11 +16,10 @@ class Particles:
|
||||||
self.attr = {}
|
self.attr = {}
|
||||||
sortidx = pp[col['id'],:,0].argsort()
|
sortidx = pp[col['id'],:,0].argsort()
|
||||||
idsorted = pp[col['id'],sortidx,0]
|
idsorted = pp[col['id'],sortidx,0]
|
||||||
#assert numpy.isclose(idsorted[0],1), "Particle IDs do not start at 1."
|
|
||||||
#assert numpy.isclose(idsorted[-1],self.num), "Particle IDs do not end at Np."
|
|
||||||
#assert numpy.all(numpy.diff(idsorted)>0), "Particle IDs are not unique."
|
|
||||||
for key in col:
|
for key in col:
|
||||||
if (key!='id') and (select_col is None or key in select_col or key in ('x','y','z')):
|
if key=='id':
|
||||||
|
self.add_attribute(key,pp[col[key],sortidx,0].astype('int'))
|
||||||
|
elif (select_col is None or key in select_col or key in ('x','y','z')):
|
||||||
self.add_attribute(key,pp[col[key],sortidx,0])
|
self.add_attribute(key,pp[col[key],sortidx,0])
|
||||||
self.period = period
|
self.period = period
|
||||||
self.time = time
|
self.time = time
|
||||||
|
|
@ -91,11 +90,6 @@ class Trajectories:
|
||||||
for key in self.part[0].attr:
|
for key in self.part[0].attr:
|
||||||
self.attr[key] = [self.part[0].attr[key].view()]
|
self.attr[key] = [self.part[0].attr[key].view()]
|
||||||
self.period = self.part[0].period
|
self.period = self.part[0].period
|
||||||
if any([period is not None for period in self.period]):
|
|
||||||
self._crossings = numpy.zeros((6,self.numpart,1),dtype=numpy.bool)
|
|
||||||
else:
|
|
||||||
self._crossings = None
|
|
||||||
return
|
|
||||||
self.unraveled = False
|
self.unraveled = False
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = 'Trajectory with\n'
|
str = 'Trajectory with\n'
|
||||||
|
|
@ -124,7 +118,7 @@ class Trajectories:
|
||||||
raise TypeError("Trajectories can only be sliced by slice objects or integers.")
|
raise TypeError("Trajectories can only be sliced by slice objects or integers.")
|
||||||
return self.get_trajectory(slice_part=sl[0],slice_time=sl[1])
|
return self.get_trajectory(slice_part=sl[0],slice_time=sl[1])
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_mat(cls,file):
|
def from_mat(cls,file,unraveled=False):
|
||||||
from .helper import load_mat
|
from .helper import load_mat
|
||||||
pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo'])
|
pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo'])
|
||||||
period = [None,None,None]
|
period = [None,None,None]
|
||||||
|
|
@ -137,12 +131,14 @@ class Trajectories:
|
||||||
traj = cls(Particles(pp[:,:,0],col,time[0],period))
|
traj = cls(Particles(pp[:,:,0],col,time[0],period))
|
||||||
for ii in range(1,ntime):
|
for ii in range(1,ntime):
|
||||||
traj += Particles(pp[:,:,ii],col,time[ii],period)
|
traj += Particles(pp[:,:,ii],col,time[ii],period)
|
||||||
|
traj.unraveled = unraveled
|
||||||
return traj
|
return traj
|
||||||
def add_particles(self,part):
|
def add_particles(self,part):
|
||||||
import numpy
|
import numpy
|
||||||
# Verify part
|
# Verify part
|
||||||
assert part.time>self.part[-1].time, "Time steps must be added in monotonically increasing order."
|
assert part.time>self.part[-1].time, "Time steps must be added in monotonically increasing order."
|
||||||
assert part.num==self.numpart, "Number of particles is different from previous time steps."
|
assert part.num==self.numpart, "Number of particles is different from previous time steps."
|
||||||
|
assert all(part['id']==self.part[0].attr['id']), "Particle IDs differ or a not in the same order."
|
||||||
assert all([part.period[ii]==self.period[ii] for ii in range(0,3)]), "Period differs!"
|
assert all([part.period[ii]==self.period[ii] for ii in range(0,3)]), "Period differs!"
|
||||||
for key in self.attr:
|
for key in self.attr:
|
||||||
assert key in part.attr, "Particles to be added are missing attribute '{}'".format(key)
|
assert key in part.attr, "Particles to be added are missing attribute '{}'".format(key)
|
||||||
|
|
@ -185,58 +181,46 @@ class Trajectories:
|
||||||
for ii,part in enumerate(self.part):
|
for ii,part in enumerate(self.part):
|
||||||
time[ii] = part.time
|
time[ii] = part.time
|
||||||
return time
|
return time
|
||||||
def get_trajectory(self,id=None,slice_part=slice(None),slice_time=slice(None)):
|
def get_trajectory(self,slice_part=slice(None),slice_time=slice(None)):
|
||||||
# WARNING: slice uses array indexing instead of ID
|
|
||||||
import numpy
|
import numpy
|
||||||
assert isinstance(slice_part,slice), "'slice_part' must be a slice."
|
assert isinstance(slice_part,slice), "'slice_part' must be a slice."
|
||||||
assert isinstance(slice_time,slice), "'slice_time' must be a slice."
|
assert isinstance(slice_time,slice), "'slice_time' must be a slice."
|
||||||
self._make_data_array()
|
self._make_data_array()
|
||||||
if id is None: # npart X ndim X ntime
|
return numpy.stack([
|
||||||
return numpy.stack([
|
self.attr['x'][slice_part,slice_time],
|
||||||
self.attr['x'][slice_part,slice_time],
|
self.attr['y'][slice_part,slice_time],
|
||||||
self.attr['y'][slice_part,slice_time],
|
self.attr['z'][slice_part,slice_time]],axis=0)
|
||||||
self.attr['z'][slice_part,slice_time]],axis=0)
|
|
||||||
else:
|
|
||||||
if id<0:
|
|
||||||
id = self.numpart+id+1
|
|
||||||
assert id>=1 and id<=self.numpart, "Particle ID out-of-bounds: {:d}".format(id)
|
|
||||||
id = slice(id-1,id)
|
|
||||||
return numpy.stack([
|
|
||||||
self.attr['x'][id,slice_time],
|
|
||||||
self.attr['y'][id,slice_time],
|
|
||||||
self.attr['z'][id,slice_time]],axis=0)
|
|
||||||
def get_segments(self,id=None):
|
def get_segments(self,id=None):
|
||||||
|
#self.unravel()
|
||||||
|
#if id is None: # npart X ndim X ntime
|
||||||
return
|
return
|
||||||
def unravel(self):
|
def unravel(self):
|
||||||
|
import numpy
|
||||||
if self.unraveled: return
|
if self.unraveled: return
|
||||||
if self._crossings is None: return
|
raise NotImplementedError('Implemented but untested!')
|
||||||
self._compute_crossings()
|
self._make_data_array()
|
||||||
# do sth
|
for axis in range(0,3):
|
||||||
|
if self.period[axis] is not None:
|
||||||
|
key = ('x','y','z')[axis]
|
||||||
|
#posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
||||||
|
posdiff = self.attr[key][:,1:]-self.attr[key][:,0:-1]
|
||||||
|
coeff = -numpy.sign(posdiff)*(numpy.abs(posdiff)>0.5*self.period[axis])
|
||||||
|
coeff = numpy.cumsum(coeff,axis=1)
|
||||||
|
self.attr[key][:,1:] += coeff*self.period[axis]
|
||||||
self.unraveled = True
|
self.unraveled = True
|
||||||
return
|
return
|
||||||
def ravel(self):
|
def ravel(self):
|
||||||
if not self.unraveled: return
|
if not self.unraveled: return
|
||||||
|
raise NotImplementedError('Implemented but untested!')
|
||||||
|
for axis in range(0,3):
|
||||||
|
if self.period[axis] is not None:
|
||||||
|
key = ('x','y','z')[axis]
|
||||||
|
self.part[ii].attr[key] %= self.period[axis]
|
||||||
self.unraveled = False
|
self.unraveled = False
|
||||||
return
|
return
|
||||||
def _compute_crossings(self):
|
|
||||||
import numpy
|
|
||||||
if self._crossings is not None: # None means no periodicity
|
|
||||||
num_computed = self._crossings.shape[1]
|
|
||||||
if num_computed==self.numtime:
|
|
||||||
return
|
|
||||||
crossings = numpy.zeros((6,self.numpart,self.numtime),dtype=numpy.bool)
|
|
||||||
crossings[:,:,0:num_computed] = self._crossings
|
|
||||||
for ii in range(num_computed,self.numtime):
|
|
||||||
for axis in range(0,3):
|
|
||||||
if self.period[axis] is not None:
|
|
||||||
key = ('x','y','z')[axis]
|
|
||||||
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
|
||||||
doesCross = abs(posdiff)>0.5*self.period[axis]
|
|
||||||
upCross = numpy.logical_and(posdiff<0,doesCross)
|
|
||||||
crossings[2*axis,:,ii] = doesCross
|
|
||||||
crossings[2*axis+1,:,ii] = upCross
|
|
||||||
self._crossings = crossings
|
|
||||||
def _make_data_array(self):
|
def _make_data_array(self):
|
||||||
|
'''Transforms self.attr[key] to a numpy array of dimension npart X ntime
|
||||||
|
and updates self.part[itime].attr[key] to views on self.attr[key][:,itime].'''
|
||||||
import numpy
|
import numpy
|
||||||
#print('DEBUG: _make_data_array')
|
#print('DEBUG: _make_data_array')
|
||||||
if not self._is_data_list:
|
if not self._is_data_list:
|
||||||
|
|
@ -250,6 +234,8 @@ class Trajectories:
|
||||||
self._is_data_list = False
|
self._is_data_list = False
|
||||||
return
|
return
|
||||||
def _make_data_list(self):
|
def _make_data_list(self):
|
||||||
|
'''Transforms self.attr[key] to a list of length ntime of numpy arrays of
|
||||||
|
dimension npart and updates self.part[itime].attr[key] to views on self.attr[key][itime].'''
|
||||||
import numpy
|
import numpy
|
||||||
#print('DEBUG: _make_data_list')
|
#print('DEBUG: _make_data_list')
|
||||||
if self._is_data_list:
|
if self._is_data_list:
|
||||||
|
|
@ -264,7 +250,6 @@ class Trajectories:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sort(pp,col):
|
def sort(pp,col):
|
||||||
ncol,npart,ntime = pp.shape
|
ncol,npart,ntime = pp.shape
|
||||||
assert('id' in col)
|
assert('id' in col)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue