import os
from typing import Optional, Literal
from types import ModuleType
import enum
from packaging import version

import torch

# collect system information
if version.parse(torch.__version__) >= version.parse("2.0.0"):
    SDP_IS_AVAILABLE = True
else:
    SDP_IS_AVAILABLE = False

try:
    import xformers
    import xformers.ops
    XFORMERS_IS_AVAILBLE = True
except:
    XFORMERS_IS_AVAILBLE = False


class AttnMode(enum.Enum):
    SDP = 0
    XFORMERS = 1
    VANILLA = 2


class Config:
    xformers: Optional[ModuleType] = None
    attn_mode: AttnMode = AttnMode.VANILLA


# initialize attention mode
if XFORMERS_IS_AVAILBLE:
    Config.attn_mode = AttnMode.XFORMERS
    print(f"use xformers attention as default")
elif SDP_IS_AVAILABLE:
    Config.attn_mode = AttnMode.SDP
    print(f"use sdp attention as default")
else:
    print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")

if XFORMERS_IS_AVAILBLE:
    Config.xformers = xformers


# user-specified attention mode
ATTN_MODE = os.environ.get("ATTN_MODE", None)
if ATTN_MODE is not None:
    assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
    if ATTN_MODE == "sdp":
        assert SDP_IS_AVAILABLE
        Config.attn_mode = AttnMode.SDP
    elif ATTN_MODE == "xformers":
        assert XFORMERS_IS_AVAILBLE
        Config.attn_mode = AttnMode.XFORMERS
    else:
        Config.attn_mode = AttnMode.VANILLA
    print(f"set attention mode to {ATTN_MODE}")
else:
    print("keep default attention mode")