jnp_nn_bugfix.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from evaluator import *
  2. DESCRIPTION = "Test if the model can identify and fix a bug in a given jax.numpy function."
  3. TAGS = ['code', 'python']
  4. question = """
  5. Fix the bug in this jnp function. Write out the full correct version of the code.
  6. import jax
  7. import jax.numpy as jnp
  8. def init_mlp_params(input_size, hidden_size, output_size, key):
  9. k1, k2 = jax.random.split(key)
  10. # Initialize weights and biases for the first linear layer
  11. w1 = jax.random.normal(k1, (hidden_size, input_size)) * jnp.sqrt(2.0 / input_size)
  12. b1 = jnp.zeros(hidden_size)
  13. # Initialize weights and biases for the second linear layer
  14. w2 = jax.random.normal(k2, (output_size, hidden_size)) * jnp.sqrt(2.0 / hidden_size)
  15. b2 = jnp.zeros(output_size)
  16. return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2}
  17. def mlp_apply(params, x):
  18. # First linear layer
  19. x = jnp.dot(x, params['w1'].T) + params['b1']
  20. # Relu activation
  21. x = jax.nn.relu(x)
  22. # Second linear layer
  23. x = jnp.dot(x, params['w2'].T) + params['b2']
  24. # Softmax activation
  25. return jax.nn.softmax(x, axis=1)
  26. # Example usage:
  27. key = jax.random.PRNGKey(0)
  28. params = init_mlp_params(10, 20, 30, key)
  29. x = jax.random.normal(key, (5, 10)) # random input similar to torch.randn
  30. y = mlp_apply(params, x)
  31. print(list(y.flatten().sum(1))) # The output sum should be very close to [1, 1, 1, 1, 1]
  32. """
  33. TestFixJnpBug = question >> LLMRun() >> ExtractCode(keep_main=True) >> \
  34. (PythonRun() >> (SubstringEvaluator("1.0,") | SubstringEvaluator("1.00000") | SubstringEvaluator("1.0 ") | SubstringEvaluator("0.99999")))
  35. if __name__ == "__main__":
  36. print(run_test(TestFixJnpBug))