On Coding Your First Attention
โข
7
pip install mlx_micrograd
from mlx_micrograd.engine import Value
a = Value(-4.0)
b = Value(2.0)
c = a + b
d = a * b + b**3
c += c + 1
c += 1 + c + (-a)
d += d * 2 + (b + a).relu()
d += 3 * d + (b - a).relu()
e = c - d
f = e**2
g = f / 2.0
g += 10.0 / f
print(f'{g.data}') # prints array(24.7041, dtype=float32), the outcome of this forward pass
g.backward()
print(f'{a.grad}') # prints array(138.834, dtype=float32), i.e. the numerical value of dg/da
print(f'{b.grad}') # prints array(645.577, dtype=float32), i.e. the numerical value of dg/db