Neural Agents That Never Forget

Building Adaptive Memory Systems with Episodic Replay and Meta-Learning in PyTorch

By Taufia Hussain — November 11, 2025 · ~14 min read

Abstract visualization of a neural network with memory modules
Brain-inspired AI: combining differentiable memory, experience replay, and meta-learning to build agents that adapt over time without forgetting.

Modern neural networks are incredibly good at fitting data but pretty bad at continuously learning. Train them on Task A, then on Task B, and they often overwrite what they learned earlier. This problem is known as catastrophic forgetting.

In this tutorial, we will build a small but powerful neural memory agent that can learn a sequence of tasks while still remembering previous ones. The idea is inspired by the brain’s episodic memory system and uses three key ingredients:

Who is this for? You are comfortable with basic PyTorch and want to understand how memory-augmented neural networks and continual learning systems actually work under the hood.

Why Continual Learning is Hard (and Why Memory Helps)

Standard deep learning assumes a fixed dataset: shuffle, train, converge, deploy. But real-world agents — robots, trading bots, adaptive scientific tools — encounter streams of data where the distribution keeps changing over time. In these settings, a naive training loop quickly overwrites old knowledge.

Biological brains handle this by combining several types of memory:

Our neural agent will mimic a tiny part of this: an episodic memory module that can store compressed representations of what it has seen, and retrieve them later using similarity-based addressing.

Step 1: Setting Up the Memory Configuration

We start by defining a configuration object that controls the structure of our episodic memory. This is helpful if you later want to experiment with different memory sizes or head counts.

from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

@dataclass
class EpisodicMemoryConfig:
    memory_slots: int = 128   # how many rows in the memory matrix
    slot_dim: int = 64        # dimensionality of each memory slot
    read_heads: int = 3       # how many parallel read operations
    write_heads: int = 1      # (kept for extensibility)

Think of this as a little key–value store inside your network: each row is a slot that can contain a vector embedding of some experience (e.g. a task, an input pattern, or a hidden state snapshot).

Step 2: Implementing Differentiable Episodic Memory

Our episodic memory is simply a learnable matrix that the model can address by content: given a query vector (a “key”), it computes similarities to all memory slots and uses a softmax to turn these scores into attention weights.

class EpisodicMemory(nn.Module):
    def __init__(self, config: EpisodicMemoryConfig):
        super().__init__()
        self.config = config
        # memory: [memory_slots, slot_dim]
        self.register_buffer("memory",
                             torch.zeros(config.memory_slots, config.slot_dim))
        self.register_buffer("usage",
                             torch.zeros(config.memory_slots))

    def address(self, key, strength):
        # key: [slot_dim]
        key_norm = F.normalize(key, dim=-1)
        mem_norm = F.normalize(self.memory, dim=-1)
        similarity = torch.matmul(key_norm, mem_norm.t())  # [memory_slots]
        return F.softmax(strength * similarity, dim=-1)

    def read(self, keys, strengths):
        reads = []
        for i in range(self.config.read_heads):
            weights = self.address(keys[i], strengths[i])      # [memory_slots]
            read_vec = torch.matmul(weights, self.memory)      # [slot_dim]
            reads.append(read_vec)
        # concatenate all read heads
        return torch.cat(reads, dim=-1)    # [read_heads * slot_dim]

    def write(self, key, vector, erase, strength):
        weights = self.address(key, strength)                  # [memory_slots]
        erase_mat = torch.outer(weights.squeeze(), erase)      # [slots, dim]
        add_mat = torch.outer(weights.squeeze(), vector)       # [slots, dim]
        self.memory = self.memory * (1 - erase_mat) + add_mat
        self.usage = 0.99 * self.usage + weights.squeeze()

The usage vector is a simple heuristic that tracks which memory slots are being touched more often. In a more advanced version, you could use it to implement “least-used” or “age-based” replacement strategies.

Step 3: Designing the Memory Controller

The controller plays the role of a tiny “prefrontal cortex”. It receives the current input, processes it through an LSTM, and then decides:

class MemoryController(nn.Module):
    def __init__(self, input_dim, hidden_dim, config: EpisodicMemoryConfig):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        total_read_dim = config.read_heads * config.slot_dim

        # Read parameters
        self.read_keys     = nn.Linear(hidden_dim, total_read_dim)
        self.read_strength = nn.Linear(hidden_dim, config.read_heads)

        # Write parameters
        self.write_key     = nn.Linear(hidden_dim, config.slot_dim)
        self.write_vector  = nn.Linear(hidden_dim, config.slot_dim)
        self.erase_vector  = nn.Linear(hidden_dim, config.slot_dim)
        self.write_strength = nn.Linear(hidden_dim, 1)

        # Final output on top of controller state + read vectors
        self.output = nn.Linear(hidden_dim + total_read_dim, input_dim)

    def forward(self, x, memory: EpisodicMemory, hidden=None):
        # x: [feat_dim] → add batch/time dim for LSTM
        lstm_out, hidden = self.lstm(x.unsqueeze(0), hidden)
        state = lstm_out.squeeze(0)  # [hidden_dim]

        # Compute read parameters
        read_keys = self.read_keys(state).view(memory.config.read_heads, -1)
        read_s    = F.softplus(self.read_strength(state))

        # Compute write parameters
        write_k = self.write_key(state)
        write_v = torch.tanh(self.write_vector(state))
        erase_v = torch.sigmoid(self.erase_vector(state))
        write_s = F.softplus(self.write_strength(state))

        # Interact with memory
        read_vecs = memory.read(read_keys, read_s)
        memory.write(write_k, write_v, erase_v, write_s)

        # Produce output
        combined = torch.cat([state, read_vecs], dim=-1)
        out = self.output(combined)
        return out, hidden

In practice, this controller could sit on top of sensory encoders (images, audio, text). For our demo, we’ll work with synthetic vector inputs so we can focus on the memory mechanics.

Step 4: Experience Replay for Stability

If we simply train on tasks one after another, the controller will still forget. To mitigate this, we maintain a small replay buffer of past examples and periodically mix them into the current training batch.

class ReplayBuffer:
    def __init__(self, capacity=5000):
        self.buffer   = deque(maxlen=capacity)
        self.priority = deque(maxlen=capacity)

    def push(self, experience, p=1.0):
        self.buffer.append(experience)
        self.priority.append(p)

    def __len__(self):
        return len(self.buffer)

    def sample(self, batch_size):
        probs = np.array(self.priority, dtype=np.float32)
        probs = probs / probs.sum()
        idx   = np.random.choice(len(self.buffer), batch_size, p=probs)
        return [self.buffer[i] for i in idx]

Here we store a priority per sample (e.g. proportional to the loss). Harder or more surprising examples are replayed slightly more often, similar to prioritized experience replay in reinforcement learning.

Optional: A Simple Meta-Learning Hook

To keep the blog focused, we only sketch a minimal meta-learning function. The idea is to create a fast adaptation step that updates a copy of the parameters on a small support set.

class MetaLearner(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def adapt(self, x, y, steps=3, lr=0.01):
        fast_params = {n: p.clone() for n, p in self.model.named_parameters()}
        for _ in range(steps):
            y_pred, _ = self.model(x, self.model.memory)
            loss = F.mse_loss(y_pred, y)
            grads = torch.autograd.grad(loss, self.model.parameters(),
                                        create_graph=True)
            fast_params = {
                n: p - lr * g for (n, p), g in zip(fast_params.items(), grads)
            }
        return fast_params

You can later extend this into a full MAML-style loop where the outer optimizer updates the base parameters based on post-adaptation performance.

Step 5: Wrapping Everything into a Continual Learning Agent

Now we build a small ContinualAgent class that holds:

class ContinualAgent:
    def __init__(self, input_dim=64, hidden_dim=128):
        self.config   = EpisodicMemoryConfig()
        self.memory   = EpisodicMemory(self.config)
        self.controller = MemoryController(input_dim, hidden_dim, self.config)
        self.replay   = ReplayBuffer()
        self.optimizer = torch.optim.Adam(self.controller.parameters(),
                                          lr=1e-3)

    def train_step(self, x, y, replay_ratio=0.3):
        self.optimizer.zero_grad()

        # 1) Current example
        pred, _ = self.controller(x, self.memory)
        loss = F.mse_loss(pred, y)
        self.replay.push((x.detach(), y.detach()), p=loss.item() + 1e-6)

        # 2) Mix in a few replay samples
        if len(self.replay) >= 16:
            replay_batch = self.replay.sample(batch_size=8)
            for rx, ry in replay_batch:
                rpred, _ = self.controller(rx, self.memory)
                rloss = F.mse_loss(rpred, ry)
                loss = loss + replay_ratio * rloss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.controller.parameters(), 1.0)
        self.optimizer.step()
        return loss.item()

    def evaluate(self, data):
        self.controller.eval()
        total = 0.0
        with torch.no_grad():
            for x, y in data:
                pred, _ = self.controller(x, self.memory)
                total += F.mse_loss(pred, y).item()
        self.controller.train()
        return total / len(data)

