%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
%matplotlib inline
import spatial_maps as sp
import septum_mec.analysis.data_processing as dp
import septum_mec.analysis.registration
import expipe
import os
import pathlib
import scipy
import scipy.signal
import numpy as np
import exdir
import pandas as pd
import optogenetics as og
import quantities as pq
import shutil
from distutils.dir_util import copy_tree
import elephant as el
import neo
from scipy.signal import find_peaks
from scipy.interpolate import interp1d
from matplotlib import mlab
from tqdm import tqdm_notebook as tqdm
from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()
data_loader = dp.Data()
actions = data_loader.actions
project = data_loader.project
output = pathlib.Path('output/stimulus-spike-lfp-response-other-tetrode')
(output / 'data').mkdir(parents=True, exist_ok=True)
identify_neurons = actions['identify-neurons']
# sessions = pd.read_csv(identify_neurons.data_path('sessions'))
units = pd.read_csv(identify_neurons.data_path('units'))
def get_lim(action_id):
stim_times = data_loader.stim_times(action_id)
if stim_times is None:
return [0, np.inf]
stim_times = np.array(stim_times)
return [stim_times.min(), stim_times.max()]
def get_mask(lfp, lim):
return (lfp.times >= lim[0]) & (lfp.times <= lim[1])
def zscore(a):
return (a - a.mean(0)) / a.std(0)
def compute_stim_freq(action_id):
stim_times = data_loader.stim_times(action_id)
if stim_times is None:
return np.nan
stim_times = np.array(stim_times)
return 1 / np.mean(np.diff(stim_times))
def signaltonoise(a, axis=0, ddof=0):
a = np.asanyarray(a)
m = a.mean(axis)
sd = a.std(axis=axis, ddof=ddof)
return np.where(sd == 0, 0, m / sd)
def compute_energy(p, f, f1, f2):
if np.isnan(f1) or np.all(np.isnan(p)):
return np.nan
mask = (f > f1) & (f < f2)
df = f[1] - f[0]
return np.sum(p[mask]) * df
def find_theta_peak(p, f, f1, f2):
if np.all(np.isnan(p)):
return np.nan, np.nan
mask = (f > f1) & (f < f2)
p_m = p[mask]
f_m = f[mask]
peaks, _ = find_peaks(p_m)
idx = np.argmax(p_m[peaks])
return f_m[peaks[idx]], p_m[peaks[idx]]
def compute_half_width(p, f, m_p, m_f):
if np.isnan(m_p):
return np.nan, np.nan
m_p_half = m_p / 2
half_p = p - m_p_half
idx_f = np.where(f <= m_f)[0].max()
idxs_p1, = np.where(np.diff(half_p[:idx_f + 1] > 0) == 1)
if len(idxs_p1) == 0:
return np.nan, np.nan
m1 = idxs_p1.max()
idxs_p2, = np.where(np.diff(half_p[idx_f:] > 0) == 1)
m2 = idxs_p2.min() + idx_f
assert p[m1] < m_p_half < p[m1+1], (p[m1], m_p_half, p[m1+1])
assert p[m2] > m_p_half > p[m2+1], (p[m2], m_p_half, p[m2+1])
f1 = interp1d([half_p[m1], half_p[m1 + 1]], [f[m1], f[m1 + 1]])(0)
f2 = interp1d([half_p[m2], half_p[m2 + 1]], [f[m2], f[m2 + 1]])(0)
return f1, f2
# p = np.load('debug_p.npy')
# f = np.load('debug_f.npy')
# compute_half_width(p, f, 0.01038941, 30.30187709636872)
# plt.plot(f, p)
# plt.xlim(29.9,30.6)
def compute_stim_peak(p, f, s_f):
if np.isnan(s_f):
return np.nan
return interp1d(f, p)(s_f)
def compute_spike_lfp_coherence(anas, sptr, NFFT):
sigs, freqs = el.sta.spike_field_coherence(anas, sptr, **{'nperseg': NFFT})
return sigs, freqs
def butter_bandpass(lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = scipy.signal.butter(order, [low, high], btype='band')
return b, a
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
b, a = butter_bandpass(lowcut, highcut, fs, order=order)
y = scipy.signal.filtfilt(b, a, data)
return y
# def compute_spike_phase_func(lfp, times, return_degrees=False):
# x_a = hilbert(lfp)
# x_phase = np.angle(x_a)
# if return_degrees:
# x_phase = x_phase * 180 / np.pi
# return interp1d(times, x_phase)
def vonmises_kde(data, kappa=100, n_bins=100):
from scipy.special import i0
bins = np.linspace(-np.pi, np.pi, n_bins)
x = np.linspace(-np.pi, np.pi, n_bins)
# integrate vonmises kernels
kde = np.exp(kappa * np.cos(x[:, None] - data[None, :])).sum(1) / (2 * np.pi * i0(kappa))
kde /= np.trapz(kde, x=bins)
return bins, kde
def spike_phase_score(phase_bins, density):
import math
import pycircstat as pc
ang = pc.mean(phase_bins, w=density)
vec_len = pc.resultant_vector_length(phase_bins, w=density)
# ci_lim = pc.mean_ci_limits(head_angle_bins, w=rate)
return ang, vec_len
def compute_clean_lfp(anas, width=500, threshold=2):
anas = np.array(anas)
idxs, = np.where(abs(anas) > threshold)
for idx in idxs:
anas[idx-width:idx+width] = 0 # TODO AR model prediction
return anas, idxs
def compute_clean_spikes(spikes, idxs, times, width=500):
for idx in idxs:
t0 = times[idx-width]
stop = idx + width
if stop > len(times) - 1:
stop = len(times) - 1
t1 = times[stop]
mask = (spikes > t0) & (spikes < t1)
spikes = spikes[~mask]
spikes = spikes[spikes <= times[-1]]
return spikes
def prepare_spike_lfp(anas, sptr, t_start, t_stop):
t_start = t_start * pq.s if t_start is not None else 0 * pq.s
sampling_rate = anas.sampling_rate
units = anas.units
times = anas.times
if t_start is not None and t_stop is not None:
t_stop = t_stop * pq.s
mask = (times > t_start) & (times < t_stop)
anas = np.array(anas)[mask,:]
times = times[mask]
# take best channel from other drive
best_channel = np.argmax(signaltonoise(anas))
# best_channel = np.random.choice(anas.shape[1])
cleaned_anas, idxs = compute_clean_lfp(anas[:, best_channel])
cleaned_anas = neo.AnalogSignal(
signal=cleaned_anas * units, sampling_rate=sampling_rate, t_start=t_start
)
spike_units = sptr.units
spike_times = sptr.times
spike_times = compute_clean_spikes(spike_times, idxs, times)
sptr = neo.SpikeTrain(
spike_times[(spike_times > t_start) & (spike_times < times[-1])], units=spike_units,
t_start=t_start, t_stop=times[-1]
)
return cleaned_anas, sptr, best_channel
def compute_spike_phase_func(lfp, times, return_degrees=False):
from scipy.fftpack import next_fast_len
x_a = scipy.signal.hilbert(
lfp, next_fast_len(len(lfp)))[:len(lfp)]
# x_a = hilbert(lfp)
x_phase = np.angle(x_a, deg=return_degrees)
return interp1d(times, x_phase)
def compute_spike_phase(lfp, spikes, flim=[6,10]):
sample_rate = lfp.sampling_rate.magnitude
# sometimes the position is recorded after LFP recording is ended
times = np.arange(lfp.shape[0]) / sample_rate
spikes = np.array(spikes)
spikes = spikes[(spikes > times.min()) & (spikes < times.max())]
filtered_lfp = butter_bandpass_filter(
lfp.magnitude.ravel(), *flim, fs=sample_rate, order=3)
spike_phase_func = compute_spike_phase_func(filtered_lfp, times)
return spike_phase_func(spikes), filtered_lfp
plt.figure(figsize=(16,9))
lfp = data_loader.lfp('1833-200619-2', 6)
# lfp = data_loader.lfp('1834-220319-3', 6)
# lfp = data_loader.lfp('1849-010319-4', 6)
times = np.arange(lfp.shape[0]) / lfp.sampling_rate.magnitude
clean_lfp, _ = compute_clean_lfp(lfp.magnitude[:, 0], threshold=2)
plt.plot(times,lfp[:,0])
plt.plot(times,clean_lfp)
plt.figure(figsize=(16,9))
plt.psd(lfp[:,0].ravel(), Fs=1000, NFFT=10000)
plt.psd(clean_lfp, Fs=1000, NFFT=10000)
plt.xlim(0,100)
# plt.figure(figsize=(16,9))
# plt.plot(times,lfp[:,0])
# # plt.plot(clean_lfp*100)
# plt.plot(times[:-1], np.diff(lfp[:,0].magnitude.ravel()))
# plt.xlim(64.5,65.5)
# # plt.ylim(-250,250)
drive_0_channel_groups = [0, 1, 2, 3]
drive_1_channel_groups = [4, 5, 6, 7]
# action_id_0, channel_0, unit_0 = '1833-200619-1', 6, 163
# action_id_1, channel_1, unit_1 = '1833-200619-2', 6, 28
action_id_0, channel_0, unit_0 = '1834-220319-3', 2, 46
action_id_1, channel_1, unit_1 = '1834-220319-4', 2, 60
# lfp_0 = data_loader.lfp(action_id_0, channel_0)
# lfp_1 = data_loader.lfp(action_id_1, channel_1)
# select best channel among other tetrodes in same drive
if channel_0 in drive_0_channel_groups:
lfps = []
for ch in drive_0_channel_groups:
if channel_0 != ch:
lfps.append(data_loader.lfp(action_id_0, ch))
elif channel_0 in drive_1_channel_groups:
lfps = []
for ch in drive_1_channel_groups:
if channel_0 != ch:
lfps.append(data_loader.lfp(action_id_0, ch))
lfps_0 = np.hstack(lfps)
if channel_1 in drive_0_channel_groups:
lfps = []
for ch in drive_0_channel_groups:
if channel_1 != ch:
lfps.append(data_loader.lfp(action_id_1, ch))
elif channel_1 in drive_1_channel_groups:
lfps = []
for ch in drive_1_channel_groups:
if channel_1 != ch:
lfps.append(data_loader.lfp(action_id_0, ch))
lfps_1 = np.hstack(lfps)
sample_rate_0 = lfp_0.sampling_rate
sample_rate_1 = lfp_1.sampling_rate
lim_0 = get_lim(action_id_0)
lim_1 = get_lim(action_id_1)
sptrs_0 = data_loader.spike_trains(action_id_0, channel_0)
sptrs_1 = data_loader.spike_trains(action_id_1, channel_1)
cleaned_lfps_0, sptr_0, best_channel_0 = prepare_spike_lfp(lfp_0, sptrs_0[unit_0], *lim_0)
cleaned_lfps_1, sptr_1, best_channel_1 = prepare_spike_lfp(lfp_1, sptrs_1[unit_1], *lim_1)
coher_0, freq_0 = compute_spike_lfp_coherence(cleaned_lfps_0, sptr_0, 4096)
coher_1, freq_1 = compute_spike_lfp_coherence(cleaned_lfps_1, sptr_1, 4096)
spike_phase_0, filtered_lfp_0 = compute_spike_phase(cleaned_lfps_0, sptrs_0[unit_0], flim=[6,10])
spike_phase_1, filtered_lfp_1 = compute_spike_phase(cleaned_lfps_1, sptrs_1[unit_1], flim=[6,10])
spike_phase_1_stim, filtered_lfp_1_stim = compute_spike_phase(cleaned_lfps_1, sptrs_1[unit_1], flim=[10.5,11.5])
# spike_phase_1_stim, filtered_lfp_1_stim = compute_spike_phase(cleaned_lfps_1, sptrs_1[unit_1], flim=[29.5,30.5])
plt.figure()
plt.plot(freq_0, coher_0.ravel())
plt.plot(freq_1, coher_1.ravel())
plt.xlim(0,20)
plt.figure()
bins_0, kde_0 = vonmises_kde(spike_phase_0, 100)
ang_0, vec_len_0 = spike_phase_score(bins_0, kde_0)
plt.polar(bins_0, kde_0, color='b')
plt.polar([ang_0, ang_0], [0, vec_len_0], color='b')
bins_1, kde_1 = vonmises_kde(spike_phase_1, 100)
ang_1, vec_len_1 = spike_phase_score(bins_1, kde_1)
plt.polar(bins_1, kde_1, color='r')
plt.polar([ang_1, ang_1], [0, vec_len_1], color='r')
bins_1_stim, kde_1_stim = vonmises_kde(spike_phase_1_stim, 100)
ang_1_stim, vec_len_1_stim = spike_phase_score(bins_1_stim, kde_1_stim)
plt.polar(bins_1_stim, kde_1_stim, color='k')
plt.polar([ang_1_stim, ang_1_stim], [0, vec_len_1_stim], color='k')
lfp_0.units
if channel_0 in drive_0_channel_groups:
lfps = []
for ch in drive_0_channel_groups:
if channel_0 != ch:
lfps.append(data_loader.lfp(action_id_0, ch))
elif channel_0 in drive_1_channel_groups:
lfps = []
for ch in drive_1_channel_groups:
if channel_0 != ch:
lfps.append(data_loader.lfp(action_id_0, ch))
lfps_0 = np.hstack(lfps)
if channel_1 in drive_0_channel_groups:
lfps = []
for ch in drive_0_channel_groups:
if channel_1 != ch:
lfps.append(data_loader.lfp(action_id_1, ch))
elif channel_1 in drive_1_channel_groups:
lfps = []
for ch in drive_1_channel_groups:
if channel_1 != ch:
lfps.append(data_loader.lfp(action_id_0, ch))
lfps_1 = np.hstack(lfps)
lfps.shape
lfps.shape
# TODO fix artefact stuff from phase precession
NFFT = 8192
theta_band_f1, theta_band_f2 = 6, 10
drive_0_channel_groups = [0, 1, 2, 3]
drive_1_channel_groups = [4, 5, 6, 7]
coherence_data, freqency_data = {}, {}
theta_kde_data, theta_bins_data = {}, {}
stim_kde_data, stim_bins_data = {}, {}
def process(row):
action_id = row['action']
channel_group = row['channel_group']
unit_name = row['unit_name']
name = f'{action_id}_{channel_group}_{unit_name}'
# select best channel among other tetrodes in same drive
if channel_group in drive_0_channel_groups:
lfps = []
for ch in drive_0_channel_groups:
if channel_group != ch:
lfps.append(data_loader.lfp(action_id, ch))
elif channel_group in drive_1_channel_groups:
lfps = []
for ch in drive_1_channel_groups:
if channel_group != ch:
lfps.append(data_loader.lfp(action_id, ch))
# merge lfp of other tetrodes into a signle AnalogSignal
lfp_arrays = np.hstack(lfps).as_array()
lfp = neo.AnalogSignal(signal=lfp_arrays * lfps[0].units, sampling_rate=lfps[0].sampling_rate,
t_start=lfps[0].t_start)
sptr = data_loader.spike_train(action_id, channel_group, unit_name)
lim = get_lim(action_id)
cleaned_lfp, sptr, best_channel = prepare_spike_lfp(lfp, sptr, *lim)
p_xys, freq = compute_spike_lfp_coherence(cleaned_lfp, sptr, NFFT=NFFT)
p_xy = p_xys.magnitude.ravel()
freq = freq.magnitude
theta_f, theta_p_max = find_theta_peak(p_xy, freq, theta_band_f1, theta_band_f2)
theta_energy = compute_energy(p_xy, freq, theta_band_f1, theta_band_f2) # theta band 6 - 10 Hz
theta_half_f1, theta_half_f2 = compute_half_width(p_xy, freq, theta_p_max, theta_f)
theta_half_width = theta_half_f2 - theta_half_f1
theta_half_energy = compute_energy(p_xy, freq, theta_half_f1, theta_half_f2) # theta band 6 - 10 Hz
theta_spike_phase, _ = compute_spike_phase(cleaned_lfp, sptr, flim=[theta_band_f1, theta_band_f2])
theta_bins, theta_kde = vonmises_kde(theta_spike_phase)
theta_ang, theta_vec_len = spike_phase_score(theta_bins, theta_kde)
theta_kde_data.update({name: theta_kde})
theta_bins_data.update({name: theta_bins})
# stim
stim_freq = compute_stim_freq(action_id)
stim_p_max = compute_stim_peak(p_xy, freq, stim_freq)
stim_half_f1, stim_half_f2 = compute_half_width(p_xy, freq, stim_p_max, stim_freq)
stim_half_width = stim_half_f2 - stim_half_f1
stim_energy = compute_energy(p_xy, freq, stim_half_f1, stim_half_f2)
if np.isnan(stim_freq):
stim_spike_phase, stim_bins, stim_kde, stim_ang, stim_vec_len = [np.nan] * 5
else:
stim_spike_phase, _ = compute_spike_phase(cleaned_lfp, sptr, flim=[stim_freq - .5, stim_freq + .5])
stim_bins, stim_kde = vonmises_kde(stim_spike_phase)
stim_ang, stim_vec_len = spike_phase_score(stim_bins, stim_kde)
stim_kde_data.update({name: stim_kde})
stim_bins_data.update({name: stim_bins})
coherence_data.update({name: p_xy})
freqency_data.update({name: freq})
result = pd.Series({
'best_channel': best_channel,
'theta_freq': theta_f,
'theta_peak': theta_p_max,
'theta_energy': theta_energy,
'theta_half_f1': theta_half_f1,
'theta_half_f2': theta_half_f2,
'theta_half_width': theta_half_width,
'theta_half_energy': theta_half_energy,
'theta_ang': theta_ang,
'theta_vec_len': theta_vec_len,
'stim_freq': stim_freq,
'stim_p_max': stim_p_max,
'stim_half_f1': stim_half_f1,
'stim_half_f2': stim_half_f2,
'stim_half_width': stim_half_width,
'stim_energy': stim_energy,
'stim_ang': stim_ang,
'stim_vec_len': stim_vec_len
})
return result
results = units.merge(
units.progress_apply(process, axis=1),
left_index=True, right_index=True)
# coher, freqs = {}, {}
# for i, row in tqdm(units.iterrows(), total=len(units)):
# action_id = row['action']
# channel_group = row['channel_group']
# unit_name = row['unit_name']
# name = f'{action_id}_{channel_group}_{unit_name}'
# lfp = data_loader.lfp(action_id, channel_group) # TODO consider choosing strongest stim response
# sptr = data_loader.spike_train(action_id, channel_group, unit_name)
# lim = get_lim(action_id)
# p_xys, freq, clean_lfp = compute_spike_lfp(lfp, sptr, *lim, NFFT=NFFT)
# snls = signaltonoise(clean_lfp)
# best_channel = np.argmax(snls)
# snl = snls[best_channel]
# p_xy = p_xys[:,best_channel].magnitude
# freq = freq.magnitude
# coher.update({name: p_xy})
# freqs.update({name: freq})
pd.DataFrame(coherence_data).to_feather(output / 'data' / 'coherence.feather')
pd.DataFrame(freqency_data).to_feather(output / 'data' / 'freqs.feather')
pd.DataFrame(theta_kde_data).to_feather(output / 'data' / 'theta_kde.feather')
pd.DataFrame(theta_bins_data).to_feather(output / 'data' / 'theta_bins.feather')
pd.DataFrame(stim_kde_data).to_feather(output / 'data' / 'stim_kde.feather')
pd.DataFrame(stim_bins_data).to_feather(output / 'data' / 'stim_bins.feather')
action = project.require_action("stimulus-spike-lfp-response-other-tetrode")
action.modules['parameters'] = {
'NFFT': NFFT,
'theta_band_f1': theta_band_f1,
'theta_band_f2': theta_band_f2
}
action.data['results'] = 'results.csv'
results.to_csv(action.data_path('results'), index=False)
copy_tree(output, str(action.data_path()))
septum_mec.analysis.registration.store_notebook(action, "10-calculate-stimulus-spike-lfp-response-other-tetrode.ipynb")