import numpy as np
import quantities as pq
from elephant.gpfa import GPFA
from elephant.spike_train_generation import StationaryPoissonProcess
from viziphant.gpfa import plot_dimensions_vs_time

np.random.seed(24)
n_trials = 10
n_channels = 5

data = []
for trial in range(n_trials):
    firing_rates = np.random.randint(low=1, high=100,
                                     size=n_channels) * pq.Hz
    spike_times = [StationaryPoissonProcess(rate=rate
                                            ).generate_spiketrain()
                   for rate in firing_rates]
    data.append(spike_times)

grouping_dict = {'trial type A': [0, 2, 4, 6, 8],
                 'trial type B': [1, 3, 5, 7, 9]}

gpfa = GPFA(bin_size=20 * pq.ms, x_dim=3, verbose=False)
gpfa.fit(data)
results = gpfa.transform(data, returned_data=['latent_variable_orth',
                                              'latent_variable'])

plot_dimensions_vs_time(
    returned_data=results,
    gpfa_instance=gpfa,
    dimensions='all',
    orthonormalized_dimensions=True,
    n_columns=1,
    plot_group_averages=True,
    trials_to_plot="all",
    trial_grouping_dict=grouping_dict)
plt.show()