{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO: Pandarallel will run on 8 workers.\n", "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n" ] } ], "source": [ "import exdir\n", "import matplotlib.pyplot as plt\n", "import matplotlib.mlab as mlab\n", "from scipy.interpolate import interp1d\n", "import os\n", "import expipe\n", "import pathlib\n", "import numpy as np\n", "import pnnmec\n", "import spatial_maps as sp\n", "import head_direction.head as head\n", "import re\n", "import joblib\n", "import multiprocessing\n", "from distutils.dir_util import copy_tree\n", "import copy\n", "import pandas as pd\n", "from scipy.io import loadmat\n", "from spatial_maps.fields import (\n", " find_peaks, calculate_field_centers, separate_fields_by_laplace, \n", " map_pass_to_unit_circle, calculate_field_centers, distance_to_edge_function, \n", " which_field, compute_crossings)\n", "from phase_precession import cl_corr\n", "import pnnmec.spikes\n", "from spike_statistics.core import permutation_resampling\n", "import scipy\n", "import scipy.signal as ss\n", "from scipy.interpolate import interp1d\n", "from skimage import measure\n", "from tqdm.notebook import tqdm_notebook as tqdm\n", "tqdm.pandas()\n", "import seaborn as sns\n", "from pandarallel import pandarallel\n", "pandarallel.initialize(progress_bar=False)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "sample_rate = 250\n", "\n", "max_speed = .5 # m/s only used for speed score\n", "min_speed = 0.02 # m/s only used for speed score\n", "\n", "box_size = [1.0, 1.0]\n", "bin_size = 0.02\n", "\n", "speed_binsize = 0.02\n", "\n", "smoothing = 0.04" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# %matplotlib notebook\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "output_path = pathlib.Path(\"output\") / \"phase-precession\"\n", "(output_path / \"statistics\").mkdir(exist_ok=True, parents=True)\n", "(output_path / \"figures\").mkdir(exist_ok=True, parents=True)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "data_path = pathlib.Path('sargolini2006/all_data/')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "data = {}\n", "for fname in data_path.iterdir():\n", " if not fname.is_file():\n", " continue\n", " try:\n", " action, ftype = fname.stem.split('_')\n", " except Exception as e:\n", " print(fname)\n", " raise e\n", " if ftype == 'EGF':\n", " continue\n", " if action not in data:\n", " data[action] = {}\n", " data[action][ftype] = loadmat(fname)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def rm_nans(*args):\n", " \"\"\"\n", " Removes nan from all corresponding arrays\n", " Parameters\n", " ----------\n", " args : arrays, lists or quantities which should have removed nans in\n", " all the same indices\n", " Returns\n", " -------\n", " out : args with removed nans\n", " \"\"\"\n", " nan_indices = []\n", " for arg in args:\n", " nan_indices.extend(np.where(np.isnan(arg))[0].tolist())\n", " nan_indices = np.unique(nan_indices)\n", " out = []\n", " for arg in args:\n", " out.append(np.delete(arg, nan_indices))\n", " return out" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:18: DeprecationWarning: using a non-integer array as obj in delete will result in an error in the future\n" ] } ], "source": [ "analog, cells = [], []\n", "for k, v in data.items():\n", " if 'POS' not in v:\n", " continue\n", " if set(['EEG', 'EG2']).intersection(set(v.keys())) == set():\n", " continue\n", " x, y, t = rm_nans(v['POS']['posx'], v['POS']['posy'], v['POS']['post'])\n", " analog.append({\n", " 'action': k, \n", " 'eeg': None if 'EEG' not in v else v['EEG']['EEG'], \n", " 'eeg2': None if 'EG2' not in v else v['EG2']['EEG'],\n", " 'x': x,\n", " 'y': y,\n", " 't': t\n", " })\n", " for kk, vv in v.items():\n", " if kk.startswith('T'):\n", " cells.append({\n", " 'action': k,\n", " 'channel': kk[1],\n", " 'unit': kk[3:],\n", " 'spikes': vv['cellTS']\n", " })\n", "analog = pd.DataFrame(analog)\n", "cells = pd.DataFrame(cells)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Plotting" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from scipy.signal import butter, filtfilt\n", "\n", "def butter_bandpass(lowcut, highcut, fs, order=5):\n", " nyq = 0.5 * fs\n", " low = lowcut / nyq\n", " high = highcut / nyq\n", " b, a = butter(order, [low, high], btype='band')\n", " return b, a\n", "\n", "\n", "def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):\n", " b, a = butter_bandpass(lowcut, highcut, fs, order=order)\n", " y = filtfilt(b, a, data)\n", " return y" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def compute_spike_phase(lfp, times):\n", " x_a = ss.hilbert(lfp)\n", " x_phase = np.angle(x_a)\n", " return interp1d(times, x_phase)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def find_grid_fields(rate_map, sigma=3, seed=2.5):\n", " # find fields with laplace\n", " fields_laplace = sp.fields.separate_fields_by_dilation(rate_map, sigma=sigma, seed=seed)\n", " fields = fields_laplace.copy() # to be cleaned by Ismakov\n", " fields_areas = scipy.ndimage.measurements.sum(\n", " np.ones_like(fields), fields, index=np.arange(fields.max() + 1))\n", " fields_area = fields_areas[fields]\n", " fields[fields_area < 9.0] = 0\n", "\n", " # find fields with Ismakov-method\n", " fields_ismakov, radius = sp.separate_fields_by_distance(rate_map)\n", " fields_ismakov_real = fields_ismakov * bin_size\n", " approved_fields = []\n", "\n", " # remove fields not found by both methods\n", " for point in fields_ismakov:\n", " field_id = fields[tuple(point)]\n", " approved_fields.append(field_id)\n", "\n", " for field_id in np.arange(1, fields.max() + 1):\n", " if not field_id in approved_fields:\n", " fields[fields == field_id] = 0\n", " \n", " return fields" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def normalize(a):\n", " _a = a - a.min()\n", " return _a / _a.max()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def distance(x, y):\n", " _x = x - x.min()\n", " _y = y - y.min()\n", " dx, dy = np.diff(_x), np.diff(_y)\n", " s = np.sqrt(dx**2 + dy**2)\n", " distance = np.cumsum(s) \n", " # first index is distance from first point, \n", " # to match len(x) we put a zero as first index to initialize distance 0\n", " return np.concatenate(([0], distance))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def model(x, slope, phi0):\n", " return 2 * np.pi * slope * x + phi0" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# channel_to_eeg = {str(i): '' if i < 4 else '2' for i in range(1,9)}\n", "\n", "channel_to_eeg = {str(i): '' if i < 4 else '2' for i in reversed(range(1,9))}" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'8': '2', '7': '2', '6': '2', '5': '2', '4': '2', '3': '', '2': '', '1': ''}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "channel_to_eeg" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def compute_statistics(row):\n", " a = analog.query(f'action==\"{row.action}\"')\n", " lfp = a.get(f'eeg{channel_to_eeg[row.channel]}')\n", " get_values = lambda x: a.get(x).values[0]\n", " pos_x, pos_y, pos_t = map(get_values, ['x', 'y', 't'])\n", " box_size_, bin_size_ = sp.maps._adjust_bin_size(box_size=box_size, bin_size=bin_size)\n", " xbins, ybins = sp.maps._make_bins(box_size_, bin_size_)\n", " \n", " pos_x, pos_y, pos_t = pos_x.ravel(), pos_y.ravel(), pos_t.ravel()\n", " \n", " if sum(np.isnan(np.concatenate([pos_x, pos_y]))) > 0:\n", " print(\n", " 'Nan in position',\n", " f'nnanx = {sum(np.isnan(pos_x))}',\n", " f'nnany = {sum(np.isnan(pos_y))}',\n", " f'shape = {pos_x.shape}', \n", " row.action, row.channel, row.unit)\n", " return\n", " \n", " pos_x, pos_y = normalize(pos_x.ravel()), normalize(pos_y.ravel())\n", " \n", " occupancy_map = sp.maps._occupancy_map(pos_x, pos_y, pos_t, xbins, ybins)\n", " \n", " occupancy_map = sp.maps.smooth_map(occupancy_map, bin_size=bin_size_, smoothing=smoothing)\n", "\n", " spikes = row.spikes.ravel()\n", "\n", " # common\n", " spike_map = sp.maps._spike_map(pos_x, pos_y, pos_t, spikes, xbins, ybins)\n", "\n", " spike_map = sp.maps.smooth_map(spike_map, bin_size=bin_size_, smoothing=smoothing)\n", "\n", " rate_map = spike_map / occupancy_map\n", " \n", " fields = find_grid_fields(rate_map, sigma=3, seed=2.5)\n", " \n", " prob_dist = sp.stats.prob_dist(pos_x, pos_y, bins=(xbins, ybins))\n", " \n", " average_rate = len(spikes) / (pos_t.max() - pos_t.min())\n", " \n", " max_rate = rate_map.max()\n", "\n", " out_field_mean_rate = rate_map[np.where(fields == 0)].mean()\n", " in_field_mean_rate = rate_map[np.where(fields != 0)].mean()\n", " max_field_mean_rate = rate_map[np.where(fields == 1)].mean()\n", "\n", " interspike_interval = np.diff(spikes)\n", " interspike_interval_cv = interspike_interval.std() / interspike_interval.mean()\n", "\n", " autocorrelogram = sp.autocorrelation(rate_map)\n", " peaks = sp.fields.find_peaks(autocorrelogram)\n", " real_peaks = peaks * bin_size\n", " autocorrelogram_box_size = np.array(box_size) * autocorrelogram.shape[0] / rate_map.shape[0]\n", " spacing, orientation = sp.spacing_and_orientation(real_peaks, autocorrelogram_box_size)\n", " orientation *= 180 / np.pi\n", "\n", " selectivity = sp.stats.selectivity(rate_map, prob_dist)\n", "\n", " sparsity = sp.stats.sparsity(rate_map, prob_dist)\n", "\n", " gridness = sp.gridness(rate_map)\n", "\n", " information_rate = sp.stats.information_rate(rate_map, prob_dist)\n", "\n", " single_spikes, bursts, bursty_spikes = pnnmec.spikes.find_bursts(spikes, threshold=0.01)\n", " burst_event_ratio = np.sum(bursts) / (np.sum(single_spikes) + np.sum(bursts))\n", " bursty_spike_ratio = np.sum(bursty_spikes) / (np.sum(bursty_spikes) + np.sum(single_spikes))\n", " mean_spikes_per_burst = np.sum(bursty_spikes) / np.sum(bursts)\n", " results = {\n", " 'action': row.action,\n", " 'channel': row.channel,\n", " 'unit': row.unit,\n", " 'average_rate': average_rate,\n", " 'out_field_mean_rate': out_field_mean_rate,\n", " 'in_field_mean_rate': in_field_mean_rate,\n", " 'max_field_mean_rate': max_field_mean_rate,\n", " 'max_rate': max_rate,\n", " 'sparsity': sparsity,\n", " 'selectivity': selectivity,\n", " 'interspike_interval_cv': interspike_interval_cv,\n", " 'burst_event_ratio': burst_event_ratio,\n", " 'bursty_spike_ratio': bursty_spike_ratio,\n", " 'gridness': gridness,\n", " 'information_rate': information_rate,\n", " 'spacing': spacing,\n", " 'orientation': orientation\n", " }\n", " return results" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:45: RuntimeWarning: Mean of empty slice.\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:68: RuntimeWarning: invalid value encountered in long_scalars\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/ipykernel_launcher.py:68: RuntimeWarning: invalid value encountered in long_scalars\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: divide by zero encountered in log2\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n", "/home/mikkel/apps/expipe-project/spatial-maps/spatial_maps/stats.py:13: RuntimeWarning: invalid value encountered in multiply\n", " return (np.nansum(np.ravel(tmp_rate_map * np.log2(tmp_rate_map/avg_rate) *\n" ] } ], "source": [ "cells_stats = cells.parallel_apply(compute_statistics, axis=1, result_type='expand')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "cells = cells.merge(cells_stats, on=['action', 'channel', 'unit'])" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def plot_spikes_and_rate_map(row, rate_map, fields, pos_x, pos_y, sy, sx, spikes, box_size, bin_size, dot_size, flim, output_path, save, axs):\n", " contours = measure.find_contours(fields, 0.8)\n", "\n", " # Display the image and plot all contours found\n", " axs[1][0].imshow(rate_map.T, extent=[0, box_size[0], 0, box_size[1]], origin='lower')\n", " axs[1][1].plot(pos_x, pos_y, color='k', alpha=.2, zorder=1000)\n", " axs[1][1].scatter(sx(spikes), sy(spikes), s=dot_size, zorder=10001)\n", "\n", " for ax in axs.ravel()[1:]:\n", " for n, contour in enumerate(contours):\n", " ax.plot(contour[:, 0] * bin_size, contour[:, 1] * bin_size, linewidth=2)\n", "\n", " for ax in axs.ravel()[1:]:\n", " ax.axis('image')\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " \n", " if save:\n", " figname = f'{row.action}_{row.channel}_{row.unit_id}_f{flim[0]}-{flim[1]}'\n", " fig.savefig(\n", " output_path / 'figures' / f'{figname}.png', \n", " bbox_inches='tight', transparent=True)\n", " fig.savefig(\n", " output_path / 'figures' / f'{figname}.svg', \n", " bbox_inches='tight', transparent=True)\n", " \n", "def plot_spike_phase(spike_dist, spike_phase, dot_size, slope, phi0, circ_lin_corr, pval, RR, plot_regression_line, ax):\n", " p = ax.scatter(spike_dist, spike_phase, s=dot_size)\n", " ax.scatter(\n", " spike_dist, spike_phase + 2 * np.pi, \n", " s=dot_size, color=p.get_facecolor()[0])\n", " ax.set_yticks([-np.pi, np.pi, 3*np.pi])\n", " ax.set_yticklabels([r'$-\\pi$', r'$\\pi$', r'$3\\pi$'])\n", " if plot_regression_line:\n", " line_fit = model(spike_dist, slope, phi0)\n", " ax.plot(spike_dist, line_fit, lw=2, label=\n", " f'corr = {circ_lin_corr:.3f}, '\n", " f'pvalue = {pval:.3f}, '\n", " f'R = {RR:.3f}')" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def compute_rate_crossings(spikes, times, threshold=0):\n", " from elephant.statistics import instantaneous_rate\n", " from elephant.kernels import GaussianKernel\n", " import quantities as pq\n", " import neo\n", " spikes = neo.SpikeTrain(spikes, t_stop=times[-1], units='s')\n", " \n", " kernel = GaussianKernel(100 * pq.ms)\n", " \n", " rate = instantaneous_rate(spikes, np.diff(times).min() * pq.s, kernel=kernel)\n", " \n", " rate = rate.magnitude\n", " mean = np.mean(rate)\n", " std = np.std(rate)\n", " indices = np.where(rate > mean + std * threshold)[0]\n", " \n", " field_indices = np.concatenate(([0], indices.astype(int), [0]))\n", " enter, = np.where(np.diff(field_indices) > 1)\n", " exit = enter[1:] - 1\n", " return rate, indices, indices[enter], indices[exit]\n", "\n", "\n", "def plot_map_spikes(rate_map, fields, spikes, x, y, t, box_size, dot_size=1, axs=None):\n", " if axs is None:\n", " fig, axs = plt.subplots(1, 2)\n", " contours = measure.find_contours(fields, 0.8)\n", " sx, sy = interp1d(t, x), interp1d(t, y)\n", " # Display the image and plot all contours found\n", " axs[0].imshow(rate_map.T, extent=[0, box_size[0], 0, box_size[1]], origin='lower')\n", " axs[1].plot(x, y, color='k', alpha=.2, zorder=1000)\n", " axs[1].scatter(sx(spikes), sy(spikes), s=dot_size, zorder=10001)\n", "\n", " for ax in axs:\n", " for n, contour in enumerate(contours):\n", " ax.plot(contour[:, 0] * bin_size, contour[:, 1] * bin_size, linewidth=2)\n", "\n", " for ax in axs:\n", " ax.axis('image')\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " return axs" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def compute_data(row, flim=[6,10]):\n", " a = analog.query(f'action==\"{row.action}\"')\n", " lfp = a.get(f'eeg{channel_to_eeg[row.channel]}').values[0]\n", " if lfp is None:\n", " print('Wrong hemisphere?', row.action, row.channel, row.unit)\n", " l = ['eeg', 'eeg2']\n", " l.pop(l.index(f'eeg{channel_to_eeg[row.channel]}'))\n", " lfp = a.get(l[0]).values[0]\n", " if lfp is not None:\n", " print('Warning: using lfp from other hemisphere')\n", " else:\n", " return [None] * 7\n", " lfp = lfp.ravel()\n", " get_values = lambda x: a.get(x).values[0]\n", " pos_x, pos_y, pos_t = map(get_values, ['x', 'y', 't'])\n", " \n", " pos_x, pos_y, pos_t = pos_x.ravel(), pos_y.ravel(), pos_t.ravel()\n", " pos_x, pos_y = normalize(pos_x.ravel()), normalize(pos_y.ravel())\n", " \n", " box_size_, bin_size_ = sp.maps._adjust_bin_size(box_size=box_size, bin_size=bin_size)\n", " xbins, ybins = sp.maps._make_bins(box_size_, bin_size_)\n", " occupancy_map = sp.maps._occupancy_map(pos_x, pos_y, pos_t, xbins, ybins)\n", " \n", " occupancy_map = sp.maps.smooth_map(occupancy_map, bin_size=bin_size_, smoothing=smoothing)\n", "\n", " spikes = row.spikes\n", " \n", " spikes = spikes[(spikes >= pos_t[0]) & (spikes <= pos_t[-1])]\n", "\n", " # common\n", " spike_map = sp.maps._spike_map(pos_x, pos_y, pos_t, spikes, xbins, ybins)\n", "\n", " spike_map = sp.maps.smooth_map(spike_map, bin_size=bin_size_, smoothing=smoothing)\n", "\n", " rate_map = spike_map / occupancy_map\n", "\n", " filtered_lfp = butter_bandpass_filter(\n", " lfp, *flim, fs=sample_rate, order=3)\n", " \n", " times = np.arange(len(lfp)) / sample_rate\n", " spike_phase_func = compute_spike_phase(filtered_lfp, times)\n", " \n", " fields = find_grid_fields(rate_map, sigma=3, seed=2.5)\n", " \n", " return spike_phase_func, spikes, pos_x, pos_y, pos_t, rate_map, fields" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def compute_phase_precession_runs(row, flim=[6, 10], field_num=None, slope_limit_dist=[-1,1], slope_limit_dur=[-1,1], norm=False, store_runs=False,\n", " plot=False, plot_regression_line=True, save=False, dot_size=1, crossing_type='field'):\n", " '''\n", " flim : [flow, fhigh]\n", " field_num : int\n", " which field to plot\n", " reg_limit : [low, high]\n", " Regression intervall\n", " plot : str\n", " \"all\" or \"scatter\"\n", " plot_regression_line : bool\n", " plot the regression line in scatter\n", " save : bool\n", " save plot\n", " dot_size : size of scatter dots\n", " crossing_type : str\n", " type of signal to compute crossings must be \"rate\" or \"field\"\n", " '''\n", " spike_phase_func, spikes, pos_x, pos_y, pos_t, rate_map, fields = compute_data(\n", " row, flim)\n", " if spike_phase_func is None:\n", " return\n", " \n", " if field_num is not None:\n", " fields = np.where(fields == field_num, fields, 0)\n", " \n", " if crossing_type == 'field':\n", " in_field_indices = which_field(pos_x, pos_y, fields, box_size)\n", " in_field_enter, in_field_exit = compute_crossings(in_field_indices)\n", " elif crossing_type == 'rate':\n", " _, _, in_field_enter, in_field_exit = compute_rate_crossings(spikes, pos_t)\n", " \n", " spikes = np.array(spikes)\n", " spikes = spikes[(spikes > pos_t.min()) & (spikes < pos_t.max())]\n", "\n", " if plot == 'all':\n", " fig, axs = plt.subplots(2, 2)\n", " plt.suptitle(f'{row.action} {row.channel} {row.unit}')\n", " elif plot == 'scatter':\n", " fig, ax = plt.subplots(1, 1)\n", " axs = [[ax]]\n", " ax.set_title(f'{row.action} {row.channel} {row.unit}')\n", "\n", " sx, sy = interp1d(pos_t, pos_x), interp1d(pos_t, pos_y)\n", " results = []\n", " max_dist = 0\n", " for en, ex in zip(in_field_enter, in_field_exit):\n", " x, y, t = pos_x[en:ex+1], pos_y[en:ex+1], pos_t[en:ex+1]\n", " if len(t) <= 1:\n", " continue\n", " s = spikes[(spikes > t[0]) & (spikes < t[-1])]\n", " if len(s) < 5:\n", " continue\n", "\n", " spike_phase = spike_phase_func(s)\n", " \n", " dist = distance(x, y)\n", " if norm:\n", " t_to_dist = interp1d(t, normalize(dist))\n", " else:\n", " t_to_dist = interp1d(t, dist)\n", " \n", " spike_dist = t_to_dist(s)\n", " spike_dur = s - t[0]\n", " \n", " circ_lin_corr_dist, pval_dist, slope_dist, phi0_dist, RR_dist = cl_corr(\n", " spike_dist, spike_phase, *slope_limit_dist, return_pval=True)\n", " circ_lin_corr_dur, pval_dur, slope_dur, phi0_dur, RR_dur = cl_corr(\n", " spike_dur, spike_phase, *slope_limit_dur, return_pval=True)\n", " result_run = {\n", " 'action': row.action, \n", " 'channel': row.channel, \n", " 'unit': row.unit,\n", " 'slope_limit_dist': slope_limit_dist,\n", " 'slope_limit_dur': slope_limit_dur,\n", " 'flim': flim,\n", " 'circ_lin_corr_dist': circ_lin_corr_dist, \n", " 'pval_dist': pval_dist, \n", " 'slope_dist': slope_dist, \n", " 'phi0_dist': phi0_dist, \n", " 'RR_dist': RR_dist,\n", " 'circ_lin_corr_dur': circ_lin_corr_dur, \n", " 'pval_dur': pval_dur, \n", " 'slope_dur': slope_dur, \n", " 'phi0_dur': phi0_dur, \n", " 'RR_dur': RR_dur,\n", " }\n", " if store_runs:\n", " result_run.update({\n", " 'spike_dist': spike_dist,\n", " 'spike_dur': spike_dur,\n", " 'spike_phase': spike_phase\n", " })\n", " results.append(result_run)\n", " if plot:\n", " plot_spike_phase(spike_dist, spike_phase, dot_size, slope_dist, phi0_dist, circ_lin_corr_dist, pval_dist, RR_dist, plot_regression_line, axs[0][0])\n", " \n", " if plot == 'all':\n", " axs[0][1].plot(x, y)\n", " axs[0][1].scatter(sx(s), sy(s), s=dot_size, color='r', zorder=100000)\n", "\n", " if plot == 'all':\n", " plot_spikes_and_rate_map(row, rate_map, fields, pos_x, pos_y, sy, sx, spikes, box_size, bin_size, dot_size, flim, output_path, save, axs)\n", " \n", " return results" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "plt.rc('axes', titlesize=12)\n", "plt.rcParams.update({\n", " 'font.size': 12, \n", " 'figure.figsize': (8, 6), \n", " 'figure.dpi': 150\n", "})" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "row = cells.sort_values('gridness', ascending=False).iloc[0]" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mikkel/.virtualenvs/expipe/lib/python3.6/site-packages/elephant/statistics.py:835: UserWarning: Instantaneous firing rate approximation contains negative values, possibly caused due to machine precision errors.\n", " warnings.warn(\"Instantaneous firing rate approximation contains \"\n", "/home/mikkel/apps/expipe-project/phase-precession/phase_precession/corr_cc.py:35: RuntimeWarning: invalid value encountered in double_scalars\n", " rho = num / den\t# correlation coefficient\n", "/home/mikkel/apps/expipe-project/phase-precession/phase_precession/corr_cc.py:41: RuntimeWarning: invalid value encountered in double_scalars\n", " ts = np.sqrt((n * l20 * l02) / l22) * rho\n", "/home/mikkel/apps/expipe-project/phase-precession/phase_precession/corr_cc.py:84: RuntimeWarning: invalid value encountered in double_scalars\n", " rho = n* (R_aminusb - R_aplusb) / den\n", "/home/mikkel/apps/expipe-project/phase-precession/phase_precession/corr_cc.py:92: RuntimeWarning: invalid value encountered in double_scalars\n", " ts = np.sqrt((n * l20 * l02) / l22) * rho\n" ] }, { "data": { "text/html": [ "
\n", " | action | \n", "channel | \n", "unit | \n", "slope_limit_dist | \n", "slope_limit_dur | \n", "flim | \n", "circ_lin_corr_dist | \n", "pval_dist | \n", "slope_dist | \n", "phi0_dist | \n", "RR_dist | \n", "circ_lin_corr_dur | \n", "pval_dur | \n", "slope_dur | \n", "phi0_dur | \n", "RR_dur | \n", "spike_dist | \n", "spike_dur | \n", "spike_phase | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "-0.057014 | \n", "0.878528 | \n", "-19.999995 | \n", "2.907198 | \n", "0.165186 | \n", "0.064951 | \n", "0.836527 | \n", "-0.958127 | \n", "4.483757 | \n", "0.388724 | \n", "[0.0010296753702724094, 0.001817031819833984, ... | \n", "[0.02847916666666306, 0.0478541666666632, 0.32... | \n", "[1.2000947270293003, -2.3461875911029755, 2.72... | \n", "
1 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "0.131524 | \n", "0.563850 | \n", "3.885511 | \n", "0.086344 | \n", "0.334172 | \n", "-0.072538 | \n", "0.752826 | \n", "-0.406430 | \n", "2.614404 | \n", "0.244867 | \n", "[0.049559492047359446, 0.09828326225275039, 0.... | \n", "[0.0644270833333227, 0.15601041666665605, 0.17... | \n", "[1.7077833533155713, 2.1292424841744704, -0.91... | \n", "
2 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "0.188106 | \n", "0.283691 | \n", "15.242794 | \n", "3.825074 | \n", "0.226141 | \n", "0.182584 | \n", "0.323239 | \n", "0.296876 | \n", "4.297967 | \n", "0.233404 | \n", "[0.0031409918384566787, 0.0037455194326765674,... | \n", "[0.08618749999998876, 0.09827083333332176, 0.1... | \n", "[-0.856849038171823, 0.8129509104365495, 1.751... | \n", "
3 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "-0.163215 | \n", "0.658520 | \n", "-3.456314 | \n", "1.836874 | \n", "0.309684 | \n", "-0.129448 | \n", "0.728305 | \n", "-0.982799 | \n", "1.897397 | \n", "0.289831 | \n", "[0.008278960511390213, 0.011733955097490707, 0... | \n", "[0.02919791666667315, 0.05157291666667341, 0.1... | \n", "[1.706753592099695, -1.4574268368886827, 1.170... | \n", "
4 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "-0.115085 | \n", "0.641511 | \n", "-5.432632 | \n", "6.027699 | \n", "0.196856 | \n", "-0.092728 | \n", "0.709784 | \n", "-0.699412 | \n", "5.567881 | \n", "0.212273 | \n", "[0.017771141003186743, 0.04181666630671976, 0.... | \n", "[0.09056249999999721, 0.20097916666666293, 0.2... | \n", "[0.10783296702320369, 2.496075744564212, -2.38... | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
104 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "0.445130 | \n", "0.295925 | \n", "13.499483 | \n", "4.692939 | \n", "0.528767 | \n", "0.409080 | \n", "0.333955 | \n", "0.688516 | \n", "4.652379 | \n", "0.504698 | \n", "[0.0032512325138726733, 0.0035440202351017212,... | \n", "[0.08663541666658148, 0.09234374999994088, 0.1... | \n", "[-2.2651495674824766, -1.140410096840983, 0.33... | \n", "
105 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "0.079285 | \n", "0.771025 | \n", "-11.773172 | \n", "2.538374 | \n", "0.215287 | \n", "-0.069550 | \n", "0.787010 | \n", "-0.164432 | \n", "5.055850 | \n", "0.217357 | \n", "[0.015594637525964256, 0.019669560707982706, 0... | \n", "[0.05721874999994725, 0.07401041666662422, 0.0... | \n", "[2.4722928434603078, -1.3131957520032655, 0.44... | \n", "
106 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "0.452029 | \n", "0.131712 | \n", "8.737603 | \n", "5.069350 | \n", "0.653502 | \n", "0.459736 | \n", "0.141185 | \n", "0.605074 | \n", "5.505432 | \n", "0.617714 | \n", "[0.0017002443178809272, 0.0018301956937179047,... | \n", "[0.05375000000003638, 0.06277083333327482, 0.0... | \n", "[0.2297067793093115, -2.041161329643188, -0.41... | \n", "
107 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "0.699380 | \n", "0.083434 | \n", "7.217852 | \n", "4.917169 | \n", "0.760941 | \n", "0.682985 | \n", "0.088925 | \n", "0.999996 | \n", "6.024202 | \n", "0.753521 | \n", "[0.03589687648547473, 0.048606408433837324, 0.... | \n", "[0.11198958333329756, 0.16214583333328392, 0.1... | \n", "[0.22944773386426603, 0.7291734013827129, 2.44... | \n", "
108 | \n", "11138-20040502 | \n", "3 | \n", "2 | \n", "[-20, 20] | \n", "[-1, 1] | \n", "[20, 25] | \n", "0.237864 | \n", "0.164136 | \n", "6.027968 | \n", "5.642149 | \n", "0.286743 | \n", "0.345264 | \n", "0.048720 | \n", "0.183132 | \n", "5.899758 | \n", "0.284605 | \n", "[0.010510995966230776, 0.027472739743439972, 0... | \n", "[0.04120833333331575, 0.10558333333335668, 0.1... | \n", "[0.13522186477845663, 0.7241862965321739, -0.4... | \n", "
109 rows × 19 columns
\n", "