import * as React from 'react';
import {
  ColumnDef,
  ColumnMeta,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  getFilteredRowModel,
  getPaginationRowModel,
  getSortedRowModel,
  InitialTableState,
  Header,
  useReactTable,
} from '@tanstack/react-table';
import {
  HiChevronDown,
  HiChevronUp,
  HiOutlineInformationCircle,
} from 'react-icons/hi';
import { Collapse, Pagination } from '@material-ui/core';

import Show from '../show';
import Tooltip from '../tooltip';
import Switch from '../switch-match';
import TableSkeletonLoader from '../table-skeleton-loader';
import { cn } from '../../utils/tw-merge';

function TableHeader<TData>({ header }: { header: Header<TData, unknown> }) {
  type ExtendedMeta = ColumnMeta<TData, unknown> & { tooltipText?: string };

  if (header.isPlaceholder) {
    return null;
  }

  const meta = header.column.columnDef.meta as ExtendedMeta;
  const content = flexRender(
    header.column.columnDef.header,
    header.getContext(),
  );

  return (
    <div
      role="button"
      className={cn('flex items-center', {
        'cursor-pointer select-none': header.column.getCanSort(),
      })}
      onClick={header.column.getToggleSortingHandler()}
      title={
        header.column.getCanSort()
          ? header.column.getNextSortingOrder() === 'asc'
            ? 'Sort ascending'
            : header.column.getNextSortingOrder() === 'desc'
            ? 'Sort descending'
            : 'Clear sort'
          : undefined
      }
    >
      <Show when={meta?.tooltipText} fallback={content}>
        {(text) => (
          <Tooltip title={<div className="max-w-xs">{text}</div>}>
            <div className="flex items-center">
              {content}
              <HiOutlineInformationCircle className="inline-block ml-1 text-primary-dark" />
            </div>
          </Tooltip>
        )}
      </Show>

      <Show when={header.column.getCanSort()}>
        <span
          className={cn('inline-flex flex-col justify-center ml-2', {
            'text-primary-tint3': header.column.getIsSorted(),
          })}
        >
          <HiChevronUp
            className={cn({
              'text-primary-dark': header.column.getIsSorted() === 'asc',
            })}
          />
          <HiChevronDown
            className={cn('-mt-1.5', {
              'text-primary-dark': header.column.getIsSorted() === 'desc',
            })}
          />
        </span>
      </Show>
    </div>
  );
}

type RowOptions = {
  className?: string;
  onClick?: (event: React.MouseEvent<HTMLTableRowElement>) => void;
};

type CellOptions = {
  className?: string;
  onClick?: (event: React.MouseEvent<HTMLTableCellElement>) => void;
};

type PaginationProps =
  | boolean
  | {
      pageSize?: number;
      pageIndex?: number;
    };

type TableProps<TData extends object> = {
  data: Array<TData>;
  /** Columns prop allows you to define column header and cell styles for that column
   * For more details, refer: https://tanstack.com/table/latest/docs/api/core/column-def
   */
  columns: Array<ColumnDef<TData, any>>;

  isLoading?: boolean;

  /** pageOptions could be boolean or you can pass
   * page config by setting pageSize and pageIndex.
   * Default pageSize is 10 and pageIndex is 0
   */
  pageOptions?: PaginationProps;
  initialState?: InitialTableState;
  columnFilters?: Array<{ id: string; value: any }>;

  expandedRowRender?: (row: TData) => React.ReactNode;

  onRow?: (row: TData, index: number) => RowOptions;
  onCell?: (cell: unknown, row: TData) => CellOptions;
};

export default function TableV8<TData extends object>({
  data,
  columns: rawColumns,
  isLoading,
  onRow,
  onCell,
  pageOptions,
  expandedRowRender,
  initialState,
  columnFilters = [],
}: TableProps<TData>) {
  const [pagination, setPagination] = React.useState({
    pageIndex: 0,
    pageSize: 10,
    ...(typeof pageOptions === 'object' ? pageOptions : {}),
  });

  // Disable sorting by default and let users to explicitly enable sorting for the required columns
  const columns = React.useMemo(() => {
    return rawColumns.map((col) => {
      if ('enableSorting' in col) {
        return col;
      } else {
        return { ...col, enableSorting: false };
      }
    });
  }, [rawColumns]);

  const table = useReactTable({
    data,
    columns,
    state: {
      columnFilters,
      ...(pageOptions ? { pagination } : {}),
    },
    initialState: {
      ...initialState,
    },
    onPaginationChange: setPagination,
    getCoreRowModel: getCoreRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getExpandedRowModel: getExpandedRowModel(),
    ...(pageOptions ? { getPaginationRowModel: getPaginationRowModel() } : {}),
    ...(columnFilters?.length > 0
      ? { getFilteredRowModel: getFilteredRowModel() }
      : {}),
  });

  const hasFooter = React.useMemo(
    () => columns.some((column) => !!column.footer),
    [columns],
  );

  const loadingItemsNumber = table
    .getHeaderGroups()
    .reduce((acc, headerGroup) => {
      const headers = headerGroup.headers;
      return headers.length > acc ? headers.length : acc;
    }, 0);

  return (
    <div className="flex flex-col flex-1 overflow-hidden">
      <div className="flex-1 mb-4 overflow-y-auto">
        <table className="w-full">
          <thead>
            {table.getHeaderGroups().map((headerGroup) => (
              <tr key={headerGroup.id}>
                {headerGroup.headers.map((header) => (
                  <th
                    key={header.id}
                    colSpan={header.colSpan}
                    className="sticky top-0 z-10 px-4 py-3 text-xs font-normal text-left bg-primary-extraLight text-secondary-dark"
                    style={{ width: header.column.columnDef.size }}
                  >
                    <TableHeader header={header} />
                  </th>
                ))}
              </tr>
            ))}
          </thead>

          <tbody>
            <Switch>
              <Switch.Match when={isLoading}>
                <TableSkeletonLoader columns={loadingItemsNumber} />
              </Switch.Match>

              <Switch.Match when={data}>
                {table.getRowModel().rows.map((row, idx) => (
                  <React.Fragment key={row.id}>
                    <tr
                      className={cn(
                        'text-sm text-text-primary ',
                        onRow?.(row.original, idx)?.className,
                      )}
                      onClick={onRow?.(row.original, idx)?.onClick}
                    >
                      {row.getVisibleCells().map((cell) => (
                        <td
                          key={cell.id}
                          className={cn(
                            'px-4 py-2 font-normal',
                            onCell?.(cell.getValue(), row.original)?.className,
                          )}
                          onClick={
                            onCell?.(cell.getValue(), row.original)?.onClick
                          }
                        >
                          {flexRender(
                            cell.column.columnDef.cell,
                            cell.getContext(),
                          )}
                        </td>
                      ))}
                    </tr>

                    <Show when={row.getIsExpanded()}>
                      <tr>
                        <td colSpan={columns.length}>
                          <Collapse
                            unmountOnExit
                            in={row.getIsExpanded()}
                            className="bg-gray-100 border-y"
                          >
                            {expandedRowRender?.(row.original) ?? null}
                          </Collapse>
                        </td>
                      </tr>
                    </Show>
                  </React.Fragment>
                ))}
              </Switch.Match>
            </Switch>
          </tbody>

          <Show when={hasFooter}>
            <tfoot>
              {table.getFooterGroups().map((footerGroup) => (
                <tr key={footerGroup.id}>
                  {footerGroup.headers.map((header) => (
                    <td
                      key={header.id}
                      className="px-4 py-2 font-medium border-primary-extraLight border-y text-text-primary"
                    >
                      <Show when={!header.isPlaceholder}>
                        {flexRender(
                          header.column.columnDef.footer,
                          header.getContext(),
                        )}
                      </Show>
                    </td>
                  ))}
                </tr>
              ))}
            </tfoot>
          </Show>
        </table>
      </div>

      <Show when={pageOptions}>
        <Pagination
          shape="rounded"
          variant="outlined"
          count={table.getPageCount()}
          page={table.getState().pagination.pageIndex + 1}
          onChange={(_, page) => {
            table.setPageIndex(page - 1);
          }}
        />
      </Show>
    </div>
  );
}
