# Amira Abdel-Rahman
# (c) Massachusetts Institute of Technology 2020


# BASED ON https://github.com/jonhiller/Voxelyze

function updateEdges!(dt,currentTimeStep,E_source,E_target,E_stress,E_axis,
        E_currentRestLength,E_pos2,E_angle1v,E_angle2v,
        E_angle1,E_angle2,E_intForce1,E_intMoment1,E_intForce2,E_intMoment2,E_damp,E_smallAngle,E_material,
        E_strain,E_maxStrain,E_strainOffset,E_currentTransverseArea,E_currentTransverseStrainSum,
        N_currPosition,N_orient,N_poissonStrain)

    index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    stride = blockDim().x * gridDim().x
    
    N=length(E_source)
    # @cuprintln("N $N, thread $index, block $stride")
    
    for i = index:stride:N
        @inbounds pVNeg=N_currPosition[E_source[i]]
        @inbounds pVPos=N_currPosition[E_target[i]]
        
        @inbounds oVNeg=N_orient[E_source[i]]
        @inbounds oVPos=N_orient[E_target[i]]
        
        @inbounds oldPos2=     Vector3(E_pos2[i].x,E_pos2[i].y,E_pos2[i].z) #?copy?
        @inbounds oldAngle1v = Vector3(E_angle1v[i].x,E_angle1v[i].y,E_angle1v[i].z)
        @inbounds oldAngle2v = Vector3(E_angle2v[i].x,E_angle2v[i].y,E_angle2v[i].z)# remember the positions/angles from last timestep to calculate velocity


        @inbounds E_currentRestLength[i]=updateTemperature(E_currentRestLength[i],currentTimeStep,E_material[i])

        prec=1e16
        @inbounds l  = roundd(E_currentRestLength[i],prec)
        @inbounds l  = E_currentRestLength[i]

        

        @inbounds E_pos2[i],E_angle1v[i],E_angle2v[i],E_angle1[i],E_angle2[i],totalRot,E_smallAngle[i],E_damp[i]= orientLink!(i,l,pVNeg,pVPos,oVNeg,oVPos,E_axis[i],E_smallAngle[i],E_damp[i] )


        @inbounds dPos2   = Vector3(0.5,0.5,0.5) * (E_pos2[i]-oldPos2)  #deltas for local damping. velocity at center is half the total velocity
        @inbounds dAngle1 = Vector3(0.5,0.5,0.5) *(E_angle1v[i]-oldAngle1v)
        @inbounds dAngle2 = Vector3(0.5,0.5,0.5) *(E_angle2v[i]-oldAngle2v)




        #if volume effects...
        @inbounds if ((E_material[i].poisson && E_material[i].nu != 0.0) || E_currentTransverseStrainSum[i] != 0.0) 
            @inbounds E_currentTransverseArea[i],E_currentTransverseStrainSum[i]=updateTransverseInfo(E_currentTransverseArea[i],E_currentTransverseStrainSum[i],E_material[i],E_axis[i],N_poissonStrain[E_source[i]],N_poissonStrain[E_target[i]]); #currentTransverseStrainSum != 0 catches when we disable poissons mid-simulation
        end

        
        @inbounds strain=(E_pos2[i].x/l)
        @inbounds E_strain[i]=strain


        @inbounds nu = convert(Float64,E_material[i].nu)
 
        # Cross Section inputs, must be floats        
        @inbounds E = roundd(E_material[i].E,prec)   # MPa
        @inbounds h = roundd(E_material[i].h,prec)   # mm
        @inbounds b = roundd(E_material[i].b,prec)  # mm
        

        @inbounds a1= roundd(E_material[i].a1,prec)
        @inbounds a2= roundd(E_material[i].a2,prec)
        @inbounds b1= roundd(E_material[i].b1,prec)
        @inbounds b2= roundd(E_material[i].b2,prec)
        @inbounds b3= roundd(E_material[i].b3,prec)
        
        
        
        @inbounds currentTransverseArea= b*h

        

        @inbounds if(E_material[i].poisson)
            @inbounds currentTransverseArea= E_currentTransverseArea[i] #todo check
        end

        @inbounds E_stress[i],E_maxStrain[i],E_strainOffset[i]=updateStrain( strain,E_maxStrain[i],E_strainOffset[i],E_material[i],E_currentTransverseStrainSum[i]) #updateStrain(strain,E)
        # if i==40
        #     @cuprintln("stress:$(E_stress[i])")
        # end
        
        @inbounds _stress=E_stress[i]

        x=(_stress*currentTransverseArea)

        if i==38
            # @cuprintln("E_maxStrain[i] $(E_maxStrain[i])")
        end

        if (isFailed(E_maxStrain[i],E_material[i])) #if failed
            # forceNeg = Vector3(0.0,0.0,0.0);
            # forcePos = Vector3(0.0,0.0,0.0);
            # momentNeg = Vector3(0.0,0.0,0.0);
            # momentPos = Vector3(0.0,0.0,0.0);
            # @cuprintln("here")
            E_stress[i]=0.0
            E_intForce1[i]= Vector3(0.0,0.0,0.0);
            E_intForce2[i]= Vector3(0.0,0.0,0.0);
            E_intMoment1[i]= Vector3(0.0,0.0,0.0);
            E_intMoment2[i]= Vector3(0.0,0.0,0.0);
            return
        end

        @inbounds y=(b1*E_pos2[i].y-b2*(E_angle1v[i].z + E_angle2v[i].z))
        @inbounds z=(b1*E_pos2[i].z + b2*(E_angle1v[i].y + E_angle2v[i].y))
        
        x=convert(Float64,x)
        y=convert(Float64,y)
        z=convert(Float64,z)

        
        
        # Use Curstress instead of -a1*Pos2.x to account for non-linear deformation 



        forceNeg = Vector3(x,y,z)
        forcePos = Vector3(-x,-y,-z)


        
        @inbounds x= (a2*(E_angle2v[i].x-E_angle1v[i].x))
        @inbounds y= (-b2*E_pos2[i].z-b3*(2.0*E_angle1v[i].y+E_angle2v[i].y))
        @inbounds z=(b2*E_pos2[i].y - b3*(2.0*E_angle1v[i].z + E_angle2v[i].z))  
        x=convert(Float64,x)
        y=convert(Float64,y)
        z=convert(Float64,z)
        momentNeg = Vector3(x,y,z)
        

        @inbounds x= (a2*(E_angle1v[i].x-E_angle2v[i].x))
        @inbounds y= (-b2*E_pos2[i].z- b3*(E_angle1v[i].y+2.0*E_angle2v[i].y))
        @inbounds z=(b2*E_pos2[i].y - b3*(E_angle1v[i].z + 2.0*E_angle2v[i].z))
        x=convert(Float64,x)
        y=convert(Float64,y)
        z=convert(Float64,z)
        momentPos = Vector3(x,y,z)

        
        ### damping
        @inbounds if E_damp[i] #first pass no damping
            # @cuprintln("damping!!!!!!!!!!")
            @inbounds sqA1     =convert(Float64,E_material[i].sqA1)
            @inbounds sqA2xIp  =convert(Float64,E_material[i].sqA2xIp)
            @inbounds sqB1     =convert(Float64,E_material[i].sqB1)
            @inbounds sqB2xFMp =convert(Float64,E_material[i].sqB2xFMp)
            @inbounds sqB3xIp  =convert(Float64,E_material[i].sqB3xIp)
            
            dampingMultiplier=Vector3(28099.3,28099.3,28099.3) # 2*mat->_sqrtMass*mat->zetaInternal/previousDt;?? todo link to material
            
            zeta=1.0
            dampingM= convert(Float64,E_material[i].dampingM)/dt*1.0
            dampingMultiplier=Vector3(dampingM,dampingM,dampingM)

            
            
            posCalc=Vector3(sqA1*dPos2.x, 
                            sqB1*dPos2.y - sqB2xFMp*(dAngle1.z+dAngle2.z),
                            sqB1*dPos2.z + sqB2xFMp*(dAngle1.y+dAngle2.y))

 
            
            forceNeg =forceNeg + (dampingMultiplier*posCalc);
            forcePos =forcePos - (dampingMultiplier*posCalc);

            momentNeg -= Vector3(0.5,0.5,0.5)*dampingMultiplier*Vector3(-sqA2xIp*(dAngle2.x - dAngle1.x),
                                                                    sqB2xFMp*dPos2.z + sqB3xIp*(2*dAngle1.y + dAngle2.y),
                                                                    -sqB2xFMp*dPos2.y + sqB3xIp*(2*dAngle1.z + dAngle2.z));
            momentPos -= Vector3(0.5,0.5,0.5)*dampingMultiplier*Vector3(sqA2xIp*(dAngle2.x - dAngle1.x),
                                                                sqB2xFMp*dPos2.z + sqB3xIp*(dAngle1.y + 2*dAngle2.y),
                                                                -sqB2xFMp*dPos2.y + sqB3xIp*(dAngle1.z + 2*dAngle2.z));

        else
           @inbounds E_damp[i]=true 
        end



        # smallAngle=true
        @inbounds if !E_smallAngle[i] # ?? check

            @inbounds forceNeg = RotateVec3DInv(E_angle1[i],forceNeg)
            @inbounds momentNeg = RotateVec3DInv(E_angle1[i],momentNeg)
        end

        

        
        @inbounds forcePos = RotateVec3DInv(E_angle2[i],forcePos)
        @inbounds momentPos = RotateVec3DInv(E_angle2[i],momentPos)

        @inbounds forceNeg =toAxisOriginalVector3(forceNeg,E_axis[i])
        @inbounds forcePos =toAxisOriginalVector3(forcePos,E_axis[i])

        @inbounds momentNeg=toAxisOriginalQuat(momentNeg,E_axis[i])# TODOO CHECKKKKKK
        @inbounds momentPos=toAxisOriginalQuat(momentPos,E_axis[i])

        
        @inbounds E_intForce1[i] =forceNeg 
        @inbounds E_intForce2[i] =forcePos

        @inbounds x= momentNeg.x
        @inbounds y= momentNeg.y
        @inbounds z= momentNeg.z  
        x=convert(Float64,x)
        y=convert(Float64,y)
        z=convert(Float64,z)
        
        @inbounds E_intMoment1[i]=Vector3(x,y,z)

        @inbounds x= momentPos.x #changed to momentPos todo check!!
        @inbounds y= momentPos.y #changed to momentPos todo check!!
        @inbounds z= momentPos.z #changed to momentPos todo check!!
        x=convert(Float64,x)
        y=convert(Float64,y)
        z=convert(Float64,z)
        
        @inbounds E_intMoment2[i]=Vector3(x,y,z)

        
    end

    return
end

function run_updateEdges!(dt,currentTimeStep,E_source,E_target,
        E_stress,E_axis,E_currentRestLength,E_pos2,E_angle1v,E_angle2v,
        E_angle1,E_angle2,E_intForce1,E_intMoment1,E_intForce2,E_intMoment2,
        E_damp,E_smallAngle,E_material,
        E_strain,E_maxStrain,E_strainOffset,E_currentTransverseArea,E_currentTransverseStrainSum,
        N_currPosition,N_orient,N_poissonStrain)
    N=length(E_source)
    numblocks = ceil(Int, N/256)
    CUDA.@sync begin
        @cuda threads=256 blocks=numblocks updateEdges!(dt,currentTimeStep,E_source,E_target,E_stress,E_axis,E_currentRestLength,E_pos2,E_angle1v,
            E_angle2v,E_angle1,E_angle2,E_intForce1,E_intMoment1,E_intForce2,
            E_intMoment2,E_damp,E_smallAngle,E_material,
            E_strain,E_maxStrain,E_strainOffset,E_currentTransverseArea,E_currentTransverseStrainSum,
            N_currPosition,N_orient,N_poissonStrain)
    end
end

function orientLink!(i,currentRestLength,pVNeg,pVPos,oVNeg,oVPos,axis,smallAngle,damp)  # updates pos2, angle1, angle2, and smallAngle //Quat3D<double> /*double restLength*/
    
    
    pos2 = toAxisXVector3(pVPos-pVNeg,axis) # digit truncation happens here...


    

    angle1 = toAxisXQuat(oVNeg,axis)
    angle2 = toAxisXQuat(oVPos,axis)

    
    totalRot = conjugate(angle1) #keep track of the total rotation of this bond (after toAxisX()) # Quat3D<double>
    pos2 = RotateVec3D(totalRot,pos2)



    angle2 = multiplyQuaternions(totalRot,angle2)
    angle1 = Quaternion(0.0,0.0,0.0,1.0)#new THREE.Quaternion() #zero for now...


    # smallAngle=true #todo later remove
    
    #small angle approximation?
	SmallTurn = ((abs(pos2.z)+abs(pos2.y))/pos2.x);
    ExtendPerc = (abs(1.0-pos2.x/currentRestLength));

    HYSTERESIS_FACTOR = 1.2 * 1e0; #Amount for small angle bond calculations *todo change based on scale
    SA_BOND_BEND_RAD = 0.05 * 1e0; #Amount for small angle bond calculations *todo change based on scale
    SA_BOND_EXT_PERC = 0.50 * 1e0; #Amount for small angle bond calculations *todo change based on scale

    if (!smallAngle && SmallTurn < SA_BOND_BEND_RAD && ExtendPerc < SA_BOND_EXT_PERC)
        smallAngle=true
        damp=false
    elseif ( smallAngle && (SmallTurn > HYSTERESIS_FACTOR*SA_BOND_BEND_RAD || ExtendPerc > HYSTERESIS_FACTOR*SA_BOND_EXT_PERC))
        smallAngle=false
        damp=false
        # @cuprintln("not small angle!!!!!!!!!!")
    end

    # smallAngle=true #todo later remove

    

    if (smallAngle)	 #Align so Angle1 is all zeros
        #pos2[1] =pos2[1]- currentRestLength #only valid for small angles
        pos2=Vector3(pos2.x-currentRestLength,pos2.y,pos2.z)
    else  #Large angle. Align so that Pos2.y, Pos2.z are zero.
        # @cuprintln("large Angle!!!")
        angle1=FromAngleToPosX(angle1,pos2) #get the angle to align Pos2 with the X axis
       
        # totalRot=Quaternion(angle1.x*totalRot.x ,angle1.y*totalRot.y ,angle1.z*totalRot.z ,angle1.w*totalRot.w )  #update our total rotation to reflect this
        totalRot = multiplyQuaternions(angle1,totalRot)

        # angle2=Quaternion(angle1.x*angle2.x ,angle1.y*angle2.y ,angle1.z*angle2.z ,angle1.w*angle2.w ) #rotate angle2
        angle2 = multiplyQuaternions(angle1,angle2)

        pos2=Vector3(lengthVector3(pos2)- currentRestLength,0.0,0.0)

        

    end

    
    angle1v = ToRotationVector(angle1)
    angle2v = ToRotationVector(angle2)

    prec=10e12
    x=roundd(pos2.x,prec)
    y=roundd(pos2.y,prec)
    z=roundd(pos2.z,prec)
    pos2=Vector3(x,y,z)


    # pos2,angle1v,angle2v,angle1,angle2,
    return pos2,angle1v,angle2v,angle1,angle2,totalRot,smallAngle,damp
