#!/usr/bin/env python
# since cabal is rather picky about wildcards in 'extra-source-files',
# we have to fill this in ourselves; this script automates that
import os, re, subprocess

ensure_str_encoding = []
def ensure_str(string):
    '''Ensure that the argument is in fact a Unicode string.  If it isn't,
    then:

      - on Python 2, it will be decoded using the preferred encoding;
      - on Python 3, it will trigger a `TypeError`.
    '''
    # Python 2
    if getattr(str, "decode", None) and getattr(str, "encode", None):
        if isinstance(string, unicode):
            return string
        if not ensure_str_encoding:
            import locale
            ensure_str_encoding.append(locale.getpreferredencoding(False))
        return string.decode(ensure_str_encoding[0])
    # Python 3
    if isinstance(string, str):
        return string
    raise TypeError("not an instance of 'str': " + repr(string))

def rename(src_filename, dest_filename):
    '''Rename a file (allows overwrites on Windows).'''
    import os
    if os.name == "nt":
        import ctypes
        success = ctypes.windll.kernel32.MoveFileExW(
            ensure_str(src_filename),
            ensure_str(dest_filename),
            ctypes.c_ulong(0x1),
        )
        if not success:
            raise ctypes.WinError()
        return
    os.rename(src_filename, dest_filename)

def read_file(filename, binary=False):
    '''Read the contents of a file.'''
    with open(filename, "rb" if binary else "rt") as file:
        contents = file.read()
    if not binary:
        contents = ensure_str(contents)
    return contents

def write_file(filename, contents, binary=False, safe=True):
    '''Write the contents to a file.  Unless `safe` is false, it is performed
    as atomically as possible.  A temporary directory is used to store the
    file while it is being written.'''
    if not safe:
        if not binary:
            contents = ensure_str(contents)
        with open(filename, "wb" if binary else "wt") as file:
            file.write(contents)
        return
    import os, shutil, tempfile
    try:
        tmp_dir = tempfile.mkdtemp(
            suffix=".tmp",
            prefix="." + os.path.basename(filename) + ".",
            dir=os.path.dirname(filename),
        )
        tmp_filename = os.path.join(tmp_dir, "file.tmp")
        write_file(tmp_filename, contents, binary, safe=False)
        rename(tmp_filename, filename)
    finally:
        try:
            shutil.rmtree(tmp_dir)
        except Exception:
            pass

def find_cabal_fn():
    '''Obtain the filename of the `*.cabal` file.'''
    fns = [fn for fn in os.listdir(".")
           if re.match(r"[\w-]+.cabal$", fn) and os.path.isfile(fn)]
    if len(fns) < 1:
        raise Exception("can't find .cabal file in current directory")
    elif len(fns) > 1:
        raise Exception("too many .cabal files in current directory")
    return fns[0]

def git_tracked_files():
    '''Obtain the list of file tracked by Git.'''
    return subprocess.check_output(
        ["git", "ls-tree", "-r", "--name-only", "HEAD"]
    ).decode("utf-8").split("\n")

indent = " " * 4

dir_name = "tests"

# extensions that are always included by default
src_extensions = [
    ".hs",
    ".stderr",
    ".stdout",
]

# additional source files not covered by 'src_extensions'
srcs = [indent + fn + "\n" for fn in git_tracked_files()
        if  os.path.basename(fn) != ".gitignore"
        and fn.startswith(dir_name + "/")
        and (os.path.dirname(fn) != dir_name or
             os.path.splitext(fn)[1] not in src_extensions)]

# convert the extensions into patterns
src_patterns = [indent + dir_name + "/*" + ext + "\n"
                for ext in src_extensions]

# update the .cabal file
cabal_fn = find_cabal_fn()
contents = read_file(cabal_fn)
contents = re.sub(r"\n(\s*" + dir_name + "/.*\n)+",
                  "\n" + "".join(src_patterns + srcs),
                  contents, count=1)
write_file(cabal_fn,
           contents.encode("utf8"),
           # don't use Windows line-endings
           binary=True)
