# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------

import coloredlogs
from cuda import cudart
from demo_utils import init_pipeline, parse_arguments, repeat_prompt
from diffusion_models import PipelineInfo
from engine_builder import EngineType, get_engine_type
from pipeline_img2img_xl import Img2ImgXLPipeline
from pipeline_txt2img_xl import Txt2ImgXLPipeline

if __name__ == "__main__":
    coloredlogs.install(fmt="%(funcName)20s: %(message)s")
    args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo")
    prompt, negative_prompt = repeat_prompt(args)

    image_height = args.height
    image_width = args.width

    # Register TensorRT plugins
    engine_type = get_engine_type(args.engine)
    if engine_type == EngineType.TRT:
        from trt_utilities import init_trt_plugins

        init_trt_plugins()

    max_batch_size = 16
    if args.build_dynamic_shape or image_height > 512 or image_width > 512:
        max_batch_size = 4

    batch_size = len(prompt)
    if batch_size > max_batch_size:
        raise ValueError(
            f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
        )

    base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner)
    base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size)

    if args.enable_refiner:
        refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True)
        refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size)

        if engine_type == EngineType.TRT:
            max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory())
            _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
            base.backend.activate_engines(shared_device_memory)
            refiner.backend.activate_engines(shared_device_memory)

        base.load_resources(image_height, image_width, batch_size)
        refiner.load_resources(image_height, image_width, batch_size)
    else:
        if engine_type == EngineType.TRT:
            max_device_memory = max(base.backend.max_device_memory(), base.backend.max_device_memory())
            _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
            base.backend.activate_engines(shared_device_memory)

        base.load_resources(image_height, image_width, batch_size)

    def run_sd_xl_inference(enable_refiner: bool, warmup=False):
        images, time_base = base.run(
            prompt,
            negative_prompt,
            image_height,
            image_width,
            warmup=warmup,
            denoising_steps=args.denoising_steps,
            guidance=args.guidance,
            seed=args.seed,
            return_type="latents" if enable_refiner else "images",
        )

        if enable_refiner:
            images, time_refiner = refiner.run(
                prompt,
                negative_prompt,
                images,
                image_height,
                image_width,
                warmup=warmup,
                denoising_steps=args.denoising_steps,
                guidance=args.guidance,
                seed=args.seed,
            )
            return images, time_base + time_refiner
        else:
            return images, time_base

    if not args.disable_cuda_graph:
        # inference once to get cuda graph
        images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True)

    print("[I] Warming up ..")
    for _ in range(args.num_warmup_runs):
        images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True)

    print("[I] Running StableDiffusion XL pipeline")
    if args.nvtx_profile:
        cudart.cudaProfilerStart()
    images, pipeline_time = run_sd_xl_inference(args.enable_refiner, warmup=False)
    if args.nvtx_profile:
        cudart.cudaProfilerStop()

    base.teardown()

    if args.enable_refiner:
        print("|------------|--------------|")
        print("| {:^10} | {:>9.2f} ms |".format("e2e", pipeline_time))
        print("|------------|--------------|")
        refiner.teardown()
