swift
9e089ac3 - [TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) (#24164)

Commit
6 years ago
[TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) (#24164) * [TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) sum(squeezinAxes:) and mean(squeezinAxes:) were throwing an error during the bawckward pass because the gradients weren't unsqueezed before being broadcast. Note that this could be refactored nicely if we had a function that took a list of ints for `expandingShape`. Second note: I may be wrong, but it seems like `_vjpMean(squeezingAxes axes: [Int])` is never used and only the Tensor<Int32> version is. * Remove unused `_vjpMean` function. * Update Gradients.swift * Add test * Minor edit for consistency.
Author
Committer
Parents
Loading