Compositing Rasters with Dask#

An integral part of a geospatial developers workflow will occassionally require composite rasters to represent specific temporal ranges as a single image. This approach is often performed to visualize cloud-free imagery over an AOI.

In the following example, we will create a cloud-free median composite. Using “median” as the compositing method is, one, simple, and two, provides adequate resiliency against cloud shadow/edges that are not necessarily caputured by the cloud-mask.

[ ]:
# We will be using Sentinel-2 L2A imagery from Microsoft Planetary Computer STAC server:
!pip install planetary_computer
[ ]:
import os
import rasterio
import rioxarray
import pystac
import stackstac
import datetime
import planetary_computer
import dask
import json
import gcsfs

import dask.array as da
import numpy as np
import xarray as xr
import rioxarray as rxr
import matplotlib.pyplot as plt
import geopandas as gpd

from dask_gateway import Gateway
# from shapely.geometry import Point
from pystac_client import Client
from dask.distributed import performance_report
from typing import List
from rio_cogeo.cogeo import cog_translate
from rio_cogeo.profiles import cog_profiles
from IPython.display import clear_output

Set up the Dask Cluser and GCSFS Client

[2]:
def register_gcsfs_client(username:str):
    # set up the gcsfs system with credentials
    print('registering gcsfs')
    tok = os.path.join(os.environ['HOME'], f'geoanalytics_{username}', 'geo.json') # Change this to your own cred file
    tok_dict = json.load(open(tok)) # cloudpickle will just send the path to tok, which the processing nodes can't see. Send json instead
    gcs = gcsfs.GCSFileSystem(token=tok_dict, access='read_write')
    return gcs

def register_dask_client(imgname:str=None):
    ''' Make the gcsfs filesystem available with credentials and start client '''
    # we want to set up a cluster
    client = None
    cluster = None
    print('registering cluster')
    gateway = Gateway()
    options = gateway.cluster_options()
    if not imgname is None:
        options.image = imgname
    cluster = gateway.new_cluster(options)
    print(cluster.name)
    client = cluster.get_client()
    client.restart() # flush nodes
    return client, cluster, options
[ ]:
username = input('Username: ')
gcs = register_gcsfs_client(username=username)
client, cluster, options = register_dask_client(imgname='pangeo/pangeo-notebook:2022.04.15')
[ ]:
# View Dask cluster details
cluster
[ ]:
# Scale Dask cluster to 30 workers
cluster.scale(30)
[ ]:
client.restart()
[ ]:
# Function to write from the dask cluster to the remote bucket
@dask.delayed
def write_ras(gcs, epsg, ras, b, pth):
    import rioxarray
    try:
        ds = xr.Dataset()
        ras = ras.rio.write_crs(epsg)
        ras.rio.to_raster('ras.tif')
        # Turn the raster into a COG
        dst_profile = cog_profiles.get("deflate")
            cog_translate(
                'ras.tif',
                'ras_cog.tif',
                dst_profile,
                in_memory=True,
                quiet=False,
            )
        # Use GCSFS Client to put COG into remote bucket
        gcs.put('ras_cog.tif', pth)
        # Clean up rasters on Dask Worker
        os.remove('ras.tif')
        os.remove('ras_cog.tif')
        return 'success'
    except Exception as e:
        # Return error and associated band
        return f'{b}: {e}'

AOI#

This AOI was generated from: https://www.keene.edu/campus/maps/tool/

[ ]:
 # Rough Polygon around Vancouver in EPSG4326
poly = {
  "coordinates": [
    [
      [
        -122.1185303,
        49.3304921
      ],
      [
        -123.3078003,
        49.3644891
      ],
      [
        -123.2528687,
        48.7742927
      ],
      [
        -122.1817017,
        48.7579995
      ],
      [
        -122.1185303,
        49.3304921
      ]
    ]
  ],
  "type": "Polygon"
}
poly
[ ]:
# We'll write this to a file, so geopanadas can open it
poly_file_pth = '/tmp/geo.geojson'
with open(poly_file_pth, 'w') as f:
   json.dump(poly, f)
[ ]:
# Get AOI from local filesystem
f = gpd.read_file(poly_file_pth)
f
[ ]:
# The FOOTPRINT needs to be enveloped for pystac_client to query with
# More complex shapes can be be clipped with at later stages of the workflow
FOOTPRINT = f.to_crs('epsg:4326').geometry[0].envelope
FOOTPRINT
[ ]:
FOOTPRINT.bounds

Set Up STAC Client#

[ ]:
# Set up STAC client
api = Client.open('https://planetarycomputer.microsoft.com/api/stac/v1')
api

Configuration#

[ ]:
# CONFIG
# -------------
BASE_PTH = 'gs://geoanalytics-user-shared-data'
OUTPUT_DIR = 'tutorial_test'
TGT_BANDS =  ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11', 'B12', 'B8A']
YEARS = ['2020']
BEGIN_MONTH = '06'
END_MONTH = '09'
MAX_CLOUD = 20.0
READ_IN_CHUNK = 4096
RESOLUTION = 10
TEMPORAL_CHUNK = {'time': -1, 'band': 1, 'x': 128, 'y': 128}
SYNCHRONOUS = False # Write bands out one at a time - use if resources can't handle all bands at once for AOI
# -------------

Main Loop#

In the main loop we iterate over the number of target years, creating a composite of each for each of the composite bands.

[ ]:
%%time
# Main pipeline to iterate over years
write_futs = []
for year in YEARS:
    OUT_PTH = f'{BASE_PTH}/{OUTPUT_DIR}/{year}'
    date_range = f'{year}-{BEGIN_MONTH}-01/{year}-{END_MONTH}-30'

    # Query the Planetary Computer STAC server with pystac_client
    print(f'[Querying] {year}')
    items = api.search(
        collections = ['sentinel-2-l2a'],
        intersects = FOOTPRINT,
        query={"eo:cloud_cover": {"lt": MAX_CLOUD}},
        datetime = date_range,
    ).get_all_items()

    print(f'\tFound {len(items)} items')
    # planetarycomputer requires signed URLs to access Asset HREFs.
    print('\t[Signing data links]')
    signed_items = [planetary_computer.sign(item).to_dict() for item in items]

    # Pull out SCL DataArray before bands are looped through
    # since this will not change per band.
    scl_stk = (
        stackstac.stack(
            signed_items,
            assets=['SCL'],
            chunksize=READ_IN_CHUNK, # Set chunksize
            resolution=RESOLUTION, # Set all bands res to this
            bounds_latlon=FOOTPRINT.bounds, # clip to AOI bounds
        )
    )
    # Create binary mask [np.nan, 1]
    # https://sentinels.copernicus.eu/web/sentinel/technical-guides/sentinel-2-msi/level-2a/algorithm
    scl_stk.data = da.where(
            ((scl_stk.data==0)| # nodata
             (scl_stk.data==1)| # Saturated or Defective
             (scl_stk.data==8)| # cloud: medium probability
             (scl_stk.data==9)| # cloud: high probability
             (scl_stk.data==10)| # cloud: thin cirrus
             (scl_stk.data==11) # snow
            ), np.nan, 1)

    # Iterate over bands and build composite DAG
    for band in TGT_BANDS:
        clear_output(wait=True) # clear Jupyter Cell output
        print(f'[Processing {band}]')

        # Convert STAC query into a xarray.DataArray
        # with stackstac
        print('\t[Converting STAC query to DataArray]')
        data = (
            stackstac.stack(
                signed_items,
                assets=[band],
                chunksize=READ_IN_CHUNK, # Set chunksize
                resolution=RESOLUTION, # Set all bands res to this
                bounds_latlon=FOOTPRINT.bounds, # clip to AOI bounds
            ).where(lambda x: x > 0, other=np.nan) # Convert nodata zero to np.nan
        )

        # Mask the bands with the accompanying SCL band per time
        print('\t[Masking data with SCL]')
        masked = data.copy()
        masked.data = data.data * scl_stk.data # np.nan will mask unwated pixels

        # Create median composite
        print('\t[Creating Median composite]')
        # skip np.nan in temporal stack with skipna=True
        median = masked.median(dim='time', skipna=True, keep_attrs=True)
        median = median.chunk({'band': 1, 'y': 'auto', 'x': 'auto'})
        median = median.transpose('band', 'y', 'x')

        # Cast the xarray.DataArray to int16
        median = median.astype(np.uint16)

        # Get EPSG from median metadata
        epsg = median.coords['epsg'].values.tolist()


        if SYNCHRONOUS:
            # Issues with large AOI's - limited resources - so compute each composite
            # individually to relieve the Dask Cluster
            print(dask.compute(write_ras(gcs, epsg, median, band, f'{OUT_PTH}/{band}.tif')))
        else:
            # Write out each band to a file asynchronously
            print(f'\t[Processing and Writing {band}]')
            median.name = band
            median.attrs['long_name'] = band
            write_futs.append(write_ras(gcs, epsg, median, band, f'{OUT_PTH}/{band}.tif'))

if not SYNCHRONOUS:
    clear_output(wait=True)
    write_futs.visualize()
    with performance_report('dask_report.html'):
        print(dask.compute(write_futs)[0])

Shut Dask Cluster Down#

Make sure to shut down the Dask Cluster as not to incur costs

[ ]:
cluster.shutdown()

Remove Intermediate Files#

The files in /tmp will be deleted at shutdown of your session, however, if you plan on continuing to work, then cleaning up old and no-longer-needed files is helpful housekeeping

[ ]:
os.remove(poly_file_pth)