# Amira Abdel-Rahman
# (c) Massachusetts Institute of Technology 2020

# Asymptotic Homogenization Implementation in Julia
# Based on: https://asmedigitalcollection.asme.org/materialstechnology/article-abstract/141/1/011005/368579

# using MAT
# using LinearAlgebra
# using CUDA, IterativeSolvers
# using SparseArrays
# using Krylov

# vars = matread("grid_octet_skel.mat")
# GRID=vars["GRID"];
# STRUT=vars["STRUT"];
# res = 40; # number of voxels per side
# rad = 0.1; # radius of struts

# two options to format effective property results: 'vec' or 'struct'
# outOption = "struct";
# options to print results and/or plot Young's modulus surface
# dispFlag = 1;
# plotFlag = 1;
# props, SH = evaluateCH(CH, dens, outOption, dispFlag);
function KE_from_E(E,nu)
    D0 = E/(1+nu)/(1-2*nu)*
        [ 1-nu   nu   nu     0          0          0     ;
            nu 1-nu   nu     0          0          0     ;
            nu   nu 1-nu     0          0          0     ;
             0    0    0 (1-2*nu)/2     0          0     ;
             0    0    0     0      (1-2*nu)/2     0     ;
             0    0    0     0          0      (1-2*nu)/2];
    Ke = elementMatVec3D(0.5, 0.5, 0.5, D0);
    
    return Ke,[E,E,E]
end

function getCH(GRID,STRUT,res,rad)

    
    vox, dens = generateVoxelLattice(res, rad, GRID, STRUT);
    # alternatively, define voxels directly
    # vars = matread("grid_octet_vox.mat")
    # vox=vars["vox"];
    # dens = sum(vox)/length(vox)
    # lengths of sides of unit cell
    ll = [1,1,1];

    # properties of isotropic constituent material properties
    E = [1e-9 2e9]; # E1, E2
    nu = [0.33 0.33]; # nu1, nu2
    lam = nu.*E ./ ((1.0 .+nu).*(1.0 .-2.0*nu));
    mu = E ./ (2.0*(1.0 .+nu));

    # two options to define constituent materials: 'young's or 'lame'
    # changes how stiffness matrix is assembled.
    def = "youngs"; props0 = [E; nu];   # with young's modulus and poisson's ratio
    # def = "lame"; props0 = [lam; mu]; # with lame's parameters

    # two options for solver: 'pcg' or 'direct'
    solver = "pcg";


    
    CH = homogAsymp3D(ll, vox, props0, def, solver);

    # two options to format effective property results: 'vec' or 'struct'
    outOption = "struct";
    # options to print results and/or plot Young's modulus surface
    dispFlag = 1;
    plotFlag = 1;
    dens=0.5
    props, SH = evaluateCH(CH, dens, outOption, dispFlag);
    EH=props["EH"]

    return CH,EH
end

function generateVoxelLattice(n,radius,node,strut)
    #########################################################################

    #########################################################################
    vox_size = 1/n;               # initial size of voxels
    voxel = zeros(Int,n,n,n);      # initial grid with zeros
    ## generate a list of centers of voxel
    voxel_c = zeros(n^3,6);   
    p = 0;                    # p count the number of all voxels
    for i = 1:n               # i for z axis
        for j = 1:n           # j for y axis
            for k = 1:n       # k for x axis
                p = p + 1;
                voxel_c[p,1:3] = [k,j,i];  # save index along x,y,z axis
                # save coordinate along x,y,z axis
                voxel_c[p,4:6] = [(k-0.5)*vox_size,(j-0.5)*vox_size,(i-0.5)*vox_size];         
            end
        end
    end
    voxel_i = Base._sub2ind(size(voxel), map(y->Int.(y),voxel_c[:,1]), map(y->Int.(y),voxel_c[:,2]), map(y->Int.(y),voxel_c[:,3]));
    start_n = node[Int.(strut[:,1]),:];
    end_n = node[Int.(strut[:,2]),:];
    

    ## Get the voxel close the the strut witnin a certain distance
    for i = 1:size(strut)[1]
        
        alpha = acosd.( sum((voxel_c[:,4:6] .- start_n[i,:]') .* (end_n[i,:]' .- start_n[i,:]'), dims=2) ./ (vecNorm(voxel_c[:,4:6] .- start_n[i,:]') .* vecNorm(end_n[i,:]' .- start_n[i,:]')) );
        beta  = acosd.( sum((voxel_c[:,4:6] .- end_n[i,:]') .* (start_n[i,:]' .- end_n[i,:]'), dims=2) ./ (vecNorm(voxel_c[:,4:6] .- end_n[i,:]') .* vecNorm(start_n[i,:]' .- end_n[i,:]')) );
        
        # if it is acute angle, distance to node
        distance = min.(vecNorm(voxel_c[:,4:6] .- start_n[i,:]'), vecNorm(voxel_c[:,4:6] .- end_n[i,:]'));
        
        
        # if not acute angle, distance to line
        obtuse = ((alpha .<90.0) .& (beta .<90.0));

        A=end_n[i,:] .- start_n[i,:];
        B= voxel_c[:,4:6] .- start_n[i,:]';
        
        temp = vecNorm(  getCross(A,B) ) ./ vecNorm(end_n[i,:]' .- start_n[i,:]');
        
        distance[obtuse] = temp[obtuse];

        # if distance less than radius, activate it
        temp = zeros(Int,p,1);
        active = (distance .<= radius);
        temp[active] .= 1;
        temp_voxel = zeros(Int,size(voxel));
        temp_voxel[voxel_i] = temp;
        voxel .= temp_voxel .| voxel;
    end

    Density = sum(sum(sum(voxel)))/length(voxel); # calculate the relative density                            
    return voxel,Density
end

function vecNorm(A)
    new_norm = sqrt.(sum(A.^2, dims=2));
    return new_norm
end  
    
function getCross(A,B)
    dim1=size(B)[1]
    dim2=size(B)[2]
    result=zeros(dim1,dim2)
    for i=1:dim1 #64000
        result[i,:].=cross(A,B[i,:])
    end
    return result   
end

## COMPUTE UNIT ELEMENT STIFFNESS MATRIX AND LOAD VECTOR
function assemble_lame(a, b, c)
    # Initialize
    keLambda = zeros(24,24); keMu = zeros(24,24);
    feLambda = zeros(24,6); feMu = zeros(24,6);
    ww = [5/9, 8/9, 5/9];
    J_ = [-a a a -a -a a a -a; -b -b b b -b -b b b; -c -c -c -c c c c c]';
    # Constitutive matrix contributions
    CMu = diagm(0=>[2, 2, 2, 1, 1, 1]); 
    CLambda = zeros(6); 
    CLambda[1:3,1:3] = 1;
    # Three Gauss points in both directions
    xx = [-sqrt(3/5), 0, sqrt(3/5)]; yy = xx; zz = xx;
    for ii = 1:size(xx)[1]
        for jj = 1:size(yy)[1]
            for kk = 1:size(zz)[1]
                # integration point
                x = xx(ii); y = yy(jj); z = zz(kk);
                # stress strain displacement matrix
                B, J = strain_disp_matrix(x, y, z, J_);
                # Weight factor at this point
                weight = det(J) * ww(ii) * ww(jj) * ww(kk);
                # Element matrices
                keLambda = keLambda + weight * B' * CLambda * B;
                keMu = keMu + weight * B' * CMu * B;
                # Element loads
                feLambda = feLambda + weight * B' * CLambda;       
                feMu = feMu + weight * B' * CMu; 
            end
        end
    end
        
    return keLambda, keMu, feLambda, feMu

end

function assemble_youngs(nu, a, b, c)
    #  Initialize
    ww = [5/9, 8/9, 5/9];
    J_ = [-a a a -a -a a a -a; -b -b b b -b -b b b; -c -c -c -c c c c c]';

    ke = zeros(24,24); fe = zeros(24,6);
    # Constitutive matrix with unit Young's modulus
    nu = nu[2]; #TODO multi-material nu
    C = diagm(1=>[nu, nu, 0, 0, 0]) .+ diagm(2=>[nu, 0, 0, 0]); 
    C = C .+C';
    C = C + diagm(0=>[ 1-nu,1-nu,1-nu,(1-2*nu)/2,(1-2*nu)/2, (1-2*nu)/2]);
    C = C / ((1+nu).*(1-2*nu));
    # Three Gauss points in both directions
    xx = [-sqrt(3/5), 0, sqrt(3/5)]; yy = xx; zz = xx;
    for ii = 1:size(xx)[1]
        for jj = 1:size(yy)[1]
            for kk = 1:size(zz)[1]
                # integration point
                x = xx[ii]; y = yy[jj]; z = zz[kk];
                # stress strain displacement matrix
                B, J = strain_disp_matrix(x, y, z, J_);
                # Weight factor at this point
                weight = det(J) * ww[ii] * ww[jj] * ww[kk];

                # Element matrices
                ke = ke + weight * B' * C * B;
                # Element loads
                fe = fe + weight * B' * C;       
            end
        end
    end
    return ke, fe
end

function strain_disp_matrix(x, y, z, J_)
    #stress strain displacement matrix
    qx = [ -((y-1)*(z-1))/8, ((y-1)*(z-1))/8, -((y+1)*(z-1))/8, ((y+1)*(z-1))/8, ((y-1)*(z+1))/8, -((y-1)*(z+1))/8,((y+1)*(z+1))/8, -((y+1)*(z+1))/8];
    qy = [ -((x-1)*(z-1))/8, ((x+1)*(z-1))/8, -((x+1)*(z-1))/8, ((x-1)*(z-1))/8, ((x-1)*(z+1))/8, -((x+1)*(z+1))/8,((x+1)*(z+1))/8, -((x-1)*(z+1))/8];
    qz = [ -((x-1)*(y-1))/8, ((x+1)*(y-1))/8, -((x+1)*(y+1))/8, ((x-1)*(y+1))/8, ((x-1)*(y-1))/8, -((x+1)*(y-1))/8,((x+1)*(y+1))/8, -((x-1)*(y+1))/8];

    J = [qx  qy  qz]' * J_; # Jacobian
    qxyz = J \ [qx   qy   qz]';
    B_e = zeros(6,3,8);
    for i_B = 1:8
        B_e[:,:,i_B] = [qxyz[1,i_B]   0             0;
                        0             qxyz[2,i_B]   0;
                        0             0             qxyz[3,i_B];
                        qxyz[2,i_B]   qxyz[1,i_B]   0;
                        0             qxyz[3,i_B]   qxyz[2,i_B];
                        qxyz[3,i_B]   0             qxyz[1,i_B]];
    end
    B = [B_e[:,:,1] B_e[:,:,2] B_e[:,:,3] B_e[:,:,4] B_e[:,:,5] B_e[:,:,6] B_e[:,:,7] B_e[:,:,8]];
    return B, J
    
end

function homogAsymp3D(ll, vox, props0, def="youngs", solver="pcg")
    nelx, nely, nelz = size(vox); #size of voxel model along x,y and z axis
    dx = ll[1]/nelx; dy = ll[2]/nely; dz = ll[3]/nelz;
    nel = nelx*nely*nelz;
    
    # Node numbers and element degrees of freedom for full (not periodic) mesh
    nodenrs = reshape(1:(1+nelx)*(1+nely)*(1+nelz),1+nelx,1+nely,1+nelz);
    edofVec = reshape(3*nodenrs[1:end-1,1:end-1,1:end-1] .+ 1,nel,1);

    addx = [0 1 2 3*nelx .+ [3 4 5 0 1 2] -3 -2 -1];
    addxy = 3*(nely+1)*(nelx+1) .+ addx;
    edofMat = repeat(edofVec,1,24) .+ repeat([addx addxy],nel,1);

    ## IMPOSE PERIODIC BOUNDARY CONDITIONS
    # Use original edofMat to index into list with the periodic dofs
    nn = (nelx+1)*(nely+1)*(nelz+1); # Total number of nodes
    nnP = (nelx)*(nely)*(nelz);      # Total number of unique nodes
    nnPArray_old = reshape(1:nnP, nelx, nely, nelz);
    
    nnPArray=zeros(nelx+1, nely+1, nelz+1);
    
    nnPArray[1:nelx,1:nely,1:nelz].=nnPArray_old;
        
    # Extend with a mirror of the back border
    nnPArray[end,:,:] = nnPArray[1,:,:];    
    # Extend with a mirror of the left border
    nnPArray[:, end, :] = nnPArray[:,1,:];
    # Extend with a mirror of the top border
    nnPArray[:, :, end] = nnPArray[:,:,1];
    
    # Make a vector into which we can index using edofMat:
    dofVector = zeros(3*nn, 1);
    dofVector[1:3:end] = 3*nnPArray[:] .-2;
    dofVector[2:3:end] = 3*nnPArray[:] .-1;
    dofVector[3:3:end] = 3*nnPArray[:];
    edof = Int.(dofVector[edofMat]);
    ndof = 3 .*nnP;

    ## ASSEMBLE GLOBAL STIFFNESS MATRIX K AND LOAD VECTOR F
    # Indexing vectors
    iK = kron(edof,ones(24,1))';
    jK = kron(edof,ones(1,24))';
    iF = repeat(edof',6,1);
    jF = [ones(24,nel); 2 .*ones(24,nel); 3 .*ones(24,nel); 4 .*ones(24,nel); 5 .*ones(24,nel); 6 .*ones(24,nel);];

    # Assemble stiffness matrix and load vector
    if def == "lame"
        # Material properties assigned to voxels with materials
        lambda = props0[1,:]; 
        mu = props0[2,:];
        lambda = lambda[1]*(vox==0) + lambda[2]*(vox==1);
        mu = mu[1]*(vox==0) + mu[2]*(vox==1);
        
        # Unit element stiffness matrix and load
        keLambda, keMu, feLambda, feMu = assemble_lame(dx/2, dy/2, dz/2);
        ke = keMu + keLambda; # Here the exact ratio does not matter, because
        fe = feMu + feLambda; # it is reflected in the load vector
        sK = keLambda[:]* lambda[:]' + keMu(:)*mu(:)';
        sF = feLambda[:]* lambda[:]' + feMu(:)*mu(:)';
    
        # sK = keLambda[:]* lambda[:].' + keMu(:)*mu(:).';
        # sF = feLambda[:]* lambda[:].' + feMu(:)*mu(:).';
    
    elseif def == "youngs"
        E = props0[1,:]; 
        E = E[1] .+ vox .*(E[2] .-E[1]); # SIMP
        nu = props0[2,:];
        
        # Unit element stiffness matrix and load
        ke, fe = assemble_youngs(nu, dx/2, dy/2, dz/2);
        sK = ke[:]*E[:]';
        sF = fe[:]*E[:]';
    else
        error("unavailable option for constituent properties definition")
    end
    # Global stiffness matrix
    K = sparse(iK[:], jK[:], sK[:], ndof, ndof);
    K = (K+K')/2;
    # Six load cases corresponding to the six strain cases
    F  = sparse(iF[:], jF[:], sF[:], ndof, 6);

    ## SOLUTION    
    activedofs = edof[reshape((vox.==0) .| (vox.==1),nelx* nely* nelz ),:];
    activedofs = Int.(sort(unique(activedofs[:])));
    X = zeros(ndof,6);
    display("Solving")
    if solver =="pcg"
        # solve using PCG method, remember to constrain one node
        # L = ichol(K[activedofs[4:end],activedofs[4:end]]); # preconditioner
        display("started pcg")
        for i = 1:6 # run once for each loading condition
            # [X[activedofs[4:end],i],~,~,~] = cg(K[activedofs[4:end],activedofs[4:end]],F[activedofs[4:end],i]
            #     ,1e-10,300,L,L');
            
            # A = cu(K[activedofs[4:end],activedofs[4:end]])
            # b = cu(F[activedofs[4:end],i])
            # X[activedofs[4:end],i].= Array(cg(A, b,verbose=true))
            A = K[activedofs[4:end],activedofs[4:end]]
            b = F[activedofs[4:end],i]
            # x = cg(A, b,tol=1.0e-10,maxiter=300)
            X[activedofs[4:end],i].= Krylov.cg(A, b,atol=1.0e-12,rtol=1.0e-12,itmax =500,verbose=false)[1]
            # X[activedofs[4:end],i].= Krylov.cg(A, b,tol=1.0e-10,maxiter=2,verbose=true)
            
            # X[activedofs[4:end],i]=cg(K[activedofs[4:end],activedofs[4:end]],F[activedofs[4:end],i])
            display(i)
        end
    elseif solver=="direct"
        display("started direct")
        # solve using direct method
        X[activedofs[4:end],:] = K[activedofs[4:end],activedofs[4:end]] \ F[activedofs[4:end],:];
    else
        error("unavailable option for solver")
    
    end

    ## ASYMPTOTIC HOMOGENIZATION
    # The displacement vectors corresponding to the unit strain cases
    X0 = zeros(nel, 24, 6);
    # The element displacements for the six unit strains
    X0_e = zeros(24, 6);
    # fix degrees of nodes [1 2 3 5 6 12];
    X0_e[vcat(4,7:11,13:24),:] = ke[vcat(4,7:11,13:24),vcat(4,7:11,13:24)] \fe[vcat(4, 7:11, 13:24),:];
    X0[:,:,1] = kron(X0_e[:,1]', ones(nel,1)); # epsilon0_11 = (1,0,0,0,0,0)
    X0[:,:,2] = kron(X0_e[:,2]', ones(nel,1)); # epsilon0_22 = (0,1,0,0,0,0)
    X0[:,:,3] = kron(X0_e[:,3]', ones(nel,1)); # epsilon0_33 = (0,0,1,0,0,0)
    X0[:,:,4] = kron(X0_e[:,4]', ones(nel,1)); # epsilon0_12 = (0,0,0,1,0,0)
    X0[:,:,5] = kron(X0_e[:,5]', ones(nel,1)); # epsilon0_23 = (0,0,0,0,1,0)
    X0[:,:,6] = kron(X0_e[:,6]', ones(nel,1)); # epsilon0_13 = (0,0,0,0,0,1)
    CH = zeros(6,6);
    volume = prod(ll);
    # Homogenized elasticity tensor
    if def == "lame"
        for i = 1:6
            for j = 1:6
                sum_L = (X0[:,:,i] .- X[edof .+(i-1)*ndof]*keLambda).*(X0[:,:,j] .- X[edof .+(j-1)*ndof]);
                sum_M = (X0[:,:,i] .- X[edof .+(i-1)*ndof]*keMu).* (X0[:,:,j] .- X[edof .+(j-1)*ndof]);
                sum_L = reshape(sum(sum_L,dims=2), nelx, nely, nelz);
                sum_M = reshape(sum(sum_M,dims=2), nelx, nely, nelz);
                CH[i,j] = 1/volume*sum(sum(sum(lambda.*sum_L + mu .* sum_M)));
            end
        end
    elseif def == "youngs"
        for i = 1:6
            for j = 1:6
                sum_XkX = ((X0[:,:,i] .- X[edof .+ (i-1)*ndof] )*ke).* (X0[:,:,j] .- X[edof .+ (j-1)*ndof]);
                sum_XkX = reshape(sum(sum_XkX,dims=2), nelx, nely, nelz);
                CH[i,j] = 1/volume*sum(sum(sum(sum_XkX.*E)));
            end
        end
    end
    return CH

end

function evaluateCH(CH, dens, outOption, dispFlag)

    U,S,V = svd(CH);
    sigma = S;
    k = sum(sigma .> 1e-15);
    SH = (U[:,1:k] * diagm(0=>(1 ./sigma[1:k])) * V[:,1:k]')'; # more stable SVD (pseudo)inverse
    EH = [1/SH[1,1], 1/SH[2,2], 1/SH[3,3]]; # effective Young's modulus
    GH = [1/SH[4,4], 1/SH[5,5], 1/SH[6,6]]; # effective shear modulus
    vH = [-SH[2,1]/SH[1,1]  -SH[3,1]/SH[1,1]  -SH[3,2]/SH[2,2];
         -SH[1,2]/SH[2,2]  -SH[1,3]/SH[3,3]  -SH[2,3]/SH[3,3]]; # effective Poisson's ratio
        
    if outOption=="struct"
        props = Dict("CH"=>CH, "SH"=>SH, "EH"=>EH, "GH"=>GH, "vH"=>vH, "density"=>dens);
    elseif outOption== "vec"
        props =  [EH, GH, vH[:]', dens];
    end
        
    if true
        println("\n--------------------------EFFECTIVE PROPERTIES--------------------------\n")
        println("Density: $dens")
        println("Youngs Modulus:____E11_____|____E22_____|____E33_____\n")
        println("               $(EH[1]) | $(EH[2]) | $(EH[3])\n\n")
        println("Shear Modulus:_____G23_____|____G31_____|____G12_____\n")
        println("               $(GH[1]) | $(GH[2]) | $(GH[3])\n\n")
        println("Poissons Ratio:____v12_____|____v13_____|____v23_____\n")
        println("               $(vH[1,1]) | $(vH[1,2]) | $(vH[1,3])\n\n")
        println("               ____v21_____|____v31_____|____v32_____\n")
        println("               $(vH[2,1]) | $(vH[2,2]) | $(vH[2,3])\n\n")
        println("------------------------------------------------------------------------")
    end
        
        
    return props, SH
end


## SUB FUNCTION: elementMatVec3D
function elementMatVec3D(a, b, c, DH)
    GN_x=[-1/sqrt(3),1/sqrt(3)]; GN_y=GN_x; GN_z=GN_x; GaussWeigh=[1,1];
    Ke = zeros(24,24); L = zeros(6,9);
    L[1,1] = 1; L[2,5] = 1; L[3,9] = 1;
    L[4,2] = 1; L[4,4] = 1; L[5,6] = 1;
    L[5,8] = 1; L[6,3] = 1; L[6,7] = 1;
    for ii=1:length(GN_x)
        for jj=1:length(GN_y)
            for kk=1:length(GN_z)
                x = GN_x[ii];y = GN_y[jj];z = GN_z[kk];
                dNx = 1/8*[-(1-y)*(1-z)  (1-y)*(1-z)  (1+y)*(1-z) -(1+y)*(1-z) -(1-y)*(1+z)  (1-y)*(1+z)  (1+y)*(1+z) -(1+y)*(1+z)];
                dNy = 1/8*[-(1-x)*(1-z) -(1+x)*(1-z)  (1+x)*(1-z)  (1-x)*(1-z) -(1-x)*(1+z) -(1+x)*(1+z)  (1+x)*(1+z)  (1-x)*(1+z)];
                dNz = 1/8*[-(1-x)*(1-y) -(1+x)*(1-y) -(1+x)*(1+y) -(1-x)*(1+y)  (1-x)*(1-y)  (1+x)*(1-y)  (1+x)*(1+y)  (1-x)*(1+y)];
                J = [dNx;dNy;dNz]*[ -a  a  a  -a  -a  a  a  -a ;  -b  -b  b  b  -b  -b  b  b; -c -c -c -c  c  c  c  c]';
                G = [inv(J) zeros(3) zeros(3);zeros(3) inv(J) zeros(3);zeros(3) zeros(3) inv(J)];
                dN=zeros(9,24)
                dN[1,1:3:24] = dNx; dN[2,1:3:24] = dNy; dN[3,1:3:24] = dNz;
                dN[4,2:3:24] = dNx; dN[5,2:3:24] = dNy; dN[6,2:3:24] = dNz;
                dN[7,3:3:24] = dNx; dN[8,3:3:24] = dNy; dN[9,3:3:24] = dNz;
                Be = L*G*dN;
                Ke = Ke + GaussWeigh[ii]*GaussWeigh[jj]*GaussWeigh[kk]*det(J)*(Be'*DH*Be);
            end
        end
    end
    return Ke
end