4

I need to collect partitions/batches from a big pyspark dataframe so that I can feed them into a neural network iteratively

My idea was to 1) partition the data, 2) Iteratively collect each partition, 3) transform the collected partition with toPandas()

I am a bit confused with methods like foreachPartition and mapPartitions because I can't iterate on them. Any idea?

Sociopath
  • 12,395
  • 17
  • 43
  • 69
cadama
  • 329
  • 3
  • 12

1 Answers1

7

You can use the mapPartitions to map each partition into list of elements and get them in iterative way using toLocalIterator:

for partition in rdd.mapPartitions(lambda part: [list(part)]).toLocalIterator():
    print(len(partition)) # or do something else :-)
Mariusz
  • 12,213
  • 3
  • 51
  • 62