tensorflow_datasets interfere with cudnn · Issue #8302 · jax-ml/jax
Using tensorflow_datasets before using jax.lax.conv_general_dilated makes it crash. import jax.numpy as jnp import jax import tensorflow_datasets as tfds tfds.load("mnist", split='train') lhs = jnp...