from datetime import timedelta
from typing import Dict, List, Optional, Tuple, Any, TypedDict, Union, Literal

import pyarrow as pa

from .index import (
    BTree,
    IvfFlat,
    IvfPq,
    IvfSq,
    Bitmap,
    LabelList,
    HnswPq,
    HnswSq,
    FTS,
)
from .io import StorageOptionsProvider
from lance_namespace import (
    ListNamespacesResponse,
    CreateNamespaceResponse,
    DropNamespaceResponse,
    DescribeNamespaceResponse,
    ListTablesResponse,
)
from .remote import ClientConfig

IvfHnswPq: type[HnswPq] = HnswPq
IvfHnswSq: type[HnswSq] = HnswSq

class PyExpr:
    """A type-safe DataFusion expression node (Rust-side handle)."""

    def eq(self, other: "PyExpr") -> "PyExpr": ...
    def ne(self, other: "PyExpr") -> "PyExpr": ...
    def lt(self, other: "PyExpr") -> "PyExpr": ...
    def lte(self, other: "PyExpr") -> "PyExpr": ...
    def gt(self, other: "PyExpr") -> "PyExpr": ...
    def gte(self, other: "PyExpr") -> "PyExpr": ...
    def and_(self, other: "PyExpr") -> "PyExpr": ...
    def or_(self, other: "PyExpr") -> "PyExpr": ...
    def not_(self) -> "PyExpr": ...
    def add(self, other: "PyExpr") -> "PyExpr": ...
    def sub(self, other: "PyExpr") -> "PyExpr": ...
    def mul(self, other: "PyExpr") -> "PyExpr": ...
    def div(self, other: "PyExpr") -> "PyExpr": ...
    def lower(self) -> "PyExpr": ...
    def upper(self) -> "PyExpr": ...
    def contains(self, substr: "PyExpr") -> "PyExpr": ...
    def cast(self, data_type: pa.DataType) -> "PyExpr": ...
    def to_sql(self) -> str: ...

def expr_col(name: str) -> PyExpr: ...
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...

class Session:
    def __init__(
        self,
        index_cache_size_bytes: Optional[int] = None,
        metadata_cache_size_bytes: Optional[int] = None,
    ): ...
    @staticmethod
    def default() -> "Session": ...
    @property
    def size_bytes(self) -> int: ...
    @property
    def approx_num_items(self) -> int: ...

class Connection(object):
    uri: str
    async def is_open(self): ...
    async def close(self): ...
    async def list_namespaces(
        self,
        namespace: Optional[List[str]] = None,
        page_token: Optional[str] = None,
        limit: Optional[int] = None,
    ) -> ListNamespacesResponse: ...
    async def create_namespace(
        self,
        namespace: List[str],
        mode: Optional[str] = None,
        properties: Optional[Dict[str, str]] = None,
    ) -> CreateNamespaceResponse: ...
    async def drop_namespace(
        self,
        namespace: List[str],
        mode: Optional[str] = None,
        behavior: Optional[str] = None,
    ) -> DropNamespaceResponse: ...
    async def describe_namespace(
        self,
        namespace: List[str],
    ) -> DescribeNamespaceResponse: ...
    async def list_tables(
        self,
        namespace: Optional[List[str]] = None,
        page_token: Optional[str] = None,
        limit: Optional[int] = None,
    ) -> ListTablesResponse: ...
    async def table_names(
        self,
        namespace: Optional[List[str]],
        start_after: Optional[str],
        limit: Optional[int],
    ) -> list[str]: ...  # Deprecated: Use list_tables instead
    async def create_table(
        self,
        name: str,
        mode: str,
        data: pa.RecordBatchReader,
        namespace: Optional[List[str]] = None,
        storage_options: Optional[Dict[str, str]] = None,
        storage_options_provider: Optional[StorageOptionsProvider] = None,
        location: Optional[str] = None,
    ) -> Table: ...
    async def create_empty_table(
        self,
        name: str,
        mode: str,
        schema: pa.Schema,
        namespace: Optional[List[str]] = None,
        storage_options: Optional[Dict[str, str]] = None,
        storage_options_provider: Optional[StorageOptionsProvider] = None,
        location: Optional[str] = None,
    ) -> Table: ...
    async def open_table(
        self,
        name: str,
        namespace: Optional[List[str]] = None,
        storage_options: Optional[Dict[str, str]] = None,
        storage_options_provider: Optional[StorageOptionsProvider] = None,
        index_cache_size: Optional[int] = None,
        location: Optional[str] = None,
    ) -> Table: ...
    async def clone_table(
        self,
        target_table_name: str,
        source_uri: str,
        target_namespace: Optional[List[str]] = None,
        source_version: Optional[int] = None,
        source_tag: Optional[str] = None,
        is_shallow: bool = True,
    ) -> Table: ...
    async def rename_table(
        self,
        cur_name: str,
        new_name: str,
        cur_namespace: Optional[List[str]] = None,
        new_namespace: Optional[List[str]] = None,
    ) -> None: ...
    async def drop_table(
        self, name: str, namespace: Optional[List[str]] = None
    ) -> None: ...
    async def drop_all_tables(self, namespace: Optional[List[str]] = None) -> None: ...

class Table:
    def name(self) -> str: ...
    def __repr__(self) -> str: ...
    def is_open(self) -> bool: ...
    def close(self) -> None: ...
    async def schema(self) -> pa.Schema: ...
    async def add(
        self,
        data: pa.RecordBatchReader,
        mode: Literal["append", "overwrite"],
        progress: Optional[Any] = None,
    ) -> AddResult: ...
    async def update(
        self, updates: Dict[str, str], where: Optional[str]
    ) -> UpdateResult: ...
    async def count_rows(self, filter: Optional[str]) -> int: ...
    async def create_index(
        self,
        column: str,
        index: Union[
            IvfFlat,
            IvfSq,
            IvfPq,
            HnswPq,
            HnswSq,
            BTree,
            Bitmap,
            LabelList,
            FTS,
        ],
        replace: Optional[bool],
        wait_timeout: Optional[object],
        *,
        name: Optional[str],
        train: Optional[bool],
    ): ...
    async def list_versions(self) -> List[Dict[str, Any]]: ...
    async def version(self) -> int: ...
    async def checkout(self, version: Union[int, str]): ...
    async def checkout_latest(self): ...
    async def restore(self, version: Optional[Union[int, str]] = None): ...
    async def prewarm_index(self, index_name: str) -> None: ...
    async def prewarm_data(self, columns: Optional[List[str]] = None) -> None: ...
    async def list_indices(self) -> list[IndexConfig]: ...
    async def delete(self, filter: str) -> DeleteResult: ...
    async def add_columns(self, columns: list[tuple[str, str]]) -> AddColumnsResult: ...
    async def add_columns_with_schema(self, schema: pa.Schema) -> AddColumnsResult: ...
    async def alter_columns(
        self, columns: list[dict[str, Any]]
    ) -> AlterColumnsResult: ...
    async def optimize(
        self,
        *,
        cleanup_since_ms: Optional[int] = None,
        delete_unverified: Optional[bool] = None,
    ) -> OptimizeStats: ...
    async def uri(self) -> str: ...
    async def initial_storage_options(self) -> Optional[Dict[str, str]]: ...
    async def latest_storage_options(self) -> Optional[Dict[str, str]]: ...
    @property
    def tags(self) -> Tags: ...
    def query(self) -> Query: ...
    def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
    def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
    def vector_search(self) -> VectorQuery: ...

class Tags:
    async def list(self) -> Dict[str, Tag]: ...
    async def get_version(self, tag: str) -> int: ...
    async def create(self, tag: str, version: int): ...
    async def delete(self, tag: str): ...
    async def update(self, tag: str, version: int): ...

class IndexConfig:
    name: str
    index_type: str
    columns: List[str]

async def connect(
    uri: str,
    api_key: Optional[str],
    region: Optional[str],
    host_override: Optional[str],
    read_consistency_interval: Optional[float],
    client_config: Optional[Union[ClientConfig, Dict[str, Any]]],
    storage_options: Optional[Dict[str, str]],
    session: Optional[Session],
) -> Connection: ...

class RecordBatchStream:
    @property
    def schema(self) -> pa.Schema: ...
    def __aiter__(self) -> "RecordBatchStream": ...
    async def __anext__(self) -> pa.RecordBatch: ...

class Query:
    def where(self, filter: str): ...
    def where_expr(self, expr: PyExpr): ...
    def select(self, columns: List[Tuple[str, str]]): ...
    def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
    def select_columns(self, columns: List[str]): ...
    def limit(self, limit: int): ...
    def offset(self, offset: int): ...
    def fast_search(self): ...
    def with_row_id(self): ...
    def postfilter(self): ...
    def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
    def nearest_to_text(self, query: dict) -> FTSQuery: ...
    async def output_schema(self) -> pa.Schema: ...
    async def execute(
        self, max_batch_length: Optional[int], timeout: Optional[timedelta]
    ) -> RecordBatchStream: ...
    async def explain_plan(self, verbose: Optional[bool]) -> str: ...
    async def analyze_plan(self) -> str: ...
    def to_query_request(self) -> PyQueryRequest: ...

class TakeQuery:
    def select(self, columns: List[str]): ...
    def with_row_id(self): ...
    async def output_schema(self) -> pa.Schema: ...
    async def execute(self) -> RecordBatchStream: ...
    def to_query_request(self) -> PyQueryRequest: ...

class FTSQuery:
    def where(self, filter: str): ...
    def where_expr(self, expr: PyExpr): ...
    def select(self, columns: List[Tuple[str, str]]): ...
    def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
    def limit(self, limit: int): ...
    def offset(self, offset: int): ...
    def fast_search(self): ...
    def with_row_id(self): ...
    def postfilter(self): ...
    def get_query(self) -> str: ...
    def add_query_vector(self, query_vec: pa.Array) -> None: ...
    def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
    async def output_schema(self) -> pa.Schema: ...
    async def execute(
        self, max_batch_length: Optional[int], timeout: Optional[timedelta]
    ) -> RecordBatchStream: ...
    def to_query_request(self) -> PyQueryRequest: ...

class VectorQuery:
    async def output_schema(self) -> pa.Schema: ...
    async def execute(self) -> RecordBatchStream: ...
    def where(self, filter: str): ...
    def where_expr(self, expr: PyExpr): ...
    def select(self, columns: List[Tuple[str, str]]): ...
    def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
    def select_with_projection(self, columns: Tuple[str, str]): ...
    def limit(self, limit: int): ...
    def offset(self, offset: int): ...
    def column(self, column: str): ...
    def distance_type(self, distance_type: str): ...
    def postfilter(self): ...
    def refine_factor(self, refine_factor: int): ...
    def nprobes(self, nprobes: int): ...
    def minimum_nprobes(self, minimum_nprobes: int): ...
    def maximum_nprobes(self, maximum_nprobes: int): ...
    def bypass_vector_index(self): ...
    def nearest_to_text(self, query: dict) -> HybridQuery: ...
    def to_query_request(self) -> PyQueryRequest: ...

class HybridQuery:
    def where(self, filter: str): ...
    def where_expr(self, expr: PyExpr): ...
    def select(self, columns: List[Tuple[str, str]]): ...
    def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
    def limit(self, limit: int): ...
    def offset(self, offset: int): ...
    def fast_search(self): ...
    def with_row_id(self): ...
    def postfilter(self): ...
    def distance_type(self, distance_type: str): ...
    def refine_factor(self, refine_factor: int): ...
    def nprobes(self, nprobes: int): ...
    def minimum_nprobes(self, minimum_nprobes: int): ...
    def maximum_nprobes(self, maximum_nprobes: int): ...
    def bypass_vector_index(self): ...
    def to_vector_query(self) -> VectorQuery: ...
    def to_fts_query(self) -> FTSQuery: ...
    def get_limit(self) -> int: ...
    def get_with_row_id(self) -> bool: ...
    def to_query_request(self) -> PyQueryRequest: ...

class FullTextQuery:
    pass

class PyQueryRequest:
    limit: Optional[int]
    offset: Optional[int]
    filter: Optional[Union[str, bytes]]
    full_text_search: Optional[FullTextQuery]
    select: Optional[Union[str, List[str]]]
    fast_search: Optional[bool]
    with_row_id: Optional[bool]
    column: Optional[str]
    query_vector: Optional[List[pa.Array]]
    minimum_nprobes: Optional[int]
    maximum_nprobes: Optional[int]
    lower_bound: Optional[float]
    upper_bound: Optional[float]
    ef: Optional[int]
    refine_factor: Optional[int]
    distance_type: Optional[str]
    bypass_vector_index: Optional[bool]
    postfilter: Optional[bool]
    norm: Optional[str]

class CompactionStats:
    fragments_removed: int
    fragments_added: int
    files_removed: int
    files_added: int

class CleanupStats:
    bytes_removed: int
    old_versions: int

class RemovalStats:
    bytes_removed: int
    old_versions_removed: int

class OptimizeStats:
    compaction: CompactionStats
    prune: RemovalStats

class Tag(TypedDict):
    version: int
    manifest_size: int

class AddResult:
    version: int

class DeleteResult:
    version: int

class UpdateResult:
    rows_updated: int
    version: int

class MergeResult:
    version: int
    num_updated_rows: int
    num_inserted_rows: int
    num_deleted_rows: int
    num_attempts: int

class AddColumnsResult:
    version: int

class AlterColumnsResult:
    version: int

class DropColumnsResult:
    version: int

class AsyncPermutationBuilder:
    def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
    def split_random(
        self,
        *,
        ratios: Optional[List[float]] = None,
        counts: Optional[List[int]] = None,
        fixed: Optional[int] = None,
        seed: Optional[int] = None,
    ) -> "AsyncPermutationBuilder": ...
    def split_hash(
        self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
    ) -> "AsyncPermutationBuilder": ...
    def split_sequential(
        self,
        *,
        ratios: Optional[List[float]] = None,
        counts: Optional[List[int]] = None,
        fixed: Optional[int] = None,
    ) -> "AsyncPermutationBuilder": ...
    def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
    def shuffle(
        self, seed: Optional[int], clump_size: Optional[int]
    ) -> "AsyncPermutationBuilder": ...
    def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
    async def execute(self) -> Table: ...

def async_permutation_builder(
    table: Table, dest_table_name: str
) -> AsyncPermutationBuilder: ...
def fts_query_to_json(query: Any) -> str: ...

class PermutationReader:
    def __init__(self, base_table: Table, permutation_table: Table): ...
