diff --git a/geos-processing/src/geos/processing/generic_processing_tools/AttributesDiff.py b/geos-processing/src/geos/processing/generic_processing_tools/AttributesDiff.py index 9e793447..21b18479 100644 --- a/geos-processing/src/geos/processing/generic_processing_tools/AttributesDiff.py +++ b/geos-processing/src/geos/processing/generic_processing_tools/AttributesDiff.py @@ -9,6 +9,7 @@ from typing_extensions import Self, Any from vtkmodules.vtkCommonDataModel import vtkMultiBlockDataSet, vtkDataSet +from vtkmodules.vtkCommonDataModel import VTK_TRIANGLE,VTK_QUAD,VTK_LINE, VTK_POLY_LINE from geos.mesh.utils.arrayModifiers import createAttribute from geos.mesh.utils.arrayHelpers import ( getAttributeSet, getNumberOfComponents, getArrayInObject ) @@ -53,7 +54,7 @@ # Set the attributes to compare: dictAttributesToCompare: dict[ Piece, set[ str ] ] - attributesDiffFilter.setDicAttributesToCompare( dicAttributesToCompare ) + attributesDiffFilter.setDictAttributesToCompare( dictAttributesToCompare ) # Set the inf norm computation (if wanted): computeInfNorm: bool @@ -75,6 +76,8 @@ class AttributesDiff: def __init__( self: Self, speHandler: bool = False, + computePoints: bool = True, + computeCells: bool = True, ) -> None: """Compute differences (L1 and inf norm) between two identical meshes attributes. @@ -96,6 +99,9 @@ def __init__( self.outputMesh: vtkMultiBlockDataSet | vtkDataSet = vtkMultiBlockDataSet() + self.computeCells: bool = computeCells + self.computePoints: bool = computePoints + # Logger. self.logger: Logger if not speHandler: @@ -138,12 +144,14 @@ def setMeshes( raise ValueError( "The list of meshes must contain two meshes." ) if listMeshes[ 0 ].GetClassName() != listMeshes[ 1 ].GetClassName(): - raise TypeError( "The meshes must have the same type." ) + raise TypeError( f"The meshes must have the same type. {listMeshes[0].GetClassName()} and {listMeshes[1].GetClassName()}" ) + + dictMeshesMaxElementId: dict[ Piece, list[ int ] ] = {} + if self.computeCells: + dictMeshesMaxElementId.update({ Piece.CELLS: [ 0, 0 ]}) + if self.computePoints: + dictMeshesMaxElementId.update({ Piece.POINTS: [ 0, 0 ]}) - dictMeshesMaxElementId: dict[ Piece, list[ int ] ] = { - Piece.CELLS: [ 0, 0 ], - Piece.POINTS: [ 0, 0 ], - } if isinstance( listMeshes[ 0 ], vtkDataSet ): for meshId, mesh in enumerate( listMeshes ): for piece in dictMeshesMaxElementId: @@ -171,8 +179,10 @@ def setMeshes( raise ValueError( f"The total number of { piece.value } in the meshes must be the same." ) self.listMeshes = listMeshes - self.dictNbElements[ Piece.CELLS ] = dictMeshesMaxElementId[ Piece.CELLS ][ 0 ] + 1 - self.dictNbElements[ Piece.POINTS ] = dictMeshesMaxElementId[ Piece.POINTS ][ 0 ] + 1 + if self.computeCells: + self.dictNbElements[ Piece.CELLS ] = dictMeshesMaxElementId[ Piece.CELLS ][ 0 ] + 1 + if self.computePoints: + self.dictNbElements[ Piece.POINTS ] = dictMeshesMaxElementId[ Piece.POINTS ][ 0 ] + 1 self.outputMesh = listMeshes[ 0 ].NewInstance() self.outputMesh.ShallowCopy( listMeshes[ 0 ] ) self._computeDictSharedAttributes() @@ -219,6 +229,9 @@ def setDictAttributesToCompare( self: Self, dictAttributesToCompare: dict[ Piece Raises: ValueError: At least one attribute to compare is not a shared attribute. """ + assert not ((Piece.CELLS in dictAttributesToCompare) ^ (self.computeCells)) + assert not ((Piece.POINTS in dictAttributesToCompare) ^ (self.computePoints)) + for piece, setSharedAttributesToCompare in dictAttributesToCompare.items(): if not setSharedAttributesToCompare.issubset( self.dictSharedAttributes[ piece ] ): wrongAttributes: set[ str ] = setSharedAttributesToCompare.difference( @@ -307,6 +320,8 @@ def _computeDictAttributesArray( self: Self ) -> None: listMeshBlockId: list[ int ] = getBlockElementIndexesFlatten( mesh ) for meshBlockId in listMeshBlockId: dataset: vtkDataSet = vtkDataSet.SafeDownCast( mesh.GetDataSet( meshBlockId ) ) + if dataset.GetCell(0).GetCellType() in [VTK_TRIANGLE, VTK_QUAD, VTK_LINE, VTK_POLY_LINE]: + continue arrayAttributeData = getArrayInObject( dataset, attributeName, piece ) nbAttributeComponents = getNumberOfComponents( dataset, attributeName, piece ) lToG: npt.NDArray[ Any ] = getArrayInObject( dataset, "localToGlobalMap", piece )