# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors


from abc import ABC, abstractmethod
from packaging.version import Version
from typing import Union, List, TYPE_CHECKING

import numpy as np
import pyarrow as pa

if TYPE_CHECKING:
    from ..table import LanceVectorQueryBuilder

ARROW_VERSION = Version(pa.__version__)


class Reranker(ABC):
    def __init__(self, return_score: str = "relevance"):
        """
        Interface for a reranker. A reranker is used to rerank the results from a
        vector and FTS search. This is useful for combining the results from both
        search methods.

        Parameters
        ----------
        return_score : str, default "relevance"
            opntions are "relevance" or "all"
            The type of score to return. If "relevance", will return only the relevance
            score. If "all", will return all scores from the vector and FTS search along
            with the relevance score.

        """
        if return_score not in ["relevance", "all"]:
            raise ValueError("score must be either 'relevance' or 'all'")
        self.score = return_score
        # Set the merge args based on the arrow version here to avoid checking it at
        # each query
        self._concat_tables_args = {"promote_options": "default"}
        if ARROW_VERSION.major <= 13:
            self._concat_tables_args = {"promote": True}

    def __str__(self):
        return self.__class__.__name__

    def rerank_vector(
        self,
        query: str,
        vector_results: pa.Table,
    ):
        """
        Rerank function receives the result from the vector search.
        This isn't mandatory to implement

        Parameters
        ----------
        query : str
            The input query
        vector_results : pa.Table
            The results from the vector search

        Returns
        -------
        pa.Table
            The reranked results
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} does not implement rerank_vector"
        )

    def _handle_empty_results(self, results: pa.Table):
        """
        Helper method to handle empty FTS results consistently
        """
        if len(results) > 0:
            return results
        return results.append_column(
            "_relevance_score", pa.array([], type=pa.float32())
        )

    def rerank_fts(
        self,
        query: str,
        fts_results: pa.Table,
    ):
        """
        Rerank function receives the result from the FTS search.
        This isn't mandatory to implement

        Parameters
        ----------
        query : str
            The input query
        fts_results : pa.Table
            The results from the FTS search

        Returns
        -------
        pa.Table
            The reranked results
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} does not implement rerank_fts"
        )

    @abstractmethod
    def rerank_hybrid(
        self,
        query: str,
        vector_results: pa.Table,
        fts_results: pa.Table,
    ) -> pa.Table:
        """
        Rerank function receives the individual results from the vector and FTS search
        results. You can choose to use any of the results to generate the final results,
        allowing maximum flexibility. This is mandatory to implement

        Parameters
        ----------
        query : str
            The input query
        vector_results : pa.Table
            The results from the vector search
        fts_results : pa.Table
            The results from the FTS search

        Returns
        -------
        pa.Table
            The reranked results
        """
        pass

    def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
        """
        Merge the results from the vector and FTS search. This is a vanilla merging
        function that just concatenates the results and removes the duplicates.

        NOTE: This doesn't take score into account. It'll keep the instance that was
        encountered first. This is designed for rerankers that don't use the score.
        In case you want to use the score, or support `return_scores="all"` you'll
        have to implement your own merging function.

        Parameters
        ----------
        vector_results : pa.Table
            The results from the vector search
        fts_results : pa.Table
            The results from the FTS search
        """
        combined = pa.concat_tables(
            [vector_results, fts_results], **self._concat_tables_args
        )

        # deduplicate
        combined = self._deduplicate(combined)

        return combined

    def rerank_multivector(
        self,
        vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]],
        query: Union[str, None],  # Some rerankers might not need the query
        deduplicate: bool = False,
    ):
        """
        This is a rerank function that receives the results from multiple
        vector searches. For example, this can be used to combine the
        results of two vector searches with different embeddings.

        Parameters
        ----------
        vector_results : List[pa.Table] or List[LanceVectorQueryBuilder]
            The results from the vector search. Either accepts the query builder
            if the results haven't been executed yet or the results in arrow format.
        query : str or None,
            The input query. Some rerankers might not need the query to rerank.
            In that case, it can be set to None explicitly. This is inteded to
            be handled by the reranker implementations.
        deduplicate : bool, optional
            Whether to deduplicate the results based on the `_rowid` column,
            by default False. Requires `_rowid` to be present in the results.

        Returns
        -------
        pa.Table
            The reranked results
        """
        vector_results = (
            [vector_results] if not isinstance(vector_results, list) else vector_results
        )

        # Make sure all elements are of the same type
        if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
            raise ValueError(
                "All elements in vector_results should be of the same type"
            )

        # avoids circular import
        if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder":
            vector_results = [result.to_arrow() for result in vector_results]
        elif not isinstance(vector_results[0], pa.Table):
            raise ValueError(
                "vector_results should be a list of pa.Table or LanceVectorQueryBuilder"
            )

        combined = pa.concat_tables(vector_results, **self._concat_tables_args)

        reranked = self.rerank_vector(query, combined)

        # TODO: Allow custom deduplicators here.
        # currently, this'll just keep the first instance.
        if deduplicate:
            if "_rowid" not in combined.column_names:
                raise ValueError(
                    "'_rowid' is required for deduplication. \
                    add _rowid to search results like this: \
                    `search().with_row_id(True)`"
                )
            reranked = self._deduplicate(reranked)

        return reranked

    def _deduplicate(self, table: pa.Table):
        """
        Deduplicate the table based on the `_rowid` column.
        """
        row_id = table.column("_rowid")

        # deduplicate
        mask = np.full((table.shape[0]), False)
        _, mask_indices = np.unique(np.array(row_id), return_index=True)
        mask[mask_indices] = True
        deduped_table = table.filter(mask=mask)

        return deduped_table

    def _merge_and_keep_scores(self, vector_results: pa.Table, fts_results: pa.Table):
        """
        Merge the results from the vector and FTS search and keep the scores.
        This op is slower than just keeping relevance score but can be useful
        for debugging.
        """
        # add nulls to fts results for _distance
        if "_distance" not in fts_results.column_names:
            fts_results = fts_results.append_column(
                "_distance",
                pa.array([None] * len(fts_results), type=pa.float32()),
            )
        # add nulls to vector results for _score
        if "_score" not in vector_results.column_names:
            vector_results = vector_results.append_column(
                "_score",
                pa.array([None] * len(vector_results), type=pa.float32()),
            )

        # combine them and fill the scores
        vector_results_dict = {row["_rowid"]: row for row in vector_results.to_pylist()}
        fts_results_dict = {row["_rowid"]: row for row in fts_results.to_pylist()}

        # merge them into vector_results
        for key, value in fts_results_dict.items():
            if key in vector_results_dict:
                vector_results_dict[key]["_score"] = value["_score"]
            else:
                vector_results_dict[key] = value

        combined = pa.Table.from_pylist(list(vector_results_dict.values()))
        return combined

    def _keep_relevance_score(self, combined_results: pa.Table):
        if self.score == "relevance":
            if "_score" in combined_results.column_names:
                combined_results = combined_results.drop_columns(["_score"])
            if "_distance" in combined_results.column_names:
                combined_results = combined_results.drop_columns(["_distance"])
        return combined_results
