%load_ext autoreload
%autoreload 2
import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import re
import shutil
import pandas as pd
import scipy.stats
import exdir
import expipe
from distutils.dir_util import copy_tree
import septum_mec
import spatial_maps as sp
import head_direction.head as head
import septum_mec.analysis.data_processing as dp
import septum_mec.analysis.registration
from septum_mec.analysis.plotting import violinplot
from spike_statistics.core import permutation_resampling
project_path = dp.project_path()
project = expipe.get_project(project_path)
actions = project.actions
output_path = pathlib.Path("output") / "comparisons-gridcells"
(output_path / "statistics").mkdir(exist_ok=True, parents=True)
(output_path / "figures").mkdir(exist_ok=True, parents=True)
statistics_action = actions['calculate-statistics']
identification_action = actions['identify-neurons']
sessions = pd.read_csv(identification_action.data_path('sessions'))
units = pd.read_csv(identification_action.data_path('units'))
session_units = pd.merge(sessions, units, on='action')
statistics_results = pd.read_csv(statistics_action.data_path('results'))
statistics = pd.merge(session_units, statistics_results, how='left')
statistics.head()
shuffling = actions['shuffling']
quantiles_95 = pd.read_csv(shuffling.data_path('quantiles_95'))
quantiles_95.head()
action_columns = ['action', 'channel_group', 'unit_name']
data = pd.merge(statistics, quantiles_95, on=action_columns, suffixes=("", "_threshold"))
data['specificity'] = np.log10(data['in_field_mean_rate'] / data['out_field_mean_rate'])
data.head()
data.groupby('stimulated').count()['action']
data['unit_day'] = data.apply(lambda x: str(x.unit_idnum) + '_' + x.action.split('-')[1], axis=1)
query = (
'gridness > gridness_threshold and '
'information_rate > information_rate_threshold and '
'gridness > .2 and '
'average_rate < 25'
)
sessions_above_threshold = data.query(query)
print("Number of sessions above threshold", len(sessions_above_threshold))
print("Number of animals", len(sessions_above_threshold.groupby(['entity'])))
gridcell_sessions = data[data.unit_day.isin(sessions_above_threshold.unit_day.values)]
print("Number of gridcells", gridcell_sessions.unit_idnum.nunique())
print("Number of gridcell recordings", len(gridcell_sessions))
print("Number of animals", len(gridcell_sessions.groupby(['entity'])))
baseline_i = gridcell_sessions.query('baseline and Hz11')
stimulated_11 = gridcell_sessions.query('frequency==11 and stim_location=="ms"')
baseline_ii = gridcell_sessions.query('baseline and Hz30')
stimulated_30 = gridcell_sessions.query('frequency==30 and stim_location=="ms"')
print("Number of gridcells in baseline i sessions", len(baseline_i))
print("Number of gridcells in stimulated 11Hz ms sessions", len(stimulated_11))
print("Number of gridcells in baseline ii sessions", len(baseline_ii))
print("Number of gridcells in stimulated 30Hz ms sessions", len(stimulated_30))
baseline_i = baseline_i.drop_duplicates('unit_id')
stimulated_11 = stimulated_11.drop_duplicates('unit_id')
baseline_ii = baseline_ii.drop_duplicates('unit_id')
stimulated_30 = stimulated_30.drop_duplicates('unit_id')
print("Number of gridcells in baseline i sessions", len(baseline_i))
print("Number of gridcells in stimulated 11Hz ms sessions", len(stimulated_11))
print("Number of gridcells in baseline ii sessions", len(baseline_ii))
print("Number of gridcells in stimulated 30Hz ms sessions", len(stimulated_30))
columns = [
'average_rate', 'gridness', 'sparsity', 'selectivity', 'information_specificity',
'max_rate', 'information_rate', 'interspike_interval_cv',
'in_field_mean_rate', 'out_field_mean_rate',
'burst_event_ratio', 'specificity', 'speed_score'
]
gridcell_sessions.groupby('stimulated')[columns].mean()
gridcell_sessions.query('baseline')[columns].describe()
gridcell_sessions.query("stimulated")[columns].describe()
def summarize(data):
return "{:.2f} ± {:.2f} ({})".format(data.mean(), data.sem(), sum(~np.isnan(data)))
def MWU(column, stim, base):
'''
Mann Whitney U
'''
Uvalue, pvalue = scipy.stats.mannwhitneyu(
stim[column].dropna(),
base[column].dropna(),
alternative='two-sided')
return "{:.2f}, {:.3f}".format(Uvalue, pvalue)
def PRS(column, stim, base):
'''
Permutation ReSampling
'''
pvalue, observed_diff, diffs = permutation_resampling(
stim[column].dropna(),
base[column].dropna(), statistic=np.median)
return "{:.2f}, {:.3f}".format(observed_diff, pvalue)
def rename(name):
return name.replace("_field", "-field").replace("_", " ").capitalize()
_stim_data = gridcell_sessions.query('stimulated')
_base_data = gridcell_sessions.query('baseline')
result = pd.DataFrame()
result['Baseline'] = _base_data[columns].agg(summarize)
result['Stimulated'] = _stim_data[columns].agg(summarize)
result.index = map(rename, result.index)
result['MWU'] = list(map(lambda x: MWU(x, _stim_data, _base_data), columns))
result['PRS'] = list(map(lambda x: PRS(x, _stim_data, _base_data), columns))
result.to_latex(output_path / "statistics" / "statistics.tex")
result.to_latex(output_path / "statistics" / "statistics.csv")
result
_stim_data = stimulated_11
_base_data = baseline_i
result = pd.DataFrame()
result['Baseline'] = _base_data[columns].agg(summarize)
result['11 Hz'] = _stim_data[columns].agg(summarize)
result.index = map(rename, result.index)
result['MWU'] = list(map(lambda x: MWU(x, _stim_data, _base_data), columns))
result['PRS'] = list(map(lambda x: PRS(x, _stim_data, _base_data), columns))
result.to_latex(output_path / "statistics" / "statistics_11.tex")
result.to_latex(output_path / "statistics" / "statistics_11.csv")
result
_stim_data = stimulated_30
_base_data = baseline_ii
result = pd.DataFrame()
result['Baseline'] = _base_data[columns].agg(summarize)
result['30 Hz'] = _stim_data[columns].agg(summarize)
result.index = map(rename, result.index)
result['MWU'] = list(map(lambda x: MWU(x, _stim_data, _base_data), columns))
result['PRS'] = list(map(lambda x: PRS(x, _stim_data, _base_data), columns))
result.to_latex(output_path / "statistics" / "statistics_30.tex")
result.to_latex(output_path / "statistics" / "statistics_30.csv")
result
_stim_data = stimulated_30
_base_data = baseline_i
result = pd.DataFrame()
result['Baseline I'] = _base_data[columns].agg(summarize)
result['30 Hz'] = _stim_data[columns].agg(summarize)
result.index = map(rename, result.index)
result['MWU'] = list(map(lambda x: MWU(x, _stim_data, _base_data), columns))
result['PRS'] = list(map(lambda x: PRS(x, _stim_data, _base_data), columns))
result.to_latex(output_path / "statistics" / "statistics_b_i_30.tex")
result.to_latex(output_path / "statistics" / "statistics_b_i_30.csv")
result
_stim_data = stimulated_30
_base_data = stimulated_11
result = pd.DataFrame()
result['11 Hz'] = _base_data[columns].agg(summarize)
result['30 Hz'] = _stim_data[columns].agg(summarize)
result.index = map(rename, result.index)
result['MWU'] = list(map(lambda x: MWU(x, _stim_data, _base_data), columns))
result['PRS'] = list(map(lambda x: PRS(x, _stim_data, _base_data), columns))
result.to_latex(output_path / "statistics" / "statistics_11_vs_30.tex")
result.to_latex(output_path / "statistics" / "statistics_11_vs_30.csv")
result
_stim_data = baseline_i
_base_data = baseline_ii
result = pd.DataFrame()
result['Baseline I'] = _stim_data[columns].agg(summarize)
result['Baseline II'] = _base_data[columns].agg(summarize)
result.index = map(rename, result.index)
result['MWU'] = list(map(lambda x: MWU(x, _stim_data, _base_data), columns))
result['PRS'] = list(map(lambda x: PRS(x, _stim_data, _base_data), columns))
result.to_latex(output_path / "statistics" / "statistics_base_i_vs_base_ii.tex")
result.to_latex(output_path / "statistics" / "statistics_base_i_vs_base_ii.csv")
result
%matplotlib inline
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (1.7, 3),
'figure.dpi': 150
})
# colors = ['#1b9e77','#d95f02','#7570b3','#e7298a']
# labels = ['Baseline I', '11 Hz', 'Baseline II', '30 Hz']
stuff = {
'': {
'base': gridcell_sessions.query('baseline'),
'stim': gridcell_sessions.query('stimulated')
},
'_11': {
'base': baseline_i,
'stim': stimulated_11
},
'_30': {
'base': baseline_ii,
'stim': stimulated_30
}
}
label = {
'': ['Baseline ', ' Stimulated'],
'_11': ['Baseline I ', ' 11 Hz'],
'_30': ['Baseline II ', ' 30 Hz']
}
colors = {
'': None,
'_11': ['#1b9e77', '#d95f02'],
'_30': ['#7570b3', '#e7298a']
}
for key, data in stuff.items():
baseline = data['base']['information_specificity'].to_numpy()
stimulated = data['stim']['information_specificity'].to_numpy()
print(key)
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Spatial information specificity")
plt.ylabel("bits/spike")
plt.ylim(-0.2, 1.6)
plt.savefig(output_path / "figures" / f"information_specificity{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"information_specificity{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['information_rate'].to_numpy()
stimulated = data['stim']['information_rate'].to_numpy()
print(key)
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Spatial information")
plt.ylabel("bits/s")
plt.ylim(-0.2, 4)
plt.savefig(output_path / "figures" / f"spatial_information{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"spatial_information{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['specificity'].to_numpy()
stimulated = data['stim']['specificity'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Spatial specificity")
plt.ylabel("")
plt.ylim(-0.02, 1.25)
plt.savefig(output_path / "figures" / f"specificity{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"specificity{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['average_rate'].to_numpy()
stimulated = data['stim']['average_rate'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Average rate")
plt.ylabel("spikes/s")
plt.ylim(-0.2, 40)
plt.savefig(output_path / "figures" / f"average_rate{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"average_rate{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['max_rate'].to_numpy()
stimulated = data['stim']['max_rate'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Max rate")
plt.ylabel("spikes/s")
# plt.ylim(-0.2, 45)
plt.savefig(output_path / "figures" / f"max_rate{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"max_rate{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['interspike_interval_cv'].to_numpy()
stimulated = data['stim']['interspike_interval_cv'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("ISI CV")
plt.ylabel("Coefficient of variation")
# plt.ylim(0.9, 5)
plt.savefig(output_path / "figures" / f"isi_cv{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"isi_cv{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['in_field_mean_rate'].to_numpy()
stimulated = data['stim']['in_field_mean_rate'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("In-field rate")
plt.ylabel("spikes/s")
# plt.ylim(-0.1, 18)
plt.savefig(output_path / "figures" / f"in_field_mean_rate{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"in_field_mean_rate{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['out_field_mean_rate'].to_numpy()
stimulated = data['stim']['out_field_mean_rate'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Out-of-field rate")
plt.ylabel("spikes/s")
# plt.ylim(-0.2, 8)
plt.savefig(output_path / "figures" / f"out_field_mean_rate{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"out_field_mean_rate{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['burst_event_ratio'].to_numpy()
stimulated = data['stim']['burst_event_ratio'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Bursting ratio")
plt.ylabel("")
# plt.ylim(-0.02, 0.60)
plt.savefig(output_path / "figures" / f"burst_event_ratio{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"burst_event_ratio{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['max_field_mean_rate'].to_numpy()
stimulated = data['stim']['max_field_mean_rate'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Mean rate of max field")
plt.ylabel("(spikes/s)")
# plt.ylim(-0.5,25)
plt.savefig(output_path / "figures" / f"max_field_mean_rate{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"max_field_mean_rate{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['bursty_spike_ratio'].to_numpy()
stimulated = data['stim']['bursty_spike_ratio'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("ratio of spikes per burst")
plt.ylabel("")
# plt.ylim(-0.03,0.9)
plt.savefig(output_path / "figures" / f"bursty_spike_ratio{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"bursty_spike_ratio{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items():
baseline = data['base']['gridness'].to_numpy()
stimulated = data['stim']['gridness'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Gridness")
plt.ylabel("Gridness")
plt.ylim(-0.6, 1.5)
plt.savefig(output_path / "figures" / f"gridness{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"gridness{key}.png", dpi=600, bbox_inches="tight")
for key, data in stuff.items(): #TODO narrow broad spiking
baseline = data['base']['speed_score'].to_numpy()
stimulated = data['stim']['speed_score'].to_numpy()
plt.figure()
violinplot(baseline, stimulated, xticks=label[key], colors=colors[key])
plt.title("Speed score")
plt.ylabel("Speed score")
# plt.ylim(-0.1, 0.5)
plt.savefig(output_path / "figures" / f"speed_score{key}.svg", bbox_inches="tight")
plt.savefig(output_path / "figures" / f"speed_score{key}.png", dpi=600, bbox_inches="tight")
# fig, (ax1, ax2) = plt.subplots(2,1, figsize=(6,6), sharey=True)
# for key, data in stuff.items():
# ax1.set_title('Baseline')
# peak_rate = data['base']['max_rate'].to_numpy()
# spacing = data['base']['spacing'].to_numpy()
# ax1.scatter(spacing, peak_rate)
# ax2.set_title('Stim')
# peak_rate = data['stim']['max_rate'].to_numpy()
# spacing = data['stim']['spacing'].to_numpy()
# ax2.scatter(spacing, peak_rate, label=key)
# ax2.legend()
action = project.require_action("comparisons-gridcells")
copy_tree(output_path, str(action.data_path()))
septum_mec.analysis.registration.store_notebook(action, "20_comparisons_gridcells.ipynb")