How do you get the key, given the value of a dictionary that contains arrays?
my_dict = {0:jnp.array([0, 0]),1: jnp.array([0,1]),2:jnp.array([0,2])}
def get_key(val, my_dict):
for key, value in my_dict.items():
if val == value:
return key
x = jnp.array([0,1])
get_key(x, my_dict)
I am getting the following error:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()