Dependency Injection for Artificial Intelligence (DI4AI)

Gideon Dresdner | March 14, 2025 | Home

Introduction

“Dependency injection” is a ubiquitous $25 term for a 25¢ concept. It has also been known as the “Inversion of Control Principle” (another terrible term) or, and this is my favorite, the Hollywood Principle: “Don’t call us, we’ll call you.” At its core, dependency injection aims to explicitly define a routine’s dependencies through its arguments. In my opinion, this often results in cleaner code. More objectively speaking, it results in code that is more modular and easier to unit test.

Dependency injection (DI) is often associated with sophisticated frameworks such as spring and guice which aim to reduce boilerplate by declaratively wiring-up dependencies. But, frameworks like these are not required to implement a DI-inspired codebase. In fact, it is a particularly exciting moment to be writing code using DI because we can easily generate boilerplate using AI-assisted programming. Furthermore, the Python community is adopting modern tooling such as mypy and dataclasses which work well with DI. Overall, it has never been easier to adopt these principles in a lightweight manner.

Despite AI/ML hype and adoption skyrocketing, we can see the field struggling to adopt basic software engineering techniques. A common misconception is that AI/ML suffers from leaky abstractions and larger, more complex configuration. In short, AI/ML is more complicated than your run of the mill software [A Recipe for Training Neural Networks, gin-config readme1]. I don’t believe this. Instead, social and historical reasons underlie this neglect of software craftmanship.

To date, AI/ML has maintained its academic roots. Many AI/ML practitioners are expected to have PhDs and the field’s luminaries are often (former) professors. Publications, not software artifacts, are still the de facto currency of achievement. This association with academia leads engineers to believe that they are scientists, prioritizing ideas and equations over software architecture and implementation. We should embrace the fact that, from its inception with Weiner and others, an interdisciplinary, engineering ethos permeates AI/ML. This doesn’t mean doing less research, it just means being less dismissive of our engineering-oriented colleagues.

Remember: This is just Python. There are no complex libraries to install. There is no black-box framework running your code. This is simply a way of organizing your thoughts in code.

Example: Trainer class

Let’s introduce a concrete example of what code written in a DI-style looks like. Suppose we are implementing a training loop for some AI/ML model. The Trainer class implements this routine in the train_one_epoch method. In essence, the training algorithm doesn’t depend on the AI/ML model architecture, the incidental way that we’ve chosen to keep track of various metrics, e.g. using WandB, or even the dataset, as long as it satisfies mild type constraints.

Thus, we pass each of those dependencies as arguments to the constructor. Constructing them is not our concern right now.

This has a positive effect on unit testing:

  1. It’s easy to implement NullWandbLogger (see null object pattern) which can be used to unit test train_one_epoch without touching wandb. This way we don’t have to resort to pytest patching and other metaprogramming tricks or worry about external calls to the wandb API, etc.
  2. We can trivially create dummy train data by just passing in a sequence, e.g. a list, that we hard-code in a unit test. This avoids a lot of incidental complexity — dealing with dummy data files on disk, etc. — streamlining unit testing.
  3. We can pass in a simple, small neural network architecture to test Trainer independently of whatever complex neural network architecture we are using.
class WandbLogger:
    def log(self, metrics: dict) -> None:
        ...


class NullWandbLogger:
    def log(self, metrics: dict) -> None:
        # Do nothing
        pass


class SomeArchitecture:
    def __init__(self, n_channels: int, height: int, width: int):
        self.n_channels = n_channels
        self.height = height
        self.width = width
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ...


class Trainer:
    def __init__(
        self, 
        model: SomeArchitecture, 
        wandb: WandbLogger, 
        train_data: Sequence[torch.Tensor]
    ) -> None:
        self.model = model
        self.wandb = wandb
        self.train_data = train_data

    def train_one_epoch(self) -> None:
        for data in self.train_data:
            self.model.forward(data)
            metrics = {
                # compute some metrics
            }
            self.wandb.log(metrics)

Building configurable objects

Our objects often have many dependencies. Some are essential and should be up to the practitioner to specify, e.g. learning rate, dataset path, etc. Others are an unfortunate consequence of incidental complexity, e.g. image size. To clearly marshall dependencies into these two categories, we can use a simple configurable-object pattern which leverages dependency injection. While it’s true that it can be difficult to differentiate between essential and incidental elements, this pattern makes it easy to change a dependency’s status simply by moving between the object’s fields and the build method. Rather than brushing this distinction under the rug, this pattern forces the programmer to clarify their intentions.

In the example below, the practitioner configures the dataset’s resolution, but not the input_shape, which is fully determined by that resolution. Rather than making input_shape externally configurable, it is passed internally to ModelConfig as a build argument. This explicitly encodes the dependency between the dataset and the model while ensuring that the practitioner cannot override it. However, in unit tests, the model can still be constructed independently of any dataset, allowing for flexible and convenient testing.

@dataclasses.dataclass
class DataConfig:
    path: str
    batch_size: int
    resolution: float

    def build(self) -> tuple[SomeDataset, torch.utils.data.DataLoader]:
        dataset = SomeDataset(path=self.path, resolution=self.resolution)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size)
        return dataset, dataloader


@dataclasses.dataclass
class ModelConfig:
    hidden_dim: int
    activation: Literal["relu", "gelu"]

    def build(self, input_shape: tuple) -> SomeModel:
        return SomeModel(
            input_shape=input_shape, 
            hidden_dim=self.hidden_dim, 
            activation=self.activation, 
        )


@dataclasses.dataclass
class TrainerConfig:
    train_data: DataConfig
    model: ModelConfig
    wandb: WandbConfig

    def build(self) -> Trainer:
        dataset, dataloader = self.train_data.build()
        model = self.model.build(dataset.input_shape)
        wandb = self.wandb.build()
        return Trainer(model=model, train_data=dataloader, wandb=wandb)

These dataclasses can then be bound to yaml config files using dacite:

with open("config.yaml") as file:
    yaml_configuration = yaml.safe_load(file)

config = dacite.from_dict(
    TrainerConfig, 
    yaml_configuration, 
    config=dacite.Config(strict=True)
)
trainer = config.build(other, dependencies)

By separating the construction of complex objects from their implementation, we see several positive outcomes:

Anti-Examples

Hard-coded dependencies

class Trainer:
    def __init__(
        self, 
        n_channels=2, 
        image_height=32, 
        image_width=32, 
        dataset_path="dataset.csv"
    ):
        self.model = SomeArchitecture(n_channels, image_height, image_width)
        self.wandb = WandbLogger()  # Hardcoded logger
        self.train_data = load_data_from_disk(dataset_path)  # Hardcoded data source

    def train_one_epoch(self):
        for data in self.train_data:
            self.model.forward(data)
            metrics = {  # compute some metrics
            }
            self.wandb.log(metrics)

Global state

wandb_logger = WandbLogger()
model = SomeArchitecture(3, 32, 32)
train_data = load_data_from_disk("dataset.csv")

class Trainer:
    def train_one_epoch(self):
        for data in train_data:
            model.forward(data)
            metrics = {  # compute some metrics }
            wandb_logger.log(metrics)

Excessive Factory Methods (Over-Engineering)

class Trainer:
    def __init__(self):
        self.model = self.create_model()
        self.wandb = self.create_logger()
        self.train_data = self.load_data()

    def create_model(self):
        return SomeArchitecture(3, 32, 32)

    def create_logger(self):
        return WandbLogger()

    def load_data(self):
        return load_data_from_disk("dataset.csv")

Composition over inheritance

A common anti-pattern is to group helper functions into a superclass:

"""
Anti-pattern: using a superclass to group helper functions together and
forcing all users of that helper function to inherit from it.
"""


class BaseTrainer:
    def _compute_metrics(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
    ) -> dict:
        """Helper method to compute training metrics"""
        loss = ((predictions - targets) ** 2).mean().item()
        return {"loss": loss}

class ImageTrainer(BaseTrainer):
    def train_one_epoch(self, model, train_data):
        for data, target in train_data:
            predictions = model.forward(data)
            metrics = self._compute_metrics(predictions, target)  # Calls superclass
            print("Metrics:", metrics)

class TextTrainer(BaseTrainer):
    def train_one_epoch(self, model, train_data):
        for data, target in train_data:
            predictions = model.forward(data)
            metrics = self._compute_metrics(predictions, target)  # Calls superclass
            print("Metrics:", metrics)

This suffers from a number of shortcomings:

A better alternative is to use composable functions:

class Trainer:
    def __init__(
        self,
        model,
        train_data,
     # This is quick & dirty, a cleaner way would be to add more explicit typing
     # and classes to metrics.
        compute_metrics: Callable[[torch.Tensor, torch.Tensor], dict],
    ):
        self.model = model
        self.train_data = train_data
        self.compute_metrics = compute_metrics  # Injected dependency

    def train_one_epoch(self):
        for data, target in self.train_data:
            predictions = self.model.forward(data)
            metrics = self.compute_metrics(predictions, target)  # Explicit function call
            print("Metrics:", metrics)

# Different metric functions for different tasks
def mse_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> dict:
    loss = ((predictions - targets) ** 2).mean().item()
    return {"loss": loss}

def cross_entropy_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> dict:
    loss = torch.nn.functional.cross_entropy(predictions, targets).item()
    return {"cross_entropy_loss": loss}

This implementation is:

Shoutouts

Thanks to Jeremy McGibbon for introducing me to Dacite, advocating for the use of dataclasses as configuration classes, and championing high-quality software development within my former team at AI2. Jordan Lewis provided invaluable feedback and helped connect these ideas to a broader software engineering perspective. I also appreciate my fellow Brightband colleagues—Ryan Keisler, Hans Mohrmann, and Taylor Mandelbaum—as well as Gabe Gaster and Kevin Lin for their thoughtful feedback on early drafts. Finally, Alex Merose’s enthusiasm has been incredibly encouraging, inspiring me to share this more widely. I’m excited to see how we can continue refining and streamlining this approach.

I’d love to hear your thoughts! Feel free to reach out via email at "me" at this domain, on Twitter @gideoknite or connect with me on LinkedIn . Happy to chat!


  1. “Gin is particularly well suited for machine learning experiments (e.g. using TensorFlow), which tend to have many parameters, often nested in complex ways.”↩︎