Skip to content

spalah.dataframe

SchemaComparer(source_schema, target_schema)

The SchemaComparer is to compare two spark dataframe schemas and find matched and not matched columns.

Constructs all the necessary input attributes for the SchemaComparer object.

Parameters:

Name Type Description Default
source_schema StructType

source schema to match

required
target_schema StructType

target schema to match

required

Examples:

>>> from spalah.dataframe import SchemaComparer
>>> schema_comparer = SchemaComparer(
...     source_schema = df_source.schema,
...     target_schema = df_target.schema
... )
Source code in spalah/dataframe/dataframe.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
def __init__(self, source_schema: T.StructType, target_schema: T.StructType) -> None:
    """Constructs all the necessary input attributes for the SchemaComparer object.

    Args:
        source_schema (T.StructType): source schema to match
        target_schema (T.StructType): target schema to match

    Examples:
        >>> from spalah.dataframe import SchemaComparer
        >>> schema_comparer = SchemaComparer(
        ...     source_schema = df_source.schema,
        ...     target_schema = df_target.schema
        ... )
    """
    self._source = self.__import_schema(source_schema)
    self._target = self.__import_schema(target_schema)

    self.matched: List[tuple] = list()
    """List of matched columns"""
    self.not_matched: List[tuple] = list()
    """The list of not matched columns"""

matched = list() instance-attribute

List of matched columns

not_matched = list() instance-attribute

The list of not matched columns

compare()

Compares the source and target schemas and populates properties matched and not_matched

Examples:

>>> # instantiate schema_comparer firstly, see example above
>>> schema_comparer.compare()

Get list of all columns that are matched by name and type:

>>> schema_comparer.matched
[MatchedColumn(name='Address.Line1',  data_type='StringType')]

Get unmatched columns:

>>> schema_comparer.not_matched
[
    NotMatchedColumn(
        name='name',
        data_type='StringType',
        reason="The column exists in source and target schemas but it's name is case-mismatched"
    ),
    NotMatchedColumn(
        name='Address.Line2',
        data_type='StringType',
        reason='The column exists only in the source schema'
    )
]
Source code in spalah/dataframe/dataframe.py
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
    def compare(self) -> None:
        """
        Compares the source and target schemas and populates properties `matched` and `not_matched`

        Examples:
            >>> # instantiate schema_comparer firstly, see example above
            >>> schema_comparer.compare()

            Get list of all columns that are matched by name and type:
            >>> schema_comparer.matched
            [MatchedColumn(name='Address.Line1',  data_type='StringType')]

            Get unmatched columns:
            >>> schema_comparer.not_matched
            [
                NotMatchedColumn(
                    name='name',
                    data_type='StringType',
                    reason="The column exists in source and target schemas but it's name is \
case-mismatched"
                ),
                NotMatchedColumn(
                    name='Address.Line2',
                    data_type='StringType',
                    reason='The column exists only in the source schema'
                )
            ]
        """

        # Case 1: find columns that are matched by name and type and remove them
        # from further processing
        self.__match_by_name_and_type()

        # Case 2: find columns that match mismatched by name due to case: ID <-> Id
        self.__match_by_name_type_excluding_case()

        # Case 3: Find columns matched by name, but not by data type
        self.__match_by_name_but_not_type()

        # Case 4: Find columns that exists only in the source or target
        self.__process_remaining_non_matched_columns()

flatten_schema(schema, include_datatype=False, column_prefix=None)

Parses spark dataframe schema and returns the list of columns If the schema is nested, the columns are flattened

Parameters:

Name Type Description Default
schema StructType

Input dataframe schema

required
include_type bool

Flag to include column types

required
column_prefix str

Column name prefix. Defaults to None.

None

Returns:

Type Description
list

The list of (flattened) column names

Examples:

>>> from spalah.dataframe import flatten_schema
>>> flatten_schema(schema=df_complex_schema.schema)

returns the list of columns, nested are flattened:

>>> ['ID', 'Name', 'Address.Line1', 'Address.Line2']
Source code in spalah/dataframe/dataframe.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def flatten_schema(
    schema: T.StructType,
    include_datatype: bool = False,
    column_prefix: Optional[str] = None,
) -> list:
    """Parses spark dataframe schema and returns the list of columns
    If the schema is nested, the columns are flattened

    Args:
        schema (StructType): Input dataframe schema
        include_type (bool, optional): Flag to include column types
        column_prefix (str, optional): Column name prefix. Defaults to None.

    Returns:
        The list of (flattened) column names

    Examples:
        >>> from spalah.dataframe import flatten_schema
        >>> flatten_schema(schema=df_complex_schema.schema)

        returns the list of columns, nested are flattened:
        >>> ['ID', 'Name', 'Address.Line1', 'Address.Line2']
    """

    if not isinstance(schema, T.StructType):
        raise TypeError("Parameter schema must be a StructType")

    columns = []

    for column in schema.fields:
        if column_prefix:
            name = column_prefix + "." + column.name
        else:
            name = column.name

        column_data_type = column.dataType

        if isinstance(column_data_type, T.ArrayType):
            column_data_type = column_data_type.elementType

        if isinstance(column_data_type, T.StructType):
            columns += flatten_schema(
                column_data_type,
                include_datatype=include_datatype,
                column_prefix=name,
            )
        else:
            if include_datatype:
                result = name, str(column_data_type)
            else:
                result = name

            columns.append(result)

    return columns

script_dataframe(input_dataframe, suppress_print_output=True)

Generate a script to recreate the dataframe The script includes the schema and the data

Parameters:

Name Type Description Default
input_dataframe DataFrame

Input spark dataframe

required
suppress_print_output bool

Disable prints to console. Defaults to True.

True

Raises:

Type Description
ValueError

when the dataframe is too large (by default > 20 rows)

Returns:

Type Description
str

The script to recreate the dataframe

Examples:

>>> from spalah.dataframe import script_dataframe
>>> script = script_dataframe(input_dataframe=df)
>>> print(script)
Source code in spalah/dataframe/dataframe.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def script_dataframe(input_dataframe: DataFrame, suppress_print_output: bool = True) -> str:
    """Generate a script to recreate the dataframe
    The script includes the schema and the data

    Args:
        input_dataframe (DataFrame): Input spark dataframe
        suppress_print_output (bool, optional): Disable prints to console. \
            Defaults to True.

    Raises:
        ValueError: when the dataframe is too large (by default > 20 rows)

    Returns:
        The script to recreate the dataframe

    Examples:
        >>> from spalah.dataframe import script_dataframe
        >>> script = script_dataframe(input_dataframe=df)
        >>> print(script)
    """

    MAX_ROWS_IN_SCRIPT = 20

    __dataframe = input_dataframe

    if __dataframe.count() > MAX_ROWS_IN_SCRIPT:
        raise ValueError(
            f"This method is limited to script up to {MAX_ROWS_IN_SCRIPT} row(s) per call"
        )

    __schema = input_dataframe.schema.jsonValue()

    __script_lines = [
        "from pyspark.sql import Row",
        "import datetime",
        "from decimal import Decimal",
        "from pyspark.sql.types import *",
        "",
        "# Scripted data and schema:",
        f"__data = {pformat(__dataframe.collect())}",
        "",
        f"__schema = {__schema}",
        "",
        "outcome_dataframe = spark.createDataFrame(__data, StructType.fromJson(__schema))",
    ]

    __final_script = "\n".join(__script_lines)

    if not suppress_print_output:
        print("#", "=" * 80)
        print(
            "# IMPORTANT!!! REMOVE PII DATA BEFORE RE-CREATING IT IN NON-PRODUCTION ENVIRONMENTS",
            " " * 3,
            "#",
        )
        print("#", "=" * 80)
        print("")
        print(__final_script)

    return __final_script

slice_dataframe(input_dataframe, columns_to_include=None, columns_to_exclude=None, nullify_only=False, generate_sql=False, debug=False)

Process flat or nested schema of the dataframe by slicing the schema or nullifying columns

Parameters:

Name Type Description Default
input_dataframe DataFrame

Input dataframe

required
columns_to_include Optional[List]

Columns that must remain in the dataframe unchanged

None
columns_to_exclude Optional[List]

Columns that must be removed (or nullified)

None
nullify_only bool

Nullify columns instead of removing them. Defaults to False

False
debug bool

For extra debug output. Defaults to False.

False

Raises:

Type Description
TypeError

If the 'column_to_include' or 'column_to_exclude' are not type list

ValueError

If the included columns overlay excluded columns, so nothing to return

Returns:

Name Type Description
DataFrame DataFrame | str

The processed dataframe

Examples:

>>> from spalah.dataframe import slice_dataframe
>>> df = spark.sql(
...         'SELECT 1 as ID, "John" AS Name,
...         struct("line1" AS Line1, "line2" AS Line2) AS Address'
...     )
>>> df_sliced = slice_dataframe(
...     input_dataframe=df,
...     columns_to_include=["Name", "Address"],
...     columns_to_exclude=["Address.Line2"]
... )

As the result, the dataframe will contain only the columns Name and Address.Line1 because Name and Address are included and a nested element Address.Line2 is excluded

>>> df_result.printSchema()
root
|-- Name: string (nullable = false)
|-- Address: struct (nullable = false)
|    |-- Line1: string (nullable = false)
Source code in spalah/dataframe/dataframe.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def slice_dataframe(
    input_dataframe: DataFrame,
    columns_to_include: Optional[List] = None,
    columns_to_exclude: Optional[List] = None,
    nullify_only: bool = False,
    generate_sql: bool = False,
    debug: bool = False,
) -> DataFrame | str:
    """Process flat or nested schema of the dataframe by slicing the schema
    or nullifying columns

    Args:
        input_dataframe (DataFrame): Input dataframe
        columns_to_include (Optional[List]): Columns that must remain in the dataframe unchanged
        columns_to_exclude (Optional[List]): Columns that must be removed (or nullified)
        nullify_only (bool, optional): Nullify columns instead of removing them. Defaults to False
        debug (bool, optional): For extra debug output. Defaults to False.

    Raises:
        TypeError: If the 'column_to_include' or 'column_to_exclude' are not type list
        ValueError: If the included columns overlay excluded columns, so nothing to return

    Returns:
        DataFrame: The processed dataframe

    Examples:
        >>> from spalah.dataframe import slice_dataframe
        >>> df = spark.sql(
        ...         'SELECT 1 as ID, "John" AS Name,
        ...         struct("line1" AS Line1, "line2" AS Line2) AS Address'
        ...     )
        >>> df_sliced = slice_dataframe(
        ...     input_dataframe=df,
        ...     columns_to_include=["Name", "Address"],
        ...     columns_to_exclude=["Address.Line2"]
        ... )

        As the result, the dataframe will contain only the columns `Name` and `Address.Line1` \
            because `Name` and `Address` are included and a nested element `Address.Line2` is \
            excluded
        >>> df_result.printSchema()
        root
        |-- Name: string (nullable = false)
        |-- Address: struct (nullable = false)
        |    |-- Line1: string (nullable = false)
    """

    projection = []

    # Verification of input parameters:

    if input_dataframe and not isinstance(input_dataframe, DataFrame):
        raise TypeError("input_dataframe must be a dataframe")
    elif input_dataframe and isinstance(input_dataframe, DataFrame):
        _schema = input_dataframe.schema
        _table_identifier = "input_dataframe"
        _input_dataframe = input_dataframe

    if not columns_to_include:
        columns_to_include = []

    if not columns_to_exclude:
        columns_to_exclude = []

    if not (type(columns_to_include) is list and type(columns_to_exclude) is list):
        raise TypeError(
            "The type of parameters 'columns_to_include', 'columns_to_exclude' must be a list"
        )

    if not all(isinstance(item, str) for item in columns_to_include + columns_to_exclude):
        raise TypeError("Members of 'columns_to_include' and 'columns_to_exclude' must be a string")

    if debug:
        print("The list of columns to include:")
        pprint(columns_to_include)

        print("The list of columns to exclude:")
        pprint(columns_to_exclude)

        if nullify_only:
            print("Columns to nullify in the final projection:")
        else:
            print("Columns to include into the final projection:")

    # lower-case items for making further filtering case-insensitive
    columns_to_include = [item.lower() for item in columns_to_include]
    columns_to_exclude = [item.lower() for item in columns_to_exclude]

    for field in _schema.fields:
        node_result = __process_schema_node(
            node=field,
            columns_to_include=columns_to_include,
            columns_to_exclude=columns_to_exclude,
            nullify_only=nullify_only,
            debug=debug,
        )

        if node_result is not None:
            projection.append(node_result)

    if not projection:
        raise ValueError(
            "At least one column should be listed in the "
            + "columns_to_include/columns_to_exclude attributes "
            + "and included column should not directly overlap with excluded one"
        )

    if generate_sql:
        delimiter = ", \n"
        result = f"SELECT \n{delimiter.join(projection)} \nFROM {_table_identifier}"
    else:
        result = _input_dataframe.selectExpr(*projection)

    return result