import torch
import flash_attn
import flash_attn_interface
import itertools
import time
import math

import torch.utils.benchmark as benchmark

def round_up_to_power_of_2(x):
    if x <= 1:
        return 1
    return 1 << (x - 1).bit_length()

def timeit(fn, *args, **kwargs):
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        fn(*args, **kwargs)
    
    # Benchmark using PyTorch Timer
    t = benchmark.Timer(
        stmt='fn(*args, **kwargs)',
        globals={'fn': fn, 'args': args, 'kwargs': kwargs}
    )
    
    # Measure execution time
    measurement = t.timeit(20)  # Runs the function 20 times
    # measurement = t.blocked_autorange(min_run_time=1)
    avg_time = measurement.mean  # Average time in seconds

    return avg_time

def main():
    num_sms = torch.cuda.get_device_properties(
        torch.cuda.current_device()
    ).multi_processor_count

    max_splits = 129
    check_all_splits = False

    causal = True
    # causal = False
    # dtype=torch.float16
    dtype=torch.bfloat16

    torch.manual_seed(42)  

    model_configs = [
        # ("Gemma-2-2B", 8, 4, 256),
        # ("Gemma-2-9B", 16, 8, 256),
        # ("Gemma-2-27B", 32, 16, 128),
        # ("Qwen-2.5-0.5B", 14, 2, 64),
        # ("Qwen-2.5-1.5B", 12, 2, 128),
        # ("Qwen-2.5-7B", 28, 4, 128),
        # ("Llama-3.1-8B", 32, 8, 128),
        ("Llama-3.1-70B", 64, 8, 128),
        # ("Llama-3.1-405B", 128, 8, 128),
        # ("Llama-3.2-1B", 32, 8, 64),
        # ("Llama-3.2-3B", 24, 8, 128),
        # ("Nemotron-4-15B", 48, 8, 128),
    ]

    all_batch_configs = []

    all_batch_configs.extend(itertools.product(
        # [1024, 2048, 4096, 8192, 16384, 32768, 131072],  # context_seqlen
        [4096, 16384, 65536],  # context_seqlen
        # [131072],  # context_seqlen
        # [i for i in range(1, (num_sms) + 1)], # num_requests
        [1, 4, 8, 16],  # num_requests
        # [1],  # num_requests
        [1, 4, 8, 16],  # query_seqlen
        # [1],  # query_seqlen
    ))

    num_caches = max(reqs for _, reqs, _ in all_batch_configs)
    cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs)

    for model_name, nheads_q, nheads_kv, headdim in model_configs:
        k_cache = torch.randn(
            (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
        )
        v_cache = torch.randn(
            (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
        )
        print(f"***{model_name}***")
        print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}")
        
        if check_all_splits is False:
            print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}")

        for context_seqlen, num_requests, query_seqlen in all_batch_configs:
            bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4)
            bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4)
            blockH = round_up_to_power_of_2(nheads_q//nheads_kv)
            blockM = 128 # true for hdim 128 causal and hdim 64
            blockM_div_H = blockM//blockH
            num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H)

            q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=dtype)
            cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
            cache_seqlens = torch.tensor(
                [context_seqlen] * num_requests, dtype=torch.int32, device="cuda"
            )

            fa2_time_heuristic = timeit(
                flash_attn.flash_attn_with_kvcache,
                q=q,
                k_cache=k_cache,
                v_cache=v_cache,
                cache_seqlens=cache_seqlens,
                cache_batch_idx=cache_idxs,
                causal=causal,
            ) * 1000. * 1000.
            # fastest_splitk_time = float("inf")
            # fastest_splitk = 0
            # for i in range(1, max_splits):
            #     t = timeit(
            #         flash_attn.flash_attn_with_kvcache,
            #         q=q,
            #         k_cache=k_cache,
            #         v_cache=v_cache,
            #         cache_seqlens=cache_seqlens,
            #         cache_batch_idx=cache_idxs,
            #         causal=causal,
            #         num_splits=i,
            #     ) * 1000. * 1000.
            #     if t < fastest_splitk_time:
            #         fastest_splitk_time = t
            #         fastest_splitk = i

            fa3_time_one_split = timeit(
                flash_attn_interface.flash_attn_with_kvcache,
                q=q,
                k_cache=k_cache,
                v_cache=v_cache,
                cache_seqlens=cache_seqlens,
                cache_batch_idx=cache_idxs,
                causal=causal,
                gqa_parallel=False,
                num_splits=1,
            ) * 1000. * 1000.

            fa3_time_gqa_heuristic = timeit(
                flash_attn_interface.flash_attn_with_kvcache,
                q=q,
                k_cache=k_cache,
                v_cache=v_cache,
                cache_seqlens=cache_seqlens,
                cache_batch_idx=cache_idxs,
                causal=causal,
                gqa_parallel=True,
                num_splits=0,
                max_seqlen_k_hint=context_seqlen
            ) * 1000. * 1000.

            if check_all_splits:
            
                fa3_fastest_num_splits = 0
                fa3_fastest_splitk_time = float("inf")
                
                for num_splits in range(1, max_splits):
                    t = timeit(
                        flash_attn_interface.flash_attn_with_kvcache,
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=False,
                        num_splits=num_splits
                    ) * 1000. * 1000.

                    out0 = flash_attn_interface.flash_attn_with_kvcache(
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=False,
                        num_splits=num_splits
                    )

                    out1 = flash_attn_interface.flash_attn_with_kvcache(
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=False,
                        num_splits=1
                    )

                    max_diff = (out0 - out1).abs().max().item()
                    mean_diff = (out0 - out1).abs().mean().item()
                    # print (f"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}")
                    # print (f"splits {num_splits}, time {t:.2f}")

                    if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
                        print(f"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")

                    if t < fa3_fastest_splitk_time:
                        fa3_fastest_splitk_time = t
                        fa3_fastest_num_splits = num_splits

                fa3_fastest_num_splits_gqa = 0
                fa3_fastest_splitk_time_gqa = float("inf")
                for num_splits in range(1, max_splits):

                    t = timeit(
                        flash_attn_interface.flash_attn_with_kvcache,
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=True,
                        num_splits=num_splits
                    ) * 1000. * 1000.

                    out0 = flash_attn_interface.flash_attn_with_kvcache(
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=True,
                        num_splits=num_splits
                    )

                    out1 = flash_attn_interface.flash_attn_with_kvcache(
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=True,
                        num_splits=1
                    )

                    max_diff = (out0 - out1).abs().max().item()
                    mean_diff = (out0 - out1).abs().mean().item()
                    # print (f"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}")
                    # print (f"gqa splits {num_splits}, time {t:.2f}")

                    if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
                        print(f"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")

                    if t < fa3_fastest_splitk_time_gqa:
                        fa3_fastest_splitk_time_gqa = t
                        fa3_fastest_num_splits_gqa = num_splits
                
                efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms
                heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa
                # remeasure to smooth anomalies
                if heuristic_ratio > 1.1:

                    fa3_time_gqa_heuristic = timeit(
                        flash_attn_interface.flash_attn_with_kvcache,
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=True,
                        # num_splits=num_splits_select,
                        # num_splits=1,
                        num_splits=0,
                        max_seqlen_k_hint=context_seqlen
                    ) * 1000. * 1000.

                    fa3_fastest_splitk_time_gqa = timeit(
                        flash_attn_interface.flash_attn_with_kvcache,
                        q=q,
                        k_cache=k_cache,
                        v_cache=v_cache,
                        cache_seqlens=cache_seqlens,
                        cache_batch_idx=cache_idxs,
                        causal=causal,
                        gqa_parallel=True,
                        num_splits=fa3_fastest_num_splits_gqa
                    ) * 1000. * 1000. 

            if check_all_splits is True:
                print(
                    f"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, "
                    f"FA2:{fa2_time_heuristic:.2f}, "
                    # f"FA2 MANUAL:{fastest_splitk_time:.2f}, "
                    # f"FA2 NUM SPLITS:{fastest_splitk}, "
                    # f"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, "
                    # f"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, "
                    # f"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, "
                    f"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, "
                    f"FA3:{fa3_time_gqa_heuristic:.2f}, "
                    # f"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, "
                    # f"FA2 NUM SPLITS:{fastest_splitk}, "
                    # f"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, "
                    f"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, "
                    # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, "
                    f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, "
                    f"EFF:{efficiency:.2f}, "
                    f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}" 
                )

            if check_all_splits is False:
                print(
                    f"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}"
                    f"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}"
                    f"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}"
                    f"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}"
                )



if __name__ == "__main__":
    main()