jax_onehot.py 962 B

1234567891011121314151617181920
  1. from evaluator import *
  2. DESCRIPTION = "Test if the model can correctly convert a list of indexes to a one-hot vector in Python using JAX."
  3. TAGS = ['code', 'python']
  4. question = """
  5. I have list of indexes and I want to convert it to one hot vector using jax and the function should be jittable and the function should be jitted. name the function one_hot and it should get two arguments the first one is the indexes and the second one is the number of possible labeles. Just give me the code
  6. """
  7. test_case, answer = make_python_test([("str(one_hot(jnp.array([1,2,0]),5))","str(jnp.array([[0.0,1.0,0.0,0.0,0.0],[0.0,0.0,1.0,0.0,0.0],[1.0,0.0,0.0,0.0,0.0]]))")],header='import jax.numpy as jnp')
  8. TestJaxOneHot = question >> LLMRun() >> ExtractCode() >> Echo() >> PythonRun(test_case) >> Echo() >> SubstringEvaluator(answer)
  9. if __name__ == "__main__":
  10. print(run_test(TestJaxOneHot))