In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns
import re
import shutil
import pandas as pd
import scipy.stats

import exdir
import expipe
from distutils.dir_util import copy_tree
import septum_mec
import spatial_maps as sp
import head_direction.head as head
import septum_mec.analysis.data_processing as dp
import septum_mec.analysis.registration
from septum_mec.analysis.plotting import violinplot, despine
from spatial_maps.fields import (
    find_peaks, calculate_field_centers, separate_fields_by_laplace, 
    map_pass_to_unit_circle, calculate_field_centers, distance_to_edge_function, 
    which_field, compute_crossings)
from phase_precession import cl_corr
from spike_statistics.core import permutation_resampling
import matplotlib.mlab as mlab
import scipy.signal as ss
from scipy.interpolate import interp1d
from septum_mec.analysis.plotting import regplot
from skimage import measure
from tqdm.notebook import tqdm_notebook as tqdm
tqdm.pandas()

import pycwt
09:52:58 [I] klustakwik KlustaKwik2 version 0.2.6
/home/mikkel/.virtualenvs/expipe/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject
  return f(*args, **kwds)
/home/mikkel/.virtualenvs/expipe/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject
  return f(*args, **kwds)
/home/mikkel/.virtualenvs/expipe/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject
  return f(*args, **kwds)
In [3]:
max_speed = 1 # m/s only used for speed score
min_speed = 0.02 # m/s only used for speed score
position_sampling_rate = 1000 # for interpolation
position_low_pass_frequency = 6 # for low pass filtering of position

box_size = [1.0, 1.0]
bin_size = 0.02

speed_binsize = 0.02

stim_mask = True
baseline_duration = 600
In [4]:
data_loader = dp.Data(
    position_sampling_rate=position_sampling_rate, 
    position_low_pass_frequency=position_low_pass_frequency,
    box_size=box_size, bin_size=bin_size, 
    stim_mask=stim_mask, baseline_duration=baseline_duration
)
In [5]:
project_path = dp.project_path()
project = expipe.get_project(project_path)
actions = project.actions

output_path = pathlib.Path("output") / "lfp-speed-stim"
(output_path / "statistics").mkdir(exist_ok=True, parents=True)
(output_path / "figures").mkdir(exist_ok=True, parents=True)
In [6]:
identify_neurons = actions['identify-neurons']
sessions = pd.read_csv(identify_neurons.data_path('sessions'))
In [7]:
channel_groups = []
for i, row in sessions.iterrows():
    for ch in range(8):
        row['channel_group'] = ch
        channel_groups.append(row.to_dict())
In [8]:
sessions = pd.DataFrame(channel_groups)
In [9]:
def signaltonoise(a, axis=0, ddof=0):
    a = np.asanyarray(a)
    m = a.mean(axis)
    sd = a.std(axis=axis, ddof=ddof)
    return np.where(sd == 0, 0, m / sd)


def remove_artifacts(anas, spikes=None, width=500, threshold=2, sampling_rate=None, fillval=0):
    sampling_rate = sampling_rate or anas.sampling_rate.magnitude
    times = np.arange(anas.shape[0]) / sampling_rate
    anas = np.array(anas)
    if anas.ndim == 1:
        anas = np.reshape(anas, (anas.size, 1))
    assert len(times) == anas.shape[0]
    nchan = anas.shape[1]
    if spikes is not None:
        spikes = np.array(spikes)
    for ch in range(nchan):
        idxs, = np.where(abs(anas[:, ch]) > threshold)
        for idx in idxs:
            if spikes is not None:
                t0 = times[idx-width]
                stop = idx+width
                if stop > len(times) - 1:
                    stop = len(times) - 1 
                t1 = times[stop]
                mask = (spikes > t0) & (spikes < t1)
                spikes = spikes[~mask]
            anas[idx-width:idx+width, ch] = fillval
    if spikes is not None:
        spikes = spikes[spikes <= times[-1]]
        return anas, times, spikes
    else:
        return anas, times
    
def find_theta_peak(p, f, f1, f2):
    if np.all(np.isnan(p)):
        return np.nan, np.nan
    mask = (f > f1) & (f < f2)
    p_m = p[mask]
    f_m = f[mask]
    peaks = find_peaks(p_m)
    idx = np.argmax(p_m[peaks])
    return f_m[peaks[idx]], p_m[peaks[idx]]
In [10]:
def zscore(a):
    return (a - a.mean()) / a.std()
#     return a
In [11]:
def compute_stim_freq(action_id):
    stim_times = data_loader.stim_times(action_id)
    if stim_times is None:
        return
    stim_times = np.array(stim_times)
    return 1 / np.mean(np.diff(stim_times))
In [12]:
output = exdir.File(output_path / 'results')

mother = pycwt.Morlet(80)
NFFT = 2056

def process(row):
    name = row['action'] + '-' + str(row['channel_group'])
    stim_freq = compute_stim_freq(row['action'])
    if stim_freq is None:
        return
    
    flim = [stim_freq - 2, stim_freq + 2]
    
    lfp = data_loader.lfp(row.action, row.channel_group)
    sample_rate = lfp.sampling_rate.magnitude
    sampling_period = 1 / sample_rate
    x, y, t, speed = map(data_loader.tracking(row.action).get, ['x', 'y', 't', 'v'])
    cleaned_lfp, times = remove_artifacts(lfp)
    speed = interp1d(t, speed, bounds_error=False, fill_value='extrapolate')(times)
    peak_amp = {}
    for i, ch in enumerate(cleaned_lfp.T):
        pxx, freqs = mlab.psd(ch, Fs=lfp.sampling_rate.magnitude, NFFT=4000)
        f, p = find_theta_peak(pxx, freqs, 6, 10)
        peak_amp[i] = p

    theta_channel = max(peak_amp, key=lambda x: peak_amp[x])
    signal = zscore(cleaned_lfp[:,theta_channel])
    
    if name in output:
        return
    
    
    results = output.require_group(name)
    freqs = np.arange(*flim, .1)
    wave, scales, freqs, coi, fft, fftfreqs = pycwt.cwt(
        signal, sampling_period, freqs=freqs, wavelet=mother)
    
    power = (np.abs(wave)) ** 2
    power /= scales[:, None] #rectify the power spectrum according to the suggestions proposed by Liu et al. (2007)
    
    theta_freq = np.array([freqs[i] for i in np.argmax(power, axis=0)])
    theta_power = np.mean(power, axis=0)

    speed_bins = np.arange(min_speed, max_speed + speed_binsize, speed_binsize)
    ia = np.digitize(speed, bins=speed_bins, right=True)
    mean_freq = np.zeros_like(speed_bins)
    mean_power = np.zeros_like(speed_bins)
    for i in range(len(speed_bins)):
        mean_freq[i] = np.mean(theta_freq[ia==i])
        mean_power[i] = np.mean(theta_power[ia==i])
        
    freq_score = np.corrcoef(speed, theta_freq)[0,1]
    power_score = np.corrcoef(speed, theta_power)[0,1]
    
    results.attrs = {
        'freq_score': float(freq_score),
        'sample_rate': float(sample_rate),
        'power_score': float(power_score),
        'action': row['action'],
        'channel_group': int(row['channel_group']),
        'max_speed': max_speed,
        'min_speed': min_speed,
        'position_low_pass_frequency': position_low_pass_frequency
    }
    
    results.create_dataset('wavelet_power', data=power)
    results.create_dataset('wavelet_freqs', data=freqs)
    results.create_dataset('theta_freq', data=theta_freq)
    results.create_dataset('theta_power', data=theta_power)
    results.create_dataset('speed', data=speed)
    results.create_dataset('mean_freq', data=mean_freq)
    results.create_dataset('mean_power', data=mean_power)
    results.create_dataset('speed_bins', data=speed_bins)
In [13]:
sessions.progress_apply(process, axis=1);
/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)

Store results in Expipe action

In [14]:
action = project.require_action("lfp_speed_stim")
In [ ]:
action.data["results"] = "results.exdir"
copy_tree(output_path, str(action.data_path()))
In [ ]:
septum_mec.analysis.registration.store_notebook(action, "10_lfp_speed_stim.ipynb")
In [ ]: