import {
  Flex, Table, Tbody, Td, Th, Thead, Tr, TableRowProps, Text, Box,
} from '@chakra-ui/react';
import React from 'react';
import {
  useReactTable,
  getCoreRowModel,
  getSortedRowModel,
  flexRender,
  SortingState,
  ColumnDef,
  Row,
  OnChangeFn,
} from '@tanstack/react-table';
import { useWindowVirtualizer } from '@tanstack/react-virtual';

import { ChevronDownIcon, ChevronUpIcon } from '@chakra-ui/icons';

export const VirtualisedSortableTable = <T extends unknown>(
  {
    data,
    columns,
    getRowProps,
    defaultSort,
    renderRow,
    manualSorting = false,
    onSortingChange,
    estimateSize,
  }
  :{
    data: T[],
    columns: ColumnDef<T, any>[],
    getRowProps?: (row:Row<T>)=> TableRowProps,
    defaultSort?: SortingState,
    renderRow?: (row:Row<T>)=> React.ReactNode,
    manualSorting?: boolean,
    onSortingChange?: OnChangeFn<SortingState>,
    estimateSize?: number;
  },
) => {
  const [sorting, setSorting] = React.useState<SortingState>(defaultSort || []);
  const listRef = React.useRef<HTMLDivElement | null>(null);

  const scrollMargin = listRef.current?.offsetTop ?? 0;
  const virtualizer = useWindowVirtualizer({
    count: data.length,
    estimateSize: () => estimateSize || 56,
    scrollMargin,
    overscan: 20,
  });
  const {
    getFlatHeaders,
    getRowModel,
  } = useReactTable(
    {
      columns,
      data,
      state: {
        sorting,
      },
      manualSorting,
      onSortingChange: (x) => {
        setSorting(x);
        if (onSortingChange) {
          onSortingChange(x);
        }
      },
      getCoreRowModel: getCoreRowModel(),
      getSortedRowModel: getSortedRowModel(),
    },
  );

  const { rows } = getRowModel();

  return (
    <Box height={`${virtualizer.getTotalSize() + 200}px`} ref={listRef}>
      <Table>
        <Thead>
          <Tr>
            {getFlatHeaders().map((header) => (
              <Th
                key={header.id}
                userSelect="none"
                cursor={header.column.getCanSort() ? 'pointer' : undefined}
                {...(header.column.columnDef?.meta?.headerProps || {})}
                _hover={header.column.getCanSort() ? {
                  color: 'magnetize.text-2',
                } : {}}
              >
                <Flex
                  alignItems="center"
                  onClick={header.column.getToggleSortingHandler()}
                >
                  {header.isPlaceholder
                    ? null
                    : flexRender(
                      header.column.columnDef.header,
                      header.getContext(),
                    )}
                  {{
                    asc: (
                      <>
                        {header.column.columnDef?.meta?.sortText?.asc && (
                        <Text pl={1}>
                          (
                          {header.column.columnDef?.meta?.sortText?.asc}
                          )
                        </Text>
                        )}
                        <ChevronUpIcon ml={1} w={4} h={4} />
                      </>),
                    desc: (
                      <>
                        {header.column.columnDef?.meta?.sortText?.desc
                      && (
                      <Text pl={1}>
                        (
                        {header.column.columnDef?.meta?.sortText?.desc}
                        )
                      </Text>
                      )}
                        <ChevronDownIcon ml={1} w={4} h={4} />
                      </>),
                  }[header.column.getIsSorted() as string] ?? null}
                </Flex>
              </Th>
            ))}
          </Tr>
        </Thead>
        <Tbody>
          {virtualizer.getVirtualItems().map((virtualRow, index) => {
            const row = rows[virtualRow.index] as Row<T>;
            return renderRow ? renderRow(row) : (
              <Tr
                key={row.id}
                {...(getRowProps ? getRowProps(row) : {})}
                transform={`translateY(${
                  virtualRow.start - index * virtualRow.size - scrollMargin
                }px)`}
              >
                {row.getVisibleCells().map((cell) => (
                  <Td
                    key={cell.id}
                    {...(cell.column.columnDef?.meta?.cellProps || {})}
                  >
                    {flexRender(
                      cell.column.columnDef.cell,
                      cell.getContext(),
                    )}
                  </Td>
                ))}
              </Tr>
            );
          })}

        </Tbody>
      </Table>
    </Box>
  );
};
