Study on LLM Inference
References
Other Details about LLM that I Keep Forgetting:
Basics
- Multi-head and Single-head Attention: Each attention block has its own key-value weights in multi-head attention. We can see this as each attention block having its own K-V Space. This is similar to having multiple different libraries that we can search to find most relevant details regarding our query (Q).
- What is Q, K, and V: I keep getting confused with the weight matrices of Q, K, and V. These are updated during training and the output of matrix multiplication between the embeddings (user input tokens converted to vectors) and W_q would result in the Q mentioned above. Then, a dot product between Q and K followed by a softmax is usually done to find the value with the highest similarity score to each query.
-
Num of Parameters != Hidden Dimension
: The total parameter count is roughly calculated by taking the parameters in each layer (dominated byh^2
) and multiplying that by the number of layers (L
), then adding the parameters from the embedding layers. The hidden size (h
) is a measure of the model’s width, while the total parameter count is a measure of the model’s overall scale or size, which is a product of its width, depth (L
), and other factors.
Good Analogy
- It’s good to think about an attention-block (whether it be multi-headed or single-headed) as a big building. Within this, there are multiple (or a single) library, which would be the attention-heads.
- MHA: Each Attention-head has its own W_q, W_k, W_v matrices that are determined during the training process. These essentially provide different “perspectives” to the same user input queries.
- MQA: multiple questions to a single library; This is probably the most conservative way to save memory for KV Cache. Each attention head will have its own W_q matrix, but all heads will share the same W_k and W_v.
- GQA: this is a hybrid of both approaches (MHA and MQA); each attention head has its won W_q (so there’s multiple Q), but the W_k and W_v are grouped such that only some heads share the same weights.
LLM Inference
- Prefill Phase: process entire user input prompt and initiates/updates the KV Cache → token logits are available
- Token Sampling (uses sampling process to select the next new token)
- Decode Phase: Transformer block containing Attention + MLP layer
- Attention: token to output next is used to calculate KV and append to the KV Cache
- MLP: Generate next token until EOS (end-of-sequence) is generated
- Requires access to key and value activations of previously processed tokens to perform attention
- Stored in KV Cache
- Additional notes:
- There can be multiple Transformer blocks
- each are responsible for extracting different context for the same token
- each layer has its own KV Cache
- each layer produces its own token_x to serve as input to next transformer block
- There can be multiple Attention layers (run in parallel) within a single Transformer block
- each attention layer is responsible for focusing on different parts of the sequence for a given token
- each attention layer outputs a vector representing a token
- these tokens are concatenated then projected to expected dimensions for the MLP layer
- For a transformer block, and for a given input, X,
(Q,K,V)
matrix is generated by a learned matrix(W_q, W_k, W_v)
. This is then equally split by the number of attention heads in that layer.-
Q
is disposed -
K
andV
are stored in the KV Cache
-
- There can be multiple Transformer blocks
- Optimization Techniques:
- TP / PP parallelism
- Prefill / Decode Priority Scheduling
- Note: Optimization technique is dependent on the model and workload
Decode Step : Attention + MLP
- Attention Kernel - dependent on request history (KV Cache)
- Current token to be processed (Query token) is compared against all previous tokens (Key vectors) to calculate the attention score
- Attention score: how relevant each previous token is to the current one
- Attention score is used to create a weighted average of the Value vectors for all previous tokens
- Output is a vector of fixed size that represents the token’s meaning with the given context
- Work needed to be done is proportional to the length of the KV Cache (request history)
- Current token to be processed (Query token) is compared against all previous tokens (Key vectors) to calculate the attention score
- MLP Kernel/Layer - independent of KV Cache
- MLP takes an input from the output of the Attention Layer
- Doesn’t need the KV Cache as the input vector is already defined within the context of the request history
- Amount of computation done in MLP is constant for any single token (regardless of the position of the token)