"""
Raster and event plots of spike times
-------------------------------------
.. autosummary::
:toctree: toctree/rasterplot/
eventplot
rasterplot
rasterplot_rates
"""
# Copyright 2017-2023 by the Viziphant team, see `doc/authors.rst`.
# License: Modified BSD, see LICENSE.txt for details.
import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
import quantities as pq
import seaborn as sns
import warnings
import neo
from math import log10, floor
from elephant.statistics import mean_firing_rate
from viziphant.utils import check_same_units
def _round_to_1(x):
rounded = round(x, -int(floor(log10(abs(x)))))
return rounded, rounded > x
def _get_attributes(spiketrains, key_list):
"""
This function returns attribute_array which is of an array of shape
(len(spiketrains), len(key_list)) and consists of numerical ids for each
value of each key for each spike train.
Passed spike trains must be already sorted according to key_list and
key_list must not be empty.
"""
key_count = len(key_list)
attribute_array = np.zeros((len(spiketrains), len(key_list)))
# count all group sizes for all keys in key_list:
while key_count > 0:
key_count -= 1
group_key = key_list[key_count]
i = 0
if group_key in spiketrains[i].annotations:
current_value = spiketrains[i].annotations[group_key]
else:
# use placeholder value when key is not in annotations
# of the current spike train
current_value = '####BLANK####'
ref_value = current_value
values = np.array([])
# count all group sizes for values of current key:
while i < spiketrains.__len__():
if not len(values) or current_value not in values:
values = np.append(values, current_value)
# count group size for a value of the current key:
while i < len(spiketrains) and current_value == ref_value:
attribute_array[i][key_count] = \
np.where(values == current_value)[0][0]
i += 1
if i < len(spiketrains):
if group_key in spiketrains[i].annotations:
current_value = spiketrains[i].annotations[
group_key]
else:
current_value = '####BLANK####'
ref_value = current_value
return attribute_array
[docs]
def rasterplot_rates(spiketrains,
key_list=[],
groupingdepth=0,
spacing=[8, 3],
colorkey=0,
pophist_mode='color',
pophistbins=100,
right_histogram=mean_firing_rate,
righthist_barwidth=1.01,
filter_function=None,
histscale=.1,
labelkey=None,
markerargs={'markersize': 4, 'marker': '.'},
separatorargs=[
{'linewidth': 2, 'linestyle': '--', 'color': '0.8'},
{'linewidth': 1, 'linestyle': '--', 'color': '0.8'}],
legend=False,
legendargs={'loc': (.98, 1.), 'markerscale': 1.5,
'handletextpad': 0},
ax=None,
style='ticks',
palette=None,
context=None, # paper, poster, talk
):
"""
This function plots the dot display of spike trains alongside its
population histogram and the mean firing rate (or a custom function).
Optional visual aids are offered such as sorting, grouping and color coding
on the basis of the arrangement in list of spike trains and spike train
annotations.
Changes to optics of the dot marker, the separators and the legend can be
applied by providing a dict with the respective parameters. Changes and
additions to the dot display itself or the two histograms are best realized
by using the returned axis handles.
Parameters
----------
spiketrains: list of neo.SpikeTrain or list of list of neo.SpikeTrain
List can either contain Neo SpikeTrains object or lists of Neo
SpikeTrains objects.
key_list: str or list of str
Annotation key(s) for which the spike trains should be ordered.
When list of keys is given the spike trains are ordered successively
for the keys.
By default the ordering by the given lists of spike trains have
priority. This can be bypassed by using an empty string '' as list-key
at any position in the key_list.
groupingdepth: int
* 0: No grouping (default)
* 1: grouping by first key in key_list.
Note that when list of lists of spike trains are given the first
key is by the list identification key ''. If this is unwanted
the empty string '' can be placed at a different position in
key_list.
* 2: additional grouping by second key respectively
The groups are separated by whitespace specified in the spacing
parameter and optionally by a line specified by the the separatorargs.
spacing: int or list of int
Size of whitespace separating the groups in units of spike trains.
When groupingdepth == 2 a list of two values can specify the distance
between the groups in level 1 and level 2. When only one value is given
level 2 spacing is set to half the spacing of level 1.
Default: [5, 3]
colorkey: str or int (default 0)
Contrasts values of a key by color. The key can be defined by its
namestring or its position in key_list. Note that position 0 points to
the list identification key ('') when list of lists of spike trains are
given, if not otherwise specified in key_list!
pophist_mode: str
* total: One population histogram for all drawn spike trains
* color: Additionally to the total population histogram,
a histogram for each colored subset is drawn (see colorkey).
pophistbins: int (default 100)
Number of bins used for the population histogram.
right_histogram: function
The function gets ONE neo.SpikeTrain object as argument and has to
return a scalar.
For example the functions in the elephant.statistics module can
be used. (default: mean_firing_rate)
When a function is applied is is recommended to set the axis label
accordingly by using the axis handle returned by the function:
axhisty.set_xlabel('Label Name')
righthist_barwidth: float (default 1.01)
The bin width of the right side histogram.
filter_function: function
The function gets ONE neo.SpikeTrain object as argument and if the
return is True the spike train is included; if False it is exluded.
histscale: float (default .1)
Portion of the figure used for the histograms on the right and upper
side.
labelkey: int or string or None
* 0, 1: Set label according to first or second key in key_list.
Note that the first key is by default the list identification
key ('') when list of lists of spike trains are given.
* '0+1': Two level labeling of 0 and 1
* annotation-key: Labeling each spike train with its value for given
key
* None: No labeling
Note that only groups (-> see groupingdepth) can be labeled as bulks.
Alternatively you can color for an annotation key and show a legend.
markerargs: dict
Arguments dictionary is passed on to matplotlib.pyplot.plot()
separatorargs: dict or list of dict or None
If only one dict is given and groupingdepth == 2 the arguments are
applied to the separator of both level. Otherwise the arguments are
of separatorargs[0] are applied to the level 1 and [1] to level 2.
Arguments dictionary is passed on to matplotlib.pyplot.plot()
To turn of separators set it to None.
legend: bool
Show legend?
legendargs: dict
Arguments dictionary is passed on to matplotlib.pyplot.legend()
ax: matplotlib axis or None (default)
The axis onto which to plot. If None a new figure is created.
When an axis is given, the function can't handle the figure settings.
Therefore it is recommended to call seaborn.set() with your preferred
settings before creating your matplotlib figure in order to control
your plotting layout.
style: str
seaborn style setting. Default: 'ticks'
palette: str or sequence
Define the color palette either by its name or use a custom palette in
a sequence of the form ([r,g,b],[r,g,b],...).
context: str
'paper'(default) | 'talk' | 'poster'
seaborn context setting which controls the scaling of labels. For the
three options the parameters are scaled by .8, 1.3, and 1.6
respectively.
Returns
-------
ax : matplotlib.axes.Axes
The handle of the dot display plot.
axhistx : matplotlib.axes.Axes
The handle of the histogram plot above the the dot display
axhisty : matplotlib.axes.Axes
The handle of the histogram plot on the right hand side
See Also
--------
rasterplot : simplified raster plot
eventplot : plot spike times in vertical stripes
Examples
--------
1. Basic Example.
.. plot::
:include-source:
from elephant.spike_train_generation import \
homogeneous_poisson_process, homogeneous_gamma_process
import quantities as pq
import matplotlib.pyplot as plt
from viziphant.rasterplot import rasterplot_rates
spiketrains = [homogeneous_poisson_process(rate=10 * pq.Hz)
for _ in range(100)]
rasterplot_rates(spiketrains)
plt.show()
2. Plot visually separated realizations of different neurons.
.. plot::
:include-source:
from elephant.spike_train_generation import \
homogeneous_poisson_process, homogeneous_gamma_process
import quantities as pq
import matplotlib.pyplot as plt
from viziphant.rasterplot import rasterplot_rates
spiketrains1 = [homogeneous_poisson_process(rate=10 * pq.Hz)
for _ in range(100)]
spiketrains2 = [homogeneous_gamma_process(a=3, b=10 * pq.Hz)
for _ in range(100)]
rasterplot_rates([spiketrains1, spiketrains2])
plt.show()
3. Add annotations to spike trains.
.. plot::
:include-source:
from elephant.spike_train_generation import \
homogeneous_poisson_process, homogeneous_gamma_process
import quantities as pq
import matplotlib.pyplot as plt
from viziphant.rasterplot import rasterplot_rates
spiketrains1 = [homogeneous_poisson_process(rate=10 * pq.Hz)
for _ in range(100)]
spiketrains2 = [homogeneous_gamma_process(a=3, b=10 * pq.Hz)
for _ in range(100)]
for i, (st1, st2) in enumerate(zip(spiketrains1, spiketrains2)):
if i % 2 == 1:
st1.annotations['parity'] = 'odd'
st2.annotations['parity'] = 'odd'
else:
st1.annotations['parity'] = 'even'
st2.annotations['parity'] = 'even'
# plot separates the lists and the annotation values within each list
rasterplot_rates([spiketrains1, spiketrains2], key_list=['parity'],
groupingdepth=2, labelkey='0+1')
``''`` key change the priority of the list grouping:
.. code-block:: python
rasterplot_rates([spiketrains1, spiketrains2],
key_list=['parity', ''],
groupingdepth=2, labelkey='0+1')
Groups can also be emphasized by an explicit color mode:
.. code-block:: python
rasterplot_rates([spiketrains1, spiketrains2],
key_list=['', 'parity'],
groupingdepth=1, labelkey=0, colorkey='parity',
legend=True)
"""
# Initialize plotting canvas
sns.set_style(style)
if context is not None:
sns.set_context(context)
if palette is not None:
sns.set_palette(palette)
else:
palette = sns.color_palette()
if ax is None:
fig, ax = plt.subplots()
# axis must be created after sns.set() command for style to apply!
margin = 1 - histscale
left, bottom, width, height = ax.get_position().bounds
ax.set_position([left, bottom,
margin * width, margin * height])
axhistx = plt.axes([left, bottom + margin * height,
margin * width, histscale * height])
axhisty = plt.axes([left + margin * width, bottom,
histscale * width, margin * height])
sns.despine(ax=axhistx)
sns.despine(ax=axhisty)
# Whitespace margin around dot display = 2%
ws_margin = 0.02
# Control of user entries
if groupingdepth > 2:
raise ValueError("Grouping is limited to two layers.")
groupingdepth = int(groupingdepth)
list_key = r"%$\@[#*&/!" # unique key to be added to annotations to store
# list ordering information.
if type(key_list) == 'str':
key_list = [key_list]
if '' not in key_list:
key_list = [list_key] + key_list
else:
key_list = [list_key if not key else key for key in key_list]
if type(spacing) == list:
if len(spacing) == 1:
spacing = [spacing[0], spacing[0] / 2.]
else:
spacing = [spacing, spacing / 2.]
if spacing[0] < spacing[1]:
raise DeprecationWarning("For reasonable visual aid, spacing between "
"top level group (spacing[0]) must be larger "
"than for subgroups (spacing[1]).")
if type(colorkey) == int and len(key_list):
if colorkey >= len(key_list):
raise IndexError("An integer colorkey must refer to a position in "
"key_list.")
colorkey = key_list[colorkey]
else:
if not colorkey:
colorkey = list_key
elif colorkey not in key_list:
raise AttributeError("colorkey must be in key_list.")
if legend and not key_list:
raise AttributeError("Legend requires a non empty key_list.")
if labelkey == '':
labelkey = list_key
if type(separatorargs) == list:
if len(separatorargs) == 1:
separatorargs += separatorargs
for args in separatorargs:
if type(args) != dict:
raise TypeError("The parameters must be given as dict.")
else:
separatorargs = [separatorargs, separatorargs]
for i, args in enumerate(separatorargs):
if 'c' in args:
separatorargs[i]['color'] = args['c']
elif 'color' not in args:
separatorargs[i]['color'] = '0.8'
markerargs['linestyle'] = ''
# Flatten list of lists while keeping the grouping info in annotations
if isinstance(spiketrains[0], list):
for list_nbr, st_list in enumerate(spiketrains):
for st in st_list:
st.annotations[list_key] = "set {}".format(list_nbr)
spiketrains = [st for sublist in spiketrains for st in sublist]
else:
for st in spiketrains:
st.annotations[list_key] = "set {}".format(0)
key_list.remove(list_key)
key_list.append(list_key)
# Input checks on flattened lists
if len(key_list) < groupingdepth:
raise ValueError("Can't group more as keys in key_list.")
# Filter spike trains according to given filter function
if filter_function is not None:
filter_index = []
for st_count, spiketrain in enumerate(spiketrains):
if filter_function(spiketrain):
filter_index += [st_count]
spiketrains = [spiketrains[i] for i in filter_index]
# Initialize plotting parameters
t_lims = [(st.t_start, st.t_stop) for st in spiketrains]
tmin = min(t_lims, key=lambda f: f[0])[0]
tmax = max(t_lims, key=lambda f: f[1])[1]
period = tmax - tmin
ax.set_xlim(tmin - ws_margin * period, tmax + ws_margin * period)
yticks = np.zeros(len(spiketrains))
# Sort spike trains according to keylist
def sort_func(x):
return ['' if key not in x.annotations
else x.annotations[key] for key in key_list]
spiketrains = sorted(spiketrains, key=lambda x: sort_func(x))
if len(key_list) > 1:
attribute_array = _get_attributes(spiketrains, key_list)
elif len(key_list) == 1:
attribute_array = np.zeros((len(spiketrains), 2))
attribute_array[:, 0] = _get_attributes(spiketrains, key_list)[:, 0]
else:
attribute_array = np.zeros((len(spiketrains), 1))
# Define colormap
if not len(key_list):
nbr_of_colors = 1
colorkey = None
else:
colorkey = np.where(colorkey == np.array(key_list))[0][0]
nbr_of_colors = int(max(attribute_array[:, colorkey]) + 1)
colormap = sns.color_palette(palette, nbr_of_colors)
# Draw population histogram (upper side)
colorkeyvalues = np.unique(attribute_array[:, colorkey])
if pophist_mode == 'color' and len(colorkeyvalues) - 1:
if len(sns.color_palette()) < len(colorkeyvalues):
warnings.warn("There are more subsets than can be separated by "
"colors in the color palette which might lead to "
"confusion")
max_y = 0
for value in colorkeyvalues:
idx = np.where(attribute_array[:, colorkey] == value)[0]
histout = axhistx.hist(
np.concatenate([spiketrains[i] for i in idx]),
pophistbins, histtype='step', linewidth=1,
color=colormap[int(value)])
max_y = np.max([max_y, np.max(histout[0])])
else: # pophist_mode == 'total':
if len(colorkeyvalues) - 1:
sum_color = separatorargs[0]['color']
else:
sum_color = sns.color_palette()[0]
histout = axhistx.hist(np.concatenate(spiketrains),
pophistbins, histtype='step', linewidth=1,
color=sum_color)
max_y = np.max(histout[0])
# Set ticks and labels for population histogram
axhistx_ydim, up = _round_to_1(max_y)
if max_y > axhistx.get_ylim()[-1]:
axhistx.set_ylim(0, max_y)
if up and axhistx_ydim > max_y:
axhistx.set_ylim(0, axhistx_ydim)
axhistx.set_yticks([axhistx_ydim])
axhistx.set_yticklabels(['{:.0f}'.format(axhistx_ydim)])
axhistx.set_ylabel('count')
# Legend for colorkey
if legend:
__, index = np.unique(attribute_array[:, colorkey], return_index=True)
legend_labels = [spiketrains[i].annotations[key_list[colorkey]]
for i in index]
legend_handles = [0] * len(index)
# Reshape list into sublists according to groupingdepth
if groupingdepth > 0:
value1, index1, counts1 = np.unique(attribute_array[:, 0],
return_index=True,
return_counts=True)
for v1, i1, c1 in zip(value1, index1, counts1):
v1 = int(v1)
spiketrains[v1:v1 + c1] = [spiketrains[v1:v1 + c1]]
if groupingdepth > 1:
__, counts2 = np.unique(attribute_array[i1:i1 + c1, 1],
return_counts=True)
for v2, c2 in enumerate(counts2):
v2 = int(v2)
spiketrains[v1][v2:v2 + c2] = [
spiketrains[v1][v2:v2 + c2]]
else:
spiketrains[v1] = [spiketrains[v1]]
else:
spiketrains = [[spiketrains]]
# HIERARCHIE:
# [ [ []..[] ] .... [ []..[] ] ] spiketrains
# [ []..[] ] LIST
# [] list
# spike train
# Loop through lists of lists of spike trains
for COUNT, SLIST in enumerate(spiketrains):
# Separator depth 1
if COUNT and separatorargs is not None:
linepos = ypos + len(spiketrains[COUNT - 1][-1]) \
+ spacing[0] / 2. - 0.5
ax.plot(ax.get_xlim(), [linepos] * 2, **separatorargs[0])
# Loop through lists of spike trains
for count, slist in enumerate(SLIST):
nbr_of_drawn_sts = int(
sum(len(sl) for SL in spiketrains[:COUNT] for sl in SL) +
sum(len(sl) for sl in SLIST[:count]))
# Calculate postition of next spike train to draw
prev_spaces = np.sum([len(SLIST_it) - 1
for SLIST_it in spiketrains[:COUNT]])
ypos = nbr_of_drawn_sts + int(
bool(groupingdepth)) * COUNT * spacing[0] \
+ groupingdepth / 2 * count * spacing[1] \
+ groupingdepth / 2 * prev_spaces * spacing[1]
# Separator depth 2
if count and separatorargs is not None:
linepos = ypos - (spacing[1] + 1) / 2.
ax.plot(ax.get_xlim(), [linepos] * 2, **separatorargs[1])
# Loop through spike trains
for st_count, st in enumerate(slist):
current_st = nbr_of_drawn_sts + st_count
annotation_value = int(attribute_array[current_st, colorkey])
color = colormap[annotation_value]
# Dot display
handle = ax.plot(st.times.magnitude,
[st_count + ypos] * st.__len__(),
color=color, **markerargs)
if legend:
legend_handles[annotation_value] = handle[0]
# Right side histogram bar
barvalue = right_histogram(st)
barwidth = righthist_barwidth
axhisty.barh(y=st_count + ypos, # - barwidth/2.,
width=barvalue,
height=barwidth,
color=color,
edgecolor=color)
# Append positions of spike trains to tick list
ycoords = np.arange(len(slist)) + ypos
yticks[nbr_of_drawn_sts:nbr_of_drawn_sts + len(slist)] = ycoords
# Plotting axis
yrange = yticks[-1] - yticks[0]
ax.set_ylim(yticks[0] - ws_margin * yrange,
yticks[-1] + ws_margin * yrange)
axhistx.set_xlim(ax.get_xlim())
axhisty.set_ylim(ax.get_ylim())
ax.set_xlabel(f'Time ({spiketrains[0][0][0].units.dimensionality})')
axhistx.get_xaxis().set_visible(False)
axhisty.get_yaxis().set_visible(False)
# Set ticks and labels for right side histogram
axhisty_xdim, up = _round_to_1(axhisty.get_xlim()[-1])
if up:
axhistx.set_ylim(0, axhistx_ydim)
axhisty.set_xticks([axhisty_xdim])
axhisty.set_xticklabels(['{}'.format(axhisty_xdim)])
# Y labeling
if key_list and labelkey in key_list + [0, 1, '0+1']:
if labelkey == key_list[0]:
if groupingdepth > 0:
labelkey = 0
elif len(key_list) > 1 and labelkey == key_list[1]:
if groupingdepth > 1:
labelkey = 1
if type(labelkey) == int or labelkey == '0+1':
labelpos = [[] for label_level in range(2)]
labelname = [[] for label_level in range(2)]
# Labeling depth 1 + 2
if groupingdepth:
values1, index1, counts1 = np.unique(attribute_array[:, 0],
return_index=True,
return_counts=True)
for v1, i1, c1 in zip(values1, index1, counts1):
st = spiketrains[int(v1)][0][0]
if key_list[0] in st.annotations:
labelname[0] += [st.annotations[key_list[0]]]
if labelkey == '0+1':
labelname[0][-1] += ' ' * 5
else:
labelname[0] += ['']
labelpos[0] += [(yticks[i1] + yticks[i1 + c1 - 1]) / 2.]
# Labeling depth 2
if groupingdepth / 2 and labelkey and len(key_list) - 1:
__, index2, counts2 = np.unique(
attribute_array[i1:i1 + c1, 1],
return_index=True,
return_counts=True)
for v2, (i2, c2) in enumerate(zip(index2, counts2)):
st = spiketrains[int(v1)][int(v2)][0]
if key_list[1] in st.annotations:
labelname[1] += [st.annotations[key_list[1]]]
else:
labelname[1] += ['']
labelpos[1] += [(yticks[i1 + i2] + yticks[
i1 + i2 + c2 - 1]) / 2.]
# Set labels according to labelkey
if type(labelkey) == int:
ax.set_yticks(labelpos[1] if labelkey else labelpos[0])
ax.set_yticklabels(labelname[1] if labelkey else labelname[0])
elif labelkey == "0+1":
ax.set_yticks(labelpos[0] + labelpos[1])
ax.set_yticklabels(labelname[0] + labelname[1])
else:
# Annotatation key as labelkey
labelname = []
for COUNT, SLIST in enumerate(spiketrains):
for count, slist in enumerate(SLIST):
for st_count, st in enumerate(slist):
if labelkey in st.annotations:
labelname += [st.annotations[labelkey]]
else:
labelname += ['']
ax.set_yticks(yticks)
ax.set_yticklabels(labelname)
else:
ax.set_yticks([])
# Draw legend
if legend:
ax.legend(legend_handles, legend_labels, **legendargs)
# Remove list_key from annotations
for SLIST in spiketrains:
for slist in SLIST:
for st in slist:
st.annotations.pop(list_key, None)
return ax, axhistx, axhisty
[docs]
def rasterplot(spiketrains, axes=None, histogram_bins=0, title=None,
color=None, **kwargs):
"""
Simple and fast raster plot of spike times.
Parameters
----------
spiketrains : list of neo.SpikeTrain or pq.Quantity
A list of `neo.SpikeTrain` objects or quantity arrays with spike times.
axes : matplotlib.axes.Axes or None, optional
Matplotlib axes handle. If None, new axes are created and returned.
Default: None
histogram_bins : int, optional
Defines the number of histogram bins. If set to ``0``, no histogram
is shown.
Default: 0
title : str or None, optional
The axes title.
Default: None
color : str or list of str or None, optional
Raster colors.
Default: None
**kwargs
Additional parameters passed to matplotlib `scatter` function.
Returns
-------
axes : matplotlib.Axes.axes
See Also
--------
rasterplot_rates : advanced raster plot
eventplot : plot spike times in vertical stripes
Examples
--------
1. Basic example.
.. plot::
:include-source:
import numpy as np
import quantities as pq
import matplotlib.pyplot as plt
from elephant.spike_train_generation import homogeneous_poisson_process
from viziphant.rasterplot import rasterplot
np.random.seed(7)
spiketrains = [homogeneous_poisson_process(rate=10*pq.Hz,
t_stop=10*pq.s) for _ in range(10)]
rasterplot(spiketrains, s=3, c='black')
plt.show()
2. Raster plot with a histogram and events.
.. plot::
:include-source:
import neo
import numpy as np
import quantities as pq
import matplotlib.pyplot as plt
from elephant.spike_train_generation import homogeneous_poisson_process
from viziphant.rasterplot import rasterplot
from viziphant.events import add_event
np.random.seed(7)
spiketrains = [homogeneous_poisson_process(rate=r * pq.Hz,
t_stop=10 * pq.s) for r in range(1, 21)]
event = neo.Event([0.5, 2.8] * pq.s, labels=['Trig ON', 'Trig OFF'])
axes = rasterplot(spiketrains, histogram_bins=50, title='Title', s=0.5)
add_event(axes, event=event)
plt.show()
"""
if isinstance(spiketrains[0], neo.SpikeTrain):
spiketrains = [spiketrains]
spiketrains = list(filter(len, spiketrains))
check_same_units(spiketrains)
units = spiketrains[0][0].units
if color is None:
color = kwargs.pop('c', None)
if not isinstance(color, (list, tuple)):
color = [color] * len(spiketrains)
if axes is None:
nrows = 2 if histogram_bins else 1
fig, axes = plt.subplots(nrows=nrows, ncols=1)
count = 0
histtype = 'bar' if len(spiketrains) == 1 else 'step'
for sts_population, c in zip(spiketrains, color):
sts_population = [st.magnitude for st in sts_population]
axes = np.atleast_1d(axes)
times_population = np.hstack(sts_population)
ys = np.hstack([np.repeat(i + count, repeats=len(st))
for i, st in enumerate(sts_population)])
axes[0].scatter(times_population, ys, c=c, **kwargs)
if histogram_bins:
axes[1].hist(times_population, bins=histogram_bins,
histtype=histtype, color=c)
count += len(sts_population)
axes[0].set_yticks([0, count - 1])
axes[0].set_title(title)
if histogram_bins:
axes[1].set_ylabel("Spike count")
axes[-1].set_xlabel(f"Time ({units.dimensionality})")
if len(axes) == 1:
axes = axes[0]
return axes
[docs]
def eventplot(spiketrains, axes=None, histogram_bins=0, title=None, **kwargs):
"""
Spike times eventplot with an additional histogram.
Parameters
----------
spiketrains : list of neo.SpikeTrain or pq.Quantity
A list of `neo.SpikeTrain` objects or quantity arrays with spike times.
axes : matplotlib.axes.Axes or None
Matplotlib axes handle. If None, new axes are created and returned.
Default: None
histogram_bins : int, optional
Defines the number of histogram bins. If set to ``0``, no histogram
is shown.
Default: 0
title : str or None, optional
The axes title.
Default: None
**kwargs
Additional parameters passed to matplotlib `eventplot` function.
Returns
-------
axes : matplotlib.axes.Axes
See Also
--------
rasterplot : simplified raster plot
rasterplot_rates : advanced raster plot
Examples
--------
Basic spike times eventplot.
.. plot::
:include-source:
import numpy as np
import quantities as pq
import matplotlib.pyplot as plt
from elephant.spike_train_generation import homogeneous_poisson_process
from viziphant.rasterplot import eventplot
np.random.seed(12)
spiketrains = [homogeneous_poisson_process(rate=10*pq.Hz,
t_stop=10*pq.s) for _ in range(10)]
eventplot(spiketrains, linelengths=0.75, color='black')
plt.show()
To plot with a histogram, provide a value for ``histogram_bins``.
To compare spike times between different neurons, create
`matplotlib.axes.Axes` instance prior to calling the function.
Additionally, you can add events to the plot with
:func:`viziphant.events.add_event` function.
.. plot::
:include-source:
import neo
import numpy as np
import quantities as pq
import matplotlib.pyplot as plt
from elephant.spike_train_generation import homogeneous_poisson_process
from viziphant.rasterplot import eventplot
from viziphant.events import add_event
np.random.seed(12)
spiketrains = [homogeneous_poisson_process(rate=5*pq.Hz,
t_stop=10*pq.s) for _ in range(20)]
fig, axes = plt.subplots(2, 2, sharex=True, sharey='row')
event = neo.Event([0.5, 8]*pq.s, labels=['trig0', 'trig1'])
eventplot(spiketrains[:10], axes=axes[:, 0], histogram_bins=20,
title="Neuron A", linelengths=0.75, linewidths=1)
add_event(axes[:, 0], event)
eventplot(spiketrains[10:], axes=axes[:, 1], histogram_bins=20,
title="Neuron B", linelengths=0.75, linewidths=1)
add_event(axes[:, 1], event)
plt.show()
"""
check_same_units(spiketrains)
units = spiketrains[0].units
spiketrains = [st.magnitude for st in spiketrains]
if axes is None:
nrows = 2 if histogram_bins else 1
fig, axes = plt.subplots(nrows=nrows, ncols=1)
axes = np.atleast_1d(axes)
axes[0].eventplot(spiketrains, **kwargs)
axes[0].set_yticks([0, len(spiketrains) - 1])
axes[0].set_title(title)
if histogram_bins:
axes[1].hist(np.hstack(spiketrains), bins=histogram_bins)
axes[1].set_ylabel("Spike count")
axes[-1].set_xlabel(f"Time ({units.dimensionality})")
if len(axes) == 1:
axes = axes[0]
return axes