from abc import ABC, abstractmethod
from typing import Any, Optional

from torch import Tensor
from typing_extensions import Self

from lightning_fabric.utilities.types import CollectibleGroup


class Collective(ABC):
    """Interface for collective operations.

    Supports communications between multiple processes and multiple nodes. A collective owns a group.

    .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature which is still in development.

    """

    def __init__(self) -> None:
        self._group: Optional[CollectibleGroup] = None

    @property
    @abstractmethod
    def rank(self) -> int:
        """Rank."""

    @property
    @abstractmethod
    def world_size(self) -> int:
        """World size."""

    @property
    def group(self) -> CollectibleGroup:
        if self._group is None:
            raise RuntimeError(
                f"`{type(self).__name__}` does not own a group. HINT: try `collective.create_group().group`"
            )
        return self._group

    @abstractmethod
    def broadcast(self, tensor: Tensor, src: int) -> Tensor: ...

    @abstractmethod
    def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ...

    @abstractmethod
    def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ...

    @abstractmethod
    def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: ...

    @abstractmethod
    def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: ...

    @abstractmethod
    def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: ...

    @abstractmethod
    def reduce_scatter(self, output: Tensor, input_list: list[Tensor], op: str) -> Tensor: ...

    @abstractmethod
    def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: ...

    @abstractmethod
    def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ...

    @abstractmethod
    def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: ...

    @abstractmethod
    def barrier(self, device_ids: Optional[list[int]] = None) -> None: ...

    @classmethod
    @abstractmethod
    def is_available(cls) -> bool: ...

    @classmethod
    @abstractmethod
    def is_initialized(cls) -> bool: ...

    @classmethod
    @abstractmethod
    def init_group(cls, **kwargs: Any) -> None: ...

    @classmethod
    @abstractmethod
    def new_group(cls, **kwargs: Any) -> CollectibleGroup: ...

    @classmethod
    @abstractmethod
    def destroy_group(cls, group: CollectibleGroup) -> None: ...

    @classmethod
    @abstractmethod
    def _convert_to_native_op(cls, op: str) -> Any: ...

    def setup(self, **kwargs: Any) -> Self:
        if not self.is_initialized():
            self.init_group(**kwargs)
        return self

    def create_group(self, **kwargs: Any) -> Self:
        """Create a group.

        This assumes that :meth:`~lightning_fabric.plugins.collectives.Collective.init_group` has been
        called already by the user.

        """
        if self._group is not None:
            raise RuntimeError(f"`{type(self).__name__}` already owns a group.")
        self._group = self.new_group(**kwargs)
        return self

    def teardown(self) -> Self:
        if self._group is None:
            raise RuntimeError(f"`{type(self).__name__}` does not own a group to destroy.")
        self.destroy_group(self._group)
        self._group = None
        return self
