import { ImageVolume, eventTarget, Enums } from '@cornerstonejs/core';
import vtkDataArray from '@kitware/vtk.js/Common/Core/DataArray';
import vtkMath from '@kitware/vtk.js/Common/Core/Math';
import vtkImageData from '@kitware/vtk.js/Common/DataModel/ImageData';
import vtkPolyData from '@kitware/vtk.js/Common/DataModel/PolyData';
import { triggerVolumeLoadedEvent } from '../cornerstone/fetcher';
import { getSegmentColor } from './trkLoader';
import { TRACTS_COLOR_MODE } from '../../constants';

const { VOLUME_CACHE_VOLUME_ADDED } = Enums.Events;

const Bresenham3D = (v1, v2, callback) => {
  let [ x1, y1, z1 ] = v1;
  const [ x2, y2, z2 ] = v2;
  callback(x1, y1, z1);

  const dx = Math.abs(x2 - x1);
  const dy = Math.abs(y2 - y1);
  const dz = Math.abs(z2 - z1);
  const xs = x2 > x1 ? 1 : -1;
  const ys = y2 > y1 ? 1 : -1;
  const zs = z2 > z1 ? 1 : -1;

  if (dx >= dy && dx >= dz) {
    // Driving axis is X-axis
    let p1 = 2 * dy - dx;
    let p2 = 2 * dz - dx;
    while (x1 !== x2) {
      x1 += xs;
      if (p1 >= 0) {
        y1 += ys;
        p1 -= 2 * dx;
      }
      if (p2 >= 0) {
        z1 += zs;
        p2 -= 2 * dx;
      }
      p1 += 2 * dy;
      p2 += 2 * dz;
      callback(x1, y1, z1);
    }
  } else if (dy >= dx && dy >= dz) {
    // Driving axis is Y-axis
    let p1 = 2 * dx - dy;
    let p2 = 2 * dz - dy;
    while (y1 !== y2) {
      y1 += ys;
      if (p1 >= 0) {
        x1 += xs;
        p1 -= 2 * dy;
      }
      if (p2 >= 0) {
        z1 += zs;
        p2 -= 2 * dy;
      }
      p1 += 2 * dx;
      p2 += 2 * dz;
      callback(x1, y1, z1);
    }
  }	else {
    // Driving axis is Z-axis
    let p1 = 2 * dy - dz;
    let p2 = 2 * dx - dz;
    while (z1 !== z2) {
      z1 += zs;
      if (p1 >= 0) {
        y1 += ys;
        p1 -= 2 * dz;
      }
      if (p2 >= 0) {
        x1 += xs;
        p2 -= 2 * dz;
      }
      p1 += 2 * dy;
      p2 += 2 * dx;
      callback(x1, y1, z1);
    }
  }
}

const vec = (p1, p2) => {
  return p2.map((p, i) => p - p1[i]);
}

const dot = (v1, v2) => {
  return v1.reduce((a, p, i) => a+ p * v2[i], 0);
}

const cross = (v1, v2) => {
  return [
    v1[1] * v2[2] - v1[2] * v2[1],
    v1[2] * v2[0] - v1[0] * v2[2],
    v1[0] * v2[1] - v1[1] * v2[0]
  ]
}

const createRect = (points) => {
  const r = {
    points,
    n: norm(points),
    ab: vec(points[0], points[1]),
    ad: vec(points[0], points[2])
  };

  r.abl = dot(r.ab, r.ab);
  r.adl = dot(r.ad, r.ad);

  return r;
}

const segmentPlaneIntersection = (p1, p2, p0, n) => {
  const a = dot(p0.map((p, i) => p - p1[i]), n);
  const b = dot(vec(p1, p2), n);
  if(b === 0) return undefined;

  const d = a / b;
  if(Math.abs(d) > 1) return undefined;

  return p1.map((p, i) => p + (p2[i] - p) * d);
}

const segmentRectIntersection = (p1, p2, rect) => {
  const ip = segmentPlaneIntersection(p1, p2, rect.points[0], rect.n);
  if(!ip) return undefined;

  const am = vec(rect.points[0], ip);
  const amab = dot(am, rect.ab);
  const amad = dot(am, rect.ad);

  if(amab > 0 && amab < rect.abl && amad > 0 && amad < rect.adl) return ip;
  return undefined;
}

const norm = (rect) => {
  return cross(vec(rect[0], rect[1]), vec(rect[0], rect[2]));
}

// TODO: filtering and all other stuff must be implemented via vtk filters

class TrkImageVolume extends ImageVolume {
  constructor({polyData, quality, ...props}) {
    super(props);
    this.clusters = undefined;
    this.clusterColors = undefined;
    this.colorMode = TRACTS_COLOR_MODE.NORMALS;
    this.sourceData = polyData;
    this.sampledTracks = this.metadata.tracks;
    this.sampledLines = polyData.getLines().getData();
    this.sampledScalars = polyData.getCellData().getScalars().getData();

    this.visibleTracks = this.sampledTracks;
    this.visibleLines = this.sampledLines;
    this.visibleScalars = this.sampledScalars;
    this.visibleClusters = undefined;

    const da = vtkDataArray.newInstance({
      numberOfComponents: 3,
      values: polyData.getCellData().getScalars().getData(),
      name: 'Scalars'
    });
    this.overrideColor = undefined;

    this.polyData = vtkPolyData.newInstance();
    this.polyData.setPoints(polyData.getPoints());
    this.polyData.getLines().setData(polyData.getLines().getData());
    this.polyData.getCellData().setScalars(da);

    this.filterScalars = da.getData();
    this.filterRect = undefined;

    this.csRenderable = false;
    this.isReady = false;

    const onAdded = ({ detail }) => {
      // this is needed to properly display track colors
      // since cornerstone don't support RGBA point data (and RGB as well)
      if(props.volumeId === detail.volume.volumeId) {
        this._rebuildImageData();
        triggerVolumeLoadedEvent(this);
        eventTarget.removeEventListener(VOLUME_CACHE_VOLUME_ADDED, onAdded);
      }
    }

    eventTarget.addEventListener(VOLUME_CACHE_VOLUME_ADDED, onAdded);

    const start = new Date();

    // if(quality !== undefined) {
      // triggerVolumeLoadedEvent(this);
      // this.setQuality(quality);
    // } else {
      this._createVoxelVolume();
    // }

    this.metadata.voxelizationTime = new Date().getTime() - start.getTime();
  }

  setClusterData(clusters) {
    this.clusters = this.metadata.tracks.reduce((a, t, i) => {
      const clusterId = clusters.labels[i];
      if(!a[clusterId]) a[clusterId] = [i];
      else a[clusterId].push(i);
      return a;
    }, {});
    this.clusterColors = clusters.colors;
    this.visibleClusters = Object.keys(this.clusters);
  }

  setVisibleClusters(visibleList) {
    // performs on sampled tracks
    if(!this.clusters || !visibleList) {
      this.visibleLines = this.sampledLines;
      this.visibleScalars = this.sampledScalars;
    } else {
      this.visibleLines = [];
      this.visibleScalars = [];

      Object.entries(this.clusters).filter(([k, v]) => visibleList.includes(k))
        .forEach(([k, v]) => {
          v.forEach(trackIndex => {
            const track = this.sampledTracks[trackIndex];
            this.visibleLines.push(
              ...this.sampledLines.slice(track.index, track.index + 3 * (track.numPoints - 1))
            );
            if(this.colorMode === TRACTS_COLOR_MODE.COLORMAP) {
              for(let i = 0; i < track.numPoints - 1; ++i) {
                this.visibleScalars.push(...this.clusterColors[k]);
              }
            } else {
              this.visibleScalars.push(
                ...this.sampledScalars.slice(track.index, track.index + 3 * (track.numPoints - 1))
              );
            }
          });
      });
    }

    this.visibleClusters = visibleList;

    this.filter(this.filterRect);
  }

  setQuality(quality) {
    const q = Math.max(TrkImageVolume.MAX_QUALITY - quality, 1);

    if(q === 1) {
      this.sampledLines = this.sourceData.getLines().getData();
      this.sampledScalars = this.sourceData.getCellData().getScalars().getData();
      this.sampledTracks = this.metadata.tracks;
    } else {
      const vertexArray = this.sourceData.getPoints().getData();
      this.sampledLines = [];
      this.sampledTracks = [];

      this.metadata.tracks.forEach((track, k) => {
        let startPoint = undefined;
        const trackIndex = this.sampledLines.length;
        // TODO: always add last point of the track
        for(let j = 0; j < track.numPoints - 1; ++j) {
          if(j % q !== 0) continue;
          if(startPoint !== undefined) {
            this.sampledLines.push(2, startPoint, track.index / 3 + k + j);
          }
          startPoint = track.index / 3 + k + j;
        }

        this.sampledTracks.push({
          ...track,
          index: trackIndex,
          numPoints: (this.sampledLines.length - trackIndex) / 3 + 1
        });
      });

      this.sampledScalars = [];
      for(let i = 0; i < this.sampledLines.length; i += 3) {
        this.sampledScalars.push(
          ...getSegmentColor(this.sampledLines, vertexArray, i)
        );
      }
    }

    // this.filter(this.filterRect);
    this.setVisibleClusters(this.visibleClusters);
  }

  setColorMode(mode) {
    this.colorMode = mode;
    this.setVisibleClusters(this.visibleClusters);
  }

  setDefaultColor() {
    this.setColor(undefined);
  }

  setColor(color) {
    this.overrideColor = color;

    if(this.overrideColor) {
      const colorData = [];
      for(let i = 0; i < this.polyData.getCellData().getScalars().getData().length / 3; ++i) {
        colorData.push(...color);
      }
      this.polyData.getCellData().getScalars().setData(Uint8Array.from(colorData));
    } else {
      this.polyData.getCellData().getScalars().setData(Uint8Array.from(this.filterScalars));
    }

    this._createVoxelVolume();
    this._rebuildImageData();
  }

  filter(rect) {
    this.filterRect = rect;
    let fCellArray = [];
    let fScalars = [];

    if(!rect) {
      fCellArray = this.visibleLines;
      fScalars = this.visibleScalars;
    } else {
      const filteredTracks = [];
      const vertexArray = this.sourceData.getPoints().getData();

      const r = createRect(rect);

      this.sampledTracks.forEach((track, k) => {
        let hasIntersection = false;
        for(let j = 0; j < track.numPoints * 3; j += 3) {
          const i1 = 3 * this.sampledLines[track.index + j + 1];
          const i2 = 3 * this.sampledLines[track.index + j + 2];
          const p1 = [vertexArray[i1], vertexArray[i1 + 1], vertexArray[i1 + 2]];
          const p2 = [vertexArray[i2], vertexArray[i2 + 1], vertexArray[i2 + 2]];
          const ip = segmentRectIntersection(p1, p2, r);

          if(ip) {
            hasIntersection = true;
            break;
          }
        }

        if(hasIntersection) filteredTracks.push(track);
      });

      filteredTracks.forEach((track, k) => {
        for(let i = 0; i < track.numPoints * 3; i += 3) {
          fCellArray.push(
            2,
            this.sampledLines[track.index + i + 1],
            this.sampledLines[track.index + i + 2]
          );
          fScalars.push(
            this.sampledScalars[track.index + i],
            this.sampledScalars[track.index + i + 1],
            this.sampledScalars[track.index + i + 2]
          );
        }
      });
    }

    this.filterScalars = fScalars;
    this.polyData.getLines().setData(Uint32Array.from(fCellArray))
    this.polyData.getCellData().getScalars().setData(Uint8Array.from(fScalars));
    this.setColor(this.overrideColor);
  }

  _createVoxelVolume(scale = 1.0) {
    const sourceDims = this.dimensions.map(d => Math.floor(d + 1));
    const targetDims = sourceDims.map(i => i * scale);
    const targetSize = targetDims.reduce((a, i) => a * i, 4);
    const spacing    = this.spacing.map(i => i / scale);

    const voxelData = new Array(targetSize).fill(0);

    const mapToVoxel = (v) => {
      return v.map((x, i) =>
        Math.round((x * spacing[i] + sourceDims[i] / 2) / sourceDims[i] * targetDims[i])
      );
    };

    const lineArray   = this.polyData.getLines().getData();
    const vertexArray = this.polyData.getPoints().getData();
    const colorsArray = this.filterScalars

    for(let j = 0; j < lineArray.length; j += 3) {
      const p1 = 3 * lineArray[j + 1];
      const p2 = 3 * lineArray[j + 2];
      const v1 = mapToVoxel([vertexArray[p1], vertexArray[p1 + 1], vertexArray[p1 + 2]]);
      const v2 = mapToVoxel([vertexArray[p2], vertexArray[p2 + 1], vertexArray[p2 + 2]]);

      Bresenham3D(v1, v2, (x, y, z) => {
        const idx = 4 * (x + targetDims[0] * y + targetDims[0] * targetDims[1] * z);

        for(let k = 0; k < 3; ++k) {
          if(this.colorMode === TRACTS_COLOR_MODE.COLORMAP) {
            voxelData[idx + k] = colorsArray[j + k];
          } else {
            voxelData[idx + k] += Math.abs(v2[k] - v1[k]);
          }
        }
      });
    }

    const outputData = new Uint8Array(voxelData.length);//new Uint32Array(voxelData.length / 4);
    for(let i = 0; i < voxelData.length; i += 4) {
      const vec = [voxelData[i], voxelData[i + 1], voxelData[i + 2]];
      if(vec.every(v => v === 0)) continue;

      if(this.overrideColor) {
        outputData[i]     = this.overrideColor[0];
        outputData[i + 1] = this.overrideColor[1];
        outputData[i + 2] = this.overrideColor[2];
        outputData[i + 3] = 255;
      } else {
        let v = vec;
        if(this.colorMode !== TRACTS_COLOR_MODE.COLORMAP) {
          vtkMath.normalize(vec);
          v = vec.map(v => Math.round(v * 255));
        }

        outputData[i]     = v[0];
        outputData[i + 1] = v[1];
        outputData[i + 2] = v[2];
        outputData[i + 3] = 255;
      }

       // voxelData[i] = (v[0] << 16) + (v[1] << 8) + v[2];
       // outputData[i / 4] = (v[0] << 16) + (v[1] << 8) + v[2];
    }

    this.scalarData = outputData;
    this.origin = targetDims.map(i => -i / 2.0)
    this.dimensions = targetDims;
    this.metadata.minPixelValue = 0;
    this.metadata.maxPixelValue = 1;
  }

  _rebuildImageData() {
    const scalarArray = vtkDataArray.newInstance({
      name: 'RGBAPixels',
      numberOfComponents: 4,
      values: this.scalarData,
    });

    const imageData = vtkImageData.newInstance();
    imageData.setDimensions(this.dimensions);
    imageData.setSpacing(this.spacing);
    imageData.setDirection(this.direction);
    imageData.setOrigin(this.origin);
    imageData.getPointData().setScalars(scalarArray);

    this.imageData = imageData;
    this.isReady = true;
  }
}

TrkImageVolume.MAX_QUALITY = 4;

export default TrkImageVolume;
