%load_ext autoreload
%autoreload 2
import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns
import re
import shutil
import pandas as pd
import scipy.stats
import exdir
import expipe
from distutils.dir_util import copy_tree
import septum_mec
import spatial_maps as sp
import head_direction.head as head
import septum_mec.analysis.data_processing as dp
import septum_mec.analysis.registration
from septum_mec.analysis.plotting import violinplot, despine
from spatial_maps.fields import (
find_peaks, calculate_field_centers, separate_fields_by_laplace,
map_pass_to_unit_circle, calculate_field_centers, distance_to_edge_function,
compute_crossings, which_field)
from phase_precession import cl_corr
import matplotlib.mlab as mlab
import scipy.signal as ss
from scipy.interpolate import interp1d
from septum_mec.analysis.plotting import regplot
from skimage import measure
from tqdm.notebook import tqdm_notebook as tqdm
tqdm.pandas()
from septum_mec.analysis.statistics import load_data_frames, make_paired_tables, make_statistics_table
# %matplotlib notebook
%matplotlib inline
project_path = dp.project_path()
project = expipe.get_project(project_path)
actions = project.actions
output_path = pathlib.Path("output") / "spikes-in-field"
(output_path / "statistics").mkdir(exist_ok=True, parents=True)
(output_path / "figures").mkdir(exist_ok=True, parents=True)
data, labels, colors, queries = load_data_frames()
# remove multiple measured cells
results, labels = make_paired_tables(data, ['action', 'unit_name'], cell_types=['gridcell'])
results = results['gridcell']
results['action']
def make_unique_unit_df(label):
output = pd.DataFrame()
for (i, action_row), (j, unit_name_row) in zip(results['action'].iterrows(), results['unit_name'].iterrows()):
assert i==j
if action_row[label] != action_row[label]: # ugly test for nan
continue
output.loc[i, 'action'] = action_row[label]
output.loc[i, 'channel_group'] = action_row.channel_group
output.loc[i, 'unit_idnum'] = action_row.unit_idnum
output.loc[i, 'unit_name'] = unit_name_row[label]
day = action_row[label].split('-')[1]
output.loc[i, 'unit_day'] = f'{action_row.unit_idnum}_{day}'
output.unit_name = output.unit_name.values.astype(int)
output.channel_group = output.channel_group.values.astype(int)
output.unit_idnum = output.unit_idnum.values.astype(int)
return output
baseline_i = make_unique_unit_df('Baseline I')
stimulated_11 = make_unique_unit_df('11 Hz')
baseline_ii = make_unique_unit_df('Baseline II')
stimulated_30 = make_unique_unit_df('30 Hz')
print("Number of gridcells in baseline i sessions", len(baseline_i))
print("Number of gridcells in stimulated 11Hz ms sessions", len(stimulated_11))
print("Number of gridcells in baseline ii sessions", len(baseline_ii))
print("Number of gridcells in stimulated 30Hz ms sessions", len(stimulated_30))
max_speed = .5 # m/s only used for speed score
min_speed = 0.02 # m/s only used for speed score
position_sampling_rate = 100 # for interpolation
position_low_pass_frequency = 6 # for low pass filtering of position
box_size = [1.0, 1.0]
bin_size = 0.02
smoothing_low = 0.03
smoothing_high = 0.06
speed_binsize = 0.02
stim_mask = True
baseline_duration = 600
data_loader = dp.Data(
position_sampling_rate=position_sampling_rate,
position_low_pass_frequency=position_low_pass_frequency,
box_size=box_size, bin_size=bin_size,
stim_mask=stim_mask, baseline_duration=baseline_duration
)
def find_grid_fields(rate_map, sigma=3, seed=2.5):
# find fields with laplace
fields_laplace = sp.fields.separate_fields_by_dilation(rate_map, sigma=sigma, seed=seed)
fields = fields_laplace.copy() # to be cleaned by Ismakov
fields_areas = scipy.ndimage.measurements.sum(
np.ones_like(fields), fields, index=np.arange(fields.max() + 1))
fields_area = fields_areas[fields]
fields[fields_area < 9.0] = 0
# find fields with Ismakov-method
fields_ismakov, radius = sp.separate_fields_by_distance(rate_map)
fields_ismakov_real = fields_ismakov * bin_size
approved_fields = []
# remove fields not found by both methods
for point in fields_ismakov:
field_id = fields[tuple(point)]
approved_fields.append(field_id)
for field_id in np.arange(1, fields.max() + 1):
if not field_id in approved_fields:
fields[fields == field_id] = 0
return fields
def get_data(row):
spikes = data_loader.spike_train(row.action, row.channel_group, row.unit_name)
rate_map = data_loader.rate_map(row.action, row.channel_group, row.unit_name, smoothing=0.04)
pos_x, pos_y, pos_t, pos_speed = map(data_loader.tracking(row.action).get, ['x', 'y', 't', 'v'])
stim_times = data_loader.stim_times(row.action)
if stim_times is not None:
stim_times = np.array(stim_times)
spikes = np.array(spikes)
spikes = spikes[(spikes > pos_t.min()) & (spikes < pos_t.max())]
# sx, sy = rate_map.shape
# dx = box_size[0] / sx
# dy = box_size[1] / sy
# x_bins = np.arange(0, box_size[0], dx)
# y_bins = np.arange(0, box_size[1], dy)
# f = interp2d(x_bins, y_bins, rate_map.T)
# x_new = np.arange(0, box_size[0], dx / 3)
# y_new = np.arange(0, box_size[1], dy / 3)
# rate_map = f(x_new, y_new).T
fields = find_grid_fields(rate_map)
return spikes, pos_x, pos_y, pos_t, rate_map, fields, stim_times
def compute_field_spikes(row, plot=False, z1=5e-3, z2=11e-3, surrogate_fields=None):
spikes, pos_x, pos_y, pos_t, rate_map, fields, stim_times = get_data(row)
if surrogate_fields is not None:
fields = surrogate_fields
# if field_num is not None:
# fields = np.where(fields == field_num, fields, 0)
if plot:
fig, axs = plt.subplots(1, 3, figsize=(16,9))
axs[1].set_title(f'{row.action} {row.channel_group} {row.unit_idnum}, G={row.gridness:.3f}')
dot_size = 10
sx, sy = interp1d(pos_t, pos_x), interp1d(pos_t, pos_y)
stim_spikes = []
stim_in_field_indices = []
# set z1 and z2 to be after the response: z1=12e3 - z2=60e3
if stim_times is not None:
for t in stim_times:
idx = np.searchsorted(spikes, [t + z1, t + z2], side='right')
tmp_spikes = spikes[idx[0]: idx[1]].tolist()
stim_spikes.extend(tmp_spikes)
stim_spikes_x = sx(stim_spikes)
stim_spikes_y = sy(stim_spikes)
stim_in_field_indices = which_field(stim_spikes_x, stim_spikes_y, fields, box_size).astype(bool)
# stim_ids_ = []
# stim_spikes_ = []
# for i, t in enumerate(stim_times):
# idx = np.searchsorted(spikes, [t, t + 30e-3], side='right')
# tmp_spikes = (spikes[idx[0]: idx[1]] - t).tolist()
# stim_ids_.extend([i] * len(tmp_spikes))
# stim_spikes_.extend(tmp_spikes)
# plt.scatter(stim_spikes_, stim_ids_, s=1)
# plt.axvspan(z1, z2, color='r', alpha=.3)
spikes_x = sx(spikes)
spikes_y = sy(spikes)
in_field_indices = which_field(spikes_x, spikes_y, fields, box_size).astype(bool)
if plot:
axs[0].imshow(
fields.T.astype(bool), extent=[0, box_size[0], 0, box_size[1]],
origin='lower', cmap=plt.cm.Greys, zorder=0)
axs[0].scatter(
spikes_x[in_field_indices], spikes_y[in_field_indices],
s=dot_size, color='r', zorder=1)
axs[0].scatter(
spikes_x[~in_field_indices], spikes_y[~in_field_indices],
s=dot_size, color='b', zorder=1)
if stim_times is not None:
axs[0].scatter(
stim_spikes_x, stim_spikes_y,
s=dot_size, color='orange', zorder=1)
# Display the image and plot all contours found
contours = measure.find_contours(fields, 0.0)
axs[1].imshow(rate_map.T, extent=[0, box_size[0], 0, box_size[1]], origin='lower')
axs[2].plot(pos_x, pos_y, color='k', alpha=.2, zorder=0)
axs[2].scatter(
interp1d(pos_t, pos_x)(spikes), interp1d(pos_t, pos_y)(spikes),
s=dot_size, zorder=1)
for ax in axs.ravel()[1:]:
for n, contour in enumerate(contours):
ax.plot(
contour[:, 0] * bin_size, contour[:, 1] * bin_size,
lw=4, color='y', zorder=3)
for ax in axs.ravel():
ax.axis('image')
ax.set_xticks([])
ax.set_yticks([])
return fields, in_field_indices, stim_in_field_indices
def plot_stim_field_spikes(row, t1=0, t2=30, z1_base=0, z2_base=5, z1_stim=5, z2_stim=11, colors=['k','r']):
spikes, pos_x, pos_y, pos_t, rate_map, fields, stim_times = get_data(row)
spikes = np.array(spikes) * 1000
pos_t = np.array(pos_t) * 1000
stim_times = np.array(stim_times) * 1000
fig, axs = plt.subplots(1, 2)
dot_size = 2
sx, sy = interp1d(pos_t, pos_x), interp1d(pos_t, pos_y)
stim_spikes_base = []
stim_spikes_base_plot = []
stim_ids_base = []
stim_spikes_stim = []
stim_spikes_stim_plot = []
stim_ids_stim = []
stim_ids_all = []
stim_spikes_all = []
for i, t in enumerate(stim_times):
idx = np.searchsorted(spikes, [t + z1_base, t + z2_base], side='right')
tmp_spikes = spikes[idx[0]: idx[1]] - t
stim_ids_base.extend([i] * len(tmp_spikes))
stim_spikes_base_plot.extend(tmp_spikes)
stim_spikes_base.extend(spikes[idx[0]: idx[1]].tolist())
idx = np.searchsorted(spikes, [t + z1_stim, t + z2_stim], side='right')
tmp_spikes = spikes[idx[0]: idx[1]] - t
stim_ids_stim.extend([i] * len(tmp_spikes))
stim_spikes_stim_plot.extend(tmp_spikes)
stim_spikes_stim.extend(spikes[idx[0]: idx[1]].tolist())
idx = np.searchsorted(spikes, [t + t1, t + t2], side='right')
tmp_spikes = (spikes[idx[0]: idx[1]] - t).tolist()
stim_ids_all.extend([i] * len(tmp_spikes))
stim_spikes_all.extend(tmp_spikes)
stim_spikes_base_x = sx(stim_spikes_base)
stim_spikes_base_y = sy(stim_spikes_base)
# stim_in_field_indices_base = which_field(
# stim_spikes_base_x, stim_spikes_base_y, fields, box_size).astype(bool)
stim_spikes_stim_x = sx(stim_spikes_stim)
stim_spikes_stim_y = sy(stim_spikes_stim)
# stim_in_field_indices_stim = which_field(
# stim_spikes_stim_x, stim_spikes_stim_y, fields, box_size).astype(bool)
axs[0].scatter(stim_spikes_all, stim_ids_all, s=dot_size, color='k', alpha=.5)
axs[0].scatter(stim_spikes_base_plot, stim_ids_base, s=dot_size, color=colors[0], alpha=.8)
axs[0].scatter(stim_spikes_stim_plot, stim_ids_stim, s=dot_size, color=colors[1], alpha=.8)
times = np.arange(t1, t2, .1)
from scipy.stats import gaussian_kde
kernel = gaussian_kde(stim_spikes_all, 0.1)
pdf = kernel(times)
m = max(stim_ids_all)
pdf = (pdf - pdf.min()) / (pdf - pdf.min()).max() * m
axs[0].plot(times, pdf, 'k', lw=1)
axs[0].set_xlim(t1, t2)
# ax.plot(0, len(trials) * 1.1, ls='none', marker='v', color='k', markersize=5)
# axs[0].axvspan(0, 5, color='#43a2ca', alpha=.3, zorder=-5)
contours = measure.find_contours(fields, 0.0)
axs[1].scatter(
stim_spikes_base_x, stim_spikes_base_y,
s=dot_size, color=colors[0], zorder=1, alpha=.8)
axs[1].scatter(
stim_spikes_stim_x, stim_spikes_stim_y,
s=dot_size, color=colors[1], zorder=1, alpha=.8)
axs[1].plot(pos_x, pos_y, color='k', alpha=.2, zorder=0)
for n, contour in enumerate(contours):
axs[1].plot(
contour[:, 0] * bin_size, contour[:, 1] * bin_size,
lw=1, color='k', zorder=3)
axs[0].set_aspect((t2 - t1) / len(stim_times))
axs[1].axis('image')
axs[1].set_xticks([])
axs[1].set_yticks([])
despine(axs[0])
despine(axs[1], left=True, bottom=True)
compute_field_spikes(baseline_i.sort_values('gridness', ascending=False).iloc[18], plot=True)
baseline_ids = baseline_i.unit_day.unique()
baseline_ids
stimulated_11_sub = stimulated_11[stimulated_11.unit_day.isin(baseline_ids)]
len(stimulated_11_sub)
baseline_ids_11 = stimulated_11_sub.unit_day.unique()
baseline_i_sub = baseline_i[baseline_i.unit_day.isin(baseline_ids_11)]
iter_base = baseline_i_sub.sort_values('unit_day', ascending=False).itertuples()
iter_stim = stimulated_11_sub.sort_values('unit_day', ascending=False).itertuples()
for row_base, row_stim in zip(iter_base, iter_stim):
fields,_,_ = compute_field_spikes(row_base, plot=True)
compute_field_spikes(row_stim, plot=True)#, surrogate_fields=fields)
iter_base = baseline_i_sub.sort_values('unit_day', ascending=False).itertuples()
iter_stim = stimulated_11_sub.sort_values('unit_day', ascending=False).itertuples()
for row_base, row_stim in zip(iter_base, iter_stim):
fields,_,_ = compute_field_spikes(row_base, plot=True)
compute_field_spikes(row_stim, plot=True)#, surrogate_fields=fields)
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (5, 2),
'figure.dpi': 150
})
plot_stim_field_spikes(
stimulated_11.sort_values('gridness', ascending=False).iloc[18],
colors=['#2166ac', '#b2182b']#['#1b9e77','#d95f02']
)
fig = plt.gcf()
figname = 'stim_field_spikes_example'
fig.savefig(
output_path / 'figures' / f'{figname}.png',
bbox_inches='tight', transparent=True)
fig.savefig(
output_path / 'figures' / f'{figname}.svg',
bbox_inches='tight', transparent=True)
results_stim_stim_11 = []
z1_stim=5e-3
z2_stim=11e-3
z1_base=0
z2_base=5e-3
for row_stim in stimulated_11.itertuples():
_, _, base_in_field = compute_field_spikes(
row_stim, z1=z1_base, z2=z2_base)
_, _, stim_in_field = compute_field_spikes(
row_stim, z1=z1_stim, z2=z2_stim)
results_stim_stim_11.append({
'base_in_field': 1 - base_in_field.mean(),
'stim_in_field': 1 - stim_in_field.mean(),
'entity': row_stim.action.split('-')[0],
'unit_idnum': row_stim.unit_idnum
})
# break
results_stim_stim_11 = pd.DataFrame(results_stim_stim_11)
results_stim_stim_11
results_stim_stim_30 = []
z1_stim=5e-3
z2_stim=11e-3
z1_base=0
z2_base=5e-3
for row_stim in stimulated_30.itertuples():
_, _, base_in_field = compute_field_spikes(
row_stim, z1=z1_base, z2=z2_base)
_, _, stim_in_field = compute_field_spikes(
row_stim, z1=z1_stim, z2=z2_stim)
results_stim_stim_30.append({
'base_in_field': 1 - base_in_field.mean(),
'stim_in_field': 1 - stim_in_field.mean(),
'entity': row_stim.action.split('-')[0],
'unit_idnum': row_stim.unit_idnum
})
# break
results_stim_stim_30 = pd.DataFrame(results_stim_stim_30)
results_stim_stim_30
results_stim_stim_all = pd.concat([results_stim_stim_11, results_stim_stim_30])
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (1.7, 3),
'figure.dpi': 150
})
fig = plt.figure()
violinplot(
results_stim_stim_all.base_in_field,
results_stim_stim_all.stim_in_field,
colors=None,
draw_significance=False
)
figname = 'stim_field_spikes_combined'
fig.savefig(
output_path / 'figures' / f'{figname}.png',
bbox_inches='tight', transparent=True)
fig.savefig(
output_path / 'figures' / f'{figname}.svg',
bbox_inches='tight', transparent=True)
# 11
fig = plt.figure()
violinplot(
results_stim_stim_11.base_in_field,
results_stim_stim_11.stim_in_field,
colors=['#1b9e77','#d95f02'],
draw_significance=False
)
figname = 'stim_field_spikes_11'
fig.savefig(
output_path / 'figures' / f'{figname}.png',
bbox_inches='tight', transparent=True)
fig.savefig(
output_path / 'figures' / f'{figname}.svg',
bbox_inches='tight', transparent=True)
# 30
fig = plt.figure()
violinplot(
results_stim_stim_30.base_in_field,
results_stim_stim_30.stim_in_field,
colors=['#7570b3', '#e7298a'],
draw_significance=False
)
figname = 'stim_field_spikes_30'
fig.savefig(
output_path / 'figures' / f'{figname}.png',
bbox_inches='tight', transparent=True)
fig.savefig(
output_path / 'figures' / f'{figname}.svg',
bbox_inches='tight', transparent=True)
results_all = {
'Combined':results_stim_stim_all,
'11 Hz': results_stim_stim_11,
'30 Hz': results_stim_stim_30
}
stat, _ = make_statistics_table(results_all, ['base_in_field', 'stim_in_field'])
stat.to_latex(output_path / "statistics" / "statistics.tex")
stat.to_csv(output_path / "statistics" / "statistics.csv")
iter_base = baseline_i_sub.sort_values('unit_day', ascending=False).itertuples()
iter_stim = stimulated_11_sub.sort_values('unit_day', ascending=False).itertuples()
results = []
z1=5e-3
z2=11e-3
# z1=0
# z2=5e-3
for row_base, row_stim in zip(iter_base, iter_stim):
base_fields, base_in_field, _ = compute_field_spikes(
row_base, z1=z1, z2=z2)
stim_fields, stim_in_field, stim_stim_in_field = compute_field_spikes(
row_stim, z1=z1, z2=z2)
results.append({
'base_in_field': base_in_field.mean(),
'stim_in_field': stim_in_field.mean(),
'stim_stim_in_field': stim_stim_in_field.mean(),
'gridness_base': row_base.gridness,
'gridness_stim': row_stim.gridness,
'action_base': row_base.action,
'action_stim': row_stim.action
})
# break
results = pd.DataFrame(results)
results
plt.rc('axes', titlesize=12)
plt.rcParams.update({
'font.size': 12,
'figure.figsize': (3.5, 3),
'figure.dpi': 150
})
fig, ax = plt.subplots(1,1)
sc = ax.scatter(
results.base_in_field, results.stim_in_field,
c=results.gridness_base
# c=results.gridness_stim
)
ax.plot([0, 1], [0,1], 'k--')
plt.xlabel('Baseline percentage in field')
plt.ylabel('11 Hz percentage in field')
cb = plt.colorbar(mappable=sc, cax=None, ax=ax)
cb.ax.yaxis.set_ticks_position('right')
cb.set_label('Baseline gridness')
action = project.require_action("spikes-in-field")
copy_tree(output_path, str(action.data_path()))
septum_mec.analysis.registration.store_notebook(action, "20_spikes_in_field.ipynb")