Skip to content
2 changes: 1 addition & 1 deletion src/DIRAC/Core/Base/DB.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, dbname, fullname, debug=False, parentLogger=None):

result = getDBParameters(fullname)
if not result["OK"]:
raise RuntimeError(f"Cannot get database parameters: {result['Message']}")
raise RuntimeError(f"Cannot get database parameters for '{dbname}': {result['Message']}")

dbParameters = result["Value"]
self.dbHost = dbParameters["Host"]
Expand Down
27 changes: 15 additions & 12 deletions src/DIRAC/Core/DISET/private/Service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Service class implements the server side part of the DISET protocol
There are 2 main parts in this class:
Service class implements the server side part of the DISET protocol
There are 2 main parts in this class:

- All useful functions for initialization
- All useful functions to handle the requests
- All useful functions for initialization
- All useful functions to handle the requests
"""
# pylint: skip-file
# __searchInitFunctions gives RuntimeError: maximum recursion depth exceeded
Expand Down Expand Up @@ -102,17 +102,10 @@ def initialize(self):
}
self.securityLogging = Operations().getValue("EnableSecurityLogging", False)

# Initialize Monitoring
# The import needs to be here because of the CS must be initialized before importing
# this class (see https://github.com/DIRACGrid/DIRAC/issues/4793)
from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter

self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring")

# Call static initialization function
try:
self._handler["class"]._rh__initializeClass(
dict(self._serviceInfoDict), self._lockManager, self._msgBroker, self.activityMonitoringReporter
dict(self._serviceInfoDict), self._lockManager, self._msgBroker, None
)
if self._handler["init"]:
for initFunc in self._handler["init"]:
Expand All @@ -132,6 +125,10 @@ def initialize(self):
gLogger.exception(errMsg)
return S_ERROR(errMsg)
if self.activityMonitoring:
from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter

self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring")

gThreadScheduler.addPeriodicTask(30, self.__reportActivity)
gThreadScheduler.addPeriodicTask(100, self.__activityMonitoringReporting)

Expand Down Expand Up @@ -563,6 +560,9 @@ def _executeAction(self, trid, proposalTuple, handlerObj):
retStatus = "OK"
else:
retStatus = "ERROR"
from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter

self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring")
self.activityMonitoringReporter.addRecord(
{
"timestamp": int(TimeUtilities.toEpochMilliSeconds()),
Expand Down Expand Up @@ -592,6 +592,9 @@ def _mbReceivedMsg(self, trid, msgObj):
handlerObj = result["Value"]
response = handlerObj._rh_executeMessageCallback(msgObj)
if self.activityMonitoring and response["OK"]:
from DIRAC.MonitoringSystem.Client.MonitoringReporter import MonitoringReporter

self.activityMonitoringReporter = MonitoringReporter(monitoringType="ServiceMonitoring")
self.activityMonitoringReporter.addRecord(
{
"timestamp": int(TimeUtilities.toEpochMilliSeconds()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _insertFiles(self, lfns, uid, gid, connection=False):
if insertTuples:
fields = "FileID,GUID,Checksum,ChecksumType,CreationDate,ModificationDate,Mode"
req = f"INSERT INTO FC_FileInfo ({fields}) VALUES {','.join(insertTuples)}"
res = self.db._update(req)
res = self.db._update(req, conn=connection)
if not res["OK"]:
self._deleteFiles(toDelete, connection=connection)
for lfn in list(lfns):
Expand Down Expand Up @@ -841,7 +841,7 @@ def repairFileTables(self, connection=False):

fields = "FileID,GUID,CreationDate,ModificationDate,Mode"
req = f"INSERT INTO FC_FileInfo ({fields}) VALUES {','.join(insertTuples)}"
result = self.db._update(req)
result = self.db._update(req, conn=connection)
if not result["OK"]:
return result

Expand Down
87 changes: 59 additions & 28 deletions src/DIRAC/RequestManagementSystem/DB/RequestDB.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# We disable pylint no-callable because of https://github.com/PyCQA/pylint/issues/8138

""" Frontend for ReqDB
"""Frontend for ReqDB

:mod: RequestDB
:mod: RequestDB

=======================
=======================

.. module: RequestDB
.. module: RequestDB

:synopsis: db holding Requests
:synopsis: db holding Requests

db holding Request, Operation and File
db holding Request, Operation and File
"""

import datetime
import errno
import random

from urllib.parse import quote_plus

from sqlalchemy import (
Expand All @@ -32,6 +32,7 @@
create_engine,
distinct,
func,
inspect,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import backref, joinedload, registry, relationship, sessionmaker
Expand Down Expand Up @@ -187,6 +188,38 @@ class RequestDB:
db holding requests
"""

@staticmethod
def _get_column(table_name, column_name):
"""Resolve supported ORM column attributes without evaluating input."""

models = {"Request": Request, "Operation": Operation}
aliases = {"Status": "_Status"}

model = models.get(table_name)
if model is None:
raise ValueError(f"Unknown table '{table_name}'")

resolved_name = aliases.get(column_name, column_name)
if resolved_name not in inspect(model).column_attrs:
raise ValueError(f"Unknown {table_name} attribute '{column_name}'")

return getattr(model, resolved_name)

@classmethod
def _apply_web_filter(cls, query, table_name, column_name, value):
column = cls._get_column(table_name, column_name)
if isinstance(value, list):
return query.filter(column.in_(value))
return query.filter(column == value)

@classmethod
def _get_order_expression(cls, table_name, column_name, direction):
column = cls._get_column(table_name, column_name)
normalized_direction = direction.lower()
if normalized_direction not in {"asc", "desc"}:
raise ValueError(f"Unknown sort direction '{direction}'")
return getattr(column, normalized_direction)()

def __getDBConnectionInfo(self, fullname):
"""Collect from the CS all the info needed to connect to the DB.
This should be in a base class eventually
Expand Down Expand Up @@ -704,13 +737,12 @@ def getRequestSummaryWeb(self, selectDict, sortList, startItem, maxItems):
elif key == "Status":
key = "_Status"

if isinstance(value, list):
summaryQuery = summaryQuery.filter(eval(f"{tableName}.{key}.in_({value})"))
else:
summaryQuery = summaryQuery.filter(eval(f"{tableName}.{key}") == value)
summaryQuery = self._apply_web_filter(summaryQuery, tableName, key, value)

if sortList:
summaryQuery = summaryQuery.order_by(eval(f"Request.{sortList[0][0]}.{sortList[0][1].lower()}()"))
summaryQuery = summaryQuery.order_by(
self._get_order_expression("Request", sortList[0][0], sortList[0][1])
)

try:
requestLists = summaryQuery.all()
Expand Down Expand Up @@ -744,6 +776,8 @@ def getRequestSummaryWeb(self, selectDict, sortList, startItem, maxItems):
resultDict["TotalRecords"] = nRequests

return S_OK(resultDict)
except ValueError as e:
return S_ERROR(str(e))
#
except Exception as e:
self.log.exception("getRequestSummaryWeb: unexpected exception", lException=e)
Expand All @@ -763,16 +797,14 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict):

session = self.DBSession()

if groupingAttribute == "Type":
groupingAttribute = "Operation.Type"
elif groupingAttribute == "Status":
groupingAttribute = "Request._Status"
else:
groupingAttribute = f"Request.{groupingAttribute}"

try:
if groupingAttribute == "Type":
groupingColumn = self._get_column("Operation", "Type")
else:
groupingColumn = self._get_column("Request", groupingAttribute)

summaryQuery = session.query(
eval(groupingAttribute), func.count(Request.RequestID) # pylint: disable=not-callable,no-member
groupingColumn, func.count(Request.RequestID) # pylint: disable=not-callable,no-member
)

for key, value in selectDict.items():
Expand All @@ -788,12 +820,9 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict):
elif key == "Status":
key = "_Status"

if isinstance(value, list):
summaryQuery = summaryQuery.filter(eval(f"{objectType}.{key}.in_({value})"))
else:
summaryQuery = summaryQuery.filter(eval(f"{objectType}.{key}") == value)
summaryQuery = self._apply_web_filter(summaryQuery, objectType, key, value)

summaryQuery = summaryQuery.group_by(eval(groupingAttribute))
summaryQuery = summaryQuery.group_by(groupingColumn)

try:
requestLists = summaryQuery.all()
Expand All @@ -805,6 +834,8 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict):

return S_OK(resultDict)

except ValueError as e:
return S_ERROR(str(e))
except Exception as e:
self.log.exception("getRequestSummaryWeb: unexpected exception", lException=e)
return S_ERROR(f"getRequestSummaryWeb: unexpected exception : {e}")
Expand All @@ -817,11 +848,11 @@ def getDistinctValues(self, tableName, columnName):

session = self.DBSession()
distinctValues = []
if columnName == "Status":
columnName = "_Status"
try:
result = session.query(distinct(eval(f"{tableName}.{columnName}"))).all()
result = session.query(distinct(self._get_column(tableName, columnName))).all()
distinctValues = [dist[0] for dist in result]
except ValueError as e:
return S_ERROR(str(e))
except NoResultFound:
pass
except Exception as e:
Expand Down
34 changes: 15 additions & 19 deletions src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,21 +302,11 @@ def _submitPilotsPerQueue(self, queueName: str):
self.failedQueues[queueName] += 1
return S_OK(0)

# Adjust queueCPUTime: needed to generate the proxy
if "CPUTime" not in queueDictionary["ParametersDict"]:
self.log.error("CPU time limit is not specified, skipping", f"queue {queueName}")
return S_ERROR(f"CPU time limit is not specified, skipping queue {queueName}")

queueCPUTime = int(queueDictionary["ParametersDict"]["CPUTime"])
if queueCPUTime > self.maxQueueLength:
queueCPUTime = self.maxQueueLength

# Get CE instance
ce = self.queueDict[queueName]["CE"]

