{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "using LinearAlgebra\n",
    "using Plots\n",
    "import JSON\n",
    "# using Quaternions\n",
    "using StaticArrays, Rotations\n",
    "using Distributed\n",
    "using StaticArrays, BenchmarkTools\n",
    "using Base.Threads\n",
    "using CUDAnative\n",
    "using CuArrays,CUDAdrv \n",
    "using Test\n",
    "import Base: +, * , -, ^\n",
    "# BASED ON https://github.com/jonhiller/Voxelyze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct Vector3\n",
    "    x::Float32\n",
    "    y::Float32\n",
    "    z::Float32\n",
    "    function Vector3()\n",
    "        x=0.0\n",
    "        y=0.0\n",
    "        z=0.0\n",
    "        new(x,y,z)\n",
    "    end\n",
    "    function Vector3(x,y,z)\n",
    "       new(x,y,z)\n",
    "    end\n",
    "end\n",
    "struct Quaternion\n",
    "    x::Float32\n",
    "    y::Float32\n",
    "    z::Float32\n",
    "    w::Float32\n",
    "    function Quaternion()\n",
    "        x=0.0\n",
    "        y=0.0\n",
    "        z=0.0\n",
    "        w=1.0\n",
    "        new(x,y,z,w)\n",
    "    end\n",
    "    function Quaternion(x,y,z,w)\n",
    "        new(x,y,z,w)\n",
    "    end\n",
    "end\n",
    "\n",
    "+(f::Vector3, g::Vector3)=Vector3(f.x+g.x , f.y+g.y,f.z+g.z )\n",
    "-(f::Vector3, g::Vector3)=Vector3(f.x-g.x , f.y-g.y,f.z-g.z )\n",
    "*(f::Vector3, g::Vector3)=Vector3(f.x*g.x , f.y*g.y,f.z*g.z )\n",
    "\n",
    "+(f::Vector3, g::Number)=Vector3(f.x+g , f.y+g,f.z+g )\n",
    "-(f::Vector3, g::Number)=Vector3(f.x-g , f.y-g,f.z-g )\n",
    "*(f::Vector3, g::Number)=Vector3(f.x*g , f.y*g,f.z*g )\n",
    "\n",
    "+(g::Vector3, f::Number)=Vector3(f.x+g , f.y+g,f.z+g )\n",
    "-(g::Vector3, f::Number)=Vector3(g-f.x , g-f.y,g-f.z )\n",
    "*(g::Vector3, f::Number)=Vector3(f.x*g , f.y*g,f.z*g )\n",
    "\n",
    "addX(f::Vector3, g::Number)=Vector3(f.x+g , f.y,f.z)\n",
    "addY(f::Vector3, g::Number)=Vector3(f.x , f.y+g,f.z )\n",
    "addZ(f::Vector3, g::Number)=Vector3(f.x , f.y,f.z+g )\n",
    "\n",
    "function Base.show(io::IO, v::Vector3)\n",
    "    print(io, \"x:$(v.x), y:$(v.y), z:$(v.z)\")\n",
    "end\n",
    "\n",
    "function Base.show(io::IO, v::Quaternion)\n",
    "    print(io, \"x:$(v.x), y:$(v.y), z:$(v.z), w:$(v.z)\")\n",
    "end\n",
    "\n",
    "Base.Broadcast.broadcastable(q::Vector3) = Ref(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct Node\n",
    "    id::Int32\n",
    "    position::Vector3\n",
    "    restrained::Bool\n",
    "    displacement::Vector3\n",
    "    angle::Vector3\n",
    "    force::Vector3\n",
    "    currPosition::Vector3\n",
    "    orient::Quaternion\n",
    "    linMom::Vector3\n",
    "    angMom::Vector3\n",
    "    intForce::Vector3\n",
    "    intMoment::Vector3\n",
    "    moment::Vector3\n",
    "    function Node()\n",
    "        id=0\n",
    "        position=Vector3()\n",
    "        restrained=false\n",
    "        displacement=Vector3()\n",
    "        angle=Vector3()\n",
    "        force=Vector3()\n",
    "        currPosition=Vector3()\n",
    "        orient=Quaternion()\n",
    "        linMom=Vector3()\n",
    "        angMom=Vector3()\n",
    "        intForce=Vector3()\n",
    "        intMoment=Vector3()\n",
    "        moment=Vector3()\n",
    "        new(id,position,restrained,displacement,angle,force,currPosition,orient,linMom,angMom,intForce,intMoment,moment)\n",
    "    end\n",
    "end\n",
    "struct Edge\n",
    "    id::Int32\n",
    "    source::Int32 #change to Int32\n",
    "    target::Int32\n",
    "    area::Float32\n",
    "    density::Float32\n",
    "    stiffness::Float32\n",
    "    stress::Float32\n",
    "    axis::Vector3\n",
    "    currentRestLength::Float32\n",
    "    pos2::Vector3\n",
    "    angle1v::Vector3\n",
    "    angle2v::Vector3\n",
    "    angle1::Quaternion\n",
    "    angle2::Quaternion\n",
    "    currentTransverseStrainSum::Float32\n",
    "    ## add pos node and negative node\n",
    "    ## add memory cuda??\n",
    "    function Edge()\n",
    "        id=0\n",
    "        source=0\n",
    "        target=0\n",
    "        area=0.0\n",
    "        density=0.0\n",
    "        stiffness=0.0\n",
    "        stress=0.0\n",
    "        axis=Vector3(1.0,0.0,0.0)\n",
    "        currentRestLength=0.0\n",
    "        pos2=Vector3()\n",
    "        angle1v=Vector3()\n",
    "        angle2v=Vector3()\n",
    "        angle1=Quaternion()\n",
    "        angle2=Quaternion()\n",
    "        currentTransverseStrainSum=0.0\n",
    "        \n",
    "        new(id,source,target,area,density,stiffness,stress,axis,currentRestLength,pos2,angle1v,angle2v,angle1,angle2,currentTransverseStrainSum)\n",
    "    end\n",
    "end\n",
    "\n",
    "function Base.show(io::IO, v::Node)\n",
    "    print(io, \"node:$(v.id), position:($(v.position)), restrained:$(v.restrained)\")\n",
    "end\n",
    "\n",
    "function Base.show(io::IO, v::Edge)\n",
    "    print(io, \"edge:$(v.id), source:$(v.source), target:$(v.target), stress:$(v.stress), axis:($(v.axis))\")\n",
    "end\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# allowscalar(false)\n",
    "m=Vector3(1.0,1.0,1.0)\n",
    "mg=CuArray([m,m,m])\n",
    "mm=mg.+m\n",
    "# broadcast(+, mg, m)\n",
    "# broadcast(addX, mg, 1)\n",
    "\n",
    "# m=node\n",
    "# println(m)\n",
    "m=Node()\n",
    "mg=CuArray([m,m,m])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###############################################################################################\n",
    "setup = Dict()\n",
    "open(\"../json/setupTest.json\", \"r\") do f\n",
    "    global setup\n",
    "    dicttxt = String(read(f))  # file information to string\n",
    "    setup=JSON.parse(dicttxt)  # parse and transform data\n",
    "end\n",
    "\n",
    "setup=setup[\"setup\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function initialize(setup)\n",
    "\tnodes      = setup[\"nodes\"]\n",
    "    edges      = setup[\"edges\"]\n",
    "    \n",
    "    i=1\n",
    "\t# pre-calculate current position\n",
    "\tfor node in nodes\n",
    "        # element=parse(Int,node[\"id\"][2:end])\n",
    "        N_position[i,:]=[node[\"position\"][\"x\"] node[\"position\"][\"y\"] node[\"position\"][\"z\"]]\n",
    "        \n",
    "        N_positionV[i]=Vector3(node[\"position\"][\"x\"],node[\"position\"][\"y\"],node[\"position\"][\"z\"])\n",
    "        \n",
    "        append!(N_degrees_of_freedom,[node[\"degrees_of_freedom\"]])\n",
    "        N_restrained_degrees_of_freedom[i,:]=node[\"restrained_degrees_of_freedom\"]\n",
    "        N_restrained[i]=node[\"restrained_degrees_of_freedom\"][1]\n",
    "        append!(N_displacement,[[node[\"displacement\"][\"x\"] node[\"displacement\"][\"y\"] node[\"displacement\"][\"z\"]]])\n",
    "        append!(N_angle,[[node[\"angle\"][\"x\"] node[\"angle\"][\"y\"] node[\"angle\"][\"z\"]]])\n",
    "        append!(N_force,[[node[\"force\"][\"x\"] node[\"force\"][\"y\"] node[\"force\"][\"z\"]]])\n",
    "        append!(N_currPosition,[[node[\"position\"][\"x\"] node[\"position\"][\"y\"] node[\"position\"][\"z\"]]])\n",
    "        append!(N_orient,[Quat(1.0,0.0,0.0,0.0)])#quat\n",
    "        append!(N_linMom,[[0 0 0]])\n",
    "        append!(N_angMom,[[0 0 0]])\n",
    "        append!(N_intForce,[[0 0 0]])\n",
    "        append!(N_intMoment,[[0 0 0]])\n",
    "        append!(N_moment,[[0 0 0]])\n",
    "        \n",
    "        # for dynamic simulations\n",
    "        append!(N_posTimeSteps,[[]])\n",
    "        append!(N_angTimeSteps,[[]])\n",
    "        \n",
    "        i=i+1\n",
    " \n",
    "\tend \n",
    "    \n",
    "    i=1\n",
    "\t# pre-calculate the axis\n",
    "\tfor edge in edges\n",
    "        # element=parse(Int,edge[\"id\"][2:end])\n",
    "        \n",
    "        # find the nodes that the lements connects\n",
    "        fromNode = nodes[edge[\"source\"]+1]\n",
    "        toNode = nodes[edge[\"target\"]+1]\n",
    "\n",
    "        \n",
    "        node1 = [fromNode[\"position\"][\"x\"] fromNode[\"position\"][\"y\"] fromNode[\"position\"][\"z\"]]\n",
    "        node2 = [toNode[\"position\"][\"x\"] toNode[\"position\"][\"y\"] toNode[\"position\"][\"z\"]]\n",
    "        \n",
    "        length=norm(node2-node1)\n",
    "        axis=normalize(collect(Iterators.flatten(node2-node1)))\n",
    "        \n",
    "        append!(E_source,[edge[\"source\"]+1])\n",
    "        append!(E_target,[edge[\"target\"]+1])\n",
    "        append!(E_area,[edge[\"area\"]])\n",
    "        append!(E_density,[edge[\"density\"]])\n",
    "        append!(E_stiffness,[edge[\"stiffness\"]])\n",
    "        append!(E_stress,[0])\n",
    "        append!(E_axis,[axis])\n",
    "        append!(E_currentRestLength,[length])\n",
    "        append!(E_pos2,[[0 0 0]])\n",
    "        append!(E_angle1v,[[0 0 0]])\n",
    "        append!(E_angle2v,[[0 0 0]])\n",
    "        append!(E_angle1,[Quat(1.0,0,0,0)]) #quat\n",
    "        append!(E_angle2,[Quat(1.0,0,0,0)]) #quat\n",
    "        append!(E_currentTransverseStrainSum,[0])\n",
    "        \n",
    "        # for dynamic simulations\n",
    "        append!(E_stressTimeSteps,[[]])\n",
    "        \n",
    "        i=i+1\n",
    "\tend \n",
    "\t\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "########\n",
    "voxCount=0\n",
    "linkCount=0\n",
    "nodes      = setup[\"nodes\"]\n",
    "edges      = setup[\"edges\"]\n",
    "voxCount=size(nodes)[1]\n",
    "linkCount=size(edges)[1]\n",
    "strain =0 #todooo moveeee\n",
    "\n",
    "############# nodes\n",
    "N_position=fill(Vector3(),voxCount)\n",
    "N_restrained=zeros(Bool, voxCount)\n",
    "N_displacement=fill(Vector3(),voxCount)\n",
    "N_angle=fill(Vector3(),voxCount)\n",
    "N_currPosition=fill(Vector3(),voxCount)\n",
    "N_linMom=fill(Vector3(),voxCount)\n",
    "N_angMom=fill(Vector3(),voxCount)\n",
    "N_intForce=fill(Vector3(),voxCount)\n",
    "N_intMoment=fill(Vector3(),voxCount)\n",
    "N_moment=fill(Vector3(),voxCount)\n",
    "# N_posTimeSteps=[]\n",
    "# N_angTimeSteps=[]\n",
    "N_force=fill(Vector3(),voxCount)\n",
    "N_orient=fill(Quaternion(),voxCount)\n",
    "\n",
    "\n",
    "\n",
    "############# edges\n",
    "E_source=fill(0,linkCount)\n",
    "E_target=fill(0,linkCount)\n",
    "E_area=fill(0.0F0,linkCount)\n",
    "E_density=fill(0.0F0,linkCount)\n",
    "E_stiffness=fill(0.0F0,linkCount)\n",
    "E_stress=fill(0.0F0,linkCount)\n",
    "E_axis=fill(Vector3(1.0,0.0,0.0),linkCount)\n",
    "E_currentRestLength=fill(0.0F0,linkCount)\n",
    "E_pos2=fill(Vector3(),linkCount)\n",
    "E_angle1v=fill(Vector3(),linkCount)\n",
    "E_angle2v=fill(Vector3(),linkCount)\n",
    "E_angle1=fill(Quaternion(),voxCount)\n",
    "E_angle2=fill(Quaternion(),voxCount)\n",
    "E_currentTransverseStrainSum=fill(0.0F0,linkCount)# TODO remove ot incorporate\n",
    "# E_stressTimeSteps=[]\n",
    "\n",
    "\n",
    "\n",
    "initialize(setup)\n",
    "N_position\n",
    "N_positionGPU=CuArray(N_position)\n",
    "N_restrained_degrees_of_freedomGPU=CuArray(N_restrained_degrees_of_freedom)\n",
    "N_restrainedGPU=CuArray(N_restrained)\n",
    "N_positionVGPU=CuArray(N_positionV)\n",
    "# N_positionV .= ifelse.(N_restrained .==true, N_positionV, N_positionV .+ 1 )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function trialGPU!(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "    index = (blockIdx().x - 1) * blockDim().x + threadIdx().x\n",
    "    stride = blockDim().x * gridDim().x\n",
    "    for i = index:stride:length(vecGPU)\n",
    "        @inbounds  N_positionVGPU[i]=ifelse(N_restrainedGPU[i], N_positionVGPU[i], N_positionVGPU[i] +vecGPU[i] )\n",
    "    end\n",
    "    return\n",
    "end\n",
    "\n",
    "# function trialGPU!(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "#     i = (blockIdx().x-1) * blockDim().x + threadIdx().x\n",
    "#     N_positionVGPU[i]=ifelse(N_restrainedGPU[i], N_positionVGPU[i], N_positionVGPU[i] +vecGPU[i] )\n",
    "#     return\n",
    "# end\n",
    "\n",
    "function bench_gpu!(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "    N=length(vecGPU)\n",
    "    numblocks = ceil(Int, N/256)\n",
    "    CuArrays.@sync begin\n",
    "        @cuda threads=256 blocks=numblocks trialGPU!(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "    end\n",
    "end\n",
    "\n",
    "# @cuda threads=voxCount trialGPU(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "# N_positionVGPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function trialCPU!(N_positionV, N_restrained, vec)\n",
    "    for i in 1:length(vec)\n",
    "        @inbounds N_positionV[i]=ifelse(N_restrained[i], N_positionV[i], N_positionV[i] +vec[i] )\n",
    "    end\n",
    "    return\n",
    "end\n",
    "\n",
    "# trialCPU!(N_positionV, N_restrained, vec)\n",
    "# N_positionV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "voxCount=2^12\n",
    "\n",
    "############# nodes\n",
    "N_positionV=fill(Vector3(0,0,0),voxCount)\n",
    "N_restrained=zeros(Bool, voxCount)\n",
    "\n",
    "N_restrainedGPU=CuArray(N_restrained)\n",
    "N_positionVGPU=CuArray(N_positionV)\n",
    "\n",
    "vec = fill(Vector3(1,1,1),voxCount)\n",
    "vecGPU = CuArray(vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @btime bench_gpu!(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "@btime bench_gpu!(N_positionVGPU, N_restrainedGPU, vecGPU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@btime trialCPU!(N_positionV, N_restrained, vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num=[2^8,2^9,2^10]\n",
    "resulstGPU=[34.89e-6,35.6e-6,37.3e-6]\n",
    "resulstCPU=[347.437e-9,676.6e-9,1.38e-6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function bench()\n",
    "    num=[]\n",
    "    resultsGPU=[]\n",
    "    resultsCPU=[]\n",
    "    for i in 8:25\n",
    "        voxCount=2^i\n",
    "\n",
    "        ############# nodes\n",
    "        N_positionV=fill(Vector3(0,0,0),voxCount)\n",
    "        N_restrained=zeros(Bool, voxCount)\n",
    "\n",
    "        N_restrainedGPU=CuArray(N_restrained)\n",
    "        N_positionVGPU=CuArray(N_positionV)\n",
    "\n",
    "        vec = fill(Vector3(1,1,1),voxCount)\n",
    "        vecGPU = CuArray(vec)\n",
    "        \n",
    "        res=@timed bench_gpu!(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "        res1=@timed trialCPU!(N_positionV, N_restrained, vec)\n",
    "        \n",
    "        append!(num,voxCount)\n",
    "        append!(resultsGPU,res[2])\n",
    "        append!(resultsCPU,res1[2])\n",
    "    end\n",
    "    return num,resultsGPU,resultsCPU\n",
    "end\n",
    "num,resultsGPU,resultsCPU=bench()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(num,[resultsGPU resultsCPU],label = [\"GPU\" \"CPU\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N=100\n",
    "function repeatGPU(N)\n",
    "    for i in 1:N\n",
    "        @cuda threads=voxCount trialGPU(N_positionVGPU, N_restrainedGPU, vecGPU)\n",
    "    end\n",
    "end\n",
    "@btime repeatGPU(N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function repeatCPU(N)\n",
    "    for i in 1:N\n",
    "        trialCPU!(N_positionV, N_restrained, vec)\n",
    "    end\n",
    "end\n",
    "@btime repeatCPU(N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Trial 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_positionV"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.2.0",
   "language": "julia",
   "name": "julia-1.2"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.2.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}