# NOTE: Code has not been fully tested, take it with a grain of salt.
# TODO: Handling repeated eigenvalues has not yet been fully implemented.
#
# See "Sampling matrices from Harish-Chandra--Itzykson--Zuber densities with applications to Quantum inference and differential privacy"
# by Jonathan Leake, Colin McSwiggen, and Nisheeth K. Vishnoi at
# https://dl.acm.org/doi/10.1145/3406325.3451094
# for further details.

import numpy as np
import secrets
import hopsy
import matplotlib.pyplot as plt
import math
from ctypes import c_uint32

class LogLinear:
    def __init__(self, alpha):
        self._alpha = np.reshape(np.array(alpha), (len(alpha), 1))

    def log_density(self, x):
        return np.dot(x, self._alpha)
    
    def log_gradient(self, x):
        return self._alpha

# NOTE: Currently assumes all relevant values of X distinct.
def compute_radii(X: np.array):
    n = len(X)
    V = np.zeros((n,n))

    for k in range(1,n):
        for i in range(k):
            V[k,i] = math.sqrt( \
                -math.prod(X[k-1,i] - X[k,j] for j in range(k+1)) \
                / math.prod(X[k-1,i] - X[k-1,j] for j in range(k) if j != i))
            
    return V

# Computes the random unitary conjugation based upon
# random phases given by T.
# Output matrix will have random eigenvector columns.
# NOTE: Currently assumes all relevant values of X distinct.
def compute_unitary(X: np.array, V: np.array, T: np.array):
    n = len(X)
    U_out = np.identity(n, dtype='complex')

    for k in range(1,n):
        U_k = np.identity(n, dtype='complex')

        # TODO: this should be able to be simplified maybe?
        for j in range(k+1):
            for i in range(k):
                U_k[i,j] = V[k,i] * T[k-1,i] / (X[k,j] - X[k-1,i])
            
            U_k[k,j] = 1

        U_k = U_k / np.linalg.norm(U_k, axis=0)
        U_out = U_out @ U_k

    U_out = U_out @ np.diag(T[n-1])
    
    return U_out

# Constructs A,b such that Ax < b gives the GT polytope.
# NOTE: Works for all choices of spec.
def construct_GT_polytope(spec: np.array, q: int):
    n = len(spec)

    # var gives indices for the output A matrix (so that points
    # are of the form Ax <= b)
    var = -np.ones((n,n), dtype=int)

    # frc keeps track of which entries of a Rayleigh triangle are forced
    frc = np.zeros((n,n))

    # pnt 
    pnt = np.zeros((n,n))

    frc[-1] = spec
    pnt[-1] = spec

    nC2 = n*(n-1) // 2
    A = np.zeros((2*nC2,nC2))
    b = np.zeros(2*nC2)
    con_ct = 0
    var_ct = 0

    for i in range(n-2,-1,-1):
        for j in range(i+1):
            if frc[i+1,j] == frc[i+1,j+1]:
                frc[i,j] = frc[i+1,j]
                pnt[i,j] = pnt[i+1,j]
            else:
                frc[i,j] = np.NaN

                #left_bnd_int = int(np.round(pnt[i+1,j] * q * (n-i) * (n-(i+1))))
                right_bnd_int = int(np.round(pnt[i+1,j+1] * q * (n-i) * (n-(i+1))))
                pnt[i,j] = ((right_bnd_int - 1) // (n-(i+1))) / (q * (n-i))

                var[i,j] = var_ct
                var_ct += 1

                A[con_ct, var[i,j]] = -1

                if np.isnan(frc[i+1,j]):
                    A[con_ct, var[i+1,j]] = 1
                    b[con_ct] = 0
                else:
                    b[con_ct] = -frc[i+1,j]
                
                con_ct += 1
                A[con_ct, var[i,j]] = 1

                if np.isnan(frc[i+1,j+1]):
                    A[con_ct, var[i+1,j+1]] = -1
                    b[con_ct] = 0
                else:
                    b[con_ct] = frc[i+1,j+1]

                con_ct += 1

    A = A[:con_ct, :var_ct]
    b = b[:con_ct]

    return (A,b,var,pnt)

# Main function to sample Hermitian matrices from an HCIZ distribution.
# NOTE: spec_orbit must be sorted in increasing order
# Inputs:
#   spec_orbit (required): the spectrum of unitary conjugation orbit we are sampling from
#   spec_scaling (required): the spectrum of the scaling matrix for the distribution
#   max_denom (required): the smallest integer q such that q * spec_orbit is an integer vector
#   ret_unitary (optional, default True): return unitary samples, or else return Hermitian samples
#   num_samples (optional, default 100): number of Hermitian matrices to sample
#   rng_seed (optional, default random): seed for the MC and phase samplers
#   mc_thinning (optional, default 100): samples for the MC to skip between each returned sample
#   mc_proposal (optional, default BilliardWalk): MC random walk type (see hopsy documentation)
def HCIZ_sample(spec_orbit: np.array, spec_scaling: np.array, max_denom: int, ret_unitary: bool = True, num_samples: int = 100, rng_seed: int = secrets.randbits(128), mc_thinning: int = 100, mc_proposal: hopsy.Proposal = hopsy.BilliardWalkProposal):
    n = len(spec_orbit)

    # spec_orbit = np.sort(spec_orbit)

    (A,b,var,pnt) = construct_GT_polytope(spec_orbit, max_denom)

    p = np.zeros(np.size(A,1))
    v = np.zeros(np.size(A,1))

    for i in range(n):
        for j in range(n):
            if var[i,j] >= 0:
                p[var[i,j]] = pnt[i,j]
                v[var[i,j]] = spec_scaling[i] - spec_scaling[i+1]

    model = LogLinear(v)
    problem = hopsy.Problem(A, b, model)
    mc = hopsy.MarkovChain(problem, proposal=mc_proposal, starting_point=p)
    rng_mc = hopsy.RandomNumberGenerator(seed=c_uint32(rng_seed).value)

    (acceptance_rate, states) = hopsy.sample(mc, rng_mc, n_samples=num_samples, thinning=mc_thinning)

    # print(acceptance_rate)

    rng_T = np.random.default_rng(seed=rng_seed)

    X = np.zeros((np.size(states, 1), n, n))
    T = np.exp((2 * math.pi * 1j) * rng_T.random((np.size(states, 1), n, n)))
    out = np.zeros((np.size(states, 1), n, n), dtype="complex")

    for k in range(np.size(states, 1)):
        for i in range(n):
            for j in range(n):
                if var[i,j] >= 0:
                    X[k,i,j] = states[0,k,var[i,j]]

        X[k,-1] = spec_orbit

        V_k = compute_radii(X[k])
        U_k = compute_unitary(X[k], V_k, T[k])

        if ret_unitary:
            out[k] = U_k
        else:
            out[k] = U_k @ np.diag(spec_orbit) @ U_k.conjugate().transpose()

    return out

# Some test spectra.
# NOTE: spec_orbit must be sorted in increasing order
spec_orbit = np.array([0,1,2,3,4])
spec_scaling = np.array([10,0,0,0,0])

# Main function to sample Hermitian matrices from an HCIZ distribution.
# Inputs:
#   spec_orbit (required): the spectrum of unitary conjugation orbit we are sampling from
#   spec_scaling (required): the spectrum of the scaling matrix for the distribution
#   max_denom (required): the smallest integer q such that q * spec_orbit is an integer vector
#   ret_unitary (optional, default True): return unitary samples, or else return Hermitian samples
#   num_samples (optional, default 100): number of Hermitian matrices to sample
#   rng_seed (optional, default random): seed for the MC and phase samplers
#   mc_thinning (optional, default 100): samples for the MC to skip between each returned sample
#   mc_proposal (optional, default BilliardWalk): MC random walk type (see hopsy documentation)
samples = HCIZ_sample(spec_orbit, spec_scaling, max_denom=1, num_samples=1000, ret_unitary=True)

# Pretty print the info.
np.set_printoptions(precision=2)

# Print the first 3 samples.
print(samples[0])
print(samples[1])
print(samples[2])

# # Check eigenvalues of first sample.
# for i in range(len(spec_orbit),0,-1):
#     print(np.linalg.eigvalsh(samples[0,:i,:i]))

# Plot the top-left entry of all samples.
to_plt = samples[:, 0, 0]
plt.scatter(range(len(to_plt)), np.real(to_plt))
plt.show()