Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature : Multimasks training for RAMP #240

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion backend/aiproject/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""

import os

import logging
import dj_database_url
import environ
from corsheaders.defaults import default_headers
Expand Down Expand Up @@ -185,6 +185,7 @@
STATIC_ROOT = os.path.join(BASE_DIR, "api_static")

if DEBUG:
logging.info("Enabling oauthlib insecure transport in debug mode")
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"


Expand Down
38 changes: 32 additions & 6 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import os
import shutil
import sys
import tarfile
import traceback
from shutil import rmtree
import tarfile

import hot_fair_utilities
import ramp.utils
Expand Down Expand Up @@ -37,6 +37,7 @@

DEFAULT_TILE_SIZE = 256


def xz_folder(folder_path, output_filename, remove_original=False):
"""
Compresses a folder and its contents into a .tar.xz file and optionally removes the original folder.
Expand All @@ -47,8 +48,8 @@ def xz_folder(folder_path, output_filename, remove_original=False):
- remove_original: If True, the original folder is removed after compression.
"""

if not output_filename.endswith('.tar.xz'):
output_filename += '.tar.xz'
if not output_filename.endswith(".tar.xz"):
output_filename += ".tar.xz"

with tarfile.open(output_filename, "w:xz") as tar:
tar.add(folder_path, arcname=os.path.basename(folder_path))
Expand All @@ -67,6 +68,9 @@ def train_model(
source_imagery,
feedback=None,
freeze_layers=False,
multimasks=False,
input_contact_spacing=8,
input_boundary_width=3,
):
training_instance = get_object_or_404(Training, id=training_id)
training_instance.status = "RUNNING"
Expand Down Expand Up @@ -182,12 +186,22 @@ def train_model(
# preprocess
model_input_image_path = f"{base_path}/input"
preprocess_output = f"/{base_path}/preprocessed"

if multimasks:
logger.info(
"Using multiple masks for training : background, footprint, boundary, contact"
)
else:
logger.info("Using binary masks for training : background, footprint")
preprocess(
input_path=model_input_image_path,
output_path=preprocess_output,
rasterize=True,
rasterize_options=["binary"],
georeference_images=True,
multimasks=multimasks,
input_contact_spacing=input_contact_spacing,
input_boundary_width=input_boundary_width,
)

# train
Expand All @@ -206,6 +220,7 @@ def train_model(
),
model_home=os.environ["RAMP_HOME"],
epoch_size=epochs,
multimasks=multimasks,
batch_size=batch_size,
freeze_layers=freeze_layers,
)
Expand All @@ -218,6 +233,7 @@ def train_model(
model="ramp",
model_home=os.environ["RAMP_HOME"],
freeze_layers=freeze_layers,
multimasks=multimasks,
)

# copy final model to output
Expand Down Expand Up @@ -272,9 +288,19 @@ def train_model(
f.write(json.dumps(aoi_serializer.data))

# copy aois and labels to preprocess output before compressing it to tar
shutil.copyfile(os.path.join(output_path, "aois.geojson"), os.path.join(preprocess_output,'aois.geojson'))
shutil.copyfile(os.path.join(output_path, "labels.geojson"), os.path.join(preprocess_output,'labels.geojson'))
xz_folder(preprocess_output, os.path.join(output_path, "preprocessed.tar.xz"), remove_original=True)
shutil.copyfile(
os.path.join(output_path, "aois.geojson"),
os.path.join(preprocess_output, "aois.geojson"),
)
shutil.copyfile(
os.path.join(output_path, "labels.geojson"),
os.path.join(preprocess_output, "labels.geojson"),
)
xz_folder(
preprocess_output,
os.path.join(output_path, "preprocessed.tar.xz"),
remove_original=True,
)

# now remove the ramp-data all our outputs are copied to our training workspace
shutil.rmtree(base_path)
Expand Down
64 changes: 48 additions & 16 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ class DatasetViewSet(
class TrainingSerializer(
serializers.ModelSerializer
): # serializers are used to translate models objects to api

multimasks = serializers.BooleanField(required=False, default=False)
input_contact_spacing = serializers.IntegerField(
required=False, default=8, min_value=0, max_value=20
)
input_boundary_width = serializers.IntegerField(
required=False, default=3, min_value=0, max_value=10
)

class Meta:
model = Training
fields = "__all__" # defining all the fields to be included in curd for now , we can restrict few if we want
Expand Down Expand Up @@ -126,6 +135,16 @@ def create(self, validated_data):
user = self.context["request"].user
validated_data["created_by"] = user
# create the model instance
multimasks = validated_data.get("multimasks", False)
input_contact_spacing = validated_data.get("input_contact_spacing", 0.75)
input_boundary_width = validated_data.get("input_boundary_width", 0.5)

pop_keys = ["multimasks", "input_contact_spacing", "input_boundary_width"]

for key in pop_keys:
if key in validated_data.keys():
validated_data.pop(key)

instance = Training.objects.create(**validated_data)
logging.info("Sending record to redis queue")
# run your function here
Expand All @@ -138,11 +157,16 @@ def create(self, validated_data):
source_imagery=instance.source_imagery
or instance.model.dataset.source_imagery,
freeze_layers=instance.freeze_layers,
multimasks=multimasks,
input_contact_spacing=input_contact_spacing,
input_boundary_width=input_boundary_width,
)
logging.info("Record saved in queue")

if not instance.source_imagery:
instance.source_imagery = instance.model.dataset.source_imagery
if multimasks:
instance.description += f" Multimask params (ct/bw): {input_contact_spacing}/{input_boundary_width}"
instance.task_id = task.id
instance.save()
print(f"Saved train model request to queue with id {task.id}")
Expand Down Expand Up @@ -194,7 +218,7 @@ class FeedbackLabelViewset(viewsets.ModelViewSet):
bbox_filter_field = "geom"
filter_backends = (
InBBoxFilter, # it will take bbox like this api/v1/label/?in_bbox=-90,29,-89,35 ,
DjangoFilterBackend
DjangoFilterBackend,
)
bbox_filter_include_overlapping = True
filterset_fields = ["feedback_aoi", "feedback_aoi__training"]
Expand Down Expand Up @@ -345,9 +369,9 @@ def download_training_data(request, dataset_id: int):
response = HttpResponse(open(zip_temp_path, "rb"))
response.headers["Content-Type"] = "application/x-zip-compressed"

response.headers[
"Content-Disposition"
] = f"attachment; filename=training_{dataset_id}_all_data.zip"
response.headers["Content-Disposition"] = (
f"attachment; filename=training_{dataset_id}_all_data.zip"
)
return response
else:
# "error": "File Doesn't Exist or has been cleared up from system",
Expand Down Expand Up @@ -555,12 +579,16 @@ def post(self, request, *args, **kwargs):
zoom_level=zoom_level,
tms_url=source,
tile_size=DEFAULT_TILE_SIZE,
confidence=deserialized_data["confidence"] / 100
if "confidence" in deserialized_data
else 0.5,
tile_overlap_distance=deserialized_data["tile_overlap_distance"]
if "tile_overlap_distance" in deserialized_data
else 0.15,
confidence=(
deserialized_data["confidence"] / 100
if "confidence" in deserialized_data
else 0.5
),
tile_overlap_distance=(
deserialized_data["tile_overlap_distance"]
if "tile_overlap_distance" in deserialized_data
else 0.15
),
)
print(
f"It took {round(time.time()-start_time)}sec for generating predictions"
Expand All @@ -571,12 +599,16 @@ def post(self, request, *args, **kwargs):
if use_josm_q is True:
feature["geometry"] = othogonalize_poly(
feature["geometry"],
maxAngleChange=deserialized_data["max_angle_change"]
if "max_angle_change" in deserialized_data
else 15,
skewTolerance=deserialized_data["skew_tolerance"]
if "skew_tolerance" in deserialized_data
else 15,
maxAngleChange=(
deserialized_data["max_angle_change"]
if "max_angle_change" in deserialized_data
else 15
),
skewTolerance=(
deserialized_data["skew_tolerance"]
if "skew_tolerance" in deserialized_data
else 15
),
)

print(
Expand Down
1 change: 1 addition & 0 deletions backend/fAIr-utilities
Submodule fAIr-utilities added at 93debb