"""
Spike train correlation plots
-----------------------------
.. autosummary::
:toctree: toctree/spike_train_correlation/
plot_corrcoef
plot_cross_correlation_histogram
"""
# Copyright 2017-2023 by the Viziphant team, see `doc/authors.rst`.
# License: Modified BSD, see LICENSE.txt.txt for details.
from __future__ import division, print_function, unicode_literals
import matplotlib.pyplot as plt
import neo
import numpy as np
import quantities as pq
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
from elephant.utils import check_neo_consistency
[docs]
def plot_corrcoef(corrcoef_matrix, axes=None, correlation_range='auto',
colormap='bwr', colorbar_aspect=20,
colorbar_padding_fraction=0.5, remove_diagonal=True):
"""
Plots a cross-correlation matrix returned by
:func:`elephant.spike_train_correlation.correlation_coefficient`
function with a color bar.
Parameters
----------
corrcoef_matrix : np.ndarray
Pearson's correlation coefficient matrix
axes : matplotlib.axes.Axes or None, optional
Matplotlib axes handle. If None, new axes are created and returned.
Default: None
correlation_range : {'auto', 'full'} or tuple of float, optional
Minimum and maximum correlations to consider for color mapping.
If tuple, the first element is the minimum and the second
element is the maximum correlation.
If 'auto', the maximum absolute value of the non-diagonal coefficients
will be used symmetrically as minimum and maximum.
If 'full', maximum correlation is set at 1.0 and minimum at -1.0.
Default: 'auto'
colormap : str, optional
Colormap. Default: 'bwr'
colorbar_aspect : float, optional
Aspect ratio of the color bar. Default: 20
colorbar_padding_fraction : float, optional
Padding between matrix plot and color bar relative to color bar width.
Default: 0.5
remove_diagonal : bool, optional
If True, the values in the main diagonal are replaced with zeros.
Default: True
Returns
-------
axes : matplotlib.axes.Axes
Raises
------
ValueError
If `correlation_range` is not tuple or 'auto' or 'full'.
Examples
--------
Create 10 homogeneous random Poisson spike trains of rate `10Hz` and bin
the spikes into bins of `100ms` width, which is relatively large for such
a firing rate, so we expect non-zero correlations.
.. plot::
:include-source:
import quantities as pq
from elephant.spike_train_generation import homogeneous_poisson_process
from elephant.conversion import BinnedSpikeTrain
from elephant.spike_train_correlation import correlation_coefficient
from viziphant.spike_train_correlation import plot_corrcoef
np.random.seed(0)
spiketrains = [homogeneous_poisson_process(rate=10*pq.Hz,
t_stop=10*pq.s) for _ in range(10)]
binned_spiketrains = BinnedSpikeTrain(spiketrains, bin_size=100*pq.ms)
corrcoef_matrix = correlation_coefficient(binned_spiketrains)
fig, axes = plt.subplots()
plot_corrcoef(corrcoef_matrix, axes=axes)
axes.set_xlabel('Neuron')
axes.set_ylabel('Neuron')
axes.set_title("Correlation coefficient matrix")
plt.show()
"""
if axes is None:
fig, axes = plt.subplots()
if remove_diagonal:
corrcoef_matrix = corrcoef_matrix.copy()
np.fill_diagonal(corrcoef_matrix, val=0)
# Get limits
if correlation_range == 'full':
vmin, vmax = -1, 1
elif correlation_range == 'auto':
vmax = np.max(np.abs(corrcoef_matrix))
vmin = -vmax
elif isinstance(correlation_range, (tuple, list)):
vmin, vmax = correlation_range
else:
raise ValueError(f"Invalid 'correlation_range' ({correlation_range}). "
f"Must be a tuple of float values or 'auto'/'full'.")
image = axes.imshow(corrcoef_matrix, vmin=vmin, vmax=vmax, cmap=colormap)
# Initialise colour bar axis
divider = make_axes_locatable(axes)
width = axes_size.AxesY(axes, aspect=1. / colorbar_aspect)
pad = axes_size.Fraction(colorbar_padding_fraction, width)
cax = divider.append_axes("right", size=width, pad=pad)
plt.colorbar(image, cax=cax)
return axes
[docs]
def plot_cross_correlation_histogram(cch, axes=None, units=None, maxlag=None,
legend=None,
title='Cross-correlation histogram'):
"""
Plot a cross-correlation histogram returned by
:func:`elephant.spike_train_correlation.cross_correlation_histogram`,
rescaled to seconds.
Parameters
----------
cch : neo.AnalogSignal or list of neo.AnalogSignal
Cross-correlation histogram or a list of such.
axes : matplotlib.axes.Axes or None, optional
Matplotlib axes handle. If set to None, new axes are created and
returned.
Default: None
units : pq.Quantity or str or None, optional
Desired time axis units.
If None, ``cch.sampling_period`` units are used.
Default: None
maxlag : pq.Quantity or None, optional
Left and right borders of the plot.
Default: None
legend : str or list of str or None, optional
The axes legend labels.
Default: None
title : str, optional
The axes title.
Default: 'Cross-correlation histogram'
Returns
-------
fig : matplotlib.figure.Figure
ax : matplotlib.axes.Axes
Examples
--------
.. plot::
:include-source:
import quantities as pq
import matplotlib.pyplot as plt
from elephant.spike_train_generation import homogeneous_poisson_process
from elephant.conversion import BinnedSpikeTrain
from elephant.spike_train_correlation import \
cross_correlation_histogram
from viziphant.spike_train_correlation import \
plot_cross_correlation_histogram
spiketrain1 = homogeneous_poisson_process(rate=10*pq.Hz,
t_stop=10*pq.s)
spiketrain2 = homogeneous_poisson_process(rate=10*pq.Hz,
t_stop=10*pq.s)
binned_spiketrain1 = BinnedSpikeTrain(spiketrain1, bin_size=100*pq.ms)
binned_spiketrain2 = BinnedSpikeTrain(spiketrain2, bin_size=100*pq.ms)
cch, lags = cross_correlation_histogram(binned_spiketrain1,
binned_spiketrain2)
plot_cross_correlation_histogram(cch)
plt.show()
"""
if axes is None:
fig, axes = plt.subplots()
if isinstance(cch, neo.AnalogSignal):
cch = [cch]
check_neo_consistency(cch, object_type=neo.AnalogSignal)
if units is None:
units = cch[0].sampling_period.units
elif isinstance(units, str):
units = pq.Quantity(1, units)
if legend is None:
legend = [None] * len(cch)
elif isinstance(legend, str):
legend = [legend]
if len(legend) != len(cch):
raise ValueError("The length of the input list and legend labels do "
"not match.")
for label, signal in zip(legend, cch):
cch_times = signal.times.rescale(units).magnitude
axes.plot(cch_times, signal.magnitude, label=label)
axes.set_ylabel(cch[0].annotations['cch_parameters']['normalization'])
axes.set_xlabel(f"Time lag ({units.dimensionality})")
axes.set_title(title)
if maxlag is not None:
maxlag = maxlag.rescale(units).magnitude
axes.set_xlim(-maxlag, maxlag)
if legend[0] is not None:
axes.legend()
return axes