import gc
import json
import os
import pathlib
import subprocess
from unittest.mock import MagicMock, patch

import openai
import pytest
import torch
from tensorizer import EncryptionParams

from aphrodite import SamplingParams
from aphrodite.engine.args_tools import EngineArgs
# yapf: disable
from aphrodite.modeling.model_loader.tensorizer import (
    TensorizerConfig, TensorSerializer, is_aphrodite_tensorized,
    load_with_tensorizer, open_stream, serialize_aphrodite_model,
    tensorize_aphrodite_model)

from ..conftest import AphroditeRunner
from ..utils import RemoteOpenAIServer
from .conftest import retry_until_skip

# yapf conflicts with isort for this docstring


prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_aphrodite_model_for_testing.py")


def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False


def get_torch_model(aphrodite_runner: AphroditeRunner):
    return aphrodite_runner \
        .model \
        .llm_engine \
        .model_executor \
        .driver_worker \
        .model_runner \
        .model


def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)


@patch('aphrodite.modeling.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
                                  quant_method=mock_linear_method)

    mock_agent.assert_called_once_with(tensorizer_config,
                                       quant_method=mock_linear_method)
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(aphrodite_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

    with aphrodite_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=tensorized_path,
                         num_readers=1,
                         s3_endpoint="object.ord1.coreweave.com",
                     )) as loaded_hf_model:
        deserialized_outputs = loaded_hf_model.generate(prompts,
                                                        sampling_params)
        # noqa: E501

        assert deserialized_outputs


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_aphrodite_model_has_same_outputs(
        aphrodite_runner, tmp_path):
    with aphrodite_runner(model_ref) as aphrodite_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
        write_keyfile(key_path)

        outputs = aphrodite_model.generate(prompts, sampling_params)

        config_for_serializing = TensorizerConfig(
            tensorizer_uri=model_path,
            encryption_keyfile=key_path
        )
        serialize_aphrodite_model(get_torch_model(aphrodite_model),
                             config_for_serializing)

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

    with aphrodite_runner(
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=config_for_deserializing) as loaded_aphrodite_model:  # noqa: E501

        deserialized_outputs = loaded_aphrodite_model.generate(prompts,
                                                          sampling_params)
        # noqa: E501

        assert outputs == deserialized_outputs


def test_deserialized_hf_model_has_same_outputs(hf_runner, aphrodite_runner,
                                                tmp_path):
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

    with aphrodite_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=model_path,
                         num_readers=1,
                     )) as loaded_hf_model:
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)

        assert outputs == deserialized_outputs


def test_aphrodite_model_can_load_with_lora(aphrodite_runner, tmp_path):
    from huggingface_hub import snapshot_download

    from examples.offline_inference.slora_inference import (
        create_test_prompts, process_requests)

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
    with aphrodite_runner(model_ref, ) as aphrodite_model:
        model_path = tmp_path / (model_ref + ".tensors")

        serialize_aphrodite_model(get_torch_model(aphrodite_model),
                             TensorizerConfig(tensorizer_uri=model_path))

    with aphrodite_runner(
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=model_path,
                num_readers=1,
            ),
            enable_lora=True,
            max_loras=1,
            max_lora_rank=8,
            max_cpu_loras=2,
            max_num_seqs=50,
            max_model_len=1000,
    ) as loaded_aphrodite_model:
        process_requests(loaded_aphrodite_model.model.llm_engine, test_prompts)

        assert loaded_aphrodite_model


def test_load_without_tensorizer_load_format(aphrodite_runner):
    model = None
    with pytest.raises(ValueError):
        model = aphrodite_runner(
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
    del model
    gc.collect()
    torch.cuda.empty_cache()


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(aphrodite_runner, tmp_path):
    ## Serialize model
    with aphrodite_runner(model_ref, ) as aphrodite_model:
        model_path = tmp_path / (model_ref + ".tensors")

        serialize_aphrodite_model(get_torch_model(aphrodite_model),
                             TensorizerConfig(tensorizer_uri=model_path))

        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }

    ## Start OpenAI API server
    openai_args = [
        "--dtype", "float16", "--load-format",
        "tensorizer", "--model-loader-extra-config",
        json.dumps(model_loader_extra_config),
    ]

    with RemoteOpenAIServer(model_ref, openai_args) as server:
        print("Server ready.")

        client = server.get_client()
        completion = client.completions.create(model=model_ref,
                                               prompt="Hello, my name is",
                                               max_tokens=5,
                                               temperature=0.0)

        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)


def test_raise_value_error_on_invalid_load_format(aphrodite_runner):
    model = None
    with pytest.raises(ValueError):
        model = aphrodite_runner(
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
    del model
    gc.collect()
    torch.cuda.empty_cache()


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(aphrodite_runner):
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        aphrodite_runner(
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
        )


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_deserialized_encrypted_aphrodite_model_with_tp_has_same_outputs(
    aphrodite_runner, tmp_path):
    model_ref = "EleutherAI/pythia-1.4b"
    # record outputs from un-sharded un-tensorized model
    with aphrodite_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
        base_model.model.llm_engine.model_executor.shutdown()

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_aphrodite_model(
        engine_args=EngineArgs(
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

    with aphrodite_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config
            ) as loaded_aphrodite_model:
        deserialized_outputs = loaded_aphrodite_model.generate(prompts,
                                                          sampling_params)

    assert outputs == deserialized_outputs



@retry_until_skip(3)
def test_aphrodite_tensorized_model_has_same_outputs(
    aphrodite_runner, tmp_path):
    gc.collect()
    torch.cuda.empty_cache()
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

    with aphrodite_runner(model_ref) as aphrodite_model:
        outputs = aphrodite_model.generate(prompts, sampling_params)
        serialize_aphrodite_model(get_torch_model(aphrodite_model), config)

        assert is_aphrodite_tensorized(config)

    with aphrodite_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=config
                     ) as loaded_aphrodite_model:
        deserialized_outputs = loaded_aphrodite_model.generate(prompts,
                                                          sampling_params)
        # noqa: E501

        assert outputs == deserialized_outputs