# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from typing import Dict

import torch

from . import __version__, _cpp_lib, _is_opensource, _is_triton_available, ops
from .ops.common import OPERATORS_REGISTRY
from .profiler.profiler_dcgm import DCGM_PROFILER_AVAILABLE


def get_features_status() -> Dict[str, str]:
    features = {}
    for op in OPERATORS_REGISTRY:
        status_str = "available" if op.is_available() else "unavailable"
        features[f"{op.OPERATOR_CATEGORY}.{op.NAME}"] = status_str
    for k, v in ops.swiglu_op._info().items():
        features[f"swiglu.{k}"] = v
    features["is_triton_available"] = str(_is_triton_available())
    return features


def print_info():
    features = get_features_status()
    print(f"xFormers {__version__}")
    features["pytorch.version"] = torch.__version__
    if torch.cuda.is_available():
        features["pytorch.cuda"] = "available"
        device = torch.cuda.current_device()
        cap = torch.cuda.get_device_capability(device)
        features["gpu.compute_capability"] = ".".join(str(ver) for ver in cap)
        features["gpu.name"] = torch.cuda.get_device_name(device)
    else:
        features["pytorch.cuda"] = "not available"

    features["dcgm_profiler"] = (
        "available" if DCGM_PROFILER_AVAILABLE else "unavailable"
    )

    build_info = _cpp_lib._build_metadata
    if build_info is None and isinstance(
        _cpp_lib._cpp_library_load_exception, _cpp_lib.xFormersInvalidLibException
    ):
        build_info = _cpp_lib._cpp_library_load_exception.build_info
    if build_info is not None:
        features["build.info"] = "available"
        features["build.cuda_version"] = build_info.cuda_version
        features["build.hip_version"] = build_info.hip_version
        features["build.python_version"] = build_info.python_version
        features["build.torch_version"] = build_info.torch_version
        for k, v in build_info.build_env.items():
            features[f"build.env.{k}"] = v
    else:
        features["build.info"] = "none"

    try:
        features["build.nvcc_version"] = ".".join(
            str(v) for v in torch.ops.xformers._nvcc_build_version()
        )
    except (RuntimeError, AttributeError):
        pass

    if _is_opensource:
        features["source.privacy"] = "open source"
    else:
        features["source.privacy"] = "fairinternal"

    for name, status in features.items():
        print("{:<50} {}".format(f"{name}:", status))


if __name__ == "__main__":
    print_info()
