added ravel/unravel but untested (it can be loaded from mat file)

This commit is contained in:
Michael Krayer 2021-05-25 10:32:26 +02:00
parent d23824128c
commit 69230e4edd
1 changed files with 18 additions and 27 deletions

View File

@ -91,11 +91,6 @@ class Trajectories:
for key in self.part[0].attr:
self.attr[key] = [self.part[0].attr[key].view()]
self.period = self.part[0].period
if any([period is not None for period in self.period]):
self._crossings = numpy.zeros((6,self.numpart,1),dtype=numpy.bool)
else:
self._crossings = None
return
self.unraveled = False
def __str__(self):
str = 'Trajectory with\n'
@ -124,7 +119,7 @@ class Trajectories:
raise TypeError("Trajectories can only be sliced by slice objects or integers.")
return self.get_trajectory(slice_part=sl[0],slice_time=sl[1])
@classmethod
def from_mat(cls,file):
def from_mat(cls,file,unraveled=False):
from .helper import load_mat
pp,col,time,ccinfo = load_mat(file,['pp','colpy','time','ccinfo'])
period = [None,None,None]
@ -137,6 +132,7 @@ class Trajectories:
traj = cls(Particles(pp[:,:,0],col,time[0],period))
for ii in range(1,ntime):
traj += Particles(pp[:,:,ii],col,time[ii],period)
traj.unraveled = unraveled
return traj
def add_particles(self,part):
import numpy
@ -206,36 +202,31 @@ class Trajectories:
self.attr['y'][id,slice_time],
self.attr['z'][id,slice_time]],axis=0)
def get_segments(self,id=None):
#self.unravel()
#if id is None: # npart X ndim X ntime
return
def unravel(self):
import numpy
if self.unraveled: return
if self._crossings is None: return
self._compute_crossings()
# do sth
raise NotImplementedError('Implemented but untested!')
for axis in range(0,3):
if self.period[axis] is not None:
key = ('x','y','z')[axis]
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
coeff = -(numpy.abs(posdiff)>0.5*self.period[axis]).*numpy.sign(posdiff)
coeff = numpy.cumsum(coeff)
self.part[ii].attr[key][1:] += coeff*self.period[axis]
self.unraveled = True
return
def ravel(self):
if not self.unraveled: return
raise NotImplementedError('Implemented but untested!')
for axis in range(0,3):
if self.period[axis] is not None:
key = ('x','y','z')[axis]
self.part[ii].attr[key] %= self.period[axis]
self.unraveled = False
return
def _compute_crossings(self):
import numpy
if self._crossings is not None: # None means no periodicity
num_computed = self._crossings.shape[1]
if num_computed==self.numtime:
return
crossings = numpy.zeros((6,self.numpart,self.numtime),dtype=numpy.bool)
crossings[:,:,0:num_computed] = self._crossings
for ii in range(num_computed,self.numtime):
for axis in range(0,3):
if self.period[axis] is not None:
key = ('x','y','z')[axis]
posdiff = (self.part[ii].attr[key]-self.part[ii-1].attr[key]).squeeze()
doesCross = abs(posdiff)>0.5*self.period[axis]
upCross = numpy.logical_and(posdiff<0,doesCross)
crossings[2*axis,:,ii] = doesCross
crossings[2*axis+1,:,ii] = upCross
self._crossings = crossings
def _make_data_array(self):
import numpy
#print('DEBUG: _make_data_array')