added ravel/unravel but untested (it can be loaded from mat file)
This commit is contained in:
parent
d23824128c
commit
69230e4edd
45
particle.py
45
particle.py
|
|
@ -91,11 +91,6 @@ class Trajectories:
|
|||
for key in self.part[0].attr:
|
||||
self.attr[key] = [self.part[0].attr[key].view()]
|
||||
self.period = self.part[0].period
|
||||
if any([period is not None for period in self.period]):
|
||||
self._crossings = numpy.zeros((6,self.numpart,1),dtype=numpy.bool)
|
||||
else:
|
||||
self._crossings = None
|
||||
return
|
||||
self.unraveled = False
|
||||
def __str__(self):
|
||||
str = 'Trajectory with\n'
|
||||
|
|
@ -124,7 +119,7 @@ class Trajectories:
|
|||
raise TypeError("Trajectories can only be sliced by slice objects or integers.")
|
||||
return self.get_trajectory(slice_part=sl[0],slice_time=sl[1])
|
||||
@classmethod
|
||||
def from_mat(cls,file):
|
||||
def from_mat(cls,file,unraveled=False):
|
||||
from .helper import load_mat
|
||||
pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo'])
|
||||
period = [None,None,None]
|
||||
|
|
@ -137,6 +132,7 @@ class Trajectories:
|
|||
traj = cls(Particles(pp[:,:,0],col,time[0],period))
|
||||
for ii in range(1,ntime):
|
||||
traj += Particles(pp[:,:,ii],col,time[ii],period)
|
||||
traj.unraveled = unraveled
|
||||
return traj
|
||||
def add_particles(self,part):
|
||||
import numpy
|
||||
|
|
@ -206,36 +202,31 @@ class Trajectories:
|
|||
self.attr['y'][id,slice_time],
|
||||
self.attr['z'][id,slice_time]],axis=0)
|
||||
def get_segments(self,id=None):
|
||||
#self.unravel()
|
||||
#if id is None: # npart X ndim X ntime
|
||||
return
|
||||
def unravel(self):
|
||||
import numpy
|
||||
if self.unraveled: return
|
||||
if self._crossings is None: return
|
||||
self._compute_crossings()
|
||||
# do sth
|
||||
raise NotImplementedError('Implemented but untested!')
|
||||
for axis in range(0,3):
|
||||
if self.period[axis] is not None:
|
||||
key = ('x','y','z')[axis]
|
||||
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
||||
coeff = -(numpy.abs(posdiff)>0.5*self.period[axis]).*numpy.sign(posdiff)
|
||||
coeff = numpy.cumsum(coeff)
|
||||
self.part[ii].attr[key][1:] += coeff*self.period[axis]
|
||||
self.unraveled = True
|
||||
return
|
||||
def ravel(self):
|
||||
if not self.unraveled: return
|
||||
raise NotImplementedError('Implemented but untested!')
|
||||
for axis in range(0,3):
|
||||
if self.period[axis] is not None:
|
||||
key = ('x','y','z')[axis]
|
||||
self.part[ii].attr[key] %= self.period[axis]
|
||||
self.unraveled = False
|
||||
return
|
||||
def _compute_crossings(self):
|
||||
import numpy
|
||||
if self._crossings is not None: # None means no periodicity
|
||||
num_computed = self._crossings.shape[1]
|
||||
if num_computed==self.numtime:
|
||||
return
|
||||
crossings = numpy.zeros((6,self.numpart,self.numtime),dtype=numpy.bool)
|
||||
crossings[:,:,0:num_computed] = self._crossings
|
||||
for ii in range(num_computed,self.numtime):
|
||||
for axis in range(0,3):
|
||||
if self.period[axis] is not None:
|
||||
key = ('x','y','z')[axis]
|
||||
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
||||
doesCross = abs(posdiff)>0.5*self.period[axis]
|
||||
upCross = numpy.logical_and(posdiff<0,doesCross)
|
||||
crossings[2*axis,:,ii] = doesCross
|
||||
crossings[2*axis+1,:,ii] = upCross
|
||||
self._crossings = crossings
|
||||
def _make_data_array(self):
|
||||
import numpy
|
||||
#print('DEBUG: _make_data_array')
|
||||
|
|
|
|||
Loading…
Reference in New Issue