end

###################################
function isFailed(strain,mat) 
    # return strain > mat.epsilonFail #todo fix
    return mat.epsilonFail != -1.0 && strain>mat.epsilonFail; 
end #!< Returns true if the specified strain is past the failure point (if one is specified)

function stress(strain, transverseStrainSum,mat)
    #reference: http://www.colorado.edu/engineering/CAS/courses.d/Structures.d/IAST.Lect05.d/IAST.Lect05.pdf page 10
    if (isFailed(strain,mat)) 
        # @cuprintln("fail!")
        return 0.0; #/if a failure point is set and exceeded, we've broken!
    end
    
    # if ( mat.linear)
	if (strain <= mat.strainData[1] || mat.linear)# || forceLinear) #for compression/first segment and linear materials (forced or otherwise), simple calculation
        if ( !mat.poisson || mat.nu == 0.0)
            prec=10e8 #do i really need it now??
            return roundd(mat.E,prec)*strain;
        else
            # @cuprintln(" transverseStrainSum $(transverseStrainSum*1e6) *1e-6")
            # @cuprintln(" mat.eHat $(mat.eHat)")
            return mat.eHat*((1.0-mat.nu)*strain + mat.nu*transverseStrainSum)
            #else return eHat()*((1-nu)*strain + nu*transverseStrainSum); 
        end
	end

	#the non-linear feature with non-zero poissons ratio is currently experimental
    DataCount = length(mat.strainData); #int
	for i = 3:DataCount #(i=2; i<DataCount; i++) #go through each segment in the material model (skipping the first segment because it has already been handled.
		if (strain <= mat.strainData[i] || i==DataCount) #if in the segment ending with this point (or if this is the last point extrapolate out) 
			Perc = (strain-mat.strainData[i-1])/(mat.strainData[i]-mat.strainData[i-1]);
			basicStress = mat.stressData[i-1] + Perc*(mat.stressData[i]-mat.stressData[i-1]);
            if (!mat.poisson || mat.nu == 0.0) 
                return basicStress;
			else  #accounting for volumetric effects
				modulus = (mat.stressData[i]-mat.stressData[i-1])/(mat.strainData[i]-mat.strainData[i-1]);
				modulusHat = modulus/((1.0-2.0*mat.nu)*(1.0+mat.nu));
				effectiveStrain = basicStress/modulus; #this is the strain at which a simple linear stress strain line would hit this point at the definied modulus
				effectiveTransverseStrainSum = transverseStrainSum*(effectiveStrain/strain);
				return modulusHat*((1.0-mat.nu)*effectiveStrain + mat.nu*effectiveTransverseStrainSum);
            end
		end
	end

    ##assert(false); //should never reach this point
    # todo show error
	return 0.0;
end

function updateTransverseInfo(currentTransverseArea,currentTransverseStrainSum,mat,axis,poissonsStrainNeg,poissonsStrainPos)
    # @cuprintln("updateTransverseInfo!!!!!!!!!!!!!")

	currentTransverseArea = 0.5*(transverseArea( mat,axis,poissonsStrainNeg)+transverseArea( mat,axis,poissonsStrainPos));
    currentTransverseStrainSum = 0.5*(transverseStrainSum( mat,axis,poissonsStrainNeg)+transverseStrainSum( mat,axis,poissonsStrainPos));


    return currentTransverseArea,currentTransverseStrainSum

end

function strainEnergy(mat,forceNeg,momentNeg,momentPos) 
	return	forceNeg.x*forceNeg.x/(2.0*mat.a1) + #Tensile strain
			momentNeg.x*momentNeg.x/(2.0*mat.a2) + #Torsion strain
			(momentNeg.z*momentNeg.z - momentNeg.z*momentPos.z +momentPos.z*momentPos.z)/(3.0*mat.b3) + #Bending Z
			(momentNeg.y*momentNeg.y - momentNeg.y*momentPos.y +momentPos.y*momentPos.y)/(3.0*mat.b3); #/Bending Y
end

function updateStrain( axialStrain,maxStrain,strainOffset,mat,currentTransverseStrainSum)

	if (mat.linear)
        if (axialStrain > maxStrain) 
            maxStrain = axialStrain; #remember this maximum for easy reference
        end
		return stress(axialStrain, currentTransverseStrainSum,mat),maxStrain,strainOffset;
	else 
		# @cuprintln(" non linear material!")
		returnStress=0.0

        if (axialStrain > maxStrain) #if new territory on the stress/strain curve
			maxStrain = axialStrain; #remember this maximum for easy reference
			returnStress = stress(axialStrain, currentTransverseStrainSum,mat);
			
            if (mat.poisson && mat.nu != 0.0) 
                strainOffset = maxStrain-stress(axialStrain, 0.0,mat)/(mat.eHat*(1.0-mat.nu)); #precalculate strain offset for when we back off
            else 
                strainOffset = maxStrain-returnStress/mat.E; #precalculate strain offset for when we back off
            end
		else  #backed off a non-linear material, therefore in linear region.
			relativeStrain = axialStrain-strainOffset; # treat the material as linear with a strain offset according to the maximum plastic deformation
			
            if (mat.poisson && mat.nu != 0.0) 
                returnStress = stress(relativeStrain, currentTransverseStrainSum,mat);
            else 
                returnStress = mat.E*relativeStrain;
            end
		end
		return returnStress,maxStrain,strainOffset;
    end
end

function transverseStrainSum( mat,axis,poissonsStrain)
    if (!mat.poisson || mat.nu == 0.0)
        return 0;
    end
	
    psVec = poissonsStrain; 
    
    val=0.0 #todo change for multiple degrees of freedom
    if (axis.x!=0.0)
        val=val+psVec.y+psVec.z
    elseif (axis.y!=0.0)
        val=val+psVec.x+psVec.z
    elseif (axis.z!=0.0)
        val=val+psVec.x+psVec.y
    end
    return val
end

function transverseArea(mat,axis,poissonsStrain)
    # size =mat.nominalSize;
    size =mat.b; #todo change later to nom size
    
    if (!mat.poisson || mat.nu == 0.0) 
        return size*size
    end

    psVec = poissonsStrain;

    # x=pos2.x*1e6
    # y=pos2.y*1e6
    # z=pos2.z*1e6
    # @cuprintln("pos2 12 x $x 1e-6, y $y 1e-6, z $z 1e-6")

    val=size*size #todo change for multiple degrees of freedom
    if (axis.x!=0.0)
        val=val*(1.0+psVec.y)*(1.0+psVec.z)
    elseif (axis.y!=0.0)
        val=val*(1.0+psVec.x)*(1.0+psVec.z)
    elseif (axis.z!=0.0)
        val=val*(1.0+psVec.x)*(1.0+psVec.y)
    end
    return val

end



# function axialStiffness(pVNeg,pVPos,axis,mat,currentTransverseArea,strain,currentRestLength) 
#     if (mat.isXyzIndependent()) 
#         return mat.a1;
# 	else 
# 		# updateRestLength();
# 		updateTransverseInfo(pVNeg,pVPos,axis)

# 		return (mat.eHat*currentTransverseArea/((strain+1.0)*currentRestLength)); # _a1;
#     end
# end
###########################################################################