Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion golang/modelfeature/ConfigurationHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -58,6 +59,10 @@ func (t *ConfigurationHandler[T]) Load() (bool, error) {
if s3Error != nil {
return false, fmt.Errorf("error fetching s3 file: %v", s3Error)
}
defer func() {
_, _ = io.Copy(io.Discard, getObjectOutput.Body)
_ = getObjectOutput.Body.Close()
}()
if !t.localCacheFactory.ShouldRefresh(t.fileIdentifierCacheKey, *getObjectOutput.ETag) {
Logger.Info().Msgf("Skipping refresh for %s", filename)
return false, nil
Expand All @@ -71,13 +76,15 @@ func (t *ConfigurationHandler[T]) Load() (bool, error) {
if localErr != nil {
return false, fmt.Errorf("error opening file: %v", localErr)
}
defer func() {
_ = filePointer.Close()
}()
if !t.localCacheFactory.ShouldRefreshLocal(t.fileIdentifierCacheKey, filePointer) {
Logger.Info().Msgf("Skipping refresh for %s", filename)
return false, nil
}

jsonData, err = t.daoFactory.GetDataFromLocal(filePointer)
_ = filePointer.Close()
}

if err != nil {
Expand Down
134 changes: 76 additions & 58 deletions golang/modelfeature/ModelResult.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,78 +85,96 @@ func (t *ModelResultHandler) Load(sspIdentifier string) error {
var putItemTotalSize int64

for modelIdentifier, modelDefinition := range modelConfiguration.ModelDefinitionByIdentifier {
modelResultFileName := t.BuildModelResultFileName(sspIdentifier, modelIdentifier)
modelResultValue, exists := ModelTypeValue[modelDefinition.Type]
if !exists {
// default to a low value model type (0.0) if not defined
Logger.Info().Msgf("model type [%s] not found in the [%+v]. Defaulting to LowValue", modelDefinition.Type, ModelTypeValue)
modelResultValue = 0.0
}

var modelResult []byte
var repositoryError error
if strings.HasPrefix(t.folderPrefix, repository.S3Prefix) {
// get bucket name from "s3://<bucket-name>"
s3BucketName := strings.TrimPrefix(t.folderPrefix, repository.S3Prefix)
getObjectOutput, s3Error := t.daoFactory.GetS3Object(context.TODO(), s3BucketName, modelResultFileName)
if s3Error != nil {
Logger.Error().Msgf("Error fetching S3 file %s/%s: %v", s3BucketName, modelResultFileName, s3Error)
continue

}
if !t.localCacheFactory.ShouldRefresh(repository.CacheKeyModelResultFileIdentifier, *(getObjectOutput.ETag)) {
Logger.Info().Msgf("Skipping refresh for %s", modelResultFileName)
continue
}

modelResult, repositoryError = t.daoFactory.ReadContent(getObjectOutput.Body)
} else {
// read from local file path
filePath := filepath.Join(t.folderPrefix, modelResultFileName)
filePointer, err := os.Open(filePath)
if err != nil {
Logger.Error().Msgf("Error opening file %s: %v", filePath, err)
continue
}
if !t.localCacheFactory.ShouldRefreshLocal(repository.CacheKeyModelResultFileIdentifier, filePointer) {
Logger.Info().Msgf("Skipping refresh for %s", filePath)
continue
}

modelResult, repositoryError = t.daoFactory.GetDataFromLocal(filePointer)
_ = filePointer.Close()
if err := t.loadSingleModel(sspIdentifier, modelIdentifier, modelResultValue, &putItemCounter, &putItemTotalSize); err != nil {
return err
}
}

Logger.Info().Msgf("Processed %d items with total size of %d bytes", putItemCounter, putItemTotalSize)
return nil
}

func (t *ModelResultHandler) loadSingleModel(sspIdentifier string, modelIdentifier string, modelResultValue float32, putItemCounter *int, putItemTotalSize *int64) error {
modelResultFileName := t.BuildModelResultFileName(sspIdentifier, modelIdentifier)

var modelResult []byte
var repositoryError error

if strings.HasPrefix(t.folderPrefix, repository.S3Prefix) {
// get bucket name from "s3://<bucket-name>"
s3BucketName := strings.TrimPrefix(t.folderPrefix, repository.S3Prefix)
getObjectOutput, s3Error := t.daoFactory.GetS3Object(context.TODO(), s3BucketName, modelResultFileName)
if s3Error != nil {
Logger.Error().Msgf("Error fetching S3 file %s/%s: %v", s3BucketName, modelResultFileName, s3Error)
return nil
}

if repositoryError != nil {
return fmt.Errorf("error getting data %w", repositoryError)
defer func() {
_, _ = io.Copy(io.Discard, getObjectOutput.Body)
_ = getObjectOutput.Body.Close()
}()

if !t.localCacheFactory.ShouldRefresh(repository.CacheKeyModelResultFileIdentifier, *(getObjectOutput.ETag)) {
Logger.Info().Msgf("Skipping refresh for %s", modelResultFileName)
return nil
}

// clear all entries from cache since new model is detected
t.localCacheFactory.ClearLocalCache(modelIdentifier)

reader := csv.NewReader(bytes.NewReader(modelResult))
//reader.ReuseRecord = true // Reuse the same slice for each record to reduce allocations
for {
record, readerError := reader.Read()
if readerError == io.EOF {
break
}
if readerError != nil {
Logger.Error().Msgf("Error reading record: %v", readerError)
continue
}

if !t.localCacheFactory.PutToLocalCache(modelIdentifier, record[0], modelResultValue) {
Logger.Error().Msgf("Error putting model result record to the local cache [%v] with Key [%v]", modelIdentifier, record[0])
continue
}

putItemCounter++
putItemTotalSize += int64(len(record[0])) // Only count the size of the Key, not the entire modelResultKeys
modelResult, repositoryError = t.daoFactory.ReadContent(getObjectOutput.Body)
} else {
// read from local file path
filePath := filepath.Join(t.folderPrefix, modelResultFileName)
filePointer, err := os.Open(filePath)
if err != nil {
Logger.Error().Msgf("Error opening file %s: %v", filePath, err)
return nil
}

defer func() {
_ = filePointer.Close()
}()

if !t.localCacheFactory.ShouldRefreshLocal(repository.CacheKeyModelResultFileIdentifier, filePointer) {
Logger.Info().Msgf("Skipping refresh for %s", filePath)
return nil
}

modelResult, repositoryError = t.daoFactory.GetDataFromLocal(filePointer)
}

Logger.Info().Msgf("Processed %d items with total size of %d bytes", putItemCounter, putItemTotalSize)
if repositoryError != nil {
return fmt.Errorf("error getting data %w", repositoryError)
}

// clear all entries from cache since new model is detected
t.localCacheFactory.ClearLocalCache(modelIdentifier)

reader := csv.NewReader(bytes.NewReader(modelResult))
// reader.ReuseRecord = true // Reuse the same slice for each record to reduce allocations
for {
record, readerError := reader.Read()
if readerError == io.EOF {
break
}
if readerError != nil {
Logger.Error().Msgf("Error reading record: %v", readerError)
continue
}

if !t.localCacheFactory.PutToLocalCache(modelIdentifier, record[0], modelResultValue) {
Logger.Error().Msgf("Error putting model result record to the local cache [%v] with Key [%v]", modelIdentifier, record[0])
continue
}

*putItemCounter++
*putItemTotalSize += int64(len(record[0])) // Only count the size of the Key, not the entire modelResultKeys
}
return nil
}

Expand Down