Apache Spark Optimization
Production patterns for optimizing Apache Spark jobs including partitioning strategies, memory management, shuffle optimization, and performance tuning.
When to Use This Skill Optimizing slow Spark jobs Tuning memory and executor configuration Implementing efficient partitioning strategies Debugging Spark performance issues Scaling Spark pipelines for large datasets Reducing shuffle and data skew Core Concepts 1. Spark Execution Model Driver Program ↓ Job (triggered by action) ↓ Stages (separated by shuffles) ↓ Tasks (one per partition)
- Key Performance Factors Factor Impact Solution Shuffle Network I/O, disk I/O Minimize wide transformations Data Skew Uneven task duration Salting, broadcast joins Serialization CPU overhead Use Kryo, columnar formats Memory GC pressure, spills Tune executor memory Partitions Parallelism Right-size partitions Quick Start from pyspark.sql import SparkSession from pyspark.sql import functions as F
Create optimized Spark session
spark = (SparkSession.builder .appName("OptimizedJob") .config("spark.sql.adaptive.enabled", "true") .config("spark.sql.adaptive.coalescePartitions.enabled", "true") .config("spark.sql.adaptive.skewJoin.enabled", "true") .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .config("spark.sql.shuffle.partitions", "200") .getOrCreate())
Read with optimized settings
df = (spark.read .format("parquet") .option("mergeSchema", "false") .load("s3://bucket/data/"))
Efficient transformations
result = (df .filter(F.col("date") >= "2024-01-01") .select("id", "amount", "category") .groupBy("category") .agg(F.sum("amount").alias("total")))
result.write.mode("overwrite").parquet("s3://bucket/output/")
Patterns Pattern 1: Optimal Partitioning
Calculate optimal partition count
def calculate_partitions(data_size_gb: float, partition_size_mb: int = 128) -> int: """ Optimal partition size: 128MB - 256MB Too few: Under-utilization, memory pressure Too many: Task scheduling overhead """ return max(int(data_size_gb * 1024 / partition_size_mb), 1)
Repartition for even distribution
df_repartitioned = df.repartition(200, "partition_key")
Coalesce to reduce partitions (no shuffle)
df_coalesced = df.coalesce(100)
Partition pruning with predicate pushdown
df = (spark.read.parquet("s3://bucket/data/") .filter(F.col("date") == "2024-01-01")) # Spark pushes this down
Write with partitioning for future queries
(df.write .partitionBy("year", "month", "day") .mode("overwrite") .parquet("s3://bucket/partitioned_output/"))
Pattern 2: Join Optimization from pyspark.sql import functions as F from pyspark.sql.types import *
1. Broadcast Join - Small table joins
Best when: One side < 10MB (configurable)
small_df = spark.read.parquet("s3://bucket/small_table/") # < 10MB large_df = spark.read.parquet("s3://bucket/large_table/") # TBs
Explicit broadcast hint
result = large_df.join( F.broadcast(small_df), on="key", how="left" )
2. Sort-Merge Join - Default for large tables
Requires shuffle, but handles any size
result = large_df1.join(large_df2, on="key", how="inner")
3. Bucket Join - Pre-sorted, no shuffle at join time
Write bucketed tables
(df.write .bucketBy(200, "customer_id") .sortBy("customer_id") .mode("overwrite") .saveAsTable("bucketed_orders"))
Join bucketed tables (no shuffle!)
orders = spark.table("bucketed_orders") customers = spark.table("bucketed_customers") # Same bucket count result = orders.join(customers, on="customer_id")
4. Skew Join Handling
Enable AQE skew join optimization
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true") spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5") spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
Manual salting for severe skew
def salt_join(df_skewed, df_other, key_col, num_salts=10): """Add salt to distribute skewed keys""" # Add salt to skewed side df_salted = df_skewed.withColumn( "salt", (F.rand() * num_salts).cast("int") ).withColumn( "salted_key", F.concat(F.col(key_col), F.lit("_"), F.col("salt")) )
# Explode other side with all salts
df_exploded = df_other.crossJoin(
spark.range(num_salts).withColumnRenamed("id", "salt")
).withColumn(
"salted_key",
F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
)
# Join on salted key
return df_salted.join(df_exploded, on="salted_key", how="inner")
Pattern 3: Caching and Persistence from pyspark import StorageLevel
Cache when reusing DataFrame multiple times
df = spark.read.parquet("s3://bucket/data/") df_filtered = df.filter(F.col("status") == "active")
Cache in memory (MEMORY_AND_DISK is default)
df_filtered.cache()
Or with specific storage level
df_filtered.persist(StorageLevel.MEMORY_AND_DISK_SER)
Force materialization
df_filtered.count()
Use in multiple actions
agg1 = df_filtered.groupBy("category").count() agg2 = df_filtered.groupBy("region").sum("amount")
Unpersist when done
df_filtered.unpersist()
Storage levels explained:
MEMORY_ONLY - Fast, but may not fit
MEMORY_AND_DISK - Spills to disk if needed (recommended)
MEMORY_ONLY_SER - Serialized, less memory, more CPU
DISK_ONLY - When memory is tight
OFF_HEAP - Tungsten off-heap memory
Checkpoint for complex lineage
spark.sparkContext.setCheckpointDir("s3://bucket/checkpoints/") df_complex = (df .join(other_df, "key") .groupBy("category") .agg(F.sum("amount"))) df_complex.checkpoint() # Breaks lineage, materializes
Pattern 4: Memory Tuning
Executor memory configuration
spark-submit --executor-memory 8g --executor-cores 4
Memory breakdown (8GB executor):
- spark.memory.fraction = 0.6 (60% = 4.8GB for execution + storage)
- spark.memory.storageFraction = 0.5 (50% of 4.8GB = 2.4GB for cache)
- Remaining 2.4GB for execution (shuffles, joins, sorts)
- 40% = 3.2GB for user data structures and internal metadata
spark = (SparkSession.builder .config("spark.executor.memory", "8g") .config("spark.executor.memoryOverhead", "2g") # For non-JVM memory .config("spark.memory.fraction", "0.6") .config("spark.memory.storageFraction", "0.5") .config("spark.sql.shuffle.partitions", "200") # For memory-intensive operations .config("spark.sql.autoBroadcastJoinThreshold", "50MB") # Prevent OOM on large shuffles .config("spark.sql.files.maxPartitionBytes", "128MB") .getOrCreate())
Monitor memory usage
def print_memory_usage(spark): """Print current memory usage""" sc = spark.sparkContext for executor in sc._jsc.sc().getExecutorMemoryStatus().keySet().toArray(): mem_status = sc._jsc.sc().getExecutorMemoryStatus().get(executor) total = mem_status._1() / (10243) free = mem_status._2() / (10243) print(f"{executor}: {total:.2f}GB total, {free:.2f}GB free")
Pattern 5: Shuffle Optimization
Reduce shuffle data size
spark.conf.set("spark.sql.shuffle.partitions", "auto") # With AQE spark.conf.set("spark.shuffle.compress", "true") spark.conf.set("spark.shuffle.spill.compress", "true")
Pre-aggregate before shuffle
df_optimized = (df # Local aggregation first (combiner) .groupBy("key", "partition_col") .agg(F.sum("value").alias("partial_sum")) # Then global aggregation .groupBy("key") .agg(F.sum("partial_sum").alias("total")))
Avoid shuffle with map-side operations
BAD: Shuffle for each distinct
distinct_count = df.select("category").distinct().count()
GOOD: Approximate distinct (no shuffle)
approx_count = df.select(F.approx_count_distinct("category")).collect()[0][0]
Use coalesce instead of repartition when reducing partitions
df_reduced = df.coalesce(10) # No shuffle
Optimize shuffle with compression
spark.conf.set("spark.io.compression.codec", "lz4") # Fast compression
Pattern 6: Data Format Optimization
Parquet optimizations
(df.write .option("compression", "snappy") # Fast compression .option("parquet.block.size", 128 * 1024 * 1024) # 128MB row groups .parquet("s3://bucket/output/"))
Column pruning - only read needed columns
df = (spark.read.parquet("s3://bucket/data/") .select("id", "amount", "date")) # Spark only reads these columns
Predicate pushdown - filter at storage level
df = (spark.read.parquet("s3://bucket/partitioned/year=2024/") .filter(F.col("status") == "active")) # Pushed to Parquet reader
Delta Lake optimizations
(df.write .format("delta") .option("optimizeWrite", "true") # Bin-packing .option("autoCompact", "true") # Compact small files .mode("overwrite") .save("s3://bucket/delta_table/"))
Z-ordering for multi-dimensional queries
spark.sql("""
OPTIMIZE delta.s3://bucket/delta_table/
ZORDER BY (customer_id, date)
""")
Pattern 7: Monitoring and Debugging
Enable detailed metrics
spark.conf.set("spark.sql.codegen.wholeStage", "true") spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
Explain query plan
df.explain(mode="extended")
Modes: simple, extended, codegen, cost, formatted
Get physical plan statistics
df.explain(mode="cost")
Monitor task metrics
def analyze_stage_metrics(spark): """Analyze recent stage metrics""" status_tracker = spark.sparkContext.statusTracker()
for stage_id in status_tracker.getActiveStageIds():
stage_info = status_tracker.getStageInfo(stage_id)
print(f"Stage {stage_id}:")
print(f" Tasks: {stage_info.numTasks}")
print(f" Completed: {stage_info.numCompletedTasks}")
print(f" Failed: {stage_info.numFailedTasks}")
Identify data skew
def check_partition_skew(df): """Check for partition skew""" partition_counts = (df .withColumn("partition_id", F.spark_partition_id()) .groupBy("partition_id") .count() .orderBy(F.desc("count")))
partition_counts.show(20)
stats = partition_counts.select(
F.min("count").alias("min"),
F.max("count").alias("max"),
F.avg("count").alias("avg"),
F.stddev("count").alias("stddev")
).collect()[0]
skew_ratio = stats["max"] / stats["avg"]
print(f"Skew ratio: {skew_ratio:.2f}x (>2x indicates skew)")
Configuration Cheat Sheet
Production configuration template
spark_configs = { # Adaptive Query Execution (AQE) "spark.sql.adaptive.enabled": "true", "spark.sql.adaptive.coalescePartitions.enabled": "true", "spark.sql.adaptive.skewJoin.enabled": "true",
# Memory
"spark.executor.memory": "8g",
"spark.executor.memoryOverhead": "2g",
"spark.memory.fraction": "0.6",
"spark.memory.storageFraction": "0.5",
# Parallelism
"spark.sql.shuffle.partitions": "200",
"spark.default.parallelism": "200",
# Serialization
"spark.serializer": "org.apache.spark.serializer.KryoSerializer",
"spark.sql.execution.arrow.pyspark.enabled": "true",
# Compression
"spark.io.compression.codec": "lz4",
"spark.shuffle.compress": "true",
# Broadcast
"spark.sql.autoBroadcastJoinThreshold": "50MB",
# File handling
"spark.sql.files.maxPartitionBytes": "128MB",
"spark.sql.files.openCostInBytes": "4MB",
}
Best Practices Do's Enable AQE - Adaptive query execution handles many issues Use Parquet/Delta - Columnar formats with compression Broadcast small tables - Avoid shuffle for small joins Monitor Spark UI - Check for skew, spills, GC Right-size partitions - 128MB - 256MB per partition Don'ts Don't collect large data - Keep data distributed Don't use UDFs unnecessarily - Use built-in functions Don't over-cache - Memory is limited Don't ignore data skew - It dominates job time Don't use .count() for existence - Use .take(1) or .isEmpty() Resources Spark Performance Tuning Spark Configuration Databricks Optimization Guide