# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, List, Optional, Type

from pytest import mark, param, raises, skip

from hydra.core.default_element import InputDefault
from hydra.core.object_type import ObjectType
from hydra.plugins.config_source import ConfigLoadError, ConfigSource


class ConfigSourceTestSuite:
    def skip_overlap_config_path_name(self) -> bool:
        """
        Some config source plugins do not support config name and path overlap.
        For example the following may not be allowed:
        (dataset exists both as a config object and a config group)
        /dateset.yaml
        /dataset/cifar.yaml

        Overriding and returning True here will disable testing of this scenario
        by assuming the dataset config (dataset.yaml) is not present.
        """
        return False

    def test_not_available(self, type_: Type[ConfigSource], path: str) -> None:
        scheme = type_(provider="foo", path=path).scheme()
        # Test is meaningless for StructuredConfigSource
        if scheme == "structured":
            return
        src = type_(provider="foo", path=f"{scheme}://___NOT_FOUND___")
        assert not src.available()

    @mark.parametrize(
        "config_path, expected",
        [
            param("", True, id="empty"),
            param("dataset", True, id="dataset"),
            param("optimizer", True, id="optimizer"),
            param(
                "configs_with_defaults_list",
                True,
                id="configs_with_defaults_list",
            ),
            param("dataset/imagenet", False, id="dataset/imagenet"),
            param("level1", True, id="level1"),
            param("level1/level2", True, id="level1/level2"),
            param("level1/level2/nested1", False, id="level1/level2/nested1"),
            param("not_found", False, id="not_found"),
        ],
    )
    def test_is_group(
        self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool
    ) -> None:
        src = type_(provider="foo", path=path)
        ret = src.is_group(config_path=config_path)
        assert ret == expected

    @mark.parametrize(
        "config_path, expected",
        [
            ("", False),
            ("optimizer", False),
            ("dataset/imagenet", True),
            ("dataset/imagenet.yaml", True),
            ("dataset/imagenet.foobar", False),
            ("configs_with_defaults_list/global_package", True),
            ("configs_with_defaults_list/group_package", True),
            ("level1", False),
            ("level1/level2", False),
            ("level1/level2/nested1", True),
            ("not_found", False),
        ],
    )
    def test_is_config(
        self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool
    ) -> None:
        src = type_(provider="foo", path=path)
        ret = src.is_config(config_path=config_path)
        assert ret == expected

    @mark.parametrize(
        "config_path, expected",
        [
            ("dataset", True),
        ],
    )
    def test_is_config_with_overlap_name(
        self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool
    ) -> None:
        if self.skip_overlap_config_path_name():
            skip(
                f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups "
                f"with overlapping names."
            )
        src = type_(provider="foo", path=path)
        ret = src.is_config(config_path=config_path)
        assert ret == expected

    @mark.parametrize(
        "config_path,results_filter,expected",
        [
            # groups
            ("", ObjectType.GROUP, ["dataset", "level1", "optimizer"]),
            ("dataset", ObjectType.GROUP, []),
            ("optimizer", ObjectType.GROUP, []),
            ("level1", ObjectType.GROUP, ["level2"]),
            ("level1/level2", ObjectType.GROUP, []),
            # Configs
            ("", ObjectType.CONFIG, ["config_without_group"]),
            ("dataset", ObjectType.CONFIG, ["cifar10", "imagenet"]),
            ("optimizer", ObjectType.CONFIG, ["adam", "nesterov"]),
            ("level1", ObjectType.CONFIG, []),
            ("level1/level2", ObjectType.CONFIG, ["nested1", "nested2"]),
            # both
            ("", None, ["config_without_group", "dataset", "level1", "optimizer"]),
            ("dataset", None, ["cifar10", "imagenet"]),
            ("optimizer", None, ["adam", "nesterov"]),
            ("level1", None, ["level2"]),
            ("level1/level2", None, ["nested1", "nested2"]),
            ("", None, ["config_without_group", "dataset", "level1", "optimizer"]),
        ],
    )
    def test_list(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        results_filter: Optional[ObjectType],
        expected: List[str],
    ) -> None:
        src = type_(provider="foo", path=path)
        ret = src.list(config_path=config_path, results_filter=results_filter)
        for x in expected:
            assert x in ret
        assert ret == sorted(ret)

    @mark.parametrize(
        "config_path,results_filter,expected",
        [
            # Configs
            ("", ObjectType.CONFIG, ["dataset"]),
        ],
    )
    def test_list_with_overlap_name(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        results_filter: Optional[ObjectType],
        expected: List[str],
    ) -> None:
        if self.skip_overlap_config_path_name():
            skip(
                f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups "
                f"with overlapping names."
            )
        src = type_(provider="foo", path=path)
        ret = src.list(config_path=config_path, results_filter=results_filter)
        for x in expected:
            assert x in ret
        assert ret == sorted(ret)

    @mark.parametrize(
        "config_path,expected_config,expected_defaults_list,expected_package",
        [
            param(
                "config_without_group",
                {"group": False},
                None,
                None,
                id="config_without_group",
            ),
            param(
                "config_with_unicode",
                {"group": "数据库"},
                None,
                None,
                id="config_with_unicode",
            ),
            param(
                "dataset/imagenet",
                {"name": "imagenet", "path": "/datasets/imagenet"},
                None,
                None,
                id="dataset/imagenet",
            ),
            param(
                "dataset/cifar10",
                {"name": "cifar10", "path": "/datasets/cifar10"},
                None,
                None,
                id="dataset/cifar10",
            ),
            param(
                "dataset/not_found",
                raises(ConfigLoadError),
                None,
                None,
                id="dataset/not_found",
            ),
            param(
                "level1/level2/nested1",
                {"l1_l2_n1": True},
                None,
                None,
                id="level1/level2/nested1",
            ),
            param(
                "level1/level2/nested2",
                {"l1_l2_n2": True},
                None,
                None,
                id="level1/level2/nested2",
            ),
            param(
                "config_with_defaults_list",
                {
                    "defaults": [{"dataset": "imagenet"}],
                    "key": "value",
                },
                None,
                None,
                id="config_with_defaults_list",
            ),
            param(
                "configs_with_defaults_list/global_package",
                {
                    "defaults": [{"foo": "bar"}],
                    "x": 10,
                },
                None,
                "_global_",
                id="configs_with_defaults_list/global_package",
            ),
            param(
                "configs_with_defaults_list/group_package",
                {
                    "defaults": [{"foo": "bar"}],
                    "x": 10,
                },
                None,
                "_group_",
                id="configs_with_defaults_list/group_package",
            ),
        ],
    )
    def test_source_load_config(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        expected_defaults_list: List[InputDefault],
        expected_package: Any,
        expected_config: Any,
        recwarn: Any,
    ) -> None:
        assert issubclass(type_, ConfigSource)
        src = type_(provider="foo", path=path)
        if isinstance(expected_config, dict):
            ret = src.load_config(config_path=config_path)
            assert ret.config == expected_config
            assert ret.header["package"] == expected_package
            assert ret.defaults_list == expected_defaults_list
        else:
            with expected_config:
                src.load_config(config_path=config_path)

    @mark.parametrize(
        "config_path, expected_result, expected_package",
        [
            param("package_test/none", {"foo": "bar"}, None, id="none"),
            param("package_test/explicit", {"foo": "bar"}, "a.b", id="explicit"),
            param("package_test/global", {"foo": "bar"}, "_global_", id="global"),
            param("package_test/group", {"foo": "bar"}, "_group_", id="group"),
            param(
                "package_test/group_name",
                {"foo": "bar"},
                "foo._group_._name_",
                id="group_name",
            ),
            param("package_test/name", {"foo": "bar"}, "_name_", id="name"),
        ],
    )
    def test_package_behavior(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        expected_result: Any,
        expected_package: str,
    ) -> None:
        src = type_(provider="foo", path=path)
        cfg = src.load_config(config_path=config_path)
        assert cfg.header["package"] == expected_package
        assert cfg.config == expected_result

    def test_default_package_for_primary_config(
        self, type_: Type[ConfigSource], path: str
    ) -> None:
        src = type_(provider="foo", path=path)
        cfg = src.load_config(config_path="primary_config")
        assert cfg.header["package"] == None

    def test_primary_config_with_non_global_package(
        self, type_: Type[ConfigSource], path: str
    ) -> None:
        src = type_(provider="foo", path=path)
        cfg = src.load_config(config_path="primary_config_with_non_global_package")
        assert cfg.header["package"] == "foo"
