"""
Unitary Event Analysis (UEA) plots
----------------------------------
Standard plot function for pairwise unitary event analysis results resembling
the original publication. The output is assumed to be from
:func:`elephant.unitary_event_analysis.jointJ_window_analysis` function.
.. autosummary::
:toctree: toctree/unitary_event_analysis/
plot_ue
"""
# Copyright 2017-2023 by the Viziphant team, see `doc/authors.rst`.
# License: Modified BSD, see LICENSE.txt.txt for details.
import numpy as np
import quantities as pq
import matplotlib.pyplot as plt
import string
import elephant.unitary_event_analysis as ue
from collections import namedtuple
FigureUE = namedtuple("FigureUE", ['axes_spike_events',
'axes_spike_rates',
'axes_coincident_events',
'axes_coincidence_rates',
'axes_significance',
'axes_unitary_events'])
plot_params_default = {
# epochs to be marked on the time axis
'events': {},
# figure size
'figsize': (10, 12),
# right margin
'right': 0.9,
# top margin
'top': 0.9,
# bottom margin
'bottom': 0.1,
# left margin
'left': 0.1,
# horizontal white space between subplots
'hspace': 0.5,
# width white space between subplots
'wspace': 0.5,
# font size
'fsize': 12,
# the actual unit ids from the experimental recording
'unit_real_ids': None,
# line width
'lw': 2,
# marker size for the UEs and coincidences
'ms': 5,
# figure title
'suptitle': None,
}
[docs]
def plot_ue(spiketrains, Js_dict, significance_level=0.05,
**plot_params):
"""
Plots the results of pairwise unitary event analysis as a column of six
subplots, comprised of raster plot, peri-stimulus time histogram,
coincident event plot, coincidence rate plot, significance plot and
unitary event plot, respectively.
Parameters
----------
spiketrains : list of list of neo.SpikeTrain
A nested list of trials, neurons and their neo.SpikeTrain objects,
respectively. This should be identical to the one used to generate
Js_dict.
Js_dict : dict
The output of
:func:`elephant.unitary_event_analysis.jointJ_window_analysis`
function. The values of each key has the shape of:
* different window --> 0-axis.
* different pattern hash --> 1-axis;
Dictionary keys:
'Js': list of float
JointSurprise of different given patterns within each window.
'indices': list of list of int
A list of indices of pattern within each window.
'n_emp': list of int
The empirical number of each observed pattern.
'n_exp': list of float
The expected number of each pattern.
'rate_avg': list of float
The average firing rate of each neuron.
significance_level : float
The significance threshold used to determine which coincident events
are classified as unitary events within a window.
**plot_params
User-defined plotting parameters used to update the default plotting
parameter values. The valid keys:
'events' : dict
Epochs to be marked on the time axis.
'figsize' : tuple of int
The dimensions for the figure size.
'right' : float
The size of the right margin.
'top' : float
The size of the top margin.
'bottom' : float
The size of the bottom margin.
'left' : float
The size of the left margin.
'hspace' : flaot
The size of the horizontal white space between subplots.
'wspace' : float
The width of the white space between subplots.
'fsize' : int
The size of the font.
'unit_real_ids' : list of int
The unit ids from the experimental recording.
'lw' : int
The default line width.
'ms' : int
The marker size for the unitary events and coincidences.
Returns
-------
result : FigureUE
The container for Axes objects generated by the function. Individual
axes can be accessed using the following identifiers:
* axes_spike_events : matplotlib.axes.Axes
Contains the elements of the spike events subplot.
* axes_spike_rates : matplotlib.axes.Axes
Contains the elements of the spike rates subplot.
* axes_coincident_events : matplotlib.axes.Axes
Contains the elements of the coincident events subplot.
* axes_coincidence_rates : matplotlib.axes.Axes
Contains the elements of the coincidence rates subplot.
* axes_significance : matplotlib.axes.Axes
Contains the elements of the statistical significance subplot.
* axes_unitary_events : matplotlib.axes.Axes
Contains the elements of the unitary events subplot.
Examples
--------
Unitary Events of homogenous Poisson random processes.
Since we don't expect to find significant correlations in random processes,
we show non-significant events (``significance_level=0.34``). Typically,
in your analyses, the significant level threshold is ~0.05.
.. plot::
:include-source:
import matplotlib.pyplot as plt
import numpy as np
import quantities as pq
import viziphant
from elephant.spike_train_generation import homogeneous_poisson_process
from elephant.unitary_event_analysis import jointJ_window_analysis
np.random.seed(10)
spiketrains1 = [homogeneous_poisson_process(rate=20 * pq.Hz,
t_stop=2 * pq.s) for _ in range(5)]
spiketrains2 = [homogeneous_poisson_process(rate=50 * pq.Hz,
t_stop=2 * pq.s) for _ in range(5)]
spiketrains = np.stack((spiketrains1, spiketrains2), axis=1)
ue_dict = jointJ_window_analysis(spiketrains,
bin_size=5 * pq.ms,
win_size=100 * pq.ms,
win_step=10 * pq.ms)
viziphant.unitary_event_analysis.plot_ue(spiketrains, Js_dict=ue_dict,
significance_level=0.34,
unit_real_ids=['1', '2'])
plt.show()
Refer to `UEA Tutorial <https://elephant.readthedocs.io/en/latest/
tutorials/unitary_event_analysis.html>`_ for real-case scenario.
"""
n_trials = len(spiketrains)
n_neurons = len(spiketrains[0])
input_parameters = Js_dict['input_parameters']
t_start = input_parameters['t_start']
t_stop = input_parameters['t_stop']
bin_size = input_parameters['bin_size']
win_size = input_parameters['win_size']
win_step = input_parameters['win_step']
pattern_hash = input_parameters['pattern_hash']
if len(pattern_hash) > 1:
raise ValueError(f"To not clutter the plots, only one pattern hash is "
f"required; got {pattern_hash}. You can call this "
f"function multiple times for each hash at a time.")
for key in ['Js', 'n_emp', 'n_exp', 'rate_avg']:
Js_dict[key] = Js_dict[key].squeeze()
neurons_participated = ue.inverse_hash_from_pattern(pattern_hash,
N=n_neurons).squeeze()
t_winpos = ue._winpos(t_start=t_start, t_stop=t_stop, win_size=win_size,
win_step=win_step)
Js_sig = ue.jointJ(significance_level)
# figure format
plot_params_user = plot_params
plot_params = plot_params_default.copy()
plot_params.update(plot_params_user)
if plot_params['unit_real_ids'] is None:
plot_params['unit_real_ids'] = ['not specified'] * n_neurons
if len(plot_params['unit_real_ids']) != n_neurons:
raise ValueError('length of unit_ids should be ' +
'equal to number of neurons!')
plt.rcParams.update({'font.size': plot_params['fsize']})
ls = '-'
alpha = 0.5
fig, axes = plt.subplots(nrows=6, sharex=True,
figsize=plot_params['figsize'])
axes[5].sharey(axes[0])
axes[0].sharey(axes[2])
for ax in (axes[0], axes[2], axes[5]):
for n in range(n_neurons):
for tr, data_tr in enumerate(spiketrains):
ax.plot(data_tr[n].rescale('ms').magnitude,
np.full_like(data_tr[n].magnitude,
fill_value=n * n_trials + tr),
'.', markersize=0.5, color='k')
for n in range(1, n_neurons):
# subtract 0.5 to separate the raster plots;
# otherwise, the line crosses the raster spikes
ax.axhline(n * n_trials - 0.5, lw=0.5, color='k')
ymax = max(ax.get_ylim()[1], 2 * n_trials - 0.5)
ax.set_ylim([-0.5, ymax])
ax.set_yticks([n_trials - 0.5, 2 * n_trials - 0.5])
ax.set_yticklabels([1, n_trials], fontsize=plot_params['fsize'])
ax.set_ylabel('Trial', fontsize=plot_params['fsize'])
for i, ax in enumerate(axes):
ax.set_xlim([t_winpos[0], t_winpos[-1] + win_size])
ax.text(-0.05, 1.1, string.ascii_uppercase[i],
transform=ax.transAxes, size=plot_params['fsize'] + 5,
weight='bold')
for key in plot_params['events'].keys():
for event_time in plot_params['events'][key]:
ax.axvline(event_time, ls=ls, color='r', lw=plot_params['lw'],
alpha=alpha)
axes[0].set_title('Spike Events')
axes[0].text(1.0, 1.0, f"Unit {plot_params['unit_real_ids'][-1]}",
fontsize=plot_params['fsize'] // 2,
horizontalalignment='right',
verticalalignment='bottom',
transform=axes[0].transAxes)
axes[0].text(1.0, 0, f"Unit {plot_params['unit_real_ids'][0]}",
fontsize=plot_params['fsize'] // 2,
horizontalalignment='right',
verticalalignment='top',
transform=axes[0].transAxes)
axes[1].set_title('Spike Rates')
for n in range(n_neurons):
axes[1].plot(t_winpos + win_size / 2.,
Js_dict['rate_avg'][:, n].rescale('Hz'),
label=f"Unit {plot_params['unit_real_ids'][n]}",
lw=plot_params['lw'])
axes[1].set_ylabel('Hz', fontsize=plot_params['fsize'])
axes[1].legend(fontsize=plot_params['fsize'] // 2, loc='upper right')
axes[1].locator_params(axis='y', tight=True, nbins=3)
axes[2].set_title('Coincident Events')
for n in range(n_neurons):
if not neurons_participated[n]:
continue
for tr, data_tr in enumerate(spiketrains):
indices = np.unique(Js_dict['indices'][f'trial{tr}'])
axes[2].plot(indices * bin_size,
np.full_like(indices, fill_value=n * n_trials + tr),
ls='', ms=plot_params['ms'], marker='s',
markerfacecolor='none',
markeredgecolor='c')
axes[2].set_ylabel('Trial', fontsize=plot_params['fsize'])
axes[3].set_title('Coincidence Rates')
axes[3].plot(t_winpos + win_size / 2.,
Js_dict['n_emp'] / (
win_size.rescale('s').magnitude * n_trials),
label='Empirical', lw=plot_params['lw'], color='c')
axes[3].plot(t_winpos + win_size / 2.,
Js_dict['n_exp'] / (
win_size.rescale('s').magnitude * n_trials),
label='Expected', lw=plot_params['lw'], color='m')
axes[3].set_ylabel('Hz', fontsize=plot_params['fsize'])
axes[3].legend(fontsize=plot_params['fsize'] // 2, loc='upper right')
axes[3].locator_params(axis='y', tight=True, nbins=3)
axes[4].set_title('Statistical Significance')
axes[4].plot(t_winpos + win_size / 2., Js_dict['Js'], lw=plot_params['lw'],
color='k')
axes[4].axhline(Js_sig, ls='-', color='r')
axes[4].axhline(-Js_sig, ls='-', color='g')
xlim_ax4 = axes[4].get_xlim()[1]
alpha_pos_text = axes[4].text(xlim_ax4, Js_sig, r'$\alpha +$', color='r',
horizontalalignment='right',
verticalalignment='bottom')
alpha_neg_text = axes[4].text(xlim_ax4, -Js_sig, r'$\alpha -$', color='g',
horizontalalignment='right',
verticalalignment='top')
axes[4].set_yticks([ue.jointJ(1 - significance_level), ue.jointJ(0.5),
ue.jointJ(significance_level)])
# Try '1 - 0.34' to see the floating point errors
axes[4].set_yticklabels(np.round([1 - significance_level, 0.5,
significance_level], decimals=6))
# autoscale fix to mind the text positions.
# See https://stackoverflow.com/questions/11545062/
# matplotlib-autoscale-axes-to-include-annotations
plt.get_current_fig_manager().canvas.draw()
for text_handle in (alpha_pos_text, alpha_neg_text):
bbox = text_handle.get_window_extent()
bbox_data = bbox.transformed(axes[4].transData.inverted())
axes[4].update_datalim(bbox_data.corners(), updatex=False)
axes[4].autoscale_view()
mask_nonnan = ~np.isnan(Js_dict['Js'])
significant_win_idx = np.nonzero(Js_dict['Js'][mask_nonnan] >= Js_sig)[0]
t_winpos_significant = t_winpos[mask_nonnan][significant_win_idx]
axes[5].set_title('Unitary Events')
if len(t_winpos_significant) > 0:
for n in range(n_neurons):
if not neurons_participated[n]:
continue
for tr, data_tr in enumerate(spiketrains):
indices = np.unique(Js_dict['indices'][f'trial{tr}'])
indices_significant = []
for t_sig in t_winpos_significant:
mask = (indices * bin_size >= t_sig
) & (indices * bin_size < t_sig + win_size)
indices_significant.append(indices[mask])
indices_significant = np.hstack(indices_significant)
indices_significant = np.unique(indices_significant)
# does nothing if indices_significant is empty
axes[5].plot(indices_significant * bin_size,
np.full_like(indices_significant,
fill_value=n * n_trials + tr),
ms=plot_params['ms'], marker='s', ls='',
mfc='none', mec='r')
axes[5].set_xlabel(f'Time ({t_winpos.dimensionality})',
fontsize=plot_params['fsize'])
for key in plot_params['events'].keys():
for event_time in plot_params['events'][key]:
axes[5].text(event_time - 10 * pq.ms,
axes[5].get_ylim()[0] - 35, key,
fontsize=plot_params['fsize'], color='r')
plt.suptitle(plot_params['suptitle'], fontsize=20)
plt.subplots_adjust(top=plot_params['top'],
right=plot_params['right'],
left=plot_params['left'],
bottom=plot_params['bottom'],
hspace=plot_params['hspace'],
wspace=plot_params['wspace'])
axes = FigureUE(*axes)
return axes