The Multi-head Attention Wrapper is an implementation pattern that stacks multiple instances of a causal self-attention module to create a multi-head attention mechanism.
Implementation Details
- Structure: It creates
num_headsinstances of theCausalAttentionclass. - Forward Pass:
- It loops through each attention head to compute its output (context vector).
- It concatenates these outputs along the column dimension (last dimension) to form the final context vector matrix.
- Parameters: The wrapper accepts parameters such as:
d_in(Input embedding dimension)d_out(Output dimension for each head)context_length(Number of tokens)dropout(Dropout rate)num_heads(Number of attention heads).
Example
If num_heads=2 and d_out=2, the wrapper creates two causal attention instances. If the input is a batch of sentences, the wrapper processes them through both instances and combines the results.
