Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
In,
Expand All @@ -33,19 +34,35 @@
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])

if unique_keys.num_rows == 0:
return AlwaysFalse()

if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
else:
filters = [
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)
return In(join_cols[0], unique_keys.column(join_cols[0]).to_pylist())

# Fold the column that leaves the fewest distinct "prefix" combinations into
# an In(); this minimises the disjunct count regardless of column order.
in_col = min(
join_cols,
key=lambda cand: unique_keys.select([c for c in join_cols if c != cand])
.group_by([c for c in join_cols if c != cand])
.aggregate([])
.num_rows,
)
prefix_cols = [c for c in join_cols if c != in_col]

grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")])
in_values_col = f"{in_col}_list"

disjuncts: list[BooleanExpression] = []
for row in grouped.to_pylist():
eqs = [EqualTo(c, row[c]) for c in prefix_cols]
prefix_pred = functools.reduce(operator.and_, eqs) if len(eqs) > 1 else eqs[0]
disjuncts.append(And(prefix_pred, In(in_col, row[in_values_col])))

if len(disjuncts) == 1:
return disjuncts[0]
return Or(*disjuncts)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,10 @@ def test_create_match_filter_single_condition() -> None:
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())])
table = pa.Table.from_pylist(data, schema=schema)
expr = create_match_filter(table, ["order_id", "order_line_id"])
assert expr == And(
EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)),
EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)),
)
# Be insensitive to left/right operands
op1 = EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101))
op2 = EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1))
assert expr == And(op1, op2) or expr == And(op2, op1)


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
Expand Down