import { scaleLinear, ScaleLinear } from 'd3-scale';
import { countBy, mapValues } from 'lodash';
import { mean as statMean, median as statMedian, modeFast } from 'simple-statistics';
import { COLORING_OPTIONS } from 'app/shared/enum/analysis';
import { categoricalMaps } from 'app/shared/styles/categoricalMaps';
import { colorcetMaps, ColorScheme, scales, scientificColorMaps } from 'app/shared/styles/scales';
import {
  CategoricalDataColoringFunction,
  CategoricalNodeColoringFunction,
  ColoringData,
  COLORMAP_TYPE,
  DataColoringFunction,
  LandscapeNode,
  NodeColoringFunction,
  NodeColors,
} from '../Analysis.types';
import { CategoricalColorMaps } from './Coloring.types';

export const DEFAULT_SCATTER_PLOT_SYMBOL = 'circle';
export const DEFAULT_NODE_COLOR = '#555555DD';
export const MISSING_DATA_COLOR = '#777777BB';
export const DEFAULT_NODE_SIZE = 5;

export enum AGGREGATION_METHODS {
  mean = 'mean',
  median = 'median',
  mode = 'mode',
  max = 'max',
  // return the value with maximum distance from zero
  absmax = 'absmax',
}

export enum SCALING_STRATEGY {
  alwaysZScore = 'always z-score', // always request z-score scaled functions
  alwaysRaw = 'always raw', // always request raw values of functions
  smart = 'smart', // request raw values for single features, otherwise z-score
}

export const coloringConfig = {
  // aggregation method: how to choose a coloring value for a landscape node based on the values for its constituent data points
  aggregationMethod: AGGREGATION_METHODS.mean,
  // sequential color scheme, used for raw functions (i.e. a single feature with
  // SCALING_STRATEGY.smart, or an indicator function) or when
  // alwaysUseSequential is true
  // functions are scaled to lie in [0,1]
  sequentialColorScheme: scales.CET_L20_truncated,
  // diverging color scheme, used for z-score functions if alwaysUseSequential is false.
  // functions are scaled so that the extremes are +/- 2 standard deviations
  // divergingColorScheme: scientificColorMaps.berlin,
  divergingColorScheme: colorcetMaps.CET_D7,
  // how to choose between raw and z-score scaled functions, described above
  scalingStrategy: SCALING_STRATEGY.smart,
  // always use the sequential color scheme and the corresponding [0,1] normalization
  // otherwise, z-score scaled functions will use the diverging color scheme
  alwaysUseSequential: false,
  // scale the coloring values before aggregating to nodes
  // shouldn't make a difference for the diverging color scheme because the scaling doesn't depend on the range
  // but for sequential color schemes scaling before aggregating leads to less contrast in general
  scaleBeforeAggregation: false,
  // for z-score scaled data with the diverging colormap, the range is always symmetric around 0: [-M,M]
  // M will be set between minStandardDeviations and maxStandardDeviations
  // if possible, it will be equal to the maximum absolute value of the data
  // otherwise it will be clamped to the specified range
  minStandardDeviations: 0.5,
  maxStandardDeviations: 3,
};

// this approximates the coloring behavior used in xshop
const xshopColoringConfig = {
  aggregationMethod: AGGREGATION_METHODS.mean,
  sequentialColorScheme: colorcetMaps.CET_R2,
  divergingColorScheme: scientificColorMaps.berlin,
  scalingStrategy: SCALING_STRATEGY.alwaysRaw,
  alwaysUseSequential: true,
  scaleBeforeAggregation: false,
};

export const getColorMap = (colorScheme: ColorScheme): ScaleLinear<string, string, never> =>
  scaleLinear(colorScheme.map((colorPair) => colorPair[1]))
    .domain(colorScheme.map((colorPair) => colorPair[0]) as number[])
    .clamp(true);

const colorMapSequential = getColorMap(coloringConfig.sequentialColorScheme);
const colorMapDiverging = getColorMap(coloringConfig.divergingColorScheme);

export const colorMap = (
  x: number | null,
  colorMapType: COLORMAP_TYPE = COLORMAP_TYPE.sequential
): string => {
  if (x === null) {
    return MISSING_DATA_COLOR;
  }
  if (colorMapType === COLORMAP_TYPE.diverging) {
    return colorMapDiverging(x);
  }
  return colorMapSequential(x);
};

const mode = (array: string[]): string | null => {
  const counted = countBy(array);
  const max = Math.max(...Object.values(counted));
  const value = Object.entries(counted).filter((obj) => obj[1] === max);

  if (value.length > 1) {
    return null;
  }

  return value[0][0];
};

const handleMissingValues =
  <T>(aggFn: (array: T[]) => T | null) =>
  (array: (T | null)[]): T | null => {
    // what.
    // eslint-disable-next-line react/destructuring-assignment
    const arrayWithoutNull = array.filter((x) => x !== null);

    if (arrayWithoutNull.length === 0) return null;
    return aggFn(arrayWithoutNull as T[]);
  };

const arrayMean = handleMissingValues(statMean);

const arrayMedian = handleMissingValues(statMedian);

const arrayMode = handleMissingValues(mode);

const arrayMax = handleMissingValues((array: number[]) => Math.max(...array));

const arrayMin = handleMissingValues((array: number[]) => Math.min(...array));

const arrayAbsMax = handleMissingValues((array: number[]): number => {
  const posMax = Math.max(...array);
  const negMax = Math.max(...array.map((x) => -x));
  return posMax >= negMax ? posMax : negMax;
});

export const getDataColoringBounds = (dataColoringFn: DataColoringFunction): number[] => {
  if (!dataColoringFn) return [];

  const dataColoredValues = Object.values(dataColoringFn);

  const min = arrayMin(dataColoredValues);
  const max = arrayMax(dataColoredValues);

  if (min === null || max === null) {
    // all function values are missing
    return [0, 0];
  }

  return [min, max];
};

export const getDivergingColoringBound = (coloringFn: NodeColoringFunction): number => {
  const absmax = arrayAbsMax(coloringFn);
  if (absmax === null) {
    return coloringConfig.minStandardDeviations;
  }
  return Math.max(
    coloringConfig.minStandardDeviations,
    Math.min(coloringConfig.maxStandardDeviations, absmax)
  );
};

const aggregationMethods = {
  mean: arrayMean,
  median: arrayMedian,
  mode: modeFast,
  max: arrayMax,
  absmax: arrayAbsMax,
};

const getSequentialScaledValue = (
  averageColorValue: number | null,
  dataBounds: number[]
): number | null => {
  if (averageColorValue === null) {
    return null;
  }
  if (dataBounds[1] === dataBounds[0]) {
    return 0.5;
  }
  return (averageColorValue - dataBounds[0]) / (dataBounds[1] - dataBounds[0]);
};

const getDivergingScaledValue = (
  colorValue: number | null,
  absmax = coloringConfig.maxStandardDeviations
): number | null => (colorValue === null ? null : colorValue / (2 * absmax) + 0.5);

export const getPrescaledColorDataValues = (coloringData: ColoringData): ColoringData => {
  const { colorFunction } = coloringData;
  if (colorFunction) {
    if (coloringData.scale === COLORMAP_TYPE.diverging) {
      const coloringBound = getDivergingColoringBound(Object.values(colorFunction));
      return {
        ...coloringData,
        colorFunction: mapValues(colorFunction, (x) => getDivergingScaledValue(x, coloringBound)),
      };
    }

    const coloringBounds = getDataColoringBounds(colorFunction);
    return {
      ...coloringData,
      colorFunction: mapValues(colorFunction, (x) => getSequentialScaledValue(x, coloringBounds)),
    };
  }
  return coloringData;
};

const aggregateCategoricalValues = (
  nodes: LandscapeNode[],
  coloringData: ColoringData
): CategoricalNodeColoringFunction =>
  nodes.map((node) =>
    arrayMode(node.attributes.data.map((nodeIndex) => coloringData.colorFunction[nodeIndex]))
  );

export function modeFrequencyCategoricalValues(
  modes: number[],
  nodes: LandscapeNode[],
  coloringData: ColoringData
): number[] {
  const maps = nodes.map((node, index) => {
    const mode_ = modes[index];
    return node.attributes.data.filter((dataIndex) => coloringData.colorData[dataIndex] === mode_)
      .length;
  });
  return maps;
}

const aggregateDataValues = (
  nodes: LandscapeNode[],
  coloringData: ColoringData
): NodeColoringFunction => {
  const aggregationFunction = aggregationMethods[coloringConfig.aggregationMethod];
  if (coloringData.colorBy === COLORING_OPTIONS.selectedDataPoints) {
    const nodeFunction = nodes.map((node) =>
      aggregationFunction(
        node.attributes.data.map((nodeIndex) =>
          coloringData.colorData.includes(nodeIndex) ? 1 : 0
        )
      )
    );
    return nodeFunction;
  }
  if (coloringData.colorBy === COLORING_OPTIONS.selectedFeatures && coloringData.colorFunction) {
    const nodeFunction = nodes.map((node) =>
      aggregationFunction(
        node.attributes.data.map((nodeIndex) => coloringData.colorFunction[nodeIndex])
      )
    );
    return nodeFunction;
  }
  return nodes.map(() => NaN);
};

const aggregateFeatureValues = (
  nodes: LandscapeNode[],
  coloringData: ColoringData
): NodeColoringFunction => {
  const aggregationFunction = aggregationMethods[coloringConfig.aggregationMethod];
  if (coloringData.colorBy === COLORING_OPTIONS.selectedFeatures) {
    const nodeFunction = nodes.map((node) =>
      aggregationFunction(
        node.attributes.data.map((nodeIndex) =>
          coloringData.colorData.includes(nodeIndex) ? 1 : 0
        )
      )
    );
    return nodeFunction;
  }
  if (coloringData.colorBy === COLORING_OPTIONS.selectedDataPoints && coloringData.colorFunction) {
    const nodeFunction = nodes.map((node) =>
      aggregationFunction(
        node.attributes.data.map((nodeIndex) => coloringData.colorFunction[nodeIndex])
      )
    );
    return nodeFunction;
  }
  return nodes.map(() => NaN);
};

const sequentialScaleNodeColorValues = (
  nodeFunction: NodeColoringFunction
): NodeColoringFunction => {
  const min = arrayMin(nodeFunction);
  const max = arrayMax(nodeFunction);
  if (min === null || max === null) {
    // all values are already null
    return nodeFunction;
  }
  return nodeFunction.map((x) => getSequentialScaledValue(x, [min, max]));
};

const divergingScaleNodeColorValues = (
  nodeFunction: NodeColoringFunction
): NodeColoringFunction => {
  const dataBound = getDivergingColoringBound(nodeFunction);
  return nodeFunction.map((x) => getDivergingScaledValue(x, dataBound));
};

// export const getNearestColorIndex = (array: ColorPair[], value: number): number =>
//   array.reduce(
//     ({ pair, index }, curr, currentIndex) =>
//       Math.abs(curr[0] - value) < Math.abs(pair[0] - value)
//         ? { pair: curr, index: currentIndex }
//         : {
//             pair,
//             index,
//           },
//     { pair: array[0], index: 0 }
//   ).index;

export const getNodeCategoricalColors = (
  coloringData: ColoringData<CategoricalDataColoringFunction>,
  colorMapType: CategoricalColorMaps,
  nodes: LandscapeNode[],
  isFeatureLandscape: boolean
) => {
  const uniqueValues = Array.from(new Set(Object.values(coloringData.colorFunction)));
  const aggregationFunction = isFeatureLandscape
    ? aggregateFeatureValues
    : aggregateCategoricalValues;

  const colorMapFunction = getColorMap(categoricalMaps[colorMapType]);
  const isMonochromatic = colorMapType === CategoricalColorMaps.Monochromatic;

  const aggregatedFunction = aggregationFunction(nodes, coloringData);

  const nodeColors = aggregatedFunction.map((x) => {
    const valueIndex = uniqueValues.findIndex((i) => i === x);

    return {
      color: colorMapFunction(x !== null ? valueIndex : 10),
      type: isMonochromatic && x === null ? 'border' : 'circle',
    };
  });

  return {
    nodeColors,
    categories: uniqueValues
      .map((value, index) => ({
        value,
        color: colorMapFunction(index),
      }))
      .concat({
        value: 'no dominant category',
        color: colorMapFunction(10),
      }),
    rawNodeValues: aggregatedFunction,
  };
};

export const getNodeColors = (
  coloringData: ColoringData<DataColoringFunction>,
  nodes: LandscapeNode[],
  isFeatureColoring: boolean,
  scaleBeforeAggregation: boolean
): NodeColors => {
  const aggregationFunction = isFeatureColoring ? aggregateFeatureValues : aggregateDataValues;
  const colorMapType = coloringData.scale;

  const prescaledColoringData = scaleBeforeAggregation
    ? getPrescaledColorDataValues(coloringData)
    : coloringData;
  const aggregatedFunction = aggregationFunction(nodes, prescaledColoringData);

  const scaledAggregatedFunction =
    colorMapType === COLORMAP_TYPE.diverging
      ? divergingScaleNodeColorValues(aggregatedFunction)
      : sequentialScaleNodeColorValues(aggregatedFunction);

  let rawNodeValues = aggregationFunction(nodes, coloringData);

  rawNodeValues = rawNodeValues.map((value) => {
    if (value == null) {
      return null;
    }

    const significantFigures =
      value === 0 ? 0 : Math.max(-Math.floor(Math.log10(Math.abs(value))) + 2, 0);

    // This avoids numerical instability.
    if (significantFigures > 10) {
      return value;
    }

    const exponent = 10 ** significantFigures;

    return Math.round(value * exponent) / exponent;
  });

  const nodeColors = scaleBeforeAggregation
    ? aggregatedFunction.map((x) => colorMap(x, colorMapType))
    : scaledAggregatedFunction.map((x) => colorMap(x, colorMapType));

  if (colorMapType === COLORMAP_TYPE.sequential) {
    const fnMax = scaleBeforeAggregation
      ? arrayMax(coloringData.colorData)
      : arrayMax(aggregatedFunction);
    const fnMin = scaleBeforeAggregation
      ? arrayMin(coloringData.colorData)
      : arrayMin(aggregatedFunction);

    // const minColorIndex = getNearestColorIndex(coloringConfig.sequentialColorScheme, fnMin);
    // const maxColorIndex = getNearestColorIndex(coloringConfig.sequentialColorScheme, fnMax);

    const colorBounds =
      fnMax === fnMin
        ? [nodeColors[0]]
        : coloringConfig.sequentialColorScheme.map((colorPair) => colorPair[1]);

    const dataBounds: [number, number] = fnMax === null || fnMin === null ? [0, 0] : [fnMin, fnMax];

    return {
      nodeColors,
      dataBounds,
      colorBounds,
      colorMapType,
      rawNodeValues,
    };
  }

  const scaleBound = scaleBeforeAggregation
    ? getDivergingColoringBound(Object.values(coloringData.colorData))
    : getDivergingColoringBound(aggregatedFunction);

  const colorBounds = coloringConfig.divergingColorScheme.map((colorPair) => colorPair[1]);

  return {
    nodeColors,
    dataBounds: [-scaleBound, scaleBound],
    colorBounds,
    colorMapType,
    rawNodeValues,
  };
};
