Dependency Injection for Artificial Intelligence (DI4AI)

Gideon Dresdner | March 14, 2025 | Home

Introduction

“Dependency injection” is a ubiquitous $25 term for a 25¢ concept. It has also been known as the “Inversion of Control Principle” (another terrible term) or, and this is my favorite, the Hollywood Principle: “Don’t call us, we’ll call you.” At its core, dependency injection aims to explicitly define a routine’s dependencies through its arguments. In my opinion, this often results in cleaner code. More objectively speaking, it results in code that is more modular and easier to unit test.

Dependency injection (DI) is often associated with sophisticated frameworks such as spring and guice which aim to reduce boilerplate by declaratively wiring-up dependencies. But, frameworks like these are not required to implement a DI-inspired codebase. In fact, it is a particularly exciting moment to be writing code using DI because we can easily generate boilerplate using AI-assisted programming. Furthermore, the Python community is adopting modern tooling such as mypy and dataclasses which work well with DI. Overall, it has never been easier to adopt these principles in a lightweight manner.

Even as AI/ML adoption surges, the software engineers developing these systems are often struggling to adopt basic software engineering techniques. A common misconception is that AI/ML suffers from leaky abstractions and larger, more complex configuration. In short, AI/ML is more complicated than your run of the mill software [A Recipe for Training Neural Networks, gin-config readme1]. I don’t believe this. Instead, social and historical reasons underlie this neglect of software craftmanship.

To date, AI/ML has maintained its academic roots. Many AI/ML practitioners are expected to have PhDs and the field’s luminaries are often (former) professors. Publications, not software artifacts, are still the de facto currency of achievement. This association with academia leads engineers to believe that they are scientists, prioritizing ideas and equations over software architecture and implementation. We should embrace the fact that, from its inception with Weiner and others, an interdisciplinary, engineering ethos permeates AI/ML. This doesn’t mean doing less research, it just means being less dismissive of our engineering-oriented colleagues.

Remember: This is just Python. There are no complex libraries to install. There is no black-box framework running your code. This is simply a way of organizing your thoughts in code.

Example: Trainer class

Let’s introduce a concrete example of what code written in a DI-style looks like. Suppose we are implementing a training loop for some AI/ML model. The Trainer class implements this routine in the train_one_epoch method. In essence, the training algorithm doesn’t depend on the AI/ML model architecture, the incidental way that we’ve chosen to keep track of various metrics, e.g. using WandB, or even the dataset, as long as it satisfies mild type constraints.

Thus, we pass each of those dependencies as arguments to the constructor. Constructing them is not our concern right now.

This has a positive effect on unit testing:

  1. It’s easy to implement NullWandbLogger (see null object pattern) which can be used to unit test train_one_epoch without touching wandb. This way we don’t have to resort to pytest patching and other metaprogramming tricks or worry about external calls to the wandb API, etc.
  2. We can trivially create dummy train data by just passing in a sequence, e.g. a list, that we hard-code in a unit test. This avoids a lot of incidental complexity — dealing with dummy data files on disk, etc. — streamlining unit testing.
  3. We can pass in a simple, small neural network architecture to test Trainer independently of whatever complex neural network architecture we are using.
class WandbLogger:
    def log(self, metrics: dict) -> None:
        ...


class NullWandbLogger:
    def log(self, metrics: dict) -> None:
        # Do nothing
        pass


class SomeArchitecture:
    def __init__(self, n_channels: int, height: int, width: int):
        self.n_channels = n_channels
        self.height = height
        self.width = width
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ...


class Trainer:
    def __init__(
        self, 
        model: SomeArchitecture, 
        wandb: WandbLogger, 
        train_data: Sequence[torch.Tensor]
    ) -> None:
        self.model = model
        self.wandb = wandb
        self.train_data = train_data

    def train_one_epoch(self) -> None:
        for data in self.train_data:
            self.model.forward(data)
            metrics = {
                # compute some metrics
            }
            self.wandb.log(metrics)

Building configurable objects

Our objects often have many dependencies. Some are essential and should be up to the practitioner to specify, e.g. learning rate, dataset path, etc. Others are an unfortunate consequence of incidental complexity, e.g. image size. To clearly marshall dependencies into these two categories, we can use a simple configurable-object pattern which leverages dependency injection. While it’s true that it can be difficult to differentiate between essential and incidental elements, this pattern makes it easy to change a dependency’s status simply by moving between the object’s fields and the build method. Rather than brushing this distinction under the rug, this pattern forces the programmer to clarify their intentions.

In the example below, the practitioner configures the dataset’s resolution, but not the input_shape, which is fully determined by that resolution. Rather than making input_shape externally configurable, it is passed internally to ModelConfig as a build argument. This explicitly encodes the dependency between the dataset and the model while ensuring that the practitioner cannot override it. However, in unit tests, the model can still be constructed independently of any dataset, allowing for flexible and convenient testing.

@dataclasses.dataclass
class DataConfig:
    path: str
    batch_size: int
    resolution: float

    def build(self) -> torch.utils.data.DataLoader:
        # construct a dataloader using this class's attributes
        ...


@dataclasses.dataclass
class ModelConfig:
    hidden_dim: int
    activation: Literal["relu", "gelu"]

    def build(self, input_shape: tuple) -> SomeModel:
        return SomeModel(
            input_shape=input_shape, 
            hidden_dim=self.hidden_dim, 
            activation=self.activation, 
        )


@dataclasses.dataclass
class TrainerConfig:
    train_data: DataConfig
    model: ModelConfig
    wandb: WandbConfig

    def build(self) -> Trainer:
        dataset, dataloader = self.train_data.build()
        model = self.model.build(dataset.input_shape)
        wandb = self.wandb.build()
        return Trainer(model=model, train_data=dataloader, wandb=wandb)

By separating the construction of complex objects from their implementation, we see several positive outcomes:

Anti-Examples

These are examples of what we hope to avoid with this pattern. We call these examples “anti-examples.” ## Hard-coded dependencies

Early on in our programming journeys, we are taught to try to avoid hard-coding whenever possible. But, hard-coding can take on subtle, pernicious forms. For example, consider a method that hard-codes how data is loaded, instead of accepting an iterable dataset as an argument. A good indicator of whether a design is “natural” is how easy it is to test. In the example below, testing requires creating a toy dataset on disk — whereas a more flexible implementation would allow for an ephemeral, in-memory test dataset.

class Trainer:
    def __init__(
        self, 
        n_channels=2, 
        image_height=32, 
        image_width=32, 
        dataset_path="dataset.csv"
    ):
        self.model = SomeArchitecture(n_channels, image_height, image_width)
        self.wandb = WandbLogger()  # Hardcoded logger
        self.train_data = load_data_from_disk(dataset_path)  # Hardcoded data source

    def train_one_epoch(self):
        for data in self.train_data:
            self.model.forward(data)
            metrics = {  # compute some metrics
            }
            self.wandb.log(metrics)

Global state

Another common lesson early in programming is to avoid global state. Once again, testability serves as a useful yardstick for code quality. Global state makes code harder to test and reason about — you often need to jump around the codebase to track down where things are defined. It also clutters the mental model with irrelevant details. For example, when using a model in a training loop, we shouldn’t need to know anything beyond the fact that it has a forward method. In the example below Trainer does not declare what it requires, making it difficult to maintain.

wandb_logger = WandbLogger()
model = SomeArchitecture(3, 32, 32)
train_data = load_data_from_disk("dataset.csv")

class Trainer:
    def train_one_epoch(self):
        for data in train_data:
            model.forward(data)
            metrics = {  # compute some metrics }
            wandb_logger.log(metrics)

Excessive Factory Methods (Over-Engineering)

Abstraction often feels like good hygiene. But too much indirection can actually obscure what’s going on, making code harder to understand, test, and reuse. Consider the following Trainer class — it looks clean at first glance, but it hides several problematic decisions behind unnecessary wrappers and hard-coded logic.

class Trainer:
    def __init__(self):
        self.model = self.create_model()
        self.wandb = self.create_logger()
        self.train_data = self.load_data()

    def create_model(self):
        return SomeArchitecture(3, 32, 32)

    def create_logger(self):
        return WandbLogger()

    def load_data(self):
        return load_data_from_disk("dataset.csv")

Composition over inheritance

A common anti-pattern is to group helper functions into a superclass:

"""
Anti-pattern: using a superclass to group helper functions together and
forcing all users of that helper function to inherit from it.
"""


class BaseTrainer:
    def _compute_metrics(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
    ) -> dict:
        """Helper method to compute training metrics"""
        loss = ((predictions - targets) ** 2).mean().item()
        return {"loss": loss}

class ImageTrainer(BaseTrainer):
    def train_one_epoch(self, model, train_data):
        for data, target in train_data:
            predictions = model.forward(data)
            metrics = self._compute_metrics(predictions, target)  # Calls superclass
            print("Metrics:", metrics)

class TextTrainer(BaseTrainer):
    def train_one_epoch(self, model, train_data):
        for data, target in train_data:
            predictions = model.forward(data)
            metrics = self._compute_metrics(predictions, target)  # Calls superclass
            print("Metrics:", metrics)

This suffers from a number of shortcomings:

A better alternative is to use composable functions:

class Trainer:
    def __init__(
        self,
        model,
        train_data,
     # This is quick & dirty, a cleaner way would be to add more explicit typing
     # and classes to metrics.
        compute_metrics: Callable[[torch.Tensor, torch.Tensor], dict],
    ):
        self.model = model
        self.train_data = train_data
        self.compute_metrics = compute_metrics  # Injected dependency

    def train_one_epoch(self):
        for data, target in self.train_data:
            predictions = self.model.forward(data)
            metrics = self.compute_metrics(predictions, target)  # Explicit function call
            print("Metrics:", metrics)

# Different metric functions for different tasks
def mse_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> dict:
    loss = ((predictions - targets) ** 2).mean().item()
    return {"loss": loss}

def cross_entropy_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> dict:
    loss = torch.nn.functional.cross_entropy(predictions, targets).item()
    return {"cross_entropy_loss": loss}

This implementation is:

Shoutouts

Thanks to Jeremy McGibbon for advocating for the use of dataclasses as configuration classes but more generally for championing high-quality software development within my former team at Ai2. Jordan Lewis provided invaluable feedback and helped connect these ideas to a broader software engineering perspective. I also appreciate my fellow Brightband colleagues—Ryan Keisler, Hans Mohrmann, and Taylor Mandelbaum—as well as Gabe Gaster and Kevin Lin for their thoughtful feedback on early drafts. Alex Merose’s enthusiasm has been incredibly encouraging, inspiring me to share this more widely. I’m excited to see how we can continue refining and streamlining this approach. Finally, a big thanks to Ivan Savov [minireference] who generously provided detailed feedback.

I’d love to hear your thoughts! Feel free to reach out via email at "me" at this domain, on Twitter @gideoknite or connect with me on LinkedIn . Happy to chat!

You might also enjoy Recurse Center — my favorite programming community, and a favorite of others like Julia Evans too.


  1. “Gin is particularly well suited for machine learning experiments (e.g. using TensorFlow), which tend to have many parameters, often nested in complex ways.”↩︎