#!/usr/bin/env python
# coding: utf-8

# In[1]:

import os
import sys
import numpy as np 
import math
import scipy as sp
from scipy import special # must explicitly import this module 
from scipy import optimize # must explicitly import this module
from scipy import spatial
import pandas as pd
from scipy.linalg import svd
import phate
import scprep
#import magic
from sklearn.decomposition import PCA


# In[2]:


import math
from numpy import linalg as LA


# In[3]:


def scale_down(ideal):
    ideal_list = ideal.copy()
    list = np.unique(ideal)
    list_match = range(0,len(list),1)
    
    for r in range(0,len(ideal_list)):
        ideal_list[r] = list_match[np.where(ideal_list[r]==list)[0][0]]
        
    return np.array(ideal_list)


# In[4]:


### Building C-PHATE on input


# In[5]:


# Loading affinities and converting to kernels
import scipy.io
arg1 = int(sys.argv[1])
arg2 = int(sys.argv[2])
k_dir_input = sys.argv[3]
inputName = k_dir_input.split(os.sep)[len(k_dir_input.split(os.sep))-2]

os.mkdir(k_dir_input + "/C-PHATEout/")
k_list_input = []
epsilon_list_input = []

# In[6]:


NxT_input = scprep.io.load_csv(k_dir_input+"cluster_assigments.csv", sep=",", cell_names=False).T

for k in range(1,len(NxT_input.iloc[0,:])+1):
    mat = scipy.io.loadmat(k_dir_input +str(k)+"-affinity-matrix.mat")
    A = mat['affinity'].toarray()
    pd.DataFrame(A).to_csv(sys.argv[3] + "C-PHATEout/"+str(k)+"-affinities_"+inputName+".csv")
    Q = 1./np.sum(A, axis = 0)
    pd.DataFrame(Q).to_csv(sys.argv[3] + "C-PHATEout/"+str(k)+"-Qs_"+inputName+".csv")
    K = np.diag(Q) @ A @np.diag(Q) #Creating kernel
    pd.DataFrame(K).to_csv(sys.argv[3] + "C-PHATEout/"+str(k)+"-kernels_"+inputName+".csv")
    k_list_input.append(K)
    epsilon_list_input.append(mat['epsilon'][0][0])
            
k_list_input.append(np.array([500]))



# In[7]:


### New Function ###


# In[8]:


sizes_input = []
for i in k_list_input:
    sizes_input.append(i.shape[0])

c_phate_input = np.zeros(sum(sizes_input)**2)
c_phate_input = c_phate_input.reshape(sum(sizes_input),sum(sizes_input))

# adding kernels to matrix
base_input = 0

for i in range(len(sizes_input)):
    c_phate_input[base_input:base_input+sizes_input[i],base_input:base_input+sizes_input[i]] = k_list_input[i]
    base_input = base_input + sizes_input[i]


# In[9]:


## 1 layer connectivity ahead

base_input = sizes_input[0]

for t in range(1,len(sizes_input)-1):
    matching = np.zeros(sizes_input[t+1]*sizes_input[t])
    matching = matching.reshape(sizes_input[t+1],sizes_input[t])
    
    prev = np.array(NxT_input.iloc[:,t-1])   #column t
    fut = np.array(NxT_input.iloc[:,t])
    
    for i in range(len(fut)-1):
        matching[int(fut[i]-1),int(prev[i]-1)] = arg1
          
    c_phate_input[base_input+sizes_input[t] : base_input+sizes_input[t] + sizes_input[t+1], base_input : base_input+sizes_input[t]] = matching
    c_phate_input[ base_input : base_input+sizes_input[t],base_input+sizes_input[t] : base_input+sizes_input[t]+sizes_input[t+1]] = matching.T
        
    base_input = base_input + sizes_input[t]


# In[10]:


## 2 layer connectivity ahead

base_input = sizes_input[0]

for t in range(1,len(sizes_input)-2):
    matching = np.zeros(sizes_input[t+2]*sizes_input[t])
    matching = matching.reshape(sizes_input[t+2],sizes_input[t])
    
    prev = np.array(NxT_input.iloc[:,t-1])   #row t
    fut = np.array(NxT_input.iloc[:,t+1])
    
    for i in range(len(fut)-1):
        matching[int(fut[i]-1),int(prev[i]-1)] = arg2

    c_phate_input[base_input + sizes_input[t]+sizes_input[t+1]:base_input + sizes_input[t]+sizes_input[t+1]+sizes_input[t+2],base_input:base_input+sizes_input[t]] = matching
    c_phate_input[base_input :base_input+sizes_input[t],base_input + sizes_input[t]+sizes_input[t+1]:base_input + sizes_input[t]+sizes_input[t+1]+sizes_input[t+2]] = matching.T
        
    base_input = base_input + sizes_input[t]


# In[11]:


c_phateB_input = np.zeros((sum(sizes_input)-sizes_input[0])**2)
c_phateB_input = c_phateB_input.reshape(sum(sizes_input)-sizes_input[0],sum(sizes_input)-sizes_input[0])
c_phateB_input = c_phate_input[sizes_input[0]:sum(sizes_input), sizes_input[0]:sum(sizes_input)]


# In[12]:


pd.DataFrame(c_phateB_input).to_csv(sys.argv[3] + "C-PHATEout/"+"c_phate_"+inputName+".csv")


# In[13]:


### Visualizing ###


# In[14]:


phate_op = phate.PHATE(n_components=3, n_jobs=-1, knn_dist='precomputed_affinity')
data_phate_input = phate_op.fit_transform(c_phateB_input)


# In[15]:


pd.DataFrame(data_phate_input).to_csv(sys.argv[3] + "C-PHATEout/"+"new_phate_coordinates_"+inputName+".csv")


# In[ ]:





# In[16
