import { useMemo } from 'react'

import union from 'lodash/union'
import {
    type UseGetManyHookValue,
    type DataProvider,
    useDataProvider,
    type GetManyParams,
} from 'ra-core'
import {
    type QueryClient,
    useQueryClient,
    useQuery,
    type UseQueryOptions,
    hashQueryKey,
} from 'react-query'

import { type DataRecord, type Identifier } from 'appTypes'

export const useGetManyAggregate = <RecordType extends DataRecord = any>(
    resource: string,
    params: GetManyParams,
    options: UseQueryOptions<RecordType[], Error> = {},
): UseGetManyHookValue<RecordType> => {
    const dataProvider = useDataProvider()
    const queryClient = useQueryClient()
    const queryCache = queryClient.getQueryCache()
    const { ids, meta } = params
    const placeholderData = useMemo(() => {
        const records = (Array.isArray(ids) ? ids : [ids]).map((id) => {
            const queryHash = hashQueryKey([resource, 'getOne', { id: String(id), meta }])
            return queryCache.get<RecordType>(queryHash)?.state?.data
        })
        if (records.some((record) => record === undefined)) {
            return undefined
        }
        return records as RecordType[]
    }, [ids, queryCache, resource, meta])

    return useQuery<RecordType[], Error, RecordType[]>(
        [
            resource,
            'getMany',
            {
                ids: (Array.isArray(ids) ? ids : [ids]).map((id) => String(id)),
                meta,
            },
        ],
        () =>
            new Promise((resolve, reject) => {
                if (!ids || ids.length === 0) {
                    // no need to call the dataProvider
                    return resolve([])
                }
                // debounced / batched fetch
                return callGetManyQueries({
                    resource,
                    ids,
                    meta,
                    resolve,
                    reject,
                    dataProvider,
                    queryClient,
                })
            }),
        {
            placeholderData,
            ...options,
            onSuccess: (data) => {
                // optimistically populate the getOne cache
                ;(data ?? []).forEach((record) => {
                    queryClient.setQueryData(
                        [resource, 'getOne', { id: String(record.id), meta }],
                        (oldRecord) => oldRecord ?? record,
                    )
                })

                options?.onSuccess?.(data)
            },
            retry: false,
        },
    )
}

const batch = (fn) => {
    let capturedArgs: any[] = []
    let timeout: ReturnType<typeof setTimeout> | null = null
    return (arg: any) => {
        capturedArgs.push(arg)
        if (timeout) {
            clearTimeout(timeout)
        }
        timeout = setTimeout(() => {
            timeout = null
            fn([...capturedArgs])
            capturedArgs = []
        }, 0)
    }
}

interface GetManyCallArgs {
    resource: string
    ids: Identifier[]
    meta?: any
    resolve: (data: any[]) => void
    reject: (error?: any) => void
    dataProvider: DataProvider
    queryClient: QueryClient
}

const callGetManyQueries = batch((calls: GetManyCallArgs[]) => {
    const dataProvider = calls[0].dataProvider
    const queryClient = calls[0].queryClient

    const callsByResource = calls.reduce(
        (acc, callArgs) => {
            if (!acc[callArgs.resource]) {
                acc[callArgs.resource] = []
            }
            acc[callArgs.resource].push(callArgs)
            return acc
        },
        {} as { [resource: string]: GetManyCallArgs[] },
    )

    Object.keys(callsByResource).forEach((resource) => {
        const callsForResource = callsByResource[resource]

        const aggregatedIds = callsForResource
            .reduce((acc, { ids }) => union(acc, ids), []) // concat + unique
            .filter((v) => v != null && v !== '') // remove null values

        const uniqueMeta = callsForResource.reduce((acc, { meta }) => meta || acc, undefined)

        if (aggregatedIds.length === 0) {
            // no need to call the data provider if all the ids are null
            callsForResource.forEach(({ resolve }) => {
                resolve([])
            })
            return
        }

        const callThatHasAllAggregatedIds = callsForResource.find(
            ({ ids }) => JSON.stringify(ids) === JSON.stringify(aggregatedIds),
        )
        if (callThatHasAllAggregatedIds) {
            const { dataProvider, resource, ids, meta } = callThatHasAllAggregatedIds

            dataProvider
                .getMany<any>(resource, { ids, meta })
                .then(({ data }) => data)
                .then(
                    (data) => {
                        // We must then resolve all the pending calls with the data they requested
                        callsForResource.forEach(({ ids, resolve }) => {
                            resolve(
                                data.filter((record) =>
                                    ids.map((id) => String(id)).includes(String(record.id)),
                                ),
                            )
                        })
                    },
                    (error) => {
                        // All pending calls must also receive the error
                        callsForResource.forEach(({ reject }) => {
                            reject(error)
                        })
                    },
                )
            return
        }

        queryClient
            .fetchQuery<any[], Error, any[]>(
                [
                    resource,
                    'getMany',
                    {
                        ids: aggregatedIds.map((id) => String(id)),
                        meta: uniqueMeta,
                    },
                ],
                () =>
                    dataProvider
                        .getMany<any>(resource, {
                            ids: aggregatedIds,
                            meta: uniqueMeta,
                        })
                        .then(({ data }) => data),
            )
            .then((data) => {
                callsForResource.forEach(({ ids, resolve }) => {
                    resolve(
                        data.filter((record) =>
                            ids.map((id) => String(id)).includes(String(record.id)),
                        ),
                    )
                })
            })
            .catch((error) => callsForResource.forEach(({ reject }) => reject(error)))
    })
})
