I have a dataframe with the following schema:
|-- query: string (nullable = true)
|-- groupx: string (nullable = true)
|-- groupx_score: double (nullable = true)
|-- groupy: string (nullable = true)
|-- groupy_score: double (nullable = true)
|-- groupz: string (nullable = true)
|-- groupz_score: double (nullable = true)
and I want to do the following: per each row, take the maximum score out of all groups, (max(groupx_score, ..., groupz_score)), and assign the group_name as the value to another column. To find the max score per row, I copied the function from the following answer:
def row_max_with_name(*cols):
cols_ = [F.struct(F.col(c).alias("value"), F.lit(c).alias("col")) for c in cols]
return F.greatest(*cols_).alias("greatest({0})".format(",".join(cols)))
df = df.withColumn('max', row_max_with_name("a", "b", "c"))
where the column max is a struct that has the max score and group per row as shown below:
+-------------+---------------+--------------+---------------+--------------+---------------+--------------+-----------------------+
|query | groupx | groupx_score |groupy | groupy_score | groupz | groupz_score | max |
+-------------+---------------+--------------+---------------+--------------+---------------+--------------+-----------------------+
|Q1 |x1 | 0.13 |y1 | 0.16 |z1 |0.004 |[0.16, groupy_score]|
|Q2 |x2 | 0.159 |y2 | 0.157 |z2 |0.003 |[0.157, groupx_score]|
|Q3 |x3 | 0.01 |y3 | 0.155 |z3 |0.002 |[0.155, groupy_score]|
+-------------+---------------+
Now, I want to create a new column called winner where the value is taken from the winner group in the max column. For example, in the first row (Q1), the max score is 0.16 and it came from groupy. I then want to assign the value of groupy on that particular row into a new column as shown below:
+-------------+---------------+--------------+---------------+--------------+---------------+--------------+-----------------------+---------+
|query | groupx | groupx_score | groupy | groupy_score | groupz | groupz_score | max | winner |
+-------------+---------------+--------------+---------------+--------------+---------------+--------------+-----------------------+---------+
|Q1 |x1 | 0.13 |y1 | 0.16 |z1 |0.004 |[0.16, groupy_score]| y1 |
|Q2 |x2 | 0.159 |y2 | 0.157 |z2 |0.003 |[0.157, groupx_score]| x2 |
|Q3 |x3 | 0.01 |y3 | 0.155 |z3 |0.002 |[0.155, groupy_score]| y3 |
+-------------+---------------+
I was thinking I can do something like the following:
df = df.withColumn('winner_score', df.max.col)
winner_group = (F.when(
F.split(df['winner_score'], '_').getItem(0) == 'groupx', F.col('groupx'))
.when(F.split(df['winner_score'], '_').getItem(0) == 'groupy', F.col('groupy'))
.otherwise(F.col('groupz')))
df = df.withColumn('winner_group', winner_group)
it seems that it does the job but I don't like it since my real data has more than twenty groups and it can grow to a hundred groups. Any idea how to achieve this?
thank you