import { createSlice } from '@reduxjs/toolkit';
import { cache } from '@cornerstonejs/core';
import AXIS, { toCornerstoneAxis } from '../../constants/axes';

const R_MIN = 3;
const R_MAX = 7;
const R_NAN = 5;

const initialState = {
  error: undefined,
  rows: 3,
  cols: 5,
  axis: AXIS.Z,
  step: 2,
  autoStep: true,
  roi: undefined,
  p: 0.05,
  filterArray: 1,
  refSlice: undefined,
  volumeId: undefined,
  pointsId: undefined,
  points: undefined,
  selectedPoint: undefined,
  imgWidth: undefined,
  imgHeight: undefined,
  output: undefined
};

/*
const rotate = (p, a) => {
  return [
    p[0] * Math.cos(a) - p[1] * Math.sin(a),
    p[0] * Math.sin(a) + p[1] * Math.cos(a)
  ];
}

const translate = (p, v) => {
  return [ p[0] + v[0], p[1] + v[1] ];
}

const generateCurve = (srcPoint, dstPoint) => {
  const p = [dstPoint[0] - srcPoint[0], dstPoint[1] - srcPoint[1]];
  const a = -Math.atan2(p[1], p[0]);

  const h = srcPoint[0] < dstPoint[0] ? -30 : 30;
  const v = p[0] * Math.cos(a) - p[1] * Math.sin(a);
  let c1 = translate(rotate([ 90 * v / 100, h ], -a), srcPoint);
  let c2 = translate(rotate([ 95 * v / 100, -h], -a), srcPoint);

  return [c1[0], c1[1], c2[0], c2[1], dstPoint[0], dstPoint[1]];
}
*/

const generateImage = (state) => {
  if(!state.volumeId || !state.pointsId || !state.output) return;

  const vol = cache.getVolume(state.volumeId);
  const points = cache.getVolume(state.pointsId)?.data;
  if(!vol || !points) return;

  // find ref slice (should contain current roi)
  const { axis, rows, cols, roi, filterArray } = state;
  const roiIdx = points[roi].p[filterArray].findIndex(p => isNaN(p));
  if(roiIdx !== undefined) {
    const pos = points[roi].pos[roiIdx][axis];
    const slice = Math.trunc(pos / vol.spacing[axis] + vol.dimensions[axis] / 2);
    state.refSlice = slice;
  }

  const slices = calcSlices(state);

  let ax, ay;
  switch(axis) {
    case AXIS.Z:
      ax = AXIS.X;
      ay = AXIS.Y;
      break;
    case AXIS.Y:
      ax = AXIS.X;
      ay = AXIS.Z;
      break;
    case AXIS.X:
      ax = AXIS.Y;
      ay = AXIS.Z;
      break;
    default:
      return;
  }

  const sx = vol.dimensions[ax], sy = vol.dimensions[ay];
  const cw = sx * cols, ch = sy * rows;

  const tempCanvas = document.createElement("canvas");
  tempCanvas.width = cw;
  tempCanvas.height = ch;

  const ctx = tempCanvas.getContext('2d');

  ctx.fillStyle = 'black';
  ctx.fillRect(0, 0, cw, ch);

  const imageData = ctx.getImageData(0, 0, cw, ch);
  const data = imageData.data;

  for(let r = 0; r < rows; ++r) {
    for(let c = 0; c < cols; ++c) {
      const { slice } = vol.createSlice(axis, slices[c + r * cols]);
      slice.forEach((val, idx) => {
        // sy - --- ONLY FOR Z AXIS
        const i = 4 * ((idx % sx + sx * c) + (sy - Math.trunc(idx / sx) + sy * r) * cw);
        data[i] = val;
        data[i + 1] = val;
        data[i + 2] = val;
      });
    }
  }

  ctx.putImageData(imageData, 0, 0);

  // find data points
  const dstPoints = [];
  let srcPoint = undefined;

  const spacing = vol.spacing[axis];
  points[roi].pos.forEach((p, pi) => {
    if(points[roi].p[filterArray][pi] >= state.p) return;
    slices.forEach((s, i) => {
      const roiPos = p[axis] + vol.dimensions[axis] * spacing / 2;
      if(roiPos >= spacing * (s - 0.5) && roiPos <= spacing * (s + 0.5)) {
        const x = (i % cols) * sx + p[ax] / vol.spacing[ax] + vol.dimensions[ax] / 2;
        const y = Math.trunc(i / cols) * sy + sy - (p[ay] / vol.spacing[ay] + vol.dimensions[ay] / 2);
        const val = points[roi].vals[pi];
        const mean = points[roi].yMean[pi];

        if(isNaN(val)) srcPoint = { x, y, val: NaN, idx: pi };
        else dstPoints.push({ x, y, val, mean, idx: pi });
      }
    });
  });

  // draw lines
  ctx.lineWidth = 2;
  if(srcPoint && dstPoints.length > 0) {
    dstPoints.forEach(p => {
      const ixs = Math.floor(srcPoint.x / sx);
      const ixp = Math.floor(p.x / sx);
      const iys = Math.floor(srcPoint.y / sy);
      const iyp = Math.floor(p.y / sy);

      const isLeftHemi = p.x < (ixp + 0.5) * sx;

      let cp1, cp2;
      if(ixp < ixs) {
        // left
        if(isLeftHemi) {
          cp1 = [srcPoint.x, (iys - 0.25) * sy];
          cp2 = [ (ixp - 1) * sx, (iys + (iyp < iys ? 0.5 : -0.5) * (iys - iyp)) * sy]
        } else {
          cp1 = [ (ixp + 0.5) * sx, iys * sy ];
          cp2 = [ (ixp + 1.5) * sx, p.y ];
        }
      } else {
        // right
        if(isLeftHemi) {
          cp1 = [ (ixp + 0.5) * sx, iys * sy ];
          cp2 = [ (ixp - 0.5) * sx, p.y ];
        } else {
          cp1 = [srcPoint.x, (iys - 0.25) * sy];
          cp2 = [ (ixp + 1) * sx, (iys + (iyp < iys ? 0.5 : -0.5) * (iys - iyp)) * sy]
        }
      }

      ctx.strokeStyle = p.mean >= 0 ? "red" : "blue";
      ctx.beginPath();
      ctx.moveTo(srcPoint.x, srcPoint.y);
      ctx.bezierCurveTo(...cp1, ...cp2, p.x, p.y);
      ctx.stroke();
    });
  }

  // draw points
  const { valMax, valMin } = points[roi];

  const getRadius = v => {
    return isNaN(v)
      ? R_NAN
      : (Math.abs(v) - valMin) / (valMax - valMin) * (R_MAX - R_MIN) + R_MIN;
  }

  dstPoints.push(srcPoint);
  dstPoints.forEach(({ x, y, val }) => {
    const r = getRadius(val);

    const gradient = ctx.createRadialGradient(x, y, 1, x, y, r);
    gradient.addColorStop(0, "white");
    gradient.addColorStop(1, isNaN(val) ? "black" : (val > 0 ? "red" : "blue"));

    ctx.fillStyle = gradient;
    ctx.beginPath();
    ctx.arc(x, y, r, 0, 2 * Math.PI);
    ctx.fill();
  });

  const outputCtx = state.output.getContext('2d');
  state.output.width = cw;
  state.output.height = ch;

  outputCtx.drawImage(tempCanvas, 0, 0, cw, ch);

  state.points = dstPoints;
  state.imgWidth = cw;
  state.imgHeight = ch;

  tempCanvas.remove()
}

const findClosestPointIndex = (state, x, y) => {
  if(!state.points || state.points.length === 0) return undefined;

  const [ d, i ] = state.points
    .map(({ x: px, y: py }, i) => [(px - x) * (px - x) + (py - y) * (py - y), i])
    .sort((a, b) => a[0] > b[0])[0];

  return d < R_MAX * R_MAX ? state.points[i] : undefined;
}

const calcSlices = (state) => {
  if(!state.volumeId) return;

  const volume = cache.getVolume(state.volumeId);
  if(!volume) return;

  const { rows, cols, axis, refSlice } = state;

  const sliceMargin = 12;

  let step = state.step;
  if(state.step === undefined || state.autoStep) {
    step = Math.round(Math.min(refSlice - sliceMargin, volume.dimensions[axis] - refSlice - sliceMargin) / (rows * cols) * 2);
    step = Math.max(step, 1);
    state.step = step;
  }

  let startIdx = refSlice - step * Math.floor(rows * cols / 2);
  console.log("START IDX", startIdx, refSlice, step)
  if(startIdx < 0) {
    const m = Math.floor(refSlice / step);
    startIdx = refSlice - step * m;
  }

  const res = [];
  for(let i = 0; i < rows * cols; ++i) {
    // if(axis !== AXIS.X) {
      // res.push(Math.max(0, volume.dimensions[axis] - 1 - (startIdx + step * i)));
    // } else {
      res.push(startIdx + step * i);
    // }
  }

  return axis === AXIS.X ? res : res.reverse();
}

export const qaViewerSlice = createSlice({
  name: 'fgmviewer',
  initialState,
  reducers: {
    setError: (state, action) => {
      state.error = action.payload;
    },
    setVolume: (state, action) => {
      state.volumeId = action.payload;
    },
    setPoints: (state, action) => {
      state.pointsId = action.payload;
      state.selectedPoint = undefined;
      state.roi = 0;
    },
    setOutput: (state, action) => {
      state.autoStep = true;
      state.output = action.payload;
      generateImage(state);
    },
    setROI: (state, action) => {
      state.roi = action.payload;
      generateImage(state);
    },
    setRows: (state, action) => {
      state.rows = action.payload;
      generateImage(state);
    },
    setCols: (state, action) => {
      state.cols = action.payload;
      generateImage(state);
    },
    setFilterArray: (state, action) => {
      state.filterArray = action.payload;
      generateImage(state);
    },
    setFilterValue: (state, action) => {
      state.p = action.payload;
      generateImage(state);
    },
    setSelectedPoint: {
      reducer: (state, action) => {
        const { x, y, w, h } = action.payload;
        const p = findClosestPointIndex(
          state, state.imgWidth / w * x, state.imgHeight / h * y
        );
        if(p) state.selectedPoint = p.idx;
      },
      prepare: (x, y, w, h) => ({ payload: { x, y, w, h }})
    },
    clearScene: (state) => initialState,
  }
});

export const {
  setError, setOutput, setPoints, setVolume, setSelectedPoint,
  setRows, setCols, setROI, setFilterArray, setFilterValue,
  clearScene, saveScene
} = qaViewerSlice.actions;
export default qaViewerSlice.reducer;
