added ravel/unravel but untested (it can be loaded from mat file)
This commit is contained in:
parent
d23824128c
commit
69230e4edd
45
particle.py
45
particle.py
|
|
@ -91,11 +91,6 @@ class Trajectories:
|
||||||
for key in self.part[0].attr:
|
for key in self.part[0].attr:
|
||||||
self.attr[key] = [self.part[0].attr[key].view()]
|
self.attr[key] = [self.part[0].attr[key].view()]
|
||||||
self.period = self.part[0].period
|
self.period = self.part[0].period
|
||||||
if any([period is not None for period in self.period]):
|
|
||||||
self._crossings = numpy.zeros((6,self.numpart,1),dtype=numpy.bool)
|
|
||||||
else:
|
|
||||||
self._crossings = None
|
|
||||||
return
|
|
||||||
self.unraveled = False
|
self.unraveled = False
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = 'Trajectory with\n'
|
str = 'Trajectory with\n'
|
||||||
|
|
@ -124,7 +119,7 @@ class Trajectories:
|
||||||
raise TypeError("Trajectories can only be sliced by slice objects or integers.")
|
raise TypeError("Trajectories can only be sliced by slice objects or integers.")
|
||||||
return self.get_trajectory(slice_part=sl[0],slice_time=sl[1])
|
return self.get_trajectory(slice_part=sl[0],slice_time=sl[1])
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_mat(cls,file):
|
def from_mat(cls,file,unraveled=False):
|
||||||
from .helper import load_mat
|
from .helper import load_mat
|
||||||
pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo'])
|
pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo'])
|
||||||
period = [None,None,None]
|
period = [None,None,None]
|
||||||
|
|
@ -137,6 +132,7 @@ class Trajectories:
|
||||||
traj = cls(Particles(pp[:,:,0],col,time[0],period))
|
traj = cls(Particles(pp[:,:,0],col,time[0],period))
|
||||||
for ii in range(1,ntime):
|
for ii in range(1,ntime):
|
||||||
traj += Particles(pp[:,:,ii],col,time[ii],period)
|
traj += Particles(pp[:,:,ii],col,time[ii],period)
|
||||||
|
traj.unraveled = unraveled
|
||||||
return traj
|
return traj
|
||||||
def add_particles(self,part):
|
def add_particles(self,part):
|
||||||
import numpy
|
import numpy
|
||||||
|
|
@ -206,36 +202,31 @@ class Trajectories:
|
||||||
self.attr['y'][id,slice_time],
|
self.attr['y'][id,slice_time],
|
||||||
self.attr['z'][id,slice_time]],axis=0)
|
self.attr['z'][id,slice_time]],axis=0)
|
||||||
def get_segments(self,id=None):
|
def get_segments(self,id=None):
|
||||||
|
#self.unravel()
|
||||||
|
#if id is None: # npart X ndim X ntime
|
||||||
return
|
return
|
||||||
def unravel(self):
|
def unravel(self):
|
||||||
|
import numpy
|
||||||
if self.unraveled: return
|
if self.unraveled: return
|
||||||
if self._crossings is None: return
|
raise NotImplementedError('Implemented but untested!')
|
||||||
self._compute_crossings()
|
for axis in range(0,3):
|
||||||
# do sth
|
if self.period[axis] is not None:
|
||||||
|
key = ('x','y','z')[axis]
|
||||||
|
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
||||||
|
coeff = -(numpy.abs(posdiff)>0.5*self.period[axis]).*numpy.sign(posdiff)
|
||||||
|
coeff = numpy.cumsum(coeff)
|
||||||
|
self.part[ii].attr[key][1:] += coeff*self.period[axis]
|
||||||
self.unraveled = True
|
self.unraveled = True
|
||||||
return
|
return
|
||||||
def ravel(self):
|
def ravel(self):
|
||||||
if not self.unraveled: return
|
if not self.unraveled: return
|
||||||
|
raise NotImplementedError('Implemented but untested!')
|
||||||
|
for axis in range(0,3):
|
||||||
|
if self.period[axis] is not None:
|
||||||
|
key = ('x','y','z')[axis]
|
||||||
|
self.part[ii].attr[key] %= self.period[axis]
|
||||||
self.unraveled = False
|
self.unraveled = False
|
||||||
return
|
return
|
||||||
def _compute_crossings(self):
|
|
||||||
import numpy
|
|
||||||
if self._crossings is not None: # None means no periodicity
|
|
||||||
num_computed = self._crossings.shape[1]
|
|
||||||
if num_computed==self.numtime:
|
|
||||||
return
|
|
||||||
crossings = numpy.zeros((6,self.numpart,self.numtime),dtype=numpy.bool)
|
|
||||||
crossings[:,:,0:num_computed] = self._crossings
|
|
||||||
for ii in range(num_computed,self.numtime):
|
|
||||||
for axis in range(0,3):
|
|
||||||
if self.period[axis] is not None:
|
|
||||||
key = ('x','y','z')[axis]
|
|
||||||
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
|
|
||||||
doesCross = abs(posdiff)>0.5*self.period[axis]
|
|
||||||
upCross = numpy.logical_and(posdiff<0,doesCross)
|
|
||||||
crossings[2*axis,:,ii] = doesCross
|
|
||||||
crossings[2*axis+1,:,ii] = upCross
|
|
||||||
self._crossings = crossings
|
|
||||||
def _make_data_array(self):
|
def _make_data_array(self):
|
||||||
import numpy
|
import numpy
|
||||||
#print('DEBUG: _make_data_array')
|
#print('DEBUG: _make_data_array')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue