import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D


def set_axes_equal(ax):
    """Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc..  This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().
    """

    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5 * max([x_range, y_range, z_range])
    zscale = 1.5
    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius / zscale, z_middle + plot_radius / zscale])


def load_results(filepath):

    with open(filepath) as f:
        data = json.load(f)

    nodes = data["nodes"]
    r = np.zeros([len(nodes), 3])

    for i, node in enumerate(nodes):
        r[i, 0] = node["position"]["x"] + node["displacement"]["x"]
        r[i, 1] = node["position"]["y"] + node["displacement"]["y"]
        r[i, 2] = node["position"]["z"] + node["displacement"]["z"]

    edges = data["edges"]
    conn = np.zeros([len(edges), 2])

    for i, edge in enumerate(edges):
        conn[i, 0] = edge["source"]
        conn[i, 1] = edge["target"]
        conn = conn.astype(int)

    return (r, conn)


if __name__ == "__main__":

    nfiles = 1199
    R = []
    Conn = []

    an = np.array([8.17, 51.2, 143])
    L = 1
    w = 0.05
    h = 0.05
    A = w * h
    E = 69e9  # Pa
    rho = 2700
    I = w * h ** 3 / 12
    Beta = np.sqrt(E * I / (rho * A * L ** 4))

    t = np.arange(nfiles)
    tension = np.zeros_like(t)
    tension[0 : int(nfiles / 6)] = np.linspace(0, 70, int(nfiles / 6))
    tension[int(nfiles / 6) : :] = 70

    X = []
    Z = []
    for i in range(nfiles):
        filepath = "./json/dynamicVal/{0}.json".format(i)
        (r, conn) = load_results(filepath)
        X.append(r[:, 0] - np.mean(r[:, 0]))
        Z.append(r[:, 2])
    # for i in range(nfiles):
    #     print("File {0} of {1}".format(i + 1, nfiles))
    #     filepath = "./json/dynamicVal/{0}.json".format(i)
    #     (r, conn) = load_results(filepath)
    #     print("maximum displacement = {0}".format(np.max(r[:, 2])))
    #     fig = plt.figure(figsize=(8, 4))
    #     ax1 = plt.subplot(111)
    #     for c in conn:
    #         ax1.plot(r[c, 0], r[c, 2])
    #     ax = plt.gca()
    #     # ax.set_aspect("equal", adjustable="box")
    #     # plt.xlim([-25, 175])
    #     plt.ylim([-0.03, 0.03])
    #     ax1.set_xlabel("x [mm]")
    #     ax1.set_ylabel("y [mm]")
    #     plt.tight_layout()
    #     plt.savefig("./json/dynamicVal/fig{0:03d}.png".format(i))
    #     plt.close()
    X = np.array(X)
    Z = np.array(Z)

    dt = 4 * 3.1483100054946037e-7
    steps = 3000000
    filesteps = steps / nfiles
    dT = dt * filesteps
    t = np.arange(0, dT * nfiles, dT)
    plt.plot(t, 1000 * Z[:, -3:-1])
    plt.xlim([0, 4])
    plt.xlabel("Time [s]")
    plt.ylabel("Displacement [mm]")
    Sz = np.fft.fft(Z[:, -3:-1], axis=0)
    fmax = 1 / dT
    f = np.linspace(0, fmax, len(Z))
    plt.figure()
    plt.semilogy(f, np.abs(Sz))
    for a in an:
        plt.semilogy(
            [a, a], [np.min(np.abs(Sz)), np.max(np.abs(Sz))], "r-.",
        )
    plt.xlim([0, 200])
    plt.xlabel("Frequency [Hz]")
    plt.ylabel("Amplitude")
    plt.show()