123456789101112131415161718192021222324252627282930313233343536373839 |
- from evaluator import *
- DESCRIPTION = "Test if the model can convert a torch neural network to a jax numpy model."
- TAGS = ['code', 'python']
- question = """
- Convert the following PyTorch code to pure JAX. Do not use any other libraries.
- ```python
- import torch
- class MLP(torch.nn.Module):
- def __init__(self, input_size, hidden_size, output_size):
- super().__init__()
- self.linear1 = torch.nn.Linear(input_size, hidden_size)
- self.linear2 = torch.nn.Linear(hidden_size, output_size)
- def __call__(self, x):
- x = self.linear1(x)
- x = torch.relu(x)
- x = self.linear2(x)
- return torch.nn.functional.softmax(x, dim=1)
- model = MLP(10, 20, 30)
- x = torch.randn(5, 10)
- y = model(x)
- print(list(y.flatten().numpy().sum(1)))
- ```
- """
- TestTorchJnp = question >> LLMRun() >> ExtractCode(keep_main=True) >> \
- ((~SubstringEvaluator("import torch")) &
- (PythonRun() >> (SubstringEvaluator("1.0,") | SubstringEvaluator("1.00000") | SubstringEvaluator("1.0 ") | SubstringEvaluator("0.99999"))))
- if __name__ == "__main__":
- print(run_test(TestTorchJnp))
|