import * as THREE from "three"
import { ArcballControls } from "three/addons/controls/ArcballControls.js"
import pako from "pako"
import msgpack from "msgpack-lite"

export const colors = {
  primary: "#147BD1",
  secondary: "#333F48",
  accent: "#C7C9C7",
  background: "#1E1E1E",
  success: "#4CAF50",
  warning: "#FB8C00",
  error: "#FF5252",
  magenta: "#FF00FF"
}

const state = {
  scanData: null,
  views: {}
}

const codec = msgpack.createCodec({
  binarraybuffer: true,
  preset: true
})

export async function register(id) {
  return new Promise((resolve) => {
    if (!state.scanData) {
      fetchBinData("/scan_data.bin").then((binData) => {
        const msgPack = pako.inflate(binData)
        state.scanData = msgpack.decode(new Uint8Array(msgPack), {
          codec
        })
        state.views[id] = new View(id)
        resolve()
      })
    } else {
      if (id in state.views) {
        const view = state.views[id]
        view.reset()
      } else {
        state.views[id] = new View(id)
      }
      resolve()
    }
  })
}

export function showHeatmapColors(id) {
  console.assert(id in state.views)
  const view = state.views[id]
  view.showHeatmapColors()
}

export function showScoreColors(id) {
  console.assert(id in state.views)
  const view = state.views[id]
  view.showScoreColors()
}

export function setGaugeColor(id, gaugeIndex, color) {
  console.assert(id in state.views)
  const view = state.views[id]
  view.setGaugeColor(gaugeIndex, color)
}

function fetchBinData(url) {
  return new Promise((resolve) => {
    fetch(url)
      .then((response) => response.arrayBuffer())
      .then((buffer) => resolve(buffer))
  })
}

class Material extends THREE.MeshStandardMaterial {
  constructor(parameters) {
    super(parameters)

    this.onBeforeCompile = (shader) => {
      shader.uniforms.t = { value: 0.0 }
      shader.vertexShader = `
        attribute vec3 color1;
        attribute vec3 color2;
        uniform float t;

        ${shader.vertexShader}
      `
      shader.vertexShader = shader.vertexShader.replace(
        "#include <color_vertex>",
        `
        vColor = mix(color1, color2, t);
        `
      )
      shader.fragmentShader = `
        ${shader.fragmentShader}
      `
      shader.fragmentShader = shader.fragmentShader.replace(
        "#include <color_fragment>",
        `
        diffuseColor.rgb *= vColor;
        `
      )
      this.uniforms = shader.uniforms
    }
  }

  get t() {
    return this.uniforms.t.value
  }

  set t(value) {
    this.uniforms.t.value = value
  }
}

class View {
  constructor(id) {
    this.id = id
    this.domElement = document.getElementById(this.id)
    this.renderer = new THREE.WebGLRenderer({ alpha: true, antialias: true })
    this.renderer.setSize(
      this.domElement.clientWidth,
      this.domElement.clientHeight
    )
    this.domElement.appendChild(this.renderer.domElement)

    this.scene = new THREE.Scene()
    this.camera = new THREE.PerspectiveCamera(
      45,
      this.domElement.clientWidth / this.domElement.clientHeight,
      1,
      10000
    )
    this.cameraWorldDirection = new THREE.Vector3()

    this.controls = new ArcballControls(this.camera, this.renderer.domElement)
    this.controls.unsetMouseAction(1)
    this.controls.unsetMouseAction(2)
    this.controls.setMouseAction("PAN", 2)
    this.controls.addEventListener("change", () => {
      this.render()
    })

    this.mesh = View.createMesh()
    this.keyLight = new THREE.DirectionalLight(0xffffff, 2.0)
    this.keyLight.position.set(0, 0, 1).normalize()
    const ambientLight = new THREE.AmbientLight(colors.accent, 0.5)
    this.scene.add(this.keyLight, ambientLight, this.mesh)

    this.lookAtMesh()
  }

  reset() {
    this.domElement = document.getElementById(this.id)
    this.domElement.appendChild(this.renderer.domElement)

    const bounds = this.domElement.getBoundingClientRect()
    this.renderer.setSize(bounds.width, bounds.height)

    this.lookAtMesh()
  }

  lookAtMesh() {
    this.camera.aspect =
      this.domElement.clientWidth / this.domElement.clientHeight

    const boundingBox = new THREE.Box3()
    boundingBox.setFromObject(this.scene)
    const size = new THREE.Vector3()
    boundingBox.getSize(size)

    const fov = this.camera.fov * (Math.PI / 180)
    const fovh = 2 * Math.atan(Math.tan(fov / 2) * this.camera.aspect)
    let dx = size.z / 2 + Math.abs(size.x / 2 / Math.tan(fovh / 2))
    let dy = size.z / 2 + Math.abs(size.y / 2 / Math.tan(fov / 2))
    let cameraZ = Math.max(dx, dy)
    this.camera.position.set(0, 0, cameraZ)

    const minZ = boundingBox.min.z
    const cameraToFarEdge = minZ < 0 ? -minZ + cameraZ : cameraZ - minZ
    this.camera.far = cameraToFarEdge * 3
    this.camera.updateProjectionMatrix()

    this.controls.target = new THREE.Vector3(0, 0, 0)
    this.controls.maxDistance = cameraToFarEdge * 2
    this.controls.update()

    this.render()
  }

  render() {
    this.camera.getWorldDirection(this.cameraWorldDirection)
    this.keyLight.position.copy(this.cameraWorldDirection.negate())
    this.renderer.render(this.scene, this.camera)
  }

  showHeatmapColors() {
    this.mesh.material.uniforms.t.value = 1.0
    this.render()
  }

  showScoreColors() {
    this.mesh.material.uniforms.t.value = 0.0
    this.render()
  }

  setGaugeColor(gaugeIndex, color) {
    const { scanGauges } = state.scanData
    console.assert(gaugeIndex >= 0 && gaugeIndex < scanGauges.length)

    const colors1 = this.mesh.geometry.getAttribute("color1")
    const indices = this.mesh.geometry.getIndex()
    const colorObj = new THREE.Color(color)
    const gauge = scanGauges[gaugeIndex]
    for (const faceIndex of gauge.faceIndices) {
      const vertexIndices = indices.array.slice(
        faceIndex * 3,
        faceIndex * 3 + 3
      )
      for (const vertexIndex of vertexIndices) {
        colors1.array[vertexIndex * 3] = colorObj.r
        colors1.array[vertexIndex * 3 + 1] = colorObj.g
        colors1.array[vertexIndex * 3 + 2] = colorObj.b
      }
    }
    colors1.needsUpdate = true
    this.render()
  }

  static createMesh() {
    const { scanMesh, scanGauges, analysis } = state.scanData
    const geometry = new THREE.BufferGeometry()

    const positions = new Float32Array(scanMesh.vertexPositions)
    geometry.setAttribute("position", new THREE.BufferAttribute(positions, 3))

    const normals = new Float32Array(scanMesh.vertexNormals)
    geometry.setAttribute("normal", new THREE.BufferAttribute(normals, 3))

    const indices = new Uint32Array(scanMesh.indices)
    geometry.setIndex(new THREE.BufferAttribute(indices, 1))

    const colors1 = new THREE.BufferAttribute(
      new Float32Array(positions.length),
      3
    ).setUsage(THREE.StreamDrawUsage)
    geometry.setAttribute("color1", colors1)

    const colors2 = new THREE.BufferAttribute(
      new Float32Array(positions.length),
      3
    ).setUsage(THREE.StreamDrawUsage)
    geometry.setAttribute("color2", colors2)

    const defaultColor = new THREE.Color(colors.accent)
    for (let i = 0; i < colors1.array.length; i += 3) {
      colors1.array[i] = defaultColor.r
      colors1.array[i + 1] = defaultColor.g
      colors1.array[i + 2] = defaultColor.b
      colors2.array[i] = defaultColor.r
      colors2.array[i + 1] = defaultColor.g
      colors2.array[i + 2] = defaultColor.b
    }

    const color1 = new THREE.Color()
    const color2 = new THREE.Color()
    for (let i = 0; i < scanGauges.length; ++i) {
      const gauge = scanGauges[i]
      const gaugeComparison = analysis.gaugeComparisons[i]
      const score = gaugeComparison.scoreRestorative.toFixed(1)
      color1.setStyle(scanScoreTypeColor(scanScoreTypeFromValue(score)))
      for (const faceIndex of gauge.faceIndices) {
        const vertexIndices = indices.slice(faceIndex * 3, faceIndex * 3 + 3)
        for (const vertexIndex of vertexIndices) {
          colors1.array[vertexIndex * 3] = color1.r
          colors1.array[vertexIndex * 3 + 1] = color1.g
          colors1.array[vertexIndex * 3 + 2] = color1.b

          const distance = analysis.distances[vertexIndex]
          heatmapColor(color2, distance, 0, 0.1)
          colors2.array[vertexIndex * 3] = color2.r
          colors2.array[vertexIndex * 3 + 1] = color2.g
          colors2.array[vertexIndex * 3 + 2] = color2.b
        }
      }
    }
    colors1.needsUpdate = colors2.needsUpdate = true

    const material = new Material({
      side: THREE.DoubleSide,
      roughness: 0.4,
      metalness: 0.5,
      vertexColors: true
    })

    return new THREE.Mesh(geometry, material)
  }
}

const scanScoreType = {
  failure: "failure",
  warning: "warning",
  success: "success"
}

function scanScoreTypeFromValue(scoreValue) {
  if (scoreValue >= 1 && scoreValue < 3) {
    return scanScoreType.failure
  } else if (scoreValue >= 3 && scoreValue < 4) {
    return scanScoreType.warning
  } else {
    return scanScoreType.success
  }
}

function scanScoreTypeColor(scoreType) {
  switch (scoreType) {
    case scanScoreType.failure:
      return colors.error
    case scanScoreType.warning:
      return colors.warning
    case scanScoreType.success:
      return colors.success
  }
}

function heatmapColor(outColor, distance, minDistance, maxDistance) {
  if (distance < minDistance || distance > maxDistance) {
    outColor.setStyle("#FF00FF")
    return
  }

  const x = (distance - minDistance) / (maxDistance - minDistance)

  let r = x < 0.7 ? 4.0 * x - 1.5 : -4.0 * x + 4.5
  let g = x < 0.5 ? 4.0 * x : -4.0 * x + 3.5
  let b = x < 0.2 ? 4.0 * x + 0.5 : -4.0 * x + 1.8
  if (x < 0.15) {
    g = 4.0 * x - 0.15
  }

  outColor.set(clamp(r), clamp(g), clamp(b))
}

function clamp(val, min = 0, max = 1) {
  return Math.min(Math.max(val, min), max)
}
