Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-class pose estimation with class weights #12802

Open
1 task done
xzxorb opened this issue May 18, 2024 · 3 comments
Open
1 task done

Multi-class pose estimation with class weights #12802

xzxorb opened this issue May 18, 2024 · 3 comments
Labels
question Further information is requested

Comments

@xzxorb
Copy link

xzxorb commented May 18, 2024

Search before asking

Question

Hi! I am using the yolov8-pose model for keypoints detection of multiple classes, but some classes have very bad precision. How can I modify the code to add class weight to let some classes have higher weight when computing loss? Could you give me some advice? Thank you very much.

Additional

No response

@xzxorb xzxorb added the question Further information is requested label May 18, 2024
@glenn-jocher
Copy link
Member

Hello! To incorporate class weights into your YOLOv8-pose model for keypoints detection, you'll need to modify the loss calculation to factor in these weights. Currently, YOLOv8 does not directly support weighted losses out-of-the-box for pose estimation.

However, you can achieve this by customizing the loss function in the source code. Here’s a general approach:

  1. Identify the part of the code in models/yolo.py where the keypoint loss is computed.
  2. Introduce a weighting factor based on your class weights. You can define a tensor of class weights and multiply it by the loss values of corresponding classes.

Here's a conceptual snippet:

# Assuming 'loss' is your calculated keypoint loss tensor and 'class_weights' is a tensor containing weights for each class
weighted_loss = loss * class_weights[targets[:, 5].long()]  # targets[:, 5] should correspond to your class labels assuming they're at index 5

You would need to appropriately define the class_weights tensor to match the number of classes you have, and ensure it's moved to the same device as your model and loss tensors.

Please adjust this concept according to your specific implementation details and requirements. If you need further guidance, don't hesitate to consult!

@xzxorb
Copy link
Author

xzxorb commented May 19, 2024

Thank you, I will try this later.

@glenn-jocher
Copy link
Member

@xzxorb great! If you run into any further issues or have questions as you implement this, feel free to reach out. Happy coding! 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants