diff --git a/src/DIRAC/Core/Base/DB.py b/src/DIRAC/Core/Base/DB.py index 6eee6c84197..ff487d00cfc 100755 --- a/src/DIRAC/Core/Base/DB.py +++ b/src/DIRAC/Core/Base/DB.py @@ -14,7 +14,7 @@ def __init__(self, dbname, fullname, debug=False, parentLogger=None): result = getDBParameters(fullname) if not result["OK"]: - raise RuntimeError(f"Cannot get database parameters: {result['Message']}") + raise RuntimeError(f"Cannot get database parameters for '{dbname}': {result['Message']}") dbParameters = result["Value"] self.dbHost = dbParameters["Host"] diff --git a/src/DIRAC/Core/DISET/private/Service.py b/src/DIRAC/Core/DISET/private/Service.py index 25b2058900c..e3ea36c3718 100644 --- a/src/DIRAC/Core/DISET/private/Service.py +++ b/src/DIRAC/Core/DISET/private/Service.py @@ -1,9 +1,9 @@ """ - Service class implements the server side part of the DISET protocol - There are 2 main parts in this class: +Service class implements the server side part of the DISET protocol +There are 2 main parts in this class: - - All useful functions for initialization - - All useful functions to handle the requests +- All useful functions for initialization +- All useful functions to handle the requests """ # pylint: skip-file # __searchInitFunctions gives RuntimeError: maximum recursion depth exceeded @@ -102,17 +102,10 @@ def initialize(self): } self.securityLogging = Operations().getValue("EnableSecurityLogging", False) - # Initialize Monitoring - # The import needs to be here because of the CS must be initialized before importing - # this class (see https://github.com/DIRACGrid/DIRAC/issues/4793) - from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter - - self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring") - # Call static initialization function try: self._handler["class"]._rh__initializeClass( - dict(self._serviceInfoDict), self._lockManager, self._msgBroker, self.activityMonitoringReporter + dict(self._serviceInfoDict), self._lockManager, self._msgBroker, None ) if self._handler["init"]: for initFunc in self._handler["init"]: @@ -132,6 +125,10 @@ def initialize(self): gLogger.exception(errMsg) return S_ERROR(errMsg) if self.activityMonitoring: + from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter + + self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring") + gThreadScheduler.addPeriodicTask(30, self.__reportActivity) gThreadScheduler.addPeriodicTask(100, self.__activityMonitoringReporting) @@ -563,6 +560,9 @@ def _executeAction(self, trid, proposalTuple, handlerObj): retStatus = "OK" else: retStatus = "ERROR" + from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter + + self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring") self.activityMonitoringReporter.addRecord( { "timestamp": int(TimeUtilities.toEpochMilliSeconds()), @@ -592,6 +592,9 @@ def _mbReceivedMsg(self, trid, msgObj): handlerObj = result["Value"] response = handlerObj._rh_executeMessageCallback(msgObj) if self.activityMonitoring and response["OK"]: + from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter + + self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring") self.activityMonitoringReporter.addRecord( { "timestamp": int(TimeUtilities.toEpochMilliSeconds()), diff --git a/src/DIRAC/DataManagementSystem/DB/FileCatalogComponents/FileManager/FileManager.py b/src/DIRAC/DataManagementSystem/DB/FileCatalogComponents/FileManager/FileManager.py index c2ec3e42b3f..f920c624bce 100755 --- a/src/DIRAC/DataManagementSystem/DB/FileCatalogComponents/FileManager/FileManager.py +++ b/src/DIRAC/DataManagementSystem/DB/FileCatalogComponents/FileManager/FileManager.py @@ -298,7 +298,7 @@ def _insertFiles(self, lfns, uid, gid, connection=False): if insertTuples: fields = "FileID,GUID,Checksum,ChecksumType,CreationDate,ModificationDate,Mode" req = f"INSERT INTO FC_FileInfo ({fields}) VALUES {','.join(insertTuples)}" - res = self.db._update(req) + res = self.db._update(req, conn=connection) if not res["OK"]: self._deleteFiles(toDelete, connection=connection) for lfn in list(lfns): @@ -841,7 +841,7 @@ def repairFileTables(self, connection=False): fields = "FileID,GUID,CreationDate,ModificationDate,Mode" req = f"INSERT INTO FC_FileInfo ({fields}) VALUES {','.join(insertTuples)}" - result = self.db._update(req) + result = self.db._update(req, conn=connection) if not result["OK"]: return result diff --git a/src/DIRAC/RequestManagementSystem/DB/RequestDB.py b/src/DIRAC/RequestManagementSystem/DB/RequestDB.py index 618677084b1..249363d9c7a 100644 --- a/src/DIRAC/RequestManagementSystem/DB/RequestDB.py +++ b/src/DIRAC/RequestManagementSystem/DB/RequestDB.py @@ -1,21 +1,21 @@ # We disable pylint no-callable because of https://github.com/PyCQA/pylint/issues/8138 -""" Frontend for ReqDB +"""Frontend for ReqDB - :mod: RequestDB +:mod: RequestDB - ======================= +======================= - .. module: RequestDB +.. module: RequestDB - :synopsis: db holding Requests +:synopsis: db holding Requests - db holding Request, Operation and File +db holding Request, Operation and File """ + import datetime import errno import random - from urllib.parse import quote_plus from sqlalchemy import ( @@ -32,6 +32,7 @@ create_engine, distinct, func, + inspect, ) from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import backref, joinedload, registry, relationship, sessionmaker @@ -187,6 +188,38 @@ class RequestDB: db holding requests """ + @staticmethod + def _get_column(table_name, column_name): + """Resolve supported ORM column attributes without evaluating input.""" + + models = {"Request": Request, "Operation": Operation} + aliases = {"Status": "_Status"} + + model = models.get(table_name) + if model is None: + raise ValueError(f"Unknown table '{table_name}'") + + resolved_name = aliases.get(column_name, column_name) + if resolved_name not in inspect(model).column_attrs: + raise ValueError(f"Unknown {table_name} attribute '{column_name}'") + + return getattr(model, resolved_name) + + @classmethod + def _apply_web_filter(cls, query, table_name, column_name, value): + column = cls._get_column(table_name, column_name) + if isinstance(value, list): + return query.filter(column.in_(value)) + return query.filter(column == value) + + @classmethod + def _get_order_expression(cls, table_name, column_name, direction): + column = cls._get_column(table_name, column_name) + normalized_direction = direction.lower() + if normalized_direction not in {"asc", "desc"}: + raise ValueError(f"Unknown sort direction '{direction}'") + return getattr(column, normalized_direction)() + def __getDBConnectionInfo(self, fullname): """Collect from the CS all the info needed to connect to the DB. This should be in a base class eventually @@ -704,13 +737,12 @@ def getRequestSummaryWeb(self, selectDict, sortList, startItem, maxItems): elif key == "Status": key = "_Status" - if isinstance(value, list): - summaryQuery = summaryQuery.filter(eval(f"{tableName}.{key}.in_({value})")) - else: - summaryQuery = summaryQuery.filter(eval(f"{tableName}.{key}") == value) + summaryQuery = self._apply_web_filter(summaryQuery, tableName, key, value) if sortList: - summaryQuery = summaryQuery.order_by(eval(f"Request.{sortList[0][0]}.{sortList[0][1].lower()}()")) + summaryQuery = summaryQuery.order_by( + self._get_order_expression("Request", sortList[0][0], sortList[0][1]) + ) try: requestLists = summaryQuery.all() @@ -744,6 +776,8 @@ def getRequestSummaryWeb(self, selectDict, sortList, startItem, maxItems): resultDict["TotalRecords"] = nRequests return S_OK(resultDict) + except ValueError as e: + return S_ERROR(str(e)) # except Exception as e: self.log.exception("getRequestSummaryWeb: unexpected exception", lException=e) @@ -763,16 +797,14 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict): session = self.DBSession() - if groupingAttribute == "Type": - groupingAttribute = "Operation.Type" - elif groupingAttribute == "Status": - groupingAttribute = "Request._Status" - else: - groupingAttribute = f"Request.{groupingAttribute}" - try: + if groupingAttribute == "Type": + groupingColumn = self._get_column("Operation", "Type") + else: + groupingColumn = self._get_column("Request", groupingAttribute) + summaryQuery = session.query( - eval(groupingAttribute), func.count(Request.RequestID) # pylint: disable=not-callable,no-member + groupingColumn, func.count(Request.RequestID) # pylint: disable=not-callable,no-member ) for key, value in selectDict.items(): @@ -788,12 +820,9 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict): elif key == "Status": key = "_Status" - if isinstance(value, list): - summaryQuery = summaryQuery.filter(eval(f"{objectType}.{key}.in_({value})")) - else: - summaryQuery = summaryQuery.filter(eval(f"{objectType}.{key}") == value) + summaryQuery = self._apply_web_filter(summaryQuery, objectType, key, value) - summaryQuery = summaryQuery.group_by(eval(groupingAttribute)) + summaryQuery = summaryQuery.group_by(groupingColumn) try: requestLists = summaryQuery.all() @@ -805,6 +834,8 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict): return S_OK(resultDict) + except ValueError as e: + return S_ERROR(str(e)) except Exception as e: self.log.exception("getRequestSummaryWeb: unexpected exception", lException=e) return S_ERROR(f"getRequestSummaryWeb: unexpected exception : {e}") @@ -817,11 +848,11 @@ def getDistinctValues(self, tableName, columnName): session = self.DBSession() distinctValues = [] - if columnName == "Status": - columnName = "_Status" try: - result = session.query(distinct(eval(f"{tableName}.{columnName}"))).all() + result = session.query(distinct(self._get_column(tableName, columnName))).all() distinctValues = [dist[0] for dist in result] + except ValueError as e: + return S_ERROR(str(e)) except NoResultFound: pass except Exception as e: diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py index 28dff17ed46..418413adc84 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py @@ -302,21 +302,11 @@ def _submitPilotsPerQueue(self, queueName: str): self.failedQueues[queueName] += 1 return S_OK(0) - # Adjust queueCPUTime: needed to generate the proxy - if "CPUTime" not in queueDictionary["ParametersDict"]: - self.log.error("CPU time limit is not specified, skipping", f"queue {queueName}") - return S_ERROR(f"CPU time limit is not specified, skipping queue {queueName}") - - queueCPUTime = int(queueDictionary["ParametersDict"]["CPUTime"]) - if queueCPUTime > self.maxQueueLength: - queueCPUTime = self.maxQueueLength - # Get CE instance ce = self.queueDict[queueName]["CE"] - # Set credentials - cpuTime = queueCPUTime + 86400 - result = self._setCredentials(ce, cpuTime) + # Set credentials: needed for authenticated CE operations (e.g. ce.available() on AREX) + result = self._setCredentials(ce, 3600) if not result["OK"]: self.log.error("Failed to set credentials:", result["Message"]) return result @@ -901,20 +891,18 @@ def __supportToken(self, ce: ComputingElement) -> bool: return "Token" in ce.ceParameters.get("Tag", []) or f"Token:{self.vo}" in ce.ceParameters.get("Tag", []) def _setCredentials(self, ce: ComputingElement, proxyMinimumRequiredValidity: int): - """ + """Add a proxy and a token to the ComputingElement. :param ce: ComputingElement instance :param proxyMinimumRequiredValidity: number of seconds needed to perform an operation with the proxy - :param tokenMinimumRequiredValidity: number of seconds needed to perform an operation with the token """ getNewProxy = False # If the CE does not already embed a proxy, we need one if not ce.proxy: getNewProxy = True - - # If the CE embeds a proxy that is too short to perform a given operation, we need a new one - if ce.proxy: + else: + # If the CE embeds a proxy that is too short to perform a given operation, we need a new one result = ce.proxy.getRemainingSecs() if not result["OK"]: return result @@ -924,11 +912,19 @@ def _setCredentials(self, ce: ComputingElement, proxyMinimumRequiredValidity: in # Generate a new proxy if needed if getNewProxy: - self.log.verbose("Getting pilot proxy", f"for {self.pilotDN}/{self.vo} {proxyMinimumRequiredValidity} long") + proxyRequestedValidity = max(proxyMinimumRequiredValidity, 86400) + self.log.verbose("Getting pilot proxy", f"for {self.pilotDN}/{self.vo} {proxyRequestedValidity} long") pilotGroup = Operations(vo=self.vo).getValue("Pilot/GenericPilotGroup") - result = gProxyManager.getPilotProxyFromDIRACGroup(self.pilotDN, pilotGroup, proxyMinimumRequiredValidity) + result = gProxyManager.getPilotProxyFromDIRACGroup(self.pilotDN, pilotGroup, proxyRequestedValidity) if not result["OK"]: return result + result_validity = result["Value"].getRemainingSecs() + if not result_validity["OK"]: + return result_validity + if result_validity["Value"] < proxyRequestedValidity: + self.log.warn( + f"The validity of the generated proxy ({result_validity['Value']} seconds) is less than the requested {proxyRequestedValidity} seconds" + ) ce.setProxy(result["Value"]) # Get valid token if needed diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_SiteDirector.py b/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_SiteDirector.py index 8e2d7179b7e..6d720e2e446 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_SiteDirector.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_SiteDirector.py @@ -1,5 +1,4 @@ -""" Test class for SiteDirector -""" +"""Test class for SiteDirector""" # pylint: disable=protected-access import datetime @@ -145,6 +144,7 @@ mockPMProxy = MagicMock() mockPMProxy.dumpAllToString.return_value = {"OK": True, "Value": "fakeProxy"} +mockPMProxy.getRemainingSecs.return_value = {"OK": True, "Value": 1000} mockPMProxyReply = MagicMock() mockPMProxyReply.return_value = {"OK": True, "Value": mockPMProxy} @@ -183,6 +183,10 @@ def mock_getElementStatus(ceNamesList, *args, **kwargs): mocker.patch( "DIRAC.WorkloadManagementSystem.Agent.SiteDirector.gProxyManager.downloadProxy", side_effect=mockPMProxyReply ) + mocker.patch( + "DIRAC.WorkloadManagementSystem.Agent.SiteDirector.gProxyManager.getPilotProxyFromDIRACGroup", + side_effect=mockPMProxyReply, + ) sd = SiteDirector() # Set logger @@ -288,7 +292,8 @@ def test_getPilotWrapper(mocker, sd, pilotWrapperDirectory): assert os.path.exists(res) and os.path.isfile(res) -def test__submitPilotsToQueue(sd): +@pytest.mark.parametrize("proxy_validity", [1, 1000, 900000]) +def test__submitPilotsToQueue(sd, proxy_validity): """Testing SiteDirector()._submitPilotsToQueue()""" # Create a MagicMock that does not have the workingDirectory # attribute (https://cpython-test-docs.readthedocs.io/en/latest/library/unittest.mock.html#deleting-attributes) @@ -297,6 +302,7 @@ def test__submitPilotsToQueue(sd): del ceMock.workingDirectory proxyObject_mock = MagicMock() proxyObject_mock.dumpAllToString.return_value = S_OK("aProxy") + proxyObject_mock.getRemainingSecs.return_value = S_OK(proxy_validity) ceMock.proxy = proxyObject_mock sd.queueCECache = {"ce1.site1.com_condor": {"CE": ceMock, "Hash": "3d0dd0c60fffa900c511d7442e9c7634"}}