From 7cfaf687544af6f5a9587f0dc976fb452ab01336 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 24 Mar 2026 20:08:56 +0000 Subject: [PATCH] refactor: apply is_null_literal --- .../compile/sqlglot/expressions/bool_ops.py | 13 ++++++------ .../sqlglot/expressions/comparison_ops.py | 8 +++---- .../sqlglot/expressions/numeric_ops.py | 21 +++++++++++++++---- .../test_numeric_ops/test_div_numeric/out.sql | 2 +- .../test_floordiv_numeric/out.sql | 1 + .../test_numeric_ops/test_mod_numeric/out.sql | 3 ++- .../test_numeric_ops/test_pow/out.sql | 4 +++- .../sqlglot/expressions/test_numeric_ops.py | 6 ++++++ 8 files changed, 41 insertions(+), 17 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/bool_ops.py b/bigframes/core/compile/sqlglot/expressions/bool_ops.py index cd7f9da408..3b4ecf5431 100644 --- a/bigframes/core/compile/sqlglot/expressions/bool_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/bool_ops.py @@ -18,6 +18,7 @@ from bigframes import dtypes from bigframes import operations as ops +from bigframes.core.compile.sqlglot import sql import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr @@ -29,10 +30,10 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: # For AND, when we encounter a NULL value, we only know when the result is FALSE, # otherwise the result is unknown (NULL). See: truth table at # https://en.wikibooks.org/wiki/Structured_Query_Language/NULLs_and_the_Three_Valued_Logic#AND,_OR - if left.expr == sge.null(): + if sql.is_null_literal(left.expr): condition = sge.EQ(this=right.expr, expression=sge.convert(False)) return sge.If(this=condition, true=right.expr, false=sge.null()) - if right.expr == sge.null(): + if sql.is_null_literal(right.expr): condition = sge.EQ(this=left.expr, expression=sge.convert(False)) return sge.If(this=condition, true=left.expr, false=sge.null()) @@ -46,10 +47,10 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: # For OR, when we encounter a NULL value, we only know when the result is TRUE, # otherwise the result is unknown (NULL). See: truth table at # https://en.wikibooks.org/wiki/Structured_Query_Language/NULLs_and_the_Three_Valued_Logic#AND,_OR - if left.expr == sge.null(): + if sql.is_null_literal(left.expr): condition = sge.EQ(this=right.expr, expression=sge.convert(True)) return sge.If(this=condition, true=right.expr, false=sge.null()) - if right.expr == sge.null(): + if sql.is_null_literal(right.expr): condition = sge.EQ(this=left.expr, expression=sge.convert(True)) return sge.If(this=condition, true=left.expr, false=sge.null()) @@ -64,12 +65,12 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: # maintains the boolean data type. left_expr = left.expr left_dtype = left.dtype - if left_expr == sge.null(): + if sql.is_null_literal(left_expr): left_expr = sge.Cast(this=sge.convert(None), to="BOOLEAN") left_dtype = dtypes.BOOL_DTYPE right_expr = right.expr right_dtype = right.dtype - if right_expr == sge.null(): + if sql.is_null_literal(right_expr): right_expr = sge.Cast(this=sge.convert(None), to="BOOLEAN") right_dtype = dtypes.BOOL_DTYPE diff --git a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py index 7177f9de84..82c264da50 100644 --- a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py @@ -102,7 +102,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.ge_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() left_expr = _coerce_bool_to_int(left) @@ -112,7 +112,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.gt_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() left_expr = _coerce_bool_to_int(left) @@ -122,7 +122,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.lt_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() left_expr = _coerce_bool_to_int(left) @@ -132,7 +132,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.le_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() left_expr = _coerce_bool_to_int(left) diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index d70ec2ef3f..c5fdbe3c84 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -19,6 +19,7 @@ from bigframes import dtypes from bigframes import operations as ops +from bigframes.core.compile.sqlglot import sql import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler from bigframes.core.compile.sqlglot.expressions.common import round_towards_zero import bigframes.core.compile.sqlglot.expressions.constants as constants @@ -260,6 +261,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: def _int_pow_op( left_expr: sge.Expression, right_expr: sge.Expression ) -> sge.Expression: + if sql.is_null_literal(left_expr) or sql.is_null_literal(right_expr): + return sge.null() + overflow_cond = sge.and_( sge.NEQ(this=left_expr, expression=sge.convert(0)), sge.GT( @@ -292,6 +296,9 @@ def _int_pow_op( def _float_pow_op( left_expr: sge.Expression, right_expr: sge.Expression ) -> sge.Expression: + if sql.is_null_literal(left_expr) or sql.is_null_literal(right_expr): + return sge.null() + # Most conditions here seek to prevent calling BQ POW with inputs that would generate errors. # See: https://cloud.google.com/bigquery/docs/reference/standard-sql/mathematical_functions#pow overflow_cond = sge.and_( @@ -425,7 +432,7 @@ def _(expr: TypedExpr) -> sge.Expression: @register_binary_op(ops.add_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE: @@ -463,6 +470,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.div_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): + return sge.null() + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -482,7 +492,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.floordiv_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() left_expr = _coerce_bool_to_int(left) @@ -525,6 +535,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.mod_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): + return sge.null() + # In BigQuery returned value has the same sign as X. In pandas, the sign of y is used, so we need to flip the result if sign(x) != sign(y) left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -568,7 +581,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.mul_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() left_expr = _coerce_bool_to_int(left) @@ -594,7 +607,7 @@ def _(expr: TypedExpr, n_digits: TypedExpr) -> sge.Expression: @register_binary_op(ops.sub_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.expr == sge.null() or right.expr == sge.null(): + if sql.is_null_literal(left.expr) or sql.is_null_literal(right.expr): return sge.null() if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_numeric/out.sql index 3f5ff73326..e2ccf96410 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_numeric/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_numeric/out.sql @@ -6,7 +6,7 @@ SELECT IEEE_DIVIDE(`int64_col`, `int64_col`) AS `int_div_int`, IEEE_DIVIDE(`int64_col`, 1) AS `int_div_1`, IEEE_DIVIDE(`int64_col`, 0.0) AS `int_div_0`, - IEEE_DIVIDE(`int64_col`, NULL) AS `int_div_null`, + NULL AS `int_div_null`, IEEE_DIVIDE(`int64_col`, `float64_col`) AS `int_div_float`, IEEE_DIVIDE(`float64_col`, `int64_col`) AS `float_div_int`, IEEE_DIVIDE(`float64_col`, 0.0) AS `float_div_0`, diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_floordiv_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_floordiv_numeric/out.sql index c7fa74e48f..8307b1b8ad 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_floordiv_numeric/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_floordiv_numeric/out.sql @@ -34,6 +34,7 @@ SELECT THEN CAST('Infinity' AS FLOAT64) * `float64_col` ELSE CAST(FLOOR(IEEE_DIVIDE(`float64_col`, 0.0)) AS INT64) END AS `float_div_0`, + NULL AS `float_div_null`, CASE WHEN CAST(`bool_col` AS INT64) = CAST(0 AS INT64) THEN CAST(0 AS INT64) * `int64_col` diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql index 2a79820635..78107415b4 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql @@ -189,5 +189,6 @@ SELECT MOD(CAST(`float64_col` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC)) ) ELSE MOD(CAST(`float64_col` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC)) - END AS `float_mod_0` + END AS `float_mod_0`, + NULL AS `float_mod_null` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql index 8f72522262..7202903ebe 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql @@ -241,5 +241,7 @@ SELECT ELSE 1 END ) - END AS `float_pow_1` + END AS `float_pow_1`, + NULL AS `float_pow_null`, + NULL AS `null_pow_float` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 17c2ff98bc..1d2f0a5b44 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -220,6 +220,9 @@ def test_pow(scalar_types_df: bpd.DataFrame, snapshot): bf_df["int_pow_1"] = bf_df["int64_col"] ** 1 bf_df["float_pow_1"] = bf_df["float64_col"] ** 1 + bf_df["float_pow_null"] = bf_df["float64_col"] ** pd.NA + bf_df["null_pow_float"] = pd.NA ** bf_df["float64_col"] + snapshot.assert_match(bf_df.sql, "out.sql") @@ -370,6 +373,7 @@ def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df["int_div_float"] = bf_df["int64_col"] // bf_df["float64_col"] bf_df["float_div_int"] = bf_df["float64_col"] // bf_df["int64_col"] bf_df["float_div_0"] = bf_df["float64_col"] // 0.0 + bf_df["float_div_null"] = bf_df["float64_col"] // pd.NA bf_df["int_div_bool"] = bf_df["int64_col"] // bf_df["bool_col"] bf_df["bool_div_int"] = bf_df["bool_col"] // bf_df["int64_col"] @@ -437,6 +441,8 @@ def test_mod_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df["float_mod_1"] = bf_df["float64_col"] % 1 bf_df["float_mod_0"] = bf_df["float64_col"] % 0 + bf_df["float_mod_null"] = bf_df["float64_col"] % pd.NA + snapshot.assert_match(bf_df.sql, "out.sql")