fix(tests): run causal_softmax reference on CPU
#612
+4
−3
causal_softmax reference on CPU
#612