import { createSlice } from '@reduxjs/toolkit';
import {
  RenderingEngine,
  setVolumesForViewports,
  getRenderingEngine,
  cache
} from '@cornerstonejs/core';
import * as cornerstoneTools from '@cornerstonejs/tools';
import vtkPiecewiseFunction from '@kitware/vtk.js/Common/DataModel/PiecewiseFunction';
import { initializeViewer } from '../viewer/viewer.slice';
import AXIS, { toCornerstoneAxis } from '../../constants/axes';

const { SegmentationRepresentations } = cornerstoneTools.Enums;
const { utilities } = cornerstoneTools;
const {
  ToolGroupManager,
  SegmentationDisplayTool,
  segmentation
} = cornerstoneTools;

const RENDERING_ENGINE_ID = 'RENDERING_ENGINE_QA';
const QA_TOOL_GROUP_ID = 'QA_TOOL_GROUP_ID';

const backgroundSof = vtkPiecewiseFunction.newInstance();
backgroundSof.setClamping(false);

const initialState = {
  renderingEngine: null,
  error: undefined,
  tools: undefined,
  rows: 2,
  cols: 5,
  axis: AXIS.Z,
  step: 2,
  autoStep: true,
  refSlice: undefined,
  viewports: [],
  volumeId: undefined,
  segmentationId: undefined,
  bkgThreshold: undefined,
  bkgRange: [0, 255]
};

const removeAllSegmentations = (toolGroupId, rerender = false) => {
  const reprs = segmentation.state.getSegmentationRepresentations(toolGroupId);
  const uids = reprs?.map(({ segmentationRepresentationUID: uid }) => uid);
  if(uids) {
    segmentation.removeSegmentationsFromToolGroup(toolGroupId, uids, rerender);
  }
}

const calcSlices = (state) => {
  if(!state.renderingEngine || ! state.volumeId) return;

  const volume = cache.getVolume(state.volumeId);
  if(!volume) return;

  const { rows, cols, axis, refSlice } = state;

  const sliceMargin = 12;

  let step = state.step;
  if(state.step === undefined || state.autoStep) {
    step = Math.round(Math.min(refSlice - sliceMargin, volume.dimensions[axis] - refSlice - sliceMargin) / (rows * cols) * 2);
    step = Math.max(step, 1);
    state.step = step;
  }

  let startIdx = refSlice - step * Math.floor(rows * cols / 2);
  if(startIdx < 0) {
    const m = Math.floor(refSlice / step);
    startIdx = refSlice - step * m;
  }

  const res = [];
  for(let i = 0; i < rows * cols; ++i) {
    if(axis !== AXIS.X) {
      res.push(Math.max(0, volume.dimensions[axis] - 1 - (startIdx + step * i)));
    } else {
      res.push(startIdx + step * i);
    }
  }

  return res;
}

const resetSlices = (viewportIds, sliceIdxs, segmentationId, axis) => {
  // segmentation affects slice indexes and volumeId in jumpToSlice options
  // is actually ignored, so we need to recreate segmentation each time we
  // change slices

  removeAllSegmentations(QA_TOOL_GROUP_ID, true);

  const re = getRenderingEngine(RENDERING_ENGINE_ID);
  viewportIds.forEach((id, i) => {
    re.getViewport(id).resetCamera();
    utilities.viewport.jumpToSlice(
      re.getViewport(id).element,
      { imageIndex: sliceIdxs[i] }
    );
  });

  setTimeout(() => {
    viewportIds.forEach((id) => {
      if(axis === AXIS.Z) re.getViewport(id).flip({ flipVertical: true });
      else if(axis === AXIS.X) re.getViewport(id).flip({ flipHorizontal: true });
    });
    re.renderViewports(viewportIds);
  }, 5);

  segmentation.addSegmentationRepresentations(QA_TOOL_GROUP_ID, [
    {
      segmentationId,
      type: SegmentationRepresentations.Labelmap,
    },
  ]).then(([ segRepId ]) => {
    segmentation.config.color.setColorForSegmentIndex(
      QA_TOOL_GROUP_ID,
      segRepId,
      1,
      [255, 255, 0, 255]
    );
    segmentation.config.setSegmentationRepresentationSpecificConfig(
      QA_TOOL_GROUP_ID,
      segRepId,
      { LABELMAP: { fillAlpha: 0.1, outlineWidthActive: 1 }}
    );
  });
}

const initActor = ({ volumeActor, volumeId }) => {
  const volume = cache.getVolume(volumeId);
  if(!volume) return;

  const { minPixelValue, maxPixelValue } = volume.metadata;
  volumeActor.getProperty().getRGBTransferFunction(0).setMappingRange(minPixelValue, maxPixelValue);
  volumeActor.getProperty().setScalarOpacity(0, backgroundSof);
  volumeActor.getProperty().setInterpolationTypeToLinear();
  volumeActor.getMapper().setBlendModeToMaximumIntensity();
}

export const qaViewerSlice = createSlice({
  name: 'qaviewer',
  initialState,
  reducers: {
    setViewportReady: (state, action) => {
      const viewport = state.viewports.find(v => v.id === action.payload);
      if(viewport) {
        viewport.ready = true;
        state.tools.addViewport(viewport.id, RENDERING_ENGINE_ID);
      }
    },
    setError: (state, action) => {
      state.error = action.payload;
    },
    setVolume: (state, action) => {
      const vol = cache.getVolume(action.payload);
      state.volumeId = action.payload;
      state.refSlice = Math.floor((vol.dimensions[state.axis] - 1) / 2);
      state.bkgRange = [ Math.max(vol.metadata.minPixelValue, 1), vol.metadata.maxPixelValue ];
      if(state.bkgThreshold === undefined) {
        state.bkgThreshold = Math.floor(vol.metadata.maxPixelValue * 0.3);
        backgroundSof.removeAllPoints();
        backgroundSof.addPointLong(0, 0.0, 1.0, 1.0);
        backgroundSof.addPointLong(state.bkgThreshold, 1.0, 1.0, 1.0);
      }
    },
    setSegmentation: (state, action) => {
      if(state.segmentationId) removeAllSegmentations(QA_TOOL_GROUP_ID);

      const volumeId = action.payload;
      state.segmentationId = volumeId;
      if(segmentation.state.getSegmentation(volumeId)) return;

      segmentation.addSegmentations([{
        segmentationId: volumeId,
        representation: {
          type: SegmentationRepresentations.Labelmap,
          data: { volumeId }
        }
      }]);
    },
    clearScene: (state) => {
      state.volumeId = undefined;
      state.segmentationId = undefined;
      state.error = undefined;
      state.viewports.forEach(({ id }) => {
        state.renderingEngine.getViewport(id)?.removeAllActors();
      });
      removeAllSegmentations(QA_TOOL_GROUP_ID, true);
    },
    createScene: (state) => {
      const volume = cache.getVolume(state.volumeId);

      if(!state.renderingEngine || !volume) return;
      const viewportIds = state.viewports.map(v => v.id);
      const segmentationId = state.segmentationId;
      const axis = state.axis;
      const sliceIdxs = calcSlices(state);

      state.renderingEngine.resize(false, true);

      setVolumesForViewports(
        state.renderingEngine,
        [ { volumeId: state.volumeId, callback: initActor } ],
        viewportIds
      ).then(() => {
        resetSlices(viewportIds, sliceIdxs, segmentationId, axis);
      });
    },
    saveScene: (state, action) => {
      if(!state.volumeId) return;

      const { canvas: baseCanvas } = state.renderingEngine.getViewport(state.viewports[0].id);
      const cw = baseCanvas.width, ch = baseCanvas.height;
      const vol = cache.getVolume(state.volumeId);

      const tempCanvas = document.createElement("canvas");
      tempCanvas.width = cw * state.cols;
      tempCanvas.height = ch * state.rows;

      const ctx = tempCanvas.getContext('2d');
      ctx.fillStyle = 'black';
      ctx.fillRect(0, 0, tempCanvas.width, tempCanvas.height);

      ctx.fillStyle = "#9ccef9";
      ctx.font = '18px sans-serif';

      let i = 0;
      for(let r = 0; r < state.rows; ++r) {
        for(let c = 0; c < state.cols; ++c) {
          const vp = state.renderingEngine.getViewport(state.viewports[i].id);
          const dim = vol.dimensions[state.viewports[i].axis];
          const spacing = vol.spacing[state.viewports[i].axis];
          const sliceIdx = Math.round((vp.getCurrentImageIdIndex() - dim / 2) * spacing);

          const text = `z = ${sliceIdx}`;
          const tdx = cw / 2 - ctx.measureText(text).width / 2;
          const tdy = 18 + 10; // font size + margin

          ctx.drawImage(vp.canvas, cw * c, ch * r);
          ctx.fillText(text, cw * c + tdx, ch * r + tdy, cw);
          ++i;
        }
      }

      const image = tempCanvas.toDataURL('image/png').replace("image/png", "image/octet-stream");

      const link = document.createElement('a');
      link.download = action.payload + '.png' || 'image.png';
      link.href = image;
      link.click();

      link.remove();
      tempCanvas.remove();
    },
    setConfig: {
      reducer: (state, action) => {
        const { rows, cols, step, autoStep, axis,refSlice, bkgThreshold } = action.payload;
        state.rows = rows;
        state.cols = cols;
        state.step = step;
        state.autoStep = autoStep;
        state.refSlice = refSlice;
        if(state.bkgThreshold !== bkgThreshold) {
          backgroundSof.removeAllPoints();
          if(bkgThreshold > 0) backgroundSof.addPointLong(0, 0.0, 1.0, 1.0);
          backgroundSof.addPointLong(bkgThreshold, 1, 1.0, 1.0);
          state.bkgThreshold = bkgThreshold;
        }

        if(axis !== state.axis) {
          state.axis = axis;
          state.viewports.forEach(v => {
            v.axis = axis;
            v.ready = false;
            state.renderingEngine.getViewport(v.id).setOptions({
              orientation: toCornerstoneAxis(axis)
            }, false);
          });
        }

        if(state.viewports.length < rows * cols) {
          for(let i = state.viewports.length; i < rows * cols; ++i) {
            state.viewports.push({
              id: `vp-${i}`,
              axis: state.axis,
              ready: false
            });
          }
          // removeAllSegmentations(QA_TOOL_GROUP_ID, true);
        } else {
          state.viewports.splice(rows * cols, state.viewports.length - rows * cols);
          resetSlices(
            state.viewports.map(v => v.id),
            calcSlices(state),
            state.segmentationId,
            state.axis
          );
        }
      },
      prepare: (rows, cols, step, autoStep, axis, refSlice, bkgThreshold) =>
        ({ payload: { rows, cols, step, autoStep, axis, refSlice, bkgThreshold }})
    },
  },
  extraReducers: (builder) => {
    builder.addCase(initializeViewer.fulfilled, (state, action) => {
      state.renderingEngine = new RenderingEngine(RENDERING_ENGINE_ID);

      const tools = ToolGroupManager.createToolGroup(QA_TOOL_GROUP_ID);
      tools.addTool(SegmentationDisplayTool.toolName);
      tools.setToolEnabled(SegmentationDisplayTool.toolName);
      state.tools = tools;

      for(let i = 0; i < state.rows * state.cols; ++i) {
        state.viewports.push({ id: `vp-${i}`, axis: state.axis, ready: false });
      }
    });
  }
});

export const {
  setViewportReady, setError,
  setVolume, setSegmentation, setConfig,
  createScene, clearScene, saveScene
} = qaViewerSlice.actions;
export default qaViewerSlice.reducer;
