Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

2. Neural Networks

In part two of the course notes we change both 1. the goal from learning price/sentiment to learning sequences and 2. the function class from generalized linear models to deep neural networks with torch to study the foundations of machine learning's "age of research" from 2012 to 2020. We build off of teenygrad's numpy-like capability developed in part one and abstract two more tasks for the research scientist

  1. optim.sgd and Tensor.backward() providing iterative optimization via differentiation
  2. cuBLAS-like kernels providing acceleration on manycore processors like GPUs which we will then use to train other language models with different inductive biases (invariances) such as RNNs, LSTMs, BERTs, and GPTs culminating in nanogpt.

This will prepare us for part three of the course notes where we modify the language implementation of the deep learning framework to support distributed compilation in order to run both the training and inference of nanochat

Contents

2.1 Learning Representations, Learning Sequences

2.1.1 XOR Learning with Feedforward Neural Network

2.1.2 Sentiment Learning with FNNs

In part 1 of the course notes we trained generalized linear models of the form ___. In part 2 we modify and increase the expressivity of the function class by including non-linearities . The feedforward neural network simply put is a series of linear and non-linear layers of the form so

where , is an elementwise nonlinearity, and . Conceptually, the linear layers are performing linear transformations that rotate, reflect, shear, and scale space, whereas the nonlinear transformations perform transformations that squash and twist space.

We will now use the same model of the feedforward neural network to accomplish two other goals. Namely, representation learning, and sequence learning.

2.1.3 Representation Learning with FNNs

2.1.4 Language Modeling with FNNs

A sequence model, simply put, is the conditional probability distribution of an output token given an input token . A sequence of tokens can be a sentence of words in the domain of language, a series of pixels in the domain of vision, or a stream of waves in the domain of audio.

Since we are modeling language as stochastic phenomena, we use the formal language of probability theory, where a probability space is a measurable space with a measure . In the domain of language, the measurable space consists of a sample space which is the set of all tokens modelling a vocabulary, and the event space is the set of all token combinations which model a language. The measure is the measure of the weight of a particular token combination (sentence, really) as an event with respect to the set of all possible token combinations (sentences) as the entire event space. Once we use a random variable to map events to , we can forget about the probability space and focus our attention on language models which are joint probability distribution over all sequences of tokens.

Language modeling with ngrams.

(1. EXPLAIN MODEL).

# FFN MODEL f: R^n -> R
import torch

class MLP():
  """
  model: Neural Language Models (Bengio et al. 2003)
  key:
  b: batch size, t: sequence length
  v: vocabulary size, e: dimension of embedding, d: dimension of model
  """
  
  def __init__(self, cfg):
    super().__init__()
    b, t, v, e, d = cfg.b, cfg.t, cfg.v, cfg.e, cfg.d
    self.wte = layers.Embedding(v+1, e)   # token embeddings table (+1 for <BLANK>)
    l1 = layers.Linear(t*e, d, b=False)
    l2 = layers.Linear(d, d, b=False)
    l3 = layers.Linear(d, v, b=False)

  def forward(self, i, targets=None):
    embs = []                             # gather the word embeddings of the previous 3 words
    for k in range(self.b):
      tok_emb = self.wte(i)               # token embeddings of shape (b, t, e)
      i = torch.roll(i, 1, 1)
      i[:, 0] = self.v                    # special <BLANK> token
      embs.append(tok_emb)

                                          # concat all of the embeddings together and pass through an MLP
    x = torch.cat(embs, -1)                  # (b, t, e * block_size)
    x = self.l1(x).tanh()
    x = self.l2(x).tanh()
    x = self.l3(x)
    yhat = x

    # if we are given some desired targets also calculate the loss
    loss = None
    if targets is not None: loss = F.cross_entropy(yhat.view(-1, yhat.size(-1)), targets.view(-1), ignore_index=-1)
    return yhat, loss

(2. EXPLAIN DATASET).

# FFN DATA d={(x^i,y^i)}
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def build_dataset(t):
  import random

  words = open('./data/names.txt', 'r').read().splitlines()
  v = sorted(list(set(''.join(words))))
  encode = { c:i+1 for i,c in enumerate(v) }
  encode['.'] = 0
  decode = { i:c for c,i in encode.items() }

  def gen_dataset(words, t):
    X, Y = [], []
    for w in words:
      context = [0] * t
      for c in w + '.':
        X.append(context)
        Y.append(encode[c])
        # print(''.join(decode[i] for i in context), '-->', decode[encode[c]])
        context = context[1:] + [encode[c]]
    X, Y = torch.tensor(X), torch.tensor(Y) # X:(N,C) Y:(N)
    return X, Y

  random.seed(42)
  random.shuffle(words)
  n1, n2 = int(0.8*len(words)), int(0.9*len(words))
  Xtraining, Ytraining = gen_dataset(words[:n1], t)
  Xdev, Ydev = gen_dataset(words[n1:n2], t)
  Xte, Yte = gen_dataset(words[n2:], t)
  return Xtraining, Ytraining

(3. EXPLAIN TRAINING LOOP).

# FFN TRAINING LOOP: theta^(t+1) := theta^t - alpha*grad(L)
if __name__ == "__main__":
  b, t, v, e, d = 32, 3, 27, 10, 200                         # init hyperparameters
  X, Y = build_dataset(t)                                    # init data
  C = torch.randn((v,e), generator=g)                           # init embedding
  model = MLP()                                              # init model
  params = [C] + [p for l in model for p in l.parameters()]
  for p in params: p.requires_grad = True

  N, losses, steps = X.shape[0], [], [] # train
  for step in range(200000):
    i_b = torch.randint(0, N, (b,))
    X_b, Y_b = X[i_b], Y[i_b]
    X_bd = C[X_b].view(-1, t * e)                            # 0. embed
    for layer in model: X_bd = layer(X_bd)                   # 1. forward

    loss = X_bd.cross_entropy(Y_b)
    for layer in model: layer.out.retain_grad()
    for p in params: p.grad = None
    loss.backward()                                          # 2. backward

    for p in params: p.data += -0.01 * p.grad                # 3. update
    # optimizer.step()?

    steps.append(step)
    losses.append(loss.log10().item())
    if step % 10000 == 0: print(f"step: {step}/{200000}, loss {loss.item()}")

    plt.plot(steps, losses)

In the next two chapters of 2.2 and 2.3, we will implement two features on picograd which are the two primary tasks which pytorch abstracts away from research scientists: the backward pass with automatic differentiation, and the device acceleration of the specified forward pass.

2.2 Programming Automatic Differentiation and Gradient Descent

So far we've been training feedforward neural networks using the magical Tensor.backward() function in order to materialize the gradient of the loss on our parameter weights which gradient descent uses in it's update rule . We will now dive deeper into how .backward() is implemented.

2.2.1 Automatic Differentiation via Tensor.backward()

Consider the function where , and translate it to it's computational counterpart in python with one-dimensional `Tensor`s:
import picograd as pg

def f(x1: pg.Tensor, x2: pg.Tensor) -> pg.Tensor:
  a = pg.exp(x1)
  b = pg.sin(x2)
  c = b**2
  d = a*c
  return d

Figure 1. Python source for the function where

Here we've broken up the function to render the subexpressions more clearly. But this isn't necessary — automatic differentiation will work if the function was expressed in one line. In part one, the development of picograd followed that of numpy — an array programming language similar to Matlab but embedded in the host language of Python, that could evaluate functions of the form where Tensor objects stored their values with the value: field and the function types that produced their values with Op. For instance, evaluating the specified function f from above with 9 and 10

if __name__ == "__main__":
  print(f(9, 10))

populates the Tensor.value fields. In part one of the book we verified this with a REPL-interface, but we can also represent the entire expression being evaluated with a graph of vertices and edges where the vertices are Tensors (along with their Ops and values) and the edges are their data dependencies:

digraph G {
  bgcolor="#181818"; fontname="Helvetica,Arial,sans-serif"
  node [
    fontname="Helvetica,Arial,sans-serif" fontcolor = "#e6e6e6", color = "#e6e6e6", fillcolor = "#333333"
    style = filled,
  ]
  edge [
    fontname="Helvetica,Arial,sans-serif"
    color = "#e6e6e6",
    fontcolor = "#e6e6e6"
  ]

  graph [rankdir="LR"]
  
  x1 [label="9"];
  x2 [label="10"];

  expx1 [label="a: ℝ -> ℝ\la(x1) := exp(x1)\lval:8103.08, grad:"];
  sinx2 [label="b: ℝ -> ℝ\lb(x2) := sin(x2)\lval:-0.54, grad:"];
  sqrsin [label="c: ℝ -> ℝ\lc(b) := b^2\lval:-0.30, grad:"];

  mulxy [label="d: ℝ, ℝ -> ℝ\ld(a,c) := a*c\lval:-2430.9, grad:"];

  x1 -> expx1;
  x2 -> sinx2;
  sinx2 -> sqrsin

  sqrsin -> mulxy;
  expx1 -> mulxy;
}

Here you can see that even if the function was specified in one line, the graph of the expression always parses into Tensor vertices, and data dependency edges. You may have noticed the Tensor.grad fields, which supposedly store the values of derivatives . The question now remains in how to populate these fields.

Taking a step back to differential calculus, deriving the derivative of involves the application of the chain rule where . Evaluating the derivative of the function with respect to its inputs and results in

symbolic and numeric differentiattion symbolic differentiation has performance issues since a large unrolled expression must be constructed in order to differentiate[^0], whereas numerical differentiation has correctness issues since evaluating finite differences requires evaluating functions to a precision point resulting in numerical instability. (trace through EXAMPLE for both. talking nets widrow)

To populate the Tensor.grad fields, the simplest idea would be to literally translate the manual derivation of the derivative into code. The translation from math to code involves a design decision: should we evaluate from outputs to inputs (symbolically outside-in, graphically right-to-left) or from inputs to outputs (symbolically inside-out, graphically left-to-right)? Although the former order seems more natural with symbolic expressions, there's nothing illegal about the latter.

import picograd as pg

def f(x1: pg.Tensor, x2: pg.Tensor) -> pg.Tensor:
  a = pg.exp(x1)
  b = pg.sin(x2)
  c = b**2
  d = a*c
  return d

# dict[f(x), f'(x)] of local derivatives (adjoints)
dd_da, dd_dc = [c, a] # d(a,c):=a*c ==> d'(a)=c, d'(c)=a
da_dx1 = pg.exp(x1) # a(x1):=exp(x1) ==> a'(x1)=exp(x1)
dc_db = 2*b # c(b):=b^2 ==> c'(b)=2b
db_dx2 = pg.cos(x2) # b(x2):=sin(x2) ==> b'(x2)=cos(x2)

# outputs to inputs: outside-in symbolically, right-to-left graphically
dd_dd = pg.Tensor(1) # base case
dd_da, dd_dc = [dd_dd*dd_da, dd_dd*dd_dc]
dd_dx1 = dd_da*da_dx1 # DONE for the x1->d path

dd_db = dd_dc*dc_db
dd_dx1 = dd_db*db_dx2 # DONE for x2->path

# inputs to outputs: inside-out symbolically, left-to-right graphically
dx1_dx1, dx2_dx2 = [pg.Tensor(1), pg.Tensor(1)] # base case
da_dx1 = da_dx1*dx1_dx1
dd_dx1 = dd_da*da_dx1 # DONE for the x1->d path

db_dx2 = db_dx2*dx2_dx2
dc_dx2 = dc_dc*db_dx2
dd_dx2 = dd_dc*dc_dx_2 # DONE for the x2->d path

Do you notice any difference in the number of evaluations between the two orders?

The outputs-to-input ordering takes 6 arithmetic operations (including the destructuring), whereas the input-to-output ordering take 7 arithmetic operations. This is because the former can reuse dd_dd as a dynamic programming solution to a subproblem for the two inputs, whereas the latter cannot. And taking a step back, we only want to reuse the output because the shape of the function is of . Alternatively, if had type , then the input-to-output ordering would be able to reuse results. This distinction is referred to as "forward-mode" vs "reverse-mode", and reflects the fact that for some function the time complexity of forward-mode differentiation is proportional to , whereas that of forward-mode differentiation is proportional to . If the expression graph fans-in so that , reverse-mode is preferred. If the expression graph fans-out so that , forward-mode is preferred. However, if we take a step with a graph-theory lens, we can see that the derivative is the sum of paths, where each path is a product of local derivatives from the input source to the output sink. From a combinatorics perspective, we are calculating all the possible (ors) ways (ands) on how the inputs perturb the output. That is:

and as long as the operations along this path are associative — then we can choose the order in how we perform these path products to minimize the number of operations. Finding the optimal ordering is an NP-hard problem because ____. For instance, if the expression graph is diamond-shaped, evaluating the derivative with forward-mode for the left-half and reverse-mode for the right-half would be more performant. In practice, we use reverse-mode as a heuristic, since most of the functions that are differentiated (so they can be optimized) in the field of machine learning are neural networks of the form

How can we generalize this into an algorithm?
All we need are 1. mappings from and 2. a topological sort

For the derivative rules, the same way that optimizing compilers implement an optimization "manually" once which then gets reused many times, the authors of deep learning frameworks also implement derivatives manually which then become reused many times through automatic differentiation. In theory, we can differentiate any expression with f'(x) with only a few derivative rules for addition and multiplication, but in practice most frameworks provide sugar for complex derivatives.

For topological sort, we can simply reversed the ordering produced by a depth-first-search:

def toposort(self):
  order: list[Op] = []
  visited: set[Op] = set()

  def dfs(node: Op) -> None:
    if node in visited: return
    visited.add(node)
    for src in node.src: dfs(src)
    order.append(node)

  dfs(self)
  return order

class Tensor():
  def backward():
    for t in reversed(topo):
      t.backward()

We will now use this idea to modify the interpretation of our deep learning framework to not only evaluate , but as well. This is done by dynamically overloading the operators at runtime[^0] to trace the expression graph

chain_rules = PatternMatcher([
  (Pattern(OpCode.MATMUL, name="input"), lambda output_grad, input: (_____,)),
  (Pattern(OpCode.MATVEC, name="input"), lambda output_grad, input: (_____,)),
  (Pattern(OpCode.RECIPROCAL, name="input"), lambda output_grad, input: (-output_grad * input * input,)),
  (Pattern(OpCode.SIN, name="input"), lambda output_grad, input: ((math.pi/2 - input.src[0]).sin() * output_grad,)),
  (Pattern(OpCode.LOG2, name="input"), lambda output_grad, input: (output_grad / (input.src[0] * math.log(2)),)),
  (Pattern(OpCode.EXP2, name="input"), lambda output_grad, input: (input * output_grad * math.log(2),)),
  (Pattern(OpCode.SQRT, name="input"), lambda output_grad, input: (output_grad / (input*2),)),
  (Pattern(OpCode.ADD), lambda output_grad: (1.0*output_grad, 1.0*output_grad)),
  (Pattern(OpCode.MUL, name="input"), lambda output_grad, input: (input.src[1]*output_grad, input.src[0]*output_grad)),
])

class Tensor:
  def _forward(self, f:Callable, *other:Tensor) -> Tensor: #extra_args=(), **kwargs)
    out_tensor = evaluator.eval_uop([self, other], out_uop)

  def backward(self, grad:Tensor|None=None) -> Tensor:
    """
    backward performs by collecting tensors, computing gradients with automatic differentiation, and updating said tensors.
    """
    # 1. collect all tensors that requires grad by topologically sorting the graph of uops and filter
    all_uops = self.uop.toposort()
    tensors_require_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops and t.requires_grad]
    uops_require_grad = [t.uop for t in tensors_require_grad]
    assert grad is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
    if not (self.is_floating_point() and all(t.is_floating_point() for t in tensors_require_grad)): raise RuntimeError("only float Tensors have gradient")
    
    # 2. compute the gradient with a map of tensors to partials
    if grad is None: grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) # base case is 1.0
    tens2grads = Tensor._automatically_differentiate(self.uop, grad.uop, set(uops_require_grad)) # skipping materializing zerod grads for now
    grads = [Tensor(g, device=t.device) for t,g in zip(tens2grads.keys, tens2grads.values)] # initialize tensor grads on device
    
    # 3. update the tensors that require grad with the gradient's partials
    for t,g in zip(tensors_require_grad, grads):
      assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
      t.grad = g if t.grad is None else (t.grad + g) # accumulate if t.grad exists
    return self

  @staticmethod
  def _automatically_differentiate(root:Op, root_grad:Op, targets:set[Op]) -> dict[Op, Op]:
    """
    _differentiate backpropagates partials on a topologically sorted expression graph with the chain rule
    and produces the gradient in the form of a map of ops to their partials (which, in turn, are ops)
    """
    tens2grads = {root: root_grad}

    # 1. topological sort
    in_target_path: dict[Op, bool] = {}
    for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
    dfs = list(root.toposort()) # lambda node: node.op not in {OpCode.DETACH, OpCode.ASSIGN} and in_target_path[node])) # don't flow through DETACH/ASSIGN or anything not in target path

    # 2. backpropagation with the chain rule
    for tensor in reversed(dfs):
      if tensor not in tens2grads: continue

      local_grads: tuple[Op|None, ...]|None = cast(tuple[Op, ...]|None, chain_rules.rewrite(tensor, ctx=tens2grads[tensor]))
      if local_grads is None: raise RuntimeError(f"failed to compute gradient for {tensor.op}\n\nin {str(tensor)[0:1000]}...")
      assert len(local_grads) == len(tensor.src), f"got {len(local_grads)} gradient, expected {len(tensor.src)}"

      for tensor,local_grad in zip(tensor.src, local_grads): # <--------------------- MOOOSE: why are we accumulating inside ad()? don't we do it in backward()??
        if local_grad is None: continue
        if tensor in tens2grads: tens2grads[tensor] = tens2grads[tensor] + local_grad # accumulate if tensor exists
        else: tens2grads[tensor] = local_grad # o/w initialize

To implement automatic differentiation with Tensor.backward(), there is a design decision to be made — the choice of implementing it dynamically or just-in-time[^3], similar to the decision of how to implement types for general programming languages[^4]. This stands in contrast to the alternative of performing a just-in-time, source-to-source transformation.

Let's now move onto automatically differentiating the functions of neural networks, specifically the FFN language model from earlier. (johnson/ryan adams ordering) n^2 vs n^3

2.2.2 Stochastic Gradient Descent via optim.sgd

2.3 Accelerating cuBLAS Kernels

https://arxiv.org/pdf/1410.0759

https://arxiv.org/pdf/1804.06826 https://arxiv.org/pdf/2512.02189v1 https://girl.surgery/bad_paper https://www.arxiv.org/pdf/2512.07004

2.3.1 PDP11 Problem: Throughput-Oriented Many Core Processors

#[allow(improper_ctypes_definitions)]
#[kernel] pub unsafe fn main_gpu() {
    println!("of Tensor Programs!");
}

use cust::prelude::*;
use std::error::Error;

fn main() -> Result<(), Box<dyn Error>> {
  let _ctx = cust::quick_init()?; // Initialize the CUDA Driver API. `_ctx` must be kept alive until the end.
  let module = Module::from_ptx(PTX, &[])?; // Create a module from the PTX code compiled by `cuda_builder`.
  let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?; // Create a stream, which is like a thread for dispatching GPU calls.
  let add_kernel = module.get_function("add")?;
  unsafe { launch!(add_kernel<<<stream>>>())?; }
  stream.synchronize()?;
  Ok(())
}

2.3.2 Accelerating GEMM with CUDA(RS)

2.3.3 Accelerating GEMM with Data Reuse

#[allow(improper_ctypes_definitions)]
#[kernel] pub unsafe fn main_gpu() {
    println!("of Tensor Programs!");
}

use cust::prelude::*;
use std::error::Error;

fn main() -> Result<(), Box<dyn Error>> {
  let _ctx = cust::quick_init()?; // Initialize the CUDA Driver API. `_ctx` must be kept alive until the end.
  let module = Module::from_ptx(PTX, &[])?; // Create a module from the PTX code compiled by `cuda_builder`.
  let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?; // Create a stream, which is like a thread for dispatching GPU calls.
  let add_kernel = module.get_function("add")?;
  unsafe { launch!(add_kernel<<<stream>>>())?; }
  stream.synchronize()?;
  Ok(())
}

2.3.4 Accelerating GEMM with Scheduling:

#[allow(improper_ctypes_definitions)]
#[kernel] pub unsafe fn main_gpu() {
    println!("of Tensor Programs!");
}

use cust::prelude::*;
use std::error::Error;

fn main() -> Result<(), Box<dyn Error>> {
  let _ctx = cust::quick_init()?; // Initialize the CUDA Driver API. `_ctx` must be kept alive until the end.
  let module = Module::from_ptx(PTX, &[])?; // Create a module from the PTX code compiled by `cuda_builder`.
  let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?; // Create a stream, which is like a thread for dispatching GPU calls.
  let add_kernel = module.get_function("add")?;
  unsafe { launch!(add_kernel<<<stream>>>())?; }
  stream.synchronize()?;
  Ok(())
}

