# In the time dimension
# Weight the sum of the past
# To predict the future
mask = jnp.where(causal, ,0)
# You've reached maximum number of attempts for this puzzle
# Incorrect answer
# Try again
# You are the future