21

I want to check how can we get information about each partition such as total no. of records in each partition on driver side when Spark job is submitted with deploy mode as a yarn cluster in order to log or print on the console.

zero323
  • 305,283
  • 89
  • 921
  • 912
nilesh1212
  • 1,429
  • 2
  • 19
  • 54

6 Answers6

31

I'd use built-in function. It should be as efficient as it gets:

import org.apache.spark.sql.functions.spark_partition_id

df.groupBy(spark_partition_id).count
Alper t. Turker
  • 32,514
  • 8
  • 78
  • 112
24

You can get the number of records per partition like this :

df
  .rdd
  .mapPartitionsWithIndex{case (i,rows) => Iterator((i,rows.size))}
  .toDF("partition_number","number_of_records")
  .show

But this will also launch a Spark Job by itself (because the file must be read by spark to get the number of records).

Spark could may also read hive table statistics, but I don't know how to display those metadata..

Raphael Roth
  • 25,362
  • 13
  • 78
  • 128
17

For future PySpark users:

from pyspark.sql.functions  import spark_partition_id
rawDf.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()
BishoyM
  • 462
  • 1
  • 6
  • 15
4

Spark 1.5 solution :

(sparkPartitionId() exists in org.apache.spark.sql.functions)

import org.apache.spark.sql.functions._ 

df.withColumn("partitionId", sparkPartitionId()).groupBy("partitionId").count.show

as mentioned by @Raphael Roth

mapPartitionsWithIndex is best approach, will work with all version of spark since its RDD based approach

Praveen Sripati
  • 31,453
  • 15
  • 78
  • 111
Ram Ghadiyaram
  • 32,481
  • 14
  • 91
  • 120
4

Spark/scala:

val numPartitions = 20000
val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
val l = a.glom().map(_.length).collect()  # get length of each partition
print(l.min, l.max, l.sum/l.length, l.length)  # check if skewed

PySpark:

num_partitions = 20000
a = sc.parallelize(range(int(1e6)), num_partitions)
l = a.glom().map(len).collect()  # get length of each partition
print(min(l), max(l), sum(l)/len(l), len(l))  # check if skewed

The same is possible for a dataframe, not just for an RDD. Just add DF.rdd.glom... into the code above.

Credits: Mike Dusenberry @ https://issues.apache.org/jira/browse/SPARK-17817

Tagar
  • 12,247
  • 5
  • 84
  • 103
0

PySpark:

from pyspark.sql.functions import spark_partition_id

df.select(spark_partition_id().alias("partitionId")).groupBy("partitionId").count()
rwitzel
  • 1,605
  • 15
  • 20