An nn.Embedding
is essentially a lookup table that maps indices to dense vectors. These vectors are initially randomly initialized and then learned during training to capture meaningful relationships between the items they represent. In practice nn.Embedding(num_embeddings, embedding_dim)
where:
num_embeddings
: How many different things you want to embedembedding_dim
: Size of the vector representing each thing
Basic Concept:
# Simple illustration of what nn.Embedding does internally
class SimpleEmbedding:
def __init__(self, num_embeddings, embedding_dim):
# Create a lookup table matrix of size (num_embeddings × embedding_dim)
self.weight = torch.randn(num_embeddings, embedding_dim)
def forward(self, indices):
# Return vectors for given indices
return self.weight[indices]
# Real nn.Embedding usage
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)
# Input indices
indices = torch.tensor([0, 2, 5])
# Get embeddings for these indices
vectors = embedding(indices)
Visual Example:
num_embeddings = 10
embedding_dim = 3
Lookup Table (weight matrix):
[
[0.1, 0.2, 0.3], # index 0
[0.4, 0.5, 0.6], # index 1
[0.7, 0.8, 0.9], # index 2
... # and so on
[0.2, 0.3, 0.4] # index 9
]
If indices = [0, 2], the output would be:
[
[0.1, 0.2, 0.3], # for index 0
[0.7, 0.8, 0.9] # for index 2
]
The lookup table represents a vector space where each index is mapped to a vector in a high-dimensional space. This is a fundamental concept in representation learning and NLP.
# Example: Word Embedding Space
vocab_size = 10000 # number of words
embedding_dim = 300 # dimensions in vector space
word_embedding = nn.Embedding(vocab_size, embedding_dim)
# Each word gets a position in 300-dimensional space
cat_idx = 25
dog_idx = 47
animal_idx = 129
cat_vector = word_embedding(torch.tensor(cat_idx)) # [300]
dog_vector = word_embedding(torch.tensor(dog_idx)) # [300]
animal_vector = word_embedding(torch.tensor(animal_idx)) # [300]
In this vector space:
- Similar words should be close to each other
- Relationships can be captured by vector arithmetic
- Semantic meaning is distributed across dimensions
Visual example (in 2D for simplicity):
animal
↑
/ \
/ \
cat ---- dog
Real-world examples of vector space relationships:
# Famous example: king - man + woman ≈ queen
king_vector = word_embedding(king_idx)
man_vector = word_embedding(man_idx)
woman_vector = word_embedding(woman_idx)
queen_vector = king_vector - man_vector + woman_vector
# Similar words cluster together
# dog, puppy, canine would be close in the vector space
# cat, kitten, feline would be close in another region
This is why embeddings are powerful:
- They convert discrete items (words, positions) into continuous vectors
- They can capture semantic relationships
- They allow mathematical operations on concepts
- They enable neural networks to process symbolic data
The same principle applies to positional embeddings:
max_len = 512
hidden_dim = 768
pos_embedding = nn.Embedding(max_len, hidden_dim)
# Nearby positions should have similar vectors
pos_5 = pos_embedding(torch.tensor(5))
pos_6 = pos_embedding(torch.tensor(6))
# pos_5 and pos_6 should be closer than pos_5 and pos_100