import matplotlib.pyplot as plt
import numpy as np
import quantities as pq
from elephant.cell_assembly_detection import cell_assembly_detection
from elephant.conversion import BinnedSpikeTrain
from elephant.spike_train_generation import compound_poisson_process
import viziphant

np.random.seed(30)
spiketrains = compound_poisson_process(rate=15 * pq.Hz,
    amplitude_distribution=[0, 0.95, 0, 0, 0, 0, 0.05], t_stop=5*pq.s)
bst = BinnedSpikeTrain(spiketrains, bin_size=10 * pq.ms)
bst.rescale('ms')
patterns = cell_assembly_detection(bst, max_lag=2)

viziphant.patterns.plot_patterns(spiketrains, patterns=patterns[:2],
                                 circle_sizes=(3, 30, 40))
plt.show()