Gideon Dresdner | March 14, 2025 | Home
“Dependency injection” is a ubiquitous $25 term for a 25¢ concept. It has also been known as the “Inversion of Control Principle” (another terrible term) or, and this is my favorite, the Hollywood Principle: “Don’t call us, we’ll call you.” At its core, dependency injection aims to explicitly define a routine’s dependencies through its arguments. In my opinion, this often results in cleaner code. More objectively speaking, it results in code that is more modular and easier to unit test.
Dependency injection (DI) is often associated with sophisticated frameworks such as spring and guice which aim to reduce boilerplate by declaratively wiring-up dependencies. But, frameworks like these are not required to implement a DI-inspired codebase. In fact, it is a particularly exciting moment to be writing code using DI because we can easily generate boilerplate using AI-assisted programming. Furthermore, the Python community is adopting modern tooling such as mypy and dataclasses which work well with DI. Overall, it has never been easier to adopt these principles in a lightweight manner.
Despite AI/ML hype and adoption skyrocketing, we can see the field struggling to adopt basic software engineering techniques. A common misconception is that AI/ML suffers from leaky abstractions and larger, more complex configuration. In short, AI/ML is more complicated than your run of the mill software [A Recipe for Training Neural Networks, gin-config readme1]. I don’t believe this. Instead, social and historical reasons underlie this neglect of software craftmanship.
To date, AI/ML has maintained its academic roots. Many AI/ML practitioners are expected to have PhDs and the field’s luminaries are often (former) professors. Publications, not software artifacts, are still the de facto currency of achievement. This association with academia leads engineers to believe that they are scientists, prioritizing ideas and equations over software architecture and implementation. We should embrace the fact that, from its inception with Weiner and others, an interdisciplinary, engineering ethos permeates AI/ML. This doesn’t mean doing less research, it just means being less dismissive of our engineering-oriented colleagues.
Remember: This is just Python. There are no complex libraries to install. There is no black-box framework running your code. This is simply a way of organizing your thoughts in code.
Let’s introduce a concrete example of what code written in a DI-style
looks like. Suppose we are implementing a training loop for some AI/ML
model. The Trainer
class implements this routine in the
train_one_epoch
method. In essence, the training algorithm
doesn’t depend on the AI/ML model architecture, the incidental way that
we’ve chosen to keep track of various metrics, e.g. using WandB, or even the dataset, as long as it
satisfies mild type constraints.
Thus, we pass each of those dependencies as arguments to the constructor. Constructing them is not our concern right now.
This has a positive effect on unit testing:
NullWandbLogger
(see null object
pattern) which can be used to unit test train_one_epoch
without touching wandb. This way we don’t have to resort to pytest
patching and other metaprogramming tricks or worry about external calls
to the wandb API, etc.class WandbLogger:
def log(self, metrics: dict) -> None:
...
class NullWandbLogger:
def log(self, metrics: dict) -> None:
# Do nothing
pass
class SomeArchitecture:
def __init__(self, n_channels: int, height: int, width: int):
self.n_channels = n_channels
self.height = height
self.width = width
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
class Trainer:
def __init__(
self,
model: SomeArchitecture,
wandb: WandbLogger,
train_data: Sequence[torch.Tensor]-> None:
) self.model = model
self.wandb = wandb
self.train_data = train_data
def train_one_epoch(self) -> None:
for data in self.train_data:
self.model.forward(data)
= {
metrics # compute some metrics
}self.wandb.log(metrics)
Our objects often have many dependencies. Some are essential and should be up to the practitioner to specify, e.g. learning rate, dataset path, etc. Others are an unfortunate consequence of incidental complexity, e.g. image size. To clearly marshall dependencies into these two categories, we can use a simple configurable-object pattern which leverages dependency injection. While it’s true that it can be difficult to differentiate between essential and incidental elements, this pattern makes it easy to change a dependency’s status simply by moving between the object’s fields and the build method. Rather than brushing this distinction under the rug, this pattern forces the programmer to clarify their intentions.
In the example below, the practitioner configures the dataset’s
resolution
, but not the input_shape
, which is
fully determined by that resolution. Rather than making
input_shape
externally configurable, it is passed
internally to ModelConfig
as a build argument. This
explicitly encodes the dependency between the dataset and the model
while ensuring that the practitioner cannot override it. However, in
unit tests, the model can still be constructed independently of any
dataset, allowing for flexible and convenient testing.
@dataclasses.dataclass
class DataConfig:
str
path: int
batch_size: float
resolution:
def build(self) -> tuple[SomeDataset, torch.utils.data.DataLoader]:
= SomeDataset(path=self.path, resolution=self.resolution)
dataset = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size)
dataloader return dataset, dataloader
@dataclasses.dataclass
class ModelConfig:
int
hidden_dim: "relu", "gelu"]
activation: Literal[
def build(self, input_shape: tuple) -> SomeModel:
return SomeModel(
=input_shape,
input_shape=self.hidden_dim,
hidden_dim=self.activation,
activation
)
@dataclasses.dataclass
class TrainerConfig:
train_data: DataConfig
model: ModelConfig
wandb: WandbConfig
def build(self) -> Trainer:
= self.train_data.build()
dataset, dataloader = self.model.build(dataset.input_shape)
model = self.wandb.build()
wandb return Trainer(model=model, train_data=dataloader, wandb=wandb)
These dataclasses can then be bound to yaml config files using dacite:
with open("config.yaml") as file:
= yaml.safe_load(file)
yaml_configuration
= dacite.from_dict(
config
TrainerConfig,
yaml_configuration, =dacite.Config(strict=True)
config
)= config.build(other, dependencies) trainer
By separating the construction of complex objects from their implementation, we see several positive outcomes:
load_data_from_disk
which means that you have to work
through that function in order to test the training algorithm.class Trainer:
def __init__(
self,
=2,
n_channels=32,
image_height=32,
image_width="dataset.csv"
dataset_path
):self.model = SomeArchitecture(n_channels, image_height, image_width)
self.wandb = WandbLogger() # Hardcoded logger
self.train_data = load_data_from_disk(dataset_path) # Hardcoded data source
def train_one_epoch(self):
for data in self.train_data:
self.model.forward(data)
= { # compute some metrics
metrics
}self.wandb.log(metrics)
Trainer
does not declare what it requires, making it
difficult to maintain.= WandbLogger()
wandb_logger = SomeArchitecture(3, 32, 32)
model = load_data_from_disk("dataset.csv")
train_data
class Trainer:
def train_one_epoch(self):
for data in train_data:
model.forward(data)= { # compute some metrics }
metrics wandb_logger.log(metrics)
create_model()
instead of
just passing in a model.class Trainer:
def __init__(self):
self.model = self.create_model()
self.wandb = self.create_logger()
self.train_data = self.load_data()
def create_model(self):
return SomeArchitecture(3, 32, 32)
def create_logger(self):
return WandbLogger()
def load_data(self):
return load_data_from_disk("dataset.csv")
A common anti-pattern is to group helper functions into a superclass:
"""
Anti-pattern: using a superclass to group helper functions together and
forcing all users of that helper function to inherit from it.
"""
class BaseTrainer:
def _compute_metrics(
self,
predictions: torch.Tensor,
targets: torch.Tensor,-> dict:
) """Helper method to compute training metrics"""
= ((predictions - targets) ** 2).mean().item()
loss return {"loss": loss}
class ImageTrainer(BaseTrainer):
def train_one_epoch(self, model, train_data):
for data, target in train_data:
= model.forward(data)
predictions = self._compute_metrics(predictions, target) # Calls superclass
metrics print("Metrics:", metrics)
class TextTrainer(BaseTrainer):
def train_one_epoch(self, model, train_data):
for data, target in train_data:
= model.forward(data)
predictions = self._compute_metrics(predictions, target) # Calls superclass
metrics print("Metrics:", metrics)
This suffers from a number of shortcomings:
_compute_metrics
, which might
affect other subclasses or cause unintended side effects. Programmers
are usually aware of this, so they will make
_compute_metrics
have no side-effects, i.e. no dependence
on self
. At this point, the function may as well be a
simple, top-level function._compute_metrics
is an implicit
dependency which each subclass is forced to inherit, if it won’t use
it._compute_metrics
changes (e.g. adds accuracy calculation),
all subclasses inherit this change automatically.A better alternative is to use composable functions:
class Trainer:
def __init__(
self,
model,
train_data,# This is quick & dirty, a cleaner way would be to add more explicit typing
# and classes to metrics.
dict],
compute_metrics: Callable[[torch.Tensor, torch.Tensor],
):self.model = model
self.train_data = train_data
self.compute_metrics = compute_metrics # Injected dependency
def train_one_epoch(self):
for data, target in self.train_data:
= self.model.forward(data)
predictions = self.compute_metrics(predictions, target) # Explicit function call
metrics print("Metrics:", metrics)
# Different metric functions for different tasks
def mse_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> dict:
= ((predictions - targets) ** 2).mean().item()
loss return {"loss": loss}
def cross_entropy_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> dict:
= torch.nn.functional.cross_entropy(predictions, targets).item()
loss return {"cross_entropy_loss": loss}
This implementation is:
Thanks to Jeremy McGibbon for introducing me to Dacite, advocating for the use of dataclasses as configuration classes, and championing high-quality software development within my former team at AI2. Jordan Lewis provided invaluable feedback and helped connect these ideas to a broader software engineering perspective. I also appreciate my fellow Brightband colleagues—Ryan Keisler, Hans Mohrmann, and Taylor Mandelbaum—as well as Gabe Gaster and Kevin Lin for their thoughtful feedback on early drafts. Finally, Alex Merose’s enthusiasm has been incredibly encouraging, inspiring me to share this more widely. I’m excited to see how we can continue refining and streamlining this approach.
I’d love to hear your thoughts! Feel free to reach out via email
at "me" at this domain
, on Twitter @gideoknite or connect with me on LinkedIn . Happy to
chat!
“Gin is particularly well suited for machine learning experiments (e.g. using TensorFlow), which tend to have many parameters, often nested in complex ways.”↩︎