The key idea: during training on Task t, we constantly replay a few samples from tasks 1 ... t−1. The memory bank simultaneously accumulates representations that are useful across tasks, rather than only for the most recent one.

Step 6: Creating Synthetic Tasks for Continual Learning

To keep things lightweight, we will generate a few synthetic regression tasks. Each task maps a random input vector to a transformed version using a different nonlinearity.

def make_task(task_id: int, n_samples: int = 100):
    torch.manual_seed(task_id)
    x = torch.randn(n_samples, 64)

    if task_id == 0:
        # Task 1: sine of the mean
        y = torch.sin(x.mean(dim=1, keepdim=True).expand(-1, 64))
    elif task_id == 1:
        # Task 2: scaled cosine
        y = 0.5 * torch.cos(x.mean(dim=1, keepdim=True).expand(-1, 64))
    else:
        # Task 3+: tanh with shifting
        y = torch.tanh(x + task_id * 0.3)

    return [(x[i], y[i]) for i in range(n_samples)]

This setup gives us tasks that are related but not identical, which is exactly the regime where continual learning becomes interesting.

Step 7: Running the Continual Learning Demo

Let us put it all together and watch how the test error evolves as the agent encounters new tasks.

def run_demo(num_tasks: int = 4):
    print(" Neural Memory Agent — Continual Learning Demo")
    print("=" * 60)

    agent = ContinualAgent()
    history = {"task": [], "error": []}

    for task_id in range(num_tasks):
        print(f"\n Training on Task {task_id + 1}/{num_tasks}")
        train_data = make_task(task_id, n_samples=80)
        test_data  = make_task(task_id, n_samples=30)

        for epoch in range(20):
            total_loss = 0.0
            for x, y in train_data:
                total_loss += agent.train_step(x, y)
            if epoch % 5 == 0:
                print(f"  Epoch {epoch:02d} | Loss = {total_loss/len(train_data):.4f}")

        # Evaluate on all tasks seen so far
        print("\n   Evaluation on all seen tasks:")
        for eval_id in range(task_id + 1):
            eval_data = make_task(eval_id, n_samples=30)
            err = agent.evaluate(eval_data)
            print(f"    Task {eval_id + 1}: MSE = {err:.4f}")
            if eval_id == task_id:
                history["task"].append(eval_id + 1)
                history["error"].append(err)

    return agent, history

Visualizing Memory and Performance

After training, it is useful to look both at the final memory matrix and the evolution of test error across tasks.

def plot_results(agent, history):
    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    # Left: memory heatmap
    ax = axes[0]
    mem = agent.memory.memory.detach().numpy()
    im = ax.imshow(mem, aspect="auto", cmap="viridis")
    ax.set_title("Episodic Memory Bank")
    ax.set_xlabel("Memory Dimension")
    ax.set_ylabel("Memory Slot")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    # Right: error over tasks
    ax = axes[1]
    ax.plot(history["task"], history["error"], marker="o", linewidth=2)
    ax.set_title("Test Error on Newest Task")
    ax.set_xlabel("Task number")
    ax.set_ylabel("MSE")
    ax.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig("neural_memory_agent_results.png", dpi=150)
    plt.show()
    print(" Saved plot to 'neural_memory_agent_results.png'")

if __name__ == "__main__":
    agent, hist = run_demo(num_tasks=4)
    plot_results(agent, hist)

Interpreting the Results

When you run the script, you should see:

This is not a production-ready continual learning system, but it already demonstrates an important point: as soon as you give your network the ability to store and retrieve compressed experiences, plus a mechanism to revisit past data, its tendency to catastrophically forget is greatly reduced.

Where to Go Next

Here are a few directions you can explore on top of this tutorial:

If you work in science or engineering, you can imagine plugging this kind of agent into tools that need to adapt over days or weeks, for example, a lab analysis system that gradually learns from each new experiment while still remembering the patterns it saw months ago.

I’ll be experimenting with these ideas inside DataLens.Tools for building adaptive, AI-assisted analysis workflows for neuroscience and biology data.

Final Thoughts

Continual learning sits at the intersection of deep learning, neuroscience, and real-world deployment. By combining differentiable memory, episodic replay, and meta-learning, we can move a little closer to agents that behave less like static models and more like learning systems.

Feel free to adapt this code, plug in your own tasks, and extend the architecture. If you build something interesting with it, I would love to hear about it.