aboutsummaryrefslogtreecommitdiff
path: root/notebooks/2024-08-31-espp-model.py
blob: 50c340d3687a39e7ec6fbc4f9ae49de80ede92e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# /// script
# requires-python = ">=3.13"
# dependencies = [
#     "marimo",
#     "matplotlib==3.10.5",
#     "numpy==2.3.2",
#     "seaborn==0.13.2",
# ]
# ///

import marimo

__generated_with = "0.14.17"
app = marimo.App()


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""# Modelling SentinelOne's Employee Stock Purchase Plan""")
    return


@app.cell
def _():
    import warnings

    import marimo as mo
    import numpy as np
    import seaborn as sns
    from matplotlib import pyplot as plt

    sns.set_style("ticks")

    warnings.filterwarnings("ignore", category=UserWarning)

    DISCOUNT = 0.15  # Discount on stock price
    MONTHLY_INVESTMENT = 1300  # About £1,000
    PERIOD = 6  # Months
    TAX_RATE = 0.45  # UK additional rate
    TOTAL_INVESTMENT = MONTHLY_INVESTMENT * PERIOD
    PRICE_MIN = 15
    PRICE_MAX = 30
    return (
        DISCOUNT,
        PERIOD,
        PRICE_MAX,
        PRICE_MIN,
        TAX_RATE,
        TOTAL_INVESTMENT,
        mo,
        np,
        plt,
        sns,
    )


@app.cell
def _(DISCOUNT, PRICE_MAX, PRICE_MIN, TAX_RATE, TOTAL_INVESTMENT, np):
    initial_price = np.arange(PRICE_MAX, PRICE_MIN - 0.1, -1)
    final_price = np.arange(PRICE_MIN, PRICE_MAX + 0.1, 1)
    initial_price_matrix = np.tile(initial_price[:, np.newaxis], len(final_price))
    final_price_matrix = np.tile(final_price, (len(initial_price), 1))
    purchase_price = np.where(final_price_matrix < initial_price_matrix, final_price_matrix, initial_price_matrix) * (1 - DISCOUNT)
    shares_purchased = np.floor(TOTAL_INVESTMENT / purchase_price)
    profit_after_tax = shares_purchased * (final_price_matrix - purchase_price) * (1 - TAX_RATE)
    return final_price, initial_price, profit_after_tax


@app.cell
def _(
    PERIOD,
    TOTAL_INVESTMENT,
    final_price,
    initial_price,
    np,
    plt,
    profit_after_tax,
    sns,
):
    fig, ax = plt.subplots(figsize=(10, 8))

    sns.heatmap(
        profit_after_tax,
        xticklabels=final_price,
        yticklabels=initial_price,
        annot=False,
        cmap="viridis",
        cbar_kws={"label": "Profit"},
        vmin=0,
        vmax=np.ceil(profit_after_tax.max() / 100) * 100,
        ax=ax,
    )

    # Set colorbar labels
    cbar = ax.collections[0].colorbar
    cbar.set_ticklabels(["${:,g}".format(x) for x in cbar.get_ticks()])

    # # Set corresponding tick labels
    ax.set_xticklabels([f"${final_price[i]:.0f}" for i, _ in enumerate(ax.get_xticks())])
    ax.set_yticklabels([f"${initial_price[i]:.0f}" for i, _ in enumerate(ax.get_yticks())])

    # Rotate x tick labels
    ax.tick_params(axis="x", rotation=0)
    ax.tick_params(axis="y", rotation=0)

    # Set axis labels
    ax.set_xlabel("Final price")
    ax.set_ylabel("Initial price")
    ax.set_title(f"ESPP Profit After Tax\nTotal Investment: ${TOTAL_INVESTMENT:,g} Over {PERIOD} Months\nMinumum Profit: ${np.min(profit_after_tax):,.2f}")

    fig
    return


if __name__ == "__main__":
    app.run()