# This Python code is a translation of part of the PALM (Permutation Analysis of Linear Models) package, originally developed by Anderson M. Winkler.
# PALM is a tool for permutation-based statistical analysis.
# To learn more about PALM, please visit the official PALM website:
# http://fsl.fmrib.ox.ac.uk/fsl/fslwiki/PALM
# In this Python translation, our primary focus is on accommodating family structure within your dataset.
# The techniques employed in PALM for managing exchangeability blocks are detailed in the following publication:
# Title: Multi-level block permutation
# Authors: Winkler AM, Webster MA, Vidaurre D, Nichols TE, Smith SM.
# Source: Neuroimage. 2015;123:253-68 (Open Access)
# DOI: 10.1016/j.neuroimage.2015.05.092
# Translated by:
# Nick Y. Larsen
# CFIN / Aarhus University
# Sep/2023 (first version)
# We would like to acknowledge Anderson M. Winkler and all contributors to the PALM package for their valuable work in the field of permutation analysis.
import numpy as np
import pandas as pd
import copy
import warnings
######################### PART 0 - hcp2block #########################################################
[docs]
def hcp2block(tmp, blocksfile=None, dz2sib=False, ids=None):
"""
Convert HCP-style twin data into block structure.
Parameters:
-----------
file (str):
Path to the input CSV file containing twin data.
blocksfile (str, optional), default=None:
Path to save the resulting blocks as a CSV file.
dz2sib (bool, optional), default=False:
If True, handle non-monozygotic twins as siblings.
ids (list or array-like, optional), default=None:
List of subject IDs to include.
Returns:
--------
tuple
A tuple containing three elements:
tab : numpy.ndarray
A modified table of twin data.
B : numpy.ndarray
Block structure representing relationships between subjects.
famtype : numpy.ndarray
An array indicating the type of each family.
"""
# # Load data
# tmp = pd.read_csv(file)
# Handle missing zygosity
if 'Zygosity' not in tmp.columns:
tmp['Zygosity'] = np.where(tmp['ZygosityGT'].isna() | (tmp['ZygosityGT'] == ' ') | tmp['ZygosityGT'].isnull(),
tmp['ZygositySR'], tmp['ZygosityGT'])
# Select columns of interest
cols = ['Subject', 'Mother_ID', 'Father_ID', 'Zygosity', 'Age_in_Yrs']
tab = tmp[cols].copy()
age = tmp['Age_in_Yrs']
# Remove subjects with missing data
tab_nan = (tab.iloc[:, 0:3].isna() | tab['Zygosity'].isna() | tab['Age_in_Yrs'].isna()).any(axis=1)
# Remove missing data
tab_empty =(tab== ' ').any(axis=1)
# [i for i, value in enumerate(tab_empty) if value]
# clean up table
tab0 = tab_nan+tab_empty
idstodel = tab.loc[tab0, 'Subject'].tolist()
if len(idstodel) > 0:
print(f"These subjects have data missing and will be removed: {idstodel}")
# Create new table
tab = tab[~tab0]
age = age[~tab0]
N = tab.shape[0]
# Handle non-monozygotic twins
if dz2sib:
for n in range(N):
if tab.iloc[n, 3].lower() in ['notmz', 'dz']:
tab.iloc[n, 3] = 'NotTwin'
# Convert zygosity strings to identifiers
sibtype = np.zeros(N, dtype=int)
for n in range(N):
if tab.iloc[n, 3].lower() == 'nottwin':
sibtype[n] = 10
elif tab.iloc[n, 3].lower() in ['notmz', 'dz']:
sibtype[n] = 100
elif tab.iloc[n, 3].lower() == 'mz':
sibtype[n] = 1000
tab = tab.iloc[:, :3]
# Subselect subjects
if ids is not None:
if isinstance(ids[0], bool):
tab = tab[ids]
sibtype = sibtype[ids]
else:
idx = np.isin(tab[:, 0].astype(int), ids)
if np.any(~idx):
print(f"These subjects don't exist in the file and will be removed: {ids[~idx]}")
ids = ids[idx]
tabnew, sibtypenew, agenew = [], [], []
for n in ids:
idx = tab[:, 0].astype(int) == n
tabnew.append(tab[idx])
sibtypenew.append(sibtype[idx])
agenew.append(age[idx])
tab = np.array(tabnew)
sibtype = np.array(sibtypenew)
age = np.array(agenew)
N = tab.shape[0]
# Create family IDs
famid = np.zeros(N, dtype=int)
U, inv_idx = np.unique(tab.iloc[:, 1:3], axis=0, return_inverse=True)
for u in range(U.shape[0]):
uidx = np.all(tab.iloc[:, 1:3] == U[u], axis=1)
famid[uidx] = u
# Merge families for parents belonging to multiple families
par = tab.iloc[:, 1:3]
for p in par.values.flatten():
pidx = np.any(par == p, axis=1)
#print(np.unique(famid[pidx]))
famids = np.unique(famid[pidx])
for f in famids:
famid[famid == f] = famids[0]
# Label each family type
F = np.unique(famid)
famtype = np.zeros(N, dtype=int)
for f in F:
fidx = famid == f
famtype[fidx] = np.sum(sibtype[fidx]) + len(np.unique(tab.iloc[fidx, 1:3]))
#famtype
# Correction section, because there need to be more than one person to become a twin pair
# Handle twins with missing pair data
# Twins which pair data isn't available should be treated as
# non-twins, so fix and repeat computing the family types
idx = ((sibtype == 100) & (famtype >= 100) & (famtype <= 199)) | \
((sibtype == 1000) & (famtype >= 1000) & (famtype <= 1999))
sibtype[idx] = 10
for f in F:
fidx = famid == f
famtype[fidx] = np.sum(sibtype[fidx]) + len(np.unique(tab.iloc[fidx, 1:3]))
# Append the new info to the table.
tab = np.column_stack((tab, sibtype,famid, famtype))
tab
# Combine columns famid, sibtype, and age into a single array
#combined_array = np.column_stack((famid, sibtype, age))
# Use lexsort to obtain the indices that would sort the combined array
# Sort the data and obtain sorting indices
idx = np.lexsort((age, sibtype, famid)) # This sorts by famid, then sibtype, and finally age
idxback = np.argsort(idx)
tab = tab[idx]
sibtype = sibtype[idx]
famid = famid[idx]
famtype = famtype[idx]
age = age.iloc[idx]
# Create blocks for each family
B = []
for f in range(len(F)):
fidx = famid == F[f]
ft = famtype[np.where(fidx)[0][0]]
if ft in np.concatenate((np.arange(12, 100, 10), [23, 202, 2002])):
B.append(np.column_stack((famid[fidx]+1, sibtype[fidx], tab[fidx, 0])))
else:
B.append(np.column_stack((-(famid[fidx]+1), sibtype[fidx], tab[fidx, 0])))
if ft == 33:
tabx = tab[fidx, 1:3]
for s in range(tabx.shape[0]):
if (np.sum(tabx[:, 0] == tabx[s, 0]) == 2 and \
np.sum(tabx[:, 1] == tabx[s, 1]) == 3) or \
(np.sum(tabx[:, 0] == tabx[s, 0]) == 3 and \
np.sum(tabx[:, 1] == tabx[s, 1]) == 2):
B[-1][s, 1] += 1
elif ft == 53:
tabx = tab[fidx, 1:3]
for s in range(tabx.shape[0]):
if (np.sum(tabx[:, 0] == tabx[s, 0]) == 3 and \
np.sum(tabx[:, 1] == tabx[s, 1]) == 5) or \
(np.sum(tabx[:, 0] == tabx[s, 0]) == 5 and \
np.sum(tabx[:, 1] == tabx[s, 1]) == 3):
B[-1][s, 1] += 1
elif ft == 234:
tabx = tab[fidx, 1:3]
for s in range(tabx.shape[0]):
if (np.sum(tabx[:, 0] == tabx[s, 0]) == 1 and \
np.sum(tabx[:, 1] == tabx[s, 1]) == 3) or \
(np.sum(tabx[:, 0] == tabx[s, 0]) == 3 and \
np.sum(tabx[:, 1] == tabx[s, 1]) == 1):
B[-1][s, 1] += 1
elif ft == 54:
tabx = tab[fidx, 1:3]
for s in range(tabx.shape[0]):
if np.sum(tabx[:, 0] == tabx[s, 0]) == 4 and np.sum(tabx[:, 1] == tabx[s, 1]) == 2:
B[-1][s, 1] += 1
elif np.sum(tabx[:, 0] == tabx[s, 0]) == 1 and np.sum(tabx[:, 1] == tabx[s, 1]) == 3:
B[-1][s, 1] -= 1
elif ft == 34:
tabx = tab[fidx, 1:3]
for s in range(tabx.shape[0]):
if (np.sum(tabx[:, 0] == tabx[s, 0]) == 2 and
np.sum(tabx[:, 1] == tabx[s, 1]) == 2):
B[f][s, 1] += 1
famtype[fidx] = 39
elif ft == 43:
tabx = tab[fidx, 1:3]
k = 0
for s in range(tabx.shape[0]):
if tabx[s, 0] == tabx[0, 0] and tabx[s, 1] == tabx[0, 1]:
B[-1][s, 1] += 1
k += 1
if k == 2:
famtype[fidx] = 49
B[-1][:, 0] = -B[-1][:, 0]
elif ft == 44:
tabx = tab[fidx, 1:3]
for s in range(tabx.shape[0]):
if np.sum(tabx[:, 0] == tabx[s, 0]) == 4 and np.sum(tabx[:, 1] == tabx[s, 1]) == 2:
B[-1][s, 1] += 1
elif ft == 223:
sibx = sibtype[fidx]
B[-1][sibx == 10, 1] = -B[-1][sibx == 10, 1]
elif ft == 302:
famtype[fidx] = 212
tmpage = age[fidx]
if tmpage.iloc[0] == tmpage.iloc[1]:
B[-1][2, 1] = 10
elif tmpage.iloc[0] == tmpage.iloc[2]:
B[-1][1, 1] = 10
elif tmpage.iloc[1] == tmpage.iloc[2]:
B[-1][0, 1] = 10
elif ft == 313 or ft == 314:
famtype[fidx] = ft - 100 + 10
if (famtype[fidx] == 223).all():
# Identify the elements that are equal to 223 and need changing
to_change = (famtype == 223) & fidx
# Update the values at the selected indices
famtype[fidx] = 229
tmpage = age[fidx]
#didx = np.where(B[-1][:, 1] == 100)[0]
didx = np.where(B[-1][:, 1] == 100)[0]
if tmpage.iloc[didx[0]] == tmpage.iloc[didx[1]]:
B[-1][didx[2], 1] = 10
elif tmpage.iloc[didx[0]] == tmpage.iloc[didx[2]]:
B[-1][didx[1], 1] = 10
elif tmpage.iloc[didx[1]] == tmpage.iloc[didx[2]]:
B[-1][didx[0], 1] = 10
# Additional case: ft == 2023
elif ft == 2023:
tabx = tab[fidx, 1:3]
for s in range(tabx.shape[0]):
if np.sum(tabx[:, 0] == tabx[s, 0]) == 4 and \
(np.sum(tabx[:, 1] == tabx[s, 1]) == 1 or np.sum(tabx[:, 1] == tabx[s, 1]) == 3):
famtype[fidx] = 2029
if B[f][s, 1] == 10:
B[f][s, 1] = -B[f][s, 1]
B = np.hstack((-np.ones((N, 1),dtype=int), famtype.reshape(-1, 1), np.concatenate(B, axis=0)))
# Sort back to the original order
B = B[idxback, :]
tab = tab[idxback, :]
tab[:, 5] = B[:, 1]
blocksfile = "EB.csv"
if blocksfile is not None and isinstance(blocksfile, str):
# Save B as a CSV file with integer precision
np.savetxt(blocksfile, B, delimiter=',', fmt='%d')
# Return tab, B, and famtype
return tab, B, famtype
######################### PART 1 - Reindix #########################################################
import numpy as np
def renumber(B):
"""
This function renumbers the elements in the input array B based on distinct values in its first column.
Each distinct value represents a block, and the elements within each block are renumbered sequentially,
while preserving the relative order of elements within each block.
Parameters:
-----------
B (numpy.ndarray):
The 2D input array to be renumbered.
Returns:
-----------
Br (numpy.ndarray):
The renumbered array, where elements are renumbered within blocks.
addcol (bool):
A boolean indicating whether a column was added during renumbering.
"""
# Extract the first column of the input array B
B1 = B[:, 0]
# Find the unique values in B1 and store them in U
U = np.unique(B1)
# Create a boolean array to keep track of added columns
addcolvec = np.zeros_like(U, dtype=bool)
# Get the number of unique values
nU = U.shape[0]
# Create an empty array Br with the same shape as B
Br = np.zeros_like(B)
# Loop through unique values in B1
for u in range(nU):
# Find indices where B1 is equal to the current unique value U[u]
idx = B1 == U[u]
# Renumber the corresponding rows in Br based on the index
Br[idx, 0] = (u + 1) * np.sign(U[u])
# Check if B has more than one column
if B.shape[1] > 1:
# Recursively call renumber for the remaining columns and update addcolvec
Br[idx, 1:], addcolvec[u] = renumber(B[idx, 1:])
elif np.sum(idx) > 1:
# If there's only one column and more than one matching row, set addcol to True
addcol = True
Br[idx] = -np.abs(B[idx])
else:
addcol = False
# Check if B has more than one column and if any columns were added
if B.shape[1] > 1:
addcol = np.any(addcolvec)
# Return the renumbered array Br and the addcol flag
return Br, addcol
def palm_reindex(B, meth='fixleaves'):
"""
This function reorders the elements of a 2D numpy array `B` by applying one of several reindexing methods.
The primary goal of reindexing is to assign new values to elements in such a way that they are organized
in a desired order or structure.
Parameters:
-----------
B (numpy.ndarray):
The 2D input array to be reindexed.
meth (str, optional):
The reindexing method to be applied. It can take one of the following values:
- 'fixleaves': This method reindexes the input array by preserving the order of unique values in the
first column and recursively reindexes the remaining columns. It is well-suited for hierarchical
data where the first column represents levels or leaves.
- 'continuous': This method reindexes the input array by assigning new values to elements in a
continuous, non-overlapping manner within each column. It is useful for continuous data or when
preserving the order of unique values is not a requirement.
- 'restart': This method reindexes the input array by restarting the numbering from 1 for each block
of unique values in the first column. It is suitable for data that naturally breaks into distinct
segments or blocks.
- 'mixed': This method combines both the 'fixleaves' and 'continuous' reindexing methods. It reindexes
the first columns using 'fixleaves' and the remaining columns using 'continuous', creating a mixed
reindexing scheme.
Returns:
-----------
Br (numpy.ndarray):
The reindexed array, preserving the block structure based on the chosen method.
"""
# Convert meth to lowercase
meth = meth.lower()
# Initialize the output array Br with zeros
Br = np.zeros_like(B)
if meth == 'continuous':
# Find unique values in the first column of B
U = np.unique(B[:, 0])
# Renumber the first column based on unique values
for u in range(U.shape[0]):
idx = B[:, 0] == U[u]
Br[idx, 0] = (u + 1) * np.sign(U[u])
# Loop through columns starting from the 2nd column
for b in range(1, B.shape[1]): # From the 2nd column onwards
Bb = B[:, b]
Bp = Br[:, b - 1] # Previous column
# Find unique values in the previous column
Up = np.unique(Bp)
cnt = 1
# Renumber elements within blocks based on unique values
for up in range(Up.shape[0]):
idxp = Bp == Up[up]
U = np.unique(Bb[idxp])
# Renumber elements within the block
for u in range(U.shape[0]):
idx = np.logical_and(Bb == U[u], idxp)
Br[idx, b] = cnt * np.sign(U[u])
cnt += 1
elif meth == 'restart':
# Renumber each block separately, starting from 1
Br, _ = renumber(B)
elif meth == 'mixed':
# Mix both 'restart' and 'continuous' methods
Ba, _ = palm_reindex(B, 'restart')
Bb, _ = palm_reindex(B, 'continuous')
Br = np.hstack((Ba[:, :-1], Bb[:, -1:]))
elif meth=="fixleaves":
# Reindex using 'fixleaves' method as defined in the renumber function
B1 = B[:, 0]
U = np.unique(B1)
addcolvec = np.zeros_like(U, dtype=bool)
nU = U.shape[0]
Br = np.zeros_like(B)
for u in range(nU):
idx = B1 == U[u]
Br[idx, 0] = (u + 1) * np.sign(U[u])
if B.shape[1] > 1:
Br[idx, 1:], addcolvec[u] = renumber(B[idx, 1:])
elif np.sum(idx) > 1:
addcol = True
Br[idx] = -np.abs(B[idx])
else:
addcol = False
if B.shape[1] > 1:
addcol = np.any(addcolvec)
if addcol:
# Add a column of sequential numbers to Br and reindex
col = np.arange(1, Br.shape[0] + 1).reshape(-1, 1)
Br = np.hstack((Br, col))
Br, _ = renumber(Br)
else:
# Raise a ValueError for an unknown method
raise ValueError(f'Unknown method: {meth}')
# Return the reindexed array Br
return Br
######################### PART 2 - PALMTREE #########################################################
def palm_permtree(Ptree, nP, CMC=False, maxP=None):
"""
This function generates permutations of a palm tree structure represented by Ptree.
Parameters:
-----------
Ptree (list or numpy.ndarray)
The palm tree structure to be permuted.
nP (int)
The number of permutations to generate.
CMC (bool, optional), default=False:
Whether to use Conditional Monte Carlo (CMC) method for permutation.
maxP (int, optional), default=None
The maximum number of permutations allowed. If not provided, it is calculated automatically.
Returns:
---------
P (numpy.ndarray)
An array representing the permutations. Each row corresponds to a permutation, with the first
column always representing the identity permutation.
Notes:
---------
- If 'CMC' is False and 'nP' is greater than 'maxP' / 2, a warning message is displayed, as it may take a
considerable amount of time to find non-repeated permutations.
- The function utilizes the 'pickperm' and 'randomperm' helper functions for the permutation process.
"""
if nP == 1 and not maxP:
# Calculate the maximum number of permutations if not provided
maxP = palm_maxshuf(Ptree, 'perms')
if nP > maxP:
nP = maxP # The cap is only imposed if maxP isn't supplied
# Permutation #1 is no permutation, regardless.
P = pickperm(Ptree, np.array([], dtype=int))
P = np.hstack((P.reshape(-1,1), np.zeros((P.shape[0], nP - 1), dtype=int)))
# Generate all other permutations up to nP
if nP == 1:
pass
elif CMC or nP > maxP:
for p in range(2, nP + 1):
Ptree_perm = copy.deepcopy(Ptree)
Ptree_perm = randomperm(Ptree_perm)
P[:, p - 1] = pickperm(Ptree_perm, [])
else:
if nP > maxP / 2:
# Inform the user about the potentially long runtime
print(f'The maximum number of permutations ({maxP}) is not much larger than\n'
f'the number you chose to run ({nP}). This means it may take a while (from\n'
f'a few seconds to several minutes) to find non-repeated permutations.\n'
'Consider instead running exhaustively all possible permutations. It may be faster.')
for p in range(1, nP):
whiletest = True
while whiletest:
Ptree_perm = copy.deepcopy(Ptree)
Ptree_perm = randomperm(Ptree_perm)
P[:, p] = pickperm(Ptree_perm, [])
whiletest = np.any(np.all(P[:, :p] == P[:, p][:, np.newaxis], axis=0))
# The grouping into branches screws up the original order, which
# can be restored by noting that the 1st permutation is always
# the identity, so with indices 1:N. This same variable idx can
# be used to likewise fix the order of sign-flips (separate func).
idx = np.argsort(P[:, 0])
P = P[idx, :]
return P
def pickperm(Ptree, P):
"""
This function extracts a permutation from a given palm tree structure. It does not perform the permutation
but returns the indices representing the already permuted tree.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
P (numpy.ndarray):
The current state of the permutation.
Returns:
----------
P (numpy.ndarray):
An array of indices representing the permutation of the palm tree structure.
"""
# Check if Ptree is a list and has three elements, then recursively call pickperm on the third element
if isinstance(Ptree,list):
if len(Ptree) == 3:
P = pickperm(Ptree[2],P)
# Check if the shape of Ptree is (N, 3), where N is the number of branches
elif Ptree.shape[1] ==3:
nU = Ptree.shape[0]
# Loop through each branch
for u in range(nU):
# Recursively call pickperm on the third element of the branch
P = pickperm(Ptree[u][2],P)
# Check if the shape of Ptree is (N, 1)
elif Ptree.shape[1] ==1:
nU = Ptree.shape[0]
# Loop through each branch
for u in range(nU):
# Concatenate the first element of the branch (a submatrix) to P
P = np.concatenate((P, Ptree[u][0]), axis=None)
return P
def randomperm(Ptree_perm):
"""
Create a random permutation of a palm tree structure by shuffling its branches.
Parameters:
-----------
Ptree_perm (list or numpy.ndarray):
The palm tree structure to be permuted.
Returns:
----------
Ptree_perm (list):
The randomly permuted palm tree structure.
"""
# Check if Ptree_perm is a list and has three elements, then recursively call randomperm on the third element
if isinstance(Ptree_perm,list):
if len(Ptree_perm) == 3:
Ptree_perm = randomperm(Ptree_perm[2])
# Get the number of branches in Ptree_perm
nU = Ptree_perm.shape[0]
# Loop through each branch
for u in range(nU):
# Check if the first element of the branch is a single value and not NaN
if is_single_value(Ptree_perm[u][0]):
if not np.isnan(Ptree_perm[u][0]):
tmp = 1
# Shuffle the first element of the branch
np.random.shuffle(Ptree_perm[u][0])
# Check if tmp is not equal to the first element of the branch
if np.any(tmp != Ptree_perm[u][0][0]):
# Rearrange the third element of the branch based on the shuffled indices
Ptree_perm[u][2][Ptree_perm[u][0][:, 2].astype(int) - 1, :]
# Check if the first element of the branch is a list with three elements
elif isinstance(Ptree_perm[u][0],list) and len(Ptree_perm[u][0])==3:
tmp = 1
# Shuffle the first element of the branch
np.random.shuffle(Ptree_perm[u][0])
# Check if tmp is not equal to the first element of the branch
if np.any(tmp != Ptree_perm[u][0][0]):
# Rearrange the third element of the branch based on the shuffled indices
Ptree_perm[u][2][Ptree_perm[u][0][:, 2].astype(int) - 1, :]
else:
tmp = np.arange(1,len(Ptree_perm[u][0][:,0])+1,dtype=int)
# Shuffle the first element of the branch
np.random.shuffle(Ptree_perm[u][0])
# Check if tmp is not equal to the first element of the branch
if np.any(tmp != Ptree_perm[u][0][:, 0]):
# Rearrange the third element of the branch based on the shuffled indices
Ptree_perm[u][2] =Ptree_perm[u][2][Ptree_perm[u][0][:, 2].astype(int) - 1, :]
# Make sure the next isn't the last level.
if Ptree_perm[u][2].shape[1] > 1:
# Recursively call randomperm on the third element of the branch
Ptree_perm[u][2] = randomperm(Ptree_perm[u][2])
return Ptree_perm
######################### PART 3.1 - Permute PTREE #########################################################
#### Permutation functions
import numpy as np
def palm_maxshuf(Ptree, stype='perms', uselog=False):
"""
Calculate the maximum number of shufflings (permutations or sign-flips) for a given palm tree structure.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
stype (str, optional), default='perms:
The type of shuffling to calculate.
uselog (bool, optional), defaults=False:
A flag indicating whether to calculate using logarithmic values.
Returns:
----------
maxb (int):
The maximum number of shufflings (permutations or sign-flips) based on the specified criteria.
"""
# Calculate the maximum number of shufflings based on user-defined options
if uselog:
if stype == 'perms':
maxb = lmaxpermnode(Ptree, 0)
else:
if stype == 'perms':
maxb = maxpermnode(Ptree, 1)
return maxb
def maxpermnode(Ptree, np):
"""
Calculate the maximum number of permutations within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
np (int):
The current number of permutations (initialized to 1).
Returns:
----------
n_p (int):
The maximum number of permutations within the node.
"""
for u in range(len(Ptree)):
n_p = n_p * seq2np(Ptree[u][0][:, 0])
if len(Ptree[u][2][0]) > 1:
n_p = maxpermnode(Ptree[u][2], np)
return n_p
def seq2np(S):
"""
Calculate the number of permutations for a given sequence.
Parameters:
-----------
S (numpy.ndarray):
The input sequence.
Returns:
----------
n_p (int):
The number of permutations for the sequence.
"""
U, cnt = np.unique(S, return_counts=True)
n_p = np.math.factorial(len(S)) / np.prod(np.math.factorial(cnt))
return n_p
def maxflipnode(Ptree, ns):
"""
Calculate the maximum number of sign-flips within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
ns (int):
The current number of sign-flips (initialized to 1).
Returns:
----------
ns (int):
The maximum number of sign-flips within the node.
"""
for u in range(len(Ptree)):
if len(Ptree[u][2][0]) > 1:
ns = maxflipnode(Ptree[u][2], ns)
ns = ns * (2 ** len(Ptree[u][1]))
return ns
def lmaxpermnode(Ptree, n_p):
"""
Calculate the logarithm of the maximum number of permutations within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
n_p (int):
The current logarithm of permutations (initialized to 0).
Returns:
----------
n_p (int):
The logarithm of the maximum number of permutations within the node.
"""
if isinstance(Ptree,list):
n_p = n_p + lseq2np(Ptree[0])
if Ptree[2].shape[1] > 1:
n_p = lmaxpermnode(Ptree[2], n_p)
else:
for u in range(Ptree.shape[0]):
if isinstance(Ptree[u][0],list):
n_p = n_p + lseq2np(Ptree[u][0][0])
if Ptree[u][2].shape[1] > 1:
#n_p = lmaxpermnode(Ptree[u][2][0][2], n_p)
n_p = lmaxpermnode(Ptree[u][2], n_p)
elif is_single_value(Ptree[u][0]):
n_p = n_p + lseq2np(Ptree[u][0])
if len(Ptree[u]) > 2 and Ptree[u][2].shape[1] > 1:
n_p = lmaxpermnode(Ptree[u][2], n_p)
else:
n_p = n_p + lseq2np(Ptree[u][0][:,0])
if len(Ptree[u]) > 2 and Ptree[u][2].shape[1] > 1:
n_p = lmaxpermnode(Ptree[u][2], n_p)
return n_p
def lseq2np(S):
"""
Calculate the logarithm of the number of permutations for a given sequence.
Parameters:
-----------
S (numpy.ndarray):
The input sequence.
Returns:
----------
n_p (int):
The logarithm of the number of permutations for the sequence.
"""
if is_single_value(S):
nS = 1
if np.isnan(S):
U = np.nan
cnt = 0
else:
U, cnt = np.unique(S, return_counts=True)
else:
nS = len(S)
U, cnt = np.unique(S, return_counts=True)
#lfac=palm_factorial(nS)
lfac=palm_factorial()
n_p = lfac[nS] - np.sum(lfac[cnt])
return n_p
def lmaxflipnode(Ptree, ns):
"""
Calculate the logarithm of the maximum number of sign-flips within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
ns (int):
The current logarithm of sign-flips (initialized to 0).
Returns:
----------
ns (int):
The logarithm of the maximum number of sign-flips within the node.
"""
for u in range(len(Ptree)):
if len(Ptree[u][2][0]) > 1:
ns = lmaxflipnode(Ptree[u][2], ns)
ns = ns + len(Ptree[u][1])
return ns
def is_single_value(variable):
"""
Check if an array contains a singlevalue.
Parameters:
-----------
variable (numpy.ndarray or list):
The array to be checked.
Returns:
----------
bool:
True if the array contains a single value, False otherwise.
"""
return isinstance(variable, (int, float, complex))
def palm_factorial(N=101):
"""
Calculate logarithmically scaled factorials up to a given number.
Parameters:
-----------
N (int, optional), default=101:
The maximum number for which to precompute factorials.
Returns:
----------
lf (numpy.ndarray):
An array of precomputed logarithmically scaled factorials.
"""
if N == 1:
N = 101
# Initialize the lf array with zeros
lf = np.zeros(N+1)
# Calculate log(factorial) values
for n in range(1, N+1):
lf[n] = np.log(n) + lf[n-1]
return lf
######################### PART 3.2 - PALMTREE #########################################################
#### Permute PALM tree
[docs]
def palm_permtree(Ptree, nP, CMC=False, maxP=None):
"""
Generate permutations of a given palm tree structure represented by Ptree.
Permutations are created by shuffling the branches of the palm tree.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure to be permuted.
nP (int):
The number of permutations to generate.
CMC (bool, optional), default=False:
Whether to use Conditional Monte Carlo (CMC) method for permutation. Defaults to False.
maxP (int, optional), default=None:
The maximum number of permutations allowed. If not provided, it is calculated automatically.
Returns:
----------
P (numpy.ndarray)
An array representing the permutations. Each row corresponds to a permutation, with the first
column always representing the identity permutation.
Notes:
----------
- If 'CMC' is False and 'nP' is greater than 'maxP' / 2, a warning message is displayed, as it may take a
considerable amount of time to find non-repeated permutations.
- The function utilizes the 'pickperm' and 'randomperm' helper functions for the permutation process.
"""
if nP == 1 and not maxP:
# Calculate the maximum number of permutations if not provided
maxP = palm_maxshuf(Ptree, 'perms')
if nP > maxP:
nP = maxP # The cap is only imposed if maxP isn't supplied
# Permutation #1 is no permutation, regardless.
P = pickperm(Ptree, np.array([], dtype=int))
P = np.hstack((P.reshape(-1,1), np.zeros((P.shape[0], nP - 1), dtype=int)))
# Generate all other permutations up to nP
if nP == 1:
pass
elif CMC or nP > maxP:
for p in range(2, nP + 1):
Ptree_perm = copy.deepcopy(Ptree)
Ptree_perm = randomperm(Ptree_perm)
P[:, p - 1] = pickperm(Ptree_perm, [])
else:
if nP > maxP / 2:
# Inform the user about the potentially long runtime
print(f'The maximum number of permutations ({maxP}) is not much larger than\n'
f'the number you chose to run ({nP}). This means it may take a while (from\n'
f'a few seconds to several minutes) to find non-repeated permutations.\n'
'Consider instead running exhaustively all possible permutations. It may be faster.')
for p in range(1, nP):
whiletest = True
while whiletest:
Ptree_perm = copy.deepcopy(Ptree)
Ptree_perm = randomperm(Ptree_perm)
P[:, p] = pickperm(Ptree_perm, [])
whiletest = np.any(np.all(P[:, :p] == P[:, p][:, np.newaxis], axis=0))
# The grouping into branches screws up the original order, which
# can be restored by noting that the 1st permutation is always
# the identity, so with indices 1:N. This same variable idx can
# be used to likewise fix the order of sign-flips (separate func).
idx = np.argsort(P[:, 0])
P = P[idx, :]
return P
[docs]
def pickperm(Ptree, P):
"""
Extract a permutation from a palm tree structure.
It does not perform the permutation but returns the indices representing the already permuted tree.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
P (numpy.ndarray):
The current state of the permutation.
Returns:
----------
P (numpy.ndarray):
An array of indices representing the permutation of the palm tree structure.
"""
# Check if Ptree is a list and has three elements, then recursively call pickperm on the third element
if isinstance(Ptree,list):
if len(Ptree) == 3:
P = pickperm(Ptree[2],P)
# Check if the shape of Ptree is (N, 3), where N is the number of branches
elif Ptree.shape[1] ==3:
nU = Ptree.shape[0]
# Loop through each branch
for u in range(nU):
# Recursively call pickperm on the third element of the branch
P = pickperm(Ptree[u][2],P)
# Check if the shape of Ptree is (N, 1)
elif Ptree.shape[1] ==1:
nU = Ptree.shape[0]
# Loop through each branch
for u in range(nU):
# Concatenate the first element of the branch (a submatrix) to P
P = np.concatenate((P, Ptree[u][0]), axis=None)
return P
[docs]
def randomperm(Ptree_perm):
"""
Create a random permutation of a palm tree structure by shuffling its branches.
Parameters:
-----------
Ptree_perm (list or numpy.ndarray):
The palm tree structure to be permuted.
Returns:
----------
Ptree_perm (list):
The randomly permuted palm tree structure.
"""
# Check if Ptree_perm is a list and has three elements, then recursively call randomperm on the third element
if isinstance(Ptree_perm,list):
if len(Ptree_perm) == 3:
Ptree_perm = randomperm(Ptree_perm[2])
# Get the number of branches in Ptree_perm
nU = Ptree_perm.shape[0]
# Loop through each branch
for u in range(nU):
# Check if the first element of the branch is a single value and not NaN
if is_single_value(Ptree_perm[u][0]):
if not np.isnan(Ptree_perm[u][0]):
tmp = 1
# Shuffle the first element of the branch
np.random.shuffle(Ptree_perm[u][0])
# Check if tmp is not equal to the first element of the branch
if np.any(tmp != Ptree_perm[u][0][0]):
# Rearrange the third element of the branch based on the shuffled indices
Ptree_perm[u][2][Ptree_perm[u][0][:, 2].astype(int) - 1, :]
# Check if the first element of the branch is a list with three elements
elif isinstance(Ptree_perm[u][0],list) and len(Ptree_perm[u][0])==3:
tmp = 1
# Shuffle the first element of the branch
np.random.shuffle(Ptree_perm[u][0])
# Check if tmp is not equal to the first element of the branch
if np.any(tmp != Ptree_perm[u][0][0]):
# Rearrange the third element of the branch based on the shuffled indices
Ptree_perm[u][2][Ptree_perm[u][0][:, 2].astype(int) - 1, :]
else:
tmp = np.arange(1,len(Ptree_perm[u][0][:,0])+1,dtype=int)
# Shuffle the first element of the branch
np.random.shuffle(Ptree_perm[u][0])
# Check if tmp is not equal to the first element of the branch
if np.any(tmp != Ptree_perm[u][0][:, 0]):
# Rearrange the third element of the branch based on the shuffled indices
Ptree_perm[u][2] =Ptree_perm[u][2][Ptree_perm[u][0][:, 2].astype(int) - 1, :]
# Make sure the next isn't the last level.
if Ptree_perm[u][2].shape[1] > 1:
# Recursively call randomperm on the third element of the branch
Ptree_perm[u][2] = randomperm(Ptree_perm[u][2])
return Ptree_perm
######################### PART 3.3 - Permute PTREE #########################################################
###### Function that integrates the above functions in part 3
import warnings
import numpy as np
[docs]
def palm_shuftree(Ptree, nP, CMC= False,EE = True):
"""
Generate a set of shufflings (permutations or sign-flips) for a given palm tree structure.
Parameters:
-----------
Ptree (list):
The palm tree structure.
nP (int):
The number of permutations to generate.
CMC (bool, optional), default=False:
A flag indicating whether to use the Conditional Monte Carlo method (CMC).
EE (bool, optional), default=True:
A flag indicating whether to assume exchangeable errors, which allows permutation.
Returns:
----------
Pset (list):
A list containing the generated shufflings (permutations).
"""
# Maximum number of shufflings (perms, sign-flips, or both)
maxP = 1
maxS = 1
if EE:
lmaxP = palm_maxshuf(Ptree, 'perms', True)
maxP = np.exp(lmaxP)
if np.isinf(maxP):
print('Number of possible permutations is exp({}).'.format(lmaxP))
else:
print('Number of possible permutations is {}.'.format(maxP))
maxB = maxP * maxS
# String for the screen output below
whatshuf = 'permutations only'
whatshuf2 = 'perms'
# Generate the Pset and Sset
Pset = []
Sset = []
if nP == 0 or nP >= maxB:
# Run exhaustively if the user requests too many permutations.
# Note that here CMC is irrelevant.
print('Generating {} shufflings ({}).'.format(maxB, whatshuf))
if EE:
Pset = palm_permtree(Ptree, int(round(maxP)) if maxP != np.inf else maxP, [], int(round(maxP)) if maxP != np.inf else maxP)
elif nP < maxB:
# Or use a subset of possible permutations.
print('Generating {} shufflings ({}).'.format(nP, whatshuf))
if EE:
if nP >= maxP:
Pset = palm_permtree(Ptree, int(round(maxP)) if maxP != np.inf else maxP, CMC, int(round(maxP)) if maxP != np.inf else maxP)
else:
Pset = palm_permtree(Ptree, nP, CMC, int(round(maxP)) if maxP != np.inf else maxP)
return Pset
######################### PART 4 - quick_perm #########################################################
[docs]
def palm_quickperms(EB, M=None, nP=1000, CMC=False, EE=True):
"""
Generate a set of permutations for a given input matrix using PALM methods.
Parameters:
-----------
EB (numpy.ndarray)
Block structure representing relationships between subjects.
M (numpy.ndarray, optional), default=None:
The matrix of attributes, which is not typically required.
nP (int), default=1000:
The number of permutations to generate.
CMC (bool, optional), default=False:
A flag indicating whether to use the Conditional Monte Carlo method (CMC).
EE (bool, optional), default=True:
A flag indicating whether to assume exchangeable errors, which allows permutation.
Returns:
-----------
Pset (list):
A list containing the generated permutations.
"""
# Filter out the specific RuntimeWarning
warnings.filterwarnings("ignore", message="overflow encountered in exp")
# Reindex the input matrix for palm methods with 'fixleaves'
EB2 = palm_reindex(EB, 'fixleaves')
# Generate a palm tree structure from the reindexed matrix
Ptree = palm_tree(EB2)
# Generate a set of shufflings (permutations) based on the palm tree structure
Pset = palm_shuftree(Ptree, nP, CMC, EE)
# Need to change the number so the index startes from 0
# Pset = Pset-1
return Pset
######################### Helper functions #########################################################
[docs]
def palm_maxshuf(Ptree, stype='perms', uselog=False):
"""
Calculate the maximum number of shufflings (permutations or sign-flips) for a given palm tree structure.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
stype (str, optional), default='perms:
The type of shuffling to calculate ('perms' for permutations by default).
uselog (bool, optional), default=False:
A flag indicating whether to calculate using logarithmic values.
Returns:
-----------
maxb (int):
The maximum number of shufflings (permutations or sign-flips) based on the specified criteria.
"""
# Calculate the maximum number of shufflings based on user-defined options
if uselog:
if stype == 'perms':
maxb = lmaxpermnode(Ptree, 0)
else:
if stype == 'perms':
maxb = maxpermnode(Ptree, 1)
return maxb
[docs]
def maxpermnode(Ptree, np):
"""
Calculate the maximum number of permutations within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
np (int):
The current number of permutations.
Returns:
-----------
n_p (int):
The maximum number of permutations within the node.
"""
for u in range(len(Ptree)):
n_p = n_p * seq2np(Ptree[u][0][:, 0])
if len(Ptree[u][2][0]) > 1:
n_p = maxpermnode(Ptree[u][2], np)
return n_p
[docs]
def seq2np(S):
"""
Calculate the number of permutations for a given sequence.
Parameters:
-----------
S (numpy.ndarray):
The input sequence.
Returns:
-----------
n_p (int):
The number of permutations for the sequence.
"""
U, cnt = np.unique(S, return_counts=True)
n_p = np.math.factorial(len(S)) / np.prod(np.math.factorial(cnt))
return n_p
[docs]
def maxflipnode(Ptree, ns):
"""
Calculate the maximum number of sign-flips within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
ns (int):
The current number of sign-flips (initialized to 1).
Returns:
-----------
ns (int):
The maximum number of sign-flips within the node.
"""
for u in range(len(Ptree)):
if len(Ptree[u][2][0]) > 1:
ns = maxflipnode(Ptree[u][2], ns)
ns = ns * (2 ** len(Ptree[u][1]))
return ns
[docs]
def lmaxpermnode(Ptree, n_p):
"""
Calculate the logarithm of the maximum number of permutations within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
n_p (int):
The current logarithm of permutations (initialized to 0).
Returns:
-----------
n_p (int):
The logarithm of the maximum number of permutations within the node.
"""
if isinstance(Ptree,list):
n_p = n_p + lseq2np(Ptree[0])
if Ptree[2].shape[1] > 1:
n_p = lmaxpermnode(Ptree[2], n_p)
else:
for u in range(Ptree.shape[0]):
if isinstance(Ptree[u][0],list):
n_p = n_p + lseq2np(Ptree[u][0][0])
if Ptree[u][2].shape[1] > 1:
#n_p = lmaxpermnode(Ptree[u][2][0][2], n_p)
n_p = lmaxpermnode(Ptree[u][2], n_p)
elif is_single_value(Ptree[u][0]):
n_p = n_p + lseq2np(Ptree[u][0])
if len(Ptree[u]) > 2 and Ptree[u][2].shape[1] > 1:
n_p = lmaxpermnode(Ptree[u][2], n_p)
else:
n_p = n_p + lseq2np(Ptree[u][0][:,0])
if len(Ptree[u]) > 2 and Ptree[u][2].shape[1] > 1:
n_p = lmaxpermnode(Ptree[u][2], n_p)
return n_p
[docs]
def lseq2np(S):
"""
Calculate the logarithm of the number of permutations for a given sequence.
Parameters:
-----------
S (numpy.ndarray):
The input sequence.
Returns:
-----------
n_p (int):
The logarithm of the number of permutations for the sequence.
"""
if is_single_value(S):
nS = 1
if np.isnan(S):
U = np.nan
cnt = 0
else:
U, cnt = np.unique(S, return_counts=True)
else:
nS = len(S)
U, cnt = np.unique(S, return_counts=True)
#lfac=palm_factorial(nS)
lfac=palm_factorial()
n_p = lfac[nS] - np.sum(lfac[cnt])
return n_p
[docs]
def lmaxflipnode(Ptree, ns):
"""
Calculate the logarithm of the maximum number of sign-flips within a palm tree node.
Parameters:
-----------
Ptree (list or numpy.ndarray):
The palm tree structure.
ns (int):
The current logarithm of sign-flips (initialized to 0).
Returns:
-----------
ns (int):
The logarithm of the maximum number of sign-flips within the node.
"""
for u in range(len(Ptree)):
if len(Ptree[u][2][0]) > 1:
ns = lmaxflipnode(Ptree[u][2], ns)
ns = ns + len(Ptree[u][1])
return ns
[docs]
def is_single_value(variable):
"""
Check if an array contains a singlevalue.
Parameters:
-----------
variable (numpy.ndarray or list):
The array to be checked.
Returns:
-----------
bool:
True if the array contains a single value, False otherwise.
"""
return isinstance(variable, (int, float, complex))
[docs]
def palm_factorial(N=101):
"""
Calculate logarithmically scaled factorials up to a given number.
Parameters:
-----------
N (int, optional), default=101:
The maximum number for which to precompute factorials.
Returns:
-----------
lf (numpy.ndarray):
An array of precomputed logarithmically scaled factorials.
"""
if N == 1:
N = 101
# Initialize the lf array with zeros
lf = np.zeros(N+1)
# Calculate log(factorial) values
for n in range(1, N+1):
lf[n] = np.log(n) + lf[n-1]
return lf
[docs]
def renumber(B):
"""
Renumber the elements in the input array B based on distinct values in its first column.
Each distinct value represents a block, and the elements within each block are renumbered sequentially,
while preserving the relative order of elements within each block.
Parameters:
-----------
B (numpy.ndarray):
The 2D input array to be renumbered.
Returns:
--------
Br (numpy.ndarray):
The renumbered array, where elements are renumbered within blocks.
addcol (bool):
A boolean indicating whether a column was added during renumbering.
"""
# Extract the first column of the input array B
B1 = B[:, 0]
# Find the unique values in B1 and store them in U
U = np.unique(B1)
# Create a boolean array to keep track of added columns
addcolvec = np.zeros_like(U, dtype=bool)
# Get the number of unique values
nU = U.shape[0]
# Create an empty array Br with the same shape as B
Br = np.zeros_like(B)
# Loop through unique values in B1
for u in range(nU):
# Find indices where B1 is equal to the current unique value U[u]
idx = B1 == U[u]
# Renumber the corresponding rows in Br based on the index
Br[idx, 0] = (u + 1) * np.sign(U[u])
# Check if B has more than one column
if B.shape[1] > 1:
# Recursively call renumber for the remaining columns and update addcolvec
Br[idx, 1:], addcolvec[u] = renumber(B[idx, 1:])
elif np.sum(idx) > 1:
# If there's only one column and more than one matching row, set addcol to True
addcol = True
Br[idx] = -np.abs(B[idx])
else:
addcol = False
# Check if B has more than one column and if any columns were added
if B.shape[1] > 1:
addcol = np.any(addcolvec)
# Return the renumbered array Br and the addcol flag
return Br, addcol
[docs]
def palm_reindex(B, meth='fixleaves'):
"""
Reindex a 2D numpy array preserving block structure based on different reindexing methods.
Parameters:
-----------
B (numpy.ndarray):
The 2D input array to be reindexed.
meth (str, optional), default='fixleaves':
- 'fixleaves': Reindexes the input array by preserving the order of unique values in the first column
and recursively reindexes the remaining columns. Suitable for hierarchical data.
- 'continuous': Reindexes the input array by assigning new values to elements in a continuous,
non-overlapping manner within each column. Useful for continuous data.
- 'restart': Reindexes the input array by restarting the numbering from 1 for each block of
unique values in the first column. Suitable for data with distinct segments or blocks.
- 'mixed': Combines both 'fixleaves' and 'continuous' reindexing methods. Reindexes the first columns
using 'fixleaves' and the remaining columns using 'continuous', creating a mixed scheme.
Returns:
--------
Br (numpy.ndarray):
The reindexed array, preserving the block structure based on the chosen method.
"""
# Convert meth to lowercase
meth = meth.lower()
# Initialize the output array Br with zeros
Br = np.zeros_like(B)
if meth == 'continuous':
# Find unique values in the first column of B
U = np.unique(B[:, 0])
# Renumber the first column based on unique values
for u in range(U.shape[0]):
idx = B[:, 0] == U[u]
Br[idx, 0] = (u + 1) * np.sign(U[u])
# Loop through columns starting from the 2nd column
for b in range(1, B.shape[1]): # From the 2nd column onwards
Bb = B[:, b]
Bp = Br[:, b - 1] # Previous column
# Find unique values in the previous column
Up = np.unique(Bp)
cnt = 1
# Renumber elements within blocks based on unique values
for up in range(Up.shape[0]):
idxp = Bp == Up[up]
U = np.unique(Bb[idxp])
# Renumber elements within the block
for u in range(U.shape[0]):
idx = np.logical_and(Bb == U[u], idxp)
Br[idx, b] = cnt * np.sign(U[u])
cnt += 1
elif meth == 'restart':
# Renumber each block separately, starting from 1
Br, _ = renumber(B)
elif meth == 'mixed':
# Mix both 'restart' and 'continuous' methods
Ba, _ = palm_reindex(B, 'restart')
Bb, _ = palm_reindex(B, 'continuous')
Br = np.hstack((Ba[:, :-1], Bb[:, -1:]))
elif meth=="fixleaves":
# Reindex using 'fixleaves' method as defined in the renumber function
B1 = B[:, 0]
U = np.unique(B1)
addcolvec = np.zeros_like(U, dtype=bool)
nU = U.shape[0]
Br = np.zeros_like(B)
for u in range(nU):
idx = B1 == U[u]
Br[idx, 0] = (u + 1) * np.sign(U[u])
if B.shape[1] > 1:
Br[idx, 1:], addcolvec[u] = renumber(B[idx, 1:])
elif np.sum(idx) > 1:
addcol = True
Br[idx] = -np.abs(B[idx])
else:
addcol = False
if B.shape[1] > 1:
addcol = np.any(addcolvec)
if addcol:
# Add a column of sequential numbers to Br and reindex
col = np.arange(1, Br.shape[0] + 1).reshape(-1, 1)
Br = np.hstack((Br, col))
Br, _ = renumber(Br)
else:
# Raise a ValueError for an unknown method
raise ValueError(f'Unknown method: {meth}')
# Return the reindexed array Br
return Br
[docs]
def palm_tree(B, M=None):
"""
Construct a palm tree structure from an input matrix B and an optional design-matrix M.
The palm tree represents a hierarchical structure where each node can have three branches:
- The left branch contains data elements.
- The middle branch represents special features (if any).
- The right branch contains nested structures.
Parameters:
-----------
B (numpy.ndarray):
The input matrix where each row represents the Multi-level block definitions of the PALM tree.
M (numpy.ndarray, optional):
An optional Design-matrix that associates each node in B with additional data.
Defaults to None.
Returns:
--------
list
A list containing three elements:
- Ptree[0] : numpy.ndarray or list
The left branch of the palm tree, containing data elements.
- Ptree[1] : numpy.ndarray, list, or empty list
The middle branch of the palm tree, representing special features (if any).
- Ptree[2] : numpy.ndarray or list
The right branch of the palm tree, containing nested structures.
"""
# If M is not provided, create a default M matrix with sequential values
if M is None:
M = np.arange(1, B.shape[0] + 1).reshape(-1, 1)
# Check if the number of rows in B and M match, raise an error if not
elif B.shape[0] != M.shape[0]:
raise ValueError("The two inputs must have the same number of rows.")
# Make some initial sanity checks
O = np.arange(1, M.shape[0] + 1).reshape(-1, 1)
# Determine if the entire block is positive
wholeblock = B[0, 0] > 0
# Initialize a list to store the palm tree structure
Ptree = [[] for _ in range(3)]
# Recursively build the palm tree structure
Ptree[0], Ptree[2] = maketree(B[:, 1:], M, O, wholeblock, wholeblock)
# If the block is a whole block, set the middle branch to zeros, otherwise, set it to an empty list
if wholeblock:
Ptree[1] = np.zeros(Ptree[2].shape[0], dtype=bool)
else:
Ptree[1] = []
# Return the palm tree structure
return Ptree
[docs]
def maketree(B, M, O, wholeblock, nosf):
"""
Recursively construct a palm tree structure from input matrices that is representing
nodes in the palm tree.
Parameters:
-----------
B (numpy.ndarray):
The input matrix where each row represents a node in the palm tree (Block definitions).
M (numpy.ndarray):
The corresponding Design-matrix, which associates nodes in B with additional data.
O (numpy.ndarray):
Observation indices.
wholeblock (bool):
A boolean indicating if the entire block is positive based on the first element of B.
nosf (bool):
A boolean indicating if there are no signflips this level.
Returns:
-----------
S (numpy.ndarray):
The palm tree structure for this branch.
Ptree (list):
The palm tree structure.
"""
# Extract the first column of the input matrix B
B1 = B[:, 0]
# Find unique values in the first column of B
U = np.unique(B1)
# Get the number of unique values
nU = len(U)
# Initialize the Ptree array based on the number of columns in B
if B.shape[1] > 1:
Ptree = np.empty((nU, 3), dtype=object)
else:
Ptree = np.empty((nU, 1), dtype=object)
# Loop through unique values in the first column of B
for u in range(nU):
# Find indices where the first column matches the current unique value U[u]
idx = B1 == U[u]
if B.shape[1] > 1:
# Determine if the entire block is positive for this branch
wholeblockb = B[np.where(idx)[0][0], 0] > 0
# Recursively build left and right branches
Ptree[u][0], Ptree[u][2] = maketree(B[idx, 1:], M[idx], O[idx], wholeblockb, wholeblockb or nosf)
# Initialize the middle branch as an empty list
Ptree[u][1] = []
# Check if there are no special features
if nosf:
Ptree[u][1] = []
# Check if the right branch has more than one column
elif Ptree[u][2].shape[1] > 1:
if isinstance(Ptree[u][0][0], np.ndarray):
if M.ndim == 0:
if np.isnan(Ptree[u][0][0]):
Ptree[u][1] = []
else:
Ptree[u][1] = np.zeros(Ptree[u][2].shape[0], dtype=int)
else:
if np.isnan(Ptree[u][0][0][0]):
Ptree[u][1] = []
else:
Ptree[u][1] = np.zeros(Ptree[u][2].shape[0], dtype=int)
else:
if np.isnan(Ptree[u][0][0]):
Ptree[u][1] = []
else:
Ptree[u][1] = np.zeros(Ptree[u][2].shape[0], dtype=int)
else:
# Set the first column of this branch to O[idx]
Ptree[u][0] = O[idx]
if wholeblock and nU > 1:
# Sort the combined array based on the first column
combined_array = np.column_stack((B1, M))
sorted_indices = np.argsort(combined_array[:, 0])
B1M = combined_array[sorted_indices]
# Use lexsort to sort by both columns
sorted_indices = np.lexsort((B1M[:, 1], B1M[:, 0]))
B1M_sorted = B1M[sorted_indices]
Ms = B1M_sorted[:, 1:]
Msre = Ms.reshape(nU, int(Ms.size / nU))
# Get unique rows and their indices
_, S = np.unique(Msre, axis=0, return_inverse=True)
# Put in ascending order and (un)shuffle the branches accordingly
idx = np.argsort(S)
S = np.column_stack((S[idx], np.arange(0, S.shape[0]), np.arange(0, S.shape[0]))) + 1
Ptree = Ptree[idx, :]
elif wholeblock and nU == 1:
# If it's a whole block with a single unique value, set S to [1, 1, 1]
S = [1, 1, 1]
else:
# If not a whole block, set S to NaN
S = np.nan
return S, Ptree