import { cache, Enums } from '@cornerstonejs/core';
import vtkPlane from '@kitware/vtk.js/Common/DataModel/Plane';
import vtkDataArray from '@kitware/vtk.js/Common/Core/DataArray';
import vtkPolyData from '@kitware/vtk.js/Common/DataModel/PolyData';
import VolumeSlicesActor from './VolumeSlicesActor';
import { BaseActor, ActorsContainer, PolyDataActor } from './BaseActors';
import SliceActor from './SliceActor';
import { AXIS, TRACTS_SLICE_MODE } from '../../../../constants';
import { fromCornerstoneAxis } from '../../../../constants/axes';

import vtkColorTransferFunction from '@kitware/vtk.js/Rendering/Core/ColorTransferFunction';

const { ViewportType } = Enums;

const getVolumeActorConstructor = (mode) => {
  switch(mode) {
  case TRACTS_SLICE_MODE.DISABLED:    return TractsBaseActor;
  case TRACTS_SLICE_MODE.PRECISE:     return PreciseTractsSlicesActor;
  case TRACTS_SLICE_MODE.APPROXIMATE: return ApproximateTractsSlicesActor;
  case TRACTS_SLICE_MODE.VOXELIZED:   return VoxelizedTractsSlicesActor;
  default: return undefined;
  }
}

const getSliceActorConstructor = (mode) => {
  switch(mode) {
  case TRACTS_SLICE_MODE.DISABLED:    return SliceActor;
  case TRACTS_SLICE_MODE.CLUSTERED:   return SliceActor;
  case TRACTS_SLICE_MODE.PRECISE:     return TractsPlane;
  case TRACTS_SLICE_MODE.APPROXIMATE: return ApproximateTractsPlane;
  case TRACTS_SLICE_MODE.VOXELIZED:   return SliceActor;
  default: return undefined;
  }
}

class TractsBaseActor extends BaseActor {
  constructor(renderingEngine, volumeId) {
    super(renderingEngine, volumeId);

    const { polyData } = cache.getVolume(volumeId);

    this.uid = volumeId;
    this.polyData = polyData;
    this.mapper.setInputData(polyData);
  }
}

// ----------------------------------------------------------------------------

class TractsPlane extends TractsBaseActor {
  constructor(renderingEngine, volumeId, axis) {
    super(renderingEngine, volumeId);
    this.uid = `VOLUME_${volumeId}#trk#${axis}`;
    this.index = undefined;
    this.axis = axis;
  }

  setSlices(worldPos, force) {
    const { index, axis, dimensions, spacing } = this;

    const newIdx = this.worldPosToIndex(worldPos)[axis];
    if(!force && newIdx === index) return;

    this.index = newIdx;

    const bMin = (newIdx - 0.5 - dimensions[axis] / 2) * spacing[axis];
    const bMax = (newIdx + 0.5 - dimensions[axis] / 2) * spacing[axis];

    const normal1 = [0, 0, 0];
    const normal2 = [0, 0, 0];
    normal1[axis] = 1;
    normal2[axis] = -1;

    const clipOrigin1 = [0, 0, 0];
    const clipOrigin2 = [0, 0, 0];
    clipOrigin1[axis] = bMin;
    clipOrigin2[axis] = bMax;

    this.mapper.removeAllClippingPlanes();
    this.mapper.setClippingPlanes([
      vtkPlane.newInstance({ normal: normal1, origin: clipOrigin1 }),
      vtkPlane.newInstance({ normal: normal2, origin: clipOrigin2 })
    ]);
  }
}

class ApproximateTractsPlane extends TractsBaseActor {
  constructor(renderingEngine, volumeId, axis) {
    super(renderingEngine, volumeId);
    this.uid = `VOLUME_${volumeId}#approx#${axis}`;
    this.index = undefined;
    this.axis = axis;
  }

  setSlices(worldPos, force) {
    const { polyData, dimensions, spacing, axis } = this;

    const newIdx = this.worldPosToIndex(worldPos);
    if(!force && this.index === newIdx[axis]) return;
    this.index = newIdx[axis];

    const bounds = [0, 0, 0, 0, 0, 0].map((_, i) => {
      const axis = Math.trunc(i / 2);
      const margin = 0.25 * (i % 2 ? 1 : -1);
      return (newIdx[axis] + margin - dimensions[axis] / 2) * spacing[axis];
    });

    const tractCells  = polyData.getLines().getData();
    const tractPoints = polyData.getPoints().getData();
    const tractColors = polyData.getCellData().getScalars().getData();
    const sliceCells  = [];
    const sliceColors = [];

    const checkPoint = (index) => {
      const i = 3 * tractCells[index];
      if(bounds[2 * axis] <= tractPoints[i + axis] && bounds[2 * axis + 1] >= tractPoints[i + axis]) return true;
      return false;
    }

    for(let i = 0; i < tractCells.length; i += 3) {
      if(checkPoint(i + 1) || checkPoint(i + 2)) {
        sliceCells.push(2, tractCells[i + 1], tractCells[i + 2]);
        sliceColors.push(tractColors[i], tractColors[i + 1], tractColors[i + 2]);
      }
    }

    const colorsArray = vtkDataArray.newInstance({
      numberOfComponents: 3,
      values: Uint8Array.from(sliceColors),
    });
    colorsArray.setName('color');

    const slicesPolyData = vtkPolyData.newInstance();
    slicesPolyData.getPoints().setData(tractPoints);
    slicesPolyData.getLines().setData(Uint32Array.from(sliceCells));
    slicesPolyData.getCellData().setScalars(colorsArray);

    this.mapper.setInputData(slicesPolyData);
  }
}

// ----------------------------------------------------------------------------

class PreciseTractsSlicesActor extends ActorsContainer {
  constructor(renderingEngine, volumeId) {
    super([
      new TractsPlane(renderingEngine, volumeId, AXIS.X),
      new TractsPlane(renderingEngine, volumeId, AXIS.Y),
      new TractsPlane(renderingEngine, volumeId, AXIS.Z)
    ]);
  }
};

// ----------------------------------------------------------------------------

class ApproximateTractsSlicesActor extends ActorsContainer {
  constructor(renderingEngine, volumeId) {
    super([
      new ApproximateTractsPlane(renderingEngine, volumeId, AXIS.X),
      new ApproximateTractsPlane(renderingEngine, volumeId, AXIS.Y),
      new ApproximateTractsPlane(renderingEngine, volumeId, AXIS.Z)
    ]);
  }
}
/*
class ApproximateTractsSlicesActor extends TractsBaseActor {
  constructor(renderingEngine, volumeId) {
    super(renderingEngine, volumeId);
    this.uid = `VOLUME_${volumeId}#approxSlices`;
  }

  setSlices(worldPos) {
    const { polyData, dimensions, spacing } = this;

    const newIdx = this.worldPosToIndex(worldPos);
    const bounds = [0, 0, 0, 0, 0, 0].map((_, i) => {
      const axis = Math.trunc(i / 2);
      const margin = 0.25 * (i % 2 ? 1 : -1);
      return (newIdx[axis] + margin - dimensions[axis] / 2) * spacing[axis];
    });

    const tractCells  = polyData.getLines().getData();
    const tractPoints = polyData.getPoints().getData();
    const tractColors = polyData.getCellData().getScalars().getData();
    const sliceCells  = [];
    const sliceColors = [];

    const checkPoint = (index) => {
      const i = 3 * tractCells[index];
      for(let j = 0; j < 3; ++j) {
        if(bounds[2 * j] <= tractPoints[i + j] && bounds[2 * j + 1] >= tractPoints[i + j]) return true;
      }
      return false;
    }

    for(let i = 0; i < tractCells.length; i += 3) {
      if(checkPoint(i + 1) || checkPoint(i + 2)) {
        sliceCells.push(2, tractCells[i + 1], tractCells[i + 2]);
        sliceColors.push(tractColors[i], tractColors[i + 1], tractColors[i + 2]);
      }
    }

    const colorsArray = vtkDataArray.newInstance({
      numberOfComponents: 3,
      values: Uint8Array.from(sliceColors),
    });
    colorsArray.setName('color');

    const slicesPolyData = vtkPolyData.newInstance();
    slicesPolyData.getPoints().setData(tractPoints);
    slicesPolyData.getLines().setData(Uint32Array.from(sliceCells));
    slicesPolyData.getCellData().setScalars(colorsArray);

    this.mapper.setInputData(slicesPolyData);
  }
}
*/
// ----------------------------------------------------------------------------

class VoxelizedTractsSlicesActor extends ActorsContainer {
  constructor(renderingEngine, volumeId) {
    super([
      new VolumeSlicesActor(renderingEngine, volumeId),
      new VolumeSlicesActor(renderingEngine, volumeId)
    ]);

    const { spacing } = cache.getVolume(volumeId);

    this.actors.forEach(({ actors }, i) => {
      const offset = spacing.map(s => i % 2 ? -s * 2 : s * 2);
      actors.forEach(a => {
        a.uid = `${a.uid}#${i}`
        a.actor.setPosition(...offset);
      });
    });
  }
}

// ----------------------------------------------------------------------------

class TractsSlicesCombined extends PolyDataActor {
  constructor(renderingEngine, volumeId, mode) {
    super(renderingEngine, volumeId);
    this.mode = mode;
    this.cameraPos = undefined;
  }

  createSliceActor(axis) {
    const { renderingEngine, volumeId, mode } = this;

    const ActorConstructor = getSliceActorConstructor(mode);
    if(!ActorConstructor) {
      throw Error(`Invalid tracts slices mode: ${mode}`);
    }

    const act = new ActorConstructor(renderingEngine, volumeId, axis);
    if(mode === TRACTS_SLICE_MODE.DISABLED || mode === TRACTS_SLICE_MODE.VOXELIZED || mode === TRACTS_SLICE_MODE.CLUSTERED) {
      act.setProperty(p => p.setInterpolationTypeToNearest());
    } else {
      // act.setProperty(p => p.setLineWidth(1.5));
    }

    return act;
  }

  createVolumeActor() {
    const { volumeId, renderingEngine, mode } = this;

    const ActorConstructor = getVolumeActorConstructor(mode);
    if(!ActorConstructor) {
      throw Error(`Invalid tracts slices mode: ${mode}`);
    }

    const actor = new ActorConstructor(renderingEngine, volumeId);
    if(mode !== TRACTS_SLICE_MODE.VOXELIZED) {
      actor.setProperty(p => p.setLighting(true));
    }

    return actor;
  }

  setTractsMode(mode) {
    this.mode = mode;
    this.updateVolumeActor();
  }

  modified() {
    super.modified();
    this.updateVolumeActor();
  }

  updateVolumeActor() {
    if(!this.volume) return;

    const volumeViewports = this.renderingEngine.getViewports()
      .filter(v => v.type === ViewportType.VOLUME_3D)

    volumeViewports.forEach(v => this.volume.removeFromViewport(v.id));

    if(this.mode === TRACTS_SLICE_MODE.CLUSTERED) return;

    this.volume.delete();
    this.volume = this.createVolumeActor();

    volumeViewports.forEach(v => this.volume.addToViewport(v.id));
  }
}

export { TRACTS_SLICE_MODE };
export default TractsSlicesCombined;
