'''
Module dedicated to pixel classification.
'''
import numpy as np
import xarray as xr
import pandas as pd
import logging
from s2cloudless import S2PixelCloudDetector
FLAG_NAME = 'flags'
[docs]
class Settings:
'''
Settings used for the different masking procedures.
'''
def __init__(self):
self.bvis = 490
self.bnir = 842
self.bswir = 1610
self.bswir2 = 2190
# set extra bands (e.g., cirrus, water vapor)
self.bcirrus = None
self.bwv = 945
# mask thresholding parameters
self.swir_threshold = 0.2
self.ndwi_threshold = 0.0
self.vis_swir_index_threshold = 0.2
self.thin_cirrus_threshold = 3e-3
self.opac_cirrus_threshold = 9e-3
# settings for s2cloudless masks
self.low_confid_cloud_proba_thresh = 0.6
self.low_confid_cloud_dilation = 12
self.high_confid_cloud_proba_thresh = 0.8
self.high_confid_cloud_dilation = 5
self.number_of_flags = 32
[docs]
class Masking(Settings):
'''
Class dedicated to pixel classification and masking.
'''
def __init__(self, prod):
'''
:param prod: image raster containing the spectral bands
'''
Settings.__init__(self)
self.prod = prod
# if not(FLAG_NAME in prod.raster.keys()):
# self.prod[FLAG_NAME] = ...
if "S2" in prod.attrs['constellation_id']:
self.bcirrus = 1375
self.sentinel2 = True
elif (prod.attrs['constellation_id'] == 'L8') or (prod.attrs['constellation_id'] == 'L9'):
self.bcirrus = 1370
self.sentinel2 = False
self.flag_descriptions = np.empty(self.number_of_flags,
dtype='object') # np.dtype('U', 1000))
self.flag_names = np.empty(self.number_of_flags,
dtype='object') # np.dtype('U', 1000))
# generate flag raster from nodata mask
# self.create_flags()
self.flags = xr.DataArray()
self.flags.attrs['long_name'] = 'flags computed from l1c image'
self.flag_stats = {}
[docs]
def nodata_mask(self,
bitmask=0,
name='nodata',
description='nodata in input image'):
'''
set flag nodata: condition either nan or crazy numerical values
:return:
'''
self.flags = (np.isnan(self.prod.bands.isel(wl=1)) | (self.prod.bands.isel(wl=1) > 1e3)).astype(np.int32)
# set naming and description as attributes
self.flag_descriptions[bitmask] = description
self.flag_names[bitmask] = name
[docs]
def cloud_mask(self,
bitmasks=[1, 2],
names=['cloud_p06', 'cloud_p08'],
descriptions=['low confidence cloud mask from s2cloudless with settings proba.',
'high confidence cloud mask from s2cloudless with settings proba.']
):
'''
Apply s2cloudless masking with two levels of confidence
:param bitmasks: bit number on which the flag is coded
:param names: name of the flag
:param descriptions: description of the flag
:return:
'''
logging.info('cloud masking with s2cloudless')
# apply cloud mask
bands = np.array([self.prod.bands.values.transpose((1, 2, 0))])
# ---------------------
# first low confidence
# ---------------------
bitmask = bitmasks[0]
cloud_detector = S2PixelCloudDetector(threshold=self.low_confid_cloud_proba_thresh, average_over=1,
dilation_size=self.low_confid_cloud_dilation, all_bands=True)
probability_maps = cloud_detector.get_cloud_probability_maps(bands)
cloud_mask = cloud_detector.get_mask_from_prob(probability_maps)[0]
self.flags = self.flags + ((cloud_mask == 1) << bitmask)
# add name and description
self.flag_descriptions[bitmask] = descriptions[0] + ' threshold={:.2f}, dilation_size={:d}'.format(
self.low_confid_cloud_proba_thresh, self.low_confid_cloud_dilation)
self.flag_names[bitmask] = names[0]
# ---------------------
# second high confidence
# ---------------------
bitmask = bitmasks[1]
cloud_detector = S2PixelCloudDetector(threshold=self.high_confid_cloud_proba_thresh, average_over=1,
dilation_size=self.high_confid_cloud_dilation, all_bands=True)
cloud_mask = cloud_detector.get_mask_from_prob(probability_maps)[0]
self.flags = self.flags + ((cloud_mask == 1) << bitmask)
# add name and description
self.flag_descriptions[bitmask] = descriptions[1] + ' threshold={:.2f}, dilation_size={:d}'.format(
self.high_confid_cloud_proba_thresh, self.high_confid_cloud_dilation)
self.flag_names[bitmask] = names[1]
del cloud_detector, cloud_mask, probability_maps
[docs]
def water_mask(self,
bitmasks=[3, 4],
names=['water_swir_visible_index', 'water_red_visible_index'],
descriptions=['water mask from normalized index from swir and visible band',
'water mask from normalized index from nir and visible band, warning could fail for turbid waters'],
):
'''
apply water/land mask
compute mask from NDWI from visible and NIR and visible and SWIR
:param bitmasks: bit numbers on which the flags are coded
:param names: name of the flags
:param descriptions: description of the flags
:return:
'''
logging.info('water masking')
visible = self.prod.bands.sel(wl=self.bvis, method='nearest')
nir = self.prod.bands.sel(wl=self.bnir, method='nearest')
swir = self.prod.bands.sel(wl=self.bswir, method='nearest')
ndwi = (visible - nir) / (visible + nir)
ndwi_swir = (visible - swir) / (visible + swir)
# set flags raster
self.flags = self.flags + (
((ndwi_swir.values > self.vis_swir_index_threshold) << bitmasks[0]) +
((ndwi.values > self.ndwi_threshold) << bitmasks[1]))
# set naming and description as attributes
self.flag_descriptions[bitmasks[0]] = (descriptions[0] +
', bands centered on {:.1f} and {:.1f} nm'.format(self.bvis,
self.bswir) +
', threshold={:.2f}'.format(self.vis_swir_index_threshold))
self.flag_names[bitmasks[0]] = names[0]
self.flag_descriptions[bitmasks[1]] = (descriptions[1] +
', bands centered on {:.1f} and {:.1f} nm'.format(self.bvis, self.bnir) +
', threshold={:.2f}'.format(self.ndwi_threshold))
self.flag_names[bitmasks[1]] = names[1]
[docs]
def cirrus_mask(self,
bitmasks=[5, 6],
names=['thin_cirrus', 'opac_cirrus'],
descriptions=['thin cirrus mask from cirrus band',
'opac cirrus mask from cirrus band']
):
'''
Compute cirrus mask for thin and opac high clouds from spectral band around 1375 nm, if exists
:param bitmasks: bit numbers on which the flags are coded
:param names: name of the flags
:param descriptions: description of the flags
:return:
'''
logging.info('cirrus masking')
cirrus = self.prod.bands.sel(wl=self.bcirrus,method='nearest')
thresholds = [self.thin_cirrus_threshold, self.opac_cirrus_threshold]
for ii in range(2):
self.flags = self.flags + ((cirrus.values > thresholds[ii]) << bitmasks[ii])
# set naming and description as attributes
self.flag_descriptions[bitmasks[ii]] = descriptions[ii] + ', threshold {:.4f}'.format(thresholds[ii])
self.flag_names[bitmasks[ii]] = names[ii]
[docs]
def high_swir_mask(self,
bitmask=7,
name='high_swir',
description='high swir (bright cloud, too bright reflection...)'
):
'''
Flag to mask (too) high values of the swir bands, generally to remove potential cloud cotamination.
:param bitmask: bit number on which the flags are coded
:param name: name of the flag
:param description: description used as attribute
:return:
'''
logging.info('high swir masking')
b2200 = self.prod.bands.sel(wl=self.bswir2)
self.flags = self.flags + ((b2200.values > self.swir_threshold) << bitmask)
# set naming and description as attributes
self.flag_descriptions[bitmask] = description + ', threshold {:.4f}'.format(self.swir_threshold)
self.flag_names[bitmask] = name
[docs]
def surfwater_mask(self,
bitmasks=[8, 9, 10],
names=['surfwater_land', 'surfwater_water', 'surfwater_cloud_and_shadow'],
descriptions=['land mask from surfwater input file',
'water mask from surfwater input file',
'cloud and shadow mask from surfwater input file'],
):
'''
apply surfwater masks
TODO give reference for surfwater algorithm and code
:param bitmasks: bit numbers on which the flags are coded
:param names: name of the flags
:param descriptions: description of the flags
:return:
'''
logging.info('surfwater masking')
surfwater = self.prod.surfwater
# set flags raster
self.flags = self.flags + (
((0 == surfwater) << bitmasks[0]) +
((1 == surfwater) << bitmasks[1]) +
((2 <= surfwater) << bitmasks[2])
)
for ii in range(3):
# set naming and description as attributes
self.flag_descriptions[bitmasks[ii]] = descriptions[ii]
self.flag_names[bitmasks[ii]] = names[ii]
[docs]
def duplicate_landsat_flags(self,
bitmasks=[11, 12, 13, 14, 15, 16, 17],
landsat_bitmasks=[1,2,3,4,5,6,7],
names=['l1_dilated_cloud', 'l1_cirrus', 'l1_cloud',
'l1_cloud_shadow', 'l1_snow', 'l1_clear', 'l1_water'],
descriptions=['landsat dilated cloud mask',
'landsat cirrus mask',
'landsat cloud mask',
'landsat cloud shadow mask',
'landsat snow mask',
'landsat clear mask',
'landsat water mask'],
flags_landsat_band='flags_l1'
):
'''
Save and duplicate landsat main masks from level 1 flags
:param bitmasks: bit numbers on which the flags are coded
:param landsat_bitmasks: bit numbers of the landsat flags
:param names: name of the flags
:param descriptions: description of the flags
:param flags_landsat_band: name of the landsat flags band
:return:
'''
logging.info('duplicating landsat level 1 flags')
for ii, bit in enumerate(bitmasks):
flag_value = 1 << landsat_bitmasks[ii]
mask = (self.prod[flags_landsat_band] & flag_value) != 0
# set flags raster
self.flags = self.flags + (mask << bit)
self.flag_descriptions[bit] = descriptions[ii]
self.flag_names[bit] = names[ii]
[docs]
def get_stats(self):
'''
Compute image statistics for each flag and save them into dictionary
:return:
'''
flag_value = 1
for ii, flag_name in enumerate(self.prod[FLAG_NAME].flag_names):
if flag_name != 'None':
flag = ((self.prod[FLAG_NAME] & flag_value) != 0)
flag_stat = float(flag.sum() / flag.count())
self.flag_stats['flag_' + flag_name] = flag_stat
flag_value = flag_value << 1
self.prod.attrs.update(self.flag_stats)
[docs]
def print_stats(self):
'''
Provide pandas Dataframe with image statistics of each flag
:return: dflags:: pandas Dataframe with statistics
'''
pflags = self.prod[FLAG_NAME] # self.product[self.flag_ID]
# construct dataframe:
dflags = pd.DataFrame({'name': pflags.attrs["flag_names"]})
dflags['description'] = pflags.attrs["flag_descriptions"] # .split('\t')
dflags['bit'] = dflags.index
dflags = dflags[dflags.name != "None"]
stats = []
for name in dflags.name:
stats.append(self.prod.attrs["flag_" + name])
dflags['statistics'] = stats
return dflags
[docs]
def process(self,
output="prod"):
'''
Generate the flags raster and attributes
:param output:
- if None returns nothing but the raster is updated within the masking object
- if "prod" returns the full raster updated with the flags variable and attributes
- if "flags" returns xarray DataArray of the flags
plus the dictionary of flags statistics
:return: see :param output
'''
# apply the masking processors
self.nodata_mask()
if self.sentinel2:
self.cloud_mask()
else:
self.duplicate_landsat_flags()
self.water_mask()
if self.bcirrus:
self.cirrus_mask()
self.high_swir_mask()
self.surfwater_mask()
if isinstance(self.flags, xr.DataArray):
self.prod[FLAG_NAME] = (('y', 'x'), self.flags.values)
else:
self.prod[FLAG_NAME] = (('y', 'x'), self.flags)
# set the attributes with the names and description of the flags
self.prod[FLAG_NAME].attrs['flag_descriptions'] = self.flag_descriptions.astype(str)
self.prod[FLAG_NAME].attrs['flag_names'] = self.flag_names.astype(str)
# compute flag statistics over the image
self.get_stats()
self.prod.attrs.update(self.flag_stats)
if output == "prod":
return self.prod
elif output == "flags":
flags = self.prod[FLAG_NAME]
flags.attrs.update(self.flag_stats)
return flags
[docs]
@staticmethod
def create_mask(flags,
tomask=[0, 2],
tokeep=[3],
mask_name="mask",
_type=np.uint8
):
'''
Create binary mask from bitmask flags, with selection of bitmask to mask or to keep (by bit number).
The masking convention is: good pixels for mask == 0, bad pixels when mask == 1
:param flags: xarray dataarray with bitmask flags
:param tomask: array of bitmask flags used to mask
:param tokeep: array of bitmask flags for which pixels are kept (= good quality)
:param mask_name: name of the output mask
:param _type: type of the array (uint8 is recommended)
:return: mask
Example of output mask
>>> mask = create_mask(raster.flags,
... tomask = [0,2,11],
... tokeep = [3],
... mask_name="mask_from_flags" )
<xarray.DataArray>
'mask_from_flags'
y: 5490x: 5490
array([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[0, 0, 0, ..., 1, 1, 1],
[0, 0, 0, ..., 1, 1, 1],
[0, 0, 0, ..., 1, 1, 1]], dtype=uint8)
Coordinates:
x (x) float64 6e+05 6e+05 ... 7.098e+05 7.098e+05
y (y) float64 4.9e+06 4.9e+06 ... 4.79e+06
spatial_ref () int64 0
time () datetime64[ns] 2021-05-12T10:40:21
band () int64 1
Indexes: (2)
Attributes:
long_name: binary mask from flags
description: good pixels for mask == 0, bad pixels when mask == 1
'''
mask = xr.zeros_like(flags, dtype=_type)
flag_value_tomask = 0
flag_value_tokeep = 0
if len(tomask) > 0:
for bitnum in tomask:
flag_value_tomask += 1 << bitnum
if len(tokeep) > 0:
for bitnum in tokeep:
flag_value_tokeep += 1 << bitnum
if (len(tokeep) > 0) & (len(tomask) > 0):
mask = (((flags & flag_value_tomask) != 0) | ((flags & flag_value_tokeep) == 0)).astype(_type)
elif (len(tokeep) > 0) | (len(tomask) > 0):
if len(tokeep) > 0:
mask = ((flags & flag_value_tokeep) == 0)
else:
mask = ((flags & flag_value_tomask) != 0)
mask.attrs["long_name"] = "binary mask from flags"
mask.attrs["description"] = "good pixels for mask == 0, bad pixels when mask == 1"
mask.name = mask_name
return mask