2.3.5 Accelerating GEMM with Tensor Cores:

2.4 Learning Sequences with Different Inductive Biases

Sequence learning...

2.4.1 CNN: Convolutional Neural Networks

2.4.2 RNN: Recurrent Neural Networks

2.4.3 BERT: Bidirectional Encoder Representations from Transformers

2.4.4 GPT, Generative Pretrained Transformers

#!/usr/bin/env python3
import os, argparse, contextlib
from typing import Optional, Union
with contextlib.suppress(ImportError): import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable, dtypes
from tinygrad.uop.ops import UOp
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import gguf_load, torch_load, load_state_dict, get_state_dict
from extra.bench_log import BenchEvent, WallTimeEvent

MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")

class Attention:
  def __init__(self, dim, n_heads):
    self.c_attn = Linear(dim, 3*dim, bias=True)
    self.c_proj = Linear(dim, dim, bias=True)
    self.n_heads = n_heads
    self.dim = dim
    self.head_dim = dim // n_heads

  def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
    if mask is not None or start_pos.val == 0:
      # no symbolic shape qkv when consuming prompts
      start_pos = start_pos.val

    if HALF: x = x.half()
    xqkv = self.c_attn(x).reshape(None, None, 3, self.n_heads, self.head_dim)
    xq, xk, xv = [xqkv[:, :, i, :, :] for i in range(3)]
    bsz, seqlen, _, _ = xq.shape

    # create kv cache
    if not hasattr(self, "cache_kv"):
      self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()

    # update the cache
    self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()

    if start_pos > 0:
      keys = self.cache_kv[0][:, :start_pos+seqlen, :, :]
      values = self.cache_kv[1][:, :start_pos+seqlen, :, :]
    else:
      keys = xk
      values = xv

    xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
    return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))

class FeedForward:
  def __init__(self, dim, hidden_dim):
    self.c_fc = Linear(dim, hidden_dim, bias=True)
    self.c_proj = Linear(hidden_dim, dim, bias=True)

  def __call__(self, x:Tensor) -> Tensor:
    return self.c_proj(self.c_fc(x).gelu())

class TransformerBlock:
  def __init__(self, dim, n_heads, norm_eps):
    self.attn = Attention(dim, n_heads)
    self.mlp = FeedForward(dim, 4*dim)
    self.ln_1 = LayerNorm(dim, norm_eps)
    self.ln_2 = LayerNorm(dim, norm_eps)

  def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
    h = x + self.attn(self.ln_1(x), start_pos, mask).float()
    return (h + self.mlp(self.ln_2(h))).contiguous()

class Transformer:
  def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
    self.vocab_size = vocab_size
    self.wte = Embedding(vocab_size, dim)
    self.wpe = Embedding(max_seq_len, dim)
    self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
    self.ln_f = LayerNorm(dim, norm_eps)
    self.lm_head = Linear(dim, vocab_size, bias=False)
    self.forward_jit = TinyJit(self.forward)

  def forward(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0):
    if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
    if isinstance(tokens, UOp):
      seqlen = 1
      tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
    else:
      seqlen = tokens.shape[1]
      tok_emb = self.wte(tokens)

    # not symbolic when consuming the prompt
    selected_pos = (0, seqlen) if start_pos.val == 0 else (start_pos, start_pos+1)
    pos_emb = self.wpe(self.allpos.shrink((None, selected_pos)))

    h = tok_emb + pos_emb

    if HALF: h = h.half()

    mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1) if seqlen > 1 else None

    for hi in self.h: h = hi(h, start_pos, mask)

    logits = self.lm_head(self.ln_f(h))

    if logits.shape[1] == 0:
      # special case for empty prompt
      logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device)
    else:
      logits = logits[:, -1, :]

    if temperature < 1e-6:
      ret = logits.argmax(-1)
    else:
      ret = (logits / temperature).softmax().multinomial()
    return ret.flatten().realize()

  def __call__(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0) -> Tensor:
    forward = (self.forward_jit if JIT and (isinstance(tokens, UOp) or tokens.shape[1] == 1) else self.forward)
    return forward(tokens, start_pos, temperature)