%load_ext autoreload
%autoreload 2
import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns
import re
import shutil
import pandas as pd
import scipy.stats
import exdir
import expipe
from distutils.dir_util import copy_tree
import septum_mec
import spatial_maps as sp
import head_direction.head as head
import septum_mec.analysis.data_processing as dp
import septum_mec.analysis.registration
from septum_mec.analysis.plotting import violinplot
from spike_statistics.core import permutation_resampling
from tqdm import tqdm_notebook as tqdm
from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()
project_path = dp.project_path()
project = expipe.get_project(project_path)
actions = project.actions
output_path = pathlib.Path("output") / "longitudinal-comparisons-gridcells"
(output_path / "statistics").mkdir(exist_ok=True, parents=True)
(output_path / "figures").mkdir(exist_ok=True, parents=True)
statistics_action = actions['calculate-statistics']
identification_action = actions['identify-neurons']
sessions = pd.read_csv(identification_action.data_path('sessions'))
units = pd.read_csv(identification_action.data_path('units'))
session_units = pd.merge(sessions, units, on='action')
statistics_results = pd.read_csv(statistics_action.data_path('results'))
statistics = pd.merge(session_units, statistics_results, how='left')
statistics.head()
statistics['unit_day'] = statistics.apply(lambda x: str(x.unit_idnum) + '_' + x.action.split('-')[1], axis=1)
stim_response_action = actions['stimulus-response']
stim_response_results = pd.read_csv(stim_response_action.data_path('results'))
statistics = pd.merge(statistics, stim_response_results, how='left')
print('N cells:',statistics.shape[0])
shuffling = actions['shuffling']
quantiles_95 = pd.read_csv(shuffling.data_path('quantiles_95'))
quantiles_95.head()
action_columns = ['action', 'channel_group', 'unit_name']
data = pd.merge(statistics, quantiles_95, on=action_columns, suffixes=("", "_threshold"))
data['specificity'] = np.log10(data['in_field_mean_rate'] / data['out_field_mean_rate'])
data.head()
data.groupby('stimulated').count()['action']
query = 'gridness > gridness_threshold and information_rate > information_rate_threshold'
sessions_above_threshold = data.query(query)
print("Number of gridcells", len(sessions_above_threshold))
print("Number of animals", len(sessions_above_threshold.groupby(['entity'])))
columns = [
'average_rate', 'gridness', 'sparsity', 'selectivity', 'information_specificity',
'max_rate', 'information_rate', 'interspike_interval_cv',
'in_field_mean_rate', 'out_field_mean_rate',
'burst_event_ratio', 'specificity', 'speed_score'
]
once_a_gridcell = statistics[statistics.unit_day.isin(sessions_above_threshold.unit_day.values)]
baseline_i = once_a_gridcell.query('baseline and Hz11')
stimulated_11 = once_a_gridcell.query('stimulated and frequency==11 and stim_location=="ms"')
baseline_ii = once_a_gridcell.query('baseline and Hz30')
stimulated_30 = once_a_gridcell.query('stimulated and frequency==30 and stim_location=="ms"')
print("Number of gridcells in baseline i sessions", len(baseline_i))
print("Number of gridcells in stimulated 11Hz ms sessions", len(stimulated_11))
print("Number of gridcells in baseline ii sessions", len(baseline_ii))
print("Number of gridcells in stimulated 30Hz ms sessions", len(stimulated_30))
max_speed = 1, # m/s only used for speed score
min_speed = 0.02, # m/s only used for speed score
position_sampling_rate = 100 # for interpolation
position_low_pass_frequency = 6 # for low pass filtering of position
box_size = [1.0, 1.0]
bin_size = 0.02
smoothing_low = 0.03
smoothing_high = 0.06
data_loader = dp.Data(
position_sampling_rate=position_sampling_rate,
position_low_pass_frequency=position_low_pass_frequency,
box_size=box_size, bin_size=bin_size
)
results_corr = [[], [], []]
results_gridness = [[], [], []]
results_unit_name = [[], [], []]
results_unit_id = [[], [], []]
results_id_map = {}
for nid, unit_sessions in once_a_gridcell.groupby('unit_id'):
base_i = unit_sessions.query("baseline and Hz11")
base_ii = unit_sessions.query("baseline and Hz30")
stim_i = unit_sessions.query("frequency==11")
stim_ii = unit_sessions.query("frequency==30")
dfs = [(base_i, base_ii), (base_i, stim_i), (base_ii, stim_ii)]
for i, pair in enumerate(dfs):
for (_, row_1), (_, row_2) in zip(pair[0].iterrows(), pair[1].iterrows()):
rate_map_1 = data_loader.rate_map(
row_1['action'], row_1['channel_group'], row_1['unit_name'], smoothing_low)
rate_map_2 = data_loader.rate_map(
row_2['action'], row_2['channel_group'], row_2['unit_name'], smoothing_low)
results_corr[i].append(np.corrcoef(rate_map_1.ravel(), rate_map_2.ravel())[0,1])
results_gridness[i].append((row_1.gridness, row_2.gridness))
results_unit_name[i].append((
f'{row_1.action}_{row_1.channel_group}_{row_1.unit_name}',
f'{row_2.action}_{row_2.channel_group}_{row_2.unit_name}'))
assert row_1.unit_id == row_2.unit_id
uid = row_2.unit_id
idnum = row_1.unit_idnum
results_id_map[uid] = idnum
results_unit_id[i].append(idnum)
def session_id(row):
if row.baseline and row.i:
n = 0
elif row.stimulated and row.i:
n = 1
elif row.baseline and row.ii:
n = 2
elif row.stimulated and row.ii:
n = 3
else:
raise ValueError('what')
return n
once_a_gridcell['session_id'] = once_a_gridcell.apply(session_id, axis=1)
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (6, 4),
'figure.dpi': 150
})
for unit_id, id_num in results_id_map.items():
sessions = once_a_gridcell.query(f'unit_id=="{unit_id}"')
n_action = sessions.date.nunique()
fig, axs = plt.subplots(n_action, 4, sharey=True, sharex=True, figsize=(8, n_action*4))
sns.despine(left=True, bottom=True)
fig.suptitle(f'Neuron {id_num}')
if n_action == 1:
axs = [axs]
waxs = None
for ax, (date, rows) in zip(axs, sessions.groupby('date')):
entity = rows.iloc[0].entity
ax[0].set_ylabel(f'{entity}-{date}')
for _, row in rows.iterrows():
action_id = row['action']
channel_id = row['channel_group']
unit_name = row['unit_name']
rate_map = data_loader.rate_map(action_id, channel_id, unit_name, smoothing_low)
idx = row.session_id
ax[idx].imshow(rate_map, origin='lower')
ax[idx].set_title(f'{row.gridness:.2f} {row.max_rate:.2f}')
ax[idx].set_yticklabels([])
ax[idx].set_xticklabels([])
plt.tight_layout()
fig.savefig(output_path / 'figures' / f'neuron_{id_num}_rate_map.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / f'neuron_{id_num}_rate_map.svg', bbox_inches='tight')
# waveforms
# template = data_loader.template(action_id, channel_id, unit_name)
# if waxs is None:
# wfig, waxs = plt.subplots(1, template.data.shape[0], sharey=True, sharex=True)
# for i, wax in enumerate(waxs):
# wax.plot(template.data[i,:])
# if i > 0:
# ax.set_yticklabels([])
from scipy.interpolate import interp1d
for unit_id, id_num in results_id_map.items():
sessions = once_a_gridcell.query(f'unit_id=="{unit_id}"')
n_action = sessions.date.nunique()
fig, axs = plt.subplots(n_action, 4, sharey=True, sharex=True, figsize=(8, n_action*4))
sns.despine(left=True, bottom=True)
fig.suptitle(f'Neuron {id_num}')
if n_action == 1:
axs = [axs]
waxs = None
for ax, (date, rows) in zip(axs, sessions.groupby('date')):
entity = rows.iloc[0].entity
ax[0].set_ylabel(f'{entity}-{date}')
for _, row in rows.iterrows():
action_id = row['action']
channel_id = row['channel_group']
unit_name = row['unit_name']
idx = row.session_id
x, y, t, speed = map(data_loader.tracking(action_id).get, ['x', 'y', 't', 'v'])
ax[idx].plot(x, y, 'k', alpha=0.3)
spike_times = data_loader.spike_train(action_id, channel_id, unit_name)
spike_times = spike_times[(spike_times > min(t)) & (spike_times < max(t))]
x_spike = interp1d(t, x)(spike_times)
y_spike = interp1d(t, y)(spike_times)
ax[idx].set_xticks([])
ax[idx].set_yticks([])
ax[idx].scatter(x_spike, y_spike, marker='.', color=(0.7, 0.2, 0.2), s=1.5)
ax[idx].set_title(f'{row.session}')
ax[idx].set_yticklabels([])
ax[idx].set_xticklabels([])
for a in ax:
a.set_aspect(1)
plt.tight_layout()
fig.savefig(
output_path / 'figures' / f'neuron_{id_num}_spike_map.png',
bbox_inches='tight', transparent=True)
fig.savefig(
output_path / 'figures' / f'neuron_{id_num}_spike_map.svg',
bbox_inches='tight')
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (6, 4),
'figure.dpi': 150
})
cmap = ['#1b9e77','#d95f02','#7570b3']
len(results_gridness)
msize = 9
fig = plt.figure()
ticks = []
nuids = {}
n = 0
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
nuid = results_unit_id[i][j]
if nuid not in nuids:
nuids[nuid] = n
n += 1
plt.plot(
nuids[nuid], np.diff(pair),
color=cmap[i], marker='.', ls='none', markersize=msize)
for l in range(n):
plt.axvline(l, color='k', lw=.1, alpha=.5)
from matplotlib.lines import Line2D
labels = ['Baseline I vs baseline II', 'Baseline I vs stim I', 'Baseline II vs stim II']
custom_lines = [
Line2D([],[], color=cmap[i], marker='.', ls='none', label=label, markersize=msize)
for i, label in enumerate(labels)
]
plt.ylabel('Difference in gridness')
plt.xlabel('Neuron')
plt.legend(handles=custom_lines, bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
fig.savefig(output_path / 'figures' / 'neuron_gridness.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'neuron_gridness.svg', bbox_inches='tight')
fig = plt.figure()
labels = ['Baseline I vs baseline II', 'Baseline I vs stim I', 'Baseline II vs stim II']
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
plt.plot(*pair, color=cmap[i], marker='.', ls='none', markersize=msize)
# plt.scatter(*np.array(pairs).T, label=labels[i], color=cmap[i])
# plt.legend(bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
custom_lines = [
Line2D([],[], color=cmap[i], marker='.', ls='none', label=label, markersize=msize)
for i, label in enumerate(labels)
]
plt.legend(handles=custom_lines, bbox_to_anchor=(1.04,1), borderaxespad=0, frameon=False)
plt.ylabel('Gridness')
plt.xlabel('Baseline gridness')
lim = [-.7, 1.35]
plt.ylim(lim)
plt.xlim(lim)
plt.plot(lim, lim, '--k', alpha=.5, lw=1)
fig.savefig(output_path / 'figures' / 'baseline_gridness_vs_other.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'baseline_gridness_vs_other.svg', bbox_inches='tight')
fig = plt.figure()
import matplotlib
cNorm = matplotlib.colors.Normalize(vmin=-np.pi/2, vmax=np.pi/2)
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=plt.cm.Blues)
ticks = []
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
angle = float(np.arctan(np.diff(pair) / 0.9))
color = scalarMap.to_rgba(angle)
# color = plt.cm.Paired((np.sign(angle)+1)/14)
tick = (i, i+.8)
plt.plot(tick, pair, marker='.', color=color)
ticks.append(tick)
plt.xticks(
[t for tick in ticks for t in tick],
['Baseline I', 'Baseline II', 'Baseline I', 'Stimulation I', 'Baseline II', 'Stimulation II'],
rotation=-45, ha='left'
)
plt.ylabel('Gridness')
fig.savefig(output_path / 'figures' / 'stickplot_gridness.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'stickplot_gridness.svg', bbox_inches='tight')
fig = plt.figure()
ticks = [0,0.6,1.2]
diff_res = [[], [], []]
for i, pairs in enumerate(results_gridness):
for j, pair in enumerate(pairs):
# if results_unit_id[i][j] in [results_id_map[i] for i in exclude]:
# continue
diff_res[i].append(np.diff(pair))
violins = plt.violinplot(
diff_res, ticks, showmedians=True, showextrema=False, points=1000, bw_method=.2)
for category in ['cbars', 'cmins', 'cmaxes', 'cmedians']:
if category in violins:
violins[category].set_color(['k', 'k'])
violins[category].set_linewidth(2.0)
colors = plt.cm.Paired(np.linspace(0,1,12))
for pc, c in zip(violins['bodies'], cmap):
pc.set_facecolor(c)
pc.set_edgecolor(c)
plt.xticks(ticks, ['baseline', 'stim i', 'stim ii'])
plt.ylabel('Difference in gridness')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
fig.savefig(output_path / 'figures' / 'violins_gridness_difference.png', bbox_inches='tight')
fig.savefig(output_path / 'figures' / 'violins_gridness_difference.svg', bbox_inches='tight')
action = project.require_action("longitudinal-comparisons-gridcells")
copy_tree(output_path, str(action.data_path()))
septum_mec.analysis.registration.store_notebook(action, "20_longitudinal_comparisons_gridcells.ipynb")