You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am writing some code that requires an op that takes a vector of indexes (ints), and returns a matrix of onehot vectors. I didn't make this a pull request as I am not 100% sure this has enough error-checking / correct implementations for some functions. I also might just be missing the fact that this is already implemented.
I think, if its not already there, this would be a good addition to make gorgonia a bit easier to use.
package main
import (
"fmt""hash""github.com/chewxy/hm""gorgonia.org/gorgonia""gorgonia.org/tensor"
)
funcOneHot(x*gorgonia.Node, numClassesint, dType tensor.Dtype) (*gorgonia.Node, error) {
op:=&oneHotOp{numClasses, dType}
returngorgonia.ApplyOp(op, x)
}
var_ gorgonia.Op=&oneHotOp{}
var_ gorgonia.SDOp=&oneHotOp{}
typeoneHotOpstruct {
numClassesintdType tensor.Dtype
}
// DiffWRT implements gorgonia.SDOp.func (*oneHotOp) DiffWRT(inputsint) []bool {
// I'm pretty sure you cant, nor would ever want to, take the derivative of this op.returnmake([]bool, inputs)
}
// SymDiff implements gorgonia.SDOp.func (*oneHotOp) SymDiff(inputs gorgonia.Nodes, output*gorgonia.Node, grad*gorgonia.Node) (retVal gorgonia.Nodes, errerror) {
panic("unimplemented (tho tbf this should never be called)")
}
// Arity implements gorgonia.Op.func (*oneHotOp) Arity() int {
return1// we expect just a vector of indices
}
// CallsExtern implements gorgonia.Op.func (*oneHotOp) CallsExtern() bool {
returnfalse
}
// Do implements gorgonia.Op.func (op*oneHotOp) Do(inp...gorgonia.Value) (gorgonia.Value, error) {
batchSize:=inp[0].Shape()[0]
tens:=tensor.New(tensor.WithShape(batchSize, op.numClasses), tensor.Of(op.dType))
fori:=0; i<batchSize; i++ {
index:=inp[0].Data().([]int)[i]
switchop.dType {
casetensor.Int:
tens.SetAt(int(1), i, index)
casetensor.Float64:
tens.SetAt(float64(1), i, index)
casetensor.Float32:
tens.SetAt(float32(1), i, index)
casetensor.Bool:
tens.SetAt(true, i, index)
}
}
returntens, nil
}
// InferShape implements gorgonia.Op.func (op*oneHotOp) InferShape(inputs...gorgonia.DimSizer) (tensor.Shape, error) {
s:=inputs[0].(tensor.Shape).Clone()
s=append(s, op.numClasses)
returns, nil
}
// OverwritesInput implements gorgonia.Op.func (*oneHotOp) OverwritesInput() int {
return-1
}
// ReturnsPtr implements gorgonia.Op.func (*oneHotOp) ReturnsPtr() bool {
returnfalse
}
// String implements gorgonia.Op.func (*oneHotOp) String() string {
return"OneHotOp"
}
// Type implements gorgonia.Op.func (*oneHotOp) Type() hm.Type {
ohTypeInput:= gorgonia.TensorType{
Dims: 1,
Of: tensor.Int,
}
ohTypeOutput:= gorgonia.TensorType{
Dims: 2,
Of: tensor.Float64,
}
returnhm.NewFnType(ohTypeInput, ohTypeOutput)
}
// I dont actually know what this is for (i just copied this code from another op)func (op*oneHotOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, op.String()) }
// Hashcode implements gorgonia.Op.func (*oneHotOp) Hashcode() uint32 {
// I dont actually know what this is forpanic("unimplementedb")
}
The text was updated successfully, but these errors were encountered:
I am writing some code that requires an op that takes a vector of indexes (ints), and returns a matrix of onehot vectors. I didn't make this a pull request as I am not 100% sure this has enough error-checking / correct implementations for some functions. I also might just be missing the fact that this is already implemented.
I think, if its not already there, this would be a good addition to make gorgonia a bit easier to use.
The text was updated successfully, but these errors were encountered: