@cstorch.step_closure
def check_nan(loss):
assert not torch.isnan(loss).any()
@cstorch.step_closure
def print_loss(loss):
print(f"Loss: {loss}")
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
check_nan(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
for inputs, targets in executor:
loss = training_step(inputs, targets)
print_loss(loss)