import React, { useMemo, useState, useCallback, useRef } from 'react';
import { max, map, pick, keyBy, memoize } from 'lodash';
import styles from './QuestionResults.module.scss';
import standardCohortKeys from 'js/utils/cohort-keys';

import classnames from 'classnames';

import { AxisRight } from '@visx/axis';
import { Grid } from '@visx/grid';
import { scaleBand, scaleLinear } from '@visx/scale';
import { curveMonotoneX as curve } from '@visx/curve';
import LineWithRings from './LineWithRings';

function addResponseToCohortAccumulator(acc, cohort, responseValue) {
  acc[cohort] = acc[cohort] || {};
  acc[cohort][responseValue] = (acc[cohort][responseValue] ?? 0) + 1;

  return acc;
}

function generateLineData(answers, responses) {
  return [
    'start',
    ...answers.map((answer) => ({
      answer,
      count: responses[answer.value] || 0,
    })),
    'end',
  ];
}

const QuestionResults = ({
  question,
  filter,
  participantCohortMap = {},
  cohorts,
  activeCohort: activeCohortContext,
  tag: Tag,
  className,
  width,
  height,
  margin,
  primaryResultsOverlayText,
  ...props
}) => {
  const answerValue = (a) => +a.value;

  const xMax = width - margin.right - margin.left;
  const yMax = height - margin.bottom - margin.top;
  const cohortColorIndices = {
    ...keyBy(
      cohorts?.map((c, i) => ({ ...c, index: i + 2 })) ?? [],
      (c) => c.value ?? standardCohortKeys.NULL
    ),
    [standardCohortKeys.CUMULATIVE]: { index: 1 },
  };

  const cohortKeys = useMemo(
    () => Object.keys(filter?.cohorts ?? {}),
    [filter?.cohorts]
  );

  const applyCohortFilter = cohortKeys.length > 0;
  const resultsByCohort = useMemo(() => {
    return question.responses.reduce((acc, r) => {
      const cohort = participantCohortMap[r.participant_id];
      if (applyCohortFilter && cohortKeys.indexOf(cohort) === -1) {
        return acc;
      }
      const responseValue = r.is_na ? null : +r.value;

      acc = addResponseToCohortAccumulator(
        acc,
        standardCohortKeys.CUMULATIVE,
        responseValue
      );
      acc = addResponseToCohortAccumulator(
        acc,
        cohort ?? standardCohortKeys.NULL,
        responseValue
      );

      return acc;
    }, {});
  }, [question, participantCohortMap, applyCohortFilter, cohortKeys]);

  const answers = useMemo(
    () => question.answers.filter((a) => !a.is_na),
    [question.answers]
  );

  const cohortLineData = useMemo(() => {
    const cohortData =
      cohortKeys.length > 0
        ? pick(resultsByCohort, [standardCohortKeys.CUMULATIVE, ...cohortKeys])
        : resultsByCohort;

    const cohortLineData = map(cohortData, (r, key) => ({
      key,
      data: generateLineData(answers, r),
    })).reverse();

    return cohortLineData;
  }, [resultsByCohort, cohortKeys, answers]);

  const resetCohortTimeout = useRef(null);
  const [activeLocalCohort, setActiveLocalCohort] = useState(null);
  const memoizedHandleCohortOver = React.useMemo(
    () =>
      memoize((cohortKey) => () => {
        clearTimeout(resetCohortTimeout.current);
        resetCohortTimeout.current = null;
        setActiveLocalCohort(cohortKey);
      }),
    [resetCohortTimeout, setActiveLocalCohort]
  );

  const handleCohortOver = (cohortKey) => {
    clearTimeout(resetCohortTimeout.current);
    resetCohortTimeout.current = null;
    setActiveLocalCohort(cohortKey);
  };

  const handleCohortOut = useCallback(() => {
    clearTimeout(resetCohortTimeout.current);
    resetCohortTimeout.current = setTimeout(() => setActiveLocalCohort(), 100);
  }, [resetCohortTimeout, setActiveLocalCohort]);

  const activeCohort =
    activeLocalCohort ?? activeCohortContext ?? standardCohortKeys.CUMULATIVE;

  const data = cohortLineData.find((d) => d.key === activeCohort)?.data ?? [];

  const { x, y, xScale, yScale, numTicksY, naResponseCount } = useMemo(() => {
    const cumulativeResponses =
      resultsByCohort[standardCohortKeys.CUMULATIVE] ?? {};
    const naResponseCount = cumulativeResponses[null] || 0;

    const maxResponses =
      max(map(cumulativeResponses, (x, key) => (key === null ? 0 : x))) || 1;

    const numTicksY = Math.min(4, Math.max(1, maxResponses));

    const xScale = scaleBand({
      range: [0, xMax],
      domain: answers.map(answerValue),
      padding: 0,
    });

    const yScale = scaleLinear({
      range: [yMax, 0],
      domain: [
        0,
        Object.keys(cumulativeResponses).length > 0
          ? Math.max(1.15, 1.15 * maxResponses)
          : 1.15,
      ],
    });

    const x = (d) => {
      switch (d) {
        case 'start':
          return 0;
        case 'end':
          return xMax;
        default:
      }

      return xScale(d.answer.value) + xScale.bandwidth() / 2;
    };

    const y = (d) => {
      switch (d) {
        case 'start':
        case 'end':
          return yMax;
        default:
      }
      return yScale(d.count);
    };

    return {
      x,
      y,
      xScale,
      yScale,
      numTicksY,
      naResponseCount,
    };
  }, [answers, resultsByCohort, xMax, yMax]);

  const applyOverlay = !resultsByCohort[standardCohortKeys.CUMULATIVE];

  const rotateAnswers = answers.length > 7;

  return (
    <Tag className={classnames(className, 'question-results')} {...props}>
      <div
        className={classnames({
          'question-results-wrapper': true,
          'no-answers-overlay': applyOverlay,
        })}
      >
        <div className="question-text">
          <p>{question.text}</p>
        </div>
        {naResponseCount !== 0 && (
          <div className="position-relative mb-4">
            <div className={styles.naBadge}>
              <span className={styles.naCount}>NA: {naResponseCount}</span>
            </div>
          </div>
        )}
        <div className="primary-results-wrapper">
          <svg viewBox={`0 0 ${width} ${height}`} className="primary-results">
            <g transform={`translate(${margin.left}, ${margin.top})`}>
              <Grid
                top={0}
                left={0}
                xScale={xScale}
                yScale={yScale}
                numTicksRows={numTicksY}
                width={xMax}
                height={yMax}
                strokeDasharray="2,2"
                stroke="rgba(200,200,200,1)"
                xOffset={xScale.bandwidth() / 2}
              />
              {cohortLineData.length > 2 ? (
                <>
                  <g className="cohort-data">
                    {cohortLineData.map((cohortData) => {
                      return (
                        <LineWithRings
                          key={cohortData.key}
                          data={cohortData.data}
                          x={x}
                          y={y}
                          curve={curve}
                          onMouseOver={memoizedHandleCohortOver(cohortData.key)}
                          onMouseOut={handleCohortOut}
                          className={`cohort-curve palette-chart-${
                            cohortColorIndices[cohortData.key].index
                          }`}
                        />
                      );
                    })}
                  </g>

                  <LineWithRings
                    data={data}
                    x={x}
                    y={y}
                    curve={curve}
                    onMouseOver={() => handleCohortOver(activeLocalCohort)}
                    onMouseOut={handleCohortOut}
                    className={`palette-chart-${cohortColorIndices[activeCohort].index}`}
                  />
                </>
              ) : (
                <LineWithRings
                  data={data}
                  x={x}
                  y={y}
                  curve={curve}
                  className={`palette-chart-${cohortColorIndices[activeCohort].index}`}
                />
              )}

              <AxisRight
                top={0}
                right={xMax}
                scale={yScale}
                numTicks={numTicksY}
                tickFormat={yScale.tickFormat(numTicksY, '.0f')}
                hideZero
                axisClassName="axis-y"
                axisLineClassName="axis-y-line"
                tickClassName="axis-y-tick"
                tickLabelProps={() => ({
                  dx: '-.5em',
                  dy: '-.5em',
                  className: 'axis-y-label',
                })}
              />
            </g>
          </svg>
          {applyOverlay && (
            <div className="no-response-overlay">
              <div>No responses</div>
            </div>
          )}
          {primaryResultsOverlayText !== undefined &&
            primaryResultsOverlayText !== null && (
              <div className="primary-results-overlay">
                {primaryResultsOverlayText}
              </div>
            )}
        </div>
        <ol
          className="answers"
          style={{ gridTemplateColumns: `repeat(${answers.length}, 1fr)` }}
        >
          {answers.map((a) => {
            return (
              <li className="answer" key={`answer-label-${a.id}`}>
                <div
                  className={classnames('answer-text', {
                    rotate: rotateAnswers,
                  })}
                  title={a.text}
                >
                  {a.text}
                </div>
              </li>
            );
          })}
        </ol>
      </div>
    </Tag>
  );
};

QuestionResults.defaultProps = {
  tag: 'article',
  width: 400,
  height: 125,
  margin: {
    top: 7,
    bottom: 5,
    left: 0,
    right: 0,
  },
};

export default QuestionResults;
