Serving machine learning models isn't easy, many decisions have to be made in order to deploy them to production in a reliable and scalable way.
For simple models, it's perfectly reasonable to make them available as an synchronous API endpoint to be used by other services in our application structure. However, when dealing with neural networks, especially big ones which can only be efficiently trained via GPU, it might not be as easy.
As soon as we reach the point at which our model needs a GPU to do reasonably performant inference, we should carefully consider all possible options of serving our model.
Ideally, we'd like to utilize a GPU instance as good as possible solely on inference tasks, leaving the API necessities of making it available to much smaller, concurrency focused CPU instances.
To achieve this separation of concerns, we'll need a message broker to store our inference request, as well as a data storage solution to store the results for a certain amount of time
Which is why we are going to look at how we can build such a system. We'll serve up a PyTorch model, utilizing RabbitMQ as our message broker, FastAPI as our REST service and Redis as our in-memory storage backend to store the inference results.
A Simple MNIST Model
For the purpose of this post, we'll train an MNIST classification model, which allows us to classify a hand-drawn number.
It's one of the classic machine learning examples involving convolutional layers, which are commonly trained via GPU, so it makes for a good substitute of some other more advanced GPU heavy ML model we might be dealing with.
To save some time, we use a trimmed down version of the PyTorch Example for MNIST, condensing it to the minimum of what's needed to get a fairly good solution.
Starting the training process via
python -m train, we can see good progress.
> Train Epoch: 1 [0/60000 (0%)] Loss: 2.321489
> Train Epoch: 1 [12800/60000 (21%)] Loss: 0.204598
> Train Epoch: 3 [38400/60000 (64%)] Loss: 0.035516
> Train Epoch: 3 [51200/60000 (85%)] Loss: 0.122487
> Test set: Average loss: 0.0320, Accuracy: 9887/10000 (99%)
Training is done for 3 epochs only, achieving 99% accuracy on the test set, barely enough for our goal. 😉
Afterwards, the model state dictionary is saved to
Setting up RabbitMQ and Redis
As our model is now ready to rock, we continue with the center parts of our desired setup, the RabbitMQ and Redis services. To make our lives a bit easier while in development, we spin up both of them using Docker.
docker run --rm -p 5672:5672 -p 15672:15672 -e RABBITMQ_DEFAULT_USER=my_user -e RABBITMQ_DEFAULT_PASS=my_password rabbitmq:3.10-management
docker run --rm -p 6379:6379 redis:latest
Done deal, easy as that, we have Redis running on port 6379, and RabbitMQ running on port 5672.
On port 15672 RabbitMQ additionally exposes its management interface, which can be used for debugging and checking the message brokers health, but we'll hopefully not need it for now.
The next part of our journey is to configure FastAPI. While doing that, we are extracting common RabbitMQ and Redis connectivity logic into a file called
bridge.py. We are going to use it on both sides, the API and the PyTorch worker.
As it can be seen, the only thing this “bridge” is doing is connecting to the RabbitMQ message broker, as well as to the Redis data-store we spun up earlier. At a later stage, we might want to change the behavior to load the credentials from the environment in order to make the deployment more flexible.
The real deal happens inside
api.py. This is where we are accepting the inference requests (files of hand drawn digits being uploaded) and forward them to the message broker.
We create a random UUID and attach it to the inference request, for our client to be able to retrieve the result later.
Naturally, we need to define an additional endpoint which allows us to query the Redis data-store for the inference result, which might or might not be available.
If one is available, it's sent to the client upon request, if this is not the case, we send it back, implicitly asking to wait a little longer for the request to be processed.
Configuring the Worker
As the API is ready to take on all the requests and put them into the message queue, the logical next step is to implement the worker, in order to have those requests swiftly processed.
Again, we are splitting the logic into 2 parts to make for a somewhat more readable code. First we create a
inference_model.py file which loads up the real PyTorch model and encapsulates the prediction and pre-processing workflow.
For this inference model to be useful, we'll finally also declare our worker in
It will be responsible to wait for requests to come in, extract the payload (the uploaded image), pass it along to the inference model and store the result for later retrieval.
The worker is also, just like the API, utilizing the “bridge” in order to connect to the Redis data-store and the RabbitMQ message broker.
Lastly, the message will be acknowledged, giving the message broker the signal that it has been processed successfully.
Trying it out
As our setup is now complete, it's time to run the API and our worker(s) to see if everything is working the way it's supposed to.
First, we start up FastAPI.
python -m uvicorn api:api
Second, we start the worker(s). In fact, we can already instantiate the worker multiple times, to see if the task distribution works as anticipated. So for testing, let's start 2 instances of our workers.
python -m worker
To try out if the inference works, we'll submit multiple sample images from the MNIST dataset. A blurry 4, a 7 positioned at the bottom and a reasonably centered 3.
Next, we'll send 3 curl requests to our API, anticipating in response different inference UUIDs we can then use later to check the inference status.
curl -X POST 'http://localhost:8000/classify' --form 'file=@"mnist_3.jpeg"'
curl -X POST 'http://localhost:8000/classify' --form 'file=@"mnist_4.jpg"'
curl -X POST 'http://localhost:8000/classify' --form 'file=@"mnist_7.webp"'
Great, we got an inference UUID back for each request. Looking at the terminal output of our 2 workers, we can also see that the task distribution worked successfully, as the first worker processed the 3 and the 7, while the second worker processed the 4.
By looking at the workers outputs, we already know the inference prediction, nevertheless we'll also test the last step, making sure that fetching the inference result on the client by providing the UUID works as expected.
curl -X GET 'http://localhost:8000/result/bac0e10d-87e8-4f93-b102-f6816a32b66b'
curl -X GET 'http://localhost:8000/result/8f021057-532a-47ac-896e-407fe14fb474'
curl -X GET 'http://localhost:8000/result/30684d7f-b86a-4014-9029-85554dccf3e7'
And there they are, 3, 4 and 7, good stuff!
Great, what's next?
Possible next steps would be to containerize our solutions via Docker, such that they can be deployed a little easier to specific Docker machines or even to a Kubernetes cluster.
Also, very likely it would make sense to invest more time looking into the documentation of RabbitMQ and Redis, hardening the setup in case of lost connections or bad data that can't be processed by a worker. Furthermore, an expiration time for the results stored in Redis would probably make sense.
So on and so forth, lots of possibilities to improve the proposed setup, which unfortunately can't be covered all in one post.
Why not use Celery (Python Job Scheduler Library)?
Celery is very promising, and in fact this post started out building the setup with Celery, however at some point the downsides outweighed the upsides for me.
Machine learning inference tasks are very specific, and Celery is a distributed job scheduler with batteries included, meaning it comes with opinionated defaults. These defaults, however, might not be the best way to go about ML inference tasks.
So often it boils down to forcefully ripping out some batteries again to get to a proper solution. For example, by having to neglect its multithreaded nature and forcing one instance of it to be single threaded only, in order to make it work with CUDA.
Also, since Celery wants to be as easy as possible to use it, we find ourselves quickly dealing with unnecessary dependencies in the API, or “conditional imports” or dynamically loading classes via string parameters.
I'm not saying it's impossible, especially the last reference proves that it is very possible. Celery is a great tool, however under the circumstances of ML inference it doesn't allow me to separate the API from the worker as clearly as I'd like them to be. So take it with a grain of salt, as it's just my subjective opinion.
You've come to an end, and by that end, you should now have a scalable solution to serve your GPU intensive ML models for inference!