# Set credentials
cpuTime = queueCPUTime + 86400
result = self._setCredentials(ce, cpuTime)
# Set credentials: needed for authenticated CE operations (e.g. ce.available() on AREX)
result = self._setCredentials(ce, 3600)
if not result["OK"]:
self.log.error("Failed to set credentials:", result["Message"])
return result
Expand Down Expand Up @@ -901,20 +891,18 @@ def __supportToken(self, ce: ComputingElement) -> bool:
return "Token" in ce.ceParameters.get("Tag", []) or f"Token:{self.vo}" in ce.ceParameters.get("Tag", [])

def _setCredentials(self, ce: ComputingElement, proxyMinimumRequiredValidity: int):
"""
"""Add a proxy and a token to the ComputingElement.

:param ce: ComputingElement instance
:param proxyMinimumRequiredValidity: number of seconds needed to perform an operation with the proxy
:param tokenMinimumRequiredValidity: number of seconds needed to perform an operation with the token
"""
getNewProxy = False

# If the CE does not already embed a proxy, we need one
if not ce.proxy:
getNewProxy = True

# If the CE embeds a proxy that is too short to perform a given operation, we need a new one
if ce.proxy:
else:
# If the CE embeds a proxy that is too short to perform a given operation, we need a new one
result = ce.proxy.getRemainingSecs()
if not result["OK"]:
return result
Expand All @@ -924,11 +912,19 @@ def _setCredentials(self, ce: ComputingElement, proxyMinimumRequiredValidity: in

# Generate a new proxy if needed
if getNewProxy:
self.log.verbose("Getting pilot proxy", f"for {self.pilotDN}/{self.vo} {proxyMinimumRequiredValidity} long")
proxyRequestedValidity = max(proxyMinimumRequiredValidity, 86400)
self.log.verbose("Getting pilot proxy", f"for {self.pilotDN}/{self.vo} {proxyRequestedValidity} long")
pilotGroup = Operations(vo=self.vo).getValue("Pilot/GenericPilotGroup")
result = gProxyManager.getPilotProxyFromDIRACGroup(self.pilotDN, pilotGroup, proxyMinimumRequiredValidity)
result = gProxyManager.getPilotProxyFromDIRACGroup(self.pilotDN, pilotGroup, proxyRequestedValidity)
if not result["OK"]:
return result
result_validity = result["Value"].getRemainingSecs()
if not result_validity["OK"]:
return result_validity
if result_validity["Value"] < proxyRequestedValidity:
self.log.warn(
f"The validity of the generated proxy ({result_validity['Value']} seconds) is less than the requested {proxyRequestedValidity} seconds"
)
ce.setProxy(result["Value"])

# Get valid token if needed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Test class for SiteDirector
"""
"""Test class for SiteDirector"""
# pylint: disable=protected-access

import datetime
Expand Down Expand Up @@ -145,6 +144,7 @@

mockPMProxy = MagicMock()
mockPMProxy.dumpAllToString.return_value = {"OK": True, "Value": "fakeProxy"}
mockPMProxy.getRemainingSecs.return_value = {"OK": True, "Value": 1000}
mockPMProxyReply = MagicMock()
mockPMProxyReply.return_value = {"OK": True, "Value": mockPMProxy}

Expand Down Expand Up @@ -183,6 +183,10 @@ def mock_getElementStatus(ceNamesList, *args, **kwargs):
mocker.patch(
"DIRAC.WorkloadManagementSystem.Agent.SiteDirector.gProxyManager.downloadProxy", side_effect=mockPMProxyReply
)
mocker.patch(
"DIRAC.WorkloadManagementSystem.Agent.SiteDirector.gProxyManager.getPilotProxyFromDIRACGroup",
side_effect=mockPMProxyReply,
)
sd = SiteDirector()

# Set logger
Expand Down Expand Up @@ -288,7 +292,8 @@ def test_getPilotWrapper(mocker, sd, pilotWrapperDirectory):
assert os.path.exists(res) and os.path.isfile(res)


def test__submitPilotsToQueue(sd):
@pytest.mark.parametrize("proxy_validity", [1, 1000, 900000])
def test__submitPilotsToQueue(sd, proxy_validity):
"""Testing SiteDirector()._submitPilotsToQueue()"""
# Create a MagicMock that does not have the workingDirectory
# attribute (https://cpython-test-docs.readthedocs.io/en/latest/library/unittest.mock.html#deleting-attributes)
Expand All @@ -297,6 +302,7 @@ def test__submitPilotsToQueue(sd):
del ceMock.workingDirectory
proxyObject_mock = MagicMock()
proxyObject_mock.dumpAllToString.return_value = S_OK("aProxy")
proxyObject_mock.getRemainingSecs.return_value = S_OK(proxy_validity)
ceMock.proxy = proxyObject_mock

sd.queueCECache = {"ce1.site1.com_condor": {"CE": ceMock, "Hash": "3d0dd0c60fffa900c511d7442e9c7634"}}
Expand Down
Loading