Effective .transform in python

9/1/19

Merging summary statistics back into a table is quite a common thing to do. At first glance, a solution to this problem is to simply compute the statistics in a new table and merge the table back in. However, with pandas we don’t have to do that.

import pandas as pd
import numpy as np

np.random.seed(124)
df = pd.DataFrame({
    'group':np.random.choice(['a', 'b'], 5),
    'x':np.random.randint(100, 200, 5)
}).sort_values(by='group'); df

group x
0 a 141
1 a 164
4 a 178
2 b 120
3 b 128

Here I’ve got some rows that having a grouping column called group, and I’d like to calculate the sum of x within each group and integrate it into my table.

df['g_sum'] = df.groupby('group')['x'].transform(np.sum); df

group x g_sum
0 a 141 483
1 a 164 483
4 a 178 483
2 b 120 248
3 b 128 248

A key idea is that .transform will apply its function argument to the dataframe and return a result that is the same size as the input frame.

We can also use .transform with a user defined function

def max_minus_one(col):
    """ Get the max - 1 of a list"""
    return np.max(col) - 1

df['udf'] = df.groupby('group')['x'].transform(max_minus_one); df

group x g_sum udf
0 a 141 483 177
1 a 164 483 177
4 a 178 483 177
2 b 120 248 127
3 b 128 248 127

…and even lambda functions

df['use_lam'] = df.groupby('group')['x'].transform(lambda x: np.max(x) - 1); df

group x g_sum udf use_lam
0 a 141 483 177 177
1 a 164 483 177 177
4 a 178 483 177 177
2 b 120 248 127 127
3 b 128 248 127 127