import { useEffect, useRef } from 'react';

import {
	ColumnDef,
	OnChangeFn,
	SortingState,
	getCoreRowModel,
	getPaginationRowModel,
	useReactTable
} from '@tanstack/react-table';
import { notUndefined, useVirtualizer } from '@tanstack/react-virtual';

import { BaseSearch, Filters } from './DataTableFilters';
import { PaginationState } from './DataTablePagination';

export type SortDirection = 'asc' | 'desc';

export type SortState = {
	orderBy: string;
	direction: SortDirection;
};

export type DataTableVirtualize = {
	estimatedSize: number;
	overscan: number;
};

type ScrollPosition = {
	top: number;
	left: number;
};

export type UseDataTableProps<TData, TValue, TSearch extends BaseSearch> = {
	data: TData[];
	columns: ColumnDef<TData, TValue>[];
	pagination?: {
		pageCount: number;
		state: PaginationState;
		onPaginationChange: (state: PaginationState) => void;
	};
	search?: {
		filters: Filters<TSearch>;
		state: TSearch;
		layout: 'sidebar' | 'toolbar';
		onSearchChange: (search: TSearch) => void;
	};
	sorting?: {
		state: SortState;
		onSortingChange: (state?: SortState) => void;
	};
	virtualize?: DataTableVirtualize;
	keepScrollPosition?: boolean;
};

export function useDataTable<TData, TValue, TSearch extends BaseSearch = never>({
	data,
	columns,
	pagination,
	search,
	sorting,
	virtualize,
	keepScrollPosition
}: UseDataTableProps<TData, TValue, TSearch>) {
	const tableSorting = sorting
		? sorting.state
			? [
					{
						id: sorting.state.orderBy,
						desc: sorting.state.direction === 'desc'
					}
			  ]
			: []
		: undefined;

	const onSortingChange: OnChangeFn<SortingState> = newState => {
		let newSorting = newState;

		if (typeof newState === 'function') {
			newSorting = newState(tableSorting || []);
		}

		if (Array.isArray(newSorting) && newSorting.length) {
			sorting?.onSortingChange({
				orderBy: newSorting[0].id,
				direction: newSorting[0].desc ? 'desc' : 'asc'
			});
			return;
		}

		sorting?.onSortingChange(undefined);
	};

	const table = useReactTable({
		data,
		columns,
		pageCount: pagination?.pageCount || 1,
		onSortingChange: sorting ? onSortingChange : undefined,
		getCoreRowModel: getCoreRowModel(),
		getPaginationRowModel: getPaginationRowModel(),
		manualSorting: true,
		manualPagination: true,
		state: {
			pagination: pagination
				? { pageIndex: pagination.state.offset - 1, pageSize: pagination.state.limit }
				: undefined,
			sorting: tableSorting
		}
	});

	const { layout = 'toolbar' } = search || {};

	const parentRef = useRef<HTMLDivElement>(null);
	const scrollPositionRef = useRef<ScrollPosition>({ top: 0, left: 0 });

	useEffect(() => {
		if (parentRef.current && keepScrollPosition) {
			parentRef.current.scrollTop = scrollPositionRef.current.top;
			parentRef.current.scrollLeft = scrollPositionRef.current.left;

			const handleScroll = () => {
				scrollPositionRef.current = {
					top: parentRef.current!.scrollTop,
					left: parentRef.current!.scrollLeft
				};
			};

			parentRef.current.addEventListener('scroll', handleScroll);
			return () => {
				parentRef.current?.removeEventListener('scroll', handleScroll);
			};
		}
	}, [data, keepScrollPosition]);

	const { rows } = table.getRowModel();

	const virtualizer = useVirtualizer({
		enabled: !!virtualize,
		count: rows.length,
		getScrollElement: () => parentRef.current,
		estimateSize: () => virtualize?.estimatedSize || 0,
		overscan: virtualize?.overscan
	});

	const items = virtualizer.getVirtualItems();

	const [before, after] =
		items.length > 0
			? [
					notUndefined(items[0]).start - virtualizer.options.scrollMargin,
					virtualizer.getTotalSize() - notUndefined(items[items.length - 1]).end
			  ]
			: [0, 0];

	return {
		table,
		layout,
		parentRef,
		rows,
		virtualizer,
		items,
		before,
		after
	};
}
