import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D


def set_axes_equal(ax):
    """Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc..  This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().
    """

    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5 * max([x_range, y_range, z_range])
    zscale = 1.5
    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius / zscale, z_middle + plot_radius / zscale])


def load_results(filepath):

    with open(filepath) as f:
        data = json.load(f)

    nodes = data["nodes"]
    r = np.zeros([len(nodes), 3])

    for i, node in enumerate(nodes):
        r[i, 0] = node["position"]["x"] + node["displacement"]["x"]
        r[i, 1] = node["position"]["y"] + node["displacement"]["y"]
        r[i, 2] = node["position"]["z"] + node["displacement"]["z"]

    edges = data["edges"]
    conn = np.zeros([len(edges), 2])

    for i, edge in enumerate(edges):
        conn[i, 0] = edge["source"]
        conn[i, 1] = edge["target"]
        conn = conn.astype(int)

    return (r, conn)


if __name__ == "__main__":

    nfiles = 253
    R = []
    Conn = []

    t = np.arange(nfiles)
    tension = np.zeros_like(t)
    tension[0 : int(nfiles / 6)] = np.linspace(0, 70, int(nfiles / 6))
    tension[int(nfiles / 6) : :] = 70

    for i in range(nfiles):
        print("File {0} of {1}".format(i + 1, nfiles))
        filepath = "./json/dynamicVal/{0}.json".format(i)
        (r, conn) = load_results(filepath)
        fig = plt.figure(figsize=(8, 4))
        gs = gridspec.GridSpec(1, 3)
        ax1 = plt.subplot(gs[:, 0:2])
        for c in conn:
            ax1.plot(r[c, 0], r[c, 1])
        ax = plt.gca()
        ax.set_aspect("equal", adjustable="box")
        plt.xlim([-25, 175])
        plt.ylim([-20, 140])
        ax1.set_xlabel("x [mm]")
        ax1.set_ylabel("y [mm]")
        ax2 = plt.subplot(gs[:, 2])
        plt.plot(t / t.max(), tension)
        plt.plot(t[i] / t.max(), tension[i], "ro")
        ax2.set_xlabel("Time")
        ax2.set_ylabel("Root Tension [N]")
        plt.tight_layout()
        plt.savefig("./json/dynamicVal/fig{0:03d}.png".format(i))
        plt.close()