Data Analysis with Python and PySpark

Notes and code written while reading Data Analysis with Python and PySpark by Jonathan Rioux.

Link to book

Preface: Environment setup

  1. Install homebrew
  2. Install Java and Spark via brew install apache-spark
    • Save yourself a headache by running sudo ln -sfn $(brew --prefix)/opt/openjdk@11/libexec/openjdk.jdk /Library/Java/JavaVirtualMachines/openjdk-11.jdk after installation. This makes sure system Java wrappers find the Java development kit (JDK) associated with this package.
  3. Create conda environment via conda create -n pyspark python=3.8 pandas pyspark=3.0.0
  4. Activate conda environment by conda activate spark
  5. Install jupyter notebook to your environment via conda install -c conda-forge notebook
    • Add the environment to your notebook:
      conda install -c anaconda ipykernel
      python -m ipykernel install --user --name=pyspark

Chapter 1: Basics

Structure

  • Pypark is slower than native Scala code. At its core, Pyspark code needs to be translated to Scala via JVM (Java Virtual Machine). This causes a speed bottleneck.

Terminology

  • "Workers" are called executors in Spark lingo. They perform the actual work on the machines. They perform the actual work on the machines.
  • The master node manages these workers.
  • The driver program is the task that the workers are going to run.
  • Cluster managers is a program that plans the capacity it will allocate to the driver program. Spark provides its own manager, but can use others like YARN, Mesos, Kubernetes, etc.
  • Any directions about capacity (machines and executors) are encoded in a SparkContext object which represents the connection to our Spark cluster.
  • Master allocates data to executors, which are the processes that run computations and store data for the application.
  • Executors sit on worker nodes which is the actual computer.
  • Executor = actual worker, worker node = workbench that the executor performs work on.

Example of 4 workers working together to calculate and average of one column:

workers

Lazy vs Eager Evaluation

Python R Java are eagerly evaluated. Spark is lazily evaluated.

Spark distinguishes between transformations and actions.

Transformations are:

  • Adding a column to a table
  • Performing aggregations
  • Computing stats
  • Training a ML model on data
  • Reading data

Actions are:

  • Printing information to the screen (i.e. show)
  • Writing data to a hard drive or a cloud bucket (i.e. write).

A Spark program will avoid performing any data work an action triggers the computation chain. Before that the master will cache your instructions. Benefits are:

  1. Storing instructions in memory takes less space than storing intermediate data frames.
  2. Caching the tasks allow the master to optimize the work between the executors more efficiently.
  3. If one node fails during processing, Spark can recreate missing chunks of the data by referring to the cached instructions. Simply put, it handles the data recovery part.

Pyspark program in a nutshell

  1. We first encode our instructions in Python code, forming a driver program.
  2. When submitting our program (or launching a PySpark shell), the cluster manager allocates resources for us to use. Those will stay constant for the duration of the program.
  3. The master ingests your code and translate it into Spark instructions. Those instructions are either transformations or actions.
  4. Once the master reaches an action, it optimizes the whole computation chain and splits the work between executors. Executors are processes performing the actual data work and they reside on machines labeled worked nodes.

Chapter 2: First Data Program in Pyspark

Most data-driven application functions in the Extract-Transform-Load (ETL) pipeline:

  1. Ingest or read the data we wish to work with.
  2. Transform the data via a few simple instructions or a very complex machine learning model
  3. Export the resulting data, either into a file to be fed into an app or by summarizing our findings into a visualization.

SparkSession entry point

  • SparkSession provides an entry point to Spark.
    • Wraps SparkContext and provides functionality for interacting with the data.
  • Can be used as a normal object imported from a library in Python.
  • SparkSession builder: builder pattern with set of methods to create a configurable object.

Creating a SparkSession entry point from scratch

from pyspark.sql import SparkSession

spark = (SparkSession
         .builder
         .appName("Analyzing the vocabulary of Pride and Prejudice.")
         .getOrCreate())

Your first Pyspark application

Most data-driven application functions in the Extract-Transform-Load (ETL) pipeline:

  1. Ingest or read the data we wish to work with.
  2. Transform the data via a few simple instructions or a very complex machine learning model
  3. Export the resulting data, either into a file to be fed into an app or by summarizing our findings into a visualization.

SparkSession entry point

  • SparkSession provides an entry point to Spark.
    • Wraps SparkContext and provides functionality for interacting with the data.
  • Can be used as a normal object imported from a library in Python.
  • SparkSession builder: builder pattern with set of methods to create a configurable object.
Creating a SparkSession entry point from scratch
In [ ]:
from pyspark.sql import SparkSession

spark = (SparkSession
         .builder
         .appName("Analyzing the vocabulary of Pride and Prejudice.")
         .getOrCreate())
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 from pyspark.sql import SparkSession
      3 spark = (SparkSession
      4          .builder
      5          .appName("Analyzing the vocabulary of Pride and Prejudice.")
      6          .getOrCreate())

ModuleNotFoundError: No module named 'pyspark'

sparkContext can be invoked from the SparkSession object like below.

(Older code may present sparkContext as an sc variable)

In [7]:
sc = spark.sparkContext
sc
Out[7]:

SparkContext

Spark UI

Version
v3.0.0
Master
local[2]
AppName
Analyzing the vocabulary of Pride and Prejudice.

Setting the log level

  • Spark defaults to WARN.
  • Can change via spark.sparkContext.setLogLevel(KEYWORD)

Log level keywords

Keyword

Description

OFF

No logging at all (not recommended).

FATAL

Only fatal errors. A fatal error will crash your Spark cluster.

ERROR

My personal favorite, will show FATAL as well as other useful (but recoverable) errors.

WARN

Add warnings (and there is quite a lot of them).

INFO

Will give you runtime information, such as repartitioning and data recovery (see chapter 1).

DEBUG

Will provide debug information on your jobs.

TRACE

Will trace your jobs (more verbose debug logs). Can be quite pedagogic, but very annoying.

ALL

Everything that PySpark can spit, it will spit. As useful as OFF.

In [8]:
spark.sparkContext.setLogLevel('ERROR')

Application Design

Goal: What are the most popular words in the Jane Austen's Pride and Prejudice?

Steps:

  1. Read: Read the input data (we’re assuming a plain text file)
  2. Tokenize: Tokenize each word
  3. Clean: Remove any punctuation and/or tokens that aren’t words.
  4. Count: Count the frequency of each word present in the text
  5. Answer: Return the top 10 (or 20, 50, 100)

Data Exploration

PySpark provide two main structures for storing data when performing manipulations:

  1. The Resilient Distributed Dataset (or RDD)
  2. The data frame; Stricter version of RDD. Makes heavy use of the concept of columns where you perform ops on columns instead of on records (like in RDD).
    • More common than RDD.
    • Syntax is similar to SQL

RDD vs Dataframe

Reading a dataframe with spark.read

Reading data into a data frame is done through the DataFrameReader object, which we can access through spark.read.

value: string is the column, with text within that column

In [13]:
book = spark.read.text("data/Ch02/1342-0.txt")
book
Out[13]:
DataFrame[value: string]
In [15]:
# Check schema
display(book.printSchema())

display(book.dtypes)
root
 |-- value: string (nullable = true)

None
[('value', 'string')]

Showing a dataframe with spark.show()

The show() method takes three optional parameters.

  1. n can be set to any positive integer, and will display that number of rows.
  2. truncate, if set to true, will truncate the columns to display only 20 characters. Set to False to display the whole length, or any positive integer to truncate to a specific number of characters.
  3. vertical takes a Boolean value and, when set to True, will display each record as a small table. If you need to check some records in detail, this is a very useful option.
In [23]:
# play with params
book.show(2, truncate=False, vertical=True)
-RECORD 0-------------------------------------------------------------------
 value | The Project Gutenberg EBook of Pride and Prejudice, by Jane Austen 
-RECORD 1-------------------------------------------------------------------
 value |                                                                    
only showing top 2 rows

Lazy vs Eager Evaluation

  • Default, you need to pass show() to see dataframe content. This follow's Spark's idea of lazy evaluation until some action is needed.
  • Since Spark 2.4.0, you can configure the SparkSession object to support printing to screen. This may be helpful when learning:
from pyspark.sql import SparkSession

spark = (SparkSession.builder
                     .config("spark.sql.repl.eagerEval.enabled", "True")
                     .getOrCreate())

Tokenizing sentences with select() and split()

select() selects the data. Similar to SQL. Syntax is similar to pandas:

book.select(book.value)
book.select(book["value"])
book.select(col("value"))
book.select("value")

split() transforms string column into an array column, containing n string elements (i.e. tokens). Note that it uses JVM-based regex instead of Python.

alias() renames transformed columns for easier reference. When applied to a column, it takes a single string as an argument.

Another way to alias set an alias is calling .withColumnRenamed() on the data frame. If you just want to rename a column without changing the rest of the data frame, use .withColumnRenamed.

In [45]:
from pyspark.sql.functions import col, split

# Read, tokenize and alias the column
lines = book.select(split(col('value'), " ").alias("line"))

display(lines)

lines.printSchema()

lines.show(5)
DataFrame[line: array<string>]
root
 |-- line: array (nullable = true)
 |    |-- element: string (containsNull = true)

+--------------------+
|                line|
+--------------------+
|[The, Project, Gu...|
|                  []|
|[This, eBook, is,...|
|[almost, no, rest...|
|[re-use, it, unde...|
+--------------------+
only showing top 5 rows

In [46]:
# Changing alias name using withColumnRenamed
alternative = lines.withColumnRenamed("line", 
                                      "here is an alternate alias")
alternative.printSchema()
root
 |-- here is an alternate alias: array (nullable = true)
 |    |-- element: string (containsNull = true)

Reshaping data with explode()

When applied to a column containing a container-like data structure (such as an array), explode() will take each element and give it its own row.

img

In [49]:
# Explode column of arrays into rows of elements

from pyspark.sql.functions import explode, col

words = lines.select(explode(col("line")).alias("word"))
words.show(10)
+----------+
|      word|
+----------+
|       The|
|   Project|
| Gutenberg|
|     EBook|
|        of|
|     Pride|
|       and|
|Prejudice,|
|        by|
|      Jane|
+----------+
only showing top 10 rows

String normalization

In [71]:
from pyspark.sql.functions import lower, regexp_extract

# Lowercase
words_lower = words.select(lower("word").alias("word_lower"))
words_lower.show()

# Naive punctuation normalization using regex
word_norm = words_lower.select(regexp_extract(col("word_lower"), "[a-z]*", 0).alias("word_normalized"))
word_norm.show()
+----------+
|word_lower|
+----------+
|       the|
|   project|
| gutenberg|
|     ebook|
|        of|
|     pride|
|       and|
|prejudice,|
|        by|
|      jane|
|    austen|
|          |
|      this|
|     ebook|
|        is|
|       for|
|       the|
|       use|
|        of|
|    anyone|
+----------+
only showing top 20 rows

+---------------+
|word_normalized|
+---------------+
|            the|
|        project|
|      gutenberg|
|          ebook|
|             of|
|          pride|
|            and|
|      prejudice|
|             by|
|           jane|
|         austen|
|               |
|           this|
|          ebook|
|             is|
|            for|
|            the|
|            use|
|             of|
|         anyone|
+---------------+
only showing top 20 rows

Filtering data

In [91]:
# Remove empty records

word_nonull = word_norm.filter(col("word_normalized") != "") \
                       .withColumnRenamed('word_normalized', 'word_nonull')
word_nonull.show()
+-----------+
|word_nonull|
+-----------+
|        the|
|    project|
|  gutenberg|
|      ebook|
|         of|
|      pride|
|        and|
|  prejudice|
|         by|
|       jane|
|     austen|
|       this|
|      ebook|
|         is|
|        for|
|        the|
|        use|
|         of|
|     anyone|
|   anywhere|
+-----------+
only showing top 20 rows

Exercises

2.1

Rewrite the following code snippet, removing the withColumnRenamed method. Which version is clearer and easier to read?

from pyspark.sql.functions import col, length

# The `length` function returns the number of characters in a string column.

ex21 = (
    spark.read.text("./data/Ch02/1342-0.txt")
    .select(length(col("value")))
    .withColumnRenamed("length(value)", "number_of_char")
)
In [77]:
from pyspark.sql.functions import col, length
ex21 = (
    spark.read.text("./data/Ch02/1342-0.txt")
    .select(length(col("value")).alias('values'))
)
ex21.show(5)
+------+
|values|
+------+
|    66|
|     0|
|    64|
|    68|
|    67|
+------+
only showing top 5 rows

2.2

The following code blocks gives an error. What is the problem and how can you solve it?

from pyspark.sql.functions import col, greatest

ex22.printSchema()
# root
#  |-- key: string (containsNull = true)
#  |-- value1: long (containsNull = true)
#  |-- value2: long (containsNull = true)

# `greatest` will return the greatest value of the list of column names,
# skipping null value

# The following statement will return an error
ex22.select(
    greatest(col("value1"), col("value2")).alias("maximum_value")
).select(
    "key", "max_value"
)

Answer

The columns given are not in a list?

2.3

Let’s take our words_nonull data frame, available in listing 2.19. You can use the code in the repository (code/Ch02/end_of_chapter.py) into your REPL to get the data frame loaded.

a) Remove all of the occurrences of the word "is"

b) (Challenge) Using the length function explained in exercise 2.1, keep only the words with more than 3 characters.

In [102]:
# 1. Remove all of the occurences of the word "is",
# 2. Using the length function explained in exercise 2.1, keep only the words with more than 3 characters.
word_nonull.filter(col("word_nonull") != "is") \
           .filter(length(col("word_nonull")) > 3) \
           .withColumnRenamed('word_nonull', 'words_greater_than_3') \
           .show()
+--------------------+
|words_greater_than_3|
+--------------------+
|             project|
|           gutenberg|
|               ebook|
|               pride|
|           prejudice|
|                jane|
|              austen|
|                this|
|               ebook|
|              anyone|
|            anywhere|
|                cost|
|                with|
|              almost|
|        restrictions|
|          whatsoever|
|                copy|
|                give|
|                away|
|               under|
+--------------------+
only showing top 20 rows

2.4

Remove the words is, not, the and if from your list of words, using a single where() method on the words_nonull data frame (see exercise 2.3). Write the code to do so.

In [103]:
word_nonull.where(~col("word_nonull").isin(['is', 'not', 'the', 'if'])) \
           .show()
+-----------+
|word_nonull|
+-----------+
|    project|
|  gutenberg|
|      ebook|
|         of|
|      pride|
|        and|
|  prejudice|
|         by|
|       jane|
|     austen|
|       this|
|      ebook|
|        for|
|        use|
|         of|
|     anyone|
|   anywhere|
|         at|
|         no|
|       cost|
+-----------+
only showing top 20 rows

2.5

One of your friends come to you with the following code. They have no idea why it doesn’t work. Can you diagnose the problem, explain why it is an error and provide a fix?

from pyspark.sql.functions import col, split

book = spark.read.text("./data/ch02/1342-0.txt")

book = book.printSchema()

lines = book.select(split(book.value, " ").alias("line"))

words = lines.select(explode(col("line")).alias("word"))
Answer

They're assigning the output of book.printSchema() to book, hence writing over the spark data frame.

Solution
In [113]:
from pyspark.sql.functions import col, split

book = spark.read.text("./data/ch02/1342-0.txt")

# Don't assign it back to `book`
book.printSchema()

lines = book.select(split(book.value, " ").alias("line"))

words = lines.select(explode(col("line")).alias("word"))

words.show()
root
 |-- value: string (nullable = true)

+----------+
|      word|
+----------+
|       The|
|   Project|
| Gutenberg|
|     EBook|
|        of|
|     Pride|
|       and|
|Prejudice,|
|        by|
|      Jane|
|    Austen|
|          |
|      This|
|     eBook|
|        is|
|       for|
|       the|
|       use|
|        of|
|    anyone|
+----------+
only showing top 20 rows

Chapter 3: Submitting and scaling your first PySpark program

In [12]:
# Set up
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName(
    "Analyzing the vocabulary of Pride and Prejudice."
).getOrCreate()

spark.sparkContext.setLogLevel("ERROR")
In [13]:
# Data Frame Setup
# Set up
from pyspark.sql.functions import col, split, lower, explode, regexp_extract

book = spark.read.text("data/Ch02/1342-0.txt")
lines = book.select(split(col("value"), " ").alias("line"))
words = lines.select(explode(col("line")).alias("word"))
words_lower = words.select(lower("word").alias("word_lower"))
word_norm = words_lower.select(
    regexp_extract(col("word_lower"), "[a-z]*", 0).alias("word_normalized")
)
word_nonull = word_norm.filter(col("word_normalized") != "").withColumnRenamed(
    "word_normalized", "word_nonull"
)

Aggregation: groupBy and count

  • GroupedData allows you to perform an aggregate function on each group.
  • Use groupby to count record occurrence, passing columns we want to group. Returned value is a GroupedData object, not a DataFrame. Once you apply a function to it like count(), it returns a DataFrame.
    • Note that groupby and groupBy are the same thing.
  • You can sort the output by orderBy
    • Note that orderBy only exists as camel case.
In [16]:
groups = word_nonull.groupBy(col("word_nonull"))
display(groups)

results = groups.count().orderBy("count", ascending=False)
results.show()
<pyspark.sql.group.GroupedData at 0x11b77b910>
+-----------+-----+
|word_nonull|count|
+-----------+-----+
|        the| 4480|
|         to| 4218|
|         of| 3711|
|        and| 3504|
|        her| 2199|
|          a| 1982|
|         in| 1909|
|        was| 1838|
|          i| 1750|
|        she| 1668|
|       that| 1487|
|         it| 1482|
|        not| 1427|
|        you| 1301|
|         he| 1296|
|         be| 1257|
|        his| 1247|
|         as| 1174|
|        had| 1170|
|       with| 1092|
+-----------+-----+
only showing top 20 rows

Writing to file: csv

  • data frame has write method, which can be chained with csv
  • default writes a bunch of separate files (1 file per partition) + _SUCCESS file.
  • use coalesce to concat to 1 file
  • use .mode('overwrite') to force write

TIP: Never assume that your data frame will keep the same ordering of records unless you explicitly ask via orderBy().

In [21]:
# Write multiple partitions + success file
results.write.mode("overwrite").csv("./output/results")

# Concatenate into 1 file, then write to disk
results.coalesce(1).write.mode("overwrite").csv("./output/result_single_partition")

Streamlining the code by chaining

Method chaining

In PySpark, every transformation returns an object, which is why we need to assign a variable to the result. This means that PySpark doesn’t perform modifications in place.

In [36]:
# qualified import; import the whole module
import pyspark.sql.functions as F

# chain methods together instead of multiple variables
results = (
    spark.read.text("./data/ch02/1342-0.txt")
    .select(F.split(F.col("value"), " ").alias("line"))
    .select(F.explode(F.col("line")).alias("word"))
    .select(F.lower(F.col("word")).alias("word"))
    .select(F.regexp_extract(F.col("word"), "[a-z']*", 0).alias("word"))
    .where(F.col("word") != "")
    .groupby("word")
    .count()
)

Submitting code in batch mode using spark-submit

When wrapping a script to be executed with spark-submit ratherh than with the pyspark command, you'll need to define your SparkSession first.

In [39]:
# This can be wrapped into a `word_counter.py` file and be executed
# using `spark-submit`

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.appName(
    "Analyzing the vocabulary of Pride and Prejudice."
).getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

results = (
    spark.read.text("./data/ch02/*.txt")
    .select(F.split(F.col("value"), " ").alias("line"))
    .select(F.explode(F.col("line")).alias("word"))
    .select(F.lower(F.col("word")).alias("word"))
    .select(F.regexp_extract(F.col("word"), "[a-z']*", 0).alias("word"))
    .where(F.col("word") != "")
    .groupby("word")
    .count()
    .orderBy("count", ascending=False)
)

results.show()
+----+-----+
|word|count|
+----+-----+
| the|38895|
| and|23919|
|  of|21199|
|  to|20526|
|   a|14464|
|   i|13973|
|  in|12777|
|that| 9623|
|  it| 9099|
| was| 8920|
| her| 7923|
|  my| 7385|
| his| 6642|
|with| 6575|
|  he| 6444|
|  as| 6439|
| you| 6295|
| had| 5718|
| she| 5617|
| for| 5425|
+----+-----+
only showing top 20 rows

Exercises

See chapter 3 code

Chapter 4: Analyzing tabular data with pyspark.sql

Summary

  • PySpark uses the SparkReader object to read any kind of data directly in a data frame. The specialized CSV SparkReader is used to ingest comma-separated value (CSV) files. Just like when reading text, the only mandatory parameter is the source location.
  • The CSV format is very versatile, so PySpark provides many optional parameters to account for this flexibility. The most important ones are the field delimiter, the record delimiter, and the quotation character. All of those parameters have sensible defaults.
  • PySpark can infer the Schema of a CSV file by setting the inferSchema optional parameter to True. PySpark accomplishes this by reading the data twice: once for setting the appropriate types for each columns, and another time to ingest the data in the inferred format.
  • Tabular data is represented into a data frame in a series of Columns, each having a name and a type. Since the data frame is a column-major data structure, the concept of row is less relevant.
  • You can use Python code to explore the data efficiently, using the column list as any Python list to expose the elements of the data frame of interest.
  • The most common operations on a data frame are the selection, deletion, and creation or columns. In PySpark, the methods used are select(), drop() and withColumn(), respectively.
  • select can be used for column re-ordering by passing a re-ordered list of columns.
  • You can rename columns one by one with the withColumnRenamed() method, or all at once by using the toDF() method.
  • You can display a summary of the columns with the describe() or summary() method. describe() has a fixed set of metrics, while summary() will take functions as parameters and apply them to all columns.

On dataframes

PySpark operates either on the whole data frame objects (via methods such as select() and groupby()) or on Column objects (for instance when using a function like split()).

  • The data frame is column-major, so its API focuses on manipulating the columns to transform the data.
  • Hence with data transformations, think about what operations to do and which columns will be impacted.
  • RDDs on the other hand are row-major. Hence you're thinking about items with attributes in which you apply functions.
In [1]:
# setup
import os
import numpy as np

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.getOrCreate()

Data Source Info

For this exercise, we’ll use some open data from the Government of Canada, more specifically the CRTC (Canadian Radio-television and Telecommunications Commission). Every broadcaster is mandated to provide a complete log of the programs, commercials and all, showcased to the Canadian public.

This gives us a lot of potential questions to answer, but we’ll select one specific one: what are the channels with the most and least proportion of commercials?

Creating a data frame

spark.createDataFrame

  • 1st param: data (list of lists, pandas dataframe, RDD)
  • 2nd param: schema (ie. think column headers in SQL)
  • Master node knows the structure of the dataframe, but actual data is on worker nodes (ie. cluster memory)
In [4]:
# Example creating a data frame with toy data
my_grocery_list = [
    ["Banana", 2, 1.74],
    ["Apple", 4, 2.04],
    ["Carrot", 1, 1.09],
    ["Cake", 1, 10.99],
]

df_grocery_list = spark.createDataFrame(my_grocery_list, ["Item", "Quantity", "Price"])

df_grocery_list.printSchema()
root
 |-- Item: string (nullable = true)
 |-- Quantity: long (nullable = true)
 |-- Price: double (nullable = true)

Reading a data frame

Data frame structure

Composed of row delimiter (e.g. newline \n) and column delimiter (e.g. tabs \t for TSVs)

In [48]:
DIRECTORY = "./data/Ch04"
logs = spark.read.csv(
    os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8_sample.CSV"),
    sep="|",  # default is ","
    quote='"',  # default is double quote.
    header=True,  # set first row as column names
    inferSchema=True,  # infer schema from column names default False
)
In [49]:
logs.printSchema()
root
 |-- BroadcastLogID: integer (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- LogDate: string (nullable = true)
 |-- SequenceNO: integer (nullable = true)
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- ProgramClassID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: string (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- SpecialAttentionID: integer (nullable = true)
 |-- BroadcastOriginPointID: integer (nullable = true)
 |-- CompositionID: integer (nullable = true)
 |-- Producer1: string (nullable = true)
 |-- Producer2: string (nullable = true)
 |-- Language1: integer (nullable = true)
 |-- Language2: integer (nullable = true)


Exercises

4.1

Take the following file, called sample.csv, and read it into a dataframe.

Item,Quantity,Price
$Banana, organic$,1,0.99
Pear,7,1.24
$Cake, chocolate$,1,14.50
In [17]:
sample = spark.read.csv(
    os.path.join(DIRECTORY, "ch4_exercise.csv"),
    sep=",",
    header=True,
    quote="$",
    inferSchema=True,
)

sample.show()
+---------------+--------+-----+
|           Item|Quantity|Price|
+---------------+--------+-----+
|Banana, organic|       1| 0.99|
|           Pear|       7| 1.24|
|Cake, chocolate|       1| 14.5|
+---------------+--------+-----+

4.2

Re-read the data in a logs_raw data frame, taking inspiration from the code in listing 4.3, this time without passing any optional parameters. Print the first 5 rows of data, as well as the schema. What are the differences in terms of data and schema between logs and logs_raw?

In [46]:
DIRECTORY = "./data/Ch04"
raw_logs = spark.read.csv(
    os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8_sample.CSV"),
)
raw_logs.show(5, False)  # False = show entire contents
raw_logs.printSchema()

# Result shows entire row concatenated into one column (_c0). Not what we want.
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|_c0                                                                                                                                                                                                                                                                                                                                                                                                                                        |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|BroadcastLogID|LogServiceID|LogDate|SequenceNO|AudienceTargetAgeID|AudienceTargetEthnicID|CategoryID|ClosedCaptionID|CountryOfOriginID|DubDramaCreditID|EthnicProgramID|ProductionSourceID|ProgramClassID|FilmClassificationID|ExhibitionID|Duration|EndTime|LogEntryDate|ProductionNO|ProgramTitle|StartTime|Subtitle|NetworkAffiliationID|SpecialAttentionID|BroadcastOriginPointID|CompositionID|Producer1|Producer2|Language1|Language2|
|1196192316|3157|2018-08-01|1|4||13|3|3|||10|19||2|02:00:00.0000000|08:00:00.0000000|2018-08-01|A39082|Newlywed and Dead|06:00:00.0000000||||||||94|                                                                                                                                                                                                                                                                                        |
|1196192317|3157|2018-08-01|2||||1|||||20|||00:00:30.0000000|06:13:45.0000000|2018-08-01||15-SPECIALTY CHANNELS-Canadian Generic|06:13:15.0000000|||||||||                                                                                                                                                                                                                                                                                  |
|1196192318|3157|2018-08-01|3||||1|||||3|||00:00:15.0000000|06:14:00.0000000|2018-08-01||3-PROCTER & GAMBLE INC-Anti-Perspirant 3rd|06:13:45.0000000|||||||||                                                                                                                                                                                                                                                                               |
|1196192319|3157|2018-08-01|4||||1|||||3|||00:00:15.0000000|06:14:15.0000000|2018-08-01||12-CREDIT KARMA-Bank/Credit Union/Trust 3rd|06:14:00.0000000|||||||||                                                                                                                                                                                                                                                                              |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
only showing top 5 rows

root
 |-- _c0: string (nullable = true)


Exploring the shape of our data universe

About Star Schema

Wiki:

In computing, the star schema is the simplest style of data mart schema and is the approach most widely used to develop data warehouses and dimensional data marts. The star schema consists of one or more fact tables referencing any number of dimension tables.

Star schemas are common in the relational database world because of normalization, a process used to avoid duplicating pieces of data and improve data integrity.

Spark uses denormalized tables (ie fat tables). Why? Mainly because it is easier to run analyses on a single table.

  • If you do need to analyze complex star schema, best bet is to work with a database manger to get a denormalized table.

select-ing what we want to see

Four ways to select colums in PySpark, all equivalent in term of results

In [50]:
# Using the string to column conversion
logs.select("BroadCastLogID", "LogServiceID", "LogDate")
logs.select(
    *["BroadCastLogID", "LogServiceID", "LogDate"]
)  # Unpack list with star prefix

# Passing the column object explicitly
logs.select(F.col("BroadCastLogID"), F.col("LogServiceID"), F.col("LogDate"))
logs.select(
    *[F.col("BroadCastLogID"), F.col("LogServiceID"), F.col("LogDate")]
)  # Unpack list with star prefix
Out[50]:
DataFrame[BroadCastLogID: int, LogServiceID: int, LogDate: string]

Because of the width of our data frame, we could split our columns into manageable sets of three to keep the output tidy on the screen. This gives a high-level view of what the data frame contains.

In [51]:
# Splitting columns in groups of three using numpy
display("Columns in groups of three")
column_split = np.array_split(np.array(logs.columns), len(logs.columns) // 3)
display(column_split)

# Show columns in groups of three
display("Table display in column groups of three")
for x in column_split:
    logs.select(*x).show(5, False)
'Columns in groups of three'
[array(['BroadcastLogID', 'LogServiceID', 'LogDate'], dtype='<U22'),
 array(['SequenceNO', 'AudienceTargetAgeID', 'AudienceTargetEthnicID'],
       dtype='<U22'),
 array(['CategoryID', 'ClosedCaptionID', 'CountryOfOriginID'], dtype='<U22'),
 array(['DubDramaCreditID', 'EthnicProgramID', 'ProductionSourceID'],
       dtype='<U22'),
 array(['ProgramClassID', 'FilmClassificationID', 'ExhibitionID'],
       dtype='<U22'),
 array(['Duration', 'EndTime', 'LogEntryDate'], dtype='<U22'),
 array(['ProductionNO', 'ProgramTitle', 'StartTime'], dtype='<U22'),
 array(['Subtitle', 'NetworkAffiliationID', 'SpecialAttentionID'],
       dtype='<U22'),
 array(['BroadcastOriginPointID', 'CompositionID', 'Producer1'],
       dtype='<U22'),
 array(['Producer2', 'Language1', 'Language2'], dtype='<U22')]
'Table display in column groups of three'
+--------------+------------+----------+
|BroadcastLogID|LogServiceID|LogDate   |
+--------------+------------+----------+
|1196192316    |3157        |2018-08-01|
|1196192317    |3157        |2018-08-01|
|1196192318    |3157        |2018-08-01|
|1196192319    |3157        |2018-08-01|
|1196192320    |3157        |2018-08-01|
+--------------+------------+----------+
only showing top 5 rows

+----------+-------------------+----------------------+
|SequenceNO|AudienceTargetAgeID|AudienceTargetEthnicID|
+----------+-------------------+----------------------+
|1         |4                  |null                  |
|2         |null               |null                  |
|3         |null               |null                  |
|4         |null               |null                  |
|5         |null               |null                  |
+----------+-------------------+----------------------+
only showing top 5 rows

+----------+---------------+-----------------+
|CategoryID|ClosedCaptionID|CountryOfOriginID|
+----------+---------------+-----------------+
|13        |3              |3                |
|null      |1              |null             |
|null      |1              |null             |
|null      |1              |null             |
|null      |1              |null             |
+----------+---------------+-----------------+
only showing top 5 rows

+----------------+---------------+------------------+
|DubDramaCreditID|EthnicProgramID|ProductionSourceID|
+----------------+---------------+------------------+
|null            |null           |10                |
|null            |null           |null              |
|null            |null           |null              |
|null            |null           |null              |
|null            |null           |null              |
+----------------+---------------+------------------+
only showing top 5 rows

+--------------+--------------------+------------+
|ProgramClassID|FilmClassificationID|ExhibitionID|
+--------------+--------------------+------------+
|19            |null                |2           |
|20            |null                |null        |
|3             |null                |null        |
|3             |null                |null        |
|3             |null                |null        |
+--------------+--------------------+------------+
only showing top 5 rows

+----------------+----------------+------------+
|Duration        |EndTime         |LogEntryDate|
+----------------+----------------+------------+
|02:00:00.0000000|08:00:00.0000000|2018-08-01  |
|00:00:30.0000000|06:13:45.0000000|2018-08-01  |
|00:00:15.0000000|06:14:00.0000000|2018-08-01  |
|00:00:15.0000000|06:14:15.0000000|2018-08-01  |
|00:00:15.0000000|06:14:30.0000000|2018-08-01  |
+----------------+----------------+------------+
only showing top 5 rows

+------------+-------------------------------------------+----------------+
|ProductionNO|ProgramTitle                               |StartTime       |
+------------+-------------------------------------------+----------------+
|A39082      |Newlywed and Dead                          |06:00:00.0000000|
|null        |15-SPECIALTY CHANNELS-Canadian Generic     |06:13:15.0000000|
|null        |3-PROCTER & GAMBLE INC-Anti-Perspirant 3rd |06:13:45.0000000|
|null        |12-CREDIT KARMA-Bank/Credit Union/Trust 3rd|06:14:00.0000000|
|null        |3-L'OREAL CANADA-Hair Products 3rd         |06:14:15.0000000|
+------------+-------------------------------------------+----------------+
only showing top 5 rows

+--------+--------------------+------------------+
|Subtitle|NetworkAffiliationID|SpecialAttentionID|
+--------+--------------------+------------------+
|null    |null                |null              |
|null    |null                |null              |
|null    |null                |null              |
|null    |null                |null              |
|null    |null                |null              |
+--------+--------------------+------------------+
only showing top 5 rows

+----------------------+-------------+---------+
|BroadcastOriginPointID|CompositionID|Producer1|
+----------------------+-------------+---------+
|null                  |null         |null     |
|null                  |null         |null     |
|null                  |null         |null     |
|null                  |null         |null     |
|null                  |null         |null     |
+----------------------+-------------+---------+
only showing top 5 rows

+---------+---------+---------+
|Producer2|Language1|Language2|
+---------+---------+---------+
|null     |94       |null     |
|null     |null     |null     |
|null     |null     |null     |
|null     |null     |null     |
|null     |null     |null     |
+---------+---------+---------+
only showing top 5 rows

drop-ing columns we don't need

Remove BroadCastLogID (primary key not needed in single table) and SequenceNo. drop() returns a new data frame.

Warning with drop: Unlike select(), where selecting a column that doesn’t exist will return a runtime error, dropping a non-existent column is a no-op. PySpark will just ignore the columns it doesn’t find. Careful with the spelling of your column names!

In [53]:
logs = logs.drop("BroadCastLogID", "SequenceNo")

assert all(col not in logs.columns for col in ["BroadCastLogID", "SequenceNo"])

Alternate method of above just using select using list comprehension.

In [52]:
logs = logs.select(
    *[col for col in logs.columns if col not in ["BroadCastLogID", "SequenceNo"]]
)

assert all(col not in logs.columns for col in ["BroadCastLogID", "SequenceNo"])

Exercises

4.3

Create a new data frame logs_clean that contains only the columns that do not end with ID

In [74]:
print([col for col in logs.columns if col[-2:] != "ID"])
['LogDate', 'SequenceNO', 'Duration', 'EndTime', 'LogEntryDate', 'ProductionNO', 'ProgramTitle', 'StartTime', 'Subtitle', 'Producer1', 'Producer2', 'Language1', 'Language2']
In [98]:
# Load original CSV again
DIRECTORY = "./data/Ch04"
logs = spark.read.csv(
    os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8_sample.CSV"),
    sep="|",  # default is ","
    quote='"',  # default is double quote.
    header=True,  # set first row as column names
    inferSchema=True,  # infer schema from column names default False
)

# Filter to columns that don't end with "ID"
logs_no_id = logs.select(*[col for col in logs.columns if col[-2:].lower() != "id"])
print("Filtered results (not end with 'ID')")
logs_no_id.printSchema()

assert all("id" not in col[-2:] for col in logs_no_id.columns)
Filtered results (not end with 'ID')
root
 |-- LogDate: string (nullable = true)
 |-- SequenceNO: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: string (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- Producer1: string (nullable = true)
 |-- Producer2: string (nullable = true)
 |-- Language1: integer (nullable = true)
 |-- Language2: integer (nullable = true)

Creating new columns with withColumn

1. Check the data type of 'Duration' column

In [103]:
logs.select(F.col("Duration")).show(5)

print(
    "dtype of 'Duration' column is 'string'. Best to convert to timestamp:\n",
    logs.select(F.col("Duration")).dtypes,
)
+----------------+
|        Duration|
+----------------+
|02:00:00.0000000|
|00:00:30.0000000|
|00:00:15.0000000|
|00:00:15.0000000|
|00:00:15.0000000|
+----------------+
only showing top 5 rows

dtype of 'Duration' column is 'string'. Best to convert to timestamp:
 [('Duration', 'string')]

2. Extract time features from Duration column only show distinct

In [116]:
logs.select(
    F.col("Duration"),
    F.col("Duration").substr(1, 2).cast("int").alias("hours"),
    F.col("Duration").substr(4, 2).cast("int").alias("minutes"),
    F.col("Duration").substr(7, 2).cast("int").alias("seconds"),
    # Add final column converting duration into total seconds
    (
        F.col("Duration").substr(1, 2).cast("int") * 60 * 60
        + F.col("Duration").substr(4, 2).cast("int") * 60
        + F.col("Duration").substr(7, 2).cast("int")
    ).alias("duration_seconds"),
).distinct().show(
    5
)  # only show distinct entries
+----------------+-----+-------+-------+----------------+
|        Duration|hours|minutes|seconds|duration_seconds|
+----------------+-----+-------+-------+----------------+
|00:00:19.0000000|    0|      0|     19|              19|
|00:07:09.0000000|    0|      7|      9|             429|
|00:53:26.0000000|    0|     53|     26|            3206|
|00:30:43.0000000|    0|     30|     43|            1843|
|00:02:41.0000000|    0|      2|     41|             161|
+----------------+-----+-------+-------+----------------+
only showing top 5 rows

3. Use withColumn() to add 'duration_seconds' to original data frame

In [122]:
logs = logs.withColumn(
    "duration_seconds",
    F.col("Duration").substr(1, 2).cast("int") * 60 * 60
    + F.col("Duration").substr(4, 2).cast("int") * 60
    + F.col("Duration").substr(7, 2).cast("int"),
)

assert "duration_seconds" in logs.columns

Warning: If you’re creating a column withColumn() and give it a name that already exists in your data frame, PySpark will happily overwrite the column.

Batch renaming with toDF()

In [124]:
logs.toDF(*[x.lower() for x in logs.columns]).printSchema()
root
 |-- broadcastlogid: integer (nullable = true)
 |-- logserviceid: integer (nullable = true)
 |-- logdate: string (nullable = true)
 |-- sequenceno: integer (nullable = true)
 |-- audiencetargetageid: integer (nullable = true)
 |-- audiencetargetethnicid: integer (nullable = true)
 |-- categoryid: integer (nullable = true)
 |-- closedcaptionid: integer (nullable = true)
 |-- countryoforiginid: integer (nullable = true)
 |-- dubdramacreditid: integer (nullable = true)
 |-- ethnicprogramid: integer (nullable = true)
 |-- productionsourceid: integer (nullable = true)
 |-- programclassid: integer (nullable = true)
 |-- filmclassificationid: integer (nullable = true)
 |-- exhibitionid: integer (nullable = true)
 |-- duration: string (nullable = true)
 |-- endtime: string (nullable = true)
 |-- logentrydate: string (nullable = true)
 |-- productionno: string (nullable = true)
 |-- programtitle: string (nullable = true)
 |-- starttime: string (nullable = true)
 |-- subtitle: string (nullable = true)
 |-- networkaffiliationid: integer (nullable = true)
 |-- specialattentionid: integer (nullable = true)
 |-- broadcastoriginpointid: integer (nullable = true)
 |-- compositionid: integer (nullable = true)
 |-- producer1: string (nullable = true)
 |-- producer2: string (nullable = true)
 |-- language1: integer (nullable = true)
 |-- language2: integer (nullable = true)
 |-- duration_seconds: integer (nullable = true)

Sorting column order with sort

In [125]:
logs.select(sorted(logs.columns)).printSchema()
root
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- BroadcastLogID: integer (nullable = true)
 |-- BroadcastOriginPointID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CompositionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- Language1: integer (nullable = true)
 |-- Language2: integer (nullable = true)
 |-- LogDate: string (nullable = true)
 |-- LogEntryDate: string (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- Producer1: string (nullable = true)
 |-- Producer2: string (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- ProgramClassID: integer (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- SequenceNO: integer (nullable = true)
 |-- SpecialAttentionID: integer (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- duration_seconds: integer (nullable = true)

Getting a high level summary of your dataframe with describe and summary

  • describe only works for numerical and string columns
In [127]:
# Show stats for the first three columns
for i in logs.columns[:3]:
    logs.describe(i).show()
+-------+--------------------+
|summary|      BroadcastLogID|
+-------+--------------------+
|  count|              238945|
|   mean|1.2168651122760174E9|
| stddev| 1.496913424143109E7|
|    min|          1195788151|
|    max|          1249431576|
+-------+--------------------+

+-------+------------------+
|summary|      LogServiceID|
+-------+------------------+
|  count|            238945|
|   mean| 3450.890284375065|
| stddev|199.50673962555592|
|    min|              3157|
|    max|              3925|
+-------+------------------+

+-------+----------+
|summary|   LogDate|
+-------+----------+
|  count|    238945|
|   mean|      null|
| stddev|      null|
|    min|2018-08-01|
|    max|2018-08-01|
+-------+----------+

  • summary shows extra stats like 25-50% and 75% percentiles
In [135]:
# Show stats for the first three columns
for i in logs.columns[:3]:
    logs.select(i).summary().show()
+-------+--------------------+
|summary|      BroadcastLogID|
+-------+--------------------+
|  count|              238945|
|   mean|1.2168651122760174E9|
| stddev| 1.496913424143109E7|
|    min|          1195788151|
|    25%|          1249431576|
|    50%|          1213242718|
|    75%|          1226220081|
|    max|          1249431576|
+-------+--------------------+

+-------+------------------+
|summary|      LogServiceID|
+-------+------------------+
|  count|            238945|
|   mean| 3450.890284375065|
| stddev|199.50673962555592|
|    min|              3157|
|    25%|              3287|
|    50%|              3379|
|    75%|              3627|
|    max|              3925|
+-------+------------------+

+-------+----------+
|summary|   LogDate|
+-------+----------+
|  count|    238945|
|   mean|      null|
| stddev|      null|
|    min|2018-08-01|
|    25%|      null|
|    50%|      null|
|    75%|      null|
|    max|2018-08-01|
+-------+----------+

Write to file

In [138]:
# write checkpoint file
logs.coalesce(1).write.mode("overwrite").csv("./output/ch04/logs.csv", header=True)
In [137]:
logs.printSchema()
root
 |-- BroadcastLogID: integer (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- LogDate: string (nullable = true)
 |-- SequenceNO: integer (nullable = true)
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- ProgramClassID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: string (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- SpecialAttentionID: integer (nullable = true)
 |-- BroadcastOriginPointID: integer (nullable = true)
 |-- CompositionID: integer (nullable = true)
 |-- Producer1: string (nullable = true)
 |-- Producer2: string (nullable = true)
 |-- Language1: integer (nullable = true)
 |-- Language2: integer (nullable = true)
 |-- duration_seconds: integer (nullable = true)

Chapter 5: Joining and Grouping Data

In [59]:
# Set up
import os
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException
import pyspark.sql.functions as F


spark = SparkSession.builder.getOrCreate()

# Read the data
DIRECTORY = "./data/Ch04"
logs = spark.read.csv(
    "./output/ch04/logs.csv", # read in data transformed in Ch04
    sep=",",  # default is ","
    quote='"',  # default is double quote.
    header=True,  # set first row as column names
    inferSchema=True,  # infer schema from column names default False
)
logs.printSchema()


# Read link table and filter to only primary channels (ie. PrimaryFG == 1)
log_identifier = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables", "LogIdentifier.csv"),
    sep="|",
    header=True,
    inferSchema=True,
)
log_identifier = log_identifier.where(F.col("PrimaryFG") == 1)


# Show results
log_identifier.printSchema()
log_identifier.show(5)
print("Unique primary channels: ", log_identifier.count())
root
 |-- BroadcastLogID: integer (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- LogDate: string (nullable = true)
 |-- SequenceNO: integer (nullable = true)
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- ProgramClassID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: string (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- SpecialAttentionID: integer (nullable = true)
 |-- BroadcastOriginPointID: integer (nullable = true)
 |-- CompositionID: integer (nullable = true)
 |-- Producer1: string (nullable = true)
 |-- Producer2: string (nullable = true)
 |-- Language1: integer (nullable = true)
 |-- Language2: integer (nullable = true)
 |-- duration_seconds: integer (nullable = true)

root
 |-- LogIdentifierID: string (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- PrimaryFG: integer (nullable = true)

+---------------+------------+---------+
|LogIdentifierID|LogServiceID|PrimaryFG|
+---------------+------------+---------+
|           13ST|        3157|        1|
|         2000SM|        3466|        1|
|           70SM|        3883|        1|
|           80SM|        3590|        1|
|           90SM|        3470|        1|
+---------------+------------+---------+
only showing top 5 rows

Unique primary channels:  758

Understanding the join recipe

[LEFT].join(
    [RIGHT],
    on=[PREDICTES],
    how=[METHOD]
)

Important points

  1. If one record in the left table resolves the predicate with more than one record in the right table (or vice versa), this record will be duplicated in the joined table.
  2. If one record in the left or in the right table does not resolve the predicate with any record in the other table, it will not be present in the resulting table, unless the join method specifies a protocol for failed predicates.

Pyspark helpers in join logic

  • You can put multiple and predicates into a list, like:
    [
          left["col1"] == right["colA"], 
          left["col2"] > right["colB"],  # value on left table is greater than the right
          left["col3"] != right["colC"]
      ]
    
  • You can test equality just by specifying the column name, or list of column names

Setting up join logic with how

  1. cross - returns a record for every record pair. not common.
  2. inner = returns record if predicate is true, otherwise drops it. most common, pyspark join default.
  3. left & right - similar to inner, except on what to do with false predicates:
    • left join adds unmatched records from the left table in the joined table, and fills in columns from right able with None
    • right join adds unmatched records nad fills in column vice versa.
  4. outer - adds unmatched records from the left and right able, padding with None.
  5. left_semi - same as inner join but only keeps columns in left table.
  6. left_anti - returns only records that don't match the predicate with any record in the right table. opposite of left join.
In [60]:
# Join `logs` with `log_identifier` using the 'LogServiceID' column
joined = logs.join(log_identifier, on="LogServiceID", how="inner")
In [61]:
# Additionally join CategoryID and ProgramClassID table
# Use left joins since keys may not be available in the link table.

# CategoryID
cd_category = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables", "CD_Category.csv"),
    sep="|",
    header=True,
    inferSchema=True,
).select(
    "CategoryID",
    "CategoryCD",
    F.col("EnglishDescription").alias("Category_Description"),
)

# ProgramClass
cd_program_class = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables", "CD_ProgramClass.csv"),
    sep="|",
    header=True,
    inferSchema=True,
).select(
    "ProgramClassID",
    "ProgramClassCD",
    F.col("EnglishDescription").alias("ProgramClass_Description"),
)


# Join all to joined table
full_log = joined.join(cd_category, "CategoryID", how="left",).join(
    cd_program_class, "ProgramClassID", how="left",
)


# Check if additional columns were joined to original log data frame
full_log.printSchema()
root
 |-- ProgramClassID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- BroadcastLogID: integer (nullable = true)
 |-- LogDate: string (nullable = true)
 |-- SequenceNO: integer (nullable = true)
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: string (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- SpecialAttentionID: integer (nullable = true)
 |-- BroadcastOriginPointID: integer (nullable = true)
 |-- CompositionID: integer (nullable = true)
 |-- Producer1: string (nullable = true)
 |-- Producer2: string (nullable = true)
 |-- Language1: integer (nullable = true)
 |-- Language2: integer (nullable = true)
 |-- duration_seconds: integer (nullable = true)
 |-- LogIdentifierID: string (nullable = true)
 |-- PrimaryFG: integer (nullable = true)
 |-- CategoryCD: string (nullable = true)
 |-- Category_Description: string (nullable = true)
 |-- ProgramClassCD: string (nullable = true)
 |-- ProgramClass_Description: string (nullable = true)

Warning: What happens when joining columns in a distributed environment

To be able to process a comparison between records, the data needs to be on the same machine. If not, PySpark will move the data in an operation called a shuffle, which is slow and expensive. More on join strategies in later chapters.

Warning: Joining tables with identically named columns leads to errors downstream

PySpark happily joins the two data frames together but fails when we try to work with the ambiguous column.

In [62]:
# Joining two tables with the same LogServiceID column
logs_and_channels_verbose = logs.join(
    log_identifier, logs["LogServiceID"] == log_identifier["LogServiceID"]
)
logs_and_channels_verbose.printSchema()


print(
    'Joined table now has two "LogServiceID" columns: ',
    [col for col in logs_and_channels_verbose.columns if col == "LogServiceID"],
    "\n",
)
print('Selecting "LogServiceID" will now throw an error')


# Selecting "LogServiceID" will throw an error
try:
    logs_and_channels_verbose.select("LogServiceID")
except AnalysisException as err:
    print("AnalysisException: ", err)
root
 |-- BroadcastLogID: integer (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- LogDate: string (nullable = true)
 |-- SequenceNO: integer (nullable = true)
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- ProgramClassID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: string (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- SpecialAttentionID: integer (nullable = true)
 |-- BroadcastOriginPointID: integer (nullable = true)
 |-- CompositionID: integer (nullable = true)
 |-- Producer1: string (nullable = true)
 |-- Producer2: string (nullable = true)
 |-- Language1: integer (nullable = true)
 |-- Language2: integer (nullable = true)
 |-- duration_seconds: integer (nullable = true)
 |-- LogIdentifierID: string (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- PrimaryFG: integer (nullable = true)

Joined table now has two "LogServiceID" columns:  ['LogServiceID', 'LogServiceID'] 

Selecting "LogServiceID" will now throw an error
AnalysisException:  Reference 'LogServiceID' is ambiguous, could be: LogServiceID, LogServiceID.;

Solutions for preventing ambiguous column references

  1. Use simplified syntax (ie. passing string of column you want). Auto-removes second instance of predicate column. Can only use on equi-joins.
     logs_and_channels = logs.join(log_identifier, "LogServiceID")
  2. Refer to the pre-existing table name.
     logs_and_channels_verbose.select(log_identifier["LogServiceID"])
  3. Use the Column object directly

     logs_and_channels_verbose = logs.alias("left").join(
     log_identifier.alias("right"),
     logs["LogServiceID"] == log_identifier["LogServiceID"],
     )
    
     logs_and_channels_verbose.drop(F.col("right.LogServiceID")).select(
         "LogServiceID"
     )

Advanced groupby with GroupedData

Goal: What channels have the most and least proportion of commercials?

Task:

  1. Get number of seconds when the program is a commerical
  2. Get total number of seconds.

groupby on multiple columns

  • Grouped by results are GroupedData objects, not data frame. Can't call show() on it.
  • You can "show" by running summary functions on it, like F.sum.
  • GroupedData object holds all non-key columns in a group cell (see fig 5.7)

grouped

agg() vs sum()

  • agg can take an arbitrary number of aggregate functions
  • You can alias resulting columns, unlike sum
In [65]:
# Group by ProgramClassCD and ProgramClass_Description, sum total duration for each

full_log.groupby("ProgramClassCD", "ProgramClass_Description").agg(
    F.sum("duration_seconds").alias("duration_total")
).orderBy("duration_total", ascending=False).show(100, False)


# Another way by passing dictionary to agg
# full_log.groupby("ProgramClassCD", "ProgramClass_Description").agg(
#     {"duration_seconds": "sum"}
# ).withColumnRenamed("sum(duration_seconds)", "duration_total").orderBy(
#     "duration_total", ascending=False
# ).show(
#     100, False
# )
+--------------+--------------------------------------+--------------+
|ProgramClassCD|ProgramClass_Description              |duration_total|
+--------------+--------------------------------------+--------------+
|PGR           |PROGRAM                               |20992510      |
|COM           |COMMERCIAL MESSAGE                    |3519163       |
|PFS           |PROGRAM FIRST SEGMENT                 |1344762       |
|SEG           |SEGMENT OF A PROGRAM                  |1205998       |
|PRC           |PROMOTION OF UPCOMING CANADIAN PROGRAM|880600        |
|PGI           |PROGRAM INFOMERCIAL                   |679182        |
|PRO           |PROMOTION OF NON-CANADIAN PROGRAM     |335701        |
|OFF           |SCHEDULED OFF AIR TIME PERIOD         |142279        |
|ID            |NETWORK IDENTIFICATION MESSAGE        |74926         |
|NRN           |No recognized nationality             |59686         |
|MAG           |MAGAZINE PROGRAM                      |57622         |
|PSA           |PUBLIC SERVICE ANNOUNCEMENT           |51214         |
|SO            |MAY IDENTIFY THE SIGN ON\OFF OF A DAY |32509         |
|OFT           |OFF AIR DUE TO TECHNICAL DIFFICULTY   |18263         |
|LOC           |LOCAL ADVERTISING                     |13294         |
|MVC           |MUSIC VIDEO CLIP                      |7907          |
|REG           |REGIONAL                              |6749          |
|MER           |MERCHANDISING                         |1680          |
|SPO           |SPONSORSHIP MESSAGE                   |1544          |
|SOL           |SOLICITATION MESSAGE                  |596           |
|MOS           |Mosaic                                |null          |
|COR           |CORNERSTONE                           |null          |
+--------------+--------------------------------------+--------------+

Using agg with custom column definitions

when logic:

(
F.when([BOOLEAN TEST], [RESULT IF TRUE])
 .when([ANOTHER BOOLEAN TEST], [RESULT IF TRUE])
 .otherwise([DEFAULT RESULT, WILL DEFAULT TO null IF OMITTED])
)
In [78]:
# Goal: Compute only the commercial time for each program


# Create custom column logic - get duration_seconds if ProgramClassCD matches an item in
# the list
is_commercial = F.when(
    F.trim(F.col("ProgramClassCD")).isin(
        ["COM", "PRC", "PGI", "PRO", "LOC", "SPO", "MER", "SOL"]
    ),
    F.col("duration_seconds"),
).otherwise(0)


# Use custom column logic to build a duration_commercial column,
# along with duration_total
commercial_time = (
    full_log.groupby("LogIdentifierID")
    .agg(
        F.sum(is_commercial).alias("duration_commercial"),
        F.sum("duration_seconds").alias("duration_total"),
    )
    .withColumn(
        "commercial_ratio", F.col("duration_commercial") / F.col("duration_total")
    )
)

commercial_time.orderBy("commercial_ratio", ascending=False).show(20, False)
+---------------+-------------------+--------------+------------------+
|LogIdentifierID|duration_commercial|duration_total|commercial_ratio  |
+---------------+-------------------+--------------+------------------+
|CIMT           |775                |775           |1.0               |
|TELENO         |17790              |17790         |1.0               |
|MSET           |2700               |2700          |1.0               |
|HPITV          |13                 |13            |1.0               |
|TLNSP          |15480              |15480         |1.0               |
|TANG           |8125               |8125          |1.0               |
|MMAX           |23333              |23582         |0.9894410991434145|
|MPLU           |20587              |20912         |0.9844586840091814|
|INVST          |20094              |20470         |0.9816316560820714|
|ZT�L�          |21542              |21965         |0.9807420896881403|
|RAPT           |17916              |18279         |0.9801411455768915|
|CANALD         |21437              |21875         |0.9799771428571429|
|ONEBMS         |18084              |18522         |0.9763524457402009|
|CANALVIE       |20780              |21309         |0.975174808766249 |
|unis           |11630              |11998         |0.9693282213702283|
|CIVM           |11370              |11802         |0.9633960345704118|
|TV5            |10759              |11220         |0.9589126559714795|
|LEAF           |11526              |12034         |0.9577862722286854|
|VISION         |12946              |13621         |0.950444167094927 |
|CJIL           |3904               |4213          |0.9266555898409684|
+---------------+-------------------+--------------+------------------+
only showing top 20 rows

Dropping unwanted records - dropna + fillna

dropna

params
  1. how, which can take the value any or all. If any is selected, PySpark will drop records where at least one of the fields are null. In the case of all, only the records where all fields are null will be removed. By default, PySpark will take the any mode.
  2. thresh takes an integer value. If set (its default is None), PySpark will ignore the how parameter and only drop the records with less than thresh non-null values.
  3. subset will take an optional list of columns that drop will use to make its decision.
In [83]:
# Drop records that have a commericla_ratio of null

c_time_no_null = commercial_time.dropna(subset=["commercial_ratio"])
c_time_no_null.orderBy("commercial_ratio", ascending=False).show()


# Check record counts for each
print("Records in commercial_time: ", commercial_time.count())
print("Records in c_time_no_null: ", c_time_no_null.count())
+---------------+-------------------+--------------+------------------+
|LogIdentifierID|duration_commercial|duration_total|  commercial_ratio|
+---------------+-------------------+--------------+------------------+
|          HPITV|                 13|            13|               1.0|
|           CIMT|                775|           775|               1.0|
|           MSET|               2700|          2700|               1.0|
|          TLNSP|              15480|         15480|               1.0|
|         TELENO|              17790|         17790|               1.0|
|           TANG|               8125|          8125|               1.0|
|           MMAX|              23333|         23582|0.9894410991434145|
|           MPLU|              20587|         20912|0.9844586840091814|
|          INVST|              20094|         20470|0.9816316560820714|
|          ZT�L�|              21542|         21965|0.9807420896881403|
|           RAPT|              17916|         18279|0.9801411455768915|
|         CANALD|              21437|         21875|0.9799771428571429|
|         ONEBMS|              18084|         18522|0.9763524457402009|
|       CANALVIE|              20780|         21309| 0.975174808766249|
|           unis|              11630|         11998|0.9693282213702283|
|           CIVM|              11370|         11802|0.9633960345704118|
|            TV5|              10759|         11220|0.9589126559714795|
|           LEAF|              11526|         12034|0.9577862722286854|
|         VISION|              12946|         13621| 0.950444167094927|
|           CJIL|               3904|          4213|0.9266555898409684|
+---------------+-------------------+--------------+------------------+
only showing top 20 rows

Records in commercial_time:  324
Records in c_time_no_null:  322

fillna

params

  1. value, either a Python int, float, string or bool.
  2. subset, which columns to fill

Tip: You can fill nulls differently for each column by passing a dictionary:

answer_no_null = answer.fillna(
    {"duration_commercial": 0, "duration_total": 0, "commercial_ratio": 0}
)
In [84]:
# Fill null fields

c_time_fill_null = commercial_time.fillna(0)
c_time_fill_null.orderBy("commercial_ratio", ascending=False).show()


# Check record counts for each
print("Records in commercial_time: ", commercial_time.count())
print("Records in c_time_no_null: ", c_time_fill_null.count())
+---------------+-------------------+--------------+------------------+
|LogIdentifierID|duration_commercial|duration_total|  commercial_ratio|
+---------------+-------------------+--------------+------------------+
|           CIMT|                775|           775|               1.0|
|           MSET|               2700|          2700|               1.0|
|          TLNSP|              15480|         15480|               1.0|
|          HPITV|                 13|            13|               1.0|
|         TELENO|              17790|         17790|               1.0|
|           TANG|               8125|          8125|               1.0|
|           MMAX|              23333|         23582|0.9894410991434145|
|           MPLU|              20587|         20912|0.9844586840091814|
|          INVST|              20094|         20470|0.9816316560820714|
|          ZT�L�|              21542|         21965|0.9807420896881403|
|           RAPT|              17916|         18279|0.9801411455768915|
|         CANALD|              21437|         21875|0.9799771428571429|
|         ONEBMS|              18084|         18522|0.9763524457402009|
|       CANALVIE|              20780|         21309| 0.975174808766249|
|           unis|              11630|         11998|0.9693282213702283|
|           CIVM|              11370|         11802|0.9633960345704118|
|            TV5|              10759|         11220|0.9589126559714795|
|           LEAF|              11526|         12034|0.9577862722286854|
|         VISION|              12946|         13621| 0.950444167094927|
|           CJIL|               3904|          4213|0.9266555898409684|
+---------------+-------------------+--------------+------------------+
only showing top 20 rows

Records in commercial_time:  324
Records in c_time_no_null:  324

Chapter 6: Multi-dimensional data frames: using PySpark with JSON data

In [184]:
# Set up
import os
import numpy as np
import json
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException
import pyspark.sql.functions as F
import pyspark.sql.types as T


spark = SparkSession.builder.getOrCreate()

Reading the Data

For this chapter, we use a JSON dump of the information about the TV Show Silicon Valley, from TV Maze.

JSON params

  • No need for delimiters like CSV
  • No need to infer data type
  • Contains hierarchical data, unlike CSVs
  • Single JSON: one JSON document, one line, one record.
  • Multiple JSON (multiLine): one JSON document, one FILE, one record.
In [185]:
# Import a single JSON document
sv = "data/ch06/shows-silicon-valley.json"
shows = spark.read.json(sv)
display(shows.count())


# Read multiple JSON documents using multiLine param
three_shows = spark.read.json("data/ch06/shows-*.json", multiLine=True)
display(three_shows.count())
1
3
In [186]:
# Inspect the schema
shows.printSchema()
root
 |-- _embedded: struct (nullable = true)
 |    |-- episodes: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- _links: struct (nullable = true)
 |    |    |    |    |-- self: struct (nullable = true)
 |    |    |    |    |    |-- href: string (nullable = true)
 |    |    |    |-- airdate: string (nullable = true)
 |    |    |    |-- airstamp: timestamp (nullable = true)
 |    |    |    |-- airtime: string (nullable = true)
 |    |    |    |-- id: long (nullable = true)
 |    |    |    |-- image: struct (nullable = true)
 |    |    |    |    |-- medium: string (nullable = true)
 |    |    |    |    |-- original: string (nullable = true)
 |    |    |    |-- name: string (nullable = true)
 |    |    |    |-- number: long (nullable = true)
 |    |    |    |-- runtime: long (nullable = true)
 |    |    |    |-- season: long (nullable = true)
 |    |    |    |-- summary: string (nullable = true)
 |    |    |    |-- url: string (nullable = true)
 |-- _links: struct (nullable = true)
 |    |-- previousepisode: struct (nullable = true)
 |    |    |-- href: string (nullable = true)
 |    |-- self: struct (nullable = true)
 |    |    |-- href: string (nullable = true)
 |-- externals: struct (nullable = true)
 |    |-- imdb: string (nullable = true)
 |    |-- thetvdb: long (nullable = true)
 |    |-- tvrage: long (nullable = true)
 |-- genres: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- id: long (nullable = true)
 |-- image: struct (nullable = true)
 |    |-- medium: string (nullable = true)
 |    |-- original: string (nullable = true)
 |-- language: string (nullable = true)
 |-- name: string (nullable = true)
 |-- network: struct (nullable = true)
 |    |-- country: struct (nullable = true)
 |    |    |-- code: string (nullable = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- timezone: string (nullable = true)
 |    |-- id: long (nullable = true)
 |    |-- name: string (nullable = true)
 |-- officialSite: string (nullable = true)
 |-- premiered: string (nullable = true)
 |-- rating: struct (nullable = true)
 |    |-- average: double (nullable = true)
 |-- runtime: long (nullable = true)
 |-- schedule: struct (nullable = true)
 |    |-- days: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- time: string (nullable = true)
 |-- status: string (nullable = true)
 |-- summary: string (nullable = true)
 |-- type: string (nullable = true)
 |-- updated: long (nullable = true)
 |-- url: string (nullable = true)
 |-- webChannel: string (nullable = true)
 |-- weight: long (nullable = true)

Spark's complex column types: array, map and struct

array

  • PySpark arrays are containers for values of the same type, unlike JSON.
  • PySpark will not raise an error if you try to read an array-type column with multiple types. Instead, it will simply default to the lowest common denominator, usually the string.
  • Many array functions are available from pyspark.sql.functions
In [187]:
# Selecting the name and genres columns of the shows dataframe

import pyspark.sql.functions as F

array_subset = shows.select("name", "genres")
array_subset.show(1, False)
+--------------+--------+
|name          |genres  |
+--------------+--------+
|Silicon Valley|[Comedy]|
+--------------+--------+

In [188]:
# Multiple methods to extract the same array

array_subset = array_subset.select(
    "name",
    array_subset.genres[0].alias("dot_and_index"),
    F.col("genres")[0].alias("col_and_index"),
    array_subset.genres.getItem(0).alias("dot_and_method"),
    F.col("genres").getItem(0).alias("col_and_method"),
)

array_subset.show()
+--------------+-------------+-------------+--------------+--------------+
|          name|dot_and_index|col_and_index|dot_and_method|col_and_method|
+--------------+-------------+-------------+--------------+--------------+
|Silicon Valley|       Comedy|       Comedy|        Comedy|        Comedy|
+--------------+-------------+-------------+--------------+--------------+

WARNING: Although the square bracket approach looks very Pythonic, you can’t use it as a slicing tool. PySpark will accept only one integer as an index.

Creating an array column

  1. Create three literal columns (using lit() to create scalar columns, then make_array()) to create an array of possible genres.
  2. Use the function array_repeat() to create a column repeating the "Comedy" string
In [218]:
"""
1. Create three literal columns (using lit() to create scalar columns, 
   then make_array() to ) to create an array of possible genres.
2. Use the function array_repeat() to create a column repeating the "Comedy" string
"""

array_subset_repeated = array_subset.select(
    "name",
    F.lit("Comedy").alias("one"),
    F.lit("Horror").alias("two"),
    F.lit("Drama").alias("three"),
    F.col("dot_and_index"),
).select(
    "name",
    F.array("one", "two", "three").alias("Some_Genres"),
    F.array_repeat("dot_and_index", 5).alias("Repeated_Genres"),
)

array_subset_repeated.show(1, False)
+--------------+-----------------------+----------------------------------------+
|name          |Some_Genres            |Repeated_Genres                         |
+--------------+-----------------------+----------------------------------------+
|Silicon Valley|[Comedy, Horror, Drama]|[Comedy, Comedy, Comedy, Comedy, Comedy]|
+--------------+-----------------------+----------------------------------------+

Use F.size to show the number of elements in an array

In [190]:
array_subset_repeated.select(
    "name", F.size("Some_Genres"), F.size("Repeated_Genres")
).show()
+--------------+-----------------+---------------------+
|          name|size(Some_Genres)|size(Repeated_Genres)|
+--------------+-----------------+---------------------+
|Silicon Valley|                3|                    5|
+--------------+-----------------+---------------------+

Use F.array_distinct() to remove duplicates (like SQL)

In [191]:
array_subset_repeated.select(
    "name",
    F.array_distinct("Some_Genres"),
    F.array_distinct("Repeated_Genres")
).show(1, False)
+--------------+---------------------------+-------------------------------+
|name          |array_distinct(Some_Genres)|array_distinct(Repeated_Genres)|
+--------------+---------------------------+-------------------------------+
|Silicon Valley|[Comedy, Horror, Drama]    |[Comedy]                       |
+--------------+---------------------------+-------------------------------+

Use F.array_intersect to show common values across arrays

In [192]:
array_subset_repeated = array_subset_repeated.select(
    "name", 
    F.array_intersect("Some_Genres", "Repeated_Genres").alias("Genres")
)

array_subset_repeated.show()
+--------------+--------+
|          name|  Genres|
+--------------+--------+
|Silicon Valley|[Comedy]|
+--------------+--------+

Use array_position() to get the position of the item in an array if it exists

WARNING: array_position is 1-based, unlike Python lists or extracting elements from arrays (e.g. array_subset.genres[0] or getItems(0))

In [193]:
# When using array_position(), the first item of the array has position 1, 
# not 0 like in python.
array_subset_repeated.select(
    "name",
    F.array_position("Genres", "Comedy").alias("Genres"),
).show()
+--------------+------+
|          name|Genres|
+--------------+------+
|Silicon Valley|     1|
+--------------+------+

map

  • Like Python typed dictionary: you have keys and values just like in a dictionary,
  • Like array, keys need to be of the same type and the values need to be of the same type
  • Values can usually be null, but keys can’t (like Python)
In [194]:
# Creating a map from two arrays: one for the keys, one for the values. 
# This creates a hash-map within the column record.

# 1. Create two columns of arrays
columns = ["name", "language", "type"]
shows_map = shows.select(
    *[F.lit(column) for column in columns],
    F.array(*columns).alias("values")
)
shows_map = shows_map.select(F.array(*columns).alias("keys"), "values")
print("Two columns of arays")
shows_map.show(1, False)

# 2. Map them together using one array as the key, and other as value
shows_map = shows_map.select(
    F.map_from_arrays("keys", "values").alias("mapped")
)
shows_map.printSchema()
print("1 column of map")
shows_map.show(1, False)

# 3. 3 ways to select a key in a map column
print("3 ways to select a key in a map")
shows_map.select(
    F.col("mapped.name"), # dot_notation with col
    F.col("mapped")["name"], # Python dictionary style
    shows_map.mapped["name"] # dot_notation to get the column + bracket
).show()
Two columns of arays
+----------------------+-----------------------------------+
|keys                  |values                             |
+----------------------+-----------------------------------+
|[name, language, type]|[Silicon Valley, English, Scripted]|
+----------------------+-----------------------------------+

root
 |-- mapped: map (nullable = false)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

1 column of map
+---------------------------------------------------------------+
|mapped                                                         |
+---------------------------------------------------------------+
|[name -> Silicon Valley, language -> English, type -> Scripted]|
+---------------------------------------------------------------+

3 ways to select a key in a map
+--------------+--------------+--------------+
|          name|  mapped[name]|  mapped[name]|
+--------------+--------------+--------------+
|Silicon Valley|Silicon Valley|Silicon Valley|
+--------------+--------------+--------------+

struct

  • Similar to JSON object. Key is a string and record can be of a different type.
  • Unlike array & map, the number of fields and their names are known ahead of time

In [195]:
# "schedule" column contain array of strings and a string
shows.select("schedule").printSchema()
root
 |-- schedule: struct (nullable = true)
 |    |-- days: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- time: string (nullable = true)

In [196]:
# A more complex struct
shows.select("_embedded").printSchema()
root
 |-- _embedded: struct (nullable = true)
 |    |-- episodes: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- _links: struct (nullable = true)
 |    |    |    |    |-- self: struct (nullable = true)
 |    |    |    |    |    |-- href: string (nullable = true)
 |    |    |    |-- airdate: string (nullable = true)
 |    |    |    |-- airstamp: timestamp (nullable = true)
 |    |    |    |-- airtime: string (nullable = true)
 |    |    |    |-- id: long (nullable = true)
 |    |    |    |-- image: struct (nullable = true)
 |    |    |    |    |-- medium: string (nullable = true)
 |    |    |    |    |-- original: string (nullable = true)
 |    |    |    |-- name: string (nullable = true)
 |    |    |    |-- number: long (nullable = true)
 |    |    |    |-- runtime: long (nullable = true)
 |    |    |    |-- season: long (nullable = true)
 |    |    |    |-- summary: string (nullable = true)
 |    |    |    |-- url: string (nullable = true)

Above struct visualized:

In [197]:
# Drop useless _embedded column and promote the fields within
shows_clean = shows.withColumn("episodes", F.col("_embedded.episodes")).drop(
    "_embedded"
)
shows_clean.select("episodes").printSchema()
root
 |-- episodes: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _links: struct (nullable = true)
 |    |    |    |-- self: struct (nullable = true)
 |    |    |    |    |-- href: string (nullable = true)
 |    |    |-- airdate: string (nullable = true)
 |    |    |-- airstamp: timestamp (nullable = true)
 |    |    |-- airtime: string (nullable = true)
 |    |    |-- id: long (nullable = true)
 |    |    |-- image: struct (nullable = true)
 |    |    |    |-- medium: string (nullable = true)
 |    |    |    |-- original: string (nullable = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- number: long (nullable = true)
 |    |    |-- runtime: long (nullable = true)
 |    |    |-- season: long (nullable = true)
 |    |    |-- summary: string (nullable = true)
 |    |    |-- url: string (nullable = true)

Using explode to split arrays into rows

In [198]:
# "episodes.name" == array of strings
episodes_name = shows_clean.select(F.col("episodes.name"))
episodes_name.printSchema()

# Just showing episodes_name is messy, so explode the array to show the names
episodes_name.select(F.explode("name").alias("name")).show(3, False)
root
 |-- name: array (nullable = true)
 |    |-- element: string (containsNull = true)

+-------------------------+
|name                     |
+-------------------------+
|Minimum Viable Product   |
|The Cap Table            |
|Articles of Incorporation|
+-------------------------+
only showing top 3 rows

How to define and use a schema with a PySpark data frame

  • Can build either 1) programmatically, or 2) DDL-style schema
  • Type objects used to build schema located in pyspark.sql.types, usually imported as T.

Two object types in pyspark.sql.types

  1. types object - represent column of a certain type (e.g. LongType(), DecimalType(precision, scale), ArrayType(StringType()), etc.
  2. field object - represent arbitrary number of named fields (e.g. StructField())
    • 2 mandatory params, name (str) and dataType (type)

Putting it altogether:

T.StructField("summary", T.StringType())
In [199]:
# For reference
shows.select("_embedded").printSchema()
root
 |-- _embedded: struct (nullable = true)
 |    |-- episodes: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- _links: struct (nullable = true)
 |    |    |    |    |-- self: struct (nullable = true)
 |    |    |    |    |    |-- href: string (nullable = true)
 |    |    |    |-- airdate: string (nullable = true)
 |    |    |    |-- airstamp: timestamp (nullable = true)
 |    |    |    |-- airtime: string (nullable = true)
 |    |    |    |-- id: long (nullable = true)
 |    |    |    |-- image: struct (nullable = true)
 |    |    |    |    |-- medium: string (nullable = true)
 |    |    |    |    |-- original: string (nullable = true)
 |    |    |    |-- name: string (nullable = true)
 |    |    |    |-- number: long (nullable = true)
 |    |    |    |-- runtime: long (nullable = true)
 |    |    |    |-- season: long (nullable = true)
 |    |    |    |-- summary: string (nullable = true)
 |    |    |    |-- url: string (nullable = true)

Building the entire schema from scratch

In [200]:
# Full schema from scratch

# episode links
episode_links_schema = T.StructType(
    [T.StructField("self", T.StructType([T.StructField("href", T.StringType())]))]
)

# episode image
episode_image_schema = T.StructType(
    [
        T.StructField("medium", T.StringType()),
        T.StructField("original", T.StringType()),
    ]
)

# episode metadata
episode_schema = T.StructType(
    [
        T.StructField("_links", episode_links_schema),
        T.StructField("airdate", T.DateType()),
        T.StructField("airstamp", T.TimestampType()),
        T.StructField("airtime", T.StringType()),
        T.StructField("id", T.StringType()),
        T.StructField("image", episode_image_schema),
        T.StructField("name", T.StringType()),
        T.StructField("number", T.LongType()),
        T.StructField("runtime", T.LongType()),
        T.StructField("season", T.LongType()),
        T.StructField("summary", T.StringType()),
        T.StructField("url", T.StringType()),
    ]
)

# set top level array
embedded_schema = T.StructType([T.StructField("episodes", T.ArrayType(episode_schema))])

# network
network_schema = T.StructType(
    [
        T.StructField(
            "country",
            T.StructType(
                [
                    T.StructField("code", T.StringType()),
                    T.StructField("name", T.StringType()),
                    T.StructField("timezone", T.StringType()),
                ]
            ),
        ),
        T.StructField("id", T.LongType()),
        T.StructField("name", T.StringType()),
    ]
)

# shows (with embedded_schema and network_schema)
shows_schema = T.StructType(
    [
        T.StructField("_embedded", embedded_schema),
        T.StructField("language", T.StringType()),
        T.StructField("name", T.StringType()),
        T.StructField("network", network_schema),
        T.StructField("officialSite", T.StringType()),
        T.StructField("premiered", T.StringType()),
        T.StructField(
            "rating", T.StructType([T.StructField("average", T.DoubleType())])
        ),
        T.StructField("runtime", T.LongType()),
        T.StructField(
            "schedule",
            T.StructType(
                [
                    T.StructField("days", T.ArrayType(T.StringType())),
                    T.StructField("time", T.StringType()),
                ]
            ),
        ),
        T.StructField("status", T.StringType()),
        T.StructField("summary", T.StringType()),
        T.StructField("type", T.StringType()),
        T.StructField("updated", T.LongType()),
        T.StructField("url", T.StringType()),
        T.StructField("webChannel", T.StringType()),
        T.StructField("weight", T.LongType()),
    ]
)

Reading JSON with a strict schema

Read the JSON file using the schema that we built up:

  • mode="FAILFAST" is a param to throw an error if it reads a malformed record versus the schema provided.
  • If reading non-standard date/timestamp format, you'll need to pass the right format to dateFormat or timestampFormat.

Default for mode parameter is PERMISSIVE, which sets malformed records to null.

In [201]:
shows_with_schema = spark.read.json("./data/Ch06/shows-silicon-valley.json",
                                   schema=shows_schema,
                                   mode="FAILFAST")

# Check format for modified columns:
for column in ["airdate", "airstamp"]:
    shows_with_schema.select(f"_embedded.episodes.{column}") \
                     .select(F.explode(column)) \
                     .show(5, False)
+----------+
|col       |
+----------+
|2014-04-06|
|2014-04-13|
|2014-04-20|
|2014-04-27|
|2014-05-04|
+----------+
only showing top 5 rows

+-------------------+
|col                |
+-------------------+
|2014-04-06 22:00:00|
|2014-04-13 22:00:00|
|2014-04-20 22:00:00|
|2014-04-27 22:00:00|
|2014-05-04 22:00:00|
+-------------------+
only showing top 5 rows

Example of FAILFAST error due to conflicting schema

In [202]:
from py4j.protocol import Py4JJavaError

shows_schema2 = T.StructType(
    [
        T.StructField("_embedded", embedded_schema),
        T.StructField("language", T.StringType()),
        T.StructField("name", T.StringType()),
        T.StructField("network", network_schema),
        T.StructField("officialSite", T.StringType()),
        T.StructField("premiered", T.StringType()),
        T.StructField(
            "rating", T.StructType([T.StructField("average", T.DoubleType())])
        ),
        T.StructField("runtime", T.LongType()),
        T.StructField(
            "schedule",
            T.StructType(
                [
                    T.StructField("days", T.ArrayType(T.StringType())),
                    T.StructField("time", T.StringType()),
                ]
            ),
        ),
        T.StructField("status", T.StringType()),
        T.StructField("summary", T.StringType()),
        T.StructField("type", T.LongType()),         # switch to LongType
        T.StructField("updated", T.LongType()),      # switch to LongType
        T.StructField("url", T.LongType()),          # switch to LongType
        T.StructField("webChannel", T.StringType()),
        T.StructField("weight", T.LongType()),
    ]
)

shows_with_schema_wrong = spark.read.json(
    "data/Ch06/shows-silicon-valley.json", schema=shows_schema2, mode="FAILFAST",
)

try:
    shows_with_schema_wrong.show()
except Py4JJavaError:
    pass

# Huge Spark ERROR stacktrace, relevant bit:
#
# Caused by: java.lang.RuntimeException: Failed to parse a value for data type
#   bigint (current token: VALUE_STRING).

Defining your schema in JSON

StructType comes with two methods for exporting its content into a JSON-esque format.

  1. json() outputs a string containing the json formatted schema
  2. jsonValue() returns the schema as a dictionary
In [203]:
from pprint import pprint

pprint(shows_with_schema.select('schedule').schema.jsonValue())
{'fields': [{'metadata': {},
             'name': 'schedule',
             'nullable': True,
             'type': {'fields': [{'metadata': {},
                                  'name': 'days',
                                  'nullable': True,
                                  'type': {'containsNull': True,
                                           'elementType': 'string',
                                           'type': 'array'}},
                                 {'metadata': {},
                                  'name': 'time',
                                  'nullable': True,
                                  'type': 'string'}],
                      'type': 'struct'}}],
 'type': 'struct'}

You can use jsonValue on complex schema to see its JSON representation. This is helpful when trying to remember a complex schema:

Array types

  1. containsNull,
  2. elementType,
  3. type (always array)
In [204]:
pprint(T.StructField("array_example", T.ArrayType(T.StringType())).jsonValue())
{'metadata': {},
 'name': 'array_example',
 'nullable': True,
 'type': {'containsNull': True, 'elementType': 'string', 'type': 'array'}}

Map types

  1. keyType
  2. type (always map)
  3. valueContainsNull
  4. valueType
  5. keyType
In [205]:
# Example 1
pprint(
    T.StructField("map_example", T.MapType(T.StringType(), T.LongType())).jsonValue()
)
{'metadata': {},
 'name': 'map_example',
 'nullable': True,
 'type': {'keyType': 'string',
          'type': 'map',
          'valueContainsNull': True,
          'valueType': 'long'}}
In [206]:
# With both
pprint(
    T.StructType(
        [
            T.StructField("map_example", T.MapType(T.StringType(), T.LongType())),
            T.StructField("array_example", T.ArrayType(T.StringType())),
        ]
    ).jsonValue()
)
{'fields': [{'metadata': {},
             'name': 'map_example',
             'nullable': True,
             'type': {'keyType': 'string',
                      'type': 'map',
                      'valueContainsNull': True,
                      'valueType': 'long'}},
            {'metadata': {},
             'name': 'array_example',
             'nullable': True,
             'type': {'containsNull': True,
                      'elementType': 'string',
                      'type': 'array'}}],
 'type': 'struct'}

Finally, we can close the loop by making sure that our JSON-schema is consistent with the one currently being used. For this, we’ll export the schema of shows_with_schema in a JSON string, load it as a JSON object and then use StructType.fromJson() method to re-create the schema.

In [207]:
other_shows_schema = T.StructType.fromJson(json.loads(shows_with_schema.schema.json()))

print(other_shows_schema == shows_with_schema.schema)  # True
True

Reducing duplicate data with complex data types

Hierarchichal vs 2-D row-column models

If we were to make the shows data frame in a traditional relational database, we could have a shows table linked to an episodes table using a star schema.

shows table

show_id name
143 silicon valley

episodes table, joined to shows by show_id

show_id episode_id name
143 1 Minimal Viable Product
143 2 The Cap Table
143 3 Articles of Incorporation

episodes could be extended with more columns, but starts to have duplicate entries

show_id episode_id name genre day
143 1 Minimal Viable Product Comedy Sunday
143 2 The Cap Table Comedy Sunday
143 3 Articles of Incorporation Comedy Sunday

In contrast, a hierarchichal data frame contains complex columns with arrays and struct columns:

  • each record represents a show;
  • a show has multiple episodes (array of structs column);
  • each episode has many fields (struct column within the array);
  • each show can have multiple genres (array of string column)
  • each show has a schedule (struct column);
  • each schedule belonging to a show can have multiple days (array), but a single time (string).
shows data frame using a hierarchical model

How to use explode and collect operations to go from hierarchical to tabular and back

We will now revisit the exploding operation by generalizing it to the map, looking at the behavior when your data frame has multiple columns, and see the different options PySpark provided with exploding.

In [208]:
# Exploding _embeedded.episodes
episodes = shows.select("id", F.explode("_embedded.episodes").alias("episodes"))
episodes.printSchema()
episodes.show(5)
root
 |-- id: long (nullable = true)
 |-- episodes: struct (nullable = true)
 |    |-- _links: struct (nullable = true)
 |    |    |-- self: struct (nullable = true)
 |    |    |    |-- href: string (nullable = true)
 |    |-- airdate: string (nullable = true)
 |    |-- airstamp: timestamp (nullable = true)
 |    |-- airtime: string (nullable = true)
 |    |-- id: long (nullable = true)
 |    |-- image: struct (nullable = true)
 |    |    |-- medium: string (nullable = true)
 |    |    |-- original: string (nullable = true)
 |    |-- name: string (nullable = true)
 |    |-- number: long (nullable = true)
 |    |-- runtime: long (nullable = true)
 |    |-- season: long (nullable = true)
 |    |-- summary: string (nullable = true)
 |    |-- url: string (nullable = true)

+---+--------------------+
| id|            episodes|
+---+--------------------+
|143|[[[http://api.tvm...|
|143|[[[http://api.tvm...|
|143|[[[http://api.tvm...|
|143|[[[http://api.tvm...|
|143|[[[http://api.tvm...|
+---+--------------------+
only showing top 5 rows

Exploding a map

  • keys and values exploded in two different fields
  • posexplode: explodes the column and also returns an additional column before the data that contains the array positions (LongType).
  • explode / posexplode skips null values
In [209]:
episode_name_id = shows.select(
    F.map_from_arrays(
        F.col("_embedded.episodes.id"), F.col("_embedded.episodes.name")
    ).alias("name_id")
)

episode_name_id = episode_name_id.select(
    F.posexplode("name_id").alias("position", "id", "name")
)

episode_name_id.show(5, False)
+--------+-----+-------------------------+
|position|id   |name                     |
+--------+-----+-------------------------+
|0       |10897|Minimum Viable Product   |
|1       |10898|The Cap Table            |
|2       |10899|Articles of Incorporation|
|3       |10900|Fiduciary Duties         |
|4       |10901|Signaling Risk           |
+--------+-----+-------------------------+
only showing top 5 rows

collect-ing records into a complex column

collect_list() and collect_set()
  • takes column as arg, returns an array column
  • collect_list = 1 array per column record
  • collect_set = 1 array per distinct column record (like Python set)
In [210]:
collected = episodes.groupby("id").agg(F.collect_list("episodes").alias("episodes"))
print(collected.count())
collected.printSchema()
1
root
 |-- id: long (nullable = true)
 |-- episodes: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- _links: struct (nullable = true)
 |    |    |    |-- self: struct (nullable = true)
 |    |    |    |    |-- href: string (nullable = true)
 |    |    |-- airdate: string (nullable = true)
 |    |    |-- airstamp: timestamp (nullable = true)
 |    |    |-- airtime: string (nullable = true)
 |    |    |-- id: long (nullable = true)
 |    |    |-- image: struct (nullable = true)
 |    |    |    |-- medium: string (nullable = true)
 |    |    |    |-- original: string (nullable = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- number: long (nullable = true)
 |    |    |-- runtime: long (nullable = true)
 |    |    |-- season: long (nullable = true)
 |    |    |-- summary: string (nullable = true)
 |    |    |-- url: string (nullable = true)

Building your own hierarchies with struct()

struct() function takess columns as params, and returns struct column containing the columns passed as params as fields.

In [211]:
# Creating a struct column

struct_ex = shows.select(
    F.struct(
        F.col("status"), F.col("weight"), F.lit(True).alias("has_watched")
    ).alias("info")
)

struct_ex.show(1, False)

struct_ex.printSchema()
+-----------------+
|info             |
+-----------------+
|[Ended, 96, true]|
+-----------------+

root
 |-- info: struct (nullable = false)
 |    |-- status: string (nullable = true)
 |    |-- weight: long (nullable = true)
 |    |-- has_watched: boolean (nullable = false)

In [212]:
shows.printSchema()
root
 |-- _embedded: struct (nullable = true)
 |    |-- episodes: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- _links: struct (nullable = true)
 |    |    |    |    |-- self: struct (nullable = true)
 |    |    |    |    |    |-- href: string (nullable = true)
 |    |    |    |-- airdate: string (nullable = true)
 |    |    |    |-- airstamp: timestamp (nullable = true)
 |    |    |    |-- airtime: string (nullable = true)
 |    |    |    |-- id: long (nullable = true)
 |    |    |    |-- image: struct (nullable = true)
 |    |    |    |    |-- medium: string (nullable = true)
 |    |    |    |    |-- original: string (nullable = true)
 |    |    |    |-- name: string (nullable = true)
 |    |    |    |-- number: long (nullable = true)
 |    |    |    |-- runtime: long (nullable = true)
 |    |    |    |-- season: long (nullable = true)
 |    |    |    |-- summary: string (nullable = true)
 |    |    |    |-- url: string (nullable = true)
 |-- _links: struct (nullable = true)
 |    |-- previousepisode: struct (nullable = true)
 |    |    |-- href: string (nullable = true)
 |    |-- self: struct (nullable = true)
 |    |    |-- href: string (nullable = true)
 |-- externals: struct (nullable = true)
 |    |-- imdb: string (nullable = true)
 |    |-- thetvdb: long (nullable = true)
 |    |-- tvrage: long (nullable = true)
 |-- genres: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- id: long (nullable = true)
 |-- image: struct (nullable = true)
 |    |-- medium: string (nullable = true)
 |    |-- original: string (nullable = true)
 |-- language: string (nullable = true)
 |-- name: string (nullable = true)
 |-- network: struct (nullable = true)
 |    |-- country: struct (nullable = true)
 |    |    |-- code: string (nullable = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- timezone: string (nullable = true)
 |    |-- id: long (nullable = true)
 |    |-- name: string (nullable = true)
 |-- officialSite: string (nullable = true)
 |-- premiered: string (nullable = true)
 |-- rating: struct (nullable = true)
 |    |-- average: double (nullable = true)
 |-- runtime: long (nullable = true)
 |-- schedule: struct (nullable = true)
 |    |-- days: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- time: string (nullable = true)
 |-- status: string (nullable = true)
 |-- summary: string (nullable = true)
 |-- type: string (nullable = true)
 |-- updated: long (nullable = true)
 |-- url: string (nullable = true)
 |-- webChannel: string (nullable = true)
 |-- weight: long (nullable = true)

In [ ]:
 

Chapter 7: Bilingual PySpark: blending Python and SQL

This chapter is dedicated to using SQL with, and on top of PySpark. I cover how we can move from one language to the other. I also cover how we can use a SQL-like syntax within data frame methods to speed up your code and some of trade-offs you can face. Finally, we blend Python and SQL code together to get the best of both worlds.

Summary

  • Spark provides an SQL API for data manipulation. This API supports ANSI SQL.
  • PySpark’s data frames need to be registered as views or tables before they can be queried with Spark SQL. You can give them a different name than the data frame you’re registering.
  • Spark SQL queries can be inserted in a PySpark program through the spark.sql function, where spark is the running SparkSession.
  • Spark SQL tables references are kept in a Catalog which contains the metadata for all tables accessible to Spark SQL.
  • PySpark will accept SQL-style clauses in where() , expr() and selectExpr(), which can simplify the syntax for complex filtering and selection.
  • When using Spark SQL queries with user-provided input, be careful about sanitizing the inputs to avoid potential SQL injection attacks.

Data Sources

We will be using a periodic table of elements database for the initial section, followed by a public data set provided by BackBlaze, which provides hard drive data and statistics.

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException
import pyspark.sql.functions as F
import pyspark.sql.types as T
import numpy as np

spark = SparkSession.builder.getOrCreate()
In [2]:
# Read in table of elements data
elements = spark.read.csv(
    "data/Ch07/Periodic_Table_Of_Elements.csv",
    header=True,
    inferSchema=True,
)

# Inspect the data frame
elements.printSchema()

# View the data frame in chunks of 3-4 columns
# column_split = np.array_split(np.array(elements.columns), len(elements.columns) // 3)

# for x in column_split:
#     elements.select(*x).show(3, False)
root
 |-- AtomicNumber: integer (nullable = true)
 |-- Element: string (nullable = true)
 |-- Symbol: string (nullable = true)
 |-- AtomicMass: double (nullable = true)
 |-- NumberofNeutrons: integer (nullable = true)
 |-- NumberofProtons: integer (nullable = true)
 |-- NumberofElectrons: integer (nullable = true)
 |-- Period: integer (nullable = true)
 |-- Group: integer (nullable = true)
 |-- Phase: string (nullable = true)
 |-- Radioactive: string (nullable = true)
 |-- Natural: string (nullable = true)
 |-- Metal: string (nullable = true)
 |-- Nonmetal: string (nullable = true)
 |-- Metalloid: string (nullable = true)
 |-- Type: string (nullable = true)
 |-- AtomicRadius: double (nullable = true)
 |-- Electronegativity: double (nullable = true)
 |-- FirstIonization: double (nullable = true)
 |-- Density: double (nullable = true)
 |-- MeltingPoint: double (nullable = true)
 |-- BoilingPoint: double (nullable = true)
 |-- NumberOfIsotopes: integer (nullable = true)
 |-- Discoverer: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- SpecificHeat: double (nullable = true)
 |-- NumberofShells: integer (nullable = true)
 |-- NumberofValence: integer (nullable = true)

pyspark.sql vs SQL

Order of execution

The code below selects the phrase column that contain "liq", then runs groupby and count.

SQL equivalent would be:

SELECT
  period,
  count(*)
FROM elements
WHERE phase = "liq"
GROUP BY period;
In [3]:
elements.where(F.col("phase") == "liq").groupby("period").count().show()
+------+-----+
|period|count|
+------+-----+
|     6|    1|
|     4|    1|
+------+-----+

Using SQL queries on a data frame

  • In order to allow a data frame to be queried via SQL, we need to register them as tables.
  • Spark SQL does not have visibility over the variables Python assigns.
  • Use createOrReplaceTempView() to read a data frame and create a Spark SQL reference. Functionally equivalent to CREATE_OR_REPLACE_VIEW in SQL
In [4]:
# Directly querying a data frame SQL-style does not work
try:
    spark.sql(
        "select period, count(*) from elements where phase='liq' group by period"
    ).show(5)
except AnalysisException as e:
    print(e)
Table or view not found: elements; line 1 pos 29;
'Aggregate ['period], ['period, unresolvedalias(count(1), None)]
+- 'Filter ('phase = liq)
   +- 'UnresolvedRelation [elements]

In [5]:
# Using createOrReplaceTempView

elements.createOrReplaceTempView("elements")

spark.sql(
    "select period, count(*) from elements where phase='liq' group by period"
).show(5)
+------+--------+
|period|count(1)|
+------+--------+
|     6|       1|
|     4|       1|
+------+--------+

Table vs View concept

In SQL, they are distinct concepts: the table is materialized in memory and the view is computed on the fly. Spark’s temp views are conceptually closer to a view than a table. Spark SQL also has tables but we will not be using them, preferring reading and materializing our data into a data frame.

Using the Spark catalog for multiple views

  • Spark catalog mainly deals with managing metadata of multiple SQL tables, and their level of caching.
  • Catalogs manages views we've registered and drops them.
In [6]:
# Instantiate
spark.catalog

# List tables we've registered
display(spark.catalog.listTables())

# Drop a table
spark.catalog.dropTempView("elements")
display(spark.catalog.listTables())
[Table(name='elements', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]
[]

Data Source - Backblaze Data Set

(Note: Only reading in Q3 data due to local compute)

In [11]:
# Read backblaze data set into a data frame and register a SQL view

DATA_DIRECTORY = "./data/Ch07/"

# q1 = spark.read.csv(
#     DATA_DIRECTORY + "drive_stats_2019_Q1", header=True, inferSchema=True
# )
# q2 = spark.read.csv(
#     DATA_DIRECTORY + "data_Q2_2019", header=True, inferSchema=True
# )
q3 = spark.read.csv(
    DATA_DIRECTORY + "data_Q3_2019", header=True, inferSchema=True
)
# q4 = spark.read.csv(
#     DATA_DIRECTORY + "data_Q4_2019", header=True, inferSchema=True
# )

# Q4 has two more fields than the rest

# q4_fields_extra = set(q4.columns) - set(q1.columns)

# for i in q4_fields_extra:
#     q1 = q1.withColumn(i, F.lit(None).cast(T.StringType()))
#     q2 = q2.withColumn(i, F.lit(None).cast(T.StringType()))
#     q3 = q3.withColumn(i, F.lit(None).cast(T.StringType()))


# Union the data frames

# if you are only using the minimal set of data, use this version
backblaze_2019 = q3

# if you are using the full set of data, use this version
# backblaze_2019 = (
#     q1.select(q4.columns)
#     .union(q2.select(q4.columns))
#     .union(q3.select(q4.columns))
#     .union(q4)
# )

# Setting the layout for each column according to the schema
q = backblaze_2019.select(
    [
        F.col(x).cast(T.LongType()) if x.startswith("smart") else F.col(x)
        for x in backblaze_2019.columns
    ]
)

# Register the view
backblaze_2019.createOrReplaceTempView("backblaze_stats_2019")
In [21]:
backblaze_2019.printSchema()
root
 |-- date: string (nullable = true)
 |-- serial_number: string (nullable = true)
 |-- model: string (nullable = true)
 |-- capacity_bytes: long (nullable = true)
 |-- failure: integer (nullable = true)
 |-- smart_1_normalized: integer (nullable = true)
 |-- smart_1_raw: integer (nullable = true)
 |-- smart_2_normalized: integer (nullable = true)
 |-- smart_2_raw: integer (nullable = true)
 |-- smart_3_normalized: integer (nullable = true)
 |-- smart_3_raw: integer (nullable = true)
 |-- smart_4_normalized: integer (nullable = true)
 |-- smart_4_raw: integer (nullable = true)
 |-- smart_5_normalized: integer (nullable = true)
 |-- smart_5_raw: integer (nullable = true)
 |-- smart_7_normalized: integer (nullable = true)
 |-- smart_7_raw: long (nullable = true)
 |-- smart_8_normalized: integer (nullable = true)
 |-- smart_8_raw: integer (nullable = true)
 |-- smart_9_normalized: integer (nullable = true)
 |-- smart_9_raw: integer (nullable = true)
 |-- smart_10_normalized: integer (nullable = true)
 |-- smart_10_raw: integer (nullable = true)
 |-- smart_11_normalized: integer (nullable = true)
 |-- smart_11_raw: integer (nullable = true)
 |-- smart_12_normalized: integer (nullable = true)
 |-- smart_12_raw: integer (nullable = true)
 |-- smart_13_normalized: string (nullable = true)
 |-- smart_13_raw: string (nullable = true)
 |-- smart_15_normalized: string (nullable = true)
 |-- smart_15_raw: string (nullable = true)
 |-- smart_16_normalized: integer (nullable = true)
 |-- smart_16_raw: integer (nullable = true)
 |-- smart_17_normalized: integer (nullable = true)
 |-- smart_17_raw: integer (nullable = true)
 |-- smart_22_normalized: integer (nullable = true)
 |-- smart_22_raw: integer (nullable = true)
 |-- smart_23_normalized: integer (nullable = true)
 |-- smart_23_raw: integer (nullable = true)
 |-- smart_24_normalized: integer (nullable = true)
 |-- smart_24_raw: integer (nullable = true)
 |-- smart_168_normalized: integer (nullable = true)
 |-- smart_168_raw: integer (nullable = true)
 |-- smart_170_normalized: integer (nullable = true)
 |-- smart_170_raw: long (nullable = true)
 |-- smart_173_normalized: integer (nullable = true)
 |-- smart_173_raw: long (nullable = true)
 |-- smart_174_normalized: integer (nullable = true)
 |-- smart_174_raw: integer (nullable = true)
 |-- smart_177_normalized: integer (nullable = true)
 |-- smart_177_raw: integer (nullable = true)
 |-- smart_179_normalized: string (nullable = true)
 |-- smart_179_raw: string (nullable = true)
 |-- smart_181_normalized: string (nullable = true)
 |-- smart_181_raw: string (nullable = true)
 |-- smart_182_normalized: string (nullable = true)
 |-- smart_182_raw: string (nullable = true)
 |-- smart_183_normalized: integer (nullable = true)
 |-- smart_183_raw: integer (nullable = true)
 |-- smart_184_normalized: integer (nullable = true)
 |-- smart_184_raw: integer (nullable = true)
 |-- smart_187_normalized: integer (nullable = true)
 |-- smart_187_raw: integer (nullable = true)
 |-- smart_188_normalized: integer (nullable = true)
 |-- smart_188_raw: long (nullable = true)
 |-- smart_189_normalized: integer (nullable = true)
 |-- smart_189_raw: integer (nullable = true)
 |-- smart_190_normalized: integer (nullable = true)
 |-- smart_190_raw: integer (nullable = true)
 |-- smart_191_normalized: integer (nullable = true)
 |-- smart_191_raw: integer (nullable = true)
 |-- smart_192_normalized: integer (nullable = true)
 |-- smart_192_raw: integer (nullable = true)
 |-- smart_193_normalized: integer (nullable = true)
 |-- smart_193_raw: integer (nullable = true)
 |-- smart_194_normalized: integer (nullable = true)
 |-- smart_194_raw: integer (nullable = true)
 |-- smart_195_normalized: integer (nullable = true)
 |-- smart_195_raw: integer (nullable = true)
 |-- smart_196_normalized: integer (nullable = true)
 |-- smart_196_raw: integer (nullable = true)
 |-- smart_197_normalized: integer (nullable = true)
 |-- smart_197_raw: integer (nullable = true)
 |-- smart_198_normalized: integer (nullable = true)
 |-- smart_198_raw: integer (nullable = true)
 |-- smart_199_normalized: integer (nullable = true)
 |-- smart_199_raw: integer (nullable = true)
 |-- smart_200_normalized: integer (nullable = true)
 |-- smart_200_raw: integer (nullable = true)
 |-- smart_201_normalized: string (nullable = true)
 |-- smart_201_raw: string (nullable = true)
 |-- smart_218_normalized: integer (nullable = true)
 |-- smart_218_raw: integer (nullable = true)
 |-- smart_220_normalized: integer (nullable = true)
 |-- smart_220_raw: integer (nullable = true)
 |-- smart_222_normalized: integer (nullable = true)
 |-- smart_222_raw: integer (nullable = true)
 |-- smart_223_normalized: integer (nullable = true)
 |-- smart_223_raw: integer (nullable = true)
 |-- smart_224_normalized: integer (nullable = true)
 |-- smart_224_raw: integer (nullable = true)
 |-- smart_225_normalized: integer (nullable = true)
 |-- smart_225_raw: integer (nullable = true)
 |-- smart_226_normalized: integer (nullable = true)
 |-- smart_226_raw: integer (nullable = true)
 |-- smart_231_normalized: integer (nullable = true)
 |-- smart_231_raw: long (nullable = true)
 |-- smart_232_normalized: integer (nullable = true)
 |-- smart_232_raw: long (nullable = true)
 |-- smart_233_normalized: integer (nullable = true)
 |-- smart_233_raw: integer (nullable = true)
 |-- smart_235_normalized: integer (nullable = true)
 |-- smart_235_raw: long (nullable = true)
 |-- smart_240_normalized: integer (nullable = true)
 |-- smart_240_raw: long (nullable = true)
 |-- smart_241_normalized: integer (nullable = true)
 |-- smart_241_raw: long (nullable = true)
 |-- smart_242_normalized: integer (nullable = true)
 |-- smart_242_raw: long (nullable = true)
 |-- smart_250_normalized: integer (nullable = true)
 |-- smart_250_raw: integer (nullable = true)
 |-- smart_251_normalized: integer (nullable = true)
 |-- smart_251_raw: integer (nullable = true)
 |-- smart_252_normalized: integer (nullable = true)
 |-- smart_252_raw: integer (nullable = true)
 |-- smart_254_normalized: integer (nullable = true)
 |-- smart_254_raw: integer (nullable = true)
 |-- smart_255_normalized: string (nullable = true)
 |-- smart_255_raw: string (nullable = true)

select and where

Use select and where to show a few hard drives serial numbers that have failed at some point (failure = 1)

In [27]:
# SQL order of operations: 1) select columns, then 2) filter 
spark.sql("select serial_number, model, capacity_bytes from backblaze_stats_2019 where failure = 1").show(5)

# PySpark order of operations: 1) filter, then 2) select columns
backblaze_2019.where("failure=1").select(
    F.col('serial_number'),
    F.col('model'),
    F.col('capacity_bytes')
).show(5)
+-------------+-------------+--------------+
|serial_number|        model|capacity_bytes|
+-------------+-------------+--------------+
|     ZA10MCJ5|  ST8000DM002| 8001563222016|
|     ZCH07T9K|ST12000NM0007|12000138625024|
|     ZCH0CA7Z|ST12000NM0007|12000138625024|
|     Z302F381|  ST4000DM000| 4000787030016|
|     ZCH0B3Z2|ST12000NM0007|12000138625024|
+-------------+-------------+--------------+
only showing top 5 rows

+-------------+-------------+--------------+
|serial_number|        model|capacity_bytes|
+-------------+-------------+--------------+
|     ZA10MCJ5|  ST8000DM002| 8001563222016|
|     ZCH07T9K|ST12000NM0007|12000138625024|
|     ZCH0CA7Z|ST12000NM0007|12000138625024|
|     Z302F381|  ST4000DM000| 4000787030016|
|     ZCH0B3Z2|ST12000NM0007|12000138625024|
+-------------+-------------+--------------+
only showing top 5 rows

groupby and orderby

Look at the capacity in gigabytes of the hard drives included in the data, by model.

In [39]:
# Groupby and order in SQL
spark.sql(
    """
    SELECT
        model,
        min(capacity_bytes / pow(1024, 3)) min_GB,
        max(capacity_bytes / pow(1024, 3)) max_GB
    FROM backblaze_stats_2019
    GROUP BY model
    ORDER BY max_GB DESC
"""
).show(5)
+--------------------+--------------------+-------+
|               model|              min_GB| max_GB|
+--------------------+--------------------+-------+
| TOSHIBA MG07ACA14TA|             13039.0|13039.0|
|HGST HUH721212ALE600|             11176.0|11176.0|
|       ST12000NM0117|             11176.0|11176.0|
|       ST12000NM0007|-9.31322574615478...|11176.0|
|HGST HUH721212ALN604|-9.31322574615478...|11176.0|
+--------------------+--------------------+-------+
only showing top 5 rows

In [40]:
# PySpark
backblaze_2019.groupby(F.col("model")).agg(
    F.min(F.col("capacity_bytes") / F.pow(F.lit(1024), 3)).alias("min_GB"),
    F.max(F.col("capacity_bytes") / F.pow(F.lit(1024), 3)).alias("max_GB"),
).orderBy(F.col("max_GB"), ascending=False).show(5)
+--------------------+--------------------+-------+
|               model|              min_GB| max_GB|
+--------------------+--------------------+-------+
| TOSHIBA MG07ACA14TA|             13039.0|13039.0|
|HGST HUH721212ALE600|             11176.0|11176.0|
|       ST12000NM0117|             11176.0|11176.0|
|       ST12000NM0007|-9.31322574615478...|11176.0|
|HGST HUH721212ALN604|-9.31322574615478...|11176.0|
+--------------------+--------------------+-------+
only showing top 5 rows

Filtering after grouping with having

having in SQL is a condition block used after grouping is done.

Filter the groupby with only those that have different min_GB and max_GB numbers

In [44]:
spark.sql(
    """
    SELECT
        model,
        min(capacity_bytes / pow(1024, 3)) min_GB,
        max(capacity_bytes / pow(1024, 3)) max_GB
    FROM backblaze_stats_2019
    GROUP BY model
    HAVING min_GB <> max_GB
    ORDER BY max_GB DESC
"""
).show(5)
+--------------------+--------------------+-----------------+
|               model|              min_GB|           max_GB|
+--------------------+--------------------+-----------------+
|       ST12000NM0007|-9.31322574615478...|          11176.0|
|HGST HUH721212ALN604|-9.31322574615478...|          11176.0|
|HGST HUH721010ALE600|-9.31322574615478...|           9314.0|
|       ST10000NM0086|-9.31322574615478...|           9314.0|
|         ST8000DM002|-9.31322574615478...|7452.036460876465|
+--------------------+--------------------+-----------------+
only showing top 5 rows

In [45]:
backblaze_2019.groupby(F.col("model")).agg(
    F.min(F.col("capacity_bytes") / F.pow(F.lit(1024), 3)).alias("min_GB"),
    F.max(F.col("capacity_bytes") / F.pow(F.lit(1024), 3)).alias("max_GB"),
).where(F.col("min_GB") != F.col("max_GB")).orderBy(
    F.col("max_GB"), ascending=False
).show(5)
+--------------------+--------------------+-----------------+
|               model|              min_GB|           max_GB|
+--------------------+--------------------+-----------------+
|       ST12000NM0007|-9.31322574615478...|          11176.0|
|HGST HUH721212ALN604|-9.31322574615478...|          11176.0|
|HGST HUH721010ALE600|-9.31322574615478...|           9314.0|
|       ST10000NM0086|-9.31322574615478...|           9314.0|
|         ST8000DM002|-9.31322574615478...|7452.036460876465|
+--------------------+--------------------+-----------------+
only showing top 5 rows

Saving tables/views using create

  • With SQL, prefix query with CREATE TABLE/VIEW
    • creating a table will materialize the data
    • creating a view will only keep the query
  • With PySpark, just save to variable

Compute the number of days of operation a model has and the number of drive failures it has had

In [56]:
# SQL

spark.catalog.dropTempView('drive_days')
spark.catalog.dropTempView('failures')

spark.sql(
    """
    CREATE TEMP VIEW drive_days AS
        SELECT model, count(*) AS drive_days
        FROM backblaze_stats_2019
        GROUP BY model
""")

spark.sql(
    """
    CREATE TEMP VIEW failures AS
        SELECT model, count(*) AS failures
        FROM backblaze_stats_2019
        WHERE failure = 1
        GROUP BY model
""")
Out[56]:
DataFrame[]
In [57]:
# PySpark

drive_days = backblaze_2019.groupBy(F.col("model")).agg(
    F.count(F.col("*")).alias("drive_days")
)

failures = (
    backblaze_2019.where(F.col("failure") == 1)
    .groupBy(F.col("model"))
    .agg(F.count(F.col("*")).alias("failures"))
)
In [59]:
failures.show(5)
+-------------------+--------+
|              model|failures|
+-------------------+--------+
|        ST4000DM000|      72|
|      ST12000NM0007|     365|
|        ST8000DM005|       1|
|TOSHIBA MQ01ABF050M|       5|
|       ST8000NM0055|      50|
+-------------------+--------+
only showing top 5 rows

Adding data to table using UNION and JOIN

  • SQL UNION removes duplicate records, while PySpark doesn't.
  • PySpark UNION is equal to SQL UNION ALL
  • To get SQL UNION equivalent with PySpark, run distinct() after union()

(Note: Not running 2 cells below since I only loaded Q3 data)

In [ ]:
columns_backblaze = ", ".join(q4.columns)

q1.createOrReplaceTempView("Q1")
q1.createOrReplaceTempView("Q2")
q1.createOrReplaceTempView("Q3")
q1.createOrReplaceTempView("Q4")

spark.sql(
    """
    CREATE VIEW backblaze_2019 AS
    SELECT {col} FROM Q1 UNION ALL
    SELECT {col} FROM Q2 UNION ALL
    SELECT {col} FROM Q3 UNION ALL
    SELECT {col} FROM Q4
""".format(
        col=columns_backblaze
    )
)
In [ ]:
backblaze_2019 = (
    q1.select(q4.columns)
    .union(q2.select(q4.columns))
    .union(q3.select(q4.columns))
    .union(q4)
)

Joining drive_days and failures tables together

In [60]:
spark.sql(
    """
    SELECT
        drive_days.model,
        drive_days,
        failures
    FROM drive_days
    LEFT JOIN failures
    ON
        drive_days.model = failures.model
"""
).show(5)
+-------------+----------+--------+
|        model|drive_days|failures|
+-------------+----------+--------+
|  ST9250315AS|        89|    null|
|  ST4000DM000|   1796728|      72|
|ST12000NM0007|   3212635|     365|
|   ST320LT007|        89|    null|
|  ST8000DM005|      2280|       1|
+-------------+----------+--------+
only showing top 5 rows

In [61]:
drive_days.join(failures, on="model", how="left").show(5)
+-------------+----------+--------+
|        model|drive_days|failures|
+-------------+----------+--------+
|  ST9250315AS|        89|    null|
|  ST4000DM000|   1796728|      72|
|ST12000NM0007|   3212635|     365|
|   ST320LT007|        89|    null|
|  ST8000DM005|      2280|       1|
+-------------+----------+--------+
only showing top 5 rows

Organizing code with subqueries and common table expressions (CTE)

Take drive_days and failures table definitions and bundle them into a single query using CTE.

In [65]:
spark.sql("""
    WITH drive_days as (
        SELECT
            model,
            count(*) AS drive_days
        FROM backblaze_stats_2019
        GROUP BY model),
    failures as (
        SELECT
            model,
            count(*) AS failures
        FROM backblaze_stats_2019
        WHERE failure = 1
        GROUP BY model)
        
    SELECT
        drive_days.model,
        failures / drive_days failure_rate
    FROM drive_days
    INNER JOIN failures
    ON drive_days.model = failures.model
    ORDER BY failure_rate DESC
""").show(5)
+--------------------+--------------------+
|               model|        failure_rate|
+--------------------+--------------------+
|       ST12000NM0117|0.019305019305019305|
|Seagate BarraCuda...|6.341154090044388E-4|
|  TOSHIBA MQ01ABF050|5.579360828423496E-4|
|         ST8000DM005|4.385964912280702E-4|
|          ST500LM030| 4.19639110365086E-4|
+--------------------+--------------------+
only showing top 5 rows

In [66]:
# CTE sort of similar to python functions
def failure_rate(drive_stats):
    drive_days = drive_stats.groupby(F.col("model")).agg(
        F.count(F.col("*")).alias("drive_days")
    )

    failures = (
        drive_stats.where(F.col("failure") == 1)
        .groupby(F.col("model"))
        .agg(F.count(F.col("*")).alias("failures"))
    )
    answer = (
        drive_days.join(failures, on="model", how="inner")
        .withColumn("failure_rate", F.col("failures") / F.col("drive_days"))
        .orderBy(F.col("failure_rate").desc())
    )
    return answer


failure_rate(backblaze_2019).show(5)

print("drive_days" in dir())
+--------------------+----------+--------+--------------------+
|               model|drive_days|failures|        failure_rate|
+--------------------+----------+--------+--------------------+
|       ST12000NM0117|       259|       5|0.019305019305019305|
|Seagate BarraCuda...|      1577|       1|6.341154090044388E-4|
|  TOSHIBA MQ01ABF050|     44808|      25|5.579360828423496E-4|
|         ST8000DM005|      2280|       1|4.385964912280702E-4|
|          ST500LM030|     21447|       9| 4.19639110365086E-4|
+--------------------+----------+--------+--------------------+
only showing top 5 rows

True

Mix and match PySpark and SQL code

This section will build on the code we’ve written so far. We’re going to write a function that, for a given capacity, will return the top 3 most reliable drives according to our failure rate.

selectExpr() is just like select(), but will process SQL-style operations. Nice because it removes F.col sort of syntax.

expr() wraps SQL-style expression into a PySpark column. Can use in lieu of F.col() when you want to modify a column.

In [68]:
# Data Ingestion using Python

from functools import reduce

import pyspark.sql.functions as F
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

DATA_DIRECTORY = "./data/Ch07/"

DATA_FILES = [
#     "drive_stats_2019_Q1",
#     "data_Q2_2019",
    "data_Q3_2019",
#     "data_Q4_2019",
]

data = [
    spark.read.csv(DATA_DIRECTORY + file, header=True, inferSchema=True)
    for file in DATA_FILES
]

common_columns = list(
    reduce(lambda x, y: x.intersection(y), [set(df.columns) for df in data])
)

assert set(["model", "capacity_bytes", "date", "failure"]).issubset(
    set(common_columns)
)

full_data = reduce(
    lambda x, y: x.select(common_columns).union(y.select(common_columns)), data
)
In [71]:
# Processing data for the query function with selectExpr

full_data = full_data.selectExpr( # <===
    "model", "capacity_bytes / pow(1024, 3) capacity_GB", "date", "failure"
)

drive_days = full_data.groupby("model", "capacity_GB").agg(
    F.count("*").alias("drive_days")
)

failures = (
    full_data.where("failure = 1")
    .groupby("model", "capacity_GB")
    .agg(F.count("*").alias("failures"))
)

summarized_data = (
    drive_days.join(failures, on=["model", "capacity_GB"], how="left")
    .fillna(0.0, ["failures"])
    .selectExpr("model", "capacity_GB", "failures / drive_days failure_rate")
    .cache()
)
In [72]:
# creating failures variable with expr

failures = (
    full_data.where("failure = 1")
    .groupby("model", "capacity_GB")
    .agg(F.expr("count(*) failures")) # <===
)
In [73]:
# Turning failure_rate in to a function using a mix of PySpark and SQL syntax

def most_reliable_drive_for_capacity(data, capacity_GB=2048, precision=0.25, top_n=3):
    """Returns the top 3 drives for a given approximate capacity.

    Given a capacity in GB and a precision as a decimal number, we keep the N
    drives where:

    - the capacity is between (capacity * 1/(1+precision)), capacity * (1+precision)
    - the failure rate is the lowest

    """
    capacity_min = capacity_GB / (1 + precision)
    capacity_max = capacity_GB * (1 + precision)

    answer = (
        data.where(f"capacity_GB between {capacity_min} and {capacity_max}")
        .orderBy("failure_rate", "capacity_GB", ascending=[True, False])
        .limit(top_n)
    )

    return answer

Chapter 8: RDD and user-defined functions

Summary

  • The resilient distributed dataset allows for better flexibility compared to the records and columns approach of the data frame.
  • The most low level and flexible way of running Python code within the distributed Spark environment is to use the RDD. With an RDD, you have no structure imposed on your data and need to manage type information into your program, and defensively code against potential exceptions.
  • The API for data processing on the RDD is heavily inspired by the MapReduce framework. You use higher order functions such as map(), filter() and reduce() on the objects of the RDD.
  • The data frame’s most basic Python code promotion functionality, called the (PySpark) UDF, emulates the "map" part of the RDD. You use it as a scalar function, taking Column objects as parameters and returning a single Column.

Terminology

Resilient Distributed Dataset (RDD)

  • Bag of elements, independent, no schema
  • Flexible with what you want to do but no safeguards

User-defined functions (UDF)

  • Simple way to promote Python functions to be used on a data frame.

RDD's Pros

  1. When you have unordered collection of Python objects that can be pickled
  2. Unordered key value pairs i.e. Python dict

Example: Creating an RDD from a Python list

In [2]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

collection = [1, "two", 3.0, ("four", 4), {"five": 5}]

sc = spark.sparkContext

collection_rdd = sc.parallelize(collection)

print(collection_rdd)
ParallelCollectionRDD[1] at readRDDFromFile at PythonRDD.scala:262

Manipulating data with map, filter and reduce

  • Each take a function as their only param, ie. they are higher-order functions.

map

  • apply one function to every object
  • need to be careful with unsupported types on whatever function you're trying to run
In [23]:
# Map a simple function to each element to an RDD.
# This will raise an error because not all of the elements are integers

from py4j.protocol import Py4JJavaError
import re


def add_one(value):
    return value + 1


collection_rdd = collection_rdd.map(add_one)

try:
    print(collection_rdd.collect())
except Py4JJavaError as e:
    pass

# Stack trace galore! The important bit, you'll get one of the following:
# TypeError: can only concatenate str (not "int") to str
# TypeError: unsupported operand type(s) for +: 'dict' and 'int'
# TypeError: can only concatenate tuple (not "int") to tuple
In [37]:
# Safer option with a try/except inside the function
def safer_add_one(value):
    try:
        return value + 1
    except TypeError:
        return value
    
# reset rdd
collection_rdd = sc.parallelize(collection)
print("Before: ", collection)

# run safe adding method
collection_rdd = collection_rdd.map(safer_add_one)
print("After : ", collection_rdd.collect())
Before:  [1, 'two', 3.0, ('four', 4), {'five': 5}]
After :  [2, 'two', 4.0, ('four', 4), {'five': 5}]

filter

In [48]:
# Filtering RDD with lambda function to keep only int and floats

collection_rdd = sc.parallelize(collection)


collection_rdd = collection_rdd.filter(lambda x: isinstance(x, (float, int)))
print(collection_rdd.collect())


# Alternative: Creating a separate function

collection_rdd = sc.parallelize(collection)

def is_string(elem):
    return True if isinstance(elem, str) else False

collection_rdd = collection_rdd.filter(is_string)
print(collection_rdd.collect())
[1, 3.0]
['two']

reduce

  • Used for summarization (ie. groupby and agg with dataframe)
  • Takes 2 elements and returns 1 element. If list > 2, will taking first 2 elements, then apply result again to third and so forth.

In [53]:
# Add list of numbers through reduce

from operator import add

collection_rdd = sc.parallelize(range(10))
print(collection_rdd.reduce(add))
45

Use Commutative and associate functions (for distributed computing)

Only give reduce commutative and associate functions.

  • Commutative function: Function in which order of arguments doesn't matter
  • Associative function: Function in which grouping of arguments doesn't matter,
    • subtract is not because (a - b) - c != a - (b - c)
  • add, multiply, min and max are both associative and commutative

User-defined Functions

  • UDFs allow you to implement custom functions on PySpark data frame columns

Using UDF to create a fractions function

In [57]:
import pyspark.sql.functions as F
import pyspark.sql.types as T

fractions = [[x,y] for x in range(100) for y in range(1, 100)]
frac_df = spark.createDataFrame(fractions, ["numerator", "denominator"])

frac_df = frac_df.select(
    F.array(F.col("numerator"), F.col("denominator")).alias("fraction"),
)

frac_df.show(5, False)
+--------+
|fraction|
+--------+
|[0, 1]  |
|[0, 2]  |
|[0, 3]  |
|[0, 4]  |
|[0, 5]  |
+--------+
only showing top 5 rows

Using typed Python functions

This section will create a function to reduce a fraction and one to transform a fraction into a floating-point number.

In [59]:
from fractions import Fraction
from typing import Tuple, Optional

Frac = Tuple[int, int]

def py_reduce_fraction(frac: Frac) -> Optional[Frac]:
    """Reduce a fraction represented as a 2-tuple of integers"""
    num, denom = frac
    if denom:
        answer = Fraction(num, denom)
        return answer.numerator, answer.denominator
    return None

assert py_reduce_fraction((3,6)) == (1, 2)
assert py_reduce_fraction((1, 0)) is None
In [60]:
def py_fraction_to_float(frac: Frac) -> Optional[float]:
    """Transforms a fraction represented as a 2-tuple of integer into a float"""
    num, denom = frac
    if denom:
        return num / denom
    return None

assert py_fraction_to_float((2, 8)) == 0.25
assert py_fraction_to_float((10, 0)) is None

Promoting Python functions to udf

The function takes two parameters.

  1. The function you want to promote.
  2. Optionally, the return type of the generated UDF.
Option 1: Creating a UDF explicitily with udf() and apply it to dataframe
In [63]:
SparkFrac = T.ArrayType(T.LongType())

# Promote python func to udf, passing SparkFrac type alias
reduce_fraction = F.udf(py_reduce_fraction, SparkFrac)

# apply to existing dataframe
frac_df = frac_df.withColumn(
    "reduced_fraction", reduce_fraction(F.col("fraction"))
)

frac_df.show(5, False)
+--------+----------------+
|fraction|reduced_fraction|
+--------+----------------+
|[0, 1]  |[0, 1]          |
|[0, 2]  |[0, 1]          |
|[0, 3]  |[0, 1]          |
|[0, 4]  |[0, 1]          |
|[0, 5]  |[0, 1]          |
+--------+----------------+
only showing top 5 rows

Option 2: Creating a UDF directly using udf() decorator
In [67]:
@F.udf(T.DoubleType())
def fraction_to_float(frac: Frac) -> Optional[float]:
    num, denom = frac
    if denom:
        return num / denom
    return None


frac_df = frac_df.withColumn(
    "fraction_float", fraction_to_float(F.col("reduced_fraction"))
)

frac_df.select("reduced_fraction", "fraction_float").distinct().show(5, False)

assert fraction_to_float.func((1, 2)) == 0.5
+----------------+-------------------+
|reduced_fraction|fraction_float     |
+----------------+-------------------+
|[3, 50]         |0.06               |
|[3, 67]         |0.04477611940298507|
|[7, 76]         |0.09210526315789473|
|[9, 23]         |0.391304347826087  |
|[9, 25]         |0.36               |
+----------------+-------------------+
only showing top 5 rows

Chapter 9: Using Pandas UDF

Summary

  • Pandas UDFs allow you to take code that works on Pandas data frames and scale it to the Spark Data Frame structure. Efficient serialization between the two data structures is ensured by PyArrow.
  • We can group Pandas UDF into two main families, depending on the level of control we need over the batches. Series and Iterator of Series (and Iterator of data frame/mapInPandas) will batch efficiently with the user having no control over the batch composition.
  • If you need control over the content of each batch, you can use grouped data UDF with the split-apply-combing programming pattern. PySpark provides access to the values inside each batch of a GroupedData object either as Series (group aggregate UDF) of as data frame (group map UDF).

Requirements

This chapter will use:

  1. pandas
  2. scikit-learn
  3. PyArrow

The chapter assumes you are using PySpark 3.0 and above.

Column transformation using Series UDF

Types of Series UDF

Series to Series

  • Takes Columns objects, converts to Pandas Series and return Series object that gets promoted back to Column object.

Iterator of Series to Iterator of Series

  • Column is batched, then fed as a Iterator object.
  • Takes single Column, returns single Column
  • Good when you need to initialize an expensive state

Iterator of multiples Series to Iterator of Series

  • Takes multiple Columns as input but preserves iterator pattern.

Dataset - Google BigQuery

We will use the National Oceanic and Atmospheric Administration (NOAA) Global Surface Summary of the Day (GSOD) dataset.

Steps to data connection
  1. Install and configure the connector (if necessary), following the vendor’s documentation.
  2. Customize the SparkReader object to account for the new data source type.
  3. Read the data, authenticating as needed.
Installation + Configuration

After setting up Google Cloud Platform account, intiialize PySpark with the BigQuery connector enabled

Errata

The code below doesn't work due to a lot of issues with PyArrow compatability with Java 11. I've skipped this part and just downloaded the dataset from the author's github.

Reference:

In [ ]:
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession

conf = SparkConf()
conf.set(
    "spark.executor.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true"
)
conf.set("spark.driver.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")
conf.set(
    "spark.jars.packages",
    "com.google.cloud.spark:spark-bigquery-with-dependencies_2.12:0.19.1",
)

# spark = (
#     SparkSession.builder
#     .config(
#         "spark.jars.packages",
#         "com.google.cloud.spark:spark-bigquery-with-dependencies_2.12:0.19.1",
#     )
#     .config(
#         "spark.driver.extraJavaOptions",
#         "-Dio.netty.tryReflectionSetAccessible=true"
#     )
#     .config(
#         "spark.executor.extraJavaOptions",
#         "-Dio.netty.tryReflectionSetAccessible=true"
#     )
#     .getOrCreate()
# )

spark = SparkSession.builder.config(conf=conf).getOrCreate()

After initializing, read the stations and gsod tables for 2010 to 2020

In [ ]:
from functools import reduce
import pyspark.sql.functions as F


def read_df_from_bq(year):
    return (
        spark.read.format("bigquery").option(
            "table", f"bigquery-public-data.noaa_gsod.gsod{year}"
        )
        .option("credentialsFile", "/Users/taichinakatani/dotfiles/keys/bq-key.json")
        .option("parentProject", "still-vim-244001")
        .load()
    )


# Because gsod2020 has an additional date column that the previous years do not have,
# unionByName will fill the values with null
gsod = (
    reduce(
        lambda x, y: x.unionByName(y, allowMissingColumns=True),
        [read_df_from_bq(year) for year in range(2020, 2021)],
    )
    .dropna(subset=["year", "mo", "da", "temp"])
    .where(F.col("temp") != 9999.9)
    .drop("date")
)
In [ ]:
gsod.select(F.col('year')).show(5)

Read data locally

In [3]:
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
conf = SparkConf()
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

spark = SparkSession.builder.config(conf=conf).getOrCreate()

# Read from local parquet instead
gsod = spark.read.load("data/gsod_noaa/gsod2018.parquet")

Series to Series UDF

  • Python UDFs work on one record at a time, while Scalar UDF work on one Series at a time and is written through Pandas code.
  • Pandas has simpler data types than PySpark, so need to be careful to align the types. pandas_udf helps with this.

Converting Fahrenheit to Celsius with a S-to-S UDF

Errata

Using the pandas_udf decorator is killing the kernel for some reason.

In [5]:
import pandas as pd
import pyspark.sql.types as T
import pyspark.sql.functions as F

# note the syntax "pandas_udf" and how it returns a pd.Series
# @F.pandas_udf(T.DoubleType())
def f_to_c(degrees: pd.Series) -> pd.Series:
    """Transforms Farhenheit to Celcius."""
    return (degrees - 32) * 5 / 9
In [6]:
gsod = gsod.withColumn("temp_c", f_to_c(F.col("temp")))
gsod.select("temp", "temp_c").distinct().show(5)
+----+------------------+
|temp|            temp_c|
+----+------------------+
|37.2|2.8888888888888906|
|71.6|21.999999999999996|
|53.5|11.944444444444445|
|24.7|-4.055555555555555|
|70.4|21.333333333333336|
+----+------------------+
only showing top 5 rows

Iterator of Series UDF

  • signature goes from (pd.Series) → pd.Series to (Iterator[pd.Series]) → Iterator[pd.Series]
  • Since we are working with an Iterator of Series, we are explicitly iterating over each batch one by one. PySpark will take care of distributing the work for us.
  • Uses yield than return so function returns an iterator
In [ ]:
from time import sleep
from typing import Iterator


@F.pandas_udf(T.DoubleType())
def f_to_c2(degrees: Iterator[pd.Series]) -> Iterator[pd.Series]:
    """Transforms Farhenheit to Celcius."""
    sleep(5)
    for batch in degrees:
        yield (batch - 32) * 5 / 9


gsod.select(
    "temp", f_to_c2(F.col("temp")).alias("temp_c")
).distinct().show(5)

# +-----+-------------------+
# | temp|             temp_c|
# +-----+-------------------+
# | 37.2| 2.8888888888888906|
# | 85.9| 29.944444444444443|
# | 53.5| 11.944444444444445|
# | 71.6| 21.999999999999996|
# |-27.6|-33.111111111111114|
# +-----+-------------------+
# only showing top 5 rows
In [ ]:
 

Chapter 10: Window functions

Dataset

We will use the National Oceanic and Atmospheric Administration (NOAA) Global Surface Summary of the Day (GSOD) dataset.

In [6]:
# Setup

from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
import pyspark.sql.functions as F


conf = SparkConf()
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark = SparkSession.builder.config(conf=conf).getOrCreate()

# Read from local parquet
gsod = spark.read.parquet("data/gsod_noaa/gsod*.parquet")

Summarizing data with over

Q: When was the lowest temperature recorded each year?

In [13]:
# Using vanilla groupBy, we can get the lowest temperature but not when.

coldest_temp = gsod.groupby("year").agg(F.min("temp").alias("temp"))
coldest_temp.orderBy("temp").show()


# Using left-semi self-join to get the "when"
# Self joins are generally an anti-pattern because it is SLOW.

coldest_when = gsod.join(coldest_temp, how="left_semi", on=["year", "temp"]) \
                   .select("stn", "year", "mo", "da", "temp")
coldest_when.orderBy("year", "mo", "da").show()
+----+------+
|year|  temp|
+----+------+
|2019|-114.7|
|2017|-114.7|
|2012|-113.5|
|2018|-113.5|
|2016|-111.7|
|2013|-110.7|
|2010|-110.7|
|2014|-110.5|
|2015|-110.2|
|2011|-106.8|
|2020|-105.0|
+----+------+

+------+----+---+---+------+
|   stn|year| mo| da|  temp|
+------+----+---+---+------+
|896060|2010| 06| 03|-110.7|
|896060|2011| 05| 19|-106.8|
|896060|2012| 06| 11|-113.5|
|895770|2013| 07| 31|-110.7|
|896060|2014| 08| 20|-110.5|
|895360|2015| 07| 12|-110.2|
|896060|2015| 08| 21|-110.2|
|896060|2015| 08| 27|-110.2|
|896060|2016| 07| 11|-111.7|
|896250|2017| 06| 20|-114.7|
|896060|2018| 08| 27|-113.5|
|895770|2019| 06| 15|-114.7|
|896060|2020| 08| 11|-105.0|
|896250|2020| 08| 13|-105.0|
+------+----+---+---+------+

In [8]:
# Using a window function instead

from pyspark.sql.window import Window

# To partition according to the values of one or more columns, 
# we pass the column name (or a Column object) to the partitionBy() method.
each_year = Window.partitionBy("year")

# Window is a builder class, just like SparkSession.builder
print(each_year)
<pyspark.sql.window.WindowSpec object at 0x139ed3250>

Using a window function

  • each_year runs the aggregate function F.min("temp") over each year, rather than the entire data frame.
  • F.min("temp") applies the minimum temperature for that year to all rows. This is then filtered to rows with temp that matches the aggregate min_temp.
In [9]:
# Use the each_year builder class

gsod.withColumn("min_temp", F.min("temp").over(each_year)).where(
    "temp = min_temp"
).select("year", "mo", "da", "stn", "temp").orderBy(
    "year", "mo", "da"
).show()
+----+---+---+------+------+
|year| mo| da|   stn|  temp|
+----+---+---+------+------+
|2010| 06| 03|896060|-110.7|
|2011| 05| 19|896060|-106.8|
|2012| 06| 11|896060|-113.5|
|2013| 07| 31|895770|-110.7|
|2014| 08| 20|896060|-110.5|
|2015| 07| 12|895360|-110.2|
|2015| 08| 21|896060|-110.2|
|2015| 08| 27|896060|-110.2|
|2016| 07| 11|896060|-111.7|
|2017| 06| 20|896250|-114.7|
|2018| 08| 27|896060|-113.5|
|2019| 06| 15|895770|-114.7|
|2020| 08| 11|896060|-105.0|
|2020| 08| 13|896250|-105.0|
+----+---+---+------+------+

Bonus:

  • partitionBy() can be used on more than one column
  • You can also directly use a window function inside a select:
In [10]:
# Using window function inside a select
gsod.select(
    "year",
    "mo",
    "da",
    "stn",
    "temp",
    F.min("temp").over(each_year).alias("min_temp"),
).where("temp = min_temp").drop("min_temp").orderBy(
    "year", "mo", "da"
).show()
+----+---+---+------+------+
|year| mo| da|   stn|  temp|
+----+---+---+------+------+
|2010| 06| 03|896060|-110.7|
|2011| 05| 19|896060|-106.8|
|2012| 06| 11|896060|-113.5|
|2013| 07| 31|895770|-110.7|
|2014| 08| 20|896060|-110.5|
|2015| 07| 12|895360|-110.2|
|2015| 08| 21|896060|-110.2|
|2015| 08| 27|896060|-110.2|
|2016| 07| 11|896060|-111.7|
|2017| 06| 20|896250|-114.7|
|2018| 08| 27|896060|-113.5|
|2019| 06| 15|895770|-114.7|
|2020| 08| 11|896060|-105.0|
|2020| 08| 13|896250|-105.0|
+----+---+---+------+------+

Ranking functions

  • Rank functions rank records based on the value of a field.
  • Functions: rank(), dense_rank(), percent_rank(), ntile() and row_number()
In [16]:
# Load lightweight dataset
gsod_light = spark.read.parquet("data/Window/gsod_light.parquet")
In [17]:
# Inspect
gsod_light.printSchema()
gsod_light.show()
root
 |-- stn: string (nullable = true)
 |-- year: string (nullable = true)
 |-- mo: string (nullable = true)
 |-- da: string (nullable = true)
 |-- temp: double (nullable = true)
 |-- count_temp: long (nullable = true)

+------+----+---+---+----+----------+
|   stn|year| mo| da|temp|count_temp|
+------+----+---+---+----+----------+
|994979|2017| 12| 11|21.3|        21|
|998012|2017| 03| 02|31.4|        24|
|719200|2017| 10| 09|60.5|        11|
|917350|2018| 04| 21|82.6|         9|
|076470|2018| 06| 07|65.0|        24|
|996470|2018| 03| 12|55.6|        12|
|041680|2019| 02| 19|16.1|        15|
|949110|2019| 11| 23|54.9|        14|
|998252|2019| 04| 18|44.7|        11|
|998166|2019| 03| 20|34.8|        12|
+------+----+---+---+----+----------+

rank & dense_rank

  • rank gives Olympic ranking (non-consecutive, when you have multiple records that tie for a rank, the next one will be offset by the number of ties)
  • dense_rank ranks consecutively. Ties share the same rank, but there won’t be any gap between the ranks. Useful when you just want a cardinal position over a window.
In [17]:
# Inspect
gsod_light.printSchema()
gsod_light.show()
root
 |-- stn: string (nullable = true)
 |-- year: string (nullable = true)
 |-- mo: string (nullable = true)
 |-- da: string (nullable = true)
 |-- temp: double (nullable = true)
 |-- count_temp: long (nullable = true)

+------+----+---+---+----+----------+
|   stn|year| mo| da|temp|count_temp|
+------+----+---+---+----+----------+
|994979|2017| 12| 11|21.3|        21|
|998012|2017| 03| 02|31.4|        24|
|719200|2017| 10| 09|60.5|        11|
|917350|2018| 04| 21|82.6|         9|
|076470|2018| 06| 07|65.0|        24|
|996470|2018| 03| 12|55.6|        12|
|041680|2019| 02| 19|16.1|        15|
|949110|2019| 11| 23|54.9|        14|
|998252|2019| 04| 18|44.7|        11|
|998166|2019| 03| 20|34.8|        12|
+------+----+---+---+----+----------+

In [31]:
# Create new window, partitioning by year and ordering by number of temperature readings
temp_per_year_asc = Window.partitionBy("year").orderBy("count_temp")
temp_per_month_asc = Window.partitionBy("mo").orderBy("count_temp")


# Using rank() with window, we get the rank accordintg the value of count_temp column
print("Using rank()")
gsod_light.withColumn("rank_tpm", F.rank().over(temp_per_month_asc)).show()


# Using dense_rank() instead to get consecutive ranking by month
print("Using dense_rank()")
gsod_light.withColumn("rank_tpm", F.dense_rank().over(temp_per_month_asc)).show()
Using rank()
+------+----+---+---+----+----------+--------+
|   stn|year| mo| da|temp|count_temp|rank_tpm|
+------+----+---+---+----+----------+--------+
|949110|2019| 11| 23|54.9|        14|       1|
|996470|2018| 03| 12|55.6|        12|       1|
|998166|2019| 03| 20|34.8|        12|       1|
|998012|2017| 03| 02|31.4|        24|       3|
|041680|2019| 02| 19|16.1|        15|       1|
|076470|2018| 06| 07|65.0|        24|       1|
|719200|2017| 10| 09|60.5|        11|       1|
|994979|2017| 12| 11|21.3|        21|       1|
|917350|2018| 04| 21|82.6|         9|       1|
|998252|2019| 04| 18|44.7|        11|       2|
+------+----+---+---+----+----------+--------+

Using dense_rank()
+------+----+---+---+----+----------+--------+
|   stn|year| mo| da|temp|count_temp|rank_tpm|
+------+----+---+---+----+----------+--------+
|949110|2019| 11| 23|54.9|        14|       1|
|996470|2018| 03| 12|55.6|        12|       1|
|998166|2019| 03| 20|34.8|        12|       1|
|998012|2017| 03| 02|31.4|        24|       2|
|041680|2019| 02| 19|16.1|        15|       1|
|076470|2018| 06| 07|65.0|        24|       1|
|719200|2017| 10| 09|60.5|        11|       1|
|994979|2017| 12| 11|21.3|        21|       1|
|917350|2018| 04| 21|82.6|         9|       1|
|998252|2019| 04| 18|44.7|        11|       2|
+------+----+---+---+----+----------+--------+

percent_rank

For every window percent_rank() computes percentage rank (0-1) based on ordered value.

formula = # records with lower value than the current / # of records in the window - 1

In [34]:
temp_each_year = each_year.orderBy("temp")


gsod_light.withColumn("rank_tpm", F.percent_rank().over(temp_each_year)).show()
+------+----+---+---+----+----------+------------------+
|   stn|year| mo| da|temp|count_temp|          rank_tpm|
+------+----+---+---+----+----------+------------------+
|041680|2019| 02| 19|16.1|        15|               0.0|
|998166|2019| 03| 20|34.8|        12|0.3333333333333333|
|998252|2019| 04| 18|44.7|        11|0.6666666666666666|
|949110|2019| 11| 23|54.9|        14|               1.0|
|994979|2017| 12| 11|21.3|        21|               0.0|
|998012|2017| 03| 02|31.4|        24|               0.5|
|719200|2017| 10| 09|60.5|        11|               1.0|
|996470|2018| 03| 12|55.6|        12|               0.0|
|076470|2018| 06| 07|65.0|        24|               0.5|
|917350|2018| 04| 21|82.6|         9|               1.0|
+------+----+---+---+----+----------+------------------+

ntile()

Gives n-tile for a given param.

In [35]:
gsod_light.withColumn("rank_tpm", F.ntile(2).over(temp_each_year)).show()
+------+----+---+---+----+----------+--------+
|   stn|year| mo| da|temp|count_temp|rank_tpm|
+------+----+---+---+----+----------+--------+
|041680|2019| 02| 19|16.1|        15|       1|
|998166|2019| 03| 20|34.8|        12|       1|
|998252|2019| 04| 18|44.7|        11|       2|
|949110|2019| 11| 23|54.9|        14|       2|
|994979|2017| 12| 11|21.3|        21|       1|
|998012|2017| 03| 02|31.4|        24|       1|
|719200|2017| 10| 09|60.5|        11|       2|
|996470|2018| 03| 12|55.6|        12|       1|
|076470|2018| 06| 07|65.0|        24|       1|
|917350|2018| 04| 21|82.6|         9|       2|
+------+----+---+---+----+----------+--------+

row_number()

Given an ordered window, it will give a increasing rank regardless of ties.

In [39]:
gsod_light.withColumn("row_number", F.row_number().over(temp_each_year)).show()
+------+----+---+---+----+----------+----------+
|   stn|year| mo| da|temp|count_temp|row_number|
+------+----+---+---+----+----------+----------+
|041680|2019| 02| 19|16.1|        15|         1|
|998166|2019| 03| 20|34.8|        12|         2|
|998252|2019| 04| 18|44.7|        11|         3|
|949110|2019| 11| 23|54.9|        14|         4|
|994979|2017| 12| 11|21.3|        21|         1|
|998012|2017| 03| 02|31.4|        24|         2|
|719200|2017| 10| 09|60.5|        11|         3|
|996470|2018| 03| 12|55.6|        12|         1|
|076470|2018| 06| 07|65.0|        24|         2|
|917350|2018| 04| 21|82.6|         9|         3|
+------+----+---+---+----+----------+----------+

In [38]:
# Creating a window with a descending ordered column

temp_per_month_desc = Window.partitionBy("mo").orderBy(F.col("count_temp").desc())

gsod_light.withColumn("row_number", F.row_number().over(temp_per_month_desc)).show()
+------+----+---+---+----+----------+----------+
|   stn|year| mo| da|temp|count_temp|row_number|
+------+----+---+---+----+----------+----------+
|949110|2019| 11| 23|54.9|        14|         1|
|998012|2017| 03| 02|31.4|        24|         1|
|996470|2018| 03| 12|55.6|        12|         2|
|998166|2019| 03| 20|34.8|        12|         3|
|041680|2019| 02| 19|16.1|        15|         1|
|076470|2018| 06| 07|65.0|        24|         1|
|719200|2017| 10| 09|60.5|        11|         1|
|994979|2017| 12| 11|21.3|        21|         1|
|998252|2019| 04| 18|44.7|        11|         1|
|917350|2018| 04| 21|82.6|         9|         2|
+------+----+---+---+----+----------+----------+

Analytic functions: looking back and ahead

lag and lead

The two most important functions of the analytics functions family are called lag(col, n=1, default=None) and lead(col, n=1, default=None), which will give you the value of the col column of the n-th record before and after the record you’re over, respectively.

In [50]:
# Get temp of previous two records using lag()

print("Temp of previous two records over each year")
gsod_light.withColumn(
    "previous_temp", F.lag("temp").over(temp_each_year)
).withColumn(
    "previous_temp_2", F.lag("temp", 2).over(temp_each_year)
).show()


print("Temp delta of previous record over each year")
gsod_light.withColumn(
    "previous_temp_delta", F.round(F.col("temp") - F.lag("temp").over(temp_each_year), 2)
).select(["year", "mo", "temp", "previous_temp_delta"]).show()
Temp of previous two records over each year
+------+----+---+---+----+----------+-------------+---------------+
|   stn|year| mo| da|temp|count_temp|previous_temp|previous_temp_2|
+------+----+---+---+----+----------+-------------+---------------+
|041680|2019| 02| 19|16.1|        15|         null|           null|
|998166|2019| 03| 20|34.8|        12|         16.1|           null|
|998252|2019| 04| 18|44.7|        11|         34.8|           16.1|
|949110|2019| 11| 23|54.9|        14|         44.7|           34.8|
|994979|2017| 12| 11|21.3|        21|         null|           null|
|998012|2017| 03| 02|31.4|        24|         21.3|           null|
|719200|2017| 10| 09|60.5|        11|         31.4|           21.3|
|996470|2018| 03| 12|55.6|        12|         null|           null|
|076470|2018| 06| 07|65.0|        24|         55.6|           null|
|917350|2018| 04| 21|82.6|         9|         65.0|           55.6|
+------+----+---+---+----+----------+-------------+---------------+

Temp delta of previous record over each year
+----+---+----+-------------------+
|year| mo|temp|previous_temp_delta|
+----+---+----+-------------------+
|2019| 02|16.1|               null|
|2019| 03|34.8|               18.7|
|2019| 04|44.7|                9.9|
|2019| 11|54.9|               10.2|
|2017| 12|21.3|               null|
|2017| 03|31.4|               10.1|
|2017| 10|60.5|               29.1|
|2018| 03|55.6|               null|
|2018| 06|65.0|                9.4|
|2018| 04|82.6|               17.6|
+----+---+----+-------------------+

cume_dist()

  • Provides cumulative distribution rather than ranking. Useful for EDA of cume-distro of variables.
  • Does not rank, but provides the cumulative density function F(x) for the records in the data frame.
In [52]:
print("Percent rank vs. Cumulative distribution of temperature over each year")
gsod_light.withColumn(
    "percen_rank" , F.percent_rank().over(temp_each_year)
).withColumn("cume_dist", F.cume_dist().over(temp_each_year)).show()
Percent rank vs. Cumulative distribution of temperature over each year
+------+----+---+---+----+----------+------------------+------------------+
|   stn|year| mo| da|temp|count_temp|       percen_rank|         cume_dist|
+------+----+---+---+----+----------+------------------+------------------+
|041680|2019| 02| 19|16.1|        15|               0.0|              0.25|
|998166|2019| 03| 20|34.8|        12|0.3333333333333333|               0.5|
|998252|2019| 04| 18|44.7|        11|0.6666666666666666|              0.75|
|949110|2019| 11| 23|54.9|        14|               1.0|               1.0|
|994979|2017| 12| 11|21.3|        21|               0.0|0.3333333333333333|
|998012|2017| 03| 02|31.4|        24|               0.5|0.6666666666666666|
|719200|2017| 10| 09|60.5|        11|               1.0|               1.0|
|996470|2018| 03| 12|55.6|        12|               0.0|0.3333333333333333|
|076470|2018| 06| 07|65.0|        24|               0.5|0.6666666666666666|
|917350|2018| 04| 21|82.6|         9|               1.0|               1.0|
+------+----+---+---+----+----------+------------------+------------------+

...more to come

In [ ]: