import _ from 'lodash';
import { axisTop } from 'd3-axis';
import { ScaleLinear, scaleLinear } from 'd3-scale';
import { select } from 'd3-selection';
import format, { FormatOptions } from '@dha/number-format';
import { appendIfNotExists, twoSidedTickFormat } from '../helpers';
import { BaseValue } from './types';
import { newDefsID } from './helpers';

const INDICATOR_PADDING = 10;
const LABEL_PADDING = 10;

export type DotPlotDrawOptions = {
    width: number;
    height: number;
    rowHeight: number;
    rowGap: number;
    axisX: number;
    xMargin: number;
    axisSpace: number;
    negativeColor: string;
    positiveColor: string;
    negativeText: string;
    positiveText: string;
    valueHeader?: string;
    format?: FormatOptions
}

export const DEFAULT_DRAW_OPTIONS = {
    width: 400,
    height: 300,
    rowHeight: 50,
    rowGap: 8,
    axisX: 0,
    axisSpace: 50,
    xMargin: 30,
    negativeColor: 'black',
    positiveColor: 'black',
    negativeText: '',
    positiveText: '',
    valueHeader: 'Intactness (%)'
};

function getXScale(data: BaseValue[], width: number, xMargin: number) {
    const domain = [
        Math.min(_.minBy(data, 'value')?.value ?? 0, 0),
        Math.max(_.maxBy(data, 'value')?.value ?? 0, 0)
    ];

    return scaleLinear()
        .domain(domain)
        .range([xMargin, width - xMargin])
        .nice();
}

function drawAxis(
    svg: Element | null,
    xScale: ScaleLinear<number, number>,
    height: number,
    axisX: number,
    axisSpace: number,
    formatOptions?: FormatOptions
) {
    const axis = axisTop(xScale)
        .ticks(7)
        .tickSize(-height)
        .tickPadding(10)
        .tickFormat(n => twoSidedTickFormat(
            n,
            formatOptions ?? {},
            xScale.ticks(7)
        ));

    appendIfNotExists<SVGGElement>(select(svg), 'g', 'axis')
        .style('transform', `translate(${axisX}px, ${axisSpace}px)`)
        .call(axis)
        .selectAll('.tick line')
        .filter(d => d === 0)
        .classed('zero-line', true);
}

function drawDots(
    svg: Element | null,
    data: BaseValue[],
    xScale: ScaleLinear<number, number>,
    options: DotPlotDrawOptions
) {
    const {
        rowHeight,
        rowGap,
        negativeColor,
        positiveColor,
        format: formatOptions
    } = options;

    // create row groups
    const rows = select(svg)
        .selectAll<SVGGElement, BaseValue>('g.row')
        .data(data, d => d.name)
        .join('g')
        .classed('row', true);

    // add lines
    rows.selectAll('line.connector')
        .data((d, i) => [{ ...d, i }])
        .join('line')
        .classed('connector', true)
        .attr('x1', xScale(0))
        .attr('x2', d => xScale(d.value))
        .attr('y1', d => d.i * (rowHeight + rowGap) + rowHeight / 2)
        .attr('y2', d => d.i * (rowHeight + rowGap) + rowHeight / 2)
        .attr('stroke', d => (d.value < 0 ? negativeColor : positiveColor));

    // add circles
    rows.selectAll('circle.dot')
        .data((d, i) => [{ ...d, i }])
        .join('circle')
        .classed('dot', true)
        .attr('cx', d => xScale(d.value))
        .attr('cy', d => d.i * (rowHeight + rowGap) + rowHeight / 2)
        .attr('r', 7)
        .attr('fill', d => (d.value < 0 ? negativeColor : positiveColor));

    // add labels
    rows.selectAll('text.label')
        .data((d, i) => [{ ...d, i }])
        .join('text')
        .classed('label', true)
        .attr('x', d => xScale(0) + (d.value < 0 ? LABEL_PADDING : -LABEL_PADDING))
        .attr('y', d => d.i * (rowHeight + rowGap) + rowHeight / 2)
        .attr('text-anchor', d => (d.value < 0 ? 'start' : 'end'))
        .attr('fill', d => (d.value < 0 ? negativeColor : positiveColor))
        .attr('opacity', 0)
        .text(d => format(d.value, formatOptions)); /* add + to positive values */
}

function drawAxisIndicator(
    indicatorSvg: Element | null,
    xScale: ScaleLinear<number, number>,
    drawOptions: DotPlotDrawOptions
) {
    const { axisX, negativeText, positiveText, valueHeader } = drawOptions;

    if (xScale.domain()[0] < 0) {
        appendIfNotExists<SVGTextElement>(
            select(indicatorSvg),
            'text',
            'left'
        )
            .attr('x', xScale(0) - INDICATOR_PADDING + axisX)
            .attr('y', 0)
            .attr('text-anchor', 'end')
            .text(negativeText);
    } else {
        select(indicatorSvg).select('text.left').remove();
    }

    if (xScale.domain()[1] > 0) {
        appendIfNotExists<SVGTextElement>(
            select(indicatorSvg),
            'text',
            'right'
        )
            .attr('x', xScale(0) + INDICATOR_PADDING + axisX)
            .attr('y', 0)
            .text(positiveText);
    } else {
        select(indicatorSvg).select('text.right').remove();
    }

    appendIfNotExists<SVGTextElement>(
        select(indicatorSvg),
        'text',
        'label'
    )
        .attr('x', axisX - INDICATOR_PADDING)
        .attr('y', 35)
        .attr('text-anchor', 'end')
        .text(valueHeader ?? '');
}

function drawGradient(
    svg: Element | null,
    y: number,
    height: number,
    color: string
) {
    let defs = select(svg).select<SVGDefsElement>('defs');
    let gradientID = '';
    if (defs.empty()) {
        defs = select(svg).append('defs');
        // generate new element ID for this gradient since IDs are unique in the document
        gradientID = newDefsID('dot-plot-gradient');

        const gradient = defs.append('linearGradient')
            .attr('id', gradientID)
        // this gradient is calculated in pixels relative to the top left corner of the element it's applied to
            .attr('gradientUnits', 'userSpaceOnUse');
        gradient.append('stop')
            .classed('stop-top', true)
            .attr('offset', 0)
            .attr('stop-color', color);
        gradient.append('stop')
            .classed('stop-middle', true)
            .attr('stop-color', color);
        gradient.append('stop')
            .classed('stop-bottom', true)
            .attr('offset', 1)
            .attr('stop-color', 'rgba(0,0,0,0)');
    } else {
        gradientID = defs.select('linearGradient').attr('id');
    }

    defs.select('linearGradient')
        .attr('x1', 0)
        .attr('x2', 0)
        .attr('y1', 0)
        .attr('y2', height)
        .select('stop.stop-middle')
        // middle stop controls where the gradient "starts"
        .attr('offset', height === 0 ? 0 : y / height);

    select(svg)
        .select('g.axis')
        .selectAll('.tick line')
        .attr('stroke', `url(#${gradientID})`);
}

export function draw(
    svg: Element | null | undefined,
    axisSvg: Element | null | undefined,
    data: BaseValue[],
    drawOptions?: Partial<DotPlotDrawOptions>
): void {
    const options: DotPlotDrawOptions = {
        ...DEFAULT_DRAW_OPTIONS,
        ..._.pickBy(drawOptions, o => !_.isNil(o))
    };
    const { width, height, xMargin, rowHeight, rowGap, axisX, axisSpace } = options;

    const xScale = getXScale(data, width, xMargin);

    drawAxis(
        axisSvg ?? null,
        xScale,
        height,
        axisX,
        axisSpace,
        drawOptions?.format
    );
    drawAxis(
        svg ?? null,
        xScale,
        height,
        0,
        0,
        drawOptions?.format
    );
    drawGradient(
        svg ?? null,
        data.length * (rowHeight + rowGap),
        height,
        'black'
    );
    drawAxisIndicator(axisSvg ?? null, xScale, options);
    drawDots(svg ?? null, data, xScale, options);
}

export function highlightRows(
    svg: Element | null | undefined,
    names: (string | undefined)[]
): void {
    select(svg ?? null)
        .selectAll<SVGTextElement, BaseValue>('text.label')
        .attr('opacity', d => (names.includes(d.name) ? '1' : '0'));
}
