#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
User-defined table function related classes and functions
"""
import pickle
from dataclasses import dataclass, field
import inspect
import sys
import warnings
from typing import Any, Type, TYPE_CHECKING, Optional, Sequence, Union
from pyspark.errors import PySparkAttributeError, PySparkPicklingError, PySparkTypeError
from pyspark.util import PythonEvalType
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.sql.types import DataType, StructType, _parse_datatype_string
from pyspark.sql.udf import _wrap_function
if TYPE_CHECKING:
    from py4j.java_gateway import JavaObject
    from pyspark.sql._typing import ColumnOrName
    from pyspark.sql.dataframe import DataFrame
    from pyspark.sql.session import SparkSession
__all__ = [
    "AnalyzeArgument",
    "AnalyzeResult",
    "PartitioningColumn",
    "OrderingColumn",
    "SelectedColumn",
    "SkipRestOfInputTableException",
    "UDTFRegistration",
]
@dataclass(frozen=True)
class AnalyzeArgument:
    """
    The argument for Python UDTF's analyze static method.
    Parameters
    ----------
    dataType : :class:`DataType`
        The argument's data type
    value : any, optional
        The calculated value if the argument is foldable; otherwise None
    isTable : bool
        If True, the argument is a table argument.
    isConstantExpression : bool
        If True, the argument is a constant-foldable scalar expression. Then the 'value' field
        contains None if the argument is a NULL literal, or a non-None value if the argument is a
        non-NULL literal. In this way, we can distinguish between a literal NULL argument and other
        types of arguments such as complex expression trees or table arguments where the 'value'
        field is always None.
    """
    dataType: DataType
    value: Optional[Any]
    isTable: bool
    isConstantExpression: bool
@dataclass(frozen=True)
class PartitioningColumn:
    """
    Represents an expression that the UDTF is specifying for Catalyst to partition the input table
    by. This can be either the name of a single column from the input table (such as "columnA"), or
    a SQL expression based on the column names of the input table (such as "columnA + columnB").
    Parameters
    ----------
    name : str
        The contents of the partitioning column name or expression represented as a SQL string.
    """
    name: str
@dataclass(frozen=True)
class OrderingColumn:
    """
    Represents an expression that the UDTF is specifying for Catalyst to order the input partition
    by. This can be either the name of a single column from the input table (such as "columnA"),
    or a SQL expression based on the column names of the input table (such as "columnA + columnB").
    Parameters
    ----------
    name : str
        The contents of the ordering column name or expression represented as a SQL string.
    ascending : bool, default True
        This is if this expression specifies an ascending sorting order.
    overrideNullsFirst : str, optional
        If this is None, use the default behavior to sort NULL values first when sorting in
        ascending order, or last when sorting in descending order. Otherwise, if this is
        True or False, we override the default behavior accordingly.
    """
    name: str
    ascending: bool = True
    overrideNullsFirst: Optional[bool] = None
@dataclass(frozen=True)
class SelectedColumn:
    """
    Represents an expression that the UDTF is specifying for Catalyst to evaluate against the
    columns in the input TABLE argument. The UDTF then receives one input column for each expression
    in the list, in the order they are listed.
    Parameters
    ----------
    name : str
        The contents of the selected column name or expression represented as a SQL string.
    alias : str, default ''
        If non-empty, this is the alias for the column or expression as visible from the UDTF's
        'eval' method. This is required if the expression is not a simple column reference.
    """
    name: str
    alias: str = ""
# Note: this class is a "dataclass" for purposes of convenience, but it is not marked "frozen"
# because the intention is that users may create subclasses of it for purposes of returning custom
# information from the "analyze" method.
@dataclass
class AnalyzeResult:
    """
    The return of Python UDTF's analyze static method.
    Parameters
    ----------
    schema: :class:`StructType`
        The schema that the Python UDTF will return.
    withSinglePartition: bool
        If true, the UDTF is specifying for Catalyst to repartition all rows of the input TABLE
        argument to one collection for consumption by exactly one instance of the correpsonding
        UDTF class.
    partitionBy: sequence of :class:`PartitioningColumn`
        If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to
        partition the input TABLE argument by. In this case, calls to the UDTF may not include any
        explicit PARTITION BY clause, in which case Catalyst will return an error. This option is
        mutually exclusive with 'withSinglePartition'.
    orderBy: sequence of :class:`OrderingColumn`
        If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to
        sort the input TABLE argument by. Note that the 'partitionBy' list must also be non-empty
        in this case.
    select: sequence of :class:`SelectedColumn`
        If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to
        evaluate against the columns in the input TABLE argument. The UDTF then receives one input
        attribute for each name in the list, in the order they are listed.
    """
    schema: StructType
    withSinglePartition: bool = False
    partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple)
    orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
    select: Sequence[SelectedColumn] = field(default_factory=tuple)
class SkipRestOfInputTableException(Exception):
    """
    This represents an exception that the 'eval' method may raise to indicate that it is done
    consuming rows from the current partition of the input table. Then the UDTF's 'terminate'
    method runs (if any).
    """
    pass
def _create_udtf(
    cls: Type,
    returnType: Optional[Union[StructType, str]],
    name: Optional[str] = None,
    evalType: int = PythonEvalType.SQL_TABLE_UDF,
    deterministic: bool = False,
) -> "UserDefinedTableFunction":
    """Create a Python UDTF with the given eval type."""
    udtf_obj = UserDefinedTableFunction(
        cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
    )
    return udtf_obj
def _create_py_udtf(
    cls: Type,
    returnType: Optional[Union[StructType, str]],
    name: Optional[str] = None,
    deterministic: bool = False,
    useArrow: Optional[bool] = None,
) -> "UserDefinedTableFunction":
    """Create a regular or an Arrow-optimized Python UDTF."""
    # Determine whether to create Arrow-optimized UDTFs.
    if useArrow is not None:
        arrow_enabled = useArrow
    else:
        from pyspark.sql import SparkSession
        session = SparkSession._instantiatedSession
        arrow_enabled = False
        if session is not None:
            value = session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")
            if isinstance(value, str) and value.lower() == "true":
                arrow_enabled = True
    eval_type: int = PythonEvalType.SQL_TABLE_UDF
    if arrow_enabled:
        # Return the regular UDTF if the required dependencies are not satisfied.
        try:
            require_minimum_pandas_version()
            require_minimum_pyarrow_version()
            eval_type = PythonEvalType.SQL_ARROW_TABLE_UDF
        except ImportError as e:
            warnings.warn(
                f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
                f"Falling back to using regular Python UDTFs.",
                UserWarning,
            )
    return _create_udtf(
        cls=cls,
        returnType=returnType,
        name=name,
        evalType=eval_type,
        deterministic=deterministic,
    )
def _validate_udtf_handler(cls: Any, returnType: Optional[Union[StructType, str]]) -> None:
    """Validate the handler class of a UDTF."""
    if not isinstance(cls, type):
        raise PySparkTypeError(
            error_class="INVALID_UDTF_HANDLER_TYPE", message_parameters={"type": type(cls).__name__}
        )
    if not hasattr(cls, "eval"):
        raise PySparkAttributeError(
            error_class="INVALID_UDTF_NO_EVAL", message_parameters={"name": cls.__name__}
        )
    has_analyze = hasattr(cls, "analyze")
    has_analyze_staticmethod = has_analyze and isinstance(
        inspect.getattr_static(cls, "analyze"), staticmethod
    )
    if returnType is None and not has_analyze_staticmethod:
        raise PySparkAttributeError(
            error_class="INVALID_UDTF_RETURN_TYPE", message_parameters={"name": cls.__name__}
        )
    if returnType is not None and has_analyze:
        raise PySparkAttributeError(
            error_class="INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE",
            message_parameters={"name": cls.__name__},
        )
[docs]class UserDefinedTableFunction:
    """
    User-defined table function in Python
    .. versionadded:: 3.5.0
    Notes
    -----
    The constructor of this class is not supposed to be directly called.
    Use :meth:`pyspark.sql.functions.udtf` to create this instance.
    This API is evolving.
    """
    def __init__(
        self,
        func: Type,
        returnType: Optional[Union[StructType, str]],
        name: Optional[str] = None,
        evalType: int = PythonEvalType.SQL_TABLE_UDF,
        deterministic: bool = False,
    ):
        _validate_udtf_handler(func, returnType)
        self.func = func
        self._returnType = returnType
        self._returnType_placeholder: Optional[StructType] = None
        self._inputTypes_placeholder = None
        self._judtf_placeholder = None
        self._name = name or func.__name__
        self.evalType = evalType
        self.deterministic = deterministic
    @property
    def returnType(self) -> Optional[StructType]:
        if self._returnType is None:
            return None
        # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string.
        # This makes sure this is called after SparkContext is initialized.
        if self._returnType_placeholder is None:
            if isinstance(self._returnType, str):
                parsed = _parse_datatype_string(self._returnType)
            else:
                parsed = self._returnType
            if not isinstance(parsed, StructType):
                raise PySparkTypeError(
                    error_class="UDTF_RETURN_TYPE_MISMATCH",
                    message_parameters={
                        "name": self._name,
                        "return_type": f"{parsed}",
                    },
                )
            self._returnType_placeholder = parsed
        return self._returnType_placeholder
    @property
    def _judtf(self) -> "JavaObject":
        if self._judtf_placeholder is None:
            self._judtf_placeholder = self._create_judtf(self.func)
        return self._judtf_placeholder
    def _create_judtf(self, func: Type) -> "JavaObject":
        from pyspark.sql import SparkSession
        spark = SparkSession._getActiveSessionOrCreate()
        sc = spark.sparkContext
        try:
            wrapped_func = _wrap_function(sc, func)
        except pickle.PicklingError as e:
            if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e):
                raise PySparkPicklingError(
                    error_class="UDTF_SERIALIZATION_ERROR",
                    message_parameters={
                        "name": self._name,
                        "message": "it appears that you are attempting to reference SparkSession "
                        "inside a UDTF. SparkSession can only be used on the driver, "
                        "not in code that runs on workers. Please remove the reference "
                        "and try again.",
                    },
                ) from None
            raise PySparkPicklingError(
                error_class="UDTF_SERIALIZATION_ERROR",
                message_parameters={
                    "name": self._name,
                    "message": "Please check the stack trace and make sure the "
                    "function is serializable.",
                },
            )
        assert sc._jvm is not None
        if self.returnType is None:
            judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
                self._name, wrapped_func, self.evalType, self.deterministic
            )
        else:
            jdt = spark._jsparkSession.parseDataType(self.returnType.json())
            judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
                self._name, wrapped_func, jdt, self.evalType, self.deterministic
            )
        return judtf
    def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame":
        from pyspark.sql.classic.column import _to_java_column, _to_java_expr, _to_seq
        from pyspark.sql import DataFrame, SparkSession
        spark = SparkSession._getActiveSessionOrCreate()
        sc = spark.sparkContext
        assert sc._jvm is not None
        jcols = [_to_java_column(arg) for arg in args] + [
            sc._jvm.Column(
                sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression(
                    key, _to_java_expr(value)
                )
            )
            for key, value in kwargs.items()
        ]
        judtf = self._judtf
        jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, jcols))
        return DataFrame(jPythonUDTF, spark)
[docs]    def asDeterministic(self) -> "UserDefinedTableFunction":
        """
        Updates UserDefinedTableFunction to deterministic.
        """
        # Explicitly clean the cache to create a JVM UDTF instance.
        self._judtf_placeholder = None
        self.deterministic = True
        return self  
[docs]class UDTFRegistration:
    """
    Wrapper for user-defined table function registration. This instance can be accessed by
    :attr:`spark.udtf` or :attr:`sqlContext.udtf`.
    .. versionadded:: 3.5.0
    """
    def __init__(self, sparkSession: "SparkSession"):
        self.sparkSession = sparkSession
[docs]    def register(
        self,
        name: str,
        f: "UserDefinedTableFunction",
    ) -> "UserDefinedTableFunction":
        """Register a Python user-defined table function as a SQL table function.
        .. versionadded:: 3.5.0
        Parameters
        ----------
        name : str
            The name of the user-defined table function in SQL statements.
        f : function or :meth:`pyspark.sql.functions.udtf`
            The user-defined table function.
        Returns
        -------
        function
            The registered user-defined table function.
        Notes
        -----
        Spark uses the return type of the given user-defined table function as the return
        type of the registered user-defined function.
        To register a nondeterministic Python table function, users need to first build
        a nondeterministic user-defined table function and then register it as a SQL function.
        Examples
        --------
        >>> from pyspark.sql.functions import udtf
        >>> @udtf(returnType="c1: int, c2: int")
        ... class PlusOne:
        ...     def eval(self, x: int):
        ...         yield x, x + 1
        ...
        >>> _ = spark.udtf.register(name="plus_one", f=PlusOne)
        >>> spark.sql("SELECT * FROM plus_one(1)").collect()
        [Row(c1=1, c2=2)]
        Use it with lateral join
        >>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect()
        [Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)]
        """
        if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]:
            raise PySparkTypeError(
                error_class="INVALID_UDTF_EVAL_TYPE",
                message_parameters={
                    "name": name,
                    "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",
                },
            )
        register_udtf = _create_udtf(
            cls=f.func,
            returnType=f.returnType,
            name=name,
            evalType=f.evalType,
            deterministic=f.deterministic,
        )
        self.sparkSession._jsparkSession.udtf().registerPython(name, register_udtf._judtf)
        return register_udtf  
def _test() -> None:
    import doctest
    from pyspark.sql import SparkSession
    import pyspark.sql.udf
    globs = pyspark.sql.udtf.__dict__.copy()
    spark = SparkSession.builder.master("local[4]").appName("sql.udtf tests").getOrCreate()
    globs["spark"] = spark
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.udtf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
    )
    spark.stop()
    if failure_count:
        sys.exit(-1)
if __name__ == "__main__":
    _test()