fix_torch_backward.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from evaluator import *
  2. DESCRIPTION = "Test if the model can fix and explain a bug in PyTorch code related to forgetting to zero gradients."
  3. TAGS = ['code', 'python', 'fix']
  4. code = """
  5. ```
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. class SimpleNet(nn.Module):
  10. def __init__(self):
  11. super(SimpleNet, self).__init__()
  12. self.fc = nn.Linear(1, 1)
  13. def forward(self, x):
  14. return self.fc(x)
  15. def PGD(model, input, target, loss_fn, epsilon, alpha, num_iter):
  16. perturbation = torch.zeros_like(input, requires_grad=True)
  17. for i in range(num_iter):
  18. output = model(input + perturbation)
  19. loss = loss_fn(output, target)
  20. loss.backward()
  21. # Update perturbation
  22. perturbation_grad = perturbation.grad.data
  23. perturbation.data = perturbation.data + alpha * perturbation_grad.sign()
  24. perturbation.data = torch.clamp(perturbation.data, -epsilon, epsilon)
  25. perturbation.grad.data.zero_()
  26. return input + perturbation
  27. model = SimpleNet()
  28. input = torch.tensor([[1.0]], requires_grad=True)
  29. target = torch.tensor([[2.0]])
  30. loss_fn = nn.MSELoss()
  31. epsilon = 0.1
  32. alpha = 0.01
  33. num_iter = 2
  34. adversarial_input = PGD(model, input, target, loss_fn, epsilon, alpha, num_iter)
  35. print("Resulting advex", adversarial_input)
  36. ```
  37. """
  38. q1 = f"Fix the bug in this code. Rewrite the entire code exactly as-is but just fix any bugs.\n\n{code}"
  39. TestTorchBackwardFix = q1 >> LLMRun() >> ExtractCode(keep_main=True) >> PythonRun() >> (SubstringEvaluator("tensor") & SubstringEvaluator("grad_fn=<Add"))
  40. q2 = f"Explain the bug in this code.\n\n{code}"
  41. TestTorchBackwardExplain = q2 >> LLMRun() >> LLMRun("Below is a student's explanation for a bug in a torch function.\n<A>\nDoes the student's answer say specifically \"the bug is caused by not zeroing the gradient in the backward pass\"? Think out loud and then finally answer either \"The student passes\" or \"The student fails\".", llm=EVAL_LLM) >> Echo() >> SubstringEvaluator("The student passes")
  42. if __name__ == "__main__":
  43. print(run_test(TestTorchBackwardExplain))