Skip to content

spalah.dataframe

SchemaComparer

The SchemaComparer is to compare two spark dataframe schemas and find matched and not matched columns.

Source code in spalah/dataframe/dataframe.py
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
class SchemaComparer:
    """
    The SchemaComparer is to compare two spark dataframe schemas and find matched
    and not matched columns.
    """

    def __init__(
        self, source_schema: T.StringType, target_schema: T.StringType
    ) -> None:
        """Constructs all the necessary input attributes for the SchemaComparer object.

        Args:
            source_schema (T.StringType): source schema to match
            target_schema (T.StringType): target schema to match

        Examples:
            >>> from spalah.dataframe import SchemaComparer
            >>> schema_comparer = SchemaComparer(
            ...     source_schema = df_source.schema,
            ...     target_schema = df_target.schema
            ... )
        """
        self._source = self.__import_schema(source_schema)
        self._target = self.__import_schema(target_schema)

        self.matched: List[tuple] = list()
        """List of matched columns"""
        self.not_matched: List[tuple] = list()
        """The list of not matched columns"""

    def __import_schema(self, input_schema: T.StructType) -> Set[tuple]:
        """Import StructType as the flatten set of tuples: (column_name, data_type)

        Args:
            input_schema (T.StructType): Schema to process

        Raises:
            TypeError: if input schema has a type: DataFrame
            TypeError: if input schema hasn't a type: StructType

        Returns:
            Set[tuple]: Set of tuples: (column_name, data_type)
        """

        if isinstance(input_schema, DataFrame):
            raise TypeError(
                "One of 'source_schema or 'target_schema' passed as a DataFrame. "
                "Use DataFrame.schema instead"
            )
        elif not isinstance(input_schema, T.StructType):
            raise TypeError(
                "Parameters 'source_schema and 'target_schema' must have a type: StructType"
            )

        return set(flatten_schema(input_schema, True))

    def __match_by_name_and_type(
        self, source: Set[tuple] = set(), target: Set[tuple] = set()
    ) -> Set[tuple]:
        """Matches columns in source and target schemas by name and data type

        Args:
            source (Set[tuple], optional): Flattened source schema. Defaults to set().
            target (Set[tuple], optional): Flattened target schema. Defaults to set().

        Returns:
            Set[tuple]: Fully matched columns as a set of tuples: (column_name, data_type)
        """

        # If source and target is not provided, use class attributes as the input
        _source = self._source if not source else source
        _target = self._target if not target else target

        result = _source & _target

        # Remove matched values of case 1 from further processing
        self.__remove_matched_by_name_and_type(result)

        if not (source and target):
            self.__populate_matched(result)

        return result

    def __remove_matched_by_name_and_type(self, subtract_value: Set[tuple]) -> None:
        """Removes fully matched columns from the further processing

        Args:
            subtract_value (Set[tuple]): Set of matched columns
        """

        self._source = self._source - subtract_value
        self._target = self._target - subtract_value

    def __remove_matched_by_name(self, subtract_value: Set[tuple]) -> None:
        """Removes matched by name columns from the further processing

        Args:
            subtract_value (Set[tuple]): Set of matched column
        """

        def _remove(input_value: Set[tuple], subtract_value: Set[tuple]) -> Set[tuple]:
            """Internal helper for removal of Tuples by the first member"""
            return {
                (x, y)
                for (x, y) in input_value
                if x.lower() not in [z[0].lower() for z in subtract_value]
            }

        self._source = _remove(self._source, subtract_value)  # type: ignore
        self._target = _remove(self._target, subtract_value)  # type: ignore

    def __lower_column_names(self, base_value: Set[tuple]) -> Set[tuple]:
        """Lower-case all column names of the input set

        Args:
            base_value (Set[tuple]): Input set of columns

        Returns:
            Set[tuple]: Output set of columns with lower-case column names
        """
        return {(x.lower(), y) for (x, y) in base_value}

    def __match_by_name_type_excluding_case(self) -> None:
        """Matches columns in source and target schemas by name and data type
        without taking into account column name case
        """

        _source_lowered = self.__lower_column_names(self._source)
        _target_lowered = self.__lower_column_names(self._target)

        result = self.__match_by_name_and_type(_source_lowered, _target_lowered)

        # Remove matched values of case 2 from further processing
        self.__remove_matched_by_name(result)

        self.__populate_not_matched(
            result,
            "The column exists in source and target schemas but it's name is case-mismatched",
        )

    def __match_by_name_but_not_type(self) -> None:
        """Matches columns in source and target schemas only by column name"""

        x = dict(self._source)  # type: ignore
        y = dict(self._target)  # type: ignore

        result = {(k, f"{x[k]} <=> {y[k]}") for k in x if k in y and x[k] != y[k]}

        # Remove matched values of case 3 from further processing
        self.__remove_matched_by_name(result)  # type: ignore

        self.__populate_not_matched(
            result,  # type: ignore
            "The column exists in source and target schemas but it is not matched by a data type",
        )

    def __process_remaining_non_matched_columns(self) -> None:
        """Process remaining not matched columns"""

        self.__populate_not_matched(
            self._source, "The column exists only in the source schema"
        )

        self.__populate_not_matched(
            self._target, "The column exists only in the target schema"
        )

        self.__remove_matched_by_name(self._source)
        self.__remove_matched_by_name(self._target)

    def __populate_matched(self, input_value: Set[tuple]) -> None:
        """Populate class property 'matched' with a list of fully matched columns

        Args:
            input_value (Set[tuple]): The set of tuples with a list of column names and data types
        """

        for match in input_value:
            self.matched.append(MatchedColumn(name=match[0], data_type=match[1]))

    def __populate_not_matched(self, input_value: Set[tuple], reason: str) -> None:
        """Populate class property 'not_matched' with a list of columns that didn't match for some
        reason with included an actual reason

        Args:
            input_value (Set[tuple]): The set of tuples with a list of column names and data types
            reason (str): Reason for not match

        """

        for match in input_value:
            self.not_matched.append(
                NotMatchedColumn(name=match[0], data_type=match[1], reason=reason)
            )

    def compare(self) -> None:
        """
        Compares the source and target schemas and populates properties `matched` and `not_matched`

        Examples:
            >>> # instantiate schema_comparer firstly, see example above
            >>> schema_comparer.compare()

            Get list of all columns that are matched by name and type:
            >>> schema_comparer.matched
            [MatchedColumn(name='Address.Line1',  data_type='StringType')]

            Get unmatched columns:
            >>> schema_comparer.not_matched
            [
                NotMatchedColumn(
                    name='name',
                    data_type='StringType',
                    reason="The column exists in source and target schemas but it's name is \
case-mismatched"
                ),
                NotMatchedColumn(
                    name='Address.Line2',
                    data_type='StringType',
                    reason='The column exists only in the source schema'
                )
            ]
        """

        # Case 1: find columns that are matched by name and type and remove them
        # from further processing
        self.__match_by_name_and_type()

        # Case 2: find columns that match mismatched by name due to case: ID <-> Id
        self.__match_by_name_type_excluding_case()

        # Case 3: Find columns matched by name, but not by data type
        self.__match_by_name_but_not_type()

        # Case 4: Find columns that exists only in the source or target
        self.__process_remaining_non_matched_columns()

matched: List[tuple] = list() instance-attribute

List of matched columns

not_matched: List[tuple] = list() instance-attribute

The list of not matched columns

__init__(source_schema, target_schema)

Constructs all the necessary input attributes for the SchemaComparer object.

Parameters:

Name Type Description Default
source_schema T.StringType

source schema to match

required
target_schema T.StringType

target schema to match

required

Examples:

>>> from spalah.dataframe import SchemaComparer
>>> schema_comparer = SchemaComparer(
...     source_schema = df_source.schema,
...     target_schema = df_target.schema
... )
Source code in spalah/dataframe/dataframe.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def __init__(
    self, source_schema: T.StringType, target_schema: T.StringType
) -> None:
    """Constructs all the necessary input attributes for the SchemaComparer object.

    Args:
        source_schema (T.StringType): source schema to match
        target_schema (T.StringType): target schema to match

    Examples:
        >>> from spalah.dataframe import SchemaComparer
        >>> schema_comparer = SchemaComparer(
        ...     source_schema = df_source.schema,
        ...     target_schema = df_target.schema
        ... )
    """
    self._source = self.__import_schema(source_schema)
    self._target = self.__import_schema(target_schema)

    self.matched: List[tuple] = list()
    """List of matched columns"""
    self.not_matched: List[tuple] = list()
    """The list of not matched columns"""

compare()

Compares the source and target schemas and populates properties matched and not_matched

Examples:

>>> # instantiate schema_comparer firstly, see example above
>>> schema_comparer.compare()

Get list of all columns that are matched by name and type:

>>> schema_comparer.matched
[MatchedColumn(name='Address.Line1',  data_type='StringType')]

Get unmatched columns:

>>> schema_comparer.not_matched
[
    NotMatchedColumn(
        name='name',
        data_type='StringType',
        reason="The column exists in source and target schemas but it's name is case-mismatched"
    ),
    NotMatchedColumn(
        name='Address.Line2',
        data_type='StringType',
        reason='The column exists only in the source schema'
    )
]
Source code in spalah/dataframe/dataframe.py
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    def compare(self) -> None:
        """
        Compares the source and target schemas and populates properties `matched` and `not_matched`

        Examples:
            >>> # instantiate schema_comparer firstly, see example above
            >>> schema_comparer.compare()

            Get list of all columns that are matched by name and type:
            >>> schema_comparer.matched
            [MatchedColumn(name='Address.Line1',  data_type='StringType')]

            Get unmatched columns:
            >>> schema_comparer.not_matched
            [
                NotMatchedColumn(
                    name='name',
                    data_type='StringType',
                    reason="The column exists in source and target schemas but it's name is \
case-mismatched"
                ),
                NotMatchedColumn(
                    name='Address.Line2',
                    data_type='StringType',
                    reason='The column exists only in the source schema'
                )
            ]
        """

        # Case 1: find columns that are matched by name and type and remove them
        # from further processing
        self.__match_by_name_and_type()

        # Case 2: find columns that match mismatched by name due to case: ID <-> Id
        self.__match_by_name_type_excluding_case()

        # Case 3: Find columns matched by name, but not by data type
        self.__match_by_name_but_not_type()

        # Case 4: Find columns that exists only in the source or target
        self.__process_remaining_non_matched_columns()

flatten_schema(schema, include_datatype=False, column_prefix=None)

Parses spark dataframe schema and returns the list of columns If the schema is nested, the columns are flattened

Parameters:

Name Type Description Default
schema StructType

Input dataframe schema

required
include_type bool

Flag to include column types

required
column_prefix str

Column name prefix. Defaults to None.

None

Returns:

Type Description
list

The list of (flattened) column names

Examples:

>>> from spalah.dataframe import flatten_schema
>>> flatten_schema(schema=df_complex_schema.schema)

returns the list of columns, nested are flattened:

>>> ['ID', 'Name', 'Address.Line1', 'Address.Line2']
Source code in spalah/dataframe/dataframe.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def flatten_schema(
    schema: T.StructType,
    include_datatype: bool = False,
    column_prefix: Optional[str] = None,
) -> list:
    """Parses spark dataframe schema and returns the list of columns
    If the schema is nested, the columns are flattened

    Args:
        schema (StructType): Input dataframe schema
        include_type (bool, optional): Flag to include column types
        column_prefix (str, optional): Column name prefix. Defaults to None.

    Returns:
        The list of (flattened) column names

    Examples:
        >>> from spalah.dataframe import flatten_schema
        >>> flatten_schema(schema=df_complex_schema.schema)

        returns the list of columns, nested are flattened:
        >>> ['ID', 'Name', 'Address.Line1', 'Address.Line2']
    """

    if not isinstance(schema, T.StructType):
        raise TypeError("Parameter schema must be a StructType")

    columns = []

    for column in schema.fields:
        if column_prefix:
            name = column_prefix + "." + column.name
        else:
            name = column.name

        column_data_type = column.dataType

        if isinstance(column_data_type, T.ArrayType):
            column_data_type = column_data_type.elementType

        if isinstance(column_data_type, T.StructType):
            columns += flatten_schema(
                column_data_type,
                include_datatype=include_datatype,
                column_prefix=name,
            )
        else:
            if include_datatype:
                result = name, str(column_data_type)
            else:
                result = name

            columns.append(result)

    return columns

script_dataframe(input_dataframe, suppress_print_output=True)

Generate a script to recreate the dataframe The script includes the schema and the data

Parameters:

Name Type Description Default
input_dataframe DataFrame

Input spark dataframe

required
suppress_print_output bool

Disable prints to console. Defaults to True.

True

Raises:

Type Description
ValueError

when the dataframe is too large (by default > 20 rows)

Returns:

Type Description
str

The script to recreate the dataframe

Examples:

>>> from spalah.dataframe import script_dataframe
>>> script = script_dataframe(input_dataframe=df)
>>> print(script)
Source code in spalah/dataframe/dataframe.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def script_dataframe(
    input_dataframe: DataFrame, suppress_print_output: bool = True
) -> str:
    """Generate a script to recreate the dataframe
    The script includes the schema and the data

    Args:
        input_dataframe (DataFrame): Input spark dataframe
        suppress_print_output (bool, optional): Disable prints to console. \
            Defaults to True.

    Raises:
        ValueError: when the dataframe is too large (by default > 20 rows)

    Returns:
        The script to recreate the dataframe

    Examples:
        >>> from spalah.dataframe import script_dataframe
        >>> script = script_dataframe(input_dataframe=df)
        >>> print(script)
    """

    MAX_ROWS_IN_SCRIPT = 20

    __dataframe = input_dataframe

    if __dataframe.count() > MAX_ROWS_IN_SCRIPT:
        raise ValueError(
            "This method is limited to script up "
            f"to {MAX_ROWS_IN_SCRIPT} row(s) per call"
        )

    __schema = input_dataframe.schema.jsonValue()

    __script_lines = [
        "from pyspark.sql import Row",
        "import datetime",
        "from decimal import Decimal",
        "from pyspark.sql.types import *",
        "",
        "# Scripted data and schema:",
        f"__data = {pformat(__dataframe.collect())}",
        "",
        f"__schema = {__schema}",
        "",
        "outcome_dataframe = spark.createDataFrame(__data, StructType.fromJson(__schema))",
    ]

    __final_script = "\n".join(__script_lines)

    if not suppress_print_output:
        print("#", "=" * 80)
        print(
            "# IMPORTANT!!! REMOVE PII DATA BEFORE RE-CREATING IT IN NON-PRODUCTION ENVIRONMENTS",
            " " * 3,
            "#",
        )
        print("#", "=" * 80)
        print("")
        print(__final_script)

    return __final_script

slice_dataframe(input_dataframe, columns_to_include=None, columns_to_exclude=None, nullify_only=False, generate_sql=False, debug=False)

Process flat or nested schema of the dataframe by slicing the schema or nullifying columns

Parameters:

Name Type Description Default
input_dataframe DataFrame

Input dataframe

required
columns_to_include Optional[List]

Columns that must remain in the dataframe unchanged

None
columns_to_exclude Optional[List]

Columns that must be removed (or nullified)

None
nullify_only bool

Nullify columns instead of removing them. Defaults to False

False
debug bool

For extra debug output. Defaults to False.

False

Raises:

Type Description
TypeError

If the 'column_to_include' or 'column_to_exclude' are not type list

ValueError

If the included columns overlay excluded columns, so nothing to return

Returns:

Name Type Description
DataFrame DataFrame

The processed dataframe

Examples:

>>> from spalah.dataframe import slice_dataframe
>>> df = spark.sql(
...         'SELECT 1 as ID, "John" AS Name,
...         struct("line1" AS Line1, "line2" AS Line2) AS Address'
...     )
>>> df_sliced = slice_dataframe(
...     input_dataframe=df,
...     columns_to_include=["Name", "Address"],
...     columns_to_exclude=["Address.Line2"]
... )

As the result, the dataframe will contain only the columns Name and Address.Line1 because Name and Address are included and a nested element Address.Line2 is excluded

>>> df_result.printSchema()
root
|-- Name: string (nullable = false)
|-- Address: struct (nullable = false)
|    |-- Line1: string (nullable = false)
Source code in spalah/dataframe/dataframe.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def slice_dataframe(
    input_dataframe: DataFrame,
    columns_to_include: Optional[List] = None,
    columns_to_exclude: Optional[List] = None,
    nullify_only: bool = False,
    generate_sql: bool = False,
    debug: bool = False,
) -> DataFrame:
    """Process flat or nested schema of the dataframe by slicing the schema
    or nullifying columns

    Args:
        input_dataframe (DataFrame): Input dataframe
        columns_to_include (Optional[List]): Columns that must remain in the dataframe unchanged
        columns_to_exclude (Optional[List]): Columns that must be removed (or nullified)
        nullify_only (bool, optional): Nullify columns instead of removing them. Defaults to False
        debug (bool, optional): For extra debug output. Defaults to False.

    Raises:
        TypeError: If the 'column_to_include' or 'column_to_exclude' are not type list
        ValueError: If the included columns overlay excluded columns, so nothing to return

    Returns:
        DataFrame: The processed dataframe

    Examples:
        >>> from spalah.dataframe import slice_dataframe
        >>> df = spark.sql(
        ...         'SELECT 1 as ID, "John" AS Name,
        ...         struct("line1" AS Line1, "line2" AS Line2) AS Address'
        ...     )
        >>> df_sliced = slice_dataframe(
        ...     input_dataframe=df,
        ...     columns_to_include=["Name", "Address"],
        ...     columns_to_exclude=["Address.Line2"]
        ... )

        As the result, the dataframe will contain only the columns `Name` and `Address.Line1` \
            because `Name` and `Address` are included and a nested element `Address.Line2` is \
            excluded
        >>> df_result.printSchema()
        root
        |-- Name: string (nullable = false)
        |-- Address: struct (nullable = false)
        |    |-- Line1: string (nullable = false)
    """

    projection = []
    spark = SparkSession.getActiveSession()

    # Verification of input parameters:

    if input_dataframe and not isinstance(input_dataframe, DataFrame):
        raise TypeError("input_dataframe must be a dataframe")
    elif input_dataframe and isinstance(input_dataframe, DataFrame):
        _schema = input_dataframe.schema
        _table_identifier = "input_dataframe"
        _input_dataframe = input_dataframe

    if not columns_to_include:
        columns_to_include = []

    if not columns_to_exclude:
        columns_to_exclude = []

    if not (type(columns_to_include) is list and type(columns_to_exclude) is list):
        raise TypeError(
            "The type of parameters 'columns_to_include', 'columns_to_exclude' "
            "must be a list"
        )

    if not all(
        isinstance(item, str) for item in columns_to_include + columns_to_exclude
    ):
        raise TypeError(
            "Members of 'columns_to_include' and 'columns_to_exclude' "
            "must be a string"
        )

    if debug:
        print("The list of columns to include:")
        pprint(columns_to_include)

        print("The list of columns to exclude:")
        pprint(columns_to_exclude)

        if nullify_only:
            print("Columns to nullify in the final projection:")
        else:
            print("Columns to include into the final projection:")

    # lower-case items for making further filtering case-insensitive
    columns_to_include = [item.lower() for item in columns_to_include]
    columns_to_exclude = [item.lower() for item in columns_to_exclude]

    for field in _schema.fields:
        node_result = __process_schema_node(
            node=field,
            columns_to_include=columns_to_include,
            columns_to_exclude=columns_to_exclude,
            nullify_only=nullify_only,
            debug=debug,
        )

        if node_result is not None:
            projection.append(node_result)

    if not projection:
        raise ValueError(
            "At least one column should be listed in the "
            + "columns_to_include/columns_to_exclude attributes "
            + "and included column should not directly overlap with excluded one"
        )

    if generate_sql:
        delimiter = ", \n"
        result = f"SELECT \n{delimiter.join( projection)} \nFROM {_table_identifier}"
    else:
        result = _input_dataframe.selectExpr(*projection)

    return result