Introduction
I use conda to manage my python environments. With the advent of Keras Core you can easily switch between different backend (TensorFlow, PyTorch, JAX, and NumPy).
My ususal workflow was:
- Create a conda environemt
conda create -n kcore python==3.9
- Activate the environment
conda activate kcore
- Set the
KERAS_BACKEND
environment variableexport KERAS_BACKEND="tensorflow"
I was frustrated with two things in particular:
- The backend conflicts in a single environment.
- Exporting the environment variable continuously.
Solution
To combat the backend conflicts in a single environment, an easy fix was to create different environments for each backend.
conda create -n kcore-tf python==3.9
conda activate kcore-tf
python -m pip install tensorflow keras-core
This is an environment that is exclusive to TensorFlow. Similarly you could also create one for PyTorch, JAX, and NumPy.
With the first frustration bidding adieu, let’s talk about the second problem. I was thinking of a state where I could switch to any environment, and the environment variable KERAS_BACKEND
was already set for me. Wouldn’t that be cool?
To do that I followed this great tutorial. Let’s go through the steps one at a time.
conda create kcore-jax python==3.9
conda activate kcore-jax
This creates and activates the jax environemnt.
python -m pip install jaxlib jax keras-core
This would now install jaxlib
, jax
, and keras-core
libraries.
cd $CONDA_PREFIX
mkdir -p ./etc/conda/activate.d
mkdir -p ./etc/conda/deactivate.d
touch ./etc/conda/activate.d/env_vars.sh
touch ./etc/conda/deactivate.d/env_vars.sh
Here we build a shell script which would help with the environment variables.
In the activate.d/env_vars.sh
we export the env variable.
#!/bin/sh
export KERAS_BACKEDN='jax'
While in the deactivate.d/env_vars.sh
we unset them.
#!/bin/sh
unset KERAS_BACKEND
Now, if you activate the kcore-jax
environment the KERAS_BACKEND
env var is already set. You can check it yourself with echo $KERAS_BACKEND
. What does this do you may ask?
import keras_core
print(keras_core.backend.backend())
Using JAX backend
And if you repeated the same process for all the conda environemnts, you would just have to activate your conda environment, and forget about everything else.