#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Basic graphics - Gaussian Linear Hidden Markov Model
@author: Diego Vidaurre 2023
"""
import numpy as np
import seaborn as sb
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import pandas as pd
import warnings
from matplotlib import cm, colors
from matplotlib.colors import LogNorm, LinearSegmentedColormap, to_rgba_array
from mpl_toolkits.axes_grid1 import make_axes_locatable
from . import utils
# import utils
[docs]
def show_trans_prob_mat(hmm,only_active_states=False,show_diag=True,show_colorbar=True):
"""Displays the transition probability matrix of a given HMM.
Parameters:
-----------
hmm: HMM object
An instance of the HMM class containing the transition probability matrix to be visualized.
only_active_states : bool, optional, default=False
Whether to display only active states or all states in the matrix.
show_diag : bool, optional, defatult=True
Whether to display the diagonal elements of the matrix or not.
show_colorbar : bool, optional, default=True
Whether to display the colorbar next to the matrix or not.
"""
P = np.copy(hmm.P)
if only_active_states:
P = P[hmm.active_states,hmm.active_states]
K = P.shape[0]
if not show_diag:
for k in range(P.shape[0]):
P[k,k] = 0
P[k,:] = P[k,:] / np.sum(P[k,:])
_,axes = plt.subplots()
g = sb.heatmap(ax=axes,data=P,\
cmap='bwr',xticklabels=np.arange(K), yticklabels=np.arange(K),
square=True,cbar=show_colorbar)
for k in range(K):
g.plot([0, K],[k, k], '-k')
g.plot([k, k],[0, K], '-k')
axes.axhline(y=0, color='k',linewidth=4)
axes.axhline(y=K, color='k',linewidth=4)
axes.axvline(x=0, color='k',linewidth=4)
axes.axvline(x=K, color='k',linewidth=4)
[docs]
def show_Gamma(Gamma, line_overlay=None, tlim=None, Hz=1, palette='viridis'):
"""Displays the activity of the hidden states as a function of time.
Parameters:
-----------
Gamma : array of shape (n_samples, n_states)
The state timeseries probabilities.
line_overlay : array of shape (n_samples, 1)
A secondary related data type to overlay as a line.
tlim : 2x1 array or None, default=None
The time interval to be displayed. If None (default), displays the
entire sequence.
Hz : int, default=1
The frequency of the signal, in Hz.
palette : str, default = 'Oranges'
The name of the color palette to use.
"""
T,K = Gamma.shape
# Setup colors
x = np.round(np.linspace(0.0, 256-1, K)).astype(int)
# cmap = plt.get_cmap('plasma').colors
cmap = plt.get_cmap(palette)
cmap = cmap(np.arange(0, cmap.N))[:, :3]
colors = np.zeros((K,3))
for k in range(K):
colors[k,:] = cmap[x[k]]
# Setup data according to given limits
if tlim is not None:
T = tlim[1] - tlim[0]
data = Gamma[tlim[0] : tlim[1], :]
if line_overlay is not None:
line = line_overlay[tlim[0] : tlim[1]].copy()
else:
data = Gamma
df = pd.DataFrame(data, index=np.arange(T)/Hz)
df = df.divide(df.sum(axis=1), axis=0)
# Plot Gamma area
ax = df.plot(
kind='area',
stacked=True,
ylim=(0,1),
legend=False,
color=colors
)
# Overlay line if given
if line_overlay is not None:
df2 = pd.DataFrame(line, index=np.arange(T)/Hz)
ax2 = ax.twinx()
df2.plot(ax=ax2, legend=False, color="black")
ax2.set(ylabel = '')
# Adjust axis specifications
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1))
ax.set(
title = "",
xlabel = 'Time [s]',
ylabel = 'State probability')
ax.margins(0,0)
plt.show()
[docs]
def show_temporal_statistic(Gamma,indices,statistic='FO',type_plot='barplot'):
"""Plots a statistic over time for a set of sessions.
Parameters:
-----------
Gamma : array of shape (n_samples, n_states)
The state timeseries probabilities.
indices: numpy.ndarray of shape (n_sessions,)
The session indices to plot.
statistic: str, default='FO'
The statistic to compute and plot. Can be 'FO', 'switching_rate' or 'FO_entropy'.
type_plot: str, default='barplot'
The type of plot to generate. Can be 'barplot', 'boxplot' or 'matrix'.
Raises:
-------
Exception
- Statistic is not one of 'FO', 'switching_rate' or 'FO_entropy'.
- type_plot is 'boxplot' and there are less than 10 sessions.
- type_plot is 'matrix' and there is only one session.
"""
s = eval("utils.get_" + statistic)(Gamma,indices)
if statistic not in ["FO","switching_rate","FO_entropy"]:
raise Exception("statistic has to be 'FO','switching_rate' or 'FO_entropy'")
N,K = s.shape
sb.set(style='whitegrid')
if type_plot=='boxplot':
if N < 10:
raise Exception("Too few sessions for a boxplot; use barplot")
sb.boxplot(data=s,palette='plasma')
elif type_plot=='barplot':
sb.barplot(data=np.concatenate((s,s)),palette='plasma', errorbar=None)
elif type_plot=='matrix':
if N < 2:
raise Exception("There is only one session; use barplot")
fig,ax = plt.subplots()
labels_x = np.round(np.linspace(0,N,5)).astype(int)
pos_x = np.linspace(0,N,5)
if K > 10: labels_y = np.linspace(0,K-1,5)
else: labels_y = np.arange(K)
im = plt.imshow(s.T,aspect='auto')
plt.xticks(pos_x, labels_x)
plt.yticks(labels_y, labels_y)
ax.set_xlabel('Sessions')
ax.set_ylabel('States')
fig.tight_layout()
[docs]
def show_beta(hmm,only_active_states=True,recompute_states=False,
X=None,Y=None,Gamma=None,show_average=None,alpha=1.0):
"""
Displays the beta coefficients of a given HMM.
The beta coefficients can be extracted directly from the HMM structure or reestimated from the data;
for the latter, X, Y and Gamma need to be provided as parameters.
This is useful for example if one has run the model on PCA space,
but wants to show coefficients in the original space.
Parameters:
-----------
hmm: HMM object
An instance of the HMM class containing the beta coefficients to be visualized.
only_active_states: bool, optional, default=False
If True, only the beta coefficients of active states are shown.
recompute_states: bool, optional, default=False
If True, the betas will be recomputed from the data and the state time courses
X: numpy.ndarray, optional, default=None
The timeseries of set of variables 1.
Y: numpy.ndarray, optional, default=None
The timeseries of set of variables 2.
Gamma: numpy.ndarray, optional, default=None
The state time courses
show_average: bool, optional, default=None
If True, an additional row of the average beta coefficients is shown.
alpha: float, optional, default=0.1
The regularisation parameter to be applied if the betas are to be recomputed.
"""
if show_average is None:
show_average = not ((X is None) or (Y is None))
K = hmm.get_betas().shape[2]
if recompute_states:
if (Y is None) or (X is None) or (Gamma is None):
raise Exception("The data (X,Y) and the state time courses (Gamma) need \
to be provided if recompute_states is True ")
(p,q) = (X.shape[1],Y.shape[1])
beta = np.zeros((p,q,K))
for k in range(K):
if hmm.hyperparameters["model_mean"] != 'no':
m = (np.expand_dims(Gamma[:,k],axis=1).T @ Yr) / np.sum(Gamma[:,k])
Yr = Y - np.expand_dims(m, axis=0)
else:
Yr = Y
beta[:,:,k] = np.linalg.inv((X * np.expand_dims(Gamma[:,k],axis=1)).T @ X + alpha * np.eye(p)) @ \
((X * np.expand_dims(Gamma[:,k],axis=1)).T @ Yr)
else:
beta = hmm.get_betas()
(p,q,_) = beta.shape
if only_active_states:
idx = np.where(hmm.active_states)[0]
beta = beta[:,:,idx]
K = beta.shape[2]
else:
idx = np.arange(K)
if show_average:
Yr = Y - np.expand_dims(np.mean(Y,axis=0), axis=0)
b0 = np.linalg.inv(X.T @ X + alpha * np.eye(p)) @ (X.T @ Yr)
K += 1
B = np.zeros((p,q,K))
B[:,:,0:K-1] = beta
B[:,:,-1] = b0
else:
B = beta
Bstar1 = np.zeros((p,q,K,K))
for k in range(K): Bstar1[:,:,k,:] = B
Bstar2 = np.zeros((p,q,K,K))
for k in range(K): Bstar2[:,:,:,k] = B
I1 = np.zeros((p,q,K,K),dtype=object)
for j in range(q): I1[:,j,:,:] = str(j)
I2 = np.zeros((p,q,K,K),dtype=object)
for k in range(K):
if show_average and (k==(K-1)):
I2[:,:,k,:] = 'Average'
else:
I2[:,:,k,:] = 'State ' + str(k)
I3 = np.zeros((p,q,K,K),dtype=object)
for k in range(K):
if show_average and (k==(K-1)):
I3[:,:,:,k] = 'Average'
else:
I3[:,:,:,k] = 'State ' + str(k)
Bstar1 = np.expand_dims(np.reshape(Bstar1,p*q*K*K,order='F'),axis=0)
Bstar2 = np.expand_dims(np.reshape(Bstar2,p*q*K*K,order='F'),axis=0)
I1 = np.expand_dims(np.reshape(I1,p*q*K*K,order='F'),axis=0)
I2 = np.expand_dims(np.reshape(I2,p*q*K*K,order='F'),axis=0)
I3 = np.expand_dims(np.reshape(I3,p*q*K*K,order='F'),axis=0)
B = np.concatenate((Bstar1,Bstar2,I1,I2,I3),axis=0).T
df = pd.DataFrame(B,columns=('x','y','Variable','beta x','beta y'))
g = sb.relplot(x='x',
y='y',
s=25,
hue='Variable',
col="beta x", row="beta y",
data=df,
palette='cool')
for item, ax in g.axes_dict.items():
ax.grid(False, axis='x')
ax.set_title('')
# def show_r2(r2=None,hmm=None,Gamma=None,X=None,Y=None,indices=None,show_average=False):
# if r2 is None:
# if (Y is None) or (indices is None):
# raise Exception("Y and indices (and maybe X) has to be specified if r2 is not provided")
# r2 = hmm.get_r2(X,Y,Gamma,indices)
# if show_average:
# if (Y is None) or (indices is None):
# raise Exception("Y and indices (and maybe X) has to be specified if the average is to computed")
# r20 = hmm.get_r2(X,Y,Gamma,indices)
# for j in range(N):
# tt_j = range(indices[j,0],indices[j,1])
# if X is not None:
# Xj = np.copy(X[tt_j,:])
# d = np.copy(Y[tt_j,:])
# if self.hyperparameters["model_mean"] == 'shared':
# d -= np.expand_dims(self.mean[0]['Mu'],axis=0)
# if self.hyperparameters["model_beta"] == 'shared':
# d -= (Xj @ self.beta[0]['Mu'])
# for k in range(K):
# if self.hyperparameters["model_mean"] == 'state':
# d -= np.expand_dims(self.mean[k]['Mu'],axis=0) * np.expand_dims(Gamma[:,k],axis=1)
# if self.hyperparameters["model_beta"] == 'state':
# d -= (Xj @ self.beta[k]['Mu']) * np.expand_dims(Gamma[:,k],axis=1)
# d = np.sum(d**2,axis=0)
# d0 = np.copy(Y[tt_j,:])
# if self.hyperparameters["model_mean"] != 'no':
# d0 -= np.expand_dims(m,axis=0)
# d0 = np.sum(d0**2,axis=0)
# r2[j,:] = 1 - (d / d0)
[docs]
def custom_colormap():
"""
Generate a custom colormap consisting of segments from red to blue.
Returns:
--------
A custom colormap with defined color segments.
"""
# Retrieve existing colormaps
coolwarm_cmap = plt.get_cmap('coolwarm').reversed()
coolwarm_cmap2 = plt.get_cmap('autumn')
copper_cmap = plt.get_cmap('copper').reversed()
# Define the colors for the colormap
copper_color1 = to_rgba_array(copper_cmap(1))[0][:3]
# Define the colors for the colormap
red = (1,0,0)
red2 = (66/255, 13/255, 9/255)
orange =(1, 0.5, 0)
# red_color1 = to_rgba_array(coolwarm_cmap(0))[0][:3]
warm_color2 = to_rgba_array(coolwarm_cmap2(0.8))[0][:3]
blue_color1 = to_rgba_array(coolwarm_cmap(0.6))[0][:3]
blue_color2 = to_rgba_array(coolwarm_cmap(1.0))[0][:3] # Extract the blue color from coolwarm
# Define the color map with three segments: red to white, white, and white to blue
cmap_segments = [
(0.0, red2),
#(0.002, orange),
(0.005, red), # Intermediate color
(0.02, orange), # Intermediate color
#(0.045, warm_color1),
(0.040, warm_color2), # Intermediate color
(0.05, copper_color1),
(0.09,blue_color1),
(1, blue_color2)
]
# Create the custom colormap
custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', cmap_segments)
return custom_cmap
[docs]
def red_colormap():
"""
Generate a custom colormap consisting of red and warm colors.
Returns:
--------
A custom colormap with red and warm color segments.
"""
# Get the reversed 'coolwarm' colormap
coolwarm_cmap = plt.get_cmap('coolwarm').reversed()
# Get the 'autumn' colormap
autumn_cmap = plt.get_cmap('autumn')
# Define the colors for the colormap
red0 = (float(120/255), 0, 0)
red = (1, 0, 0)
red2 = (66/255, 13/255, 9/255)
orange = (1, 0.5, 0)
red_color1 = to_rgba_array(coolwarm_cmap(0))[0][:3]
warm_color1 = to_rgba_array(autumn_cmap(0.4))[0][:3]
warm_color2 = to_rgba_array(autumn_cmap(0.7))[0][:3]
# Define the color map with three segments: red to white, white, and white to blue
cmap_segments = [
(0.0, red2),
(0.3, red0),
(0.5, red),
(0.7, warm_color1), # Intermediate color
(1, warm_color2), # Intermediate color
]
# Create the custom colormap
custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', cmap_segments)
return custom_cmap
[docs]
def blue_colormap():
"""
Generate a custom blue colormap.
Returns:
--------
A custom colormap with shades of blue.
"""
coolwarm_cmap = plt.get_cmap('coolwarm').reversed()
copper_cmap = plt.get_cmap('copper').reversed()
# cool_cmap = plt.get_cmap('cool')
# Define the colors for the colormap
# white = (1, 1, 1) # White color
copper_color1 = to_rgba_array(copper_cmap(1))[0][:3]
# cool_color1 = to_rgba_array(cool_cmap(0.3))[0][:3]
# blue_color1 = to_rgba_array(coolwarm_cmap(0.5))[0][:3]
blue_color2 = to_rgba_array(coolwarm_cmap(0.7))[0][:3]
blue_color3 = to_rgba_array(coolwarm_cmap(1.0))[0][:3] # Extract the blue color from coolwarm
# Define the color map with three segments: red to white, white, and white to blue
cmap_segments = [
(0, copper_color1),
#(0.15, cool_color1),
(0.2,blue_color2),
#(0.7, cool_color1),
(1, blue_color3)
]
# Create the custom colormap
blue_cmap = LinearSegmentedColormap.from_list('custom_colormap', cmap_segments)
return blue_cmap
[docs]
def create_cmap_alpha(cmap_list,color_array, alpha):
"""
Modify the colors in a colormap based on an alpha threshold.
Parameters:
-----------
cmap_list (numpy.ndarray)
List of colors representing the original colormap.
color_array (numpy.ndarray)
Array of color values corresponding to each colormap entry.
alpha (float)
Alpha threshold for modifying colors.
Returns:
--------
Modified list of colors representing the colormap with adjusted alpha values.
"""
cmap_list_alpha =cmap_list.copy()
_,idx_alpha =np.where(color_array <= alpha)
coolwarm_cmap = plt.get_cmap('coolwarm').reversed()
#coolwarm_cmap2 = plt.get_cmap('autumn')
red = (1,0,0)
orange =(1, 0.5, 0)
red_color1 = to_rgba_array(coolwarm_cmap(0))[0][:3]
list_red = [red,red_color1,orange]
idx_interval =int(idx_alpha[-1]/(len(list_red)-1))
# Recolor the first to -idx_interval
cmap_list_alpha[:idx_alpha[-1],:3]=list_red[0]
for i in range(len(list_red)-1):
cmap_list_alpha[idx_interval*(i+1):idx_alpha[-1]+1,:3]=list_red[i+1]
return cmap_list_alpha
[docs]
def interpolate_colormap(cmap_list):
"""
Create a new colormap with the modified color_array.
Parameters:
--------------
cmap_list (numpy.ndarray):
Original color array for the colormap.
Returns:
----------
modified_cmap (numpy.ndarray):
Modified colormap array.
"""
# Create a new colormap with the modified color_array
modified_cmap = np.ones_like(cmap_list)
for channel_idx in range(3):
# Extract the channel values from the colormap
channel_values = cmap_list[:, channel_idx]
# Get unique values, their indices, and counts
unique_values, unique_indices, counts = np.unique(channel_values, return_index=True, return_counts=True)
# Create a copy unique_indices that is will get reduced for every interation
remaining_indices = unique_indices.copy()
remaining_counts = counts.copy()
# Create a list to store the interpolated values
new_map_list = []
for _ in range(len(unique_values)-1):
# Find the minimum value
min_value = np.min(remaining_indices)
# Locate the index
min_idx =np.where(unique_indices==min_value)
# Remove the minimum value from the array
remaining_counts = remaining_counts[remaining_indices != min_value]
remaining_indices = remaining_indices[remaining_indices != min_value]
# Find the location of the next minimum value from remaining_indices
next_min_value_idx =np.where(unique_indices==np.min(remaining_indices))
# Calculate interpolation space difference
space_diff = (unique_values[next_min_value_idx]-unique_values[min_idx])/int(counts[min_idx])
# Append interpolated values to the list
new_map_list.append(np.linspace(unique_values[min_idx], unique_values[next_min_value_idx]-space_diff, int(counts[min_idx])))
last_val =np.where(unique_indices==np.min(remaining_indices))
for _ in range(int(remaining_counts)):
# Append the last value to the new_map_list
new_map_list.append([unique_values[last_val]])
con_values= np.squeeze(np.concatenate(new_map_list))
# Insert values into the new color map
modified_cmap [:,channel_idx]=con_values
return modified_cmap
[docs]
def plot_p_value_matrix(pval, alpha = 0.05, normalize_vals=True, figsize=(9, 5),
title_text="Heatmap (p-values)", annot=False,
cmap_type='default', cmap_reverse=True, xlabel="", ylabel="",
xticklabels=None, none_diagonal = False, num_colors = 259, xlabel_rotation=0):
from matplotlib import cm, colors
import seaborn as sb
from mpl_toolkits.axes_grid1 import make_axes_locatable
"""
Plot a heatmap of p-values.
Parameters:
-----------
pval (numpy.ndarray)
The p-values data to be plotted.
normalize_vals : (bool, optional), default=False:
If True, the data range will be normalized from 0 to 1.
figsize tuple, optional, default=(12,7):
Figure size in inches (width, height).
steps (int, optional), default=11:
Number of steps for x and y-axis ticks.
title_text (str, optional), default= "Heatmap (p-values)"
Title text for the heatmap.
annot (bool, optional), default=True:
If True, annotate each cell with the numeric value.
cmap (str, optional), default= "default":
Colormap to use. Default is a custom colormap based on 'coolwarm'.
xlabel (str, optional), default=""
X-axis label. If not provided, default labels based on the method will be used.
ylabel (str, optional), default=""
Y-axis label. If not provided, default labels based on the method will be used.
xticklabels (List[str], optional), default=None:
If not provided, labels will be numbers equal to shape of pval.shape[1].
Else you can define your own labels, e.g., xticklabels=['sex', 'age'].
none_diagonal (bool, optional), default=False:
If you want to turn the diagonal into NaN numbers.
num_colors (numpy.ndarray), default=259:
Define the number of different shades of color.
xlabel_rotation (numpy-mdarray), default=0
The degree of rotation for the labels in the x-axis
"""
if pval.ndim==0:
pval = np.reshape(pval, (1, 1))
if xlabel_rotation==45:
ha ="right"
else:
ha = "center"
fig, axes = plt.subplots(figsize=figsize)
if len(pval.shape)==1:
pval =np.expand_dims(pval,axis=0)
if cmap_type=='default':
if normalize_vals:
color_array = np.logspace(-3, 0, num_colors).reshape(1, -1)
if alpha == None and normalize_vals==False:
cmap = cm.coolwarm.reversed()
elif alpha == None and normalize_vals==True:
# Create custom colormap
coolwarm_cmap = custom_colormap()
# Create a new colormap with the modified color_array
cmap_list = coolwarm_cmap(color_array)[0]
modified_cmap=interpolate_colormap(cmap_list)
# Create a LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list('custom_colormap', modified_cmap)
else:
color_array = np.logspace(-3, 0, num_colors).reshape(1, -1)
# Make a jump in color after alpha
# Get blue colormap
cmap_blue = blue_colormap()
# Create a new colormap with
cmap_list = cmap_blue(color_array)[0]
red_cmap = red_colormap()
blue_cmap = blue_colormap()
# Specify the number of elements you want (e.g., 50)
num_elements_red = np.sum(color_array <= alpha)
num_elements_blue = np.sum(color_array > alpha)
# Generate equally spaced values between 0 and 1
colormap_val_red = np.linspace(0, 1, num_elements_red)
colormap_val_blue = np.linspace(0, 1, num_elements_blue)
# Apply the colormap to the generated values
cmap_red = red_cmap(colormap_val_red)
cmap_blue = blue_cmap(colormap_val_blue)
# overwrite the values below alpha
cmap_list[:num_elements_red,:]=cmap_red
cmap_list[num_elements_red:,:]=cmap_blue
cmap = LinearSegmentedColormap.from_list('custom_colormap', cmap_list)
else:
# Get the colormap dynamically based on the input string
cmap = getattr(cm, cmap_type, None)
if cmap_reverse:
cmap =cmap.reversed()
# Set the value of 0 to white in the colormap
if none_diagonal:
# Create a copy of the pval matrix
pval_with_nan_diagonal = np.copy(pval)
# Set the diagonal elements to NaN in the copied matrix
np.fill_diagonal(pval_with_nan_diagonal, np.nan)
pval = pval_with_nan_diagonal.copy()
if normalize_vals:
norm = LogNorm(vmin=1e-3, vmax=1)
heatmap = sb.heatmap(pval, ax=axes, cmap=cmap, annot=annot, fmt=".3f", cbar=False, norm=norm)
else:
heatmap = sb.heatmap(pval, ax=axes, cmap=cmap, annot=annot, fmt=".3f", cbar=False)
# Add labels and title
axes.set_xlabel(xlabel, fontsize=12)
axes.set_ylabel(ylabel, fontsize=12)
axes.set_title(title_text, fontsize=14)
# Number of x-tick steps
steps=len(pval)
# Set the x-axis ticks
if xticklabels is not None:
axes.set_xticks(np.arange(len(xticklabels)) + 0.5)
axes.set_xticklabels(xticklabels, rotation=xlabel_rotation, fontsize=10, ha=ha)
elif pval.shape[1]>1:
axes.set_xticks(np.linspace(0, pval.shape[1]-1, steps).astype(int)+0.5)
axes.set_xticklabels(np.linspace(1, pval.shape[1], steps).astype(int), rotation=xlabel_rotation, fontsize=10, ha=ha)
else:
axes.set_xticklabels([])
# Set the y-axis ticks
if pval.shape[0]>1:
axes.set_yticks(np.linspace(0, pval.shape[0]-1, steps).astype(int)+0.5)
axes.set_yticklabels(np.linspace(1, pval.shape[0], steps).astype(int), rotation=xlabel_rotation, fontsize=10, ha=ha)
else:
axes.set_yticklabels([])
# Create an axes on the right side of ax. The width of cax will be 5%
# of ax and the padding between cax and ax will be fixed at 0.05 inch.
if normalize_vals:
divider = make_axes_locatable(axes)
cax = divider.append_axes("right", size="5%", pad=0.05)
colorbar = plt.colorbar(heatmap.get_children()[0], cax=cax, ticks=np.logspace(-3, 0, num_colors))
colorbar.update_ticks()
# Round the tick values to three decimal places
rounded_ticks = [round(tick, 3) for tick in colorbar.get_ticks()]
if figsize[-1] ==1:
# Set colorbar ticks based on the same log scale
tick_positions = [0, 0.001, 0.01, 0.05, 0.3, 1]
else:
# Set colorbar ticks based on the same log scale
tick_positions = [0, 0.001, 0.01, 0.05, 0.1, 0.3, 1]
tick_labels = [f'{tick:.3f}' if tick in tick_positions else '' for tick in rounded_ticks]
unique_values_set = set()
unique_values_array = ['' if value == '' or value in unique_values_set else (unique_values_set.add(value), value)[1] for value in tick_labels]
indices_not_empty = [index for index, value in enumerate(unique_values_array) if value != '']
colorbar.set_ticklabels(unique_values_array)
colorbar.ax.tick_params(axis='y')
for idx, tick_line in enumerate(colorbar.ax.yaxis.get_ticklines()):
if idx not in indices_not_empty:
tick_line.set_visible(False)
else:
divider = make_axes_locatable(axes)
cax = divider.append_axes("right", size="5%", pad=0.05)
# Create a custom colorbar
colorbar = plt.colorbar(heatmap.get_children()[0], cax=cax)
# Set the ticks to range from the bottom to the top of the colorbar
# Get the minimum and maximum values from your data
min_value = np.nanmin(pval)
max_value = np.nanmax(pval)
# Set ticks with at least 5 values evenly spaced between min and max
colorbar.set_ticks(np.linspace(min_value, max_value, 5).round(2))
#colorbar.set_ticks([0, 0.25, 0.5, 1]) # Adjust ticks as needed
# Show the plot
plt.show()
[docs]
def plot_correlation_matrix(corr_vals, performed_tests, normalize_vals=False,
figsize=(9, 5), title_text="Correlation Coefficients Heatmap",
annot=False, cmap_type='default', cmap_reverse=True, xlabel="", ylabel="",
xticklabels=None,xlabel_rotation=45, none_diagonal = False, num_colors = 256):
from matplotlib import cm, colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
"""
Plot a heatmap of correlation coefficients.
Parameters:
-----------
corr_vals (numpy.ndarray)
Base statistics of correlation coefficients.
performed_tests (dict)
Holds information about the different test statistics that have been applied.
normalize_vals (bool, optional)
If True, the data range will be normalized from 0 to 1 (default is False).
figsize (tuple, optional), default=(9, 5):
Figure size in inches (width, height).
title_text (str, optional), default="Correlation Coefficients Heatmap"
Title text for the heatmap.
annot (bool, optional), default=True:
If True, annotate each cell with the numeric value.
cmap_type (str, optional), default='default':
Colormap to use.
cmap_reverse (bool, optional), default=True:
If True, reverse the colormap.
xlabel (str, optional), default='':
X-axis label. If not provided, default labels based on the method will be used.
ylabel (str, optional), default='':
Y-axis label. If not provided, default labels based on the method will be used.
xticklabels (List[str], optional), default=None:
If not provided, labels will be numbers equal to the shape of corr_vals.shape[1].
Else, you can define your own labels, e.g., xticklabels=['sex', 'age'].
none_diagonal (bool, optional), default=False:
If True, turn the diagonal into NaN numbers.
num_colors (int, optional), default=256:
Number of colors to use in the colormap.
"""
if performed_tests["t_test_cols"]!=[] or performed_tests["f_test_cols"]!=[]:
raise ValueError("Cannot plot the base statistics for the correlation coefficients because different test statistics have been used.")
if corr_vals.ndim==0:
corr_vals = np.reshape(corr_vals, (1, 1))
if xlabel_rotation==45:
ha ="right"
else:
ha = "center"
# Number of x-tick steps
steps=len(corr_vals)
fig, axes = plt.subplots(figsize=figsize)
if len(corr_vals.shape)==1:
corr_vals =np.expand_dims(corr_vals,axis=0)
if cmap_type=='default':
# seismic_cmap = cm.seismic.reversed()
coolwarm_cmap = cm.coolwarm.reversed()
#seismic_cmap = cm.RdBu.reversed()
# Generate an array of values representing the colormap
color_array = np.linspace(0, 1, num_colors).reshape(1, -1)
cmap_list = coolwarm_cmap(color_array)[0]
cmap = colors.ListedColormap(cmap_list)
else:
# Get the colormap dynamically based on the input string
cmap = getattr(cm, cmap_type, None)
if cmap_reverse:
cmap =cmap.reversed()
if normalize_vals:
# Normalize the data range from -1 to 1
norm = plt.Normalize(vmin=-1, vmax=1)
heatmap = sb.heatmap(corr_vals, ax=axes, cmap=cmap, annot=annot, fmt=".3f", cbar=False, norm=norm)
else:
heatmap = sb.heatmap(corr_vals, ax=axes, cmap=cmap, annot=annot, fmt=".3f", cbar=False)
# Add labels and title
axes.set_xlabel(xlabel, fontsize=12)
axes.set_ylabel(ylabel, fontsize=12)
axes.set_title(title_text, fontsize=14)
# Set the x-axis ticks
if xticklabels is not None:
axes.set_xticks(np.arange(len(xticklabels)) + 0.5)
axes.set_xticklabels(xticklabels, rotation=xlabel_rotation, ha=ha,fontsize=10)
elif corr_vals.shape[1]>1:
axes.set_xticks(np.linspace(0, corr_vals.shape[1]-1, steps).astype(int)+0.5)
axes.set_xticklabels(np.linspace(1, corr_vals.shape[1], steps).astype(int), rotation=xlabel_rotation, ha=ha, fontsize=10)
else:
axes.set_xticklabels([])
# Set the y-axis ticks
if corr_vals.shape[0]>1:
axes.set_yticks(np.linspace(0, corr_vals.shape[0]-1, steps).astype(int)+0.5)
axes.set_yticklabels(np.linspace(1, corr_vals.shape[0], steps).astype(int), rotation=xlabel_rotation, ha=ha, fontsize=10)
else:
axes.set_yticklabels([])
# Create an axes on the right side of ax. The width of cax will be 5%
# of ax and the padding between cax and ax will be fixed at 0.05 inch.
divider = make_axes_locatable(axes)
cax = divider.append_axes("right", size="5%", pad=0.01)
# Create a custom colorbar
colorbar = plt.colorbar(heatmap.get_children()[0], cax=cax)
# Set the ticks to range from the bottom to the top of the colorbar
# Get the minimum and maximum values from your data
min_value = np.nanmin(corr_vals).round(2)
max_value = np.floor(np.nanmax(corr_vals) * 100) / 100
if normalize_vals:
colorbar.set_ticks(np.linspace(-1, 1, 7).round(2))
else:
# Set ticks with at least 5 values evenly spaced between min and max
colorbar.set_ticks(np.linspace(min_value, max_value, 7).round(2))
# Show the plot
plt.show()
[docs]
def plot_permutation_distribution(test_statistic, title_text="Permutation Distribution",xlabel="Test Statistic Values",ylabel="Density"):
"""
Plot the histogram of the permutation with the observed statistic marked.
Parameters:
-----------
test_statistic (numpy.ndarray)
An array containing the permutation values.
title_text (str, optional), default="Permutation Distribution":
Title text of the plot.
xlabel (str, optional), default="Test Statistic Values"
Text of the xlabel.
ylabel (str, optional), default="Density"
Text of the ylabel.
"""
plt.figure()
sb.histplot(test_statistic, kde=True)
plt.axvline(x=test_statistic[0], color='red', linestyle='--', label='Observed Statistic')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title_text, fontsize=14)
plt.legend()
plt.show()
[docs]
def plot_scatter_with_labels(p_values, alpha=0.05, title_text="", xlabel=None, ylabel=None, xlim_start=0.9, ylim_start=0):
"""
Create a scatter plot to visualize p-values with labels indicating significant points.
Parameters:
-----------
p_values (numpy.ndarray)
An array of p-values. Can be a 1D array or a 2D array with shape (1, 5).
alpha (float, optional), default=0.05:
Threshold for significance.
title_text (str, optional), default="":
The title text for the plot.
xlabel (str, optional), default=None:
The label for the x-axis.
ylabel (str, optional), default=None:
The label for the y-axis.
xlim_start (float, optional), default=-5
Start position of x-axis limits.
ylim_start (float, optional), default=-0.1
Start position of y-axis limits.
Notes:
------
Points with p-values less than alpha are considered significant and marked with red text.
"""
# If p_values is a 2D array with shape (1, 5), flatten it to 1D
if len(p_values.shape) == 2 and p_values.shape[0] == 1 and p_values.shape[1] == 5:
p_values = p_values.flatten()
# Create a binary mask based on condition (values below alpha)
mask = p_values < alpha
# Create a hue p_values based on the mask (True/False values)
hue = mask.astype(int)
# Set the color palette and marker style
markers = ["o", "s"]
# Create a scatter plot with hue and error bars
fig, ax = plt.subplots(figsize=(8, 6))
sb.scatterplot(x=np.arange(0, len(p_values)) + 1, y=-np.log(p_values), hue=hue, style=hue,
markers=markers, s=40, edgecolor='k', linewidth=1, ax=ax)
# Add labels and title to the plot
if not title_text:
ax.set_title(f'Scatter Plot of P-values, alpha={alpha}', fontsize=14)
else:
ax.set_title(title_text, fontsize=14)
if xlabel is None:
ax.set_xlabel('Index', fontsize=12)
else:
ax.set_xlabel(xlabel, fontsize=12)
if ylabel is None:
ax.set_ylabel('-log(p-values)', fontsize=12)
else:
ax.set_ylabel(ylabel, fontsize=12)
# Add text labels for indices where the mask is True
for i, m in enumerate(mask):
if m:
ax.text(i + 1, -np.log(p_values[i]), str(i+1), ha='center', va='bottom', color='red', fontsize=10)
# Adjust legend position and font size
ax.legend(title="Significance", loc="upper right", fontsize=10, bbox_to_anchor=(1.25, 1))
# Set axis limits to focus on the relevant data range
ax.set_xlim(xlim_start, len(p_values) + 1)
ax.set_ylim(ylim_start, np.max(-np.log(p_values)) * 1.2)
# # Customize plot background and grid style
# sb.set_style("white")
# ax.grid(color='lightgray', linestyle='--')
# Show the plot
plt.tight_layout()
plt.show()
import seaborn as sns
[docs]
def plot_vpath(viterbi_path, signal=None, idx_data=None, figsize=(7, 4), fontsize_labels=13, fontsize_title=16, yticks=None, time_conversion_rate=None, xlabel="Timepoints", ylabel="", title="Viterbi Path", signal_label="Signal", show_legend=True, vertical_linewidth=1.5):
"""
Plot Viterbi path with optional signal overlay.
Parameters:
-----------
viterbi_path
The Viterbi path data matrix.
signal : array-like, optional
Signal data to overlay on the plot. Default is None.
idx_data : array-like, optional
Array representing time intervals. Default is None.
figsize : tuple, optional
Figure size. Default is (7, 4).
fontsize_labels : int, optional
Font size for axis labels. Default is 13.
fontsize_title : int, optional
Font size for plot title. Default is 16.
yticks : bool, optional
Whether to show y-axis ticks. Default is None.
time_conversion_rate : float, optional
Conversion rate from time steps to seconds. Default is None.
xlabel : str, optional
Label for the x-axis. Default is "Timepoints".
ylabel : str, optional
Label for the y-axis. Default is "".
title : str, optional
Title for the plot. Default is "Viterbi Path".
signal_label : str, optional
Label for the signal plot. Default is "Signal".
show_legend : bool, optional
Whether to show the legend. Default is True.
vertical_linewidth : float, optional
Line width for vertical gray lines. Default is 1.5.
"""
num_states = viterbi_path.shape[1]
colors = sns.color_palette("Set3", n_colors=num_states)
if num_states > len(colors):
extra_colors = sns.color_palette("husl", n_colors=num_states - len(colors))
colors.extend(extra_colors)
fig, axes = plt.subplots(figsize=figsize)
# Plot Viterbi path
if time_conversion_rate is not None:
time_seconds = np.arange(viterbi_path.shape[0]) / time_conversion_rate
axes.stackplot(time_seconds, viterbi_path.T, colors=colors, labels=[f'State {i + 1}' for i in range(num_states)])
if xlabel == "Timepoints":
xlabel = "Time (seconds)"
axes.set_xlabel(xlabel, fontsize=fontsize_labels)
else:
axes.stackplot(np.arange(viterbi_path.shape[0]), viterbi_path.T, colors=colors, labels=[f'State {i + 1}' for i in range(num_states)])
axes.set_xlabel(xlabel, fontsize=fontsize_labels)
axes.set_ylabel(ylabel, fontsize=fontsize_labels)
axes.set_title(title, fontsize=fontsize_title)
# Plot signal overlay
if signal is not None:
if time_conversion_rate is not None:
time_seconds = np.arange(len(signal)) / time_conversion_rate
axes.plot(time_seconds, signal, color='black', label=signal_label)
axes.set_xlabel(xlabel, fontsize=fontsize_labels)
else:
axes.plot(signal, color='black', label=signal_label)
# Draw vertical gray lines for T_t intervals
if idx_data is not None:
for idx in idx_data[:-1, 1]:
axes.axvline(x=idx, color='gray', linestyle='--', linewidth=vertical_linewidth)
# Show legend
if show_legend:
axes.legend(title='States', loc='upper left', bbox_to_anchor=(1, 1))
if yticks:
scaled_values = [int(val * len(np.unique(signal))) for val in np.unique(signal)]
# Set y-ticks with formatted integers
axes.set_yticks(np.unique(signal), scaled_values)
else:
# Remove x-axis tick labels
axes.set_yticks([])
# Remove the frame around the plot
axes.spines['top'].set_visible(False)
axes.spines['right'].set_visible(False)
axes.spines['bottom'].set_visible(False)
axes.spines['left'].set_visible(False)
# Adjust tick label font size
axes.tick_params(axis='both', labelsize=fontsize_labels)
plt.tight_layout()
plt.show()
[docs]
def plot_average_probability(Gamma_reconstruct, title='Average probability for each state', fontsize=16, figsize=(7, 5), vertical_lines=None, line_colors=None, highlight_boxes=False):
"""
Plots the average probability for each state over time.
Parameters:
-----------
Gamma_reconstruct (numpy.ndarray)
3D array representing reconstructed gamma values.
Shape: (num_timepoints, num_trials, num_states)
title (str, optional), default='Average probability for each state':
Title for the plot.
fontsize (int, optional), default=16:
Font size for labels and title.
figsize (tuple, optional), default=(8,6):
Figure size (width, height) in inches).
vertical_lines (list of tuples, optional), default=None:
List of pairs specifying indices for vertical lines.
line_colors (list of str or bool, optional), default=None:
List of colors for each pair of vertical lines. If True, generates random colors (unless a list is provided).
highlight_boxes (bool, optional), default=False:
Whether to include highlighted boxes for each pair of vertical lines.
"""
# Initialize an array for average gamma values
Gamma_avg = np.zeros((Gamma_reconstruct.shape[0], Gamma_reconstruct.shape[-1]))
# Calculate and store average gamma values
for i in range(Gamma_reconstruct.shape[0]):
filtered_values = Gamma_reconstruct[i, :, :]
Gamma_avg[i, :] = np.mean(filtered_values, axis=0).round(3)
# Set figure size
fig, axes = plt.subplots(1, figsize=figsize)
# Plot each line with a label
for state in range(Gamma_reconstruct.shape[-1]):
plt.plot(Gamma_avg[:, state], label=f'State {state + 1}')
# Add vertical lines, line colors, and highlight boxes
if vertical_lines:
for idx, pair in enumerate(vertical_lines):
color = line_colors[idx] if line_colors and len(line_colors) > idx else 'gray'
axes.axvline(x=pair[0], color=color, linestyle='--', linewidth=1)
axes.axvline(x=pair[1], color=color, linestyle='--', linewidth=1)
if highlight_boxes:
rect = plt.Rectangle((pair[0], axes.get_ylim()[0]), pair[1] - pair[0], axes.get_ylim()[1] - axes.get_ylim()[0], linewidth=0, edgecolor='none', facecolor=color, alpha=0.2)
axes.add_patch(rect)
# Add labels and legend
plt.xlabel('Timepoints', fontsize=fontsize)
plt.ylabel('Average probability', fontsize=fontsize)
plt.title(title, fontsize=fontsize)
# Add legend for the highlighted boxes
if highlight_boxes:
legend_rect = plt.Rectangle((0, 0), 1, 1, linewidth=0, edgecolor='none', facecolor='gray', alpha=0.2, label='Interval with significant difference')
plt.legend(handles=[legend_rect], loc='upper right')
# Place legend for the lines to the right of the figure
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Show the plot
plt.show()
[docs]
def plot_FO(FO, figsize=(8, 4), fontsize_labels=13, fontsize_title=16, width=0.8,xlabel='Subject',ylabel='Fractional occupancy',title='State Fractional Occupancies', show_legend=True, num_ticks=10):
"""
Plot fractional occupancies for different states.
Parameters:
-----------
FO (numpy.ndarray):
Fractional occupancy data matrix.
figsize (tuple, optional), default=(8,4):
Figure size.
fontsize_labels (int, optional), default=13:
Font size for axes labels.
fontsize_title (int, optional), default=16:
Font size for plot title.
width (float, optional), default=0.5:
Width of the bars.
xlabel (str, optional), default='Subject':
Label for the x-axesis.
ylabel (str, optional), default='Fractional occupancy':
Label for the y-axesis.
title (str, optional), default='State Fractional Occupancies':
Title for the plot.
show_legend (bool, optional), default=True:
Whether to show the legend.
"""
fig, axes = plt.subplots(figsize=figsize)
bottom = np.zeros(FO.shape[0])
sessions = np.arange(1, FO.shape[0] + 1)
num_states = FO.shape[1]
colors = sns.color_palette("Set3", n_colors=num_states)
if num_states > len(colors):
extra_colors = sns.color_palette("husl", n_colors=num_states - len(colors))
colors.extend(extra_colors)
for k in range(num_states):
p = axes.bar(sessions, FO[:, k], bottom=bottom, color=colors[k], width=width)
bottom += FO[:, k]
axes.set_xticks(sessions)
axes.set_xlabel(xlabel, fontsize=fontsize_labels)
axes.set_ylabel(ylabel, fontsize=fontsize_labels)
axes.set_title(title, fontsize=fontsize_title)
ticks = np.linspace(1, FO.shape[0], FO.shape[0]).astype(int)
# If there are more than 10 states then make a steps of 5
if len(ticks)>10:
n_ticks = num_ticks
else:
n_ticks = len(ticks)
axes.set_xticks(np.linspace(1, FO.shape[0], n_ticks).astype(int))
axes.set_yticks(np.linspace(0, 1, 5))
# Remove the frame around the plot
axes.spines['top'].set_visible(False)
axes.spines['right'].set_visible(False)
axes.spines['bottom'].set_visible(False)
axes.spines['left'].set_visible(False)
# Adjust tick label font size
axes.tick_params(axis='both', labelsize=fontsize_labels)
if show_legend:
legend = axes.legend(['State {}'.format(i+1) for i in range(FO.shape[1])], fontsize=fontsize_labels, loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
[docs]
def plot_switching_rates(SR, figsize=(8, 4), fontsize_labels=13, fontsize_title=16, width=0.18, xlabel='Subject', ylabel='Switching Rate', title='State Switching Rates', show_legend=True, num_ticks=10):
"""
Plot switching rates for different states.
Parameters:
-----------
SR (numpy.ndarray):
Switching rate data matrix.
figsize (tuple, optional), default=(8, 4):
Figure size.
fontsize_labels (int, optional), default=13:
Font size for axes labels.
fontsize_title (int, optional), default=16:
Font size for plot title.
width (float, optional), default=0.18:
Width of the bars.
xlabel (str, optional), default='Subject':
Label for the x-axesis.
ylabel (str, optional), default='Switching Rate':
Label for the y-axesis.
title (str, optional), default='State Switching Rates':
Title for the plot.
show_legend (bool, optional), default=True:
Whether to show the legend.
"""
fig, axes = plt.subplots(figsize=figsize, constrained_layout=True)
multiplier = 0
sessions = np.arange(1, SR.shape[0] + 1)
num_states = SR.shape[1]
colors = sns.color_palette("Set3", n_colors=num_states)
if num_states > len(colors):
extra_colors = sns.color_palette("husl", n_colors=num_states - len(colors))
colors.extend(extra_colors)
for k in range(num_states):
offset = width * multiplier
rects = axes.bar(sessions + offset, SR[:, k], width, color=colors[k])
multiplier += 1
axes.set_xticks(sessions)
axes.set_xlabel(xlabel, fontsize=fontsize_labels)
axes.set_ylabel(ylabel, fontsize=fontsize_labels)
axes.set_title(title, fontsize=fontsize_title)
ticks = np.linspace(1, SR.shape[0], SR.shape[0]).astype(int)
# If there are more than 10 states then make a steps of 5
if len(ticks)>10:
n_ticks = num_ticks
else:
n_ticks = len(ticks)
axes.set_xticks(np.linspace(1, SR.shape[0], n_ticks).astype(int))
# Remove the frame around the plot
axes.spines['top'].set_visible(False)
axes.spines['right'].set_visible(False)
axes.spines['bottom'].set_visible(False)
axes.spines['left'].set_visible(False)
# Adjust tick label font size
axes.tick_params(axis='both', labelsize=fontsize_labels)
if show_legend:
axes.legend(['State {}'.format(i+1) for i in range(num_states)], fontsize=fontsize_labels, loc='upper left', bbox_to_anchor=(1, 1))
plt.show()
[docs]
def plot_state_lifetimes(LT, figsize=(8, 4), fontsize_labels=13, fontsize_title=16, width=0.18, xlabel='Subject', ylabel='Lifetime', title='State Lifetimes', show_legend=True, num_ticks=10):
"""
Plot state lifetimes for different states.
Parameters:
-----------
LT (numpy.ndarray):
State lifetime (dwell time) data matrix.
figsize (tuple, optional), default=(8, 4):
Figure size.
fontsize_labels (int, optional), default=13:
Font size for axeses labels.
fontsize_title (int, optional), default=16:
Font size for plot title.
width (float, optional), default=0.18:
Width of the bars.
xlabel (str, optional), default='Subject':
Label for the x-axesis.
ylabel (str, optional), default='Lifetime':
Label for the y-axesis.
title (str, optional), default='State Lifetimes':
Title for the plot.
show_legend (bool, optional), default=True:
Whether to show the legend.
"""
fig, axes = plt.subplots(figsize=figsize, constrained_layout=True)
multiplier = 0
sessions = np.arange(1, LT.shape[0] + 1)
num_states = LT.shape[1]
colors = sns.color_palette("Set3", n_colors=num_states)
if num_states > len(colors):
extra_colors = sns.color_palette("husl", n_colors=num_states - len(colors))
colors.extend(extra_colors)
for k in range(num_states):
offset = width * multiplier
rects = axes.bar(sessions + offset, LT[:, k], width, color=colors[k])
multiplier += 1
axes.set_xticks(sessions, sessions)
axes.set_xlabel(xlabel, fontsize=fontsize_labels)
axes.set_ylabel(ylabel, fontsize=fontsize_labels)
axes.set_title(title, fontsize=fontsize_title)
ticks = np.linspace(1, LT.shape[0], LT.shape[0]).astype(int)
# If there are more than 10 states then make a steps of 5
if len(ticks)>10:
n_ticks = num_ticks
else:
n_ticks = len(ticks)
axes.set_xticks(np.linspace(1, LT.shape[0], n_ticks).astype(int))
# Remove the frame around the plot
axes.spines['top'].set_visible(False)
axes.spines['right'].set_visible(False)
axes.spines['bottom'].set_visible(False)
axes.spines['left'].set_visible(False)
# Adjust tick label font size
axes.tick_params(axis='both', labelsize=fontsize_labels)
if show_legend:
axes.legend(['State {}'.format(i+1) for i in range(num_states)], fontsize=fontsize_labels, loc='upper left', bbox_to_anchor=(1, 1))
plt.show()
[docs]
def plot_state_prob_and_covariance(init_stateP, TP, state_means, state_FC, cmap='viridis', figsize=(9, 7), num_ticks=5):
"""
Plot HMM parameters.
Parameters:
-----------
init_stateP : array-like
Initial state probabilities.
TP : array-like
Transition probabilities.
state_means : array-like
State means.
state_FC : array-like
State covariances.
cmap : str or Colormap, optional
The colormap to be used for plotting. Default is 'viridis'.
figsize : tuple, optional
Figure size. Default is (9, 7).
num_ticks : int, optional
Number of ticks for the colorbars
"""
# Define the number of plots and their layout
num_plots = 3 + state_FC.shape[2] # Number of plots including initial stateP, TP, state_means, and state_FC
num_cols = min(num_plots, 3) # Maximum number of columns
num_rows = (num_plots - 1) // 3 + 1 # Calculate number of rows
# Create the figure and subplots
fig, axes = plt.subplots(num_rows, 3, figsize=figsize) # Adjust figsize as needed
# Plot initial state probabilities
im0 = axes[0, 0].imshow(init_stateP.reshape(-1, 1), cmap=cmap)
axes[0, 0].set_title("Initial state probabilities")
axes[0, 0].set_xticks([])
cbar0 = fig.colorbar(im0, ax=axes[0, 0])
cbar0.set_ticks(np.linspace(init_stateP.min(), init_stateP.max(), num=num_ticks).round(2))
ticks = np.linspace(0, init_stateP.shape[0]-1, init_stateP.shape[0]).astype(int)
# If there are more than 10 states then make a steps of 5
if len(ticks)>10:
num_state = num_ticks
else:
num_state = len(ticks)
axes[0, 0].set_yticks(np.linspace(0, init_stateP.shape[0]-1, num_state).astype(int))
axes[0, 0].set_yticklabels(ticks + 1) # Increment ticks by 1 for labels
# Plot transition probabilities
im1 = axes[0, 1].imshow(TP, cmap=cmap)
axes[0, 1].set_title("Transition probabilities")
cbar1 = fig.colorbar(im1, ax=axes[0, 1])
cbar1.set_ticks(np.linspace(TP.min(), TP.max(), num=num_ticks).round(2))
ticks = np.linspace(0, TP.shape[0]-1, TP.shape[0]).astype(int)
# If there are more than 10 states then make a steps of 5
axes[0, 1].set_xticks(np.linspace(0, TP.shape[0]-1, num_state).astype(int))
axes[0, 1].set_xticklabels(ticks + 1) # Increment ticks by 1 for labels
axes[0, 1].set_yticks(np.linspace(0, TP.shape[0]-1, num_state).astype(int))
axes[0, 1].set_yticklabels(ticks + 1) # Increment ticks by 1 for labels
# Plot state means
num_ticks = max(5, min(state_means.shape))
im2 = axes[0, 2].imshow(state_means, cmap=cmap, aspect='auto')
axes[0, 2].set_title("State means")
cbar2 = fig.colorbar(im2, ax=axes[0, 2])
cbar2.set_ticks(np.linspace(state_means.min(), state_means.max(), num=num_ticks).round(2))
# Set ticks and labels
ticks = np.linspace(0, state_means.shape[1]-1, num_ticks).astype(int)
axes[0, 2].set_xticks(ticks)
axes[0, 2].set_xticklabels(ticks + 1) # Increment ticks by 1 for labels
axes[0, 2].set_yticks(np.linspace(1, state_means.shape[0], num_ticks).astype(int))
# Plot state covariances
min_value = np.min(state_FC)
max_value = np.max(state_FC)
# Limits the number of ticks
if len(ticks)>10:
num_state = num_ticks
else:
num_state = len(ticks)
ticks = np.linspace(0, state_FC.shape[0] - 1, num_state).astype(int)
# Plot state covariances
for k in range((num_cols*num_rows) -3): # have to fill the remaning number of subplots
row_idx = (k + 3) // 3 # Shift row index by 3 to start from the second row
col_idx = (k + 3) % 3
if k < num_plots - 3:
im = axes[row_idx, col_idx].imshow(state_FC[:, :, k], cmap=cmap, vmin=min_value, vmax=max_value)
axes[row_idx, col_idx].set_title("State covariance\nstate #%s" % (k + 1))
# Adjust tick locations
axes[row_idx, col_idx].set_xticks(ticks)
axes[row_idx, col_idx].set_yticks(ticks)
axes[row_idx, col_idx].set_xticklabels(ticks + 1) # Increment ticks by 1 for labels
axes[row_idx, col_idx].set_yticklabels(ticks + 1) # Increment ticks by 1 for
cbar = fig.colorbar(im, ax=axes[row_idx, col_idx])
cbar.set_ticks(np.linspace(min_value, max_value, num=num_ticks).round(2))
else:
axes[row_idx, col_idx].axis('off') # Leave empty plots blank
plt.subplots_adjust(hspace=0.5, wspace=0.5)
plt.show()
[docs]
def plot_condition_difference(Gamma_reconstruct, R_trials, title='Average Probability and Difference', fontsize=16, figsize=(9, 2), vertical_lines=None, line_colors=None, highlight_boxes=False):
"""
Plots the average probability for each state over time for two conditions and their difference.
Parameters:
-----------
Gamma_reconstruct (numpy.ndarray)
3D array representing reconstructed gamma values. Shape: (num_timepoints, num_trials, num_states)
R_trials (numpy.ndarray)
1D array representing the condition for each trial.
Should have the same length as the second dimension of Gamma_reconstruct.
title (str, optional), default='Average Probability and Difference':
Title for the plot.
fontsize (int, optional), default=16:
Font size for labels and title.
figsize (tuple, optional), default=(9, 2):
Figure size (width, height).
vertical_lines (list of tuples, optional), default=None:
List of pairs specifying indices for vertical lines.
line_colors (list of str or bool, optional), default=None:
List of colors for each pair of vertical lines. If True, generates random colors
(unless a list is provided).
highlight_boxes (bool, optional), default=False:
Whether to include highlighted boxes for each pair of vertical lines.
Example usage:
--------------
plot_condition_difference(Gamma_reconstruct, R_trials, vertical_lines=[(10, 100)], highlight_boxes=True)
"""
filt_val = np.zeros((2, Gamma_reconstruct.shape[0], Gamma_reconstruct.shape[2]))
# Create subplots
fig, axes = plt.subplots(1, 3, figsize=figsize)
# Plot for each condition
for condition in range(2):
for i in range(Gamma_reconstruct.shape[0]):
filtered_values = Gamma_reconstruct[i, (R_trials == condition + 1), :]
filt_val[condition, i, :] = np.mean(filtered_values, axis=0).round(3)
axes[condition].plot(filt_val[condition, :, :])
axes[condition].set_title(f"Condition {condition + 1}")
axes[condition].set_xticks(np.linspace(0, Gamma_reconstruct.shape[0] - 1, 5).astype(int))
axes[condition].set_yticks(np.linspace(axes[condition].get_ylim()[0], axes[condition].get_ylim()[1], 5).round(2))
# Find the element-wise difference
difference = filt_val[0, :, :] - filt_val[1, :, :]
# Plot the difference
axes[2].plot(difference)
axes[2].set_title("Difference")
axes[2].set_xticks(np.linspace(0, Gamma_reconstruct.shape[0] - 1, 5).astype(int))
axes[2].set_yticks(np.linspace(axes[2].get_ylim()[0], axes[2].get_ylim()[1], 5).round(2))
# Add vertical lines, line colors, and highlight boxes
if vertical_lines:
for idx, pair in enumerate(vertical_lines):
color = line_colors[idx] if line_colors and len(line_colors) > idx else 'gray'
axes[2].axvline(x=pair[0], color=color, linestyle='--', linewidth=1)
axes[2].axvline(x=pair[1], color=color, linestyle='--', linewidth=1)
if highlight_boxes:
rect = plt.Rectangle((pair[0], axes[2].get_ylim()[0]), pair[1] - pair[0], axes[2].get_ylim()[1] - axes[2].get_ylim()[0], linewidth=0, edgecolor='none', facecolor=color, alpha=0.2)
axes[2].add_patch(rect)
# Set labels fontsize
for ax in axes:
ax.set_xlabel('Timepoints', fontsize=12)
ax.set_ylabel('Average probability', fontsize=12)
# Label each state on the right for the last figure (axes[2])
state_labels = [f"State {state+1}" for state in range(Gamma_reconstruct.shape[2])]
axes[2].legend(state_labels, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=12)
fig.suptitle(title, fontsize=fontsize)
# Show the plot
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
[docs]
def plot_p_values_over_time(pval, figsize=(8, 4), total_time_seconds=None, xlabel="Timepoints",
ylabel="P-values (Log Scale)",title_text="P-values over time", xlim_start=0,
tick_positions=[0, 0.001, 0.01, 0.05, 0.1, 0.3, 1], num_colors=259,
alpha=0.05,plot_style = "line", linewidth=2.5):
"""
Plot a scatter plot of p-values over time with a log-scale y-axis and a colorbar.
Parameters:
-----------
pval (numpy.ndarray):
The p-values data to be plotted.
figsize : tuple, optional, default=(8, 4):
Figure size in inches (width, height).
total_time_seconds : float, optional, default=None
Total time duration in seconds. If provided, time points will be scaled accordingly.
xlabel (str, optional), default="Timepoints":
Label for the x-axis.
ylabel (str, optional), default="P-values (Log Scale)":
Label for the y-axis.
title_text (str, optional), default="P-values over time":
Title for the plot.
xlim_start (int, optional), default=0:
Starting point for the x-axis limit.
tick_positions (list, optional), default=[0, 0.001, 0.01, 0.05, 0.1, 0.3, 1]:
Specific values to mark on the y-axis.
num_colors (int, optional), default=259:
Resolution for the color bar.
alpha (float, optional), default=0.05:
Alpha value is the threshold we set for the p-values when doing visualization.
plot_style (str, optional), default="line":
Style of plot.
linewidth (float, optional), default=2.5:
Width of the lines in the plot.
"""
if pval.ndim != 1:
# Raise an exception and stop function execution
raise ValueError("To use the function 'plot_p_values_over_time', the variable for p-values must be one-dimensional.")
# Generate Timepoints based on total_time_seconds
if total_time_seconds:
time_points = np.linspace(0, total_time_seconds, len(pval))
else:
time_points = np.arange(len(pval))
# Convert to log scale
color_array = np.logspace(-3, 0, num_colors).reshape(1, -1)
if alpha is None:
# Create custom colormap
coolwarm_cmap = custom_colormap()
# Create a new colormap with the modified color_array
cmap_list = coolwarm_cmap(color_array)[0]
modified_cmap = interpolate_colormap(cmap_list)
# Create a LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list('custom_colormap', modified_cmap)
else:
# Make a jump in color after alpha
# Get blue colormap
cmap_blue = blue_colormap()
# Create a new colormap with
cmap_list = cmap_blue(color_array)[0]
red_cmap = red_colormap()
blue_cmap = blue_colormap()
# Specify the number of elements you want (e.g., 50)
num_elements_red = np.sum(color_array <= alpha)
num_elements_blue = np.sum(color_array > alpha)
# Generate equally spaced values between 0 and 1
colormap_val_red = np.linspace(0, 1, num_elements_red)
colormap_val_blue = np.linspace(0, 1, num_elements_blue)
# Apply the colormap to the generated values
cmap_red = red_cmap(colormap_val_red)
cmap_blue = blue_cmap(colormap_val_blue)
# overwrite the values below alpha
cmap_list[:num_elements_red,:]=cmap_red
cmap_list[num_elements_red:,:]=cmap_blue
cmap = LinearSegmentedColormap.from_list('custom_colormap', cmap_list)
# Create the line plot with varying color based on p-values
_, axes = plt.subplots(figsize=figsize)
# Normalize the data to [0, 1] for the colormap with logarithmic scale
norm = LogNorm(vmin=1e-3, vmax=1)
if plot_style == "line":
if alpha !=None:
# Plot the line segments with varying colors
for i in range(len(time_points)-1):
if pval[i+1]>alpha:
color = cmap(norm(pval[i+1]))
else:
color = cmap(norm(pval[i]))
axes.plot([time_points[i], time_points[i+1]], [pval[i], pval[i+1]], color=color, linewidth=linewidth)
else:
for i in range(len(time_points)-1):
if pval[i+1]>0.05:
color = cmap(norm(pval[i+1]))
else:
color = cmap(norm(pval[i]))
axes.plot([time_points[i], time_points[i+1]], [pval[i], pval[i+1]], color=color, linewidth=linewidth)
elif plot_style=="scatter":
axes.scatter(time_points, pval, c=pval, cmap=cmap, norm=LogNorm(vmin=1e-3, vmax=1))
elif plot_style=="scatter_line":
axes.scatter(time_points, pval, c=pval, cmap=cmap, norm=LogNorm(vmin=1e-3, vmax=1))
# Draw lines between points
axes.plot(time_points, pval, color='black', linestyle='-', linewidth=1)
# Add labels and title
axes.set_xlabel(xlabel, fontsize=12)
axes.set_ylabel(ylabel, fontsize=12)
axes.set_title(title_text, fontsize=14)
# Set axis limits to focus on the relevant data range
axes.set_xlim(xlim_start, len(pval) + 1)
axes.set_ylim([0.0008, 1.5])
# Set y-axis to log scale
axes.set_yscale('log')
# Mark specific values on the y-axis
plt.yticks([0.001, 0.01, 0.05, 0.1, 0.3, 1], ['0.001', '0.01', '0.05', '0.1', '0.3', '1'])
# Add a colorbar to show the correspondence between colors and p-values
divider = make_axes_locatable(axes)
cax = divider.append_axes("right", size="5%", pad=0.05)
colorbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax, ticks=np.logspace(-3, 0, num_colors), format="%1.0e")
colorbar.update_ticks()
# Round the tick values to three decimal places
rounded_ticks = [round(tick, 3) for tick in colorbar.get_ticks()]
tick_labels = [f'{tick:.3f}' if tick in tick_positions else '' for tick in rounded_ticks]
unique_values_set = set()
unique_values_array = ['' if value == '' or value in unique_values_set else (unique_values_set.add(value), value)[1] for value in tick_labels]
indices_not_empty = [index for index, value in enumerate(unique_values_array) if value != '']
colorbar.set_ticklabels(unique_values_array)
colorbar.ax.tick_params(axis='y')
for idx, tick_line in enumerate(colorbar.ax.yaxis.get_ticklines()):
if idx not in indices_not_empty:
tick_line.set_visible(False)
plt.show()
[docs]
def plot_p_values_bar(pval,xticklabels=[], figsize=(9, 4), num_colors=256, xlabel="",
ylabel="P-values (Log Scale)", title_text="Bar Plot",
tick_positions=[0, 0.001, 0.01, 0.05, 0.1, 0.3, 1], top_adjustment=0.9,
alpha = 0.05, pad_title=20, xlabel_rotation=45, pval_text_hight_same=False):
"""
Visualize a bar plot with LogNorm and a colorbar.
Parameters:
-----------
pval (numpy.ndarray):
Array of p-values to be plotted.
xticklabels (list, optional), default=[]:
List of categories or variables.
figsize (tuple, optional), default=(9, 4):
Figure size in inches (width, height).
num_colors (int, optional), default=256:
Number of colors in the colormap.
xlabel (str, optional), default="":
X-axis label.
ylabel (str, optional), default="P-values (Log Scale)":
Y-axis label.
title_text (str, optional), default="Bar Plot":
Title for the plot.
tick_positions (list, optional), default=[0, 0.001, 0.01, 0.05, 0.1, 0.3, 1]
Positions of ticks on the colorbar.
top_adjustment (float, optional), default=0.9:
Adjustment for extra space between title and plot.
alpha (float, optional), default=0.05:
Alpha value is the threshold we set for the p-values when doing visualization.
pad_title (int, optional), default=20:
Padding for the plot title.
"""
# Choose a colormap
coolwarm_cmap = custom_colormap()
# Convert to log scale
color_array = np.logspace(-3, 0, num_colors).reshape(1, -1)
if alpha == None:
# Create custom colormap
coolwarm_cmap = custom_colormap()
# Create a new colormap with the modified color_array
cmap_list = coolwarm_cmap(color_array)[0]
cmap_list = interpolate_colormap(cmap_list)
else:
# Make a jump in color after alpha
# Get blue colormap
cmap_blue = blue_colormap()
# Create a new colormap with
cmap_list = cmap_blue(color_array)[0]
red_cmap = red_colormap()
blue_cmap = blue_colormap()
# Specify the number of elements you want (e.g., 50)
num_elements_red = np.sum(color_array <= alpha)
num_elements_blue = np.sum(color_array > alpha)
# Generate equally spaced values between 0 and 1
colormap_val_red = np.linspace(0, 1, num_elements_red)
colormap_val_blue = np.linspace(0, 1, num_elements_blue)
# Apply the colormap to the generated values
cmap_red = red_cmap(colormap_val_red)
cmap_blue = blue_cmap(colormap_val_blue)
# shift values a bit
# cmap_red[:,:3] -= 0.15
# # Set values above 1 to 1
# # overwrite the values below alpha
# cmap_red[cmap_red < 0] = 0
# overwrite the values below alpha
cmap_list[:num_elements_red,:]=cmap_red
cmap_list[num_elements_red:,:]=cmap_blue
# Create a LinearSegmentedColormap
colormap = LinearSegmentedColormap.from_list('custom_colormap', cmap_list)
# Plot the bars with LogNorm
fig, axes = plt.subplots(figsize=figsize)
if isinstance(pval, (float, np.ndarray)) and np.size(pval) == 1:
# It's a scalar, create a list with a single element
xticklabels = [f"Var 1"] if xticklabels==[] else xticklabels
else:
# It's an iterable, use len()
xticklabels =[f"Var {i+1}" for i in np.arange(len(pval))] if xticklabels==[] else xticklabels
bars = plt.bar(xticklabels, pval, color=colormap(LogNorm(vmin=1e-3, vmax=1)(pval)))
# Remove the legend
#plt.legend().set_visible(False)
# Add data labels on top of the bars
if pval_text_hight_same:
yval_hight_list=[]
for bar in bars:
# Get the yval_heights
yval_hight_list.append(bar.get_height())
yval_height =np.max(np.array(yval_hight_list))
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval_height + 0.5, round(yval, 3), ha='center', va='bottom', color='black', fontweight='bold')
else:
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 0.5, round(yval, 3), ha='center', va='bottom', color='black', fontweight='bold')
# Set y-axis to log scale
axes.set_yscale('log')
# Customize plot
plt.yscale('log')
axes.set_xlabel(xlabel, fontsize=12)
axes.set_ylabel(ylabel, fontsize=12)
axes.set_title(title_text, fontsize=14, pad=pad_title)
# Set xticks and rotate xtick labels
axes.set_xticks(np.arange(len(xticklabels)))
if xlabel_rotation==45:
ha ='right'
axes.set_xticklabels(xticklabels, rotation=xlabel_rotation, ha=ha)
else:
ha ='center'
axes.set_xticklabels(xticklabels, rotation=xlabel_rotation, ha=ha)
# Mark specific values on the y-axis
plt.yticks([0.001, 0.01, 0.05, 0.1, 0.3, 1], ['0.001', '0.01', '0.05', '0.1', '0.3', '1'])
# Add a colorbar to show the correspondence between colors and p-values
divider = make_axes_locatable(axes)
cax = divider.append_axes("right", size="5%", pad=0.05)
colorbar = plt.colorbar(plt.cm.ScalarMappable(cmap=colormap, norm=LogNorm(vmin=1e-3, vmax=1)), cax=cax, ticks=np.logspace(-3, 0, num_colors), format="%1.0e")
colorbar.update_ticks()
# Round the tick values to three decimal places
rounded_ticks = [round(tick, 3) for tick in colorbar.get_ticks()]
tick_labels = [f'{tick:.3f}' if tick in tick_positions else '' for tick in rounded_ticks]
unique_values_set = set()
unique_values_array = ['' if value == '' or value in unique_values_set else (unique_values_set.add(value), value)[1] for value in tick_labels]
indices_not_empty = [index for index, value in enumerate(unique_values_array) if value != '']
colorbar.set_ticklabels(unique_values_array)
colorbar.ax.tick_params(axis='y')
for idx, tick_line in enumerate(colorbar.ax.yaxis.get_ticklines()):
if idx not in indices_not_empty:
tick_line.set_visible(False)
# Add extra space between title and plot
plt.subplots_adjust(top=top_adjustment)
axes.spines['right'].set_visible(False)
axes.spines['top'].set_visible(False)
plt.tight_layout()
plt.show()