import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

from aphrodite.common.sequence import SampleLogprobs

TokensText = Tuple[List[int], str]


def check_outputs_equal(
    *,
    outputs_0_lst: Sequence[TokensText],
    outputs_1_lst: Sequence[TokensText],
    name_0: str,
    name_1: str,
):
    """
    Compare the two sequences generated by different models, 
    which should be equal.
    """
    assert len(outputs_0_lst) == len(outputs_1_lst)

    for prompt_idx, (outputs_0,
                     outputs_1) in enumerate(zip(outputs_0_lst,
                                                 outputs_1_lst)):
        output_ids_0, output_str_0 = outputs_0
        output_ids_1, output_str_1 = outputs_1

        # The text and token outputs should exactly match
        fail_msg = (f"Test{prompt_idx}:"
                    f"\n{name_0}:\t{output_str_0!r}"
                    f"\n{name_1}:\t{output_str_1!r}")

        assert output_str_0 == output_str_1, fail_msg
        assert output_ids_0 == output_ids_1, fail_msg


TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
                                                                    float]],
                                                          SampleLogprobs]]]


def check_logprobs_close(
    *,
    outputs_0_lst: Sequence[TokensTextLogprobs],
    outputs_1_lst: Sequence[TokensTextLogprobs],
    name_0: str,
    name_1: str,
    num_outputs_0_skip_tokens: int = 0,
    warn_on_mismatch: bool = True,
):
    """
    Compare the logprobs of two sequences generated by different models,
    which should be similar but not necessarily equal.

    Arguments:

    * outputs_0_lst: First sequence to compare
    * outputs_0_lst: Second sequence to compare
    * name_0: sequence #0 name
    * name_1: sequence #1 name
    * num_outputs_0_skip_tokens: If > 0, specifies the number of initial
                                 sequence #0 tokens & logprobs to discard
                                 before comparison, i.e. all
                                 of sequence #1 will be compared to
                                 sequence #0 beginning at index
                                 num_outputs_0_skip_tokens
    * warn_on_mismatch: Issue a warning if there is token-wise or text-wise
                        mismatch between the two sequences
    """
    assert len(outputs_0_lst) == len(outputs_1_lst)

    # Loop through responses to each prompt.
    for prompt_idx, (outputs_0,
                     outputs_1) in enumerate(zip(outputs_0_lst,
                                                 outputs_1_lst)):
        output_ids_0, output_str_0, logprobs_0 = outputs_0
        output_ids_1, output_str_1, logprobs_1 = outputs_1

        if logprobs_0 is None:
            logprobs_0 = [None] * len(output_ids_0)
        if logprobs_1 is None:
            logprobs_1 = [None] * len(output_ids_1)

        # Skip specified number of initial sequence #0 tokens
        # & logprobs, leaving output text as-is for simplicity
        # (text mismatches may generate warnings but do not
        # cause the test to fail.)
        if num_outputs_0_skip_tokens < 0:
            raise ValueError("num_outputs_0_skip_tokens must be non-negative")
        output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
        logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]

        # Loop through generated tokens.
        for idx, (output_id_0,
                  output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):

            # If generated tokens don't match, then
            if output_id_0 != output_id_1:
                logprobs_elem_0 = logprobs_0[idx]
                logprobs_elem_1 = logprobs_1[idx]

                # Each predicted token must be in top N logprobs of the other
                fail_msg = (
                    f"Test{prompt_idx}:"
                    f"\nMatched tokens:\t{output_ids_0[:idx]}"
                    f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
                    f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")

                assert logprobs_elem_0 is not None, fail_msg
                assert logprobs_elem_1 is not None, fail_msg
                assert output_id_0 in logprobs_elem_1, fail_msg
                assert output_id_1 in logprobs_elem_0, fail_msg

                if warn_on_mismatch:
                    with warnings.catch_warnings():
                        # This ensures that repeated warnings are shown
                        # in the output, not just the first occurrence
                        warnings.simplefilter("always")

                        warnings.warn(fail_msg, stacklevel=2)

                # Break out since sequences will now diverge.
                break
        else:
            if output_str_0 != output_str_1 and warn_on_mismatch:
                # The token outputs exactly match,
                # so the text outputs should exactly match as well
                fail_msg = (f"Test{prompt_idx}:"
                            f"\n{name_0}:\t{output_str_0!r}"
                            f"\n{name_1}:\t{output_str_1!r}")

                with warnings.catch_warnings():
                    # This ensures that repeated warnings are shown
                    # in the output, not just the first occurrence
                    warnings.simplefilter("always")

                    warnings.warn(fail_msg, stacklevel=2)