%load_ext autoreload
%autoreload 2
import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns
import re
import shutil
import pandas as pd
import scipy.stats
import exdir
import expipe
from distutils.dir_util import copy_tree
import septum_mec
import spatial_maps as sp
import head_direction.head as head
import septum_mec.analysis.data_processing as dp
import septum_mec.analysis.registration
from septum_mec.analysis.plotting import violinplot, despine
from spatial_maps.fields import (
find_peaks, calculate_field_centers, separate_fields_by_laplace,
map_pass_to_unit_circle, calculate_field_centers, distance_to_edge_function,
compute_crossings, which_field)
from phase_precession import cl_corr
from spike_statistics.core import permutation_resampling
import matplotlib.mlab as mlab
import scipy.signal as ss
from scipy.interpolate import interp1d
from septum_mec.analysis.plotting import regplot
from skimage import measure
from tqdm.notebook import tqdm_notebook as tqdm
tqdm.pandas()
from scipy.stats import wilcoxon
# %matplotlib notebook
%matplotlib inline
project_path = dp.project_path()
project = expipe.get_project(project_path)
actions = project.actions
output_path = pathlib.Path("output") / "spikes-in-field"
(output_path / "statistics").mkdir(exist_ok=True, parents=True)
(output_path / "figures").mkdir(exist_ok=True, parents=True)
statistics_action = actions['calculate-statistics']
identification_action = actions['identify-neurons']
sessions = pd.read_csv(identification_action.data_path('sessions'))
units = pd.read_csv(identification_action.data_path('units'))
session_units = pd.merge(sessions, units, on='action')
statistics_results = pd.read_csv(statistics_action.data_path('results'))
statistics = pd.merge(session_units, statistics_results, how='left')
statistics.head()
statistics['unit_day'] = statistics.apply(lambda x: str(x.unit_idnum) + '_' + x.action.split('-')[1], axis=1)
stim_response_action = actions['stimulus-response']
stim_response_results = pd.read_csv(stim_response_action.data_path('results'))
statistics = pd.merge(statistics, stim_response_results, how='left')
print('N cells:',statistics.shape[0])
shuffling = actions['shuffling']
quantiles_95 = pd.read_csv(shuffling.data_path('quantiles_95'))
quantiles_95.head()
action_columns = ['action', 'channel_group', 'unit_name']
data = pd.merge(statistics, quantiles_95, on=action_columns, suffixes=("", "_threshold"))
data['specificity'] = np.log10(data['in_field_mean_rate'] / data['out_field_mean_rate'])
data.head()
data.groupby('stimulated').count()['action']
query = (
'gridness > gridness_threshold and '
'information_rate > information_rate_threshold and '
'gridness > .2 and '
'average_rate < 25'
)
sessions_above_threshold = data.query(query)
print("Number of sessions above threshold", len(sessions_above_threshold))
print("Number of animals", len(sessions_above_threshold.groupby(['entity'])))
once_a_gridcell = statistics[statistics.unit_day.isin(sessions_above_threshold.unit_day.values)]
print("Number of gridcells", once_a_gridcell.unit_idnum.nunique())
print("Number of gridcell recordings", len(once_a_gridcell))
print("Number of animals", len(once_a_gridcell.groupby(['entity'])))
baseline_i = once_a_gridcell.query('baseline and Hz11')
stimulated_11 = once_a_gridcell.query('stimulated and frequency==11 and stim_location=="ms"')
baseline_ii = once_a_gridcell.query('baseline and Hz30')
stimulated_30 = once_a_gridcell.query('stimulated and frequency==30 and stim_location=="ms"')
print("Number of gridcells in baseline i sessions", len(baseline_i))
print("Number of gridcells in stimulated 11Hz ms sessions", len(stimulated_11))
print("Number of gridcells in baseline ii sessions", len(baseline_ii))
print("Number of gridcells in stimulated 30Hz ms sessions", len(stimulated_30))
baseline_ids = baseline_i.unit_day.unique()
baseline_ids
stimulated_11_sub = stimulated_11[stimulated_11.unit_day.isin(baseline_ids)]
baseline_ids_11 = stimulated_11_sub.unit_day.unique()
baseline_i_sub = baseline_i[baseline_i.unit_day.isin(baseline_ids_11)]
max_speed = .5 # m/s only used for speed score
min_speed = 0.02 # m/s only used for speed score
position_sampling_rate = 100 # for interpolation
position_low_pass_frequency = 6 # for low pass filtering of position
box_size = [1.0, 1.0]
bin_size = 0.02
smoothing_low = 0.03
smoothing_high = 0.06
speed_binsize = 0.02
stim_mask = True
baseline_duration = 600
data_loader = dp.Data(
position_sampling_rate=position_sampling_rate,
position_low_pass_frequency=position_low_pass_frequency,
box_size=box_size, bin_size=bin_size,
stim_mask=stim_mask, baseline_duration=baseline_duration
)
def find_grid_fields(rate_map, sigma=3, seed=2.5):
# find fields with laplace
fields_laplace = sp.fields.separate_fields_by_dilation(rate_map, sigma=sigma, seed=seed)
fields = fields_laplace.copy() # to be cleaned by Ismakov
fields_areas = scipy.ndimage.measurements.sum(
np.ones_like(fields), fields, index=np.arange(fields.max() + 1))
fields_area = fields_areas[fields]
fields[fields_area < 9.0] = 0
# find fields with Ismakov-method
fields_ismakov, radius = sp.separate_fields_by_distance(rate_map)
fields_ismakov_real = fields_ismakov * bin_size
approved_fields = []
# remove fields not found by both methods
for point in fields_ismakov:
field_id = fields[tuple(point)]
approved_fields.append(field_id)
for field_id in np.arange(1, fields.max() + 1):
if not field_id in approved_fields:
fields[fields == field_id] = 0
return fields
def get_data(row):
spikes = data_loader.spike_train(row.action, row.channel_group, row.unit_name)
rate_map = data_loader.rate_map(row.action, row.channel_group, row.unit_name, smoothing=0.04)
pos_x, pos_y, pos_t, pos_speed = map(data_loader.tracking(row.action).get, ['x', 'y', 't', 'v'])
stim_times = data_loader.stim_times(row.action)
if stim_times is not None:
stim_times = np.array(stim_times)
spikes = np.array(spikes)
spikes = spikes[(spikes > pos_t.min()) & (spikes < pos_t.max())]
# sx, sy = rate_map.shape
# dx = box_size[0] / sx
# dy = box_size[1] / sy
# x_bins = np.arange(0, box_size[0], dx)
# y_bins = np.arange(0, box_size[1], dy)
# f = interp2d(x_bins, y_bins, rate_map.T)
# x_new = np.arange(0, box_size[0], dx / 3)
# y_new = np.arange(0, box_size[1], dy / 3)
# rate_map = f(x_new, y_new).T
fields = find_grid_fields(rate_map)
return spikes, pos_x, pos_y, pos_t, rate_map, fields, stim_times
def compute_field_spikes(row, plot=False, z1=5e-3, z2=11e-3, surrogate_fields=None):
spikes, pos_x, pos_y, pos_t, rate_map, fields, stim_times = get_data(row)
if surrogate_fields is not None:
fields = surrogate_fields
# if field_num is not None:
# fields = np.where(fields == field_num, fields, 0)
if plot:
fig, axs = plt.subplots(1, 3, figsize=(16,9))
axs[1].set_title(f'{row.action} {row.channel_group} {row.unit_idnum}, G={row.gridness:.3f}')
dot_size = 10
sx, sy = interp1d(pos_t, pos_x), interp1d(pos_t, pos_y)
stim_spikes = []
stim_in_field_indices = []
if stim_times is not None:
for t in stim_times:
idx = np.searchsorted(spikes, [t + z1, t + z2], side='right')
tmp_spikes = spikes[idx[0]: idx[1]].tolist()
stim_spikes.extend(tmp_spikes)
stim_spikes_x = sx(stim_spikes)
stim_spikes_y = sy(stim_spikes)
stim_in_field_indices = which_field(stim_spikes_x, stim_spikes_y, fields, box_size).astype(bool)
# stim_ids_ = []
# stim_spikes_ = []
# for i, t in enumerate(stim_times):
# idx = np.searchsorted(spikes, [t, t + 30e-3], side='right')
# tmp_spikes = (spikes[idx[0]: idx[1]] - t).tolist()
# stim_ids_.extend([i] * len(tmp_spikes))
# stim_spikes_.extend(tmp_spikes)
# plt.scatter(stim_spikes_, stim_ids_, s=1)
# plt.axvspan(z1, z2, color='r', alpha=.3)
spikes_x = sx(spikes)
spikes_y = sy(spikes)
in_field_indices = which_field(spikes_x, spikes_y, fields, box_size).astype(bool)
if plot:
axs[0].imshow(
fields.T.astype(bool), extent=[0, box_size[0], 0, box_size[1]],
origin='lower', cmap=plt.cm.Greys, zorder=0)
axs[0].scatter(
spikes_x[in_field_indices], spikes_y[in_field_indices],
s=dot_size, color='r', zorder=1)
axs[0].scatter(
spikes_x[~in_field_indices], spikes_y[~in_field_indices],
s=dot_size, color='b', zorder=1)
if stim_times is not None:
axs[0].scatter(
stim_spikes_x, stim_spikes_y,
s=dot_size, color='orange', zorder=1)
# Display the image and plot all contours found
contours = measure.find_contours(fields, 0.0)
axs[1].imshow(rate_map.T, extent=[0, box_size[0], 0, box_size[1]], origin='lower')
axs[2].plot(pos_x, pos_y, color='k', alpha=.2, zorder=0)
axs[2].scatter(
interp1d(pos_t, pos_x)(spikes), interp1d(pos_t, pos_y)(spikes),
s=dot_size, zorder=1)
for ax in axs.ravel()[1:]:
for n, contour in enumerate(contours):
ax.plot(
contour[:, 0] * bin_size, contour[:, 1] * bin_size,
lw=4, color='y', zorder=3)
for ax in axs.ravel():
ax.axis('image')
ax.set_xticks([])
ax.set_yticks([])
return fields, in_field_indices, stim_in_field_indices
def plot_stim_field_spikes(row, t1=0, t2=30, z1_base=0, z2_base=5, z1_stim=5, z2_stim=11, colors=['k','r']):
spikes, pos_x, pos_y, pos_t, rate_map, fields, stim_times = get_data(row)
spikes = np.array(spikes) * 1000
pos_t = np.array(pos_t) * 1000
stim_times = np.array(stim_times) * 1000
fig, axs = plt.subplots(1, 2)
dot_size = 2
sx, sy = interp1d(pos_t, pos_x), interp1d(pos_t, pos_y)
stim_spikes_base = []
stim_spikes_base_plot = []
stim_ids_base = []
stim_spikes_stim = []
stim_spikes_stim_plot = []
stim_ids_stim = []
stim_ids_all = []
stim_spikes_all = []
for i, t in enumerate(stim_times):
idx = np.searchsorted(spikes, [t + z1_base, t + z2_base], side='right')
tmp_spikes = spikes[idx[0]: idx[1]] - t
stim_ids_base.extend([i] * len(tmp_spikes))
stim_spikes_base_plot.extend(tmp_spikes)
stim_spikes_base.extend(spikes[idx[0]: idx[1]].tolist())
idx = np.searchsorted(spikes, [t + z1_stim, t + z2_stim], side='right')
tmp_spikes = spikes[idx[0]: idx[1]] - t
stim_ids_stim.extend([i] * len(tmp_spikes))
stim_spikes_stim_plot.extend(tmp_spikes)
stim_spikes_stim.extend(spikes[idx[0]: idx[1]].tolist())
idx = np.searchsorted(spikes, [t + t1, t + t2], side='right')
tmp_spikes = (spikes[idx[0]: idx[1]] - t).tolist()
stim_ids_all.extend([i] * len(tmp_spikes))
stim_spikes_all.extend(tmp_spikes)
stim_spikes_base_x = sx(stim_spikes_base)
stim_spikes_base_y = sy(stim_spikes_base)
# stim_in_field_indices_base = which_field(
# stim_spikes_base_x, stim_spikes_base_y, fields, box_size).astype(bool)
stim_spikes_stim_x = sx(stim_spikes_stim)
stim_spikes_stim_y = sy(stim_spikes_stim)
# stim_in_field_indices_stim = which_field(
# stim_spikes_stim_x, stim_spikes_stim_y, fields, box_size).astype(bool)
axs[0].scatter(stim_spikes_all, stim_ids_all, s=dot_size, color='k', alpha=.5)
axs[0].scatter(stim_spikes_base_plot, stim_ids_base, s=dot_size, color=colors[0], alpha=.8)
axs[0].scatter(stim_spikes_stim_plot, stim_ids_stim, s=dot_size, color=colors[1], alpha=.8)
times = np.arange(t1, t2, .1)
from scipy.stats import gaussian_kde
kernel = gaussian_kde(stim_spikes_all, 0.1)
pdf = kernel(times)
m = max(stim_ids_all)
pdf = (pdf - pdf.min()) / (pdf - pdf.min()).max() * m
axs[0].plot(times, pdf, 'k', lw=1)
axs[0].set_xlim(t1, t2)
# ax.plot(0, len(trials) * 1.1, ls='none', marker='v', color='k', markersize=5)
# axs[0].axvspan(0, 5, color='#43a2ca', alpha=.3, zorder=-5)
contours = measure.find_contours(fields, 0.0)
axs[1].scatter(
stim_spikes_base_x, stim_spikes_base_y,
s=dot_size, color=colors[0], zorder=1, alpha=.8)
axs[1].scatter(
stim_spikes_stim_x, stim_spikes_stim_y,
s=dot_size, color=colors[1], zorder=1, alpha=.8)
axs[1].plot(pos_x, pos_y, color='k', alpha=.2, zorder=0)
for n, contour in enumerate(contours):
axs[1].plot(
contour[:, 0] * bin_size, contour[:, 1] * bin_size,
lw=1, color='k', zorder=3)
axs[0].set_aspect((t2 - t1) / len(stim_times))
axs[1].axis('image')
axs[1].set_xticks([])
axs[1].set_yticks([])
despine(axs[0])
despine(axs[1], left=True, bottom=True)
compute_field_spikes(baseline_i.sort_values('gridness', ascending=False).iloc[18], plot=True)
iter_base = baseline_i_sub.sort_values('unit_day', ascending=False).itertuples()
iter_stim = stimulated_11_sub.sort_values('unit_day', ascending=False).itertuples()
for row_base, row_stim in zip(iter_base, iter_stim):
fields,_,_ = compute_field_spikes(row_base, plot=True)
compute_field_spikes(row_stim, plot=True)#, surrogate_fields=fields)