ATTgt#
[1]:
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 6)
from differences import simulate_data, ATTgt
see plotto to configure the plots. the .plot() method takes the same arguments as the function mark_plot() in plotto
Binary Treatment#
An entity can be either treated or not treated, and once treated it remains treated, the groups are identified by the date of the treatement (cohort)
[2]:
panel_data = simulate_data() # generate data
panel_data
[2]:
| y | x0 | w | cat.0 | cat.1 | effect | cohort | intensity | ||
|---|---|---|---|---|---|---|---|---|---|
| entity | time | ||||||||
| e0 | 1900 | 4.295747 | 0.160245 | 3.729827 | 0 | 0 | 0.00 | 1904.0 | 5.25 |
| 1901 | -4.799873 | -1.933404 | 1.691201 | 0 | 0 | 0.00 | 1904.0 | 5.25 | |
| 1902 | 11.082130 | -0.056813 | 1.786326 | 0 | 0 | 0.00 | 1904.0 | 5.25 | |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| e999 | 1905 | -7.629120 | -0.258604 | 1.801049 | 0 | 0 | 5.25 | 1904.0 | 5.25 |
| 1906 | 39.241097 | 1.504564 | 0.616657 | 0 | 0 | 10.50 | 1904.0 | 5.25 | |
| 1907 | 23.619567 | 0.472740 | 0.674045 | 0 | 0 | 15.75 | 1904.0 | 5.25 |
8000 rows × 8 columns
[3]:
att_gt = ATTgt(data=panel_data, cohort_name='cohort')
[4]:
att_gt.fit(
formula='y ~ x0',
)
Computing ATTgt [workers=1] 100%|████████████████████| 14/14 [00:00<00:00, 59.11it/s]
[4]:
| ATTgtResult | |||||||
|---|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||||
| ATT | std_error | lower | upper | zero_not_in_cband | |||
| cohort | base_period | time | |||||
| 1903 | 1900 | 1901 | 0.447031 | 1.407877 | -2.312357 | 3.206418 | |
| 1901 | 1902 | 0.196806 | 1.293725 | -2.338849 | 2.732461 | ||
| 1902 | 1903 | 2.613823 | 1.295713 | 0.074272 | 5.153373 | * | |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 1904 | 1903 | 1905 | 3.701955 | 1.352406 | 1.051288 | 6.352621 | * |
| 1906 | 10.637781 | 1.302243 | 8.085431 | 13.190131 | * | ||
| 1907 | 13.856575 | 1.353726 | 11.203320 | 16.509830 | * | ||
14 rows × 5 columns
[5]:
att_gt.plot(configure_axisX={'format': 'c'})
[5]:
Aggregate#
[6]:
att_gt.aggregate('time')
att_gt.plot('time', configure_axisX={'format': 'c'})
[6]:
[7]:
att_gt.aggregate('event')
att_gt.plot('event')
[7]:
[8]:
att_gt.aggregate('cohort')
att_gt.plot('cohort', lines=False, configure_axisX={'format': 'c'})
[8]:
[9]:
att_gt.aggregate('simple')
[9]:
| SimpleAggregation | |||||
|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||
| ATT | std_error | lower | upper | zero_not_in_cband | |
| 0 | 4.008994 | 0.727817 | 2.582498 | 5.43549 | * |
[10]:
att_gt.aggregate('simple')
[10]:
| SimpleAggregation | |||||
|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||
| ATT | std_error | lower | upper | zero_not_in_cband | |
| 0 | 4.008994 | 0.727817 | 2.582498 | 5.43549 | * |
[11]:
att_gt.aggregate('event', overall=True)
[11]:
| EventAggregationOverall | |||||
|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||
| ATT | std_error | lower | upper | zero_not_in_cband | |
| 0 | 3.775662 | 0.742154 | 2.321068 | 5.230257 | * |
Heterogeneity & Triple Difference#
[12]:
panel_data = simulate_data(samples=3)
[13]:
att_gt = ATTgt(data=panel_data, cohort_name='cohort')
# heterogeneity
att_gt.fit(
formula='y',
split_sample_by='samples'
)
Computing ATTgt for samples = 0 [workers=1]100%|████████████████████| 21/21 [00:00<00:00, 119.56it/s]
Computing ATTgt for samples = 1 [workers=1]100%|████████████████████| 21/21 [00:00<00:00, 116.07it/s]
Computing ATTgt for samples = 2 [workers=1]100%|████████████████████| 21/21 [00:00<00:00, 115.64it/s]
[13]:
| ATTgtResult | ||||||||
|---|---|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | |||||||
| ATT | std_error | lower | upper | zero_not_in_cband | ||||
| sample_name | cohort | base_period | time | |||||
| samples = 0 | 1902 | 1900 | 1901 | -0.413163 | 2.285396 | -4.892458 | 4.066131 | |
| 1901 | 1902 | 0.016927 | 2.219508 | -4.333228 | 4.367082 | |||
| 1903 | 8.410651 | 2.174143 | 4.149408 | 12.671893 | * | |||
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| samples = 2 | 1905 | 1904 | 1905 | -2.266016 | 2.272358 | -6.719756 | 2.187723 | |
| 1906 | 4.853377 | 2.291364 | 0.362385 | 9.344369 | * | |||
| 1907 | 12.146894 | 2.203472 | 7.828168 | 16.465620 | * | |||
63 rows × 5 columns
[14]:
att_gt.aggregate('event')
att_gt.plot('event', vertical_line=0, horizontal_line=0)
[14]:
[15]:
att_gt.aggregate('simple')
[15]:
| SimpleAggregation | |||||
|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||
| ATT | std_error | lower | upper | zero_not_in_cband | |
| sample_name | |||||
| samples = 0 | 11.659525 | 1.145702 | 9.413990 | 13.905061 | * |
| samples = 1 | 13.036015 | 1.231696 | 10.621935 | 15.450095 | * |
| samples = 2 | 14.491206 | 1.284529 | 11.973576 | 17.008836 | * |
[16]:
att_gt.plot(
'simple',
y_title=['Sample', 'Names'],
vertical_line=0,
configure_axisY={'domain': True, 'ticks': True},
)
[16]:
Triple Difference#
[17]:
att_gt.sample_names
[17]:
['samples = 0', 'samples = 1', 'samples = 2']
[18]:
att_gt.aggregate('time', difference=['samples = 1', 'samples = 2'])
[18]:
| DifferenceTimeAggregation | ||||||
|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | |||||
| ATT | std_error | lower | upper | zero_not_in_cband | ||
| difference_between | time | |||||
| samples = 1 - samples = 2 | 1902 | 0.478123 | 1.541316 | -2.542801 | 3.499047 | |
| 1903 | 5.549908 | 1.474024 | 2.660874 | 8.438943 | * | |
| 1904 | -5.869359 | 1.375390 | -8.565075 | -3.173643 | * | |
| 1905 | 1.940128 | 1.286668 | -0.581695 | 4.461951 | ||
| 1906 | -1.888593 | 1.396793 | -4.626258 | 0.849071 | ||
| 1907 | -5.348880 | 1.394556 | -8.082160 | -2.615600 | * | |
[19]:
att_gt.plot('time', difference=True)
[19]:
Multi-valued Treatment#
An entity can be treated with different intensities or not treated, the groups are identified by the date of the treatement (cohort) & the intensity which represents a second dimension.
[20]:
panel_data = simulate_data(intensity_by=2) # generate data
panel_data
[20]:
| y | x0 | w | cat.0 | cat.1 | effect | cohort | strata | intensity | ||
|---|---|---|---|---|---|---|---|---|---|---|
| entity | time | |||||||||
| e0 | 1900 | 14.320181 | 0.160245 | 3.729827 | 0 | 0 | 0.0 | 1904.0 | 0.0 | 10.0 |
| 1901 | 1.118288 | -1.933404 | 1.691201 | 0 | 0 | 0.0 | 1904.0 | 0.0 | 10.0 | |
| 1902 | -0.252880 | -0.056813 | 1.786326 | 0 | 0 | 0.0 | 1904.0 | 0.0 | 10.0 | |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| e999 | 1905 | 16.700577 | -0.258604 | 1.801049 | 0 | 0 | 0.5 | 1904.0 | 1.0 | 0.5 |
| 1906 | 15.151684 | 1.504564 | 0.616657 | 0 | 0 | 1.0 | 1904.0 | 1.0 | 0.5 | |
| 1907 | 0.163186 | 0.472740 | 0.674045 | 0 | 0 | 1.5 | 1904.0 | 1.0 | 0.5 |
8000 rows × 9 columns
[21]:
att_gt = ATTgt(data=panel_data, cohort_name='cohort', strata_name='strata')
[22]:
att_gt.fit(formula='y', n_jobs=1)
Computing ATTgt [workers=1] 100%|████████████████████| 42/42 [00:00<00:00, 44.39it/s]
[22]:
| ATTgtResult | ||||||||
|---|---|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | |||||||
| ATT | std_error | lower | upper | zero_not_in_cband | ||||
| stratum | cohort | base_period | time | |||||
| 0 | 1901 | 1900 | 1901 | -1.166568 | 1.488863 | -4.084686 | 1.751550 | |
| 1902 | 9.743155 | 1.628144 | 6.552052 | 12.934258 | * | |||
| 1903 | 19.473230 | 1.626157 | 16.286020 | 22.660440 | * | |||
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1 | 1905 | 1904 | 1905 | 1.454368 | 1.619994 | -1.720762 | 4.629498 | |
| 1906 | 0.030797 | 1.537388 | -2.982428 | 3.044021 | ||||
| 1907 | 1.231228 | 1.572191 | -1.850210 | 4.312665 | ||||
42 rows × 5 columns
Aggregate#
[23]:
att_gt.aggregate('event')
[23]:
| EventAggregation | ||||||
|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | |||||
| ATT | std_error | lower | upper | zero_not_in_cband | ||
| stratum | relative_period | |||||
| 0 | -4 | -0.528218 | 1.521623 | -3.510545 | 2.454109 | |
| -3 | 1.539960 | 1.009962 | -0.439529 | 3.519448 | ||
| -2 | -0.138531 | 0.989853 | -2.078607 | 1.801545 | ||
| ... | ... | ... | ... | ... | ... | ... |
| 1 | 4 | 3.188915 | 1.576782 | 0.098478 | 6.279352 | * |
| 5 | 2.272094 | 1.596461 | -0.856913 | 5.401100 | ||
| 6 | 3.896055 | 1.626878 | 0.707432 | 7.084679 | * | |
22 rows × 5 columns
[24]:
att_gt.plot('event')
[24]:
[25]:
att_gt.aggregate('event', difference=[0, 1], boot_iterations=5000)
Bootstrap 100%|████████████████████| [00:00<00:00, 10926.23it/s]
[25]:
| DifferenceEventAggregation | ||||||
|---|---|---|---|---|---|---|
| bootstrap | simult. conf. band | |||||
| ATT | std_error | lower | upper | zero_not_in_cband | ||
| difference_between | relative_period | |||||
| 0 - 1 | -4 | -1.121424 | 1.702570 | -5.902428 | 3.659580 | |
| -3 | 1.860588 | 1.272509 | -1.712757 | 5.433934 | ||
| -2 | -0.597033 | 1.240433 | -4.080306 | 2.886239 | ||
| ... | ... | ... | ... | ... | ... | |
| 4 | 37.084082 | 1.716253 | 32.264655 | 41.903510 | * | |
| 5 | 47.373146 | 1.620863 | 42.821584 | 51.924708 | * | |
| 6 | 55.768956 | 1.691237 | 51.019777 | 60.518135 | * | |
11 rows × 5 columns
[26]:
att_gt.plot('event', difference=True)
[26]:
[27]:
att_gt.aggregate('simple')
att_gt.plot('simple')
[27]:
Heterogeneity & Triple Difference#
[28]:
panel_data = simulate_data(intensity_by=2, samples=3) # generate data
panel_data
[28]:
| y | x0 | w | cat.0 | cat.1 | effect | cohort | samples | strata | intensity | ||
|---|---|---|---|---|---|---|---|---|---|---|---|
| entity | time | ||||||||||
| e0 | 1900 | 2.347452 | 0.160245 | 3.729827 | 0 | 0 | 0.0 | 1906.0 | 2 | 1.0 | 4.3 |
| 1901 | -2.880331 | -1.933404 | 1.691201 | 0 | 0 | 0.0 | 1906.0 | 2 | 1.0 | 4.3 | |
| 1902 | -6.399243 | -0.056813 | 1.786326 | 0 | 0 | 0.0 | 1906.0 | 2 | 1.0 | 4.3 | |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| e999 | 1905 | -8.514995 | -0.258604 | 1.801049 | 0 | 0 | 0.0 | 1906.0 | 2 | 1.0 | 4.3 |
| 1906 | 3.713039 | 1.504564 | 0.616657 | 0 | 0 | 0.0 | 1906.0 | 2 | 1.0 | 4.3 | |
| 1907 | 26.942832 | 0.472740 | 0.674045 | 0 | 0 | 4.3 | 1906.0 | 2 | 1.0 | 4.3 |
8000 rows × 10 columns
[29]:
att_gt = ATTgt(data=panel_data, cohort_name='cohort', strata_name='strata')
[30]:
att_gt.fit(formula='y', split_sample_by='samples')
Computing ATTgt for samples = 2 [workers=1]100%|████████████████████| 42/42 [00:00<00:00, 92.14it/s]
Computing ATTgt for samples = 0 [workers=1]100%|████████████████████| 42/42 [00:00<00:00, 95.82it/s]
Computing ATTgt for samples = 1 [workers=1]100%|████████████████████| 42/42 [00:00<00:00, 95.95it/s]
[30]:
| ATTgtResult | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||||||
| ATT | std_error | lower | upper | zero_not_in_cband | |||||
| sample_name | stratum | cohort | base_period | time | |||||
| samples = 2 | 0 | 1901 | 1900 | 1901 | -2.313173 | 2.816396 | -7.833208 | 3.206863 | |
| 1902 | 3.324637 | 2.912590 | -2.383933 | 9.033208 | |||||
| 1903 | 11.378922 | 2.777983 | 5.934176 | 16.823669 | * | ||||
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| samples = 1 | 1 | 1906 | 1904 | 1905 | 2.983248 | 2.933005 | -2.765336 | 8.731831 | |
| 1905 | 1906 | -0.334297 | 2.554715 | -5.341447 | 4.672852 | ||||
| 1907 | 0.343971 | 3.023992 | -5.582944 | 6.270887 | |||||
126 rows × 5 columns
[31]:
att_gt.aggregate('time')
att_gt.plot(
'time',
configure_header={
'titleFontSize': 18,
'labelFontSize': 15
},
estimation_details=False,
configure_axisX={'format':'c'},
)
[31]:
[32]:
att_gt.aggregate('time', difference=[0, 1])
[32]:
| DifferenceTimeAggregation | |||||||
|---|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||||
| ATT | std_error | lower | upper | zero_not_in_cband | |||
| sample_name | difference_between | time | |||||
| samples = 2 | 0 - 1 | 1901 | -0.366749 | 2.745978 | -5.748767 | 5.015269 | |
| 1902 | 0.689242 | 3.101619 | -5.389819 | 6.768304 | |||
| 1903 | 1.007773 | 2.002671 | -2.917389 | 4.932936 | |||
| ... | ... | ... | ... | ... | ... | ... | ... |
| samples = 1 | 0 - 1 | 1905 | 5.320888 | 2.212145 | 0.985162 | 9.656613 | * |
| 1906 | 4.642867 | 1.751682 | 1.209633 | 8.076101 | * | ||
| 1907 | 6.223082 | 1.908472 | 2.482545 | 9.963619 | * | ||
21 rows × 5 columns
[33]:
att_gt.sample_names
[33]:
['samples = 2', 'samples = 0', 'samples = 1']
[34]:
att_gt.aggregate('time', difference=['samples = 2', 'samples = 0'])
[34]:
| DifferenceTimeAggregation | |||||||
|---|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | ||||||
| ATT | std_error | lower | upper | zero_not_in_cband | |||
| stratum | difference_between | time | |||||
| 0 | samples = 2 - samples = 0 | 1901 | 1.470721 | 2.102341 | -2.649793 | 5.591234 | |
| 1902 | -10.006937 | 2.011322 | -13.949056 | -6.064819 | * | ||
| 1903 | -6.449823 | 1.612146 | -9.609571 | -3.290075 | * | ||
| ... | ... | ... | ... | ... | ... | ... | ... |
| 1 | samples = 2 - samples = 0 | 1905 | -8.045895 | 1.515554 | -11.016327 | -5.075464 | * |
| 1906 | -4.207633 | 1.397745 | -6.947163 | -1.468103 | * | ||
| 1907 | -4.728102 | 1.423883 | -7.518862 | -1.937342 | * | ||
14 rows × 5 columns
Custom Estimation & DoubleML#
[35]:
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
# outcome model
outcome_model = RandomForestRegressor(
n_estimators=100,
max_depth=2)
# propensity score model
pscore_model = RandomForestClassifier(
n_estimators=100,
max_depth=2)
[36]:
from functools import partial
from differences.did.double_ml import aiptw_double_ml_did_panel
# use partial to set those models for cross-fitting
aiptw = partial(
aiptw_double_ml_did_panel,
pscore_model=pscore_model,
outcome_model=outcome_model,
outcome_type='continuous'
)
[37]:
# now, instead of selecting from the menu of options available for the est_method,
# we plug in the function we composed above
att_gt.fit(
formula='y ~ x0',
n_jobs=-1,
est_method=aiptw, # <--- plug in here
)
Computing ATTgt [workers=10] 100%|████████████████████| 42/42 [00:27<00:00, 1.53it/s]
[37]:
| ATTgtResult | ||||||||
|---|---|---|---|---|---|---|---|---|
| analytic | pointwise conf. band | |||||||
| ATT | std_error | lower | upper | zero_not_in_cband | ||||
| stratum | cohort | base_period | time | |||||
| 0 | 1901 | 1900 | 1901 | -2.809343 | 1.806964 | -6.350927 | 0.732242 | |
| 1902 | 5.300236 | 1.724092 | 1.921077 | 8.679394 | * | |||
| 1903 | 10.981719 | 1.759577 | 7.533012 | 14.430427 | * | |||
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1 | 1906 | 1904 | 1905 | -2.160258 | 1.731272 | -5.553488 | 1.232972 | |
| 1905 | 1906 | 0.015155 | 1.767796 | -3.449661 | 3.479971 | |||
| 1907 | 4.812429 | 1.722523 | 1.436346 | 8.188512 | * | |||
42 rows × 5 columns