torch_to_jnp.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from evaluator import *
  2. DESCRIPTION = "Test if the model can convert a torch neural network to a jax numpy model."
  3. TAGS = ['code', 'python']
  4. question = """
  5. Convert the following PyTorch code to pure JAX. Do not use any other libraries.
  6. ```python
  7. import torch
  8. class MLP(torch.nn.Module):
  9. def __init__(self, input_size, hidden_size, output_size):
  10. super().__init__()
  11. self.linear1 = torch.nn.Linear(input_size, hidden_size)
  12. self.linear2 = torch.nn.Linear(hidden_size, output_size)
  13. def __call__(self, x):
  14. x = self.linear1(x)
  15. x = torch.relu(x)
  16. x = self.linear2(x)
  17. return torch.nn.functional.softmax(x, dim=1)
  18. model = MLP(10, 20, 30)
  19. x = torch.randn(5, 10)
  20. y = model(x)
  21. print(list(y.flatten().numpy().sum(1)))
  22. ```
  23. """
  24. TestTorchJnp = question >> LLMRun() >> ExtractCode(keep_main=True) >> \
  25. ((~SubstringEvaluator("import torch")) &
  26. (PythonRun() >> (SubstringEvaluator("1.0,") | SubstringEvaluator("1.00000") | SubstringEvaluator("1.0 ") | SubstringEvaluator("0.99999"))))
  27. if __name__ == "__main__":
  28. print(run_test(TestTorchJnp))