# Copyright 2025 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from dataclasses import dataclass, field
from enum import IntEnum

import torch

from ...utils import is_psutil_available, is_torch_xpu_available
from ...utils.logging import logging
from ...utils.metrics import traced


if is_psutil_available():
    import psutil

# This is a temporary token ID used to represent a token that is not yet generated
TMP_TOKEN_ID = -1


# We centralize the logger here to coordinate between logging and progress bar
logger = logging.getLogger("ContinuousBatchingLogger")


def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        total_memory = torch.cuda.get_device_properties(device).total_memory
        reserved_memory = torch.cuda.memory_reserved(device)
        allocated_memory = torch.cuda.memory_allocated(device)
    elif is_torch_xpu_available():
        device = torch.device("xpu")
        torch.xpu.empty_cache()
        torch.xpu.synchronize()
        total_memory = torch.xpu.get_device_properties(device).total_memory
        reserved_memory = torch.xpu.memory_reserved(device)
        allocated_memory = torch.xpu.memory_allocated(device)
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = torch.device("mps")
        # MPS memory reporting (PyTorch 2.0+)
        total_memory = torch.mps.driver_allocated_memory()
        allocated_memory = total_memory - torch.mps.recommended_max_memory()
        reserved_memory = 0  # MPS does not track reserved separately
    else:
        device = torch.device("cpu")
        if is_psutil_available():
            total_memory = psutil.virtual_memory().total
            allocated_memory = psutil.Process().memory_info().rss
            reserved_memory = allocated_memory
        else:
            logger.error(
                "Cannot get memory breakdown on CPU without psutil: returning 0 for all memory values. Please install "
                "psutil to get an actual memory breakdown."
            )
            total_memory = 0
            reserved_memory = 0
            allocated_memory = 0

    return device, total_memory, reserved_memory, allocated_memory


class RequestStatus(IntEnum):
    """Status of a generation request through its lifecycle."""

    PENDING = 0
    PREFILLING = 1
    DECODING = 2
    FINISHED = 3
    FAILED = 4


@dataclass
class GenerationOutput:
    """Tracks the output of a generation request.

    Attributes:
        request_id (str): The ID of the generation request.
        prompt_ids (list[int]): The IDs of the prompt tokens.
        generated_tokens (list[int]): The generated tokens.
        logprobs (list[float]): The log probabilities of the generated tokens.
        error (Optional[str]): Any error message associated with the request. When None, the request was successful.
        status (RequestStatus): The status of the request.
        created_time (float): The time the request was created.
        lifespan (tuple[float, float]): The time the request was no longer pending and the time the request finished.
    """

    request_id: str
    prompt_ids: list[int] = field(default_factory=list)
    generated_tokens: list[int] = field(default_factory=list)
    logprobs: list[float] = field(default_factory=list)
    error: str | None = None
    status: RequestStatus = RequestStatus.PENDING
    created_time: float = field(default_factory=time.perf_counter)
    lifespan: tuple[float, float] = (-1, -1)  # (time request was no longer pending, time request finished)
    timestamps: list[float] | None = None  # Timestamps of the generated tokens

    def is_finished(self) -> bool:
        return self.status == RequestStatus.FINISHED


@dataclass
class RequestState:
    """Tracks the state of a generation request through its lifecycle.

    Attributes:
        request_id (str): The ID of the generation request.
        initial_tokens (list[int]): The initial prompt tokens.
        num_children (int): The number of children requests
        full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
        prompt_ids (list[int] | None): The tokens IDs currently being processed.
        remaining_prompt_ids (list[int]): The initial tokens IDs remaining to be processed.
        static_outputs (list[int]): The generated tokens.
        allocated_blocks (int): The number of blocks allocated to the request.
        position_offset (int): The current position in the sequence for position_ids.
        status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
                                SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
        max_new_tokens (int | None): The maximum number of new tokens to generate.
        eos_token_id (int): The ID of the end-of-sequence token.
        streaming (bool): Whether to stream tokens as they're generated
        created_time (float): The time the request was created.
        error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
    """

    # Required fields
    request_id: str
    initial_tokens: list[int]  # Initial prompt tokens # TODO: rename this as prefill tokens
    # Optional fields
    record_timestamps: bool = False  # Whether to record timestamps for the generated tokens
    num_children: int = 0  # Number of children requests
    # Internal fields
    tokens_to_process: list[int] = field(default_factory=list)  # Tokens IDs currently being processed
    remaining_prefill_tokens: list[int] = field(
        default_factory=list
    )  # Initial tokens left to process (initialized in __post_init__)
    generated_tokens: list[int] = field(default_factory=list)  # Generated tokens
    allocated_blocks: int = 0  # Number of blocks allocated to the request
    position_offset: int = 0  # Current position in the sequence for position_ids
    _status: RequestStatus = RequestStatus.PENDING  # Status of the request, hidden behind a property
    max_new_tokens: int | None = 20  # Maximum number of new tokens to generate. None means no limit. Default to 20.
    eos_token_id: int = -1  # ID of the end-of-sequence token
    streaming: bool = False  # Whether to stream tokens as they're generated
    created_time: float = field(default_factory=time.perf_counter)  # Time the request was created
    error: str | None = None  # Error message if the request failed
    lifespan: tuple[float, float] = (-1, -1)  # (time request was no longer pending, time request finished)
    _timestamps: list[float] = field(default_factory=list)  # Timestamps of the generated tokens
    _true_initial_tokens: int = 0  # The true number of initial tokens, useful when soft resetting requests
    # TODO: remove the attribute above to _num_initial_tokens once initial_tokens is renamed
    _new_tokens_limit: int = 2147483647  # An int to check the max number of new tokens w/out always comparing w/ None

    def __post_init__(self):
        # If no max length is set, we set an absurdly high value which will never be reached
        self._new_tokens_limit = 2147483647 if self.max_new_tokens is None else self.max_new_tokens
        self.remaining_prefill_tokens = self.initial_tokens[:]

    @property
    def status(self) -> RequestStatus:
        return self._status

    @status.setter
    def status(self, value: RequestStatus):
        if self._status == RequestStatus.PENDING:
            self.lifespan = (time.perf_counter(), -1)
        elif value == RequestStatus.FINISHED:
            self.lifespan = (self.lifespan[0], time.perf_counter())
            self.log_end_of_request()
        self._status = value

    @property
    def timestamps(self) -> list[float] | None:
        return self._timestamps if self.record_timestamps else None

    def log_end_of_request(self):
        prefill_len = len(self.initial_tokens)
        decode_len = self.generated_len()
        start_time = self.lifespan[0] - self.created_time
        end_time = self.lifespan[1] - self.created_time
        logger.info(
            f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
        )

    def current_len(self) -> int:
        """Get the current length of the sequence (prompt + generated tokens)."""
        return self.position_offset

    def generated_len(self) -> int:
        """Get the number of tokens generated so far."""
        return len(self.generated_tokens)

    # TODO: this logic seems one token off, check it out
    @traced
    def update_and_check_completion(self, token_id: int) -> bool:
        """Update the request with a newly generated token and check for completion.

        Args:
            token_id: The token ID to add to the output sequence

        Returns:
            bool: True if the request is now complete, False otherwise
        """
        # Only update if we're in decoding state # TODO: seems useless (always true) -- remove this
        if self.status != RequestStatus.DECODING:
            return False

        # If we're recording timestamps, add timestamp to the list
        if self.record_timestamps:
            self._timestamps.append(time.perf_counter())

        # Stop if we reached an EOS token
        is_eos = token_id == self.eos_token_id and self.eos_token_id != -1
        current_len = self.generated_len()

        # Replace the temporary token if we're not finishing due to max length
        # (EOS tokens should still be added to the output)
        if is_eos or (current_len < self._new_tokens_limit):
            self.generated_tokens.append(token_id)
            self.tokens_to_process = [token_id]  # this works for 2 levels of pipelines, but not sure for more
            current_len += 1
        else:
            logger.warning(f"Request {self.request_id} generated a useless token: {token_id}")
            self.generated_tokens.pop()

        if is_eos or current_len >= self._new_tokens_limit:
            self.status = RequestStatus.FINISHED
            return True
        return False  # We still need to process more tokens

    def __repr__(self):
        msg = [
            f"request_id={self.request_id}",
            f"status={self._status}",
            f"out_tokens={self.generated_len()}",
            f"query_length={len(self.tokens_to_process)}",
            f"remaining_tokens={len(self.remaining_prefill_tokens)}",
            f"kv_length={self.position_offset}",
            f"full_prompt_length={len(self.initial_tokens)}",
            f"allocated_blocks={self.allocated_blocks}",
            f"generated_tokens={self.generated_tokens}",
        ]
        return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)"

    def to_generation_output(self):
        """Convert the request state to a GenerationOutput object."""
        if self._true_initial_tokens:
            self.generated_tokens = self.initial_tokens[self._true_initial_tokens :] + self.generated_tokens
            self.initial_tokens = self.initial_tokens[: self._true_initial_tokens]
        return GenerationOutput(
            request_id=self.request_id,
            prompt_ids=self.initial_tokens,
            generated_tokens=self.generated_tokens,
            logprobs=[],
            error=self.error,
            status=self.status,
            created_time=self.created_time,
            lifespan=self.lifespan,
            timestamps=self.timestamps,
        )

    def fork(self, new_request_id: str) -> "RequestState":
        """Fork the request into a new request with the same state expect for request_id, created_time and lifespan."""
        t = time.perf_counter()
        new_request = RequestState(
            request_id=new_request_id,
            initial_tokens=self.initial_tokens,
            num_children=self.num_children,
            tokens_to_process=self.tokens_to_process[:],
            generated_tokens=self.generated_tokens[:],
            allocated_blocks=self.allocated_blocks,
            position_offset=self.position_offset,
            _status=self.status,
            max_new_tokens=self.max_new_tokens,
            eos_token_id=self.eos_token_id,
            streaming=self.streaming,
            created_time=t,
            lifespan=(t, -1),
            _timestamps=[],
            error=self.error,
            record_timestamps=self.record_timestamps,
        )
        # Modified by __post_init__
        new_request.remaining_prefill_tokens = self.remaining_prefill_tokens[:]
        return new_request

    def create_equivalent_initial_request(self) -> "RequestState":
        """Creates an equivalent new request by removing the generated tokens and adding them to the initial prompt. The
        created request has THE SAME request_id. Notably, we can retrieve the original request from the created one with
        the _true_initial_tokens attribute."""
        max_new_tokens = None if self.max_new_tokens is None else (self.max_new_tokens - len(self.generated_tokens))
        new_state = RequestState(
            request_id=self.request_id,
            initial_tokens=self.initial_tokens + self.generated_tokens,
            num_children=self.num_children,
            record_timestamps=self.record_timestamps,
            max_new_tokens=max_new_tokens,
            eos_token_id=self.eos_token_id,
            streaming=self.streaming,
        )
        new_state._true_initial_tokens = self._true_initial_tokens + len(self.initial_tokens)
        return new_state


class FutureRequestState:
    """Tracks the current state of a request and the relevant information to update it."""

    # This makes instantiating this class faster
    __slots__ = ("state", "has_new_token", "complete_blocks")

    def __init__(self, state: RequestState, has_new_token: bool, complete_blocks: int) -> None:
        self.state = state
        self.has_new_token = has_new_token
        self.complete_blocks = complete_blocks
