%load_ext autoreload
%autoreload 2
import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns
import re
import shutil
import pandas as pd
import scipy.stats
import exdir
import expipe
from distutils.dir_util import copy_tree
import septum_mec
import spatial_maps as sp
import head_direction.head as head
import septum_mec.analysis.data_processing as dp
import septum_mec.analysis.registration
from septum_mec.analysis.plotting import violinplot, despine
from spatial_maps.fields import find_peaks, calculate_field_centers, separate_fields_by_laplace
from spike_statistics.core import permutation_resampling
from tqdm import tqdm_notebook as tqdm
from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()
project_path = dp.project_path()
project = expipe.get_project(project_path)
actions = project.actions
output_path = pathlib.Path("output") / "longitudinal-comparisons-gridcells"
(output_path / "statistics").mkdir(exist_ok=True, parents=True)
(output_path / "figures").mkdir(exist_ok=True, parents=True)
statistics_action = actions['calculate-statistics']
identification_action = actions['identify-neurons']
sessions = pd.read_csv(identification_action.data_path('sessions'))
units = pd.read_csv(identification_action.data_path('units'))
session_units = pd.merge(sessions, units, on='action')
statistics_results = pd.read_csv(statistics_action.data_path('results'))
statistics = pd.merge(session_units, statistics_results, how='left')
statistics.head()
statistics['unit_day'] = statistics.apply(lambda x: str(x.unit_idnum) + '_' + x.action.split('-')[1], axis=1)
stim_response_action = actions['stimulus-response']
stim_response_results = pd.read_csv(stim_response_action.data_path('results'))
statistics = pd.merge(statistics, stim_response_results, how='left')
print('N cells:',statistics.shape[0])
shuffling = actions['shuffling']
quantiles_95 = pd.read_csv(shuffling.data_path('quantiles_95'))
quantiles_95.head()
action_columns = ['action', 'channel_group', 'unit_name']
data = pd.merge(statistics, quantiles_95, on=action_columns, suffixes=("", "_threshold"))
data['specificity'] = np.log10(data['in_field_mean_rate'] / data['out_field_mean_rate'])
data.head()
data.groupby('stimulated').count()['action']
query = (
'gridness > gridness_threshold and '
'information_rate > information_rate_threshold and '
'gridness > .2 and '
'average_rate < 25'
)
sessions_above_threshold = data.query(query)
print("Number of sessions above threshold", len(sessions_above_threshold))
print("Number of animals", len(sessions_above_threshold.groupby(['entity'])))
once_a_gridcell = statistics[statistics.unit_day.isin(sessions_above_threshold.unit_day.values)]
print("Number of gridcells", once_a_gridcell.unit_idnum.nunique())
print("Number of gridcell recordings", len(once_a_gridcell))
print("Number of animals", len(once_a_gridcell.groupby(['entity'])))
baseline_i = once_a_gridcell.query('baseline and Hz11')
stimulated_11 = once_a_gridcell.query('stimulated and frequency==11 and stim_location=="ms"')
baseline_ii = once_a_gridcell.query('baseline and Hz30')
stimulated_30 = once_a_gridcell.query('stimulated and frequency==30 and stim_location=="ms"')
print("Number of gridcells in baseline i sessions", len(baseline_i))
print("Number of gridcells in stimulated 11Hz ms sessions", len(stimulated_11))
print("Number of gridcells in baseline ii sessions", len(baseline_ii))
print("Number of gridcells in stimulated 30Hz ms sessions", len(stimulated_30))
max_speed = 1, # m/s only used for speed score
min_speed = 0.02, # m/s only used for speed score
position_sampling_rate = 100 # for interpolation
position_low_pass_frequency = 6 # for low pass filtering of position
box_size = [1.0, 1.0]
bin_size = 0.02
smoothing_low = 0.03
smoothing_high = 0.06
stim_mask = True
baseline_duration = 600
data_loader = dp.Data(
position_sampling_rate=position_sampling_rate,
position_low_pass_frequency=position_low_pass_frequency,
box_size=box_size, bin_size=bin_size,
stim_mask=stim_mask, baseline_duration=baseline_duration
)
def fftcorrelate2d(arr1, arr2, normalize=False, **kwargs):
from copy import copy
arr1 = copy(arr1)
arr2 = copy(arr2)
from astropy.convolution import convolve_fft
if normalize:
# https://stackoverflow.com/questions/53436231/normalized-cross-correlation-in-python
a_ = arr1.ravel()
v_ = arr2.ravel()
arr1 = (arr1 - np.mean(a_)) / (np.std(a_) * len(a_))
arr2 = (arr2 - np.mean(v_)) / np.std(v_)
corr = convolve_fft(arr1, np.fliplr(np.flipud(arr2)), normalize_kernel=False, **kwargs)
return corr
def cross_correlation_distance(r1, r2):
r12 = fftcorrelate2d(r1, r2)
labels = separate_fields_by_laplace(r12, threshold=0)
peaks = calculate_field_centers(r12, labels)
centered_peaks = peaks - np.array(r1.shape) / 2
offset = np.linalg.norm(centered_peaks, axis=1)
distance_idx = np.argmin(offset)
distance = offset[distance_idx]
angle = np.arctan2(*centered_peaks[distance_idx])
return distance, angle
results_xcorr = [[], [], [], []]
results_gridness = [[], [], [], []]
results_maxrate = [[], [], [], []]
results_avgrate = [[], [], [], []]
results_unit_name = [[], [], [], []]
results_unit_id = [[], [], [], []]
results_id_map = {}
for nid, unit_sessions in once_a_gridcell.groupby('unit_id'):
base_i = unit_sessions.query("baseline and Hz11")
base_ii = unit_sessions.query("baseline and Hz30")
stim_i = unit_sessions.query("frequency==11")
stim_ii = unit_sessions.query("frequency==30")
dfs = [(base_i, base_ii), (base_i, stim_i), (base_ii, stim_ii), (base_i, stim_ii)]
for i, pair in enumerate(dfs):
for (_, row_1), (_, row_2) in zip(pair[0].iterrows(), pair[1].iterrows()):
rate_map_1 = data_loader.rate_map(
row_1['action'], row_1['channel_group'], row_1['unit_name'], smoothing_low)
rate_map_2 = data_loader.rate_map(
row_2['action'], row_2['channel_group'], row_2['unit_name'], smoothing_low)
results_xcorr[i].append(cross_correlation_distance(rate_map_1, rate_map_2))
results_gridness[i].append((row_1.gridness, row_2.gridness))
results_maxrate[i].append((row_1.max_rate, row_2.max_rate))
results_avgrate[i].append((row_1.average_rate, row_2.average_rate))
results_unit_name[i].append((
f'{row_1.action}_{row_1.channel_group}_{row_1.unit_name}',
f'{row_2.action}_{row_2.channel_group}_{row_2.unit_name}'))
assert row_1.unit_id == row_2.unit_id
uid = row_2.unit_id
idnum = row_1.unit_idnum
results_id_map[uid] = idnum
results_unit_id[i].append(idnum)
def session_id(row):
if row.baseline and row.i:
n = 0
elif row.stimulated and row.i:
n = 1
elif row.baseline and row.ii:
n = 2
elif row.stimulated and row.ii:
n = 3
else:
raise ValueError('what')
return n
once_a_gridcell['session_id'] = once_a_gridcell.apply(session_id, axis=1)
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (6, 4),
'figure.dpi': 150
})
for unit_id, id_num in results_id_map.items():
sessions = once_a_gridcell.query(f'unit_id=="{unit_id}"')
n_action = sessions.date.nunique()
fig, axs = plt.subplots(n_action, 4, sharey=True, sharex=True, figsize=(8, n_action*4))
sns.despine(left=True, bottom=True)
fig.suptitle(f'Neuron {id_num}')
if n_action == 1:
axs = [axs]
waxs = None
for ax, (date, rows) in zip(axs, sessions.groupby('date')):
rows = rows.sort_values('session')
entity = rows.iloc[0].entity
ax[0].set_ylabel(f'{entity}-{date}')
vmax = None
for _, row in rows.iterrows():
action_id = row['action']
channel_id = row['channel_group']
unit_name = row['unit_name']
rate_map = data_loader.rate_map(action_id, channel_id, unit_name, smoothing_low)
idx = row.session_id
if vmax is None:
vmax = rate_map.max()
ax[idx].imshow(rate_map, origin='lower', vmax=vmax)
ax[idx].set_title(f'{row.gridness:.2f} {row.max_rate:.2f} {row.average_rate:.2f}')
ax[idx].set_yticklabels([])
ax[idx].set_xticklabels([])
plt.tight_layout()
fig.savefig(output_path / 'figures' / f'neuron_{id_num}_rate_map.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / f'neuron_{id_num}_rate_map.svg', bbox_inches='tight')
# waveforms
# template = data_loader.template(action_id, channel_id, unit_name)
# if waxs is None:
# wfig, waxs = plt.subplots(1, template.data.shape[0], sharey=True, sharex=True)
# for i, wax in enumerate(waxs):
# wax.plot(template.data[i,:])
# if i > 0:
# ax.set_yticklabels([])
from scipy.interpolate import interp1d
for unit_id, id_num in results_id_map.items():
sessions = once_a_gridcell.query(f'unit_id=="{unit_id}"')
n_action = sessions.date.nunique()
fig, axs = plt.subplots(n_action, 4, sharey=True, sharex=True, figsize=(8, n_action*4))
sns.despine(left=True, bottom=True)
fig.suptitle(f'Neuron {id_num}')
if n_action == 1:
axs = [axs]
waxs = None
for ax, (date, rows) in zip(axs, sessions.groupby('date')):
entity = rows.iloc[0].entity
ax[0].set_ylabel(f'{entity}-{date}')
for _, row in rows.iterrows():
action_id = row['action']
channel_id = row['channel_group']
unit_name = row['unit_name']
idx = row.session_id
x, y, t, speed = map(data_loader.tracking(action_id).get, ['x', 'y', 't', 'v'])
ax[idx].plot(x, y, 'k', alpha=0.3)
spike_times = data_loader.spike_train(action_id, channel_id, unit_name)
spike_times = spike_times[(spike_times > min(t)) & (spike_times < max(t))]
x_spike = interp1d(t, x)(spike_times)
y_spike = interp1d(t, y)(spike_times)
ax[idx].set_xticks([])
ax[idx].set_yticks([])
ax[idx].scatter(x_spike, y_spike, marker='.', color=(0.7, 0.2, 0.2), s=1.5)
ax[idx].set_title(f'{row.session}')
ax[idx].set_yticklabels([])
ax[idx].set_xticklabels([])
for a in ax:
a.set_aspect(1)
plt.tight_layout()
fig.savefig(
output_path / 'figures' / f'neuron_{id_num}_spike_map.png',
bbox_inches='tight', transparent=True)
fig.savefig(
output_path / 'figures' / f'neuron_{id_num}_spike_map.svg',
bbox_inches='tight')
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (6, 4),
'figure.dpi': 150
})
cmap = ['#1b9e77','#d95f02','#7570b3', '#e7298a']
labels = [
'Baseline I vs baseline II',
'Baseline I vs stim I',
'Baseline II vs stim II',
'Baseline I vs stim II'
]
msize = 9
fig = plt.figure()
ticks = []
nuids = {}
n = 0
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
nuid = results_unit_id[i][j]
if nuid not in nuids:
nuids[nuid] = n
n += 1
plt.plot(
nuids[nuid], np.diff(pair),
color=cmap[i], marker='.', ls='none', markersize=msize)
for l in range(n):
plt.axvline(l, color='k', lw=.1, alpha=.5)
from matplotlib.lines import Line2D
custom_lines = [
Line2D([],[], color=cmap[i], marker='.', ls='none', label=label, markersize=msize)
for i, label in enumerate(labels)
]
plt.ylabel('Difference in gridness')
plt.xlabel('Neuron')
plt.legend(handles=custom_lines, bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
fig.savefig(output_path / 'figures' / 'neuron_gridness.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'neuron_gridness.svg', bbox_inches='tight')
fig = plt.figure()
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
plt.plot(*pair, color=cmap[i], marker='.', ls='none', markersize=msize)
# plt.scatter(*np.array(pairs).T, label=labels[i], color=cmap[i])
# plt.legend(bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
custom_lines = [
Line2D([],[], color=cmap[i], marker='.', ls='none', label=label, markersize=msize)
for i, label in enumerate(labels)
]
plt.legend(handles=custom_lines, bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
plt.ylabel('Gridness')
plt.xlabel('Baseline gridness')
lim = [-.7, 1.35]
plt.ylim(lim)
plt.xlim(lim)
plt.plot(lim, lim, '--k', alpha=.5, lw=1)
fig.savefig(output_path / 'figures' / 'baseline_gridness_vs_other.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'baseline_gridness_vs_other.svg', bbox_inches='tight')
fig = plt.figure()
for i, pairs in enumerate(results_maxrate):
for j, pair in enumerate(pairs):
plt.plot(*pair, color=cmap[i], marker='.', ls='none', markersize=msize)
# plt.scatter(*np.array(pairs).T, label=labels[i], color=cmap[i])
# plt.legend(bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
custom_lines = [
Line2D([],[], color=cmap[i], marker='.', ls='none', label=label, markersize=msize)
for i, label in enumerate(labels)
]
plt.legend(handles=custom_lines, bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
plt.ylabel('Max rate')
plt.xlabel('Baseline max rate')
lim = [-.7, 100]
plt.ylim(lim)
plt.xlim(lim)
plt.plot(lim, lim, '--k', alpha=.5, lw=1)
fig.savefig(output_path / 'figures' / 'baseline_max_rate_vs_other.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'baseline_max_rate_vs_other.svg', bbox_inches='tight')
fig = plt.figure()
for i, pairs in enumerate(results_avgrate):
for j, pair in enumerate(pairs):
plt.plot(*pair, color=cmap[i], marker='.', ls='none', markersize=msize)
# plt.scatter(*np.array(pairs).T, label=labels[i], color=cmap[i])
# plt.legend(bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
custom_lines = [
Line2D([],[], color=cmap[i], marker='.', ls='none', label=label, markersize=msize)
for i, label in enumerate(labels)
]
plt.legend(handles=custom_lines, bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
plt.ylabel('Average rate')
plt.xlabel('Baseline average rate')
lim = [-.7, 40]
plt.ylim(lim)
plt.xlim(lim)
plt.plot(lim, lim, '--k', alpha=.5, lw=1)
fig.savefig(output_path / 'figures' / 'baseline_average_rate_vs_other.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'baseline_average_rate_vs_other.svg', bbox_inches='tight')
fig = plt.figure()
import matplotlib
cNorm = matplotlib.colors.Normalize(vmin=-np.pi/2, vmax=np.pi/2)
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=plt.cm.Blues)
ticks = []
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
angle = float(np.arctan(np.diff(pair) / 0.9))
color = scalarMap.to_rgba(angle)
# color = plt.cm.Paired((np.sign(angle)+1)/14)
tick = (i, i+.8)
plt.plot(tick, pair, marker='.', color=color)
ticks.append(tick)
plt.xticks(
[t for tick in ticks for t in tick],
['Baseline I', 'Baseline II',
'Baseline I', 'Stimulation I',
'Baseline II', 'Stimulation II',
'Baseline I', 'Stimulation II'],
rotation=-45, ha='left'
)
plt.ylabel('Gridness')
fig.savefig(output_path / 'figures' / 'stickplot_gridness.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'stickplot_gridness.svg', bbox_inches='tight')
fig = plt.figure()
ticks = [0,0.6,1.2, 1.8]
pairwise_gridness = [[], [], [], []]
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
pairwise_gridness[i].append(np.diff(pair))
violins = plt.violinplot(
pairwise_gridness, ticks, showmedians=True, showextrema=False, points=1000, bw_method=.2)
for category in ['cbars', 'cmins', 'cmaxes', 'cmedians']:
if category in violins:
violins[category].set_color(['k', 'k'])
violins[category].set_linewidth(2.0)
colors = plt.cm.Paired(np.linspace(0,1,12))
for pc, c in zip(violins['bodies'], cmap):
pc.set_facecolor(c)
pc.set_edgecolor(c)
pc.set_alpha(0.8)
plt.xticks(ticks, labels, rotation=-45, ha='center')
plt.ylabel('Difference in gridness')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
fig.savefig(output_path / 'figures' / 'violins_gridness_difference.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'violins_gridness_difference.svg', bbox_inches='tight')
fig = plt.figure()
ticks = [0,0.6,1.2, 1.8]
pairwise_maxrate = [[], [], [], []]
for i, pairs in enumerate(results_maxrate):
for j, pair in enumerate(pairs):
pairwise_maxrate[i].append(np.diff(pair) / pair[0])
violins = plt.violinplot(
pairwise_maxrate, ticks, showmedians=True, showextrema=False, points=1000, bw_method=.2)
for category in ['cbars', 'cmins', 'cmaxes', 'cmedians']:
if category in violins:
violins[category].set_color(['k', 'k'])
violins[category].set_linewidth(2.0)
colors = plt.cm.Paired(np.linspace(0,1,12))
for pc, c in zip(violins['bodies'], cmap):
pc.set_facecolor(c)
pc.set_edgecolor(c)
pc.set_alpha(0.8)
plt.xticks(ticks, labels, rotation=-45, ha='center')
plt.ylabel('Relative change in max rate')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
fig.savefig(output_path / 'figures' / 'violins_max_rate_difference.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'violins_max_rate_difference.svg', bbox_inches='tight')
fig = plt.figure()
ticks = [0,0.6,1.2, 1.8]
pairwise_avgrate = [[], [], [], []]
for i, pairs in enumerate(results_avgrate):
for j, pair in enumerate(pairs):
pairwise_avgrate[i].append(np.diff(pair) / pair[0])
violins = plt.violinplot(
pairwise_avgrate, ticks, showmedians=True, showextrema=False, points=1000, bw_method=.2)
for category in ['cbars', 'cmins', 'cmaxes', 'cmedians']:
if category in violins:
violins[category].set_color(['k', 'k'])
violins[category].set_linewidth(2.0)
colors = plt.cm.Paired(np.linspace(0,1,12))
for pc, c in zip(violins['bodies'], cmap):
pc.set_facecolor(c)
pc.set_edgecolor(c)
pc.set_alpha(0.8)
plt.xticks(ticks, labels, rotation=-45, ha='center')
plt.ylabel('Relative change in mean rate')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
fig.savefig(output_path / 'figures' / 'violins_mean_rate_difference.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'violins_mean_rate_difference.svg', bbox_inches='tight')
fig = plt.figure()
ticks = [0,0.6,1.2, 1.8]
pairwise_xcorr = [[], [], [], []]
for i, pairs in enumerate(results_xcorr):
for j, pair in enumerate(pairs):
pairwise_xcorr[i].append(pair[0] * bin_size)
violins = plt.violinplot(
pairwise_xcorr, ticks, showmedians=True, showextrema=False, points=1000, bw_method=.2)
for category in ['cbars', 'cmins', 'cmaxes', 'cmedians']:
if category in violins:
violins[category].set_color(['k', 'k'])
violins[category].set_linewidth(2.0)
colors = plt.cm.Paired(np.linspace(0,1,12))
for pc, c in zip(violins['bodies'], cmap):
pc.set_facecolor(c)
pc.set_edgecolor(c)
pc.set_alpha(0.8)
plt.xticks(ticks, labels, rotation=-45, ha='center')
plt.ylabel('Spatial shift')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
fig.savefig(output_path / 'figures' / 'violins_spatial_shift.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'violins_spatial_shift.svg', bbox_inches='tight')
plt.imshow([np.arange(100), np.arange(100)])
despine(bottom=True, left=True, xticks=False, yticks=False)
plt.gcf().savefig('rocket_colorbar.svg')
'baseline I'.capitalize()
ncol, nrow = 4, 4
fig, axs = plt.subplots(nrow, ncol, sharey=True, figsize=(2 * ncol, 8))
form = lambda x: x.capitalize().replace(' i', ' I').replace(' Ii', ' II')
bins = [10, 10, 10, 10]
for i, ax in enumerate(axs):
ax[0].set_ylabel('\n'.join([form(l) for l in labels[i].split(' vs ')]))
h, b, _ = ax[0].hist(np.array(pairwise_xcorr[i]) * 100, bins=bins[0], color='k')
bins[0] = b
if i == 3:
ax[0].set_xlabel('Displacement (cm)')
elif i == 0:
ax[0].set_title('Spatial displacement')
ax[0].set_xticklabels([])
else:
ax[0].set_xticklabels([])
h, b, _ = ax[1].hist(np.array(pairwise_gridness[i]), bins=bins[1], color='k')
bins[1] = b
if i == 3:
ax[1].set_xlabel('Change')
elif i == 0:
ax[1].set_title('$\\Delta$ Gridness')
ax[1].set_xticklabels([])
else:
ax[1].set_xticklabels([])
h, b, _ = ax[2].hist(np.array(pairwise_maxrate[i]), bins=bins[2], color='k')
bins[2] = b
if i == 3:
ax[2].set_xlabel('Relative change')
elif i == 0:
ax[2].set_title('$\\Delta$ Max rate')
ax[2].set_xticklabels([])
else:
ax[2].set_xticklabels([])
h, b, _ = ax[3].hist(np.array(pairwise_avgrate[i]), bins=bins[3], color='k')
bins[3] = b
if i == 3:
ax[3].set_xlabel('Relative change')
elif i == 0:
ax[3].set_title('$\\Delta$ Average rate')
ax[3].set_xticklabels([])
else:
ax[3].set_xticklabels([])
despine()
fig.savefig(output_path / 'figures' / 'histogram_grid_all.svg')
fig.savefig(output_path / 'figures' / 'histogram_grid_all.png')
action = project.require_action("longitudinal-comparisons-gridcells")
copy_tree(output_path, str(action.data_path()))
septum_mec.analysis.registration.store_notebook(action, "20_longitudinal_comparisons_gridcells.ipynb")