# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import argparse
from datetime import datetime
from pathlib import Path
from typing import Any, Mapping, Optional

import pandas as pd

from nsys_recipe import log
from nsys_recipe.data_service import DataService
from nsys_recipe.lib import helpers, recipe
from nsys_recipe.lib.args import Option
from nsys_recipe.log import logger


class FileAccessStatistics(recipe.Recipe):
    @staticmethod
    def _mapper_func(
        report_path: str, parsed_args: argparse.Namespace
    ) -> Optional[tuple[str, pd.DataFrame, pd.DataFrame]]:
        service = DataService(report_path, parsed_args)

        service.queue_table("OSRT_FILE_ACCESS_DESCRIPTORS")
        service.queue_table("OSRT_FILE_ACCESS_EVENTS")
        service.queue_table("ENUM_OSRT_FILE_ACCESS_EVENT_TYPE")
        service.queue_table("MPI_RANKS")
        service.queue_table("TARGET_INFO_SYSTEM_ENV", ["name", "value"])
        service.queue_table("StringIds")

        df_dict = service.read_queued_tables()
        if df_dict is None:
            return None

        metadata_df, events_df = FileAccessStatistics.get_and_resolve_dfs(df_dict)

        # Add thread id to file access metadata first as its used to get the MPI ranks
        # Note technically file access might be used from multiple threads, but we picked the first one
        tids_df = events_df.groupby("fileAccessId", as_index=False).agg(
            {"threadId": "first"}
        )
        metadata_with_tids_df = metadata_df.merge(
            tids_df, on="fileAccessId", how="inner"
        )

        # Add MPI rank if available, otherwise set to NaN for consistent schema
        if "MPI_RANKS" in df_dict and not df_dict["MPI_RANKS"].empty:
            mpi_ranks_df = df_dict["MPI_RANKS"].copy()

            # Decode globalTid back into processId and threadId components for matching
            mpi_ranks_df["decodedProcessId"] = (
                mpi_ranks_df["globalTid"].astype(int).to_numpy() >> 24
            ) & 0x00FFFFFF
            mpi_ranks_df["decodedThreadId"] = (
                mpi_ranks_df["globalTid"].astype(int).to_numpy() & 0x00FFFFFF
            )

            # Merge with MPI ranks on both processId and threadId, but only bring over the rank
            metadata_with_tids_df = (
                metadata_with_tids_df.merge(
                    mpi_ranks_df[["decodedProcessId", "decodedThreadId", "rank"]],
                    left_on=["processId", "threadId"],
                    right_on=["decodedProcessId", "decodedThreadId"],
                    how="left",
                )
                .drop(columns=["decodedProcessId", "decodedThreadId"])
                .rename(columns={"rank": "mpiRank"})
            )
            # Ensure missing MPI Rank values are set to pd.NA
            metadata_with_tids_df["mpiRank"] = metadata_with_tids_df["mpiRank"].fillna(
                pd.NA
            )
        else:
            # Add mpiRank column with pd.NA values when MPI data is not available
            metadata_with_tids_df["mpiRank"] = pd.NA

        # Add the target's machine name
        hostname_df = df_dict["TARGET_INFO_SYSTEM_ENV"]
        hostname_values = hostname_df.query("name == 'Hostname'")["value"]
        metadata_with_tids_df["machineName"] = (
            hostname_values.values[0] if not hostname_values.empty else "Unknown"
        )

        # Create metadata_events_complete_df after MPI rank is added
        # Remove threadId to avoid duplicates, as events_df already has it and it will be re-added
        metadata_for_merge = metadata_with_tids_df.drop(columns=["threadId"])
        metadata_events_complete_df = metadata_for_merge.merge(
            events_df, on="fileAccessId", how="inner"
        )

        if metadata_events_complete_df.empty:
            logger.warning(
                f"""{report_path}: Report was processed successfully but did not contain OSRT file access data.
                            Was the report recorded with '--trace=osrt' and '--osrt-file-access=true' options?"""
            )
            return None

        file_statistics_df = metadata_with_tids_df[
            [
                "fileAccessId",
                "processId",
                "threadId",
                "mpiRank",
                "filePath",
                "openedAt",
                "closedAt",
                "machineName",
            ]
        ]

        (read_counters_df, write_counters_df) = (
            FileAccessStatistics.get_counters_statistics(events_df)
        )

        # add the counters statistics columns to the metadata table and remove the fileAccessId column
        file_statistics_df = (
            file_statistics_df.merge(
                read_counters_df, on=["fileAccessId", "threadId"], how="left"
            )
            .merge(write_counters_df, on=["fileAccessId", "threadId"], how="left")
            .drop(columns=["fileAccessId"])
        )

        metadata_events_complete_df.index.name = "Event ID"
        file_statistics_df.index.name = "File Access ID"

        filename = Path(report_path).stem

        return filename, file_statistics_df, metadata_events_complete_df

    @staticmethod
    def get_and_resolve_dfs(
        df_dict: Mapping[Any, pd.DataFrame],
    ) -> tuple[pd.DataFrame, pd.DataFrame]:

        metadata_df = df_dict["OSRT_FILE_ACCESS_DESCRIPTORS"]

        events_df = (
            df_dict["OSRT_FILE_ACCESS_EVENTS"]
            .merge(
                df_dict["ENUM_OSRT_FILE_ACCESS_EVENT_TYPE"],
                left_on="eventType",
                right_on="id",
                how="left",
            )
            .drop(columns=["eventType", "id"])
            .rename(columns={"name": "eventType", "label": "eventLabel"})
            .merge(
                df_dict["StringIds"],
                left_on="apiCallId",
                right_on="id",
                how="left",
            )
            .drop(columns=["apiCallId", "id"])
            .rename(columns={"value": "apiCall"})
        )

        return metadata_df, events_df

    @staticmethod
    def get_counters_statistics(
        events_df: pd.DataFrame,
    ) -> tuple[pd.DataFrame, pd.DataFrame]:
        """
        Find total bytes read and written and total read and write operations
        per file access id and thread id.
        """

        read_counters_df = (
            events_df.query("eventType == 'Read'")
            .groupby(["fileAccessId", "threadId"], as_index=False)
            .agg({"eventType": "count", "bytesProcessed": "sum"})
            .rename(
                columns={
                    "eventType": "totalReadOps",
                    "bytesProcessed": "totalReadBytes",
                }
            )
        )

        write_counters_df = (
            events_df.query("eventType == 'Write'")
            .groupby(["fileAccessId", "threadId"], as_index=False)
            .agg({"eventType": "count", "bytesProcessed": "sum"})
            .rename(
                columns={
                    "eventType": "totalWriteOps",
                    "bytesProcessed": "totalWriteBytes",
                }
            )
        )

        return read_counters_df, write_counters_df

    @log.time("Mapper")
    def mapper_func(
        self, context: recipe.Context
    ) -> list[Optional[tuple[str, pd.DataFrame, pd.DataFrame]]]:
        return context.wait(
            context.map(
                self._mapper_func,
                self._parsed_args.input,
                parsed_args=self._parsed_args,
            )
        )

    @log.time("Reducer")
    def reducer_func(
        self, mapper_res: list[Optional[tuple[str, pd.DataFrame, pd.DataFrame]]]
    ) -> None:
        filtered_res = helpers.filter_none_or_empty(mapper_res)

        # Sort by file name.
        filtered_res = sorted(filtered_res, key=lambda x: x[0])
        filenames: tuple[str, ...]
        file_statistics_dfs: tuple[pd.DataFrame, ...]
        metadata_events_complete_dfs: tuple[pd.DataFrame, ...]

        filenames, file_statistics_dfs, metadata_events_complete_dfs = zip(
            *filtered_res
        )

        files_df = pd.DataFrame({"File": filenames}).rename_axis("nsys-rep ID")
        files_df.to_parquet(self.add_output_file("files.parquet"))

        file_statistics_dfs = tuple(
            df.assign(rep_file=filename)
            for filename, df in zip(filenames, file_statistics_dfs)
        )
        file_statistics_df = pd.concat(file_statistics_dfs)
        file_statistics_df.rename(
            columns={"rep_file": "sourceReportFile"}, inplace=True
        )

        file_statistics_df.to_parquet(self.add_output_file("file_statistics.parquet"))

        metadata_events_complete_dfs = tuple(
            df.assign(rep_file=filename)
            for filename, df in zip(filenames, metadata_events_complete_dfs)
        )
        metadata_events_complete_df = pd.concat(metadata_events_complete_dfs)
        metadata_events_complete_df.rename(
            columns={"rep_file": "sourceReportFile"}, inplace=True
        )

        metadata_events_complete_df.to_parquet(
            self.add_output_file("file_events.parquet")
        )

    def save_notebook(self) -> None:
        self.create_notebook("file_access_stats.ipynb")
        self.add_notebook_helper_file("nsys_display.py")

    def save_analysis_file(self) -> None:
        self._analysis_dict.update(
            {
                "EndTime": str(datetime.now()),
                "Outputs": self._output_files,
            }
        )
        self.create_analysis_file()

    def run(self, context: recipe.Context) -> None:
        super().run(context)

        mapper_res = self.mapper_func(context)
        self.reducer_func(mapper_res)

        self.save_notebook()
        self.save_analysis_file()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.INPUT, required=True)
        parser.add_recipe_argument(Option.START)
        parser.add_recipe_argument(Option.END)

        filter_group = parser.recipe_group.add_mutually_exclusive_group()
        parser.add_argument_to_group(filter_group, Option.FILTER_TIME)
        parser.add_argument_to_group(filter_group, Option.FILTER_NVTX)

        return parser
