Skip to content

KNN as Polars Expr

KNN related query expressions in Polars.

Functions:

Name Description
is_knn_from

Returns a boolean column that returns points that are k nearest neighbors from the point.

query_dist_from_kth_nb

Computes the distance of each row to its k-th closest neighbor. This is useful for outlier detection.

query_knn_avg

Takes the target column, and uses feature columns to determine the k nearest neighbors

query_knn_freq_cnt

Takes the index column, and uses feature columns to determine the k nearest neighbors

query_knn_ptwise

Takes the index column, and uses feature columns to determine the k nearest neighbors

query_nb_cnt

Return the number of neighbors within (<=) radius r for each row under the given distance

query_radius_freq_cnt

Takes the index column, and uses features columns to determine distance, finds all neighbors

query_radius_ptwise

Takes the index column, and uses features columns to determine distance, and finds all neighbors

query_radius_ptwise_null_safe

Null-safe variant of query_radius_ptwise. Rows where any feature column is null are

within_dist_from

Returns a boolean column that returns points that are within radius from the given point.

is_knn_from(*features, pt, k, dist='sql2')

Returns a boolean column that returns points that are k nearest neighbors from the point.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
pt Iterable[float]

The point

required
k int

k nearest neighbor

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
Source code in python/polars_ds/exprs/expr_knn.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def is_knn_from(
    *features: str | pl.Expr,
    pt: Iterable[float],
    k: int,
    dist: Distance = "sql2",
) -> pl.Expr:
    """
    Returns a boolean column that returns points that are k nearest neighbors from the point.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    pt : Iterable[float]
        The point
    k : int
        k nearest neighbor
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    """
    # For a single point, it is faster to just do it in native polars
    oth = [to_expr(x) for x in features]
    if not warn_len_compare(pt, oth):
        raise ValueError("Dimension does not match.")

    if dist == "l1":
        dist_out = pl.sum_horizontal(
            (e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth)
        )
        return dist_out <= dist_out.bottom_k(k=k).max()
    elif dist in ("l2", "sql2"):
        dist_out = pl.sum_horizontal(
            (e - pl.lit(xi, dtype=pl.Float64)).pow(2) for xi, e in zip(pt, oth)
        )
        return dist_out <= dist_out.bottom_k(k=k).max()
    elif dist == "inf":
        dist_out = pl.max_horizontal(
            (e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth)
        )
        return dist_out <= dist_out.bottom_k(k=k).max()
    elif dist == "cosine":
        x_list = list(pt)
        x_norm = sum(z * z for z in x_list)
        oth_norm = pl.sum_horizontal(e * e for e in oth)
        dist_out = (
            1.0
            - pl.sum_horizontal(xi * e for xi, e in zip(x_list, oth)) / (x_norm * oth_norm).sqrt()
        )
        return dist_out <= dist_out.bottom_k(k=k).max()
    elif dist in ("h", "haversine"):
        from . import haversine

        pt_as_list = list(pt)
        if (len(pt_as_list) != 2) or (len(oth) < 2):
            raise ValueError(
                "For Haversine distance, input x must have dimension 2 and 2 other columns"
                " must be provided as lat and long."
            )

        y_lat = pl.lit(pt_as_list[0], dtype=pl.Float64)
        y_long = pl.lit(pt_as_list[1], dtype=pl.Float64)
        dist_out = haversine(oth[0], oth[1], y_lat, y_long)
        return dist_out <= dist_out.bottom_k(k=k).max()
    else:
        raise ValueError(f"Unknown distance function: {dist}")

query_dist_from_kth_nb(*features, k, dist='sql2', parallel=False, epsilon=0.0, max_bound=99999.0)

Computes the distance of each row to its k-th closest neighbor. This is useful for outlier detection. E.g. if the average distance to the 5th neighbor is 0.1, then a distance of 0.3 to the 5th neighbor might indicate that the point might be far away from neighboring points, or that it occupies a sparse region in which sample points typically do not appear.

This can be 10% faster and more direct than getting the result from query_knn_ptwise with return_distance = True.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
k int

Number of neighbors to query

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
parallel bool

Whether to run the k-nearest neighbor query in parallel. This is recommended when you are running only this expression, and not in group_by() or over() context.

False
epsilon float

If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter should increase as the dimension of the vector space increases because higher dimensions allow for errors from more directions.

0.0
max_bound float

Max distance the neighbors must be within

99999.0
Source code in python/polars_ds/exprs/expr_knn.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def query_dist_from_kth_nb(
    *features: str | pl.Expr,
    k: int,
    dist: Distance = "sql2",
    parallel: bool = False,
    epsilon: float = 0.0,
    max_bound: float = 99999.0,
) -> pl.Expr:
    """
    Computes the distance of each row to its k-th closest neighbor. This is useful for outlier detection.
    E.g. if the average distance to the 5th neighbor is 0.1, then a distance of 0.3 to the 5th neighbor might
    indicate that the point might be far away from neighboring points, or that it occupies a sparse region in which
    sample points typically do not appear.

    This can be 10% faster and more direct than getting the result from `query_knn_ptwise` with return_distance = True.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    k : int
        Number of neighbors to query
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    parallel : bool
        Whether to run the k-nearest neighbor query in parallel. This is recommended when you
        are running only this expression, and not in group_by() or over() context.
    epsilon
        If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter
        should increase as the dimension of the vector space increases because higher dimensions
        allow for errors from more directions.
    max_bound
        Max distance the neighbors must be within
    """
    return pl_plugin(
        symbol="pl_dist_from_kth_nb",
        args=[to_expr(e) for e in features],
        kwargs={
            "k": k,
            "metric": str(dist).lower(),
            "parallel": parallel,
            "skip_eval": False,
            "max_bound": max_bound,
            "epsilon": epsilon,
        },
    )

query_knn_avg(*features, target, k, dist='sql2', weighted=False, parallel=False, min_bound=1e-09, max_bound=99999.0)

Takes the target column, and uses feature columns to determine the k nearest neighbors to each row. By default, this will return k + 1 neighbors, because the point (the row) itself is a neighbor to itself and this returns k additional neighbors. Any row with a null/NaN will never be a neighbor and will get null as the average.

Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find k neighbors within max_bound, then there will be < k neighbors returned.

This is also known as KNN Regression, but really it is just the average of the K nearest neighbors.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
target str | Expr

Float, must be castable to f64. This should not contain null.

required
k int

Number of neighbors to query

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
weighted bool

If weighted, it will use 1/distance as weights to compute the KNN average. If min_bound is an extremely small value, this will default to 1/(1+distance) as weights to avoid division by 0.

False
parallel bool

Whether to run the k-nearest neighbor query in parallel. This is recommended when you are running only this expression, and not in group_by() or over() context.

False
min_bound float

Min distance (>=) for a neighbor to be part of the average calculation. This prevents "identical" points from being part of the average and prevents division by 0. Note that this filter is applied after getting k nearest neighbors.

1e-09
max_bound float

Max distance the neighbors must be within (<)

99999.0
Source code in python/polars_ds/exprs/expr_knn.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def query_knn_avg(
    *features: str | pl.Expr,
    target: str | pl.Expr,
    k: int,
    dist: Distance = "sql2",
    weighted: bool = False,
    parallel: bool = False,
    min_bound: float = 1e-9,
    max_bound: float = 99999.0,
) -> pl.Expr:
    """
    Takes the target column, and uses feature columns to determine the k nearest neighbors
    to each row. By default, this will return k + 1 neighbors, because the point (the row) itself
    is a neighbor to itself and this returns k additional neighbors. Any row with a null/NaN will
    never be a neighbor and will get null as the average.

    Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find
    k neighbors within `max_bound`, then there will be < k neighbors returned.

    This is also known as KNN Regression, but really it is just the average of the K nearest neighbors.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    target : str | pl.Expr
        Float, must be castable to f64. This should not contain null.
    k : int
        Number of neighbors to query
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    weighted : bool
        If weighted, it will use 1/distance as weights to compute the KNN average. If min_bound is
        an extremely small value, this will default to 1/(1+distance) as weights to avoid division by 0.
    parallel : bool
        Whether to run the k-nearest neighbor query in parallel. This is recommended when you
        are running only this expression, and not in group_by() or over() context.
    min_bound
        Min distance (>=) for a neighbor to be part of the average calculation. This prevents "identical"
        points from being part of the average and prevents division by 0. Note that this filter is applied
        after getting k nearest neighbors.
    max_bound
        Max distance the neighbors must be within (<)
    """
    if k < 1:
        raise ValueError("Input `k` must be >= 1.")

    if dist in ("cosine", "h", "haversine"):
        raise ValueError(f"Distance {dist} doesn't work with current implementation.")

    idx = to_expr(target).cast(pl.Float64).rechunk()
    feats = [to_expr(f) for f in features]
    keep_data = ~pl.any_horizontal(f.is_null() for f in feats)
    cols = [idx, keep_data]
    cols.extend(feats)

    kwargs = {
        "k": k,
        "metric": str(dist).lower(),
        "weighted": weighted,
        "parallel": parallel,
        "min_bound": min_bound,
        "max_bound": max_bound,
    }

    return pl_plugin(
        symbol="pl_knn_avg",
        args=cols,
        kwargs=kwargs,
    )

query_knn_freq_cnt(*features, index, k, dist='sql2', parallel=False, eval_mask=None, data_mask=None, epsilon=0.0, max_bound=99999.0)

Takes the index column, and uses feature columns to determine the k nearest neighbors to each row, and finally returns the number of times a row is a KNN of some other point.

This calls query_knn_ptwise internally. See the docstring of query_knn_ptwise for more info.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
index str | Expr

The column used as index, must be castable to u32

required
k int

Number of neighbors to query

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
parallel bool

Whether to run the k-nearest neighbor query in parallel. This is recommended when you are running only this expression, and not in group_by() or over() context.

False
return_dist

If true, return a struct with indices and distances.

required
eval_mask str | Expr | None

Either None or a boolean expression or the name of a boolean column. If not none, this will only evaluate KNN for rows where this is true. This can speed up computation when only results on a subset are nedded.

None
data_mask str | Expr | None

Either None or a boolean expression or the name of a boolean column. If none, all rows can be neighbors. If not None, the pool of possible neighbors will be rows where this is true.

None
epsilon float

If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter should increase as the dimension of the vector space increases because higher dimensions allow for errors from more directions.

0.0
max_bound float

Max distance the neighbors must be within

99999.0
Source code in python/polars_ds/exprs/expr_knn.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def query_knn_freq_cnt(
    *features: str | pl.Expr,
    index: str | pl.Expr,
    k: int,
    dist: Distance = "sql2",
    parallel: bool = False,
    eval_mask: str | pl.Expr | None = None,
    data_mask: str | pl.Expr | None = None,
    epsilon: float = 0.0,
    max_bound: float = 99999.0,
) -> pl.Expr:
    """
    Takes the index column, and uses feature columns to determine the k nearest neighbors
    to each row, and finally returns the number of times a row is a KNN of some other point.

    This calls `query_knn_ptwise` internally. See the docstring of `query_knn_ptwise` for more info.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    index : str | pl.Expr
        The column used as index, must be castable to u32
    k : int
        Number of neighbors to query
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    parallel : bool
        Whether to run the k-nearest neighbor query in parallel. This is recommended when you
        are running only this expression, and not in group_by() or over() context.
    return_dist
        If true, return a struct with indices and distances.
    eval_mask
        Either None or a boolean expression or the name of a boolean column. If not none, this will
        only evaluate KNN for rows where this is true. This can speed up computation when only results on a
        subset are nedded.
    data_mask
        Either None or a boolean expression or the name of a boolean column. If none, all rows can be
        neighbors. If not None, the pool of possible neighbors will be rows where this is true.
    epsilon
        If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter
        should increase as the dimension of the vector space increases because higher dimensions
        allow for errors from more directions.
    max_bound
        Max distance the neighbors must be within
    """

    knn_expr: pl.Expr = query_knn_ptwise(
        *features,
        index=index,
        k=k,
        dist=dist,
        parallel=parallel,
        return_dist=False,
        eval_mask=eval_mask,
        data_mask=data_mask,
        epsilon=epsilon,
        max_bound=max_bound,
    )
    return knn_expr.explode().drop_nulls().value_counts(sort=True, parallel=parallel)

query_knn_ptwise(*features, index, k, dist='sql2', parallel=False, return_dist=False, eval_mask=None, data_mask=None, epsilon=0.0, max_bound=99999.0)

Takes the index column, and uses feature columns to determine the k nearest neighbors to each row. By default, this will return k + 1 neighbors, because the point (the row) itself is a neighbor to itself and this returns k additional neighbors. The only exception to this is when data_mask excludes the point from being a neighbor, in which case, k + 1 distinct neighbors will be returned. Any row with a null/NaN will never be a neighbor and will have null as its neighbor.

Note that the index column must be convertible to u32. If you do not have a u32 column, you can generate one using pl.int_range(..), which should be a step before this. The index column must not contain nulls.

Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find k neighbors within max_bound, then there will be < k neighbors returned.

Also note that this internally builds a kd-tree for fast querying and deallocates it once we are done. If you need to repeatedly run the same query on the same data, then it is not ideal to use this. A specialized external kd-tree structure would be better in that case.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
index str | Expr

The column used as index, must be castable to u32

required
k int

Number of neighbors to query

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
parallel bool

Whether to run the k-nearest neighbor query in parallel. This is recommended when you are running only this expression, and not in group_by() or over() context.

False
return_dist bool

If true, return a struct with indices and distances.

False
eval_mask str | Expr | None

Either None or a boolean expression or the name of a boolean column. If not none, this will only evaluate KNN for rows where this is true. This can speed up computation when only results on a subset are nedded.

None
data_mask str | Expr | None

Either None or a boolean expression or the name of a boolean column. If none, all rows can be neighbors. If not None, the pool of possible neighbors will be rows where this is true.

None
epsilon float

If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter should increase as the dimension of the vector space increases because higher dimensions allow for errors from more directions.

0.0
max_bound float

Max distance the neighbors must be within

99999.0
Source code in python/polars_ds/exprs/expr_knn.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def query_knn_ptwise(
    *features: str | pl.Expr,
    index: str | pl.Expr,
    k: int,
    dist: Distance = "sql2",
    parallel: bool = False,
    return_dist: bool = False,
    eval_mask: str | pl.Expr | None = None,
    data_mask: str | pl.Expr | None = None,
    epsilon: float = 0.0,
    max_bound: float = 99999.0,
) -> pl.Expr:
    """
    Takes the index column, and uses feature columns to determine the k nearest neighbors
    to each row. By default, this will return k + 1 neighbors, because the point (the row) itself
    is a neighbor to itself and this returns k additional neighbors. The only exception to this
    is when data_mask excludes the point from being a neighbor, in which case, k + 1 distinct neighbors will
    be returned. Any row with a null/NaN will never be a neighbor and will have null as its neighbor.

    Note that the index column must be convertible to u32. If you do not have a u32 column,
    you can generate one using pl.int_range(..), which should be a step before this. The index column
    must not contain nulls.

    Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find
    k neighbors within `max_bound`, then there will be < k neighbors returned.

    Also note that this internally builds a kd-tree for fast querying and deallocates it once we
    are done. If you need to repeatedly run the same query on the same data, then it is not
    ideal to use this. A specialized external kd-tree structure would be better in that case.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    index : str | pl.Expr
        The column used as index, must be castable to u32
    k : int
        Number of neighbors to query
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    parallel : bool
        Whether to run the k-nearest neighbor query in parallel. This is recommended when you
        are running only this expression, and not in group_by() or over() context.
    return_dist
        If true, return a struct with indices and distances.
    eval_mask
        Either None or a boolean expression or the name of a boolean column. If not none, this will
        only evaluate KNN for rows where this is true. This can speed up computation when only results on a
        subset are nedded.
    data_mask
        Either None or a boolean expression or the name of a boolean column. If none, all rows can be
        neighbors. If not None, the pool of possible neighbors will be rows where this is true.
    epsilon
        If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter
        should increase as the dimension of the vector space increases because higher dimensions
        allow for errors from more directions.
    max_bound
        Max distance the neighbors must be within
    """
    if k < 1:
        raise ValueError("Input `k` must be >= 1.")

    if dist in ("cosine", "h", "haversine"):
        raise ValueError(f"Distance {dist} doesn't work with current implementation.")

    idx = to_expr(index).cast(pl.UInt32).rechunk()
    cols = [idx]
    feats: List[pl.Expr] = [to_expr(e) for e in features]

    skip_data = data_mask is not None
    if skip_data:  # true means keep
        keep_mask = pl.all_horizontal(to_expr(data_mask), *(f.is_not_null() for f in feats))
    else:
        keep_mask = pl.all_horizontal(f.is_not_null() for f in feats)

    cols.append(keep_mask)
    skip_eval = eval_mask is not None
    if skip_eval:
        cols.append(to_expr(eval_mask))

    cols.extend(feats)
    kwargs = {
        "k": k,
        "metric": str(dist).lower(),
        "parallel": parallel,
        "skip_eval": skip_eval,
        "max_bound": max_bound,
        "epsilon": 0.0,
    }
    if return_dist:
        return pl_plugin(
            symbol="pl_knn_ptwise_w_dist",
            args=cols,
            kwargs=kwargs,
        )
    else:
        return pl_plugin(
            symbol="pl_knn_ptwise",
            args=cols,
            kwargs=kwargs,
        )

query_nb_cnt(*features, r, dist='sql2', parallel=False)

Return the number of neighbors within (<=) radius r for each row under the given distance metric. The point itself is always a neighbor of itself.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
r float | Iterable[float] | Expr | str

If this is a scalar, then it will run the query with fixed radius for all rows. If this is a list, then it must have the same height as the dataframe. If this is an expression, it must be an expression representing radius. If this is a str, it must be the name of a column

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
parallel bool

Whether to run the distance query in parallel. This is recommended when you are running only this expression, and not in group_by() or over() context.

False
Source code in python/polars_ds/exprs/expr_knn.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def query_nb_cnt(
    *features: str | pl.Expr,
    r: float | str | pl.Expr | Iterable[float],
    dist: Distance = "sql2",
    parallel: bool = False,
) -> pl.Expr:
    """
    Return the number of neighbors within (<=) radius r for each row under the given distance
    metric. The point itself is always a neighbor of itself.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    r : float | Iterable[float] | pl.Expr | str
        If this is a scalar, then it will run the query with fixed radius for all rows. If
        this is a list, then it must have the same height as the dataframe. If
        this is an expression, it must be an expression representing radius. If this is a str,
        it must be the name of a column
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    parallel : bool
        Whether to run the distance query in parallel. This is recommended when you
        are running only this expression, and not in group_by() or over() context.
    """
    if dist in ("cosine", "h", "haversine"):
        raise ValueError(f"Distance `{dist}` doesn't work with current implementation.")

    if isinstance(r, (float, int)):
        rad = pl.lit(pl.Series(values=[r], dtype=pl.Float64))
    elif isinstance(r, pl.Expr):
        rad = r
    elif isinstance(r, str):
        rad = pl.col(r)
    else:
        rad = pl.lit(pl.Series(values=r, dtype=pl.Float64))

    return pl_plugin(
        symbol="pl_nb_cnt",
        args=[rad] + [to_expr(x) for x in features],
        kwargs={
            "k": 0,
            "metric": dist,
            "parallel": parallel,
            "skip_eval": False,
            "skip_data": False,
        },
    )

query_radius_freq_cnt(*features, index, r, dist='sql2', parallel=False)

Takes the index column, and uses features columns to determine distance, finds all neighbors within distance r from each index, and finally finds the count of the number of times the point is within distance r from other points.

This calls query_radius_ptwise internally. See the docstring of query_radius_ptwise for more info.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
index str | Expr

The column used as index, must be castable to u32

required
r float

The radius. Must be a scalar value now.

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
parallel bool

Whether to run the k-nearest neighbor query in parallel. This is recommended when you are running only this expression, and not in group_by() or over() context.

False
Source code in python/polars_ds/exprs/expr_knn.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def query_radius_freq_cnt(
    *features: str | pl.Expr,
    index: str | pl.Expr,
    r: float,
    dist: Distance = "sql2",
    parallel: bool = False,
) -> pl.Expr:
    """
    Takes the index column, and uses features columns to determine distance, finds all neighbors
    within distance r from each index, and finally finds the count of the number of times the point is
    within distance r from other points.

    This calls `query_radius_ptwise` internally. See the docstring of `query_radius_ptwise` for more info.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    index : str | pl.Expr
        The column used as index, must be castable to u32
    r : float
        The radius. Must be a scalar value now.
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    parallel : bool
        Whether to run the k-nearest neighbor query in parallel. This is recommended when you
        are running only this expression, and not in group_by() or over() context.
    """
    within_radius = query_radius_ptwise(
        *features, index=index, r=r, dist=dist, sort=False, parallel=parallel
    )

    return within_radius.explode().drop_nulls().value_counts(sort=True, parallel=parallel)

query_radius_ptwise(*features, index, r, dist='sql2', sort=True, parallel=False)

Takes the index column, and uses features columns to determine distance, and finds all neighbors within distance r from each id. If you only care about neighbor count, you should use query_nb_cnt, which supports expression for radius and is way faster.

Note that the index column must be convertible to u32. If you do not have a u32 ID column, you can generate one using pl.int_range(..), which should be a step before this.

Also note that this internally builds a kd-tree for fast querying and deallocates it once we are done. If you need to repeatedly run the same query on the same data, then it is not ideal to use this. A specialized external kd-tree structure would be better in that case.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
index str | Expr

The column used as index, must be castable to u32

required
r float

The radius. Must be a scalar value now.

required
dist Literal[`l1`, `l2`, `sql2`, `inf`]

Note sql2 stands for squared l2.

'sql2'
sort bool

Whether the neighbors returned should be sorted by the distance. Setting this to False can improve performance by 10-20%.

True
parallel bool

Whether to run the k-nearest neighbor query in parallel. This is recommended when you are running only this expression, and not in group_by() or over() context.

False
Source code in python/polars_ds/exprs/expr_knn.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def query_radius_ptwise(
    *features: str | pl.Expr,
    index: str | pl.Expr,
    r: float,
    dist: Distance = "sql2",
    sort: bool = True,
    parallel: bool = False,
) -> pl.Expr:
    """
    Takes the index column, and uses features columns to determine distance, and finds all neighbors
    within distance r from each id. If you only care about neighbor count, you should use
    `query_nb_cnt`, which supports expression for radius and is way faster.

    Note that the index column must be convertible to u32. If you do not have a u32 ID column,
    you can generate one using pl.int_range(..), which should be a step before this.

    Also note that this internally builds a kd-tree for fast querying and deallocates it once we
    are done. If you need to repeatedly run the same query on the same data, then it is not
    ideal to use this. A specialized external kd-tree structure would be better in that case.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    index : str | pl.Expr
        The column used as index, must be castable to u32
    r : float
        The radius. Must be a scalar value now.
    dist : Literal[`l1`, `l2`, `sql2`, `inf`]
        Note `sql2` stands for squared l2.
    sort
        Whether the neighbors returned should be sorted by the distance. Setting this to False can
        improve performance by 10-20%.
    parallel : bool
        Whether to run the k-nearest neighbor query in parallel. This is recommended when you
        are running only this expression, and not in group_by() or over() context.
    """

    if r <= 0.0:
        raise ValueError("Input `r` must be > 0.")
    elif isinstance(r, pl.Expr):
        raise ValueError("Input `r` must be a scalar now. Expression input is not implemented.")

    if dist in ("cosine", "h", "haversine"):
        raise ValueError(f"Distance {dist} doesn't work with current implementation.")

    idx = to_expr(index).cast(pl.UInt32).rechunk()
    metric = str(dist).lower()
    cols = [idx]
    cols.extend(to_expr(x) for x in features)
    return pl_plugin(
        symbol="pl_query_radius_ptwise",
        args=cols,
        kwargs={"r": r, "metric": metric, "parallel": parallel, "sort": sort},
    )

query_radius_ptwise_null_safe(*features, index, r, dist='sql2', sort=True, parallel=False)

Null-safe variant of query_radius_ptwise. Rows where any feature column is null are excluded from the kd-tree (they cannot be neighbors) and return null in the output list.

The non-null-safe query_radius_ptwise panics on null feature inputs because the kd-tree builder reads features unchecked; use this variant when nulls cannot be ruled out upstream.

Parameters mirror query_radius_ptwise.

Source code in python/polars_ds/exprs/expr_knn.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
def query_radius_ptwise_null_safe(
    *features: str | pl.Expr,
    index: str | pl.Expr,
    r: float,
    dist: Distance = "sql2",
    sort: bool = True,
    parallel: bool = False,
) -> pl.Expr:
    """
    Null-safe variant of `query_radius_ptwise`. Rows where any feature column is null are
    excluded from the kd-tree (they cannot be neighbors) and return null in the output list.

    The non-null-safe `query_radius_ptwise` panics on null feature inputs because the kd-tree
    builder reads features unchecked; use this variant when nulls cannot be ruled out upstream.

    Parameters mirror `query_radius_ptwise`.
    """
    if r <= 0.0:
        raise ValueError("Input `r` must be > 0.")
    elif isinstance(r, pl.Expr):
        raise ValueError("Input `r` must be a scalar now. Expression input is not implemented.")

    if dist in ("cosine", "h", "haversine"):
        raise ValueError(f"Distance {dist} doesn't work with current implementation.")

    idx = to_expr(index).cast(pl.UInt32).rechunk()
    feats = [to_expr(e) for e in features]
    keep_mask = pl.all_horizontal(f.is_not_null() for f in feats)

    cols = [idx, keep_mask]
    cols.extend(feats)
    return pl_plugin(
        symbol="pl_query_radius_ptwise_null_safe",
        args=cols,
        kwargs={"r": r, "metric": str(dist).lower(), "parallel": parallel, "sort": sort},
    )

warn_len_compare(item1, item2)

Compares the len of two Iterables if they have len returning true and warning if no len.

Parameters:

Name Type Description Default
item1 Iterable[Any]

Any iterable

required
item2 Iterable[Any]

Any iterable

required
Returns

bool: If both items have len then it will simply return whether or not they have equal size. If they don't have len then it returns True with a warning

required
Source code in python/polars_ds/exprs/expr_knn.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def warn_len_compare(item1: Iterable[Any], item2: Iterable[Any]) -> bool:
    """
    Compares the len of two Iterables if they have len returning true and warning if no len.

    Parameters
    ----------
    item1: Iterable[Any]
        Any iterable
    item2: Iterable[Any])
        Any iterable

    Returns:
        bool: If both items have __len__ then it will simply return whether or not
            they have equal size. If they don't have len then it returns True with a
            warning
    """
    if hasattr(item1, "__len__") and hasattr(item2, "__len__"):
        return len(cast(Sequence, item1)) == len(cast(Sequence, item2))
    else:
        msg = "The inputs do not each have len so can't be compared, unexpected results may follow."
        warnings.warn(msg, stacklevel=2)
        return True

within_dist_from(*features, pt, r, dist='sql2')

Returns a boolean column that returns points that are within radius from the given point.

Parameters:

Name Type Description Default
*features str | Expr

Other columns used as features

()
pt Iterable[float]

The point

required
r either a float or an expression

The radius to query with. If this is an expression, the radius will be applied row-wise.

required
dist Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`, `haversine`]

Note sql2 stands for squared l2.

'sql2'
Source code in python/polars_ds/exprs/expr_knn.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def within_dist_from(
    *features: str | pl.Expr,
    pt: Sequence[float] | Iterable[float],
    r: float | pl.Expr,
    dist: Distance = "sql2",
) -> pl.Expr:
    """
    Returns a boolean column that returns points that are within radius from the given point.

    Parameters
    ----------
    *features : str | pl.Expr
        Other columns used as features
    pt : Iterable[float]
        The point
    r : either a float or an expression
        The radius to query with. If this is an expression, the radius will be applied row-wise.
    dist : Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`, `haversine`]
        Note `sql2` stands for squared l2.
    """
    # For a single point, it is faster to just do it in native polars
    oth = [to_expr(x) for x in features]
    if not warn_len_compare(pt, oth):
        raise ValueError("Dimension does not match.")

    if dist == "l1":
        return (
            pl.sum_horizontal((e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth))
            <= r
        )
    elif dist in ("l2", "sql2"):
        return (
            pl.sum_horizontal((e - pl.lit(xi, dtype=pl.Float64)).pow(2) for xi, e in zip(pt, oth))
            <= r
        )
    elif dist == "inf":
        return (
            pl.max_horizontal((e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth))
            <= r
        )
    elif dist == "cosine":
        x_list = list(pt)
        x_norm = sum(z * z for z in x_list)
        oth_norm = pl.sum_horizontal(e * e for e in oth)
        distN = (
            1.0
            - pl.sum_horizontal(xi * e for xi, e in zip(x_list, oth)) / (x_norm * oth_norm).sqrt()
        )
        return distN <= r
    elif dist in ("h", "haversine"):
        from . import haversine

        pt_as_list = list(pt)
        if (len(pt_as_list) != 2) or (len(oth) < 2):
            raise ValueError(
                "For Haversine distance, input x must have dimension 2 and 2 other columns"
                " must be provided as lat and long."
            )

        y_lat = pl.lit(pt_as_list[0], dtype=pl.Float64)
        y_long = pl.lit(pt_as_list[1], dtype=pl.Float64)
        dist_out = haversine(oth[0], oth[1], y_lat, y_long)
        return dist_out <= r
    else:
        raise ValueError(f"Unknown distance function: {dist}")