Add unit tests, fix the bugs they discover #1
21 changed files with 919924 additions and 341 deletions
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
.coverage
|
||||
.cache
|
||||
__pycache__
|
||||
*.swp
|
||||
*.pyc
|
||||
0
code/__init__.py
Normal file
0
code/__init__.py
Normal file
|
|
@ -1,229 +1,782 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: iso-8859-15 -*-
|
||||
import numpy as np
|
||||
from statsmodels.robust.scale import mad
|
||||
from scipy import signal
|
||||
from scipy import ndimage
|
||||
import os
|
||||
from os.path import join as opj
|
||||
from scipy.signal import savgol_filter # Savitzky–Golay filter, for smoothing data
|
||||
from scipy.ndimage import median_filter
|
||||
import sys
|
||||
import gzip
|
||||
from os.path import basename
|
||||
from os.path import exists
|
||||
from glob import glob
|
||||
from math import (
|
||||
degrees,
|
||||
atan2,
|
||||
)
|
||||
|
||||
#infile = sys.argv[1]
|
||||
#outfile = sys.argv[2]
|
||||
import logging
|
||||
lgr = logging.getLogger('studyforrest.detect_eyegaze_events')
|
||||
|
||||
def get_signal_props(data, px2deg):
|
||||
|
||||
def deg_per_pixel(screen_size, viewing_distance, screen_resolution):
|
||||
"""Determine `px2deg` factor for EyegazeClassifier
|
||||
|
||||
Parameters
|
||||
----------
|
||||
screen_size : float
|
||||
Either vertical or horizontal dimension of the screen in any unit.
|
||||
viewing_distance : float
|
||||
Viewing distance from the screen in the same unit as `screen_size`.
|
||||
screen_resolution : int
|
||||
Number of pixels along the dimensions reported for `screen_size`.
|
||||
"""
|
||||
return degrees(atan2(.5 * screen_size, viewing_distance)) / \
|
||||
(.5 * screen_resolution)
|
||||
|
||||
|
||||
def find_peaks(vels, threshold):
|
||||
"""Find above-threshold time periods
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vels : array
|
||||
Velocities.
|
||||
threshold : float
|
||||
Velocity threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
Each item is a tuple with start and end index of the window where
|
||||
velocities exceed the threshold.
|
||||
"""
|
||||
def _get_vels(start, end):
|
||||
v = vels[start:end]
|
||||
v = v[~np.isnan(v)]
|
||||
return v
|
||||
|
||||
sacs = []
|
||||
sac_on = None
|
||||
for i, v in enumerate(vels):
|
||||
if sac_on is None and v > threshold:
|
||||
# start of a saccade
|
||||
sac_on = i
|
||||
elif sac_on is not None and v < threshold:
|
||||
sacs.append([
|
||||
sac_on,
|
||||
i,
|
||||
_get_vels(
|
||||
sac_on,
|
||||
min(len(vels), i + 1))
|
||||
])
|
||||
sac_on = None
|
||||
if sac_on:
|
||||
# end of data, but velocities still high
|
||||
sacs.append([
|
||||
sac_on,
|
||||
len(vels) - 1,
|
||||
_get_vels(sac_on, len(vels))])
|
||||
return sacs
|
||||
|
||||
|
||||
def find_saccade_onsetidx(vels, start_idx, sac_onset_velthresh):
|
||||
idx = start_idx
|
||||
while idx > 0 \
|
||||
and (vels[idx] > sac_onset_velthresh or
|
||||
vels[idx] <= vels[idx - 1]):
|
||||
# find first local minimum after vel drops below onset threshold
|
||||
# going backwards in time
|
||||
|
||||
# we used to do this, but it could mean detecting very long
|
||||
# saccades that consist of (mostly) missing data
|
||||
# or np.isnan(vels[sacc_start])):
|
||||
idx -= 1
|
||||
return idx
|
||||
|
||||
|
||||
def find_movement_offsetidx(vels, start_idx, off_velthresh):
|
||||
idx = start_idx
|
||||
# shift saccade end index to the first element that is below the
|
||||
# velocity threshold
|
||||
while idx < len(vels) - 1 \
|
||||
and (vels[idx] > off_velthresh or
|
||||
(vels[idx] > vels[idx + 1])):
|
||||
# we used to do this, but it could mean detecting very long
|
||||
# saccades that consist of (mostly) missing data
|
||||
# or np.isnan(vels[idx])):
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
|
||||
def find_psoend(velocities, sac_velthresh, sac_peak_velthresh):
|
||||
pso_peaks = find_peaks(velocities, sac_peak_velthresh)
|
||||
if pso_peaks:
|
||||
pso_label = 'HPSO'
|
||||
else:
|
||||
pso_peaks = find_peaks(velocities, sac_velthresh)
|
||||
if pso_peaks:
|
||||
pso_label = 'LPSO'
|
||||
if not pso_peaks:
|
||||
# no PSO
|
||||
return
|
||||
|
||||
# find minimum after the offset of the last reported peak
|
||||
pso_end = find_movement_offsetidx(
|
||||
velocities, pso_peaks[-1][1], sac_velthresh)
|
||||
|
||||
if pso_end > len(velocities):
|
||||
# velocities did not go down within the given window
|
||||
return
|
||||
|
||||
return pso_label, pso_end
|
||||
|
||||
|
||||
def filter_spikes(data):
|
||||
"""In-place high-frequency spike filter
|
||||
|
||||
Inspired by:
|
||||
|
||||
Stampe, D. M. (1993). Heuristic filtering and reliable calibration
|
||||
methods for video-based pupil-tracking systems. Behavior Research
|
||||
Methods, Instruments, & Computers, 25(2), 137–142.
|
||||
doi:10.3758/bf03204486
|
||||
"""
|
||||
def _filter(arr):
|
||||
# over all triples of neighboring samples
|
||||
for i in range(1, len(arr) - 1):
|
||||
if (arr[i - 1] < arr[i] and arr[i] > arr[i + 1]) \
|
||||
or (arr[i - 1] > arr[i] and arr[i] < arr[i + 1]):
|
||||
# immediate sign-reversal of the difference from
|
||||
# x-1 -> x -> x+1
|
||||
prev_dist = abs(arr[i - 1] - arr[i])
|
||||
next_dist = abs(arr[i + 1] - arr[i])
|
||||
# replace x by the neighboring value that is closest
|
||||
# in value
|
||||
arr[i] = arr[i - 1] \
|
||||
if prev_dist < next_dist else arr[i + 1]
|
||||
return arr
|
||||
|
||||
data['x'] = _filter(data['x'])
|
||||
data['y'] = _filter(data['y'])
|
||||
return data
|
||||
|
||||
|
||||
def get_dilated_nan_mask(arr, iterations, max_ignore_size=None):
|
||||
clusters, nclusters = ndimage.label(np.isnan(arr))
|
||||
# go through all clusters and remove any cluster that is less
|
||||
# the max_ignore_size
|
||||
for i in range(nclusters):
|
||||
# cluster index is base1
|
||||
i = i + 1
|
||||
if (clusters == i).sum() <= max_ignore_size:
|
||||
clusters[clusters == i] = 0
|
||||
# mask to cover all samples with dataloss > `max_ignore_size`
|
||||
mask = ndimage.binary_dilation(clusters > 0, iterations=iterations)
|
||||
return mask
|
||||
|
||||
|
||||
|
||||
class EyegazeClassifier(object):
|
||||
|
||||
record_field_names = [
|
||||
'id', 'label',
|
||||
'start_time', 'end_time',
|
||||
'start_x', 'start_y',
|
||||
'end_x', 'end_y',
|
||||
'amp', 'peak_vel', 'avg_vel',
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
px2deg,
|
||||
sampling_rate,
|
||||
velthresh_startvelocity=300.0,
|
||||
min_intersaccade_duration=0.04,
|
||||
min_saccade_duration=0.01,
|
||||
max_initial_saccade_freq=2.0,
|
||||
saccade_context_window_length=1.0,
|
||||
max_pso_duration=0.04,
|
||||
min_fixation_duration=0.04,
|
||||
max_fixation_amp=0.7):
|
||||
self.px2deg = px2deg
|
||||
self.sr = sr = sampling_rate
|
||||
self.velthresh_startvel = velthresh_startvelocity
|
||||
self.max_fix_amp = max_fixation_amp
|
||||
|
||||
# convert to #samples
|
||||
self.min_intersac_dur = int(
|
||||
min_intersaccade_duration * sr)
|
||||
self.min_sac_dur = int(
|
||||
min_saccade_duration * sr)
|
||||
self.sac_context_winlen = int(
|
||||
saccade_context_window_length * sr)
|
||||
self.max_pso_dur = int(
|
||||
max_pso_duration * sr)
|
||||
self.min_fix_dur = int(
|
||||
min_fixation_duration * sr)
|
||||
|
||||
self.max_sac_freq = max_initial_saccade_freq / sr
|
||||
|
||||
# TODO dissolve
|
||||
def _get_signal_props(self, data):
|
||||
data = data[~np.isnan(data['vel'])]
|
||||
pv = data['vel'].max()
|
||||
amp = (((data[0]['x'] - data[-1]['x']) ** 2 + \
|
||||
(data[0]['y'] - data[-1]['y']) ** 2) ** 0.5) * px2deg
|
||||
avVel = data['vel'].mean()
|
||||
return pv, amp, avVel
|
||||
(data[0]['y'] - data[-1]['y']) ** 2) ** 0.5) * self.px2deg
|
||||
medvel = np.median(data['vel'])
|
||||
return amp, pv, medvel
|
||||
|
||||
def get_adaptive_saccade_velocity_velthresh(self, vels):
|
||||
"""Determine saccade peak velocity threshold.
|
||||
|
||||
def detect(infile, outfile, fixation_threshold, px2deg):
|
||||
data = np.recfromcsv(
|
||||
infile,
|
||||
delimiter='\t',
|
||||
names=['vel', 'accel', 'x', 'y'])
|
||||
print ("Data length", len(data))
|
||||
Takes global noise-level of data into account. Implementation
|
||||
based on algorithm proposed by NYSTROM and HOLMQVIST (2010).
|
||||
|
||||
out=gzip.open(outfile,"wb")
|
||||
Parameters
|
||||
----------
|
||||
start : float
|
||||
Start velocity for adaptation algorithm. Should be larger than
|
||||
any conceivable minimal saccade velocity (in deg/s).
|
||||
TODO std unit multipliers
|
||||
|
||||
#####get threshold function #######
|
||||
newThr=200 # What is this "threshold"?
|
||||
def getThresh(cut): # def refers to defining your own function; cut is input arg
|
||||
vel_uthr = data['vel'][data['vel'] < cut]
|
||||
avg = vel_uthr.mean()
|
||||
sd = vel_uthr.std()
|
||||
return avg+6*sd, avg, sd # outputs of function; average+6*sd denotes a RANGE in any normal distribution
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
(peak saccade velocity threshold, saccade onset velocity threshold).
|
||||
The latter (and lower) value can be used to determine a more precise
|
||||
saccade onset.
|
||||
"""
|
||||
cur_thresh = self.velthresh_startvel
|
||||
|
||||
###### threshold function ######### NYSTROM and HOLMQVIST (2010) ALGORITHM IS USED to find a suitable threshold
|
||||
def _get_thresh(cut):
|
||||
# helper function
|
||||
vel_uthr = vels[vels < cut]
|
||||
med = np.median(vel_uthr)
|
||||
scale = mad(vel_uthr)
|
||||
return med + 10 * scale, med, scale
|
||||
|
||||
dif=2
|
||||
while dif > 1:
|
||||
oldThr = newThr #Threshold in 100-300 degree/sec. 200 here.
|
||||
newThr, avg, sd = getThresh(oldThr) #Average and std is calculated and Thr is renewed
|
||||
dif= abs(oldThr - newThr) #return absolute value, keep doing the loop until PTn-PTn-1 is smaller than 1 degree
|
||||
# re-compute threshold until value converges
|
||||
count = 0
|
||||
dif = 2
|
||||
while dif > 1 and count < 30: # less than 1deg/s difference
|
||||
old_thresh = cur_thresh
|
||||
cur_thresh, med, scale = _get_thresh(old_thresh)
|
||||
if not cur_thresh:
|
||||
# safe-guard in case threshold runs to zero in
|
||||
# case of really clean and sparse data
|
||||
cur_thresh = old_thresh
|
||||
break
|
||||
lgr.debug(
|
||||
'Saccade threshold velocity: %.1f '
|
||||
'(non-saccade mvel: %.1f, stdvel: %.1f)',
|
||||
cur_thresh, med, scale)
|
||||
dif = abs(old_thresh - cur_thresh)
|
||||
count += 1
|
||||
|
||||
threshold=newThr
|
||||
soft_threshold = avg + 3 * sd
|
||||
print("after thr selection", threshold)
|
||||
return cur_thresh, (med + 5 * scale)
|
||||
|
||||
def _mk_event_record(self, data, idx, label, start, end):
|
||||
return dict(zip(self.record_field_names, (
|
||||
idx,
|
||||
label,
|
||||
start,
|
||||
end,
|
||||
data[start]['x'],
|
||||
data[start]['y'],
|
||||
data[end - 1]['x'],
|
||||
data[end - 1]['y']) +
|
||||
self._get_signal_props(data[start:end])))
|
||||
|
||||
####get peaks#### Saccade by definition, is the first velocity that goes above the saccade threshold (NOT VELOCITY threshold)
|
||||
peaks=[]
|
||||
def __call__(self, data, classify_isp=True, sort_events=True):
|
||||
# find threshold velocities
|
||||
sac_peak_med_velthresh, sac_onset_med_velthresh = \
|
||||
self.get_adaptive_saccade_velocity_velthresh(data['med_vel'])
|
||||
lgr.info(
|
||||
'Global saccade MEDIAN velocity thresholds: '
|
||||
'%.1f, %.1f (onset, peak)',
|
||||
sac_onset_med_velthresh, sac_peak_med_velthresh)
|
||||
|
||||
peaks = np.where(
|
||||
np.logical_and(
|
||||
data['vel'][:-1] < threshold,
|
||||
data['vel'][1:] > threshold))[0]
|
||||
# XXX original code had [0] at index 1
|
||||
# XXX really?! why 1
|
||||
peaks += 1
|
||||
saccade_locs = find_peaks(
|
||||
data['med_vel'],
|
||||
sac_peak_med_velthresh)
|
||||
|
||||
above_thr_clusters, nclusters = ndimage.label(data['vel'] > soft_threshold)
|
||||
if not nclusters:
|
||||
print('Got no above threshold values, baby. Going home...')
|
||||
return
|
||||
# reinclude any timepoint that has missing data, and treat it as above threshold
|
||||
# XXX could this possibly introduce fake saccades? MIH think not, but isnt sure
|
||||
above_thr_clusters[np.isnan(data['vel'])] = 1
|
||||
|
||||
fix=[]
|
||||
events = []
|
||||
saccade_events = []
|
||||
for e in self._detect_saccades(
|
||||
saccade_locs,
|
||||
data,
|
||||
0,
|
||||
len(data),
|
||||
context=self.sac_context_winlen):
|
||||
saccade_events.append(e.copy())
|
||||
events.append(e)
|
||||
|
||||
print (peaks)
|
||||
lgr.info('Start ISP classification')
|
||||
|
||||
for i, pos in enumerate(peaks):
|
||||
sacc_start = pos
|
||||
while sacc_start > 0 and above_thr_clusters[sacc_start] > 0:
|
||||
sacc_start -= 1
|
||||
if classify_isp:
|
||||
events.extend(self._classify_intersaccade_periods(
|
||||
data,
|
||||
0,
|
||||
len(data),
|
||||
# needs to be in order of appearance
|
||||
sorted(saccade_events, key=lambda x: x['start_time']),
|
||||
saccade_detection=True))
|
||||
|
||||
# TODO: make sane
|
||||
fix.append(-(sacc_start - 1)) # this is chinese for saying "I am not a fixation anymore"
|
||||
|
||||
off_period_vel = data['vel'][sacc_start - 41:sacc_start]
|
||||
# exclude NaN
|
||||
off_period_vel = off_period_vel[~np.isnan(off_period_vel)]
|
||||
# go with adaptive threshold, but only if the 40ms prior to the saccade have some
|
||||
# data to compute a velocity stdev from
|
||||
off_threshold = (0.7 * soft_threshold) + \
|
||||
(0.3 * (np.mean(off_period_vel) + 3 * np.std(off_period_vel))) \
|
||||
if len(off_period_vel) > 40 else soft_threshold
|
||||
|
||||
sacc_end = pos
|
||||
while sacc_end < len(data) - 1 > 0 and \
|
||||
(data['vel'][sacc_end] > off_threshold or \
|
||||
np.isnan(data['vel'][sacc_end])):
|
||||
sacc_end += 1
|
||||
# mark start of a fixation
|
||||
fix.append(sacc_end)
|
||||
|
||||
# minimum duration 9 ms and no blinks allowed (!) If we increase this then saccades higher than 9ms will be considered as fixations too --- we can now get the "short" saccades
|
||||
|
||||
if sacc_end - sacc_start >= 21 and\
|
||||
not np.sum(np.isnan(data['x'][sacc_start:sacc_end])):
|
||||
sacc_data = data[sacc_start:sacc_end]
|
||||
pv, amp, avVel = get_signal_props(sacc_data, px2deg)
|
||||
sacc_duration = sacc_end - sacc_start
|
||||
events.append((
|
||||
"SACCADE",
|
||||
sacc_start,
|
||||
sacc_end,
|
||||
sacc_data[0]['x'],
|
||||
sacc_data[0]['y'],
|
||||
sacc_data[-1]['x'],
|
||||
sacc_data[-1]['y'],
|
||||
amp,
|
||||
pv,
|
||||
avVel,
|
||||
sacc_duration))
|
||||
# The rest of the shorter saccades will be assigned as "FIX"'s as well.
|
||||
# Note: they may become indistinguisble from our events that meet the formal fixation criterion.
|
||||
# Could call them something else "PURSUIT" ?
|
||||
|
||||
elif sacc_end - sacc_start < 21 and sacc_end - sacc_start > 9 and\
|
||||
not np.sum(np.isnan(data['x'][sacc_start:sacc_end])):
|
||||
sacc_data = data[sacc_start:sacc_end]
|
||||
pv, amp, avVel = get_signal_props(sacc_data, px2deg)
|
||||
sacc_duration = sacc_end - sacc_start
|
||||
events.append((
|
||||
"FIX",
|
||||
sacc_start,
|
||||
sacc_end,
|
||||
sacc_data[0]['x'],
|
||||
sacc_data[0]['y'],
|
||||
sacc_data[-1]['x'],
|
||||
sacc_data[-1]['y'],
|
||||
amp,
|
||||
pv,
|
||||
avVel,
|
||||
sacc_duration))
|
||||
|
||||
######## end of saccade detection = begin of glissade detection ########
|
||||
|
||||
idx = sacc_end + 1
|
||||
# below=False
|
||||
# offset=False
|
||||
# pval=[]
|
||||
#sacc_data
|
||||
|
||||
gldata = data[sacc_end:sacc_end + 40]
|
||||
# going from the end of the window to find the last match
|
||||
for i in range(0, len(gldata) - 2):
|
||||
# velocity after saccade end goes below the soft threshold
|
||||
# and immediately afterwards stay or increases the velocity again
|
||||
if gldata[(-1 * i) - 2]['vel'] < soft_threshold and \
|
||||
gldata[(-1 * i) - 1]['vel'] > soft_threshold and \
|
||||
gldata[(-1 * i) -3]['vel'] >= gldata[(-1 * i) -2]['vel']:
|
||||
gliss_data = gldata[:-i]
|
||||
gliss_end = sacc_end + len(gldata) - i
|
||||
|
||||
if not len(gliss_data) or np.sum(np.isnan(gliss_data['vel'])) > 10:
|
||||
# not more than 10 ms of signal loss in glissades
|
||||
break
|
||||
pv, amp, avVel = get_signal_props(gliss_data, px2deg)
|
||||
gl_duration = gliss_end - (sacc_end + 1)
|
||||
events.append((
|
||||
"GLISSADE",
|
||||
sacc_end + 1,
|
||||
gliss_end,
|
||||
gldata[0]['x'],
|
||||
gldata[0]['y'],
|
||||
gldata[-i]['x'],
|
||||
gldata[-1]['y'],
|
||||
amp,
|
||||
pv,
|
||||
avVel,
|
||||
gl_duration))
|
||||
fix.pop()
|
||||
fix.append(gliss_end)
|
||||
break
|
||||
|
||||
######### fixation detection after everything else is identified ########
|
||||
|
||||
for j, f in enumerate(fix[:-1]):
|
||||
fix_start = f
|
||||
# end times are coded negative
|
||||
fix_end = abs(fix[j + 1])
|
||||
if f > 0 and fix_end - f > 40: #onset of fixation
|
||||
fixdata = data[f:fix_end]
|
||||
if not len(fixdata) or np.isnan(fixdata[0][0]):
|
||||
print("Erroneous fixation interval")
|
||||
continue
|
||||
pv, amp, avVel = get_signal_props(fixdata, px2deg)
|
||||
fix_duration = fix_end - f
|
||||
|
||||
if avVel < fixation_threshold and amp < 2 and np.sum(np.isnan(fixdata['vel'])) <= 10:
|
||||
events.append((
|
||||
"FIX",
|
||||
f,
|
||||
fix[j + 1],
|
||||
data[f]['x'],
|
||||
data[f]['y'],
|
||||
data[fix_end]['x'],
|
||||
data[fix_end]['y'],
|
||||
amp,
|
||||
pv,
|
||||
avVel,
|
||||
fix_duration))
|
||||
|
||||
|
||||
# TODO think about just saving it in binary form
|
||||
f = gzip.open (outfile, "w")
|
||||
# make timing info absolute times, not samples
|
||||
for e in events:
|
||||
f.write('%s\t%i\t%i\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\n' % e)
|
||||
print ("done")
|
||||
for i in ('start_time', 'end_time'):
|
||||
e[i] = e[i] / self.sr
|
||||
|
||||
return sorted(events, key=lambda x: x['start_time']) \
|
||||
if sort_events else events
|
||||
|
||||
def _detect_saccades(
|
||||
self,
|
||||
candidate_locs,
|
||||
data,
|
||||
start,
|
||||
end,
|
||||
context):
|
||||
|
||||
saccade_events = []
|
||||
|
||||
if context is None:
|
||||
# no context size was given, use all data
|
||||
# to determine velocity thresholds
|
||||
lgr.debug(
|
||||
'Determine velocity thresholds on full segment '
|
||||
'[%i, %i]', start, end)
|
||||
sac_peak_velthresh, sac_onset_velthresh = \
|
||||
self.get_adaptive_saccade_velocity_velthresh(
|
||||
data['vel'][start:end])
|
||||
if candidate_locs is None:
|
||||
lgr.debug(
|
||||
'Find velocity peaks on full segment '
|
||||
'[%i, %i]', start, end)
|
||||
candidate_locs = [
|
||||
(e[0] + start, e[1] + start, e[2]) for e in find_peaks(
|
||||
data['vel'][start:end],
|
||||
sac_peak_velthresh)]
|
||||
|
||||
# status map indicating which event class any timepoint has been
|
||||
# assigned to so far
|
||||
status = np.zeros((len(data),), dtype=int)
|
||||
|
||||
# loop over all peaks sorted by the sum of their velocities
|
||||
# i.e. longer and faster goes first
|
||||
for i, props in enumerate(sorted(
|
||||
candidate_locs, key=lambda x: x[2].sum(), reverse=True)):
|
||||
sacc_start, sacc_end, peakvels = props
|
||||
lgr.info(
|
||||
'Process peak velocity window [%i, %i] at ~%.1f deg/s',
|
||||
sacc_start, sacc_end, peakvels.mean())
|
||||
|
||||
if context:
|
||||
# extract velocity data in the vicinity of the peak to
|
||||
# calibrate threshold
|
||||
win_start = max(
|
||||
start,
|
||||
sacc_start - int(context / 2))
|
||||
win_end = min(
|
||||
end,
|
||||
sacc_end + context - (sacc_start - win_start))
|
||||
lgr.debug(
|
||||
'Determine velocity thresholds in context window '
|
||||
'[%i, %i]', win_start, win_end)
|
||||
lgr.debug('Actual context window: [%i, %i] -> %i',
|
||||
win_start, win_end, win_end - win_start)
|
||||
|
||||
sac_peak_velthresh, sac_onset_velthresh = \
|
||||
self.get_adaptive_saccade_velocity_velthresh(
|
||||
data['vel'][win_start:win_end])
|
||||
|
||||
lgr.info('Active saccade velocity thresholds: '
|
||||
'%.1f, %.1f (onset, peak)',
|
||||
sac_onset_velthresh, sac_peak_velthresh)
|
||||
|
||||
# move backwards in time to find the saccade onset
|
||||
sacc_start = find_saccade_onsetidx(
|
||||
data['vel'], sacc_start, sac_onset_velthresh)
|
||||
|
||||
# move forward in time to find the saccade offset
|
||||
sacc_end = find_movement_offsetidx(
|
||||
data['vel'], sacc_end, sac_onset_velthresh)
|
||||
|
||||
sacc_data = data[sacc_start:sacc_end]
|
||||
if sacc_end - sacc_start < self.min_sac_dur:
|
||||
lgr.debug('Skip saccade candidate, too short')
|
||||
continue
|
||||
elif np.sum(np.isnan(sacc_data['x'])): # pragma: no cover
|
||||
# should not happen
|
||||
lgr.debug('Skip saccade candidate, missing data')
|
||||
continue
|
||||
elif status[
|
||||
max(0,
|
||||
sacc_start - self.min_intersac_dur):min(
|
||||
len(data), sacc_end + self.min_intersac_dur)].sum():
|
||||
lgr.debug('Skip saccade candidate, too close to another event')
|
||||
continue
|
||||
|
||||
lgr.debug('Found SACCADE [%i, %i]',
|
||||
sacc_start, sacc_end)
|
||||
event = self._mk_event_record(data, i, "SACC", sacc_start, sacc_end)
|
||||
|
||||
yield event.copy()
|
||||
saccade_events.append(event)
|
||||
|
||||
# mark as a saccade
|
||||
status[sacc_start:sacc_end] = 1
|
||||
|
||||
pso = find_psoend(
|
||||
data['vel'][sacc_end:sacc_end + self.max_pso_dur],
|
||||
sac_onset_velthresh,
|
||||
sac_peak_velthresh)
|
||||
if pso:
|
||||
pso_label, pso_end = pso
|
||||
lgr.debug('Found %s [%i, %i]',
|
||||
pso_label, sacc_end, pso_end)
|
||||
psoevent = self._mk_event_record(
|
||||
data, i, pso_label, sacc_end, sacc_end + pso_end)
|
||||
if psoevent['amp'] < saccade_events[-1]['amp']:
|
||||
# discard PSO with amplitudes larger than their
|
||||
# anchor saccades
|
||||
yield psoevent.copy()
|
||||
# mark as a saccade part
|
||||
status[sacc_end:sacc_end + pso_end] = 1
|
||||
else:
|
||||
lgr.debug(
|
||||
'Ignore PSO, amplitude large than that of '
|
||||
'the previous saccade: %.1f >= %.1f',
|
||||
psoevent['amp'], saccade_events[-1]['amp'])
|
||||
|
||||
if self.max_sac_freq and \
|
||||
float(len(saccade_events)) / len(data) > self.max_sac_freq:
|
||||
lgr.info('Stop initial saccade detection, max frequency '
|
||||
'reached')
|
||||
break
|
||||
|
||||
def _classify_intersaccade_periods(
|
||||
self,
|
||||
data,
|
||||
start,
|
||||
end,
|
||||
saccade_events,
|
||||
saccade_detection):
|
||||
|
||||
lgr.warn(
|
||||
'Determine ISPs %i, %i (%i saccade-related events)',
|
||||
start, end, len(saccade_events))
|
||||
|
||||
prev_sacc = None
|
||||
prev_pso = None
|
||||
for ev in saccade_events:
|
||||
if prev_sacc is None:
|
||||
if 'SAC' not in ev['label']:
|
||||
continue
|
||||
elif prev_pso is None and 'PS' in ev['label']:
|
||||
prev_pso = ev
|
||||
continue
|
||||
elif 'SAC' not in ev['label']:
|
||||
continue
|
||||
|
||||
# at this point we have a previous saccade (and possibly its PSO)
|
||||
# on record, and we have just found the next saccade
|
||||
# -> inter-saccade window is determined
|
||||
if prev_sacc is None:
|
||||
win_start = start
|
||||
else:
|
||||
if prev_pso is not None:
|
||||
win_start = prev_pso['end_time']
|
||||
else:
|
||||
win_start = prev_sacc['end_time']
|
||||
# enforce dtype for indexing
|
||||
win_end = ev['start_time']
|
||||
if win_start == win_end:
|
||||
prev_sacc = ev
|
||||
prev_pso = None
|
||||
continue
|
||||
|
||||
lgr.warn('Found ISP [%i:%i]', win_start, win_end)
|
||||
for e in self._classify_intersaccade_period(
|
||||
data,
|
||||
win_start,
|
||||
win_end,
|
||||
saccade_detection=saccade_detection):
|
||||
yield e
|
||||
|
||||
# lastly, the current saccade becomes the previous one
|
||||
prev_sacc = ev
|
||||
prev_pso = None
|
||||
|
||||
if prev_sacc is not None and prev_sacc['end_time'] == end:
|
||||
return
|
||||
|
||||
lgr.debug("LAST_SEGMENT_ISP: %s -> %s", prev_sacc, prev_pso)
|
||||
# and for everything beyond the last saccade (if there was any)
|
||||
for e in self._classify_intersaccade_period(
|
||||
data,
|
||||
start if prev_sacc is None
|
||||
else prev_sacc['end_time'] if prev_pso is None
|
||||
else prev_pso['end_time'],
|
||||
end,
|
||||
saccade_detection=saccade_detection):
|
||||
yield e
|
||||
|
||||
def _classify_intersaccade_period(
|
||||
self,
|
||||
data,
|
||||
start,
|
||||
end,
|
||||
saccade_detection):
|
||||
lgr.warn('Determine NaN-free intervals in [%i:%i] (%i)',
|
||||
start, end, end - start)
|
||||
|
||||
# split the ISP up into its non-NaN pieces:
|
||||
win_start = None
|
||||
for idx in range(start, end + 1):
|
||||
if win_start is None and not np.isnan(data['x'][idx]):
|
||||
win_start = idx
|
||||
elif win_start is not None and \
|
||||
((idx == end) or np.isnan(data['x'][idx])):
|
||||
for e in self._classify_intersaccade_period_helper(
|
||||
data,
|
||||
win_start,
|
||||
idx,
|
||||
saccade_detection):
|
||||
yield e
|
||||
# reset non-NaN window start
|
||||
win_start = None
|
||||
|
||||
def _classify_intersaccade_period_helper(
|
||||
self,
|
||||
data,
|
||||
start,
|
||||
end,
|
||||
saccade_detection):
|
||||
# no NaN values in data at this point!
|
||||
lgr.warn(
|
||||
'Process non-NaN segment [%i, %i] -> %i',
|
||||
start, end, end - start)
|
||||
|
||||
label_remap = {
|
||||
'SACC': 'ISAC',
|
||||
'HPSO': 'IHPS',
|
||||
'LPSO': 'ILPS',
|
||||
}
|
||||
|
||||
length = end - start
|
||||
# detect saccades, if the there is enough space to maintain minimal
|
||||
# distance to other saccades
|
||||
if length > (
|
||||
2 * self.min_intersac_dur) \
|
||||
+ self.min_sac_dur + self.max_pso_dur:
|
||||
lgr.warn('Perform saccade detection in [%i:%i]', start, end)
|
||||
saccades = self._detect_saccades(
|
||||
None,
|
||||
data,
|
||||
start,
|
||||
end,
|
||||
context=None)
|
||||
saccade_events = []
|
||||
if saccades is not None:
|
||||
kill_pso = False
|
||||
for s in saccades:
|
||||
if kill_pso:
|
||||
kill_pso = False
|
||||
if s['label'].endswith('PSO'):
|
||||
continue
|
||||
if s['start_time'] - start < self.min_intersac_dur or \
|
||||
end - s['end_time'] < self.min_intersac_dur:
|
||||
# to close to another saccade
|
||||
kill_pso = True
|
||||
continue
|
||||
s['label'] = label_remap.get(s['label'], s['label'])
|
||||
# need to make a copy of the dict to not have outside
|
||||
# modification interfere with further inside processing
|
||||
yield s.copy()
|
||||
saccade_events.append(s)
|
||||
if saccade_events:
|
||||
lgr.warn('Found additional saccades in ISP')
|
||||
# and now process the intervals between the saccades
|
||||
for e in self._classify_intersaccade_periods(
|
||||
data,
|
||||
start,
|
||||
end,
|
||||
sorted(saccade_events,
|
||||
key=lambda x: x['start_time']),
|
||||
saccade_detection=False):
|
||||
yield e
|
||||
return
|
||||
|
||||
max_amp, label = self._fix_or_pursuit(data, start, end)
|
||||
if label is not None:
|
||||
yield self._mk_event_record(
|
||||
data,
|
||||
max_amp,
|
||||
label,
|
||||
start,
|
||||
end)
|
||||
|
||||
def _fix_or_pursuit(self, data, start, end):
|
||||
win_data = data[start:end].copy()
|
||||
|
||||
if len(win_data) < self.min_fix_dur:
|
||||
return None, None
|
||||
|
||||
def _butter_lowpass(cutoff, fs, order=5):
|
||||
nyq = 0.5 * fs
|
||||
normal_cutoff = cutoff / nyq
|
||||
b, a = signal.butter(
|
||||
order,
|
||||
normal_cutoff,
|
||||
btype='low',
|
||||
analog=False)
|
||||
return b, a
|
||||
|
||||
b, a = _butter_lowpass(10.0, 1000.0)
|
||||
win_data['x'] = signal.filtfilt(b, a, win_data['x'], method='gust')
|
||||
win_data['y'] = signal.filtfilt(b, a, win_data['y'], method='gust')
|
||||
|
||||
win_data = win_data[10:-10]
|
||||
start_x = win_data[0]['x']
|
||||
start_y = win_data[0]['y']
|
||||
|
||||
# determine max location deviation from start coordinate
|
||||
amp = (((start_x - win_data['x']) ** 2 +
|
||||
(start_y - win_data['y']) ** 2) ** 0.5)
|
||||
amp_argmax = amp.argmax()
|
||||
max_amp = amp[amp_argmax] * self.px2deg
|
||||
#print('MAX IN WIN [{}:{}]@{:.1f})'.format(start, end, max_amp))
|
||||
|
||||
if max_amp > self.max_fix_amp:
|
||||
return max_amp, 'PURS'
|
||||
return max_amp, 'FIXA'
|
||||
|
||||
def preproc(
|
||||
self,
|
||||
data,
|
||||
min_blink_duration=0.02,
|
||||
dilate_nan=0.01,
|
||||
median_filter_length=0.05,
|
||||
savgol_length=0.019,
|
||||
savgol_polyord=2,
|
||||
max_vel=1000.0):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data : array
|
||||
Record array with fields ('x', 'y', 'pupil')
|
||||
px2deg : float
|
||||
Size of a pixel in visual angles.
|
||||
min_blink_duration : float
|
||||
In seconds. Any signal loss shorter than this duration with not be
|
||||
considered for `dilate_blink`.
|
||||
dilate_blink : float
|
||||
Duration by which to dilate a blink window (missing data segment) on
|
||||
either side (in seconds).
|
||||
median_filter_width : float
|
||||
Filter window length in seconds.
|
||||
savgol_length : float
|
||||
Filter window length in seconds.
|
||||
savgol_polyord : int
|
||||
Filter polynomial order used to fit the samples.
|
||||
sampling_rate : float
|
||||
In Hertz
|
||||
max_vel : float
|
||||
Maximum velocity in deg/s. Any velocity value larger than this threshold
|
||||
will be replaced by the previous velocity value. Additionally a warning
|
||||
will be issued to indicate a potentially inappropriate filter setup.
|
||||
"""
|
||||
# convert params in seconds to #samples
|
||||
dilate_nan = int(dilate_nan * self.sr)
|
||||
min_blink_duration = int(min_blink_duration * self.sr)
|
||||
savgol_length = int(savgol_length * self.sr)
|
||||
median_filter_length = int(median_filter_length * self.sr)
|
||||
|
||||
# in-place spike filter
|
||||
data = filter_spikes(data)
|
||||
|
||||
# for signal loss exceeding the minimum blink duration, add additional
|
||||
# dilate_nan at either end
|
||||
# find clusters of "no data"
|
||||
if dilate_nan:
|
||||
mask = get_dilated_nan_mask(
|
||||
data['x'],
|
||||
dilate_nan,
|
||||
min_blink_duration)
|
||||
data['x'][mask] = np.nan
|
||||
data['y'][mask] = np.nan
|
||||
|
||||
if savgol_length:
|
||||
for i in ('x', 'y'):
|
||||
data[i] = savgol_filter(data[i], savgol_length, savgol_polyord)
|
||||
|
||||
# velocity calculation, exclude velocities over `max_vel`
|
||||
# euclidean distance between successive coordinate samples
|
||||
# no entry for first datapoint!
|
||||
velocities = (np.diff(data['x']) ** 2 + np.diff(data['y']) ** 2) ** 0.5
|
||||
# convert from px/sample to deg/s
|
||||
velocities *= self.px2deg * self.sr
|
||||
|
||||
if median_filter_length:
|
||||
med_velocities = np.zeros((len(data),), velocities.dtype)
|
||||
med_velocities[1:] = (
|
||||
np.diff(median_filter(data['x'],
|
||||
size=median_filter_length)) ** 2 +
|
||||
np.diff(median_filter(data['y'],
|
||||
size=median_filter_length)) ** 2) ** 0.5
|
||||
# convert from px/sample to deg/s
|
||||
med_velocities *= self.px2deg * self.sr
|
||||
# remove any velocity bordering NaN
|
||||
med_velocities[get_dilated_nan_mask(
|
||||
med_velocities, dilate_nan, 0)] = np.nan
|
||||
|
||||
# replace "too fast" velocities with previous velocity
|
||||
# add missing first datapoint
|
||||
filtered_velocities = [float(0)]
|
||||
for vel in velocities:
|
||||
if vel > max_vel: # deg/s
|
||||
# ignore very fast velocities
|
||||
lgr.warning(
|
||||
'Computed velocity exceeds threshold. '
|
||||
'Inappropriate filter setup? [%.1f > %.1f deg/s]',
|
||||
vel,
|
||||
max_vel)
|
||||
vel = filtered_velocities[-1]
|
||||
filtered_velocities.append(vel)
|
||||
velocities = np.array(filtered_velocities)
|
||||
|
||||
# acceleration is change of velocities over the last time unit
|
||||
acceleration = np.zeros(velocities.shape, velocities.dtype)
|
||||
acceleration[1:] = (velocities[1:] - velocities[:-1]) * self.sr
|
||||
|
||||
arrs = [med_velocities] if median_filter_length else []
|
||||
names = ['med_vel'] if median_filter_length else []
|
||||
arrs.extend([
|
||||
velocities,
|
||||
acceleration,
|
||||
data['x'],
|
||||
data['y']])
|
||||
names.extend(['vel', 'accel', 'x', 'y'])
|
||||
return np.core.records.fromarrays(arrs, names=names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fixation_threshold = float(sys.argv[1])
|
||||
fixation_velthresh = float(sys.argv[1])
|
||||
px2deg = float(sys.argv[2])
|
||||
infpath = sys.argv[3]
|
||||
outfpath = sys.argv[4]
|
||||
detect(infpath, outfpath, fixation_threshold, px2deg)
|
||||
data = np.recfromcsv(
|
||||
infpath,
|
||||
delimiter='\t',
|
||||
names=['vel', 'accel', 'x', 'y'])
|
||||
|
||||
events = detect(data, outfpath, fixation_velthresh, px2deg)
|
||||
|
||||
# TODO think about just saving it in binary form
|
||||
f = gzip.open(outfpath, "w")
|
||||
for e in events:
|
||||
f.write('%s\t%i\t%i\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\n' % e)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
# Sebastiaan Mathot (modified)
|
||||
# -*- coding: iso-8859-15 -*-
|
||||
|
||||
|
||||
from math import atan2, degrees
|
||||
h = 52.2 # Monitor width in cm, this is the BEH screen
|
||||
d = 85 # Distance between monitor and participant in cm
|
||||
r = 1280 # Horizontal resolution of the monitor
|
||||
size_in_px = 1 # The stimulus size in pixels
|
||||
# Calculate the number of degrees that correspond to a single pixel. This will
|
||||
# generally be a very small value, something like 0.03.
|
||||
deg_per_px = degrees(atan2(.5*h, d)) / (.5*r)
|
||||
print '%s degrees correspond to a single pixel' % deg_per_px
|
||||
# Calculate the size of the stimulus in degrees
|
||||
size_in_deg = size_in_px * deg_per_px
|
||||
print 'The size of the stimulus is %s pixels and %s visual degrees' \
|
||||
% (size_in_px, size_in_deg)
|
||||
|
||||
|
||||
#h = 26.5 # Monitor width in cm
|
||||
#d = 63 # Distance between monitor and participant in cm
|
||||
#This is for the MRI screen
|
||||
|
||||
#Checking from paper
|
||||
#MRI
|
||||
#In [1]: 0.0185581232561*1280
|
||||
#Out[1]: 23.754397767808
|
||||
#BEH
|
||||
#In [2]: 0.0266711972026*1280
|
||||
#Out[2]: 34.139132419328
|
||||
|
||||
#Calculating the ratio
|
||||
#In [7]: 0.0185581232561*0.01
|
||||
#Out[7]: 0.000185581232561
|
||||
|
||||
#In [8]: 0.000185581232561/0.0266711972026
|
||||
#Out[8]: 0.006958114071568895
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: iso-8859-15 -*-
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
from scipy.signal import savgol_filter # Savitzky–Golay filter, for smoothing data
|
||||
from scipy import ndimage as ndimage
|
||||
from glob import glob # The glob.glob returns the list of files with their full path
|
||||
import gzip
|
||||
import os
|
||||
from os.path import basename # returns the tail of the path
|
||||
from os.path import dirname
|
||||
from os.path import curdir
|
||||
from os.path import exists # logical for if a certain file exists
|
||||
from os.path import join as opj
|
||||
|
||||
sampling_rate = 1000.0 # in Hertz
|
||||
|
||||
|
||||
def preproc(infile, outfile, px2deg):
|
||||
# TODO parameter
|
||||
# max_signal_loss_without_something
|
||||
# blank_duration
|
||||
# savgol_window_length
|
||||
# savgol_polyord
|
||||
# savgol_iterations
|
||||
|
||||
outdir = dirname(outfile)
|
||||
outdir = curdir if not outdir else outdir
|
||||
if not exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
data = np.recfromcsv(
|
||||
infile,
|
||||
delimiter='\t',
|
||||
names=['x', 'y', 'pupil', 'frame'])
|
||||
|
||||
# for signal loss exceeding 20 ms, additional 10 ms at beginning
|
||||
# find clusters of "no data"
|
||||
clusters, nclusters = ndimage.label(np.isnan(data['x']))
|
||||
# go through all clusters and remove any cluster that is less than 20 samples
|
||||
for i in range(1, nclusters):
|
||||
if (clusters == i).sum() <= 20:
|
||||
clusters[clusters == i] = 0
|
||||
# mask to cover all samples with dataloss > 20ms, plus 10 samples on either
|
||||
# side of the lost segment
|
||||
mask = ndimage.binary_dilation(clusters > 0, iterations=10)
|
||||
data['x'][mask] = np.nan
|
||||
data['y'][mask] = np.nan
|
||||
data['pupil'][mask] = np.nan
|
||||
|
||||
# TODO filtering with NaNs in place kicks out additional datapoints, maybe
|
||||
# do no or less dilation of the mask above
|
||||
data['x'] = savgol_filter(data['x'], 19, 1)
|
||||
data['y'] = savgol_filter(data['y'], 19, 1)
|
||||
|
||||
#velocity calculation, exclude velocities over 1000
|
||||
|
||||
# euclidean distance between successive coordinate samples
|
||||
# no entry for first datapoint
|
||||
velocities = (np.diff(data['x']) ** 2 + np.diff(data['y']) ** 2) ** 0.5
|
||||
|
||||
# convert from px/msec to deg/s
|
||||
velocities *= px2deg * sampling_rate
|
||||
|
||||
# replace "too fast" velocities with previous velocity
|
||||
accelerations = [float(0)]
|
||||
filtered_velocities = [float(0)]
|
||||
for vel in velocities:
|
||||
# TODO make threshold a parameter
|
||||
if vel > 1000: # deg/s
|
||||
# ignore very fast velocities
|
||||
vel = filtered_velocities[-1]
|
||||
# acceleration is change of velocities over the last msec
|
||||
accelerations.append((vel - filtered_velocities[-1]) * sampling_rate)
|
||||
filtered_velocities.append(vel)
|
||||
# TODO report how often that happens
|
||||
|
||||
#save data to file
|
||||
data=np.array([
|
||||
filtered_velocities,
|
||||
accelerations,
|
||||
# TODO add time np.arange(len(filtered_velocities))
|
||||
data['x'],
|
||||
data['y']])
|
||||
|
||||
# TODO think about just saving it in binary form
|
||||
np.savetxt(
|
||||
outfile,
|
||||
data.T,
|
||||
fmt=['%f', '%f', '%f', '%f'],
|
||||
delimiter='\t')
|
||||
|
||||
if __name__ == '__main__':
|
||||
px2deg = float(sys.argv[1])
|
||||
infpath = sys.argv[2]
|
||||
outfpath = sys.argv[3]
|
||||
preproc(infpath, outfpath, px2deg)
|
||||
0
code/tests/__init__.py
Normal file
0
code/tests/__init__.py
Normal file
116
code/tests/test_detect.py
Normal file
116
code/tests/test_detect.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from . import utils as ut
|
||||
from .. import detect_events as d
|
||||
|
||||
|
||||
common_args = dict(
|
||||
px2deg=0.01,
|
||||
sampling_rate=1000.0,
|
||||
)
|
||||
|
||||
|
||||
def test_no_saccade():
|
||||
samp = np.random.randn(1000)
|
||||
data = ut.expand_samp(samp, y=0.0)
|
||||
clf = d.EyegazeClassifier(**common_args)
|
||||
p = clf.preproc(data, savgol_length=0.0, dilate_nan=0)
|
||||
# the entire segment is labeled as a fixation
|
||||
events = clf(p)
|
||||
assert len(events) == 1
|
||||
print(events)
|
||||
assert events[0]['end_time'] - events[0]['start_time'] == 1.0
|
||||
assert events[0]['label'] == 'FIXA'
|
||||
|
||||
# missing data split events
|
||||
p[500:510]['x'] = np.nan
|
||||
events = clf(p)
|
||||
assert len(events) == 2
|
||||
assert np.all([e['label'] == 'FIXA' for e in events])
|
||||
|
||||
# size doesn't matter
|
||||
p[500:800]['x'] = np.nan
|
||||
assert len(clf(p)) == len(events)
|
||||
|
||||
|
||||
def test_one_saccade():
|
||||
samp = ut.mk_gaze_sample()
|
||||
|
||||
data = ut.expand_samp(samp, y=0.0)
|
||||
clf = d.EyegazeClassifier(**common_args)
|
||||
p = clf.preproc(data, dilate_nan=0)
|
||||
events = clf(p)
|
||||
assert events is not None
|
||||
# we find at least the saccade
|
||||
events = ut.events2df(events)
|
||||
assert len(events) > 2
|
||||
if len(events) == 4:
|
||||
# full set
|
||||
assert list(events['label']) == ['FIXA', 'ISAC', 'ILPS', 'FIXA'] or \
|
||||
list(events['label']) == ['FIXA', 'ISAC', 'IHPS', 'FIXA']
|
||||
for i in range(0, len(events) - 1):
|
||||
# complete segmentation
|
||||
assert events['start_time'][i + 1] == events['end_time'][i]
|
||||
|
||||
|
||||
def test_too_long_pso():
|
||||
samp = ut.mk_gaze_sample(
|
||||
pre_fix=1000,
|
||||
post_fix=1000,
|
||||
sacc=20,
|
||||
sacc_dist=200,
|
||||
# just under 30deg/s (max smooth pursuit)
|
||||
pso=80,
|
||||
pso_dist=100)
|
||||
data = ut.expand_samp(samp, y=0.0)
|
||||
clf = d.EyegazeClassifier(
|
||||
max_initial_saccade_freq=.2,
|
||||
**common_args)
|
||||
p = clf.preproc(data, dilate_nan=0)
|
||||
events = clf(p)
|
||||
events = ut.events2df(events)
|
||||
# there is no PSO detected
|
||||
assert list(events['label']) == ['FIXA', 'SACC', 'FIXA']
|
||||
|
||||
|
||||
@pytest.mark.parametrize('infile', [
|
||||
'inputs/raw_eyegaze/sub-32/beh/sub-32_task-movie_run-2_recording-eyegaze_physio.tsv.gz',
|
||||
'inputs/raw_eyegaze/sub-09/ses-movie/func/sub-09_ses-movie_task-movie_run-2_recording-eyegaze_physio.tsv.gz',
|
||||
'inputs/raw_eyegaze/sub-02/ses-movie/func/sub-02_ses-movie_task-movie_run-5_recording-eyegaze_physio.tsv.gz',
|
||||
])
|
||||
def test_real_data(infile):
|
||||
data = np.recfromcsv(
|
||||
infile,
|
||||
delimiter='\t',
|
||||
names=['x', 'y', 'pupil', 'frame'])
|
||||
|
||||
clf = d.EyegazeClassifier(
|
||||
#px2deg=0.0185581232561,
|
||||
px2deg=0.0266711972026,
|
||||
sampling_rate=1000.0)
|
||||
p = clf.preproc(data)
|
||||
|
||||
events = clf(
|
||||
p[:50000],
|
||||
#p,
|
||||
)
|
||||
|
||||
evdf = ut.events2df(events)
|
||||
|
||||
labels = list(evdf['label'])
|
||||
# find all kinds of events
|
||||
for t in ('FIXA', 'PURS', 'SACC', 'LPSO', 'HPSO',
|
||||
'ISAC', 'IHPS'):
|
||||
# 'ILPS' one file doesn't have any
|
||||
assert t in labels
|
||||
return
|
||||
|
||||
ut.show_gaze(pp=p[:50000], events=events, px2deg=0.0185581232561)
|
||||
#ut.show_gaze(pp=p, events=events, px2deg=0.0185581232561)
|
||||
import pylab as pl
|
||||
saccades = evdf[evdf['label'] == 'SACC']
|
||||
isaccades = evdf[evdf['label'] == 'ISAC']
|
||||
print('#saccades', len(saccades), len(isaccades))
|
||||
pl.plot(saccades['amp'], saccades['peak_vel'], '.', alpha=.3)
|
||||
pl.plot(isaccades['amp'], isaccades['peak_vel'], '.', alpha=.3)
|
||||
pl.show()
|
||||
94
code/tests/test_nystrom.py
Normal file
94
code/tests/test_nystrom.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import numpy as np
|
||||
from . import utils as ut
|
||||
from .. import detect_events as d
|
||||
|
||||
|
||||
common_args = dict(
|
||||
# we don't know what the stimulus props are
|
||||
# variant1
|
||||
#px2deg=d.deg_per_pixel(0.516, 0.6, 1024),
|
||||
#sampling_rate=500.0,
|
||||
# variant2
|
||||
px2deg=d.deg_per_pixel(0.377, 0.67, 1024),
|
||||
sampling_rate=1250.0,
|
||||
)
|
||||
|
||||
|
||||
def test_target_data():
|
||||
label_remap = {
|
||||
1: 'FIXA',
|
||||
2: 'SACC',
|
||||
3: 'PSO',
|
||||
}
|
||||
clf = d.EyegazeClassifier(**common_args)
|
||||
data = np.recfromcsv(
|
||||
'inputs/nystrom_target/1_2.csv',
|
||||
usecols=[1, 2, 3, 4])
|
||||
events = []
|
||||
ev_type = None
|
||||
ev_start = None
|
||||
vels = []
|
||||
for i in range(len(data)):
|
||||
s = data[i]
|
||||
if ev_type is None and s['event_type'] in (1, 2):
|
||||
ev_type = s['event_type']
|
||||
ev_start = i
|
||||
elif ev_type is not None and s['event_type'] != ev_type:
|
||||
amp, pv, medvel = clf._get_signal_props(data[ev_start:i])
|
||||
events.append(dict(
|
||||
id=len(events),
|
||||
label=label_remap.get(ev_type),
|
||||
start_time=0.0 if ev_start is None else
|
||||
float(ev_start) / common_args['sampling_rate'],
|
||||
end_time=float(i) / common_args['sampling_rate'],
|
||||
peak_vel=pv,
|
||||
amp=amp,
|
||||
))
|
||||
vels = []
|
||||
ev_type = s['event_type'] if s['event_type'] in (1, 2) else None
|
||||
ev_start = i
|
||||
if ev_type:
|
||||
vels.append(s['vel'])
|
||||
ut.show_gaze(pp=data, events=events, **common_args)
|
||||
#for e in events:
|
||||
# print(e)
|
||||
import pylab as pl
|
||||
events = ut.events2df(events)
|
||||
saccades = events[events['label'] == 'SACC']
|
||||
isaccades = events[events['label'] == 'ISAC']
|
||||
print('#saccades', len(saccades), len(isaccades))
|
||||
pl.plot(saccades['amp'], saccades['peak_vel'], '.', alpha=.3)
|
||||
pl.plot(isaccades['amp'], isaccades['peak_vel'], '.', alpha=.3)
|
||||
pl.show()
|
||||
|
||||
def test_real_data():
|
||||
data = np.recfromcsv(
|
||||
'inputs/event_detector_1.1/1_2.csv',
|
||||
usecols=[0, 1])
|
||||
# when both coords are zero -> missing data
|
||||
data[np.logical_and(data['x'] == 0, data['y'] == 0)] = (np.nan, np.nan)
|
||||
|
||||
clf = d.EyegazeClassifier(
|
||||
min_intersaccade_duration=0.04,
|
||||
# high threshold, static stimuli, should not have pursuit
|
||||
max_fixation_amp=4.0,
|
||||
**common_args)
|
||||
p = clf.preproc(data, dilate_nan=0.03)
|
||||
|
||||
events = clf(p)
|
||||
|
||||
# TODO compare against output from original matlab code
|
||||
#return
|
||||
for e in events:
|
||||
print('{:.4f} -> {:.4f}: {} ({})'.format(
|
||||
e['start_time'], e['end_time'], e['label'], e['id']))
|
||||
|
||||
ut.show_gaze(pp=p, events=events, **common_args)
|
||||
import pylab as pl
|
||||
events = ut.events2df(events)
|
||||
saccades = events[events['label'] == 'SACC']
|
||||
isaccades = events[events['label'] == 'ISAC']
|
||||
print('#saccades', len(saccades), len(isaccades))
|
||||
pl.plot(saccades['amp'], saccades['peak_vel'], '.', alpha=.3)
|
||||
pl.plot(isaccades['amp'], isaccades['peak_vel'], '.', alpha=.3)
|
||||
pl.show()
|
||||
93
code/tests/test_preproc.py
Normal file
93
code/tests/test_preproc.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import numpy as np
|
||||
from . import utils as ut
|
||||
from .. import detect_events as d
|
||||
|
||||
|
||||
common_args = dict(
|
||||
px2deg=0.0185581232561,
|
||||
sampling_rate=1000.0,
|
||||
)
|
||||
|
||||
|
||||
def test_px2deg():
|
||||
assert (
|
||||
d.deg_per_pixel(26.5, 63, 1280) -
|
||||
# value from paper
|
||||
0.0185546875) < 0.0001
|
||||
|
||||
|
||||
def test_spike_filter():
|
||||
samp = np.random.randn(1000)
|
||||
data = ut.expand_samp(samp, y=0.0)
|
||||
p = d.filter_spikes(data.copy())
|
||||
assert np.std(data['x']) > np.std(p['x'])
|
||||
assert data['x'][0] == p['x'][0]
|
||||
assert data['x'][-1] == p['x'][-1]
|
||||
|
||||
|
||||
def test_preproc():
|
||||
samp = np.random.randn(1000)
|
||||
data = ut.expand_samp(samp, y=0.0)
|
||||
clf = d.EyegazeClassifier(**common_args)
|
||||
p = clf.preproc(data.copy(), savgol_length=0.019, savgol_polyord=1)
|
||||
# first values are always zero
|
||||
assert p[0]['vel'] == 0
|
||||
assert p[0]['accel'] == 0
|
||||
p = p['x']
|
||||
# shorter filter leaves more "noise"
|
||||
p_linshort = clf.preproc(
|
||||
data.copy(), savgol_length=0.009, savgol_polyord=1)['x']
|
||||
assert np.std(p) < np.std(p_linshort)
|
||||
# more flexible filter leaves more "noise"
|
||||
p_quad = clf.preproc(
|
||||
data.copy(), savgol_length=0.019, savgol_polyord=2)['x']
|
||||
assert np.std(p) < np.std(p_quad)
|
||||
|
||||
# insert small NaN patch
|
||||
data['x'][100:110] = np.nan
|
||||
assert np.sum(np.isnan(data['x'])) == 10
|
||||
p = clf.preproc(
|
||||
data.copy(), savgol_length=0.019, savgol_polyord=1,
|
||||
min_blink_duration=10.0)['x']
|
||||
# the original data does NOT change!
|
||||
assert np.sum(np.isnan(data['x'])) == 10
|
||||
# the gap will widen
|
||||
assert np.sum(np.isnan(p)) == 28
|
||||
# a wider filter will increase the gap, actual impact depends on
|
||||
# filter setup
|
||||
p = clf.preproc(
|
||||
data.copy(), savgol_length=0.101, savgol_polyord=1,
|
||||
min_blink_duration=10.0)['x']
|
||||
assert np.sum(np.isnan(p)) > 28
|
||||
# no widen the gap pre filtering (disable filter to test that)
|
||||
p = clf.preproc(
|
||||
data.copy(), savgol_length=0.001, savgol_polyord=0,
|
||||
min_blink_duration=0.0, dilate_nan=0.015)['x']
|
||||
# the original data still does NOT change!
|
||||
assert np.sum(np.isnan(data['x'])) == 10
|
||||
# the gap will widen
|
||||
assert np.sum(np.isnan(p)) == 10 + 2 * 15
|
||||
|
||||
# insert another small gap that we do not want to widen
|
||||
data['x'][200:202] = np.nan
|
||||
assert np.sum(np.isnan(data['x'])) == 12
|
||||
p = clf.preproc(
|
||||
data.copy(), savgol_length=0.001, savgol_polyord=0,
|
||||
min_blink_duration=0.008, dilate_nan=0.015)['x']
|
||||
assert np.sum(np.isnan(p)) == 10 + 2 * 15 + 2
|
||||
|
||||
samp = [0.0, 2.0]
|
||||
data = ut.expand_samp(samp, y=0.0)
|
||||
clf = d.EyegazeClassifier(px2deg=1.0, sampling_rate=10.0)
|
||||
p = clf.preproc(
|
||||
data.copy(), savgol_length=0, dilate_nan=0,
|
||||
median_filter_length=0)
|
||||
# 2 deg in 0.1s -> 20deg/s
|
||||
assert p['vel'][-1] == 20
|
||||
assert p['accel'][-1] == 200
|
||||
assert 'med_vel' not in p.dtype.names
|
||||
|
||||
data['x'][1] = 200
|
||||
p = clf.preproc(data.copy(), savgol_length=0, dilate_nan=0)
|
||||
assert p['vel'][-1] == 0
|
||||
assert p['accel'][-1] == 0
|
||||
132
code/tests/utils.py
Normal file
132
code/tests/utils.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import numpy as np
|
||||
import os
|
||||
|
||||
import logging
|
||||
lgr = logging.getLogger('studyforrest.utils')
|
||||
|
||||
|
||||
if 'NOISE_SEED' in os.environ:
|
||||
seed = int(os.environ['NOISE_SEED'])
|
||||
else:
|
||||
seed = np.random.randint(100000000)
|
||||
lgr.warn('RANDOM SEED: %i', seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def get_noise(size, loc, std):
|
||||
noise = np.random.randn(size)
|
||||
noise *= std
|
||||
noise += loc
|
||||
return noise
|
||||
|
||||
|
||||
def get_drift(size, start, dist):
|
||||
return np.linspace(start, start + dist, size)
|
||||
|
||||
|
||||
def mk_gaze_sample(
|
||||
pre_fix=1000,
|
||||
post_fix=1000,
|
||||
fix_std=5,
|
||||
sacc=20,
|
||||
sacc_dist=200,
|
||||
pso=30,
|
||||
pso_dist=-40,
|
||||
start_x=0.0,
|
||||
noise_std=2,
|
||||
):
|
||||
duration = pre_fix + sacc + pso + post_fix
|
||||
samp = np.empty(duration)
|
||||
# pre_fix
|
||||
t = 0
|
||||
pos = start_x
|
||||
samp[t:t + pre_fix] = get_noise(pre_fix, pos, fix_std)
|
||||
t += pre_fix
|
||||
# saccade
|
||||
samp[t:t + sacc] = get_drift(sacc, pos, sacc_dist)
|
||||
t += sacc
|
||||
pos += sacc_dist
|
||||
# pso
|
||||
samp[t:t + pso] = get_drift(pso, pos, pso_dist)
|
||||
t += pso
|
||||
pos += pso_dist
|
||||
# post fixation
|
||||
samp[t:t + post_fix] = get_noise(post_fix, pos, fix_std)
|
||||
samp += get_noise(len(samp), 0, noise_std)
|
||||
|
||||
return samp
|
||||
|
||||
|
||||
def expand_samp(samp, y=1000.0):
|
||||
n = len(samp)
|
||||
return np.core.records.fromarrays([
|
||||
samp,
|
||||
[y] * n,
|
||||
[0.0] * n,
|
||||
[0] * n],
|
||||
names=['x', 'y', 'pupil', 'frame'])
|
||||
|
||||
|
||||
def samp2file(data, fname):
|
||||
np.savetxt(
|
||||
fname,
|
||||
data.T,
|
||||
fmt=['%.1f', '%.1f', '%.1f', '%i'],
|
||||
delimiter='\t')
|
||||
|
||||
|
||||
def show_gaze(data=None, pp=None, events=None, px2deg=None, sampling_rate=1000.0):
|
||||
colors = {
|
||||
'FIXA': 'gray',
|
||||
'PURS': 'red',
|
||||
'SACC': 'green',
|
||||
'ISAC': 'pink',
|
||||
'HPSO': 'yellow',
|
||||
'IHPS': 'orange',
|
||||
'LPSO': 'cyan',
|
||||
'ILPS': 'blue',
|
||||
}
|
||||
|
||||
import pylab as pl
|
||||
if data is not None:
|
||||
pl.plot(
|
||||
np.linspace(0, len(data) / sampling_rate, len(data)),
|
||||
data['x'],
|
||||
color='blue')
|
||||
pl.plot(
|
||||
np.linspace(0, len(data) / sampling_rate, len(data)),
|
||||
data['y'],
|
||||
color='blue')
|
||||
if pp is not None:
|
||||
pl.plot(
|
||||
np.linspace(0, len(pp) / sampling_rate, len(pp)),
|
||||
pp['vel'],
|
||||
#(pp['accel'] / np.abs(pp['accel'][~np.isnan(pp['accel'])]).max()) * 1000,
|
||||
#(pp['accel'] / np.abs(pp['accel']).max()) * 1000,
|
||||
color='gray')
|
||||
pl.plot(
|
||||
np.linspace(0, len(pp) / sampling_rate, len(pp)),
|
||||
pp['x'],
|
||||
color='orange')
|
||||
pl.plot(
|
||||
np.linspace(0, len(pp) / sampling_rate, len(pp)),
|
||||
pp['y'],
|
||||
color='orange')
|
||||
#pl.plot(
|
||||
# np.linspace(0, len(pp) / sampling_rate, len(pp)),
|
||||
# pp['med_vel'],
|
||||
# color='black')
|
||||
if events is not None:
|
||||
for ev in events:
|
||||
pl.axvspan(
|
||||
ev['start_time'],
|
||||
ev['end_time'],
|
||||
color=colors[ev['label']],
|
||||
alpha=0.3)
|
||||
pl.text(ev['start_time'], 0, '{:.1f}'.format(ev['id']), color='red')
|
||||
pl.show()
|
||||
|
||||
|
||||
def events2df(events):
|
||||
import pandas as pd
|
||||
return pd.DataFrame(events)
|
||||
80916
inputs/event_detector_1.1/1_13.csv
Normal file
80916
inputs/event_detector_1.1/1_13.csv
Normal file
File diff suppressed because it is too large
Load diff
76306
inputs/event_detector_1.1/1_2.csv
Normal file
76306
inputs/event_detector_1.1/1_2.csv
Normal file
File diff suppressed because it is too large
Load diff
85915
inputs/event_detector_1.1/1_3.csv
Normal file
85915
inputs/event_detector_1.1/1_3.csv
Normal file
File diff suppressed because it is too large
Load diff
77698
inputs/event_detector_1.1/1_4.csv
Normal file
77698
inputs/event_detector_1.1/1_4.csv
Normal file
File diff suppressed because it is too large
Load diff
80216
inputs/event_detector_1.1/1_7.csv
Normal file
80216
inputs/event_detector_1.1/1_7.csv
Normal file
File diff suppressed because it is too large
Load diff
76460
inputs/event_detector_1.1/1_9.csv
Normal file
76460
inputs/event_detector_1.1/1_9.csv
Normal file
File diff suppressed because it is too large
Load diff
94950
inputs/event_detector_1.1/2_1.csv
Normal file
94950
inputs/event_detector_1.1/2_1.csv
Normal file
File diff suppressed because it is too large
Load diff
87419
inputs/event_detector_1.1/2_2.csv
Normal file
87419
inputs/event_detector_1.1/2_2.csv
Normal file
File diff suppressed because it is too large
Load diff
81404
inputs/event_detector_1.1/2_3.csv
Normal file
81404
inputs/event_detector_1.1/2_3.csv
Normal file
File diff suppressed because it is too large
Load diff
101135
inputs/event_detector_1.1/2_4.csv
Normal file
101135
inputs/event_detector_1.1/2_4.csv
Normal file
File diff suppressed because it is too large
Load diff
76306
inputs/nystrom_target/1_2.csv
Normal file
76306
inputs/nystrom_target/1_2.csv
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue