Skip to content

Explorative Data Analysis

Explorative Data Analysis

Modules:

Name Description
diagnosis

Data Inspection Assistant and Visualizations for Polars Dataframe.

plots

diagnosis

Data Inspection Assistant and Visualizations for Polars Dataframe.

Currently, the plot backend is Altair but this is subject to change, and will be decided base on which plotting backend supports Polars more natively.

Classes:

Name Description
DIA

Data Inspection Assistant. Most plots are powered by Altair/great_tables. Altair may require

DIA

Data Inspection Assistant. Most plots are powered by Altair/great_tables. Altair may require additional package downloads.

If you cannot import this module, please try: pip install "polars_ds[plot]"

Note: most plots are sampled by default because (1) typically plots don't look good when there are too many points, and (2) because of interactivity, if we don't sample, the plots will be too large and won't get rendered in a reasonable amount of time. If speed of rendering is crucial and you don't need interactivity, use matplotlib.

Methods:

Name Description
col_validation

Generates a validation report based on rules (pl.Expr) which evaluates to a single

corr

Returns a dataframe containing correlation information between the subset and all numeric columns.

infer_binary

Infers whether the column is binary.

infer_const

Infers whether the column is constant.

infer_corr

Trying to infer highly correlated columns by computing correlation between

infer_dependency

Infers (functional) dependency using the method of conditional entropy. This only evaluates

infer_discrete

Infers discrete columns based on unique percentage and max_val_count.

infer_high_null

Infers columns with more than threshold percentage nulls.

infer_k_distinct

Infers whether the column has k distinct values.

infer_prob

Infers columns that can potentially be probabilities. For f32/f64 columns, this checks if all values are

meta

Returns internal data in this class as a dictionary.

null_corr

Computes the correlation between A is null and B is null for all (A, B) combinations

numeric_profile

Creates a numerical profile with a histogram plot. Notice that the histograms may have

plot_corr

Plots the correlations using classic heat maps.

plot_dependency

Plot dependency using the result of self.infer_dependency and positively dtermines

plot_feature_distr

Plot distribution of the feature with a few statistical details.

row_validation

Generates a validation report based on rules (pl.Expr) which evaluates to booleans

special_values_report

Checks null, NaN, and non-finite values for float columns. Note that for integers, only null_count

str_stats

Returns basic statistics about the string columns.

Source code in python/polars_ds/eda/diagnosis.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
class DIA:
    """
    Data Inspection Assistant. Most plots are powered by Altair/great_tables. Altair may require
    additional package downloads.

    If you cannot import this module, please try: pip install "polars_ds[plot]"

    Note: most plots are sampled by default because (1) typically plots don't look good when there
    are too many points, and (2) because of interactivity, if we don't sample, the plots will be too
    large and won't get rendered in a reasonable amount of time. If speed of rendering is crucial and
    you don't need interactivity, use matplotlib.
    """

    # --- Static / Class Methods ---

    # --- Methods ---

    def __init__(self, df: PolarsFrame):
        self._frame: pl.LazyFrame = df.lazy()
        self.numerics: List[str] = df.select(cs.numeric()).collect_schema().names()
        self.ints: List[str] = df.select(cs.integer()).collect_schema().names()
        self.floats: List[str] = df.select(cs.float()).collect_schema().names()
        self.strs: List[str] = df.select(cs.string()).collect_schema().names()
        self.bools: List[str] = df.select(cs.boolean()).collect_schema().names()
        self.cats: List[str] = df.select(cs.categorical()).collect_schema().names()

        schema_dict = df.collect_schema()
        columns = schema_dict.names()

        self.list_floats: List[str] = [
            c
            for c, t in schema_dict.items()
            if (t.is_(pl.List(pl.Float32)) or (t.is_(pl.List(pl.Float64))))
        ]
        self.list_bool: List[str] = [
            c for c, t in schema_dict.items() if t.is_(pl.List(pl.Boolean))
        ]
        self.list_str: List[str] = [c for c, t in schema_dict.items() if t.is_(pl.List(pl.String))]
        self.list_ints: List[str] = [
            c
            for c, t in schema_dict.items()
            if t.is_(pl.List(pl.UInt8))
            or t.is_(pl.List(pl.UInt16))
            or t.is_(pl.List(pl.UInt32))
            or t.is_(pl.List(pl.UInt64))
            or t.is_(pl.List(pl.Int8))
            or t.is_(pl.List(pl.Int16))
            or t.is_(pl.List(pl.Int32))
            or t.is_(pl.List(pl.Int64))
        ]

        self.simple_types: List[str] = (
            self.numerics
            + self.strs
            + self.bools
            + self.cats
            + self.list_floats
            + self.list_ints
            + self.list_bool
            + self.list_str
        )
        self.other_types: List[str] = [c for c in columns if c not in self.simple_types]

    def special_values_report(self) -> pl.DataFrame:
        """
        Checks null, NaN, and non-finite values for float columns. Note that for integers, only null_count
        can possibly be non-zero.
        """
        to_check = self.numerics
        frames = [
            self._frame.select(
                pl.lit(c, dtype=pl.String).alias("column"),
                pl.col(c).null_count().alias("null_count"),
                (pl.col(c).null_count() / pl.len()).alias("null%"),
                pl.col(c).is_nan().sum().alias("NaN_count"),
                (pl.col(c).is_nan().sum() / pl.len()).alias("NaN%"),
                pl.col(c).is_infinite().sum().alias("inf_count"),
                (pl.col(c).is_infinite().sum() / pl.len()).alias("Inf%"),
            )
            for c in to_check
        ]
        return pl.concat(pl.collect_all(frames))

    def numeric_profile(
        self, n_bins: int = 20, iqr_multiplier: float = 1.5, histogram: bool = True, gt: bool = True
    ) -> GT | pl.DataFrame:
        """
        Creates a numerical profile with a histogram plot. Notice that the histograms may have
        completely different scales on the x-axis.

        Parameters
        ----------
        n_bins
            Bins in the histogram
        iqr_multiplier
            Inter Quartile Ranger multiplier. Inter quantile range is the range between
            Q1 and Q3, and this multiplier will enlarge the range by a certain amount and
            use this to count outliers.
        histogram
            Whether to show a histogram or not
        gt
            Whether to show the table as a formatted Great Table or not
        """
        to_check = self.numerics

        cuts = [i / n_bins for i in range(n_bins)]
        cuts[0] -= 1e-5
        cuts[-1] += 1e-5

        if histogram:
            columns_needed = [
                [
                    pl.lit(c, dtype=pl.String).alias("column"),
                    pl.col(c).count().alias("non_null_cnt"),
                    (pl.col(c).null_count() / pl.len()).alias("null%"),
                    pl.col(c).mean().cast(pl.Float64).alias("mean"),
                    pl.col(c).std().cast(pl.Float64).alias("std"),
                    pl.col(c).min().cast(pl.Float64).cast(pl.Float64).alias("min"),
                    pl.col(c).quantile(0.25).cast(pl.Float64).alias("q1"),
                    pl.col(c).median().cast(pl.Float64).round(2).alias("median"),
                    pl.col(c).quantile(0.75).cast(pl.Float64).alias("q3"),
                    pl.col(c).max().cast(pl.Float64).alias("max"),
                    (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25))
                    .cast(pl.Float64)
                    .alias("IQR"),
                    pl.any_horizontal(
                        pl.col(c)
                        < pl.col(c).quantile(0.25)
                        - iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                        pl.col(c)
                        > pl.col(c).quantile(0.75)
                        + iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                    )
                    .sum()
                    .alias("outlier_cnt"),
                    pl.struct(
                        ((pl.col(c) - pl.col(c).min()) / (pl.col(c).max() - pl.col(c).min()))
                        .filter(pl.col(c).is_finite())
                        .cut(breaks=cuts, left_closed=True, include_breaks=True)
                        .struct.rename_fields(["brk", "category"])
                        .struct.field("brk")
                        .value_counts()
                        .sort()
                        .struct.field("count")
                        .implode()
                    ).alias("histogram"),
                ]
                for c in to_check
            ]
        else:
            columns_needed = [
                [
                    pl.lit(c, dtype=pl.String).alias("column"),
                    pl.col(c).count().alias("non_null_cnt"),
                    (pl.col(c).null_count() / pl.len()).alias("null%"),
                    pl.col(c).mean().cast(pl.Float64).alias("mean"),
                    pl.col(c).std().cast(pl.Float64).alias("std"),
                    pl.col(c).min().cast(pl.Float64).cast(pl.Float64).alias("min"),
                    pl.col(c).quantile(0.25).cast(pl.Float64).alias("q1"),
                    pl.col(c).median().cast(pl.Float64).round(2).alias("median"),
                    pl.col(c).quantile(0.75).cast(pl.Float64).alias("q3"),
                    pl.col(c).max().cast(pl.Float64).alias("max"),
                    (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25))
                    .cast(pl.Float64)
                    .alias("IQR"),
                    pl.any_horizontal(
                        pl.col(c)
                        < pl.col(c).quantile(0.25)
                        - iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                        pl.col(c)
                        > pl.col(c).quantile(0.75)
                        + iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                    )
                    .sum()
                    .alias("outlier_cnt"),
                ]
                for c in to_check
            ]

        frames = [self._frame.select(*cols) for cols in columns_needed]
        df_final = pl.concat(pl.collect_all(frames))

        if gt:
            gt_out = (
                GT(df_final, rowname_col="column")
                .tab_stubhead("column")
                .fmt_percent(columns="null%")
                .fmt_number(
                    columns=["mean", "std", "min", "q1", "median", "q3", "max", "IQR"], decimals=3
                )
            )
            if histogram:
                return gt_out.fmt_nanoplot(columns="histogram", plot_type="bar")
            return gt_out
        else:
            return df_final

    def col_validation(
        self,
        *rules: Tuple[pl.Expr, str],
    ) -> pl.DataFrame:
        """
        Generates a validation report based on rules (pl.Expr) which evaluates to a single
        boolean per column.

        Parameters
        ----------
        rules
            A tuple of (pl.Expr, str), where the pl.Expr should evaluate to a single boolean value.
            If the boolean is False, then the entire column is considered to be violiating the rule.
        """
        rules_to_check = list(rules)
        rules_exprs = [r.name.keep() for r, _ in rules_to_check]
        violation_messages = [msg for _, msg in rules_to_check]

        df_temp = self._frame.select(*rules_exprs).collect()

        if len(df_temp) > 1:
            raise ValueError(
                "Column rules must evaluate to a single boolean expression for each rule. "
                f"But a dataframe of shape {df_temp.shape} is produced."
            )

        df_violation = df_temp.transpose(include_header=True, column_names=["pass"]).with_columns(
            pl.Series(name="__reason__", values=violation_messages)
        )

        return df_violation.filter(pl.col("pass").not_()).select("column", "__reason__")

    def row_validation(
        self,
        *rules: Tuple[pl.Expr, str],
        id_col: str | None = None,
        columns_to_keep: List[str] | None = None,
        all_reasons: bool = False,
    ) -> pl.DataFrame:
        """
        Generates a validation report based on rules (pl.Expr) which evaluates to booleans
        per row.

        Parameters
        ----------
        rules
            A tuple of (pl.Expr, str), where the pl.Expr should evaluate to a boolean value
            per row. If the boolean is False, then the row is considered a violation. The string
            should be an explanation of the violation.
        id_col
            If None, an "__index__" column will be generated which is the row number.
        columns_to_keep
            Other columns you wish to keep in the final report.
        all_reasons
            If true, all reasons for violations will be returned. If false, only 1 will be returned.
        """

        if id_col is None:
            df = self._frame.with_row_index(name="__index__")
            to_keep = ["__index__"]
        else:
            df = self._frame
            to_keep = [id_col]

        rules_to_check = list(rules)
        rules_exprs = [r.alias(n) for r, n in rules_to_check]
        all_rule_names = [n for _, n in rules_to_check]
        # Do not allow duplicate rule names
        existing_names = set()
        for name in all_rule_names:
            if name not in existing_names:
                existing_names.add(name)
            else:
                raise ValueError(f"Rule name {name} is duplicate. Please rename it.")

        # We cannot use list(set(..)) because that might change the order of all_rule_names

        if columns_to_keep is not None:
            to_keep += columns_to_keep

        df_temp = df.select(*to_keep, *rules_exprs).filter(
            # Filter to the violators.
            # pl.all_horizontal(*all_rule_names) = people who pass all rules
            # pl.all_horizontal(*all_rule_names).not_() = people who failed any one of the rules
            pl.all_horizontal(*all_rule_names).not_()
        )

        if all_reasons:
            reasons = [
                pl.when(pl.col(c)).then(None).otherwise(pl.lit(c, dtype=pl.String))
                for c in all_rule_names
            ]  # When true, return None. When false, return reason

            return df_temp.select(
                *to_keep, pl.concat_list(reasons).list.drop_nulls().list.sort().alias("__reason__")
            ).collect()
        else:
            # df_temp = all people who failed any one of the rules. So there must be at least one 0 in concat-ed list.
            return df_temp.select(
                *to_keep,
                pl.concat_list(all_rule_names)
                .list.arg_min()
                .replace_strict(old=list(range(len(all_rule_names))), new=all_rule_names)
                .alias("__reason__"),
            ).collect()

    def null_corr(
        self,
        subset: IntoExpr | Iterable[IntoExpr] = pl.all(),
        filter_by: pl.Expr | None = None,
    ) -> pl.DataFrame:
        """
        Computes the correlation between A is null and B is null for all (A, B) combinations
        in the given subset of columns.

        If either A or B is all null or all non-null, the null correlation will not be
        computed, since the value is not going to be meaningful.

        Parameters
        ----------
        subset
            Anything that can be put into a Polars .select statement. Defaults to pl.all()
        filter_by
            A boolean expression
        """

        cols = self._frame.select(subset).collect_schema().names()

        if filter_by is None:
            frame = self._frame.select(pl.col(cols).is_null()).collect()
        else:
            frame = self._frame.filter(filter_by).select(pl.col(cols).is_null()).collect()

        df_null_cnt = frame.sum()
        n = frame.shape[0]

        invalid = set(
            c for c, cnt in zip(df_null_cnt.columns, df_null_cnt.row(0)) if (cnt == 0 or cnt == n)
        )

        xx = []
        yy = []
        for x, y in combinations(cols, 2):
            if not (x in invalid or y in invalid):
                xx.append(x)
                yy.append(y)

        if len(xx) == 0:
            return pl.DataFrame(
                {"column_1": [], "column_2": [], "null_corr": []},
                schema={
                    "column_1": pl.String,
                    "column_2": pl.String,
                    "null_corr": pl.Float64,
                },
            )
        else:
            corrs = frame.select(
                pl.corr(x, y).alias(str(i)) for i, (x, y) in enumerate(zip(xx, yy))
            ).row(0)
            return pl.DataFrame({"column_1": xx, "column_2": yy, "null_corr": corrs}).sort(
                pl.col("null_corr").abs(), descending=True
            )

    def meta(self) -> Dict:
        """
        Returns internal data in this class as a dictionary.
        """
        out = self.__dict__.copy()
        out.pop("_frame")
        return out

    def str_stats(self) -> pl.DataFrame:
        """
        Returns basic statistics about the string columns.
        """
        to_check = self.strs
        frames = [
            self._frame.select(
                pl.lit(c).alias("column"),
                pl.col(c).null_count().alias("null_count"),
                pl.col(c).n_unique().alias("n_unique"),
                pl.col(c).value_counts(sort=True).first().struct.field(c).alias("most_freq"),
                pl.col(c)
                .value_counts(sort=True)
                .first()
                .struct.field("count")
                .alias("most_freq_cnt"),
                pl.col(c).str.len_bytes().min().alias("min_byte_len"),
                pl.col(c).str.len_chars().min().alias("min_char_len"),
                pl.col(c).str.len_bytes().mean().alias("avg_byte_len"),
                pl.col(c).str.len_chars().mean().alias("avg_char_len"),
                pl.col(c).str.len_bytes().max().alias("max_byte_len"),
                pl.col(c).str.len_chars().max().alias("max_char_len"),
                pl.col(c).str.len_bytes().quantile(0.05).alias("5p_byte_len"),
                pl.col(c).str.len_bytes().quantile(0.95).alias("95p_byte_len"),
            )
            for c in to_check
        ]
        return pl.concat(pl.collect_all(frames))

    def corr(
        self, subset: IntoExpr | Iterable[IntoExpr], method: CorrMethod = "pearson"
    ) -> pl.DataFrame:
        """
        Returns a dataframe containing correlation information between the subset and all numeric columns.
        Only numerical columns will be checked.

        Parameters
        ----------
        subset
            Anything that can be put into a Polars .select statement.
        method
            One of ["pearson", "spearman", "xi", "kendall", "bicor"]
        """

        to_check = self._frame.select(subset).collect_schema().names()

        corrs = [
            self._frame.select(
                # This calls corr from .stats
                pl.lit(x).alias("column"),
                *(corr(x, y, method=method).alias(y) for y in self.numerics),
            )
            for x in to_check
        ]

        return pl.concat(pl.collect_all(corrs))

    def plot_corr(
        self, subset: IntoExpr | Iterable[IntoExpr], method: CorrMethod = "pearson"
    ) -> GT:
        """
        Plots the correlations using classic heat maps.

        Parameters
        ----------
        subset
            Anything that can be put into a Polars .select statement.
        method
            One of ["pearson", "spearman", "xi", "kendall", "bicor"]
        """
        corr_values = self.corr(subset, method)
        cols = [c for c in corr_values.columns if c != "column"]
        return (
            GT(corr_values)
            .fmt_number(columns=cols, decimals=3)
            .data_color(
                columns=cols,
                palette=["#0202bd", "#bd0237"],
                domain=[-1, 1],
                alpha=0.5,
                na_color="#000000",
            )
        )

    def infer_prob(self) -> List[str]:
        """
        Infers columns that can potentially be probabilities. For f32/f64 columns, this checks if all values are
        between 0 and 1. For List[f32] or List[f64] columns, this checks whether the column can potentially be
        multi-class probabilities.
        """
        is_ok = (
            self._frame.select(
                *((pl.col(c).is_between(0.0, 1.0).all()).alias(c) for c in self.floats),
                *(
                    (
                        (
                            pl.col(c).list.eval((pl.element() >= 0.0).all()).list.first()
                        )  # every number must be positive
                        & ((pl.col(c).list.sum() - 1.0).abs() < 1e-6)  # class prob must sum to 1
                        & (
                            pl.col(c).list.len().min() == pl.col(c).list.len().max()
                        )  # class prob column must have the same length
                    ).alias(c)
                    for c in self.list_floats
                ),
            )
            .collect()
            .row(0)
        )

        return [c for c, ok in zip(self.floats + self.list_floats, is_ok) if ok is True]

    @lru_cache
    def infer_high_null(self, threshold: float = 0.75) -> List[str]:
        """
        Infers columns with more than threshold percentage nulls.

        Parameters
        ----------
        threshold
            The threshold above which a column will be considered high null
        """
        is_ok = (
            self._frame.select(
                (pl.col(c).null_count() >= pl.len() * threshold).alias(c)
                for c in self._frame.columns
            )
            .collect()
            .row(0)
        )

        return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

    @lru_cache
    def infer_discrete(self, threshold: float = 0.1, max_val_cnt: int = 100) -> List[str]:
        """
        Infers discrete columns based on unique percentage and max_val_count.

        Parameters
        ----------
        threshold
            Columns with unique percentage lower than threshold will be considered
            discrete
        max_val_cnt
            Max number of unique values the column can have in order for it to be considered
            discrete
        """
        out: List[str] = self.bools + self.cats
        to_check = [c for c in self._frame.columns if c not in out]
        is_ok = (
            self._frame.select(
                (
                    (pl.col(c).n_unique() < max_val_cnt)
                    | (pl.col(c).n_unique() < threshold * pl.len())
                ).alias(c)
                for c in to_check
            )
            .collect()
            .row(0)
        )

        return [c for c, ok in zip(to_check, is_ok) if ok is True]

    @lru_cache
    def infer_const(self, include_null: bool = False) -> List[str]:
        """
        Infers whether the column is constant.

        Parameters
        ----------
        include_null
            If true, a constant column with null values will also be included.
        """
        if include_null:
            is_ok = (
                self._frame.select(
                    (
                        (pl.col(c).n_unique() == 1)
                        | ((pl.col(c).null_count() > 0) & (pl.col(c).n_unique() == 2))
                    ).alias(c)
                    for c in self._frame.columns
                )
                .collect()
                .row(0)
            )
        else:
            is_ok = (
                self._frame.select(
                    (pl.col(c).n_unique() == 1).alias(c) for c in self._frame.columns
                )
                .collect()
                .row(0)
            )

        return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

    @lru_cache
    def infer_binary(self, include_null: bool = False) -> List[str]:
        """
        Infers whether the column is binary.

        Parameters
        ----------
        include_null
            If true, a binary column with 2 non-null distinct values and null will also be included.
        """
        if include_null:
            is_ok = (
                self._frame.select(
                    (
                        (pl.col(c).n_unique() == 2)
                        | ((pl.col(c).null_count() > 0) & (pl.col(c).n_unique() == 3))
                    ).alias(c)
                    for c in self._frame.columns
                )
                .collect()
                .row(0)
            )
        else:
            is_ok = (
                self._frame.select(
                    (pl.col(c).n_unique() == 2).alias(c) for c in self._frame.columns
                )
                .collect()
                .row(0)
            )

        return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

    @lru_cache
    def infer_k_distinct(self, k: int, include_null: bool = False) -> List[str]:
        """
        Infers whether the column has k distinct values.

        Parameters
        ----------
        k
            Any positive integer.
        include_null
            If true, a binary column with k non-null distinct values and null will also be included.
        """
        if k < 1:
            raise ValueError("Input `k` must be >= 1.")

        if include_null:
            is_ok = (
                self._frame.select(
                    (
                        (pl.col(c).n_unique() == k)
                        | ((pl.col(c).null_count() > 0) & (pl.col(c).n_unique() == (k + 1)))
                    ).alias(c)
                    for c in self._frame.columns
                )
                .collect()
                .row(0)
            )
        else:
            is_ok = (
                self._frame.select(
                    (pl.col(c).n_unique() == k).alias(c) for c in self._frame.columns
                )
                .collect()
                .row(0)
            )

        return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

    def infer_corr(self, method: CorrMethod = "pearson") -> pl.DataFrame:
        """
        Trying to infer highly correlated columns by computing correlation between
        all numerical (including boolean) columns.

        Parameters
        ----------
        method
            One of ["pearson", "spearman", "xi", "kendall"]
        """
        to_check = self.numerics + self.bools

        xx = []
        yy = []
        for x, y in combinations(to_check, 2):
            xx.append(x)
            yy.append(y)

        corrs = (
            self._frame.with_columns(pl.col(c).cast(pl.UInt8) for c in self.bools)
            .select(corr(x, y, method=method).alias(f"{i}") for i, (x, y) in enumerate(zip(xx, yy)))
            .collect()
            .row(0)
        )

        return pl.DataFrame({"x": xx, "y": yy, "corr": corrs}).sort(
            pl.col("corr").abs(), descending=True
        )

    def infer_dependency(self, subset: IntoExpr | Iterable[IntoExpr] = pl.all()) -> pl.DataFrame:
        """
        Infers (functional) dependency using the method of conditional entropy. This only evaluates
        potential qualifying columns. Potential qualifying columns are columns of type:
        int, str, categorical, or booleans.

        If returned conditional entropy is very low, that means knowning the column in
        `by` is enough to to infer the column in `column`, or the column in `column` can
        be determined by the column in `by`.

        Parameters
        ----------
        subset
            A subset of columns to try running the dependency check. The subset input can be
            anything that can be turned into a Polars selector. The df or the column subset of the df
            may contain columns that cannot be used for dependency detection, e.g. column of list of values.
            Only valid columns will be checked.
        """

        # Infer valid columns to run this detection
        valid = self.ints + self.strs + self.cats + self.bools
        check_frame = self._frame.select(subset)
        all_names = check_frame.collect_schema().names()
        to_check = [x for x in all_names if x in valid]
        n_uniques = check_frame.select(pl.col(c).n_unique() for c in to_check).collect().row(0)

        frame = (
            pl.DataFrame({"column": to_check, "n_unique": n_uniques})
            .filter(pl.col("n_unique") > 1)
            .sort("n_unique")
        )

        check = list(frame["column"])
        if len(check) <= 1:
            warnings.warn(
                f"Not enough valid columns to detect dependency on. Valid column count: {len(check)}. Empty dataframe returned.",
                stacklevel=2,
            )
            return pl.DataFrame(
                {"column": [], "by": [], "cond_entropy": []},
                schema={"column": pl.String, "by": pl.String, "cond_entropy": pl.Float64},
            )

        if len(check) != len(all_names):
            warnings.warn(
                f"The following columns are dropped because they cannot be used in dependency detection: {[f for f in all_names if f not in check]}",
                stacklevel=2,
            )

        # Construct output
        column = []
        by = []
        for x, y in combinations(check, 2):
            column.append(x)
            by.append(y)

        ce = (
            self._frame.select(
                query_cond_entropy(x, y).abs().alias(f"{i}")
                for i, (x, y) in enumerate(zip(column, by))
            )
            .collect()
            .row(0)
        )

        out = pl.DataFrame({"column": column, "by": by, "cond_entropy": ce}).sort("cond_entropy")

        return out

    def plot_dependency(
        self, threshold: float = 0.01, subset: IntoExpr | Iterable[IntoExpr] = pl.all()
    ) -> graphviz.Digraph:
        """
        Plot dependency using the result of self.infer_dependency and positively dtermines
        dependency by the threshold.

        Parameters
        ----------
        threshold
            If conditional entropy is < threshold, we draw a line indicating dependency.
        subset
            A subset of columns to try running the dependency check. The subset input can be
            anything that can be turned into a Polars selector
        """

        dep_frame = self.infer_dependency(subset=subset)

        df_local = dep_frame.filter((pl.col("cond_entropy") < threshold)).select(
            pl.col("column").alias("child"),  # c for child
            pl.col("by").alias("parent"),  # p for parent
        )
        cp = df_local.group_by("child").agg(pl.col("parent"))
        pc = df_local.group_by("parent").agg(pl.col("child"))
        child_parent: dict[str, pl.Series] = dict(zip(cp["child"], cp["parent"]))
        parent_child: dict[str, pl.Series] = dict(zip(pc["parent"], pc["child"]))

        dot = graphviz.Digraph(
            "Dependency Plot", comment=f"Conditional Entropy < {threshold:.2f}", format="png"
        )
        for c, par in child_parent.items():
            parents_of_c = set(par)
            for p in par:
                # Does parent p have a child that is also a parent of c? If so, remove p.
                children_of_p = parent_child.get(p, None)
                if children_of_p is not None:
                    if len(parents_of_c.intersection(children_of_p)) > 0:
                        parents_of_c.remove(p)

            dot.node(c)
            for p in parents_of_c:
                dot.node(p)
                dot.edge(p, c)

        return dot

    def plot_feature_distr(
        self,
        feature: str,
        n_bins: int | None = None,
        density: bool = False,
        show_bad_values: bool = True,
        min_: float | pl.Expr | None = None,
        max_: float | pl.Expr | None = None,
        over: str | None = None,
        filter_by: pl.Expr | None = None,
    ) -> alt.Chart:
        """
        Plot distribution of the feature with a few statistical details.

        Parameters
        ----------
        feature
            A string representing a column name
        n_bins
            The number of bins used for histograms. Not used when the feature column is categorical.
        density
            Whether to plot a probability density or not
        show_bad_values
            Whether to show % of bad (null or non-finite) values
        min_
            Whether to ignore values strictly lower than min_
        max_
            Whether to ignore values strictly higher than max_
        over
            Whether to look at the distribution over another categorical column
        filter_by
            An extra condition you may want to impose on the underlying dataset
        """
        if feature not in self.numerics:
            raise ValueError("Input feature must be numeric.")

        conditions = []
        if filter_by is not None:
            conditions.append(filter_by)
        if min_ is not None:
            conditions.append(pl.col(feature) >= min_)
        if max_ is not None:
            conditions.append(pl.col(feature) <= max_)

        if len(conditions) > 0:
            df = self._frame.filter(pl.all_horizontal(*conditions)).select(feature, over).collect()
        else:
            df = self._frame.select(feature, over).collect()

        return plot_feature_distr(
            df=df,
            feature=feature,
            n_bins=n_bins,
            density=density,
            show_bad_values=show_bad_values,
            over=over,
        )

col_validation(*rules)

Generates a validation report based on rules (pl.Expr) which evaluates to a single boolean per column.

Parameters:

Name Type Description Default
rules Tuple[Expr, str]

A tuple of (pl.Expr, str), where the pl.Expr should evaluate to a single boolean value. If the boolean is False, then the entire column is considered to be violiating the rule.

()
Source code in python/polars_ds/eda/diagnosis.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def col_validation(
    self,
    *rules: Tuple[pl.Expr, str],
) -> pl.DataFrame:
    """
    Generates a validation report based on rules (pl.Expr) which evaluates to a single
    boolean per column.

    Parameters
    ----------
    rules
        A tuple of (pl.Expr, str), where the pl.Expr should evaluate to a single boolean value.
        If the boolean is False, then the entire column is considered to be violiating the rule.
    """
    rules_to_check = list(rules)
    rules_exprs = [r.name.keep() for r, _ in rules_to_check]
    violation_messages = [msg for _, msg in rules_to_check]

    df_temp = self._frame.select(*rules_exprs).collect()

    if len(df_temp) > 1:
        raise ValueError(
            "Column rules must evaluate to a single boolean expression for each rule. "
            f"But a dataframe of shape {df_temp.shape} is produced."
        )

    df_violation = df_temp.transpose(include_header=True, column_names=["pass"]).with_columns(
        pl.Series(name="__reason__", values=violation_messages)
    )

    return df_violation.filter(pl.col("pass").not_()).select("column", "__reason__")

corr(subset, method='pearson')

Returns a dataframe containing correlation information between the subset and all numeric columns. Only numerical columns will be checked.

Parameters:

Name Type Description Default
subset IntoExpr | Iterable[IntoExpr]

Anything that can be put into a Polars .select statement.

required
method CorrMethod

One of ["pearson", "spearman", "xi", "kendall", "bicor"]

'pearson'
Source code in python/polars_ds/eda/diagnosis.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def corr(
    self, subset: IntoExpr | Iterable[IntoExpr], method: CorrMethod = "pearson"
) -> pl.DataFrame:
    """
    Returns a dataframe containing correlation information between the subset and all numeric columns.
    Only numerical columns will be checked.

    Parameters
    ----------
    subset
        Anything that can be put into a Polars .select statement.
    method
        One of ["pearson", "spearman", "xi", "kendall", "bicor"]
    """

    to_check = self._frame.select(subset).collect_schema().names()

    corrs = [
        self._frame.select(
            # This calls corr from .stats
            pl.lit(x).alias("column"),
            *(corr(x, y, method=method).alias(y) for y in self.numerics),
        )
        for x in to_check
    ]

    return pl.concat(pl.collect_all(corrs))

infer_binary(include_null=False) cached

Infers whether the column is binary.

Parameters:

Name Type Description Default
include_null bool

If true, a binary column with 2 non-null distinct values and null will also be included.

False
Source code in python/polars_ds/eda/diagnosis.py
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
@lru_cache
def infer_binary(self, include_null: bool = False) -> List[str]:
    """
    Infers whether the column is binary.

    Parameters
    ----------
    include_null
        If true, a binary column with 2 non-null distinct values and null will also be included.
    """
    if include_null:
        is_ok = (
            self._frame.select(
                (
                    (pl.col(c).n_unique() == 2)
                    | ((pl.col(c).null_count() > 0) & (pl.col(c).n_unique() == 3))
                ).alias(c)
                for c in self._frame.columns
            )
            .collect()
            .row(0)
        )
    else:
        is_ok = (
            self._frame.select(
                (pl.col(c).n_unique() == 2).alias(c) for c in self._frame.columns
            )
            .collect()
            .row(0)
        )

    return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

infer_const(include_null=False) cached

Infers whether the column is constant.

Parameters:

Name Type Description Default
include_null bool

If true, a constant column with null values will also be included.

False
Source code in python/polars_ds/eda/diagnosis.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
@lru_cache
def infer_const(self, include_null: bool = False) -> List[str]:
    """
    Infers whether the column is constant.

    Parameters
    ----------
    include_null
        If true, a constant column with null values will also be included.
    """
    if include_null:
        is_ok = (
            self._frame.select(
                (
                    (pl.col(c).n_unique() == 1)
                    | ((pl.col(c).null_count() > 0) & (pl.col(c).n_unique() == 2))
                ).alias(c)
                for c in self._frame.columns
            )
            .collect()
            .row(0)
        )
    else:
        is_ok = (
            self._frame.select(
                (pl.col(c).n_unique() == 1).alias(c) for c in self._frame.columns
            )
            .collect()
            .row(0)
        )

    return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

infer_corr(method='pearson')

Trying to infer highly correlated columns by computing correlation between all numerical (including boolean) columns.

Parameters:

Name Type Description Default
method CorrMethod

One of ["pearson", "spearman", "xi", "kendall"]

'pearson'
Source code in python/polars_ds/eda/diagnosis.py
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
def infer_corr(self, method: CorrMethod = "pearson") -> pl.DataFrame:
    """
    Trying to infer highly correlated columns by computing correlation between
    all numerical (including boolean) columns.

    Parameters
    ----------
    method
        One of ["pearson", "spearman", "xi", "kendall"]
    """
    to_check = self.numerics + self.bools

    xx = []
    yy = []
    for x, y in combinations(to_check, 2):
        xx.append(x)
        yy.append(y)

    corrs = (
        self._frame.with_columns(pl.col(c).cast(pl.UInt8) for c in self.bools)
        .select(corr(x, y, method=method).alias(f"{i}") for i, (x, y) in enumerate(zip(xx, yy)))
        .collect()
        .row(0)
    )

    return pl.DataFrame({"x": xx, "y": yy, "corr": corrs}).sort(
        pl.col("corr").abs(), descending=True
    )

infer_dependency(subset=pl.all())

Infers (functional) dependency using the method of conditional entropy. This only evaluates potential qualifying columns. Potential qualifying columns are columns of type: int, str, categorical, or booleans.

If returned conditional entropy is very low, that means knowning the column in by is enough to to infer the column in column, or the column in column can be determined by the column in by.

Parameters:

Name Type Description Default
subset IntoExpr | Iterable[IntoExpr]

A subset of columns to try running the dependency check. The subset input can be anything that can be turned into a Polars selector. The df or the column subset of the df may contain columns that cannot be used for dependency detection, e.g. column of list of values. Only valid columns will be checked.

all()
Source code in python/polars_ds/eda/diagnosis.py
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
def infer_dependency(self, subset: IntoExpr | Iterable[IntoExpr] = pl.all()) -> pl.DataFrame:
    """
    Infers (functional) dependency using the method of conditional entropy. This only evaluates
    potential qualifying columns. Potential qualifying columns are columns of type:
    int, str, categorical, or booleans.

    If returned conditional entropy is very low, that means knowning the column in
    `by` is enough to to infer the column in `column`, or the column in `column` can
    be determined by the column in `by`.

    Parameters
    ----------
    subset
        A subset of columns to try running the dependency check. The subset input can be
        anything that can be turned into a Polars selector. The df or the column subset of the df
        may contain columns that cannot be used for dependency detection, e.g. column of list of values.
        Only valid columns will be checked.
    """

    # Infer valid columns to run this detection
    valid = self.ints + self.strs + self.cats + self.bools
    check_frame = self._frame.select(subset)
    all_names = check_frame.collect_schema().names()
    to_check = [x for x in all_names if x in valid]
    n_uniques = check_frame.select(pl.col(c).n_unique() for c in to_check).collect().row(0)

    frame = (
        pl.DataFrame({"column": to_check, "n_unique": n_uniques})
        .filter(pl.col("n_unique") > 1)
        .sort("n_unique")
    )

    check = list(frame["column"])
    if len(check) <= 1:
        warnings.warn(
            f"Not enough valid columns to detect dependency on. Valid column count: {len(check)}. Empty dataframe returned.",
            stacklevel=2,
        )
        return pl.DataFrame(
            {"column": [], "by": [], "cond_entropy": []},
            schema={"column": pl.String, "by": pl.String, "cond_entropy": pl.Float64},
        )

    if len(check) != len(all_names):
        warnings.warn(
            f"The following columns are dropped because they cannot be used in dependency detection: {[f for f in all_names if f not in check]}",
            stacklevel=2,
        )

    # Construct output
    column = []
    by = []
    for x, y in combinations(check, 2):
        column.append(x)
        by.append(y)

    ce = (
        self._frame.select(
            query_cond_entropy(x, y).abs().alias(f"{i}")
            for i, (x, y) in enumerate(zip(column, by))
        )
        .collect()
        .row(0)
    )

    out = pl.DataFrame({"column": column, "by": by, "cond_entropy": ce}).sort("cond_entropy")

    return out

infer_discrete(threshold=0.1, max_val_cnt=100) cached

Infers discrete columns based on unique percentage and max_val_count.

Parameters:

Name Type Description Default
threshold float

Columns with unique percentage lower than threshold will be considered discrete

0.1
max_val_cnt int

Max number of unique values the column can have in order for it to be considered discrete

100
Source code in python/polars_ds/eda/diagnosis.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
@lru_cache
def infer_discrete(self, threshold: float = 0.1, max_val_cnt: int = 100) -> List[str]:
    """
    Infers discrete columns based on unique percentage and max_val_count.

    Parameters
    ----------
    threshold
        Columns with unique percentage lower than threshold will be considered
        discrete
    max_val_cnt
        Max number of unique values the column can have in order for it to be considered
        discrete
    """
    out: List[str] = self.bools + self.cats
    to_check = [c for c in self._frame.columns if c not in out]
    is_ok = (
        self._frame.select(
            (
                (pl.col(c).n_unique() < max_val_cnt)
                | (pl.col(c).n_unique() < threshold * pl.len())
            ).alias(c)
            for c in to_check
        )
        .collect()
        .row(0)
    )

    return [c for c, ok in zip(to_check, is_ok) if ok is True]

infer_high_null(threshold=0.75) cached

Infers columns with more than threshold percentage nulls.

Parameters:

Name Type Description Default
threshold float

The threshold above which a column will be considered high null

0.75
Source code in python/polars_ds/eda/diagnosis.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
@lru_cache
def infer_high_null(self, threshold: float = 0.75) -> List[str]:
    """
    Infers columns with more than threshold percentage nulls.

    Parameters
    ----------
    threshold
        The threshold above which a column will be considered high null
    """
    is_ok = (
        self._frame.select(
            (pl.col(c).null_count() >= pl.len() * threshold).alias(c)
            for c in self._frame.columns
        )
        .collect()
        .row(0)
    )

    return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

infer_k_distinct(k, include_null=False) cached

Infers whether the column has k distinct values.

Parameters:

Name Type Description Default
k int

Any positive integer.

required
include_null bool

If true, a binary column with k non-null distinct values and null will also be included.

False
Source code in python/polars_ds/eda/diagnosis.py
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
@lru_cache
def infer_k_distinct(self, k: int, include_null: bool = False) -> List[str]:
    """
    Infers whether the column has k distinct values.

    Parameters
    ----------
    k
        Any positive integer.
    include_null
        If true, a binary column with k non-null distinct values and null will also be included.
    """
    if k < 1:
        raise ValueError("Input `k` must be >= 1.")

    if include_null:
        is_ok = (
            self._frame.select(
                (
                    (pl.col(c).n_unique() == k)
                    | ((pl.col(c).null_count() > 0) & (pl.col(c).n_unique() == (k + 1)))
                ).alias(c)
                for c in self._frame.columns
            )
            .collect()
            .row(0)
        )
    else:
        is_ok = (
            self._frame.select(
                (pl.col(c).n_unique() == k).alias(c) for c in self._frame.columns
            )
            .collect()
            .row(0)
        )

    return [c for c, ok in zip(self._frame.columns, is_ok) if ok is True]

infer_prob()

Infers columns that can potentially be probabilities. For f32/f64 columns, this checks if all values are between 0 and 1. For List[f32] or List[f64] columns, this checks whether the column can potentially be multi-class probabilities.

Source code in python/polars_ds/eda/diagnosis.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def infer_prob(self) -> List[str]:
    """
    Infers columns that can potentially be probabilities. For f32/f64 columns, this checks if all values are
    between 0 and 1. For List[f32] or List[f64] columns, this checks whether the column can potentially be
    multi-class probabilities.
    """
    is_ok = (
        self._frame.select(
            *((pl.col(c).is_between(0.0, 1.0).all()).alias(c) for c in self.floats),
            *(
                (
                    (
                        pl.col(c).list.eval((pl.element() >= 0.0).all()).list.first()
                    )  # every number must be positive
                    & ((pl.col(c).list.sum() - 1.0).abs() < 1e-6)  # class prob must sum to 1
                    & (
                        pl.col(c).list.len().min() == pl.col(c).list.len().max()
                    )  # class prob column must have the same length
                ).alias(c)
                for c in self.list_floats
            ),
        )
        .collect()
        .row(0)
    )

    return [c for c, ok in zip(self.floats + self.list_floats, is_ok) if ok is True]

meta()

Returns internal data in this class as a dictionary.

Source code in python/polars_ds/eda/diagnosis.py
400
401
402
403
404
405
406
def meta(self) -> Dict:
    """
    Returns internal data in this class as a dictionary.
    """
    out = self.__dict__.copy()
    out.pop("_frame")
    return out

null_corr(subset=pl.all(), filter_by=None)

Computes the correlation between A is null and B is null for all (A, B) combinations in the given subset of columns.

If either A or B is all null or all non-null, the null correlation will not be computed, since the value is not going to be meaningful.

Parameters:

Name Type Description Default
subset IntoExpr | Iterable[IntoExpr]

Anything that can be put into a Polars .select statement. Defaults to pl.all()

all()
filter_by Expr | None

A boolean expression

None
Source code in python/polars_ds/eda/diagnosis.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def null_corr(
    self,
    subset: IntoExpr | Iterable[IntoExpr] = pl.all(),
    filter_by: pl.Expr | None = None,
) -> pl.DataFrame:
    """
    Computes the correlation between A is null and B is null for all (A, B) combinations
    in the given subset of columns.

    If either A or B is all null or all non-null, the null correlation will not be
    computed, since the value is not going to be meaningful.

    Parameters
    ----------
    subset
        Anything that can be put into a Polars .select statement. Defaults to pl.all()
    filter_by
        A boolean expression
    """

    cols = self._frame.select(subset).collect_schema().names()

    if filter_by is None:
        frame = self._frame.select(pl.col(cols).is_null()).collect()
    else:
        frame = self._frame.filter(filter_by).select(pl.col(cols).is_null()).collect()

    df_null_cnt = frame.sum()
    n = frame.shape[0]

    invalid = set(
        c for c, cnt in zip(df_null_cnt.columns, df_null_cnt.row(0)) if (cnt == 0 or cnt == n)
    )

    xx = []
    yy = []
    for x, y in combinations(cols, 2):
        if not (x in invalid or y in invalid):
            xx.append(x)
            yy.append(y)

    if len(xx) == 0:
        return pl.DataFrame(
            {"column_1": [], "column_2": [], "null_corr": []},
            schema={
                "column_1": pl.String,
                "column_2": pl.String,
                "null_corr": pl.Float64,
            },
        )
    else:
        corrs = frame.select(
            pl.corr(x, y).alias(str(i)) for i, (x, y) in enumerate(zip(xx, yy))
        ).row(0)
        return pl.DataFrame({"column_1": xx, "column_2": yy, "null_corr": corrs}).sort(
            pl.col("null_corr").abs(), descending=True
        )

numeric_profile(n_bins=20, iqr_multiplier=1.5, histogram=True, gt=True)

Creates a numerical profile with a histogram plot. Notice that the histograms may have completely different scales on the x-axis.

Parameters:

Name Type Description Default
n_bins int

Bins in the histogram

20
iqr_multiplier float

Inter Quartile Ranger multiplier. Inter quantile range is the range between Q1 and Q3, and this multiplier will enlarge the range by a certain amount and use this to count outliers.

1.5
histogram bool

Whether to show a histogram or not

True
gt bool

Whether to show the table as a formatted Great Table or not

True
Source code in python/polars_ds/eda/diagnosis.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def numeric_profile(
    self, n_bins: int = 20, iqr_multiplier: float = 1.5, histogram: bool = True, gt: bool = True
) -> GT | pl.DataFrame:
    """
    Creates a numerical profile with a histogram plot. Notice that the histograms may have
    completely different scales on the x-axis.

    Parameters
    ----------
    n_bins
        Bins in the histogram
    iqr_multiplier
        Inter Quartile Ranger multiplier. Inter quantile range is the range between
        Q1 and Q3, and this multiplier will enlarge the range by a certain amount and
        use this to count outliers.
    histogram
        Whether to show a histogram or not
    gt
        Whether to show the table as a formatted Great Table or not
    """
    to_check = self.numerics

    cuts = [i / n_bins for i in range(n_bins)]
    cuts[0] -= 1e-5
    cuts[-1] += 1e-5

    if histogram:
        columns_needed = [
            [
                pl.lit(c, dtype=pl.String).alias("column"),
                pl.col(c).count().alias("non_null_cnt"),
                (pl.col(c).null_count() / pl.len()).alias("null%"),
                pl.col(c).mean().cast(pl.Float64).alias("mean"),
                pl.col(c).std().cast(pl.Float64).alias("std"),
                pl.col(c).min().cast(pl.Float64).cast(pl.Float64).alias("min"),
                pl.col(c).quantile(0.25).cast(pl.Float64).alias("q1"),
                pl.col(c).median().cast(pl.Float64).round(2).alias("median"),
                pl.col(c).quantile(0.75).cast(pl.Float64).alias("q3"),
                pl.col(c).max().cast(pl.Float64).alias("max"),
                (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25))
                .cast(pl.Float64)
                .alias("IQR"),
                pl.any_horizontal(
                    pl.col(c)
                    < pl.col(c).quantile(0.25)
                    - iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                    pl.col(c)
                    > pl.col(c).quantile(0.75)
                    + iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                )
                .sum()
                .alias("outlier_cnt"),
                pl.struct(
                    ((pl.col(c) - pl.col(c).min()) / (pl.col(c).max() - pl.col(c).min()))
                    .filter(pl.col(c).is_finite())
                    .cut(breaks=cuts, left_closed=True, include_breaks=True)
                    .struct.rename_fields(["brk", "category"])
                    .struct.field("brk")
                    .value_counts()
                    .sort()
                    .struct.field("count")
                    .implode()
                ).alias("histogram"),
            ]
            for c in to_check
        ]
    else:
        columns_needed = [
            [
                pl.lit(c, dtype=pl.String).alias("column"),
                pl.col(c).count().alias("non_null_cnt"),
                (pl.col(c).null_count() / pl.len()).alias("null%"),
                pl.col(c).mean().cast(pl.Float64).alias("mean"),
                pl.col(c).std().cast(pl.Float64).alias("std"),
                pl.col(c).min().cast(pl.Float64).cast(pl.Float64).alias("min"),
                pl.col(c).quantile(0.25).cast(pl.Float64).alias("q1"),
                pl.col(c).median().cast(pl.Float64).round(2).alias("median"),
                pl.col(c).quantile(0.75).cast(pl.Float64).alias("q3"),
                pl.col(c).max().cast(pl.Float64).alias("max"),
                (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25))
                .cast(pl.Float64)
                .alias("IQR"),
                pl.any_horizontal(
                    pl.col(c)
                    < pl.col(c).quantile(0.25)
                    - iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                    pl.col(c)
                    > pl.col(c).quantile(0.75)
                    + iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
                )
                .sum()
                .alias("outlier_cnt"),
            ]
            for c in to_check
        ]

    frames = [self._frame.select(*cols) for cols in columns_needed]
    df_final = pl.concat(pl.collect_all(frames))

    if gt:
        gt_out = (
            GT(df_final, rowname_col="column")
            .tab_stubhead("column")
            .fmt_percent(columns="null%")
            .fmt_number(
                columns=["mean", "std", "min", "q1", "median", "q3", "max", "IQR"], decimals=3
            )
        )
        if histogram:
            return gt_out.fmt_nanoplot(columns="histogram", plot_type="bar")
        return gt_out
    else:
        return df_final

plot_corr(subset, method='pearson')

Plots the correlations using classic heat maps.

Parameters:

Name Type Description Default
subset IntoExpr | Iterable[IntoExpr]

Anything that can be put into a Polars .select statement.

required
method CorrMethod

One of ["pearson", "spearman", "xi", "kendall", "bicor"]

'pearson'
Source code in python/polars_ds/eda/diagnosis.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def plot_corr(
    self, subset: IntoExpr | Iterable[IntoExpr], method: CorrMethod = "pearson"
) -> GT:
    """
    Plots the correlations using classic heat maps.

    Parameters
    ----------
    subset
        Anything that can be put into a Polars .select statement.
    method
        One of ["pearson", "spearman", "xi", "kendall", "bicor"]
    """
    corr_values = self.corr(subset, method)
    cols = [c for c in corr_values.columns if c != "column"]
    return (
        GT(corr_values)
        .fmt_number(columns=cols, decimals=3)
        .data_color(
            columns=cols,
            palette=["#0202bd", "#bd0237"],
            domain=[-1, 1],
            alpha=0.5,
            na_color="#000000",
        )
    )

plot_dependency(threshold=0.01, subset=pl.all())

Plot dependency using the result of self.infer_dependency and positively dtermines dependency by the threshold.

Parameters:

Name Type Description Default
threshold float

If conditional entropy is < threshold, we draw a line indicating dependency.

0.01
subset IntoExpr | Iterable[IntoExpr]

A subset of columns to try running the dependency check. The subset input can be anything that can be turned into a Polars selector

all()
Source code in python/polars_ds/eda/diagnosis.py
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
def plot_dependency(
    self, threshold: float = 0.01, subset: IntoExpr | Iterable[IntoExpr] = pl.all()
) -> graphviz.Digraph:
    """
    Plot dependency using the result of self.infer_dependency and positively dtermines
    dependency by the threshold.

    Parameters
    ----------
    threshold
        If conditional entropy is < threshold, we draw a line indicating dependency.
    subset
        A subset of columns to try running the dependency check. The subset input can be
        anything that can be turned into a Polars selector
    """

    dep_frame = self.infer_dependency(subset=subset)

    df_local = dep_frame.filter((pl.col("cond_entropy") < threshold)).select(
        pl.col("column").alias("child"),  # c for child
        pl.col("by").alias("parent"),  # p for parent
    )
    cp = df_local.group_by("child").agg(pl.col("parent"))
    pc = df_local.group_by("parent").agg(pl.col("child"))
    child_parent: dict[str, pl.Series] = dict(zip(cp["child"], cp["parent"]))
    parent_child: dict[str, pl.Series] = dict(zip(pc["parent"], pc["child"]))

    dot = graphviz.Digraph(
        "Dependency Plot", comment=f"Conditional Entropy < {threshold:.2f}", format="png"
    )
    for c, par in child_parent.items():
        parents_of_c = set(par)
        for p in par:
            # Does parent p have a child that is also a parent of c? If so, remove p.
            children_of_p = parent_child.get(p, None)
            if children_of_p is not None:
                if len(parents_of_c.intersection(children_of_p)) > 0:
                    parents_of_c.remove(p)

        dot.node(c)
        for p in parents_of_c:
            dot.node(p)
            dot.edge(p, c)

    return dot

plot_feature_distr(feature, n_bins=None, density=False, show_bad_values=True, min_=None, max_=None, over=None, filter_by=None)

Plot distribution of the feature with a few statistical details.

Parameters:

Name Type Description Default
feature str

A string representing a column name

required
n_bins int | None

The number of bins used for histograms. Not used when the feature column is categorical.

None
density bool

Whether to plot a probability density or not

False
show_bad_values bool

Whether to show % of bad (null or non-finite) values

True
min_ float | Expr | None

Whether to ignore values strictly lower than min_

None
max_ float | Expr | None

Whether to ignore values strictly higher than max_

None
over str | None

Whether to look at the distribution over another categorical column

None
filter_by Expr | None

An extra condition you may want to impose on the underlying dataset

None
Source code in python/polars_ds/eda/diagnosis.py
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
def plot_feature_distr(
    self,
    feature: str,
    n_bins: int | None = None,
    density: bool = False,
    show_bad_values: bool = True,
    min_: float | pl.Expr | None = None,
    max_: float | pl.Expr | None = None,
    over: str | None = None,
    filter_by: pl.Expr | None = None,
) -> alt.Chart:
    """
    Plot distribution of the feature with a few statistical details.

    Parameters
    ----------
    feature
        A string representing a column name
    n_bins
        The number of bins used for histograms. Not used when the feature column is categorical.
    density
        Whether to plot a probability density or not
    show_bad_values
        Whether to show % of bad (null or non-finite) values
    min_
        Whether to ignore values strictly lower than min_
    max_
        Whether to ignore values strictly higher than max_
    over
        Whether to look at the distribution over another categorical column
    filter_by
        An extra condition you may want to impose on the underlying dataset
    """
    if feature not in self.numerics:
        raise ValueError("Input feature must be numeric.")

    conditions = []
    if filter_by is not None:
        conditions.append(filter_by)
    if min_ is not None:
        conditions.append(pl.col(feature) >= min_)
    if max_ is not None:
        conditions.append(pl.col(feature) <= max_)

    if len(conditions) > 0:
        df = self._frame.filter(pl.all_horizontal(*conditions)).select(feature, over).collect()
    else:
        df = self._frame.select(feature, over).collect()

    return plot_feature_distr(
        df=df,
        feature=feature,
        n_bins=n_bins,
        density=density,
        show_bad_values=show_bad_values,
        over=over,
    )

row_validation(*rules, id_col=None, columns_to_keep=None, all_reasons=False)

Generates a validation report based on rules (pl.Expr) which evaluates to booleans per row.

Parameters:

Name Type Description Default
rules Tuple[Expr, str]

A tuple of (pl.Expr, str), where the pl.Expr should evaluate to a boolean value per row. If the boolean is False, then the row is considered a violation. The string should be an explanation of the violation.

()
id_col str | None

If None, an "index" column will be generated which is the row number.

None
columns_to_keep List[str] | None

Other columns you wish to keep in the final report.

None
all_reasons bool

If true, all reasons for violations will be returned. If false, only 1 will be returned.

False
Source code in python/polars_ds/eda/diagnosis.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def row_validation(
    self,
    *rules: Tuple[pl.Expr, str],
    id_col: str | None = None,
    columns_to_keep: List[str] | None = None,
    all_reasons: bool = False,
) -> pl.DataFrame:
    """
    Generates a validation report based on rules (pl.Expr) which evaluates to booleans
    per row.

    Parameters
    ----------
    rules
        A tuple of (pl.Expr, str), where the pl.Expr should evaluate to a boolean value
        per row. If the boolean is False, then the row is considered a violation. The string
        should be an explanation of the violation.
    id_col
        If None, an "__index__" column will be generated which is the row number.
    columns_to_keep
        Other columns you wish to keep in the final report.
    all_reasons
        If true, all reasons for violations will be returned. If false, only 1 will be returned.
    """

    if id_col is None:
        df = self._frame.with_row_index(name="__index__")
        to_keep = ["__index__"]
    else:
        df = self._frame
        to_keep = [id_col]

    rules_to_check = list(rules)
    rules_exprs = [r.alias(n) for r, n in rules_to_check]
    all_rule_names = [n for _, n in rules_to_check]
    # Do not allow duplicate rule names
    existing_names = set()
    for name in all_rule_names:
        if name not in existing_names:
            existing_names.add(name)
        else:
            raise ValueError(f"Rule name {name} is duplicate. Please rename it.")

    # We cannot use list(set(..)) because that might change the order of all_rule_names

    if columns_to_keep is not None:
        to_keep += columns_to_keep

    df_temp = df.select(*to_keep, *rules_exprs).filter(
        # Filter to the violators.
        # pl.all_horizontal(*all_rule_names) = people who pass all rules
        # pl.all_horizontal(*all_rule_names).not_() = people who failed any one of the rules
        pl.all_horizontal(*all_rule_names).not_()
    )

    if all_reasons:
        reasons = [
            pl.when(pl.col(c)).then(None).otherwise(pl.lit(c, dtype=pl.String))
            for c in all_rule_names
        ]  # When true, return None. When false, return reason

        return df_temp.select(
            *to_keep, pl.concat_list(reasons).list.drop_nulls().list.sort().alias("__reason__")
        ).collect()
    else:
        # df_temp = all people who failed any one of the rules. So there must be at least one 0 in concat-ed list.
        return df_temp.select(
            *to_keep,
            pl.concat_list(all_rule_names)
            .list.arg_min()
            .replace_strict(old=list(range(len(all_rule_names))), new=all_rule_names)
            .alias("__reason__"),
        ).collect()

special_values_report()

Checks null, NaN, and non-finite values for float columns. Note that for integers, only null_count can possibly be non-zero.

Source code in python/polars_ds/eda/diagnosis.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def special_values_report(self) -> pl.DataFrame:
    """
    Checks null, NaN, and non-finite values for float columns. Note that for integers, only null_count
    can possibly be non-zero.
    """
    to_check = self.numerics
    frames = [
        self._frame.select(
            pl.lit(c, dtype=pl.String).alias("column"),
            pl.col(c).null_count().alias("null_count"),
            (pl.col(c).null_count() / pl.len()).alias("null%"),
            pl.col(c).is_nan().sum().alias("NaN_count"),
            (pl.col(c).is_nan().sum() / pl.len()).alias("NaN%"),
            pl.col(c).is_infinite().sum().alias("inf_count"),
            (pl.col(c).is_infinite().sum() / pl.len()).alias("Inf%"),
        )
        for c in to_check
    ]
    return pl.concat(pl.collect_all(frames))

str_stats()

Returns basic statistics about the string columns.

Source code in python/polars_ds/eda/diagnosis.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def str_stats(self) -> pl.DataFrame:
    """
    Returns basic statistics about the string columns.
    """
    to_check = self.strs
    frames = [
        self._frame.select(
            pl.lit(c).alias("column"),
            pl.col(c).null_count().alias("null_count"),
            pl.col(c).n_unique().alias("n_unique"),
            pl.col(c).value_counts(sort=True).first().struct.field(c).alias("most_freq"),
            pl.col(c)
            .value_counts(sort=True)
            .first()
            .struct.field("count")
            .alias("most_freq_cnt"),
            pl.col(c).str.len_bytes().min().alias("min_byte_len"),
            pl.col(c).str.len_chars().min().alias("min_char_len"),
            pl.col(c).str.len_bytes().mean().alias("avg_byte_len"),
            pl.col(c).str.len_chars().mean().alias("avg_char_len"),
            pl.col(c).str.len_bytes().max().alias("max_byte_len"),
            pl.col(c).str.len_chars().max().alias("max_char_len"),
            pl.col(c).str.len_bytes().quantile(0.05).alias("5p_byte_len"),
            pl.col(c).str.len_bytes().quantile(0.95).alias("95p_byte_len"),
        )
        for c in to_check
    ]
    return pl.concat(pl.collect_all(frames))

plots

Functions:

Name Description
plot_feature_distr

Plot distribution of the feature with a few statistical details.

plot_lin_reg

Plots the linear regression line between x and target.

plot_pca

Creates a scatter plot based on the reduced dimensions via PCA, and color it by by.

plot_prob_calibration

Plots probability calibration of score(s) with respect to the binary target.

plot_roc_auc

Parameters

plot_feature_distr(*, feature, n_bins=10, density=False, show_bad_values=True, over=None, df=None)

Plot distribution of the feature with a few statistical details.

Parameters:

Name Type Description Default
df DataFrame | LazyFrame | None

Either an eager or lazy Polars Dataframe

None
feature str | Iterable[float]

A string representing a column name

required
n_bins int

The max number of bins used for histograms.

10
density bool

Whether to plot a probability density or not

False
show_bad_values bool

Whether to show % of bad (null or non-finite) values

True
over str | None

Whether to look at the distribution over another categorical column

None
Source code in python/polars_ds/eda/plots.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def plot_feature_distr(
    *,
    feature: str | Iterable[float],
    n_bins: int = 10,
    density: bool = False,
    show_bad_values: bool = True,
    over: str | None = None,
    df: pl.DataFrame | pl.LazyFrame | None = None,
) -> alt.Chart:
    """
    Plot distribution of the feature with a few statistical details.

    Parameters
    ----------
    df
        Either an eager or lazy Polars Dataframe
    feature
        A string representing a column name
    n_bins
        The max number of bins used for histograms.
    density
        Whether to plot a probability density or not
    show_bad_values
        Whether to show % of bad (null or non-finite) values
    over
        Whether to look at the distribution over another categorical column
    """

    if n_bins <= 2:
        raise ValueError("Input `n_bins` must be > 2.")

    if over is not None and df is None:
        raise ValueError("Input `over` can only be used when df is not None.")

    if isinstance(feature, str):
        if df is None:
            raise ValueError("If `feature` is str, then df cannot be none.")
        feat = feature
        if over is None:
            data = df.lazy().select(pl.col(feat).cast(pl.Float64)).collect()
        else:
            data = df.lazy().select(pl.col(feat).cast(pl.Float64), over).collect()
    else:
        if over is None:
            data = pl.Series(name="feature", values=feature, dtype=pl.Float64).to_frame()
            feat = "feature"
        else:
            raise ValueError("If input `feature` is a Series, then `over` cannot be used.")

    # selection = alt.selection_point(fields=['species'], bind='legend')
    # .filter(pl.col(feat).is_not_null())
    if density:
        if over is None:
            chart = (
                alt.Chart(data)
                .transform_density(
                    feat,
                    as_=[feat, "density"],
                )
                .mark_area()
                .encode(
                    x=f"{feat}:Q",
                    y=alt.Y("density:Q").stack(None),
                )
            )
        else:
            selection = alt.selection_multi(fields=[over], bind="legend")
            chart = (
                alt.Chart(data)
                .transform_density(feat, as_=[feat, "density"], groupby=[over])
                .mark_area()
                .encode(
                    x=f"{feat}:Q",
                    y=alt.Y("density:Q").stack(None),
                    color=over,
                    opacity=alt.condition(selection, alt.value(0.8), alt.value(0.2)),
                )
                .add_params(selection)
            )
    else:
        if over is None:
            chart = (
                alt.Chart(data)
                .mark_bar()
                .encode(
                    alt.X(f"{feat}:Q").bin(maxbins=n_bins).title(feat),
                    y=alt.Y("count()").stack(None),
                )
            )
        else:
            selection = alt.selection_point(fields=[over], bind="legend")
            chart = (
                alt.Chart(data)
                .mark_bar()
                .encode(
                    alt.X(f"{feat}:Q").bin(maxbins=n_bins).title(feat),
                    y=alt.Y("count()").stack(None),
                    color=over,
                    opacity=alt.condition(selection, alt.value(0.8), alt.value(0.2)),
                )
                .add_params(selection)
            )

    if over is None:
        p5, median, mean, p95, min_, max_, cnt, null_cnt, not_finite = data.select(
            p5=pl.col(feat).quantile(0.05),
            median=pl.col(feat).median(),
            mean=pl.col(feat).mean(),
            p95=pl.col(feat).quantile(0.95),
            min=pl.col(feat).min(),
            max=pl.col(feat).max(),
            cnt=pl.len(),
            null_cnt=pl.col(feat).null_count(),
            not_finite=pl.col(feat).is_finite().not_().sum(),
        ).row(0)

        # stats overlay
        df_stats = pl.DataFrame(
            {"names": ["p5", "median", "avg", "p95"], "stats": [p5, median, mean, p95]}
        )

        stats_base = alt.Chart(df_stats)
        stats_chart = stats_base.mark_rule(color="#f086ab").encode(
            x=alt.X("stats").title(""),
            tooltip=[
                alt.Tooltip("names:N", title="Stats"),
                alt.Tooltip("stats:Q", title="Value"),
            ],
        )

        chart = chart + stats_chart
        if show_bad_values:
            df_bad_values = pl.DataFrame(
                {
                    "names": [""],
                    "pcts": [(null_cnt + not_finite) / cnt],
                }
            )

            bad_values_chart = (
                alt.Chart(df_bad_values)
                .mark_bar(opacity=0.7)
                .encode(
                    x=alt.X("pcts:Q", scale=alt.Scale(domain=[0, 1]))
                    .axis(format=".0%")
                    .title("Null or Non-Finite %"),
                    y=alt.Y("names:N").title(""),
                    tooltip=[
                        alt.Tooltip("pcts:Q", title="Null or Non-Finite %"),
                    ],
                )
            )

            return alt.vconcat(chart, bad_values_chart)
        else:
            return chart

    else:  # over is not None
        if show_bad_values:
            df_bad = data.group_by(over).agg(
                pcts=(pl.col(feat).null_count() + pl.col(feat).is_finite().not_().sum()) / pl.len()
            )
            bad_values_chart = (
                alt.Chart(df_bad)
                .mark_bar(opacity=0.7)
                .encode(
                    x=alt.X("pcts:Q", scale=alt.Scale(domain=[0, 1]))
                    .axis(format=".0%")
                    .title("Null or Non-Finite %"),
                    y=alt.Y(f"{over}:N"),
                    tooltip=[
                        alt.Tooltip("pcts:Q", title="Null or Non-Finite %"),
                    ],
                )
            )
            return alt.vconcat(chart, bad_values_chart)
        else:
            return chart

plot_lin_reg(df, x, target, add_bias=False, weights=None, max_points=20000, show_lin_reg_eq=True)

Plots the linear regression line between x and target.

Paramters

df Either an eager or lazy Polars Dataframe x The preditive variable target The target variable add_bias Whether to add bias in the linear regression weights Weights for the linear regression max_points The max number of points to be displayed. Notice that this only affects the number of points on the plot. The linear regression will still be fit on the entire dataset. show_lin_reg_eq Whether to show the linear regression equation at the bottom or not

Source code in python/polars_ds/eda/plots.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def plot_lin_reg(
    df: pl.DataFrame | pl.LazyFrame,
    x: str,
    target: str,
    add_bias: bool = False,
    weights: str | None = None,
    max_points: int = 20_000,
    show_lin_reg_eq: bool = True,
) -> alt.Chart:
    """
    Plots the linear regression line between x and target.

    Paramters
    ---------
    df
        Either an eager or lazy Polars Dataframe
    x
        The preditive variable
    target
        The target variable
    add_bias
        Whether to add bias in the linear regression
    weights
        Weights for the linear regression
    max_points
        The max number of points to be displayed. Notice that this only affects the number of points
        on the plot. The linear regression will still be fit on the entire dataset.
    show_lin_reg_eq
        Whether to show the linear regression equation at the bottom or not
    """

    to_select = [x, target] if weights is None else [x, target, weights]
    temp = df.lazy().select(*to_select)

    xx = pl.col(x)
    yy = pl.col(target)
    # Although using simple_lin_reg might seem to be able to reduce some code here,
    # it adds complexity because of output type and the r2 query.
    # A little bit of code dup is reasonable.
    if add_bias:
        if weights is None:
            x_mean = xx.mean()
            y_mean = yy.mean()
            beta = (xx - x_mean).dot(yy - y_mean) / (xx - x_mean).dot(xx - x_mean)
            alpha = y_mean - beta * x_mean
        else:
            w = pl.col(weights)
            w_sum = w.sum()
            x_wmean = w.dot(xx) / w_sum
            y_wmean = w.dot(yy) / w_sum
            beta = w.dot((xx - x_wmean) * (yy - y_wmean)) / (w.dot((xx - x_wmean).pow(2)))
            alpha = y_wmean - beta * x_wmean
    else:
        if weights is None:
            beta = xx.dot(yy) / xx.dot(xx)
        else:
            w = pl.col(weights)
            beta = w.dot(xx * yy) / w.dot(xx.pow(2))

        alpha = pl.lit(0, dtype=pl.Float64)

    beta, alpha, r2, length = (
        temp.select(
            beta.alias("beta"),
            alpha.alias("alpha"),
            query_r2(yy, xx * beta + alpha).alias("r2"),
            pl.len(),
        )
        .collect()
        .row(0)
    )

    df_need = temp.select(
        xx,
        yy,
        (xx * beta + alpha).alias("y_pred"),
    )
    # Sample down if len(temp) > max_points
    df_sampled = sa.sample(df_need, value=max_points) if length > max_points else df_need.collect()

    x_title = [x]
    if show_lin_reg_eq:
        if add_bias and alpha > 0:
            reg_info = f"y = {beta:.4f} * x + {round(alpha, 4) if add_bias else ''}, r2 = {r2:.4f}"
        elif add_bias and alpha < 0:
            reg_info = (
                f"y = {beta:.4f} * x - {abs(round(alpha, 4)) if add_bias else ''}, r2 = {r2:.4f}"
            )
        else:
            reg_info = f"y = {beta:.4f} * x, r2 = {r2:.4f}"

        x_title.append(reg_info)

    chart = alt.Chart(df_sampled).mark_point().encode(alt.X(x).scale(zero=False), alt.Y(target))
    return chart + chart.mark_line().encode(
        alt.X(x, title=x_title).scale(zero=False),
        alt.Y("y_pred"),
    )

plot_pca(df, features, by, center=True, dim=2, filter_by=None, max_points=10000, **kwargs)

Creates a scatter plot based on the reduced dimensions via PCA, and color it by by.

Paramters

df Either an eager or lazy Polars Dataframe features List of feature names by Color the 2-D PCA plot by the values in the column center Whether to automatically center the features dim Only 2 principal components plot can be done at this moment. filter_by A boolean expression max_points The max number of points to be displayed. If data > this limit, the data will be sampled. kwargs Anything else that will be passed to Altair encode function

Source code in python/polars_ds/eda/plots.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def plot_pca(
    df: pl.DataFrame | pl.LazyFrame,
    features: List[str],
    by: IntoExpr,
    center: bool = True,
    dim: int = 2,
    filter_by: pl.Expr | None = None,
    max_points: int = 10_000,
    **kwargs,
) -> alt.Chart:
    """
    Creates a scatter plot based on the reduced dimensions via PCA, and color it by `by`.

    Paramters
    ---------
    df
        Either an eager or lazy Polars Dataframe
    features
        List of feature names
    by
        Color the 2-D PCA plot by the values in the column
    center
        Whether to automatically center the features
    dim
        Only 2 principal components plot can be done at this moment.
    filter_by
        A boolean expression
    max_points
        The max number of points to be displayed. If data > this limit, the data will be sampled.
    kwargs
        Anything else that will be passed to Altair encode function
    """
    if len(features) < 2:
        raise ValueError("You must pass >= 2 features.")
    if dim not in (2, 3):
        raise ValueError("Dim must be 2 or 3.")

    frame = df if filter_by is None else df.filter(filter_by)

    temp = frame.select(principal_components(*features, center=center, k=dim).alias("pc"), by)
    df_plot = sa.sample(temp, value=max_points).unnest("pc")

    if dim == 2:
        return (
            alt.Chart(df_plot).mark_circle(size=60).encode(x="pc1", y="pc2", color=by, **kwargs)
        )  # .interactive()
    else:  # 3d
        raise NotImplementedError

plot_prob_calibration(*, target, score=None, name=None, scores=None, names=None, n_bins=10)

Plots probability calibration of score(s) with respect to the binary target.

Parameters:

Name Type Description Default
target Iterable[int]

The target binary varialbe

required
score Series | None

The probability score values

None
name str | None

The name of the probability score values

None
scores List[Series] | None

If score is None, and scores is a list of probability scores, this will generate a plot with all probability calibrations.

None
names List[str] | None

If scores is population, this must be a list of corresponding score names.

None
n_bins int

N quantile bins for the score(s).

10
Source code in python/polars_ds/eda/plots.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def plot_prob_calibration(
    *,
    target: Iterable[int],
    score: pl.Series | None = None,
    name: str | None = None,
    scores: List[pl.Series] | None = None,
    names: List[str] | None = None,
    n_bins: int = 10,
) -> alt.Chart:
    """
    Plots probability calibration of score(s) with respect to the binary target.

    Parameters
    ----------
    target
        The target binary varialbe
    score
        The probability score values
    name
        The name of the probability score values
    scores
        If score is None, and scores is a list of probability scores, this will
        generate a plot with all probability calibrations.
    names
        If scores is population, this must be a list of corresponding score names.
    n_bins
        N quantile bins for the score(s).
    """

    if score is not None:
        if name is None:
            raise ValueError("If `score` is not None, then `name` must not be none.")
        else:
            new_dict = {name: pl.Series(values=score)}

    else:  # score is None
        if (scores is None) or (names is None):
            raise ValueError("If `score` is None, then `scores` and `names` must be populated.")

        if hasattr(scores, "__len__") and (hasattr(names, "__len__")):
            if len(scores) != len(names):
                raise ValueError("Input `scores` and `names` must have the same length.")

            new_dict = {n: pl.Series(values=s) for n, s in zip(names, scores)}
        else:
            raise ValueError("Input `scores` and `names` must be iterables with a length.")

    target_series = pl.Series(name="__actual__", values=target)

    if any(len(s) != len(target_series) for s in new_dict.values()):
        raise ValueError("All input `score(s)` and `target` must have the same length.")

    new_dict["__actual__"] = target_series

    df = pl.from_dict(new_dict)
    perfect_line = pl.int_range(1, 100, step=5, eager=True) / 100
    df_line = pl.DataFrame(
        {"mean_predicted_prob": perfect_line, "fraction_of_positives": perfect_line}
    ).with_columns(score=pl.lit(" y=x", dtype=pl.String), __point__=pl.lit(False, dtype=pl.Boolean))

    df_socres = [
        df.select(s, "__actual__")
        .with_columns(
            pl.col(s).qcut(n_bins, labels=[str(i) for i in range(n_bins)]).alias("__qcuts__")
        )
        .group_by("__qcuts__")
        .agg(
            mean_predicted_prob=pl.col(s).mean().cast(pl.Float64),
            fraction_of_positives=pl.col("__actual__").mean().cast(pl.Float64),
        )
        .sort("__qcuts__")
        .select(
            "mean_predicted_prob",
            "fraction_of_positives",
            score=pl.lit(s, dtype=pl.String),
            __point__=pl.lit(False, dtype=pl.Boolean),
        )
        for s in new_dict.keys()
        if s != "__actual__"
    ]

    chart1 = (
        alt.Chart(df_line)
        .mark_line(point=False)
        .encode(
            x="mean_predicted_prob:Q",
            y="fraction_of_positives:Q",
            color="score:N",
            strokeDash=alt.value([5, 5]),
        )
    )
    selection = alt.selection_multi(fields=["score"], bind="legend")
    chart2 = (
        alt.Chart(pl.concat(df_socres))
        .mark_line(point=True)
        .encode(
            x="mean_predicted_prob:Q",
            y="fraction_of_positives:Q",
            color="score:N",
            tooltip=["mean_predicted_prob", "fraction_of_positives"],
            opacity=alt.condition(selection, alt.value(0.8), alt.value(0.1)),
        )
        .add_params(selection)
    )

    return chart1 + chart2

plot_roc_auc(*, target, pred=None, name=None, preds=None, names=None, show_auc=True, estimator_name='', n_decimals=4, **kwargs)

Parameters:

Name Type Description Default
target Iterable[int]

A column which has the actual binary target information

required
pred Series | None

The prediction probability variable

None
name str | None

The name for the prediction

None
preds List[Series] | None

The prediction probability variables

None
names List[str] | None

The names for the predictions

None
show_auc bool

Whether to show the AUC value or not

True
estimator_name str

The name of the estimator to be shown in the plot

''
n_decimals int

Round to n-th decimal digit if show_auc is True

4
kwargs

Other keyword arguments to Altair's mark_line

{}
Source code in python/polars_ds/eda/plots.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
def plot_roc_auc(
    *,
    target: Iterable[int],
    pred: pl.Series | None = None,
    name: str | None = None,
    preds: List[pl.Series] | None = None,
    names: List[str] | None = None,
    show_auc: bool = True,
    estimator_name: str = "",
    n_decimals: int = 4,
    **kwargs,
) -> alt.Chart:
    """
    Parameters
    ----------
    target
        A column which has the actual binary target information
    pred
        The prediction probability variable
    name
        The name for the prediction
    preds
        The prediction probability variables
    names
        The names for the predictions
    show_auc
        Whether to show the AUC value or not
    estimator_name
        The name of the estimator to be shown in the plot
    n_decimals
        Round to n-th decimal digit if show_auc is True
    kwargs
        Other keyword arguments to Altair's mark_line
    """

    if pred is not None:
        if name is None:
            raise ValueError("If `pred` is not None, then `name` must not be none.")
        else:
            new_dict = {name: pl.Series(values=pred)}

    else:  # pred is None
        if (preds is None) or (names is None):
            raise ValueError("If `pred` is None, then `preds` and `names` must be populated.")

        if hasattr(preds, "__len__") and (hasattr(names, "__len__")):
            if len(preds) != len(names):
                raise ValueError("Input `preds` and `names` must have the same length.")

            new_dict = {n: pl.Series(values=s) for n, s in zip(names, preds)}
        else:
            raise ValueError("Input `preds` and `names` must be iterables with a length.")

    target_series = pl.Series(name="__actual__", values=target)

    if any(len(s) != len(target_series) for s in new_dict.values()):
        raise ValueError("All input `score(s)` and `target` must have the same length.")

    pred_names = list(new_dict.keys())
    new_dict["__actual__"] = target_series
    df_tmp = pl.from_dict(new_dict)
    dfs = []

    for p in pred_names:
        zero = pl.DataFrame(
            {
                "tpr": [0.0],
                "fpr": [0.0],
            },
            schema={
                "tpr": pl.Float64,
                "fpr": pl.Float64,
            },
        )
        tpr_fpr = (
            df_tmp.select(tpr_fpr=query_tpr_fpr("__actual__", p).reverse())
            .unnest("tpr_fpr")
            .select("tpr", "fpr")
        )

        text = p
        if show_auc:
            auc = tpr_fpr.select(integrate_trapz("tpr", "fpr")).item(0, 0)
            text += f" (AUC = {round(auc, n_decimals)})"

        dfs.append(pl.concat([zero, tpr_fpr]).with_columns(name=pl.lit(text, dtype=pl.String)))

    perfect_line = pl.int_range(1, 100, step=5, eager=True) / 100
    df_line = pl.DataFrame({"fpr": perfect_line, "tpr": perfect_line}).with_columns(
        name=pl.lit(" y=x", dtype=pl.String),
    )
    chart1 = (
        alt.Chart(df_line)
        .mark_line()
        .encode(
            x=alt.X("fpr", title="False Positive Rate"),
            y=alt.Y("tpr", title="True Positive Rate"),
            color="name:N",
            strokeDash=alt.value([5, 5]),
        )
    )
    selection = alt.selection_multi(fields=["name"], bind="legend")
    chart2 = (
        alt.Chart(pl.concat(dfs))
        .mark_line(interpolate="step", **kwargs)
        .encode(
            x=alt.X("fpr", title="False Positive Rate"),
            y=alt.Y("tpr", title="True Positive Rate"),
            color="name:N",
            opacity=alt.condition(selection, alt.value(0.8), alt.value(0.1)),
        )
        .add_params(selection)
    )

    return chart1 + chart2