Apache Spark Optimization
Production patterns for optimizing Apache Spark jobs including partitioning strategies, memory management, shuffle optimization, and performance tuning.
When to Use This Skill
-
Optimizing slow Spark jobs
-
Tuning memory and executor configuration
-
Implementing efficient partitioning strategies
-
Debugging Spark performance issues
-
Scaling Spark pipelines for large datasets
-
Reducing shuffle and data skew
Core Concepts
- Spark Execution Model
Driver Program ↓ Job (triggered by action) ↓ Stages (separated by shuffles) ↓ Tasks (one per partition)
- Key Performance Factors
Factor Impact Solution
Shuffle Network I/O, disk I/O Minimize wide transformations
Data Skew Uneven task duration Salting, broadcast joins
Serialization CPU overhead Use Kryo, columnar formats
Memory GC pressure, spills Tune executor memory
Partitions Parallelism Right-size partitions
Quick Start
from pyspark.sql import SparkSession from pyspark.sql import functions as F
Create optimized Spark session
spark = (SparkSession.builder .appName("OptimizedJob") .config("spark.sql.adaptive.enabled", "true") .config("spark.sql.adaptive.coalescePartitions.enabled", "true") .config("spark.sql.adaptive.skewJoin.enabled", "true") .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .config("spark.sql.shuffle.partitions", "200") .getOrCreate())
Read with optimized settings
df = (spark.read .format("parquet") .option("mergeSchema", "false") .load("s3://bucket/data/"))
Efficient transformations
result = (df .filter(F.col("date") >= "2024-01-01") .select("id", "amount", "category") .groupBy("category") .agg(F.sum("amount").alias("total")))
result.write.mode("overwrite").parquet("s3://bucket/output/")
Patterns
Pattern 1: Optimal Partitioning
Calculate optimal partition count
def calculate_partitions(data_size_gb: float, partition_size_mb: int = 128) -> int: """ Optimal partition size: 128MB - 256MB Too few: Under-utilization, memory pressure Too many: Task scheduling overhead """ return max(int(data_size_gb * 1024 / partition_size_mb), 1)
Repartition for even distribution
df_repartitioned = df.repartition(200, "partition_key")
Coalesce to reduce partitions (no shuffle)
df_coalesced = df.coalesce(100)
Partition pruning with predicate pushdown
df = (spark.read.parquet("s3://bucket/data/") .filter(F.col("date") == "2024-01-01")) # Spark pushes this down
Write with partitioning for future queries
(df.write .partitionBy("year", "month", "day") .mode("overwrite") .parquet("s3://bucket/partitioned_output/"))
Pattern 2: Join Optimization
from pyspark.sql import functions as F from pyspark.sql.types import *
1. Broadcast Join - Small table joins
Best when: One side < 10MB (configurable)
small_df = spark.read.parquet("s3://bucket/small_table/") # < 10MB large_df = spark.read.parquet("s3://bucket/large_table/") # TBs
Explicit broadcast hint
result = large_df.join( F.broadcast(small_df), on="key", how="left" )
2. Sort-Merge Join - Default for large tables
Requires shuffle, but handles any size
result = large_df1.join(large_df2, on="key", how="inner")
3. Bucket Join - Pre-sorted, no shuffle at join time
Write bucketed tables
(df.write .bucketBy(200, "customer_id") .sortBy("customer_id") .mode("overwrite") .saveAsTable("bucketed_orders"))
Join bucketed tables (no shuffle!)
orders = spark.table("bucketed_orders") customers = spark.table("bucketed_customers") # Same bucket count result = orders.join(customers, on="customer_id")
4. Skew Join Handling
Enable AQE skew join optimization
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true") spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5") spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
Manual salting for severe skew
def salt_join(df_skewed, df_other, key_col, num_salts=10): """Add salt to distribute skewed keys""" # Add salt to skewed side df_salted = df_skewed.withColumn( "salt", (F.rand() * num_salts).cast("int") ).withColumn( "salted_key", F.concat(F.col(key_col), F.lit("_"), F.col("salt")) )
# Explode other side with all salts
df_exploded = df_other.crossJoin(
spark.range(num_salts).withColumnRenamed("id", "salt")
).withColumn(
"salted_key",
F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
)
# Join on salted key
return df_salted.join(df_exploded, on="salted_key", how="inner")
Pattern 3: Caching and Persistence
from pyspark import StorageLevel
Cache when reusing DataFrame multiple times
df = spark.read.parquet("s3://bucket/data/") df_filtered = df.filter(F.col("status") == "active")
Cache in memory (MEMORY_AND_DISK is default)
df_filtered.cache()
Or with specific storage level
df_filtered.persist(StorageLevel.MEMORY_AND_DISK_SER)
Force materialization
df_filtered.count()
Use in multiple actions
agg1 = df_filtered.groupBy("category").count() agg2 = df_filtered.groupBy("region").sum("amount")
Unpersist when done
df_filtered.unpersist()
Storage levels explained:
MEMORY_ONLY - Fast, but may not fit
MEMORY_AND_DISK - Spills to disk if needed (recommended)
MEMORY_ONLY_SER - Serialized, less memory, more CPU
DISK_ONLY - When memory is tight
OFF_HEAP - Tungsten off-heap memory
Checkpoint for complex lineage
spark.sparkContext.setCheckpointDir("s3://bucket/checkpoints/") df_complex = (df .join(other_df, "key") .groupBy("category") .agg(F.sum("amount"))) df_complex.checkpoint() # Breaks lineage, materializes
Pattern 4: Memory Tuning
Executor memory configuration
spark-submit --executor-memory 8g --executor-cores 4
Memory breakdown (8GB executor):
- spark.memory.fraction = 0.6 (60% = 4.8GB for execution + storage)
- spark.memory.storageFraction = 0.5 (50% of 4.8GB = 2.4GB for cache)
- Remaining 2.4GB for execution (shuffles, joins, sorts)
- 40% = 3.2GB for user data structures and internal metadata
spark = (SparkSession.builder .config("spark.executor.memory", "8g") .config("spark.executor.memoryOverhead", "2g") # For non-JVM memory .config("spark.memory.fraction", "0.6") .config("spark.memory.storageFraction", "0.5") .config("spark.sql.shuffle.partitions", "200") # For memory-intensive operations .config("spark.sql.autoBroadcastJoinThreshold", "50MB") # Prevent OOM on large shuffles .config("spark.sql.files.maxPartitionBytes", "128MB") .getOrCreate())
Monitor memory usage
def print_memory_usage(spark): """Print current memory usage""" sc = spark.sparkContext for executor in sc._jsc.sc().getExecutorMemoryStatus().keySet().toArray(): mem_status = sc._jsc.sc().getExecutorMemoryStatus().get(executor) total = mem_status._1() / (10243) free = mem_status._2() / (10243) print(f"{executor}: {total:.2f}GB total, {free:.2f}GB free")
Pattern 5: Shuffle Optimization
Reduce shuffle data size
spark.conf.set("spark.sql.shuffle.partitions", "auto") # With AQE spark.conf.set("spark.shuffle.compress", "true") spark.conf.set("spark.shuffle.spill.compress", "true")
Pre-aggregate before shuffle
df_optimized = (df # Local aggregation first (combiner) .groupBy("key", "partition_col") .agg(F.sum("value").alias("partial_sum")) # Then global aggregation .groupBy("key") .agg(F.sum("partial_sum").alias("total")))
Avoid shuffle with map-side operations
BAD: Shuffle for each distinct
distinct_count = df.select("category").distinct().count()
GOOD: Approximate distinct (no shuffle)
approx_count = df.select(F.approx_count_distinct("category")).collect()[0][0]
Use coalesce instead of repartition when reducing partitions
df_reduced = df.coalesce(10) # No shuffle
Optimize shuffle with compression
spark.conf.set("spark.io.compression.codec", "lz4") # Fast compression
Pattern 6: Data Format Optimization
Parquet optimizations
(df.write .option("compression", "snappy") # Fast compression .option("parquet.block.size", 128 * 1024 * 1024) # 128MB row groups .parquet("s3://bucket/output/"))
Column pruning - only read needed columns
df = (spark.read.parquet("s3://bucket/data/") .select("id", "amount", "date")) # Spark only reads these columns
Predicate pushdown - filter at storage level
df = (spark.read.parquet("s3://bucket/partitioned/year=2024/") .filter(F.col("status") == "active")) # Pushed to Parquet reader
Delta Lake optimizations
(df.write .format("delta") .option("optimizeWrite", "true") # Bin-packing .option("autoCompact", "true") # Compact small files .mode("overwrite") .save("s3://bucket/delta_table/"))
Z-ordering for multi-dimensional queries
spark.sql("""
OPTIMIZE delta.s3://bucket/delta_table/
ZORDER BY (customer_id, date)
""")
Pattern 7: Monitoring and Debugging
Enable detailed metrics
spark.conf.set("spark.sql.codegen.wholeStage", "true") spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
Explain query plan
df.explain(mode="extended")
Modes: simple, extended, codegen, cost, formatted
Get physical plan statistics
df.explain(mode="cost")
Monitor task metrics
def analyze_stage_metrics(spark): """Analyze recent stage metrics""" status_tracker = spark.sparkContext.statusTracker()
for stage_id in status_tracker.getActiveStageIds():
stage_info = status_tracker.getStageInfo(stage_id)
print(f"Stage {stage_id}:")
print(f" Tasks: {stage_info.numTasks}")
print(f" Completed: {stage_info.numCompletedTasks}")
print(f" Failed: {stage_info.numFailedTasks}")
Identify data skew
def check_partition_skew(df): """Check for partition skew""" partition_counts = (df .withColumn("partition_id", F.spark_partition_id()) .groupBy("partition_id") .count() .orderBy(F.desc("count")))
partition_counts.show(20)
stats = partition_counts.select(
F.min("count").alias("min"),
F.max("count").alias("max"),
F.avg("count").alias("avg"),
F.stddev("count").alias("stddev")
).collect()[0]
skew_ratio = stats["max"] / stats["avg"]
print(f"Skew ratio: {skew_ratio:.2f}x (>2x indicates skew)")
Configuration Cheat Sheet
Production configuration template
spark_configs = { # Adaptive Query Execution (AQE) "spark.sql.adaptive.enabled": "true", "spark.sql.adaptive.coalescePartitions.enabled": "true", "spark.sql.adaptive.skewJoin.enabled": "true",
# Memory
"spark.executor.memory": "8g",
"spark.executor.memoryOverhead": "2g",
"spark.memory.fraction": "0.6",
"spark.memory.storageFraction": "0.5",
# Parallelism
"spark.sql.shuffle.partitions": "200",
"spark.default.parallelism": "200",
# Serialization
"spark.serializer": "org.apache.spark.serializer.KryoSerializer",
"spark.sql.execution.arrow.pyspark.enabled": "true",
# Compression
"spark.io.compression.codec": "lz4",
"spark.shuffle.compress": "true",
# Broadcast
"spark.sql.autoBroadcastJoinThreshold": "50MB",
# File handling
"spark.sql.files.maxPartitionBytes": "128MB",
"spark.sql.files.openCostInBytes": "4MB",
}
Best Practices
Do's
-
Enable AQE - Adaptive query execution handles many issues
-
Use Parquet/Delta - Columnar formats with compression
-
Broadcast small tables - Avoid shuffle for small joins
-
Monitor Spark UI - Check for skew, spills, GC
-
Right-size partitions - 128MB - 256MB per partition
Don'ts
-
Don't collect large data - Keep data distributed
-
Don't use UDFs unnecessarily - Use built-in functions
-
Don't over-cache - Memory is limited
-
Don't ignore data skew - It dominates job time
-
Don't use .count() for existence - Use .take(1) or .isEmpty()