import sys
import os
import re

print("INFO: Applying patches via patch_pytorch.py...")

def patch_file(path, search_str, replace_str, is_regex=False):
    if not os.path.exists(path):
        print(f"SKIPPING: {path} not found.")
        return
    print(f"Patching {path}...")
    with open(path, "r") as f:
        content = f.read()
    
    if is_regex:
        new_content = re.sub(search_str, replace_str, content, flags=re.DOTALL)
    else:
        new_content = content.replace(search_str, replace_str)
    
    if new_content == content:
        print(f"WARNING: No changes made to {path}. Pattern not found.")
    else:
        with open(path, "w") as f:
            f.write(new_content)
        print(f"SUCCESS: Patched {path}")

# 1. Patch Dependencies.cmake for CUB/Thrust
patch_file("cmake/Dependencies.cmake", 
           "find_package(CUB)", 
           "set(CUB_FOUND TRUE)\n  set(CUB_INCLUDE_DIRS \"${CMAKE_CURRENT_SOURCE_DIR}/third_party/cub\")\n  set(Thrust_INCLUDE_DIRS \"${CMAKE_CURRENT_SOURCE_DIR}/third_party/thrust\")\n  include_directories(SYSTEM ${Thrust_INCLUDE_DIRS})")

patch_file("cmake/Dependencies.cmake", 
           "message(FATAL_ERROR \"Cannot find CUB.\")", 
           "message(STATUS \"CUB check bypassed\")")

# 2. Patch public/cuda.cmake for nvToolsExt
target_def = "message(STATUS \"Defining dummy CUDA::nvToolsExt\")\n  add_library(CUDA::nvToolsExt INTERFACE IMPORTED)"
patch_file("cmake/public/cuda.cmake", "message(FATAL_ERROR \"Failed to find nvToolsExt\")", target_def)
patch_file("cmake/public/cuda.cmake", "message(STATUS \"Failed to find nvToolsExt (Skipping)\")", target_def)

# 3. Patch select_compute_arch.cmake to support sm_110 (11.0) and sm_120 (12.0)
patch_file("cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake",
           "list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES \"Hopper\")",
           "list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES \"Hopper\")\n  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES \"Blackwell\")")

patch_file("cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake",
           "list(APPEND CUDA_COMMON_GPU_ARCHITECTURES \"9.0\")",
           "list(APPEND CUDA_COMMON_GPU_ARCHITECTURES \"9.0\")\n  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES \"11.0\")\n  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES \"12.0\")")

# 4. Fix architecture regex to support multi-digit major versions (11.0+)
patch_file("cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake",
           "if(arch_name MATCHES \"^([0-9]\\\\.[0-9]a?(\\\\([0-9]\\\\.[0-9]\\\\))?)$\")",
           "if(arch_name MATCHES \"^([0-9]+\\\\.[0-9]a?(\\\\([0-9]+\\\\.[0-9]\\\\))?)$\")")

# 5. Patch torch/utils/cpp_extension.py
path_cpp = "torch/utils/cpp_extension.py"
if os.path.exists(path_cpp):
    print(f"Applying robust patches to {path_cpp}...")
    with open(path_cpp, 'r') as f:
        content = f.read()

    # Fix supported_arches list
    pattern_arches = r"(supported_arches\s*=\s*\[)"
    if re.search(pattern_arches, content) and "'11.0'" not in content:
        content = re.sub(pattern_arches, r"\1'11.0', '12.0', ", content)
        print("Fixed supported_arches")

    # Fix the num calculation logic (handles 11.0 -> 110 instead of 1.0)
    # Using literal replacement for the known problematic line
    old_num = 'num = arch[0] + arch[2:].split("+")[0]'
    new_num = "num = arch.split('+')[0].replace('.', '')"
    if old_num in content:
        content = content.replace(old_num, new_num)
        print("Fixed num calculation logic")

    with open(path_cpp, 'w') as f:
        f.write(content)

# 6. Downgrade nonnull warnings specifically in CMakeLists.txt
patch_file("CMakeLists.txt",
           "append_cxx_flag_if_supported(\"-Wno-error=return-type\" CMAKE_CXX_FLAGS)",
           "append_cxx_flag_if_supported(\"-Wno-error=return-type\" CMAKE_CXX_FLAGS)\n  append_cxx_flag_if_supported(\"-Wno-error=nonnull\" CMAKE_CXX_FLAGS)")

# 7. Downgrade all warnings-as-errors to allow compilation on modern GCC versions
# This targets all flag definition logic in the PyTorch source and submodules
for root, dirs, files in os.walk("."):
    for file in files:
        if file.endswith(".cmake") or file == "CMakeLists.txt":
            path = os.path.join(root, file)
            # Replace -Werror= with -Wno-error= to downgrade specific warnings
            patch_file(path, "-Werror=", "-Wno-error=")
            # Also catch plain -Werror and downgrade it
            patch_file(path, "\"-Werror\"", "\"-Wno-error\"")
            patch_file(path, " -Werror ", " -Wno-error ")

print("INFO: All patches processed.")
