1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "marimo",
# "matplotlib==3.10.5",
# "numpy==2.3.2",
# "seaborn==0.13.2",
# ]
# ///
import marimo
__generated_with = "0.14.17"
app = marimo.App()
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""# Modelling SentinelOne's Employee Stock Purchase Plan""")
return
@app.cell
def _():
import warnings
import marimo as mo
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
sns.set_style("ticks")
warnings.filterwarnings("ignore", category=UserWarning)
DISCOUNT = 0.15 # Discount on stock price
MONTHLY_INVESTMENT = 1300 # About £1,000
PERIOD = 6 # Months
TAX_RATE = 0.45 # UK additional rate
TOTAL_INVESTMENT = MONTHLY_INVESTMENT * PERIOD
PRICE_MIN = 15
PRICE_MAX = 30
return (
DISCOUNT,
PERIOD,
PRICE_MAX,
PRICE_MIN,
TAX_RATE,
TOTAL_INVESTMENT,
mo,
np,
plt,
sns,
)
@app.cell
def _(DISCOUNT, PRICE_MAX, PRICE_MIN, TAX_RATE, TOTAL_INVESTMENT, np):
initial_price = np.arange(PRICE_MAX, PRICE_MIN - 0.1, -1)
final_price = np.arange(PRICE_MIN, PRICE_MAX + 0.1, 1)
initial_price_matrix = np.tile(initial_price[:, np.newaxis], len(final_price))
final_price_matrix = np.tile(final_price, (len(initial_price), 1))
purchase_price = np.where(final_price_matrix < initial_price_matrix, final_price_matrix, initial_price_matrix) * (1 - DISCOUNT)
shares_purchased = np.floor(TOTAL_INVESTMENT / purchase_price)
profit_after_tax = shares_purchased * (final_price_matrix - purchase_price) * (1 - TAX_RATE)
return final_price, initial_price, profit_after_tax
@app.cell
def _(
PERIOD,
TOTAL_INVESTMENT,
final_price,
initial_price,
np,
plt,
profit_after_tax,
sns,
):
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
profit_after_tax,
xticklabels=final_price,
yticklabels=initial_price,
annot=False,
cmap="viridis",
cbar_kws={"label": "Profit"},
vmin=0,
vmax=np.ceil(profit_after_tax.max() / 100) * 100,
ax=ax,
)
# Set colorbar labels
cbar = ax.collections[0].colorbar
cbar.set_ticklabels(["${:,g}".format(x) for x in cbar.get_ticks()])
# # Set corresponding tick labels
ax.set_xticklabels([f"${final_price[i]:.0f}" for i, _ in enumerate(ax.get_xticks())])
ax.set_yticklabels([f"${initial_price[i]:.0f}" for i, _ in enumerate(ax.get_yticks())])
# Rotate x tick labels
ax.tick_params(axis="x", rotation=0)
ax.tick_params(axis="y", rotation=0)
# Set axis labels
ax.set_xlabel("Final price")
ax.set_ylabel("Initial price")
ax.set_title(f"ESPP Profit After Tax\nTotal Investment: ${TOTAL_INVESTMENT:,g} Over {PERIOD} Months\nMinumum Profit: ${np.min(profit_after_tax):,.2f}")
fig
return
if __name__ == "__main__":
app.run()
|