K-Means Land Classification with Dask#

K-Means is a clustering algorithm that creates a segmentation map of different “clusters” which can represent estimated/easily-separable classifications which share similar values to a centroid optimum that represents the groups mean value. The classifications should not be considered accurate and requires verification - however it is a great starting point for unsupervised classification problems to determine separable classes.

For geospatial applications, we can use K-Means to create rough land-classification segmentation maps or generate automated labeled data given supporting methods to verify the classification is correct.

[ ]:
# We will be using Sentinel-2 L2A imagery from Microsoft Planetary Computer STAC server:
!pip install planetary_computer
[ ]:
import os
import rasterio
import rioxarray
import pystac
import stackstac
import datetime
import planetary_computer
import dask
import json
import gcsfs

import dask_ml.cluster

import numpy as np
import xarray as xr
import rioxarray as rxr
import matplotlib.pyplot as plt
import geopandas as gpd

from skimage.exposure import rescale_intensity
from dask_gateway import Gateway
from shapely.geometry import Polygon
from pystac_client import Client

1. Initialize Dask Cluster#

We will use Dask to power our computations and K-Means algoritm fitting/predicting.

[3]:
def register_gcsfs_client(username:str):
    # set up the gcsfs system with credentials
    print('registering gcsfs')
    tok = os.path.join(os.environ['HOME'], f'geoanalytics_{username}', 'geo.json') # Change this to your own cred file
    tok_dict = json.load(open(tok)) # cloudpickle will just send the path to tok, which the processing nodes can't see. Send json instead
    gcs = gcsfs.GCSFileSystem(token=tok_dict, access='read_write')
    return gcs

def register_dask_client(imgname:str=None):
    ''' Make the gcsfs filesystem available with credentials and start client '''
    # we want to set up a cluster
    client = None
    cluster = None
    print('registering cluster')
    gateway = Gateway()
    options = gateway.cluster_options()
    if not imgname is None:
        options.image = imgname
    cluster = gateway.new_cluster(options)
    print(cluster.name)
    client = cluster.get_client()
    client.restart() # flush nodes
    return client, cluster, options
[ ]:
username = input('Username: ')
gcs = register_gcsfs_client(username=username)
client, cluster, options = register_dask_client(imgname='pangeo/pangeo-notebook:2022.04.15')
[ ]:
cluster
[ ]:
cluster.scale(10)

AOI#

This AOI was generated from: https://www.keene.edu/campus/maps/tool/

We will, for the purpose of this demonstration, look at the Timberlea suburb in Montreal, Quebec, Canada

[ ]:
polygon = {
  "coordinates": [
    [
      [
        -73.8847303,
        45.4294192
      ],
      [
        -73.883357,
        45.4445361
      ],
      [
        -73.9108229,
        45.4442049
      ],
      [
        -73.9120245,
        45.4263471
      ],
      [
        -73.8847303,
        45.4294192
      ]
    ]
  ],
  "type": "Polygon"
}
polygon
[8]:
lon_list = []
lat_list = []

for lon,lat in polygon['coordinates'][0]:
    lon_list.append(lon)
    lat_list.append(lat)
polygon_geom = Polygon(zip(lon_list, lat_list))
crs = 'EPSG:4326'
polygon = gpd.GeoDataFrame(index=[0], crs=crs, geometry=[polygon_geom])
polygon
[8]:
geometry
0 POLYGON ((-73.88473 45.42942, -73.88336 45.444...
[9]:
FOOTPRINT = polygon.to_crs('epsg:4326').geometry[0].envelope
FOOTPRINT
[9]:
../_images/3_scientific_workflows_02-kmeans-dask_14_0.svg
[ ]:
# Set up Stac Client
api = Client.open('https://planetarycomputer.microsoft.com/api/stac/v1')
api
[11]:
# CONFIG
# -------------
TGT_BANDS =  ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11', 'B12', 'B8A']
YEAR = '2021'
BEGIN_MONTH = '07'
END_MONTH = '08'
DATE_RANGE = f'{YEAR}-{BEGIN_MONTH}-01/{YEAR}-{END_MONTH}-30'
MAX_CLOUD = 5
READ_IN_CHUNK = 4096
RESOLUTION = 10
# -------------
[12]:
# In our AOI with a max cloud cover of 5%, we find one image
items = api.search(
        collections = ['sentinel-2-l2a'],
        intersects = FOOTPRINT,
        query={"eo:cloud_cover": {"lt": MAX_CLOUD}},
        datetime = DATE_RANGE,
    ).get_all_items()

print(f'\tFound {len(items)} items')
        Found 1 items
[13]:
# Planetary Computer requires signing the items HREF so we can pull the Asset
signed_items = [planetary_computer.sign(item).to_dict() for item in items]
[14]:
# Create an Xarray DataArray of the pystac_client query results
data = (
    stackstac.stack(
        signed_items,
        assets=TGT_BANDS,
        chunksize=READ_IN_CHUNK, # Set chunksize
        resolution=RESOLUTION, # Set all bands res to this
        bounds_latlon=FOOTPRINT.bounds, # clip to AOI bounds
    ).where(lambda x: x > 0, other=np.nan) # Convert nodata zero to np.nan
)
data
/srv/conda/envs/notebook/lib/python3.9/site-packages/stackstac/prepare.py:413: FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
  xs = pd.Float64Index(np.linspace(minx, maxx, width, endpoint=False))
/srv/conda/envs/notebook/lib/python3.9/site-packages/stackstac/prepare.py:414: FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
  ys = pd.Float64Index(np.linspace(maxy, miny, height, endpoint=False))
[14]:
<xarray.DataArray 'stackstac-3304aac02b71837fe117ad2b247ac08a' (time: 1,
                                                                band: 12,
                                                                y: 207, x: 228)>
dask.array<where, shape=(1, 12, 207, 228), dtype=float64, chunksize=(1, 1, 207, 228), chunktype=numpy.ndarray>
Coordinates: (12/44)
  * time                                     (time) datetime64[ns] 2021-08-03...
    id                                       (time) <U54 'S2A_MSIL2A_20210803...
  * band                                     (band) <U3 'B01' 'B02' ... 'B8A'
  * x                                        (x) float64 5.851e+05 ... 5.874e+05
  * y                                        (y) float64 5.033e+06 ... 5.031e+06
    eo:cloud_cover                           float64 2.999
    ...                                       ...
    proj:bbox                                object {4990200.0, 609780.0, 499...
    title                                    (band) <U37 'Band 1 - Coastal ae...
    common_name                              (band) object 'coastal' ... 'red...
    center_wavelength                        (band) float64 0.443 0.49 ... 0.865
    full_width_half_max                      (band) float64 0.027 ... 0.033
    epsg                                     int64 32618
Attributes:
    spec:        RasterSpec(epsg=32618, bounds=(585080, 5030880, 587360, 5032...
    crs:         epsg:32618
    transform:   | 10.00, 0.00, 585080.00|\n| 0.00,-10.00, 5032950.00|\n| 0.0...
    resolution:  10
[15]:
# Visualize each band
data[0].plot.imshow(x='x', y='y', col='band', col_wrap=5)
[15]:
<xarray.plot.facetgrid.FacetGrid at 0x7f2ceb0075b0>
../_images/3_scientific_workflows_02-kmeans-dask_20_1.png

We are going to normalize our data and shift it into a int8 [0, 255] scale.

[16]:
def normalize(array):
    norm = ((array - array.min()) / (array.max() - array.min())*255).astype(np.int8)
    return norm
[17]:
data_norm = normalize(data)
data_norm
[17]:
<xarray.DataArray 'stackstac-3304aac02b71837fe117ad2b247ac08a' (time: 1,
                                                                band: 12,
                                                                y: 207, x: 228)>
dask.array<astype, shape=(1, 12, 207, 228), dtype=int8, chunksize=(1, 1, 207, 228), chunktype=numpy.ndarray>
Coordinates: (12/44)
  * time                                     (time) datetime64[ns] 2021-08-03...
    id                                       (time) <U54 'S2A_MSIL2A_20210803...
  * band                                     (band) <U3 'B01' 'B02' ... 'B8A'
  * x                                        (x) float64 5.851e+05 ... 5.874e+05
  * y                                        (y) float64 5.033e+06 ... 5.031e+06
    eo:cloud_cover                           float64 2.999
    ...                                       ...
    proj:bbox                                object {4990200.0, 609780.0, 499...
    title                                    (band) <U37 'Band 1 - Coastal ae...
    common_name                              (band) object 'coastal' ... 'red...
    center_wavelength                        (band) float64 0.443 0.49 ... 0.865
    full_width_half_max                      (band) float64 0.027 ... 0.033
    epsg                                     int64 32618
[18]:
# Visualize the normalized data and see that it's visually the same as before
data_norm[0].plot.imshow(x='x', y='y', col='band', col_wrap=5, vmin=0, vmax=255)
[18]:
<xarray.plot.facetgrid.FacetGrid at 0x7f2ce0266be0>
../_images/3_scientific_workflows_02-kmeans-dask_24_1.png

Initialize K-Means Algorithm#

[43]:
km = dask_ml.cluster.KMeans(n_clusters=4, oversampling_factor=0)
km
[43]:
KMeans(n_clusters=4, oversampling_factor=0)

Data Preprocessing#

Start by figuring out the shape of our data. Doing so will give us a better understanding of how to manipulate the data for the algorithm.

[44]:
arr_shape = data.shape
arr_shape
[44]:
(1, 12, 207, 228)

The K-Means algorithm requires a 2-dimensional array as an input. First we will essentially flatten the bands invidiually and then we will transpose the array so that each “column” represents a band.

[45]:
arr = data_norm.data[0].reshape(arr_shape[1], arr_shape[2]*arr_shape[3]).T
arr
[45]:
Array Chunk
Bytes 553.08 kiB 46.09 kiB
Shape (47196, 12) (47196, 1)
Count 195 Tasks 12 Chunks
Type int8 numpy.ndarray
12 47196

Make sure the entire array is visible to the K-Means algorithm (not as chunks - you will get Errors)

[46]:
arr_rc = arr.rechunk({1: arr.shape[1]})
arr_rc
[46]:
Array Chunk
Bytes 553.08 kiB 553.08 kiB
Shape (47196, 12) (47196, 12)
Count 196 Tasks 1 Chunks
Type int8 numpy.ndarray
12 47196

Fitting The K-Means Algorithm#

Here we will fit our input AOI imagery in to the K-Means algorithm

[47]:
%%time
km.fit(arr_rc)
Found fewer than 4 clusters in init (found 1).
CPU times: user 4.09 s, sys: 75.4 ms, total: 4.16 s
Wall time: 23.3 s
[47]:
KMeans(n_clusters=4, oversampling_factor=0)

Predicting Our Classification Clusters#

Once fitted, we can then perform predictions based on the calculated centroids of the input AOI imagery. For simplicity, we are using the same input for both fitting and prediction. Some fitted algorithms can be extended to similar areas - too dissimilar then the results will not be confident.

[48]:
%%time
pred = km.predict(arr_rc)
pred
CPU times: user 41.5 ms, sys: 1.88 ms, total: 43.4 ms
Wall time: 259 ms
[48]:
Array Chunk
Bytes 184.36 kiB 184.36 kiB
Shape (47196,) (47196,)
Count 114 Tasks 1 Chunks
Type int32 numpy.ndarray
47196 1

To visualize the data, we will need to reverse the steps we performed to the input. So, we will transpose and then reshape back into the original X and Y dimensions.

[49]:
pred = pred.T.reshape(arr_shape[2], arr_shape[3])
pred
[49]:
Array Chunk
Bytes 184.36 kiB 184.36 kiB
Shape (207, 228) (207, 228)
Count 116 Tasks 1 Chunks
Type int32 numpy.ndarray
228 207

Below is an RGB (left) and K-Means Cluster Prediction segmenation map (right) of our input AOI

[50]:
fig, ax = plt.subplots(1, 2, figsize=(20,10))
ax[0].imshow(rescale_intensity(data_norm[0].sel(band=['B04','B03','B02']).data, in_range=(0,30)).transpose(1,2,0))
ax[1].imshow(pred)
[50]:
<matplotlib.image.AxesImage at 0x7f2cd2128190>
../_images/3_scientific_workflows_02-kmeans-dask_43_1.png

As you can see, the K-Means can easily pick out major concrete infrastructure such as industry areas and highways. It also appears to do a decent job at separating different vegetation densities/types. However, in the residential/urban areas the algorithm doesn’t do as well. That could be an element of the sensor’s spatial resolution and the quality of the segmentation map could be improved on by adding more indicies (eg: NDVI) to the input/prediction operations.

[ ]:
cluster.shutdown()
[ ]: