Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OneHot op #559

Open
JoshPattman opened this issue Dec 24, 2023 · 0 comments
Open

OneHot op #559

JoshPattman opened this issue Dec 24, 2023 · 0 comments

Comments

@JoshPattman
Copy link
Contributor

I am writing some code that requires an op that takes a vector of indexes (ints), and returns a matrix of onehot vectors. I didn't make this a pull request as I am not 100% sure this has enough error-checking / correct implementations for some functions. I also might just be missing the fact that this is already implemented.

I think, if its not already there, this would be a good addition to make gorgonia a bit easier to use.

package main

import (
	"fmt"
	"hash"

	"github.com/chewxy/hm"
	"gorgonia.org/gorgonia"
	"gorgonia.org/tensor"
)

func OneHot(x *gorgonia.Node, numClasses int, dType tensor.Dtype) (*gorgonia.Node, error) {
	op := &oneHotOp{numClasses, dType}

	return gorgonia.ApplyOp(op, x)
}

var _ gorgonia.Op = &oneHotOp{}
var _ gorgonia.SDOp = &oneHotOp{}

type oneHotOp struct {
	numClasses int
	dType      tensor.Dtype
}

// DiffWRT implements gorgonia.SDOp.
func (*oneHotOp) DiffWRT(inputs int) []bool {
	// I'm pretty sure you cant, nor would ever want to, take the derivative of this op.
	return make([]bool, inputs)
}

// SymDiff implements gorgonia.SDOp.
func (*oneHotOp) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) {
	panic("unimplemented (tho tbf this should never be called)")
}

// Arity implements gorgonia.Op.
func (*oneHotOp) Arity() int {
	return 1 // we expect just a vector of indices
}

// CallsExtern implements gorgonia.Op.
func (*oneHotOp) CallsExtern() bool {
	return false
}

// Do implements gorgonia.Op.
func (op *oneHotOp) Do(inp ...gorgonia.Value) (gorgonia.Value, error) {
	batchSize := inp[0].Shape()[0]
	tens := tensor.New(tensor.WithShape(batchSize, op.numClasses), tensor.Of(op.dType))
	for i := 0; i < batchSize; i++ {
		index := inp[0].Data().([]int)[i]
		switch op.dType {
		case tensor.Int:
			tens.SetAt(int(1), i, index)
		case tensor.Float64:
			tens.SetAt(float64(1), i, index)
		case tensor.Float32:
			tens.SetAt(float32(1), i, index)
		case tensor.Bool:
			tens.SetAt(true, i, index)
		}
	}
	return tens, nil
}

// InferShape implements gorgonia.Op.
func (op *oneHotOp) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
	s := inputs[0].(tensor.Shape).Clone()
	s = append(s, op.numClasses)
	return s, nil
}

// OverwritesInput implements gorgonia.Op.
func (*oneHotOp) OverwritesInput() int {
	return -1
}

// ReturnsPtr implements gorgonia.Op.
func (*oneHotOp) ReturnsPtr() bool {
	return false
}

// String implements gorgonia.Op.
func (*oneHotOp) String() string {
	return "OneHotOp"
}

// Type implements gorgonia.Op.
func (*oneHotOp) Type() hm.Type {
	ohTypeInput := gorgonia.TensorType{
		Dims: 1,
		Of:   tensor.Int,
	}
	ohTypeOutput := gorgonia.TensorType{
		Dims: 2,
		Of:   tensor.Float64,
	}
	return hm.NewFnType(ohTypeInput, ohTypeOutput)
}

// I dont actually know what this is for (i just copied this code from another op)
func (op *oneHotOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, op.String()) }

// Hashcode implements gorgonia.Op.
func (*oneHotOp) Hashcode() uint32 {
	// I dont actually know what this is for
	panic("unimplementedb")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant