Source code for eosets.mosaic

# Copyright 2026, SERTIT-ICube - France, https://sertit.unistra.fr/
# This file is part of eosets project
#     https://github.com/sertit/eosets
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class implementing the mosaic object"""

import contextlib
import logging
import os
from collections import defaultdict
from enum import unique

try:
    from typing import Self
except Exception:
    from typing_extensions import Self

import geopandas as gpd
import xarray as xr
from eoreader import cache
from eoreader.bands import BandsType, to_band, to_str
from eoreader.products import Product
from eoreader.reader import Reader
from eoreader.utils import UINT16_NODATA
from sertit import AnyPath, files, path, rasters, types
from sertit.misc import ListEnum
from sertit.types import AnyPathStrType

from eosets import EOSETS_NAME
from eosets.exceptions import IncompatibleProducts
from eosets.set import GeometryCheck, GeometryCheckType, Set
from eosets.utils import (
    AnyProductType,
    look_for_prod_band_file,
    read,
    stack,
)

READER = Reader()

LOGGER = logging.getLogger(EOSETS_NAME)


[docs] @unique class MosaicMethod(ListEnum): """Available mosaicing methods.""" GTIFF = "merge_gtiff" VRT = "merge_vrt"
[docs] class Mosaic(Set): """Class of mosaic objetcs, composed of several contiguous EOReader's products acquired the same day."""
[docs] def __init__( self, paths: list | AnyProductType, output_path: AnyPathStrType = None, id: str = None, remove_tmp: bool = True, contiguity_check: GeometryCheckType = GeometryCheck.EXTENT, mosaic_method: MosaicMethod | str = MosaicMethod.VRT, **kwargs, ): # Manage reference product self.prods: dict = {} """ Products (contiguous and acquired the same day). """ # We need the date in _manage_prods self.date = None """ Date of the mosaic. If not provided in kwargs, using the first product's date. """ self.datetime = None """ Datetime of the mosaic. If not provided in kwargs, using the first product's datetime. """ self.mosaic_method = MosaicMethod.convert_from(mosaic_method)[0] """ Mosaicing method. If GTIFF is specified, the temporary files from every products will be removed, if VRT is spoecified, they will not.""" contiguity_check = GeometryCheck.convert_from(contiguity_check)[0] # Init the base class super().__init__( output_path, id, remove_tmp, **kwargs, ) # Update products of the mosaic self._manage_prods(paths, contiguity_check, **kwargs) # Create condensed_name: [{date}-{sat_id}_]{???} # TODO: is it OK ? # TODO: if fixed date, change that # TODO: if all same constellation, set it only once # TODO: add sth ? self.condensed_name = ( self.id if self.id is not None else f"{self.date.strftime('%Y%m%d')}_{'-'.join(list(set([prod.constellation_id for prod in self.get_prods()])))}" ) if self.id is None: self.id = self.condensed_name # Post init at the set level self.post_init(**kwargs)
[docs] def clean_tmp(self): """ Clean the temporary directory of the current mosaic """ for prod in self.get_prods(): prod.clean_tmp()
[docs] def clear(self): """ Clear this mosaic's cache """ # Delete all cached properties and functions for prod in self.get_prods(): prod.clear()
def _manage_output(self): """ Manage the output specifically for this child class """ for prod in self.get_prods(): # Never mind for non-existing files: they have already been copied :) with contextlib.suppress(FileNotFoundError): prod.output = self._get_tmp_folder(writable=True) def _manage_prods( self, paths: list | AnyProductType, contiguity_check: GeometryCheck, **kwargs, ): """ Manage products attributes and check the compatibility of the mosaic's components Args: paths (list | AnyProductType): Paths of the mosaic contiguity_check (GeometryCheck): Method to check the contiguity of the mosaic **kwargs: Other arguments Raises: IncompatibleProducts: Incompatible products if not contiguous or not the same date """ def __open(prod_or_path: AnyProductType) -> Product: """Open an EOReader product""" if path.is_path(prod_or_path): prod_: Product = READER.open( prod_or_path, remove_tmp=self._remove_tmp, **kwargs, ) elif isinstance(prod_or_path, Product): prod_ = prod_or_path else: raise NotImplementedError( "You should give either a path or 'eoreader.Product' to build your Mosaic!" ) if prod_ is None: raise ValueError( f"There is no existing products in EOReader corresponding to {paths[0]}" ) # Set output prod_.output = self._get_tmp_folder(writable=True) return prod_ if types.is_iterable(paths): ref_path = paths[0] else: ref_path = paths paths = [paths] # Nof prods (before checks) self.nof_prods = len(paths) # Open first product as a reference first_prod = __open(ref_path) self.prods[first_prod.condensed_name] = first_prod # Open others for prod_path in paths[1:]: prod = __open(prod_path) # Ensure compatibility of the mosaic component, i.e. unique date and contiguous product self.prods[prod.condensed_name] = prod self.check_compatibility(first_prod, prod) self.check_contiguity(contiguity_check) # Create full_name self.date = kwargs.pop("date", first_prod.date) self.datetime = kwargs.pop("datetime", first_prod.datetime) self.full_name = ( f"{'-'.join([prod.condensed_name for prod in self.get_prods()])}" )
[docs] def get_mosaics(self) -> list[Self]: """ Get all the products as a list. Returns: list: Products list """ return [self]
[docs] def check_compatibility(self, first_prod: Product, prod: Product) -> None: """ Check if the mosaic products are coherent between each other. - Same sensor type - Same date TODO: same constellation ? If not, throws a IncompatibleProducts error. Args: first_prod(Product): First product, to be checked against prod (Product): Product to check Raises: IncompatibleProducts: Incompatible products if not contiguous or not the same date """ # Check same sensor_type if first_prod.sensor_type != prod.sensor_type: raise IncompatibleProducts( f"Components of a mosaic should have the same sensor type! {first_prod.sensor_type.name=} != {prod.sensor_type.name=}" ) # Check same date if first_prod.date != prod.date: raise IncompatibleProducts( f"Components of a mosaic should have the same date! {first_prod.date=} != {prod.date=}" )
[docs] def check_contiguity(self, check_contiguity: GeometryCheck): """ Check the contiguity of the mosaic Args: check_contiguity (GeometryCheck): Contiguity checking method Raises: IncompatibleProducts: Incompatible products if not contiguous according to the given method """ if check_contiguity == GeometryCheck.EXTENT: union_extent = self.extent() if len(union_extent) > 1: raise IncompatibleProducts( "The mosaic should have a contiguous extent!" ) elif check_contiguity == GeometryCheck.FOOTPRINT: union_footprint = self.footprint() if len(union_footprint) > 1: raise IncompatibleProducts( "The mosaic should have a contiguous footprint!" ) else: LOGGER.warning("The contiguity of your mosaic won't be checked!") pass
[docs] def read_mtd(self): """Read the pair's metadata, but not implemented for now.""" # TODO: how ? Just return the fields that are shared between mosaic's components ? Or create a XML from scratch ? raise NotImplementedError
[docs] @cache def footprint(self) -> gpd.GeoDataFrame: """ Get the footprint of the mosaic. Returns: gpd.GeoDataFrame: Footprint of the mosaic """ ref_prod = self.get_first_prod() footprint: gpd.GeoDataFrame = self.get_first_prod().footprint() if self.nof_prods > 1: for prod in self.get_prods()[1:]: footprint = footprint.overlay( prod.footprint().to_crs(ref_prod.crs()), how="union" ) # Dissolve and explode the footprint footprint = footprint.dissolve().explode(index_parts=True) return footprint
[docs] @cache def extent(self) -> gpd.GeoDataFrame: """ Get the extent of the mosaic. Returns: gpd.GeoDataFrame: Extent of the mosaic """ ref_prod = self.get_first_prod() extent: gpd.GeoDataFrame = self.get_first_prod().extent() if self.nof_prods > 1: for prod in self.get_prods()[1:]: extent = extent.overlay( prod.extent().to_crs(ref_prod.crs()), how="union" ) # Dissolve and explode the extent extent = extent.dissolve().explode(index_parts=True) return extent
def _get_band_suffix(self): """Get the band suffix""" # For multiple products, a mosaic is needed # For one product, just copy the raster band so set it to tif return f"{self.mosaic_method.name.lower()}" if self.nof_prods > 1 else "tif"
[docs] def load( self, bands: BandsType, pixel_size: float = None, **kwargs, ) -> xr.Dataset: """ Load the bands and compute the wanted spectral indices. Args: bands (BandsType): Wanted bands pixel_size (float): Pixel size of the returned Dataset. If not specified, use the mosaic's pixel size. **kwargs: Other arguments used to load bands Returns: xr.Dataset: Wanted bands as xr.Datasets """ # Override default pixel size if pixel_size is None: pixel_size = self.default_pixel_size # Get merge function and extension merge_fct = getattr(rasters, self.mosaic_method.value) # Convert just in case bands = to_band(bands) # Get the bands to be loaded bands_to_load, bands_path = self.get_bands_to_load( bands, self._get_band_suffix(), pixel_size=pixel_size, **kwargs ) # Check validity of the bands for prod in self.get_prods(): for band in bands_to_load: assert prod.has_band(band), ( f"{prod.condensed_name} has not a {to_str(band)[0]} band." ) # Load and reorganize bands prod_band_paths = defaultdict(list) if bands_to_load: for prod in self.get_prods(): prod: Product LOGGER.debug( f"*** Loading {to_str(bands_to_load)} for {prod.condensed_name} ***" ) # Don't leave it to None to ensure looking for the correct band name multi-resolution constellations (i.e. SWIR to 10m) if pixel_size is None: pixel_size = prod.pixel_size # Load bands prod.load(bands_to_load, pixel_size, **kwargs) # Store paths for band in bands_to_load: band_path = look_for_prod_band_file( prod, band, pixel_size, **kwargs ) prod_band_paths[band].append(str(band_path)) # Merge merged_dict = {} for band in bands_path: output_path = bands_path[band] if not output_path.is_file(): if self.nof_prods > 1: LOGGER.debug(f"Merging bands {to_str(band)[0]}") if self.mosaic_method == MosaicMethod.VRT: prod_paths = [] for band_path in prod_band_paths[band]: out_path, exists = self._get_out_path( os.path.basename(band_path) ) if not exists: # Copy any band, even EOReader's as this band may be needed several times in indices computation (don't move it) files.copy(band_path, out_path) prod_paths.append(out_path) else: prod_paths = prod_band_paths[band] # Don't pass kwargs here because of unwanted errors merge_fct(prod_paths, output_path) else: # Copy any band, even EOReader's as this band may be needed several times in indices computation (don't move it) files.copy(prod_band_paths[band][0], output_path) # Load in memory and update attribute merged_dict[band] = self._update_attrs(read(output_path), bands, **kwargs) # Collocate VRTs LOGGER.debug("Collocating bands") merged_dict = self._collocate_bands(merged_dict) # Create a dataset (only after collocation) coords = None if merged_dict: coords = merged_dict[bands[0]].coords # Make sure the dataset has the bands in the right order -> re-order the input dict mos_ds = xr.Dataset({key: merged_dict[key] for key in bands}, coords=coords) # Update attributes mos_ds = self._update_xds_attrs(mos_ds, bands) return mos_ds
[docs] def stack( self, bands: list, pixel_size: float = None, stack_path: AnyPathStrType = None, save_as_int: bool = False, **kwargs, ) -> xr.DataArray: """ Stack bands and index of a mosaic. Args: bands (list): Bands and index combination pixel_size (float): Stack pixel size. . If not specified, use the product pixel size. stack_path (AnyPathStrType): Stack path save_as_int (bool): Convert stack to uint16 to save disk space (and therefore multiply the values by 10.000) **kwargs: Other arguments passed to :code:`load` or :code:`rioxarray.to_raster()` (such as :code:`compress`) Returns: xr.DataArray: Stack as a DataArray """ bands = to_band(bands) if stack_path: stack_path = AnyPath(stack_path) if stack_path.is_file(): return read(stack_path, pixel_size=pixel_size) else: os.makedirs(str(stack_path.parent), exist_ok=True) # Create the analysis stack band_ds = self.load(bands, pixel_size=pixel_size, **kwargs) # Stack bands if save_as_int: kwargs["nodata"] = kwargs.get("nodata", UINT16_NODATA) else: kwargs["nodata"] = kwargs.get("nodata", self.nodata) stk, dtype = stack(band_ds, **kwargs) # Update stack's attributes stk = self._update_attrs(stk, bands, **kwargs) # Write on disk if stack_path: self._write_stack(band_ds, stk, stack_path, save_as_int, dtype, **kwargs) return stk
def _collocate_bands(self, bands: dict, reference: xr.DataArray = None) -> dict: """ Collocate all bands from a dict Args: bands (dict): Dict of bands to collocate if needed reference (xr.DataArray): Reference array Returns: dict: Collocated bands """ return self.get_first_prod()._collocate_bands(bands, reference) def _update_attrs_constellation_specific( self, xarr: xr.DataArray, bands: list, **kwargs ) -> xr.DataArray: """ Update attributes of the given array (constellation specific) Args: xarr (xr.DataArray): Array whose attributes need an update bands (list): Array name (as a str or a list) Returns: xr.DataArray: Updated array/dataset """ xarr.attrs["acquisition_date"] = self.date return xarr
AnyMosaicType = list | AnyProductType | Mosaic """ Any Mosaic type (either a list or paths or products, a path, a product or a mosaic itself) """