Skip to content

Update keras, tf and new model usage, numpy 2.0 updates#1206

Open
JGSweets wants to merge 28 commits into
capitalone:mainfrom
JGSweets:update-keras
Open

Update keras, tf and new model usage, numpy 2.0 updates#1206
JGSweets wants to merge 28 commits into
capitalone:mainfrom
JGSweets:update-keras

Conversation

@JGSweets
Copy link
Copy Markdown
Contributor

@JGSweets JGSweets commented Mar 13, 2026

this pr:

  • updates to allow keras usage > 3.4
  • allows usage of most recent TFs / metal instead of macos
  • Updates numpy for v2+

NOTES:

@JGSweets JGSweets requested a review from a team as a code owner March 13, 2026 22:13
:param fn: Plugin function
:return: function
"""
global plugins_dict
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global is only needed when a function rebinds a module variable, like plugins_dict = {...}.

@JGSweets JGSweets changed the title [WIP] Update keras, tf and new model usage Update keras, tf and new model usage May 6, 2026
@JGSweets
Copy link
Copy Markdown
Contributor Author

JGSweets commented May 6, 2026

@shania-m this is the final PR that will update TF and keras to be current!

@JGSweets JGSweets changed the title Update keras, tf and new model usage Update keras, tf and new model usage, numpy 2.0 updates May 8, 2026
@JGSweets
Copy link
Copy Markdown
Contributor Author

@shania-m I've updated numpy to allow for 2.0+ as well here!

@shania-m
Copy link
Copy Markdown
Contributor

@JGSweets please let me know when it’s ready for review

@JGSweets
Copy link
Copy Markdown
Contributor Author

@shania-m fixed all the issues related. Updating mypy was more complicated than expected when trying to handle the numpy update!

@JGSweets
Copy link
Copy Markdown
Contributor Author

@shania-m it is ready for review!

"""Compiles the loss for the given model and number of labels."""
# Compile the model
softmax_output_layer_name = model.output_names[0]
# losses = {softmax_output_layer_name: "categorical_crossentropy"}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove

# Compile the model
softmax_output_layer_name = model.output_names[0]
# losses = {softmax_output_layer_name: "categorical_crossentropy"}
losses = ["categorical_crossentropy", None, None]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question — the loss assignment changed from dict-based (by output name) to list-based (by position):

CharacterLevelCnnModel (3 outputs)

losses = ["categorical_crossentropy", None, None]

CharLoadTFModel (2 outputs)

losses = ["categorical_crossentropy", None]

Can you confirm the output ordering is stable and these align correctly with the model outputs? Just want to make sure the positional
assignment matches up since the dict approach was order-independent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call out! I believe the previous layers were list based which is why in keras 3 it required the list losses. This code matched that, however, like you I prefer the order-independent approach and am looking into the requirements of that and ensuring the backwards compatibility of loading a model that was list based initially.

@shania-m
Copy link
Copy Markdown
Contributor

Few small things, thanks for the contribution!


def _compile_model(self, num_labels: int) -> None:
"""Compile the model with dict-based losses and metrics."""
losses = {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure we utilize dict based solution

@socket-security
Copy link
Copy Markdown

socket-security Bot commented May 22, 2026

Review the following changes in direct dependencies. Learn more about Socket for GitHub.

Diff Package Supply Chain
Security
Vulnerability Quality Maintenance License
Updatednumpy@​1.26.4 ⏵ 2.4.675 +110010010070
Updatedmemray@​1.11.0 ⏵ 1.19.372 -20100 +1100100100
Updatedkeras@​3.4.0 ⏵ 3.14.196100 +7510010080
Updatedpre-commit@​2.19.0 ⏵ 4.3.093 -1100100100100

View full report

cls, softmax_output: tf.Tensor, argmax_output: tf.Tensor | None = None
) -> dict[str, tf.Tensor]:
"""Return normalized dict outputs for training and inference."""
if argmax_output is None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure normalized dict based model outputs

@socket-security
Copy link
Copy Markdown

socket-security Bot commented May 22, 2026

Warning

Review the following alerts detected in dependencies.

According to your organization's Security Policy, it is recommended to resolve "Warn" alerts. Learn more about Socket for GitHub.

Action Severity Alert  (click "▶" to expand/collapse)
Warn High
License policy violation: pypi numpy under FSFAP

License: FSFAP - The applicable license policy does not permit this license (5) (numpy-2.4.6/vendored-meson/meson/test cases/frameworks/6 gettext/data3/metainfo.its)

From: requirements.txtpypi/numpy@2.4.6

ℹ Read more on: This package | This alert | What is a license policy violation?

Next steps: Take a moment to review the security alert above. Review the linked package source code to understand the potential risk. Ensure the package is not malicious before proceeding. If you're unsure how to proceed, reach out to your security team or ask the Socket team for help at support@socket.dev.

Suggestion: Find a package that does not violate your license policy or adjust your policy to allow this package's license.

Mark the package as acceptable risk. To ignore this alert only in this pull request, reply with the comment @SocketSecurity ignore pypi/numpy@2.4.6. You can also ignore all packages with @SocketSecurity ignore-all. To ignore an alert for all future pull requests, use Socket's Dashboard to change the triage state of this alert.

View full report

@classmethod
def _normalize_model_outputs(cls, model: tf.keras.Model) -> tf.keras.Model:
"""Convert list-style outputs to the normalized dict structure."""
return labeler_utils.normalize_tf_model_outputs(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conversion of previous style for consistency requirement by keras 3


# boolean if the label mapping requires the mapping for index 0 reserved
requires_zero_mapping: bool = True
_SOFTMAX_OUTPUT = "softmax_output"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalize layer names


# boolean if the label mapping requires the mapping for index 0 reserved
requires_zero_mapping = False
_SOFTMAX_OUTPUT = "softmax_output"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalize layer names

num_labels, activation="softmax", name="softmax_output"
num_labels,
activation="softmax",
name=self._new_softmax_head_name(),
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allows iteration on layer name due to keras reqs

Comment on lines +409 to +416
acc_value = next(
(value for key, value in model_results.items() if key.endswith("acc")),
np.nan,
)
f1_value = next(
(value for key, value in model_results.items() if "f1" in key.lower()),
np.nan,
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

due to dict based output

BaseModel.__init__(self, label_mapping, parameters)

@classmethod
def _create_model_outputs(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to char_load_tf_model.py but with the threshargmax

Comment on lines +760 to +767
acc_value = next(
(value for key, value in model_results.items() if key.endswith("acc")),
np.nan,
)
f1_value = next(
(value for key, value in model_results.items() if "f1" in key.lower()),
np.nan,
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

due to dict based change

return None


def normalize_tf_model_outputs(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this allows us backwards compatibility with the list based models.

@JGSweets
Copy link
Copy Markdown
Contributor Author

@shania-m sorry to add so much more, but this should be safer since it has the dict mapping back!

@shania-m
Copy link
Copy Markdown
Contributor

Thanks for the contributions!
Some non blocking recommendations:

  1. Add an upper bound to numpy — numpy>=1.22.0,<3.0.0 instead of numpy>=1.0.0 to prevent future breakage from numpy 3.
  2. Add an upper bound to keras — keras>3.4.0,<4.0.0 to protect against future Keras major versions.
  3. Add a test for loading old-format models — Verify that _normalize_model_outputs correctly handles models saved with the previous list-style
    output format.
  4. Track numpy private API usage — Add a code comment noting that _histograms_impl is a private module, linking to any numpy discussion about a public replacement.
  5. Update CHANGELOG — Document the numpy 2.0 and keras >3.4 support as a notable change.

@JGSweets
Copy link
Copy Markdown
Contributor Author

  1. Updated.
  2. Updated.
  3. Added unit tests to address this.
  4. Added a comment.
  5. Added a CHANGELOG

Thanks!

@shania-m
Copy link
Copy Markdown
Contributor

Warning

Review the following alerts detected in dependencies.

According to your organization's Security Policy, it is recommended to resolve "Warn" alerts. Learn more about Socket for GitHub.
Action Severity Alert  (click "▶" to expand/collapse)
Warn High
License policy violation: pypi numpy under FSFAP
License: FSFAP - The applicable license policy does not permit this license (5) (numpy-2.4.6/vendored-meson/meson/test cases/frameworks/6 gettext/data3/metainfo.its)

From: requirements.txtpypi/numpy@2.4.6

ℹ Read more on: This package | This alert | What is a license policy violation?

Next steps: Take a moment to review the security alert above. Review
the linked package source code to understand the potential risk. Ensure the
package is not malicious before proceeding. If you're unsure how to proceed,
reach out to your security team or ask the Socket team for help at
support@socket.dev.

Suggestion: Find a package that does not violate your license policy or adjust your policy to allow this package's license.

_Mark the package as acceptable risk_. To ignore this alert only
in this pull request, reply with the comment
`@SocketSecurity ignore pypi/numpy@2.4.6`. You can
also ignore all packages with `@SocketSecurity ignore-all`.
To ignore an alert for all future pull requests, use Socket's Dashboard to
change the [triage state of this alert](https://socket.dev/dashboard/org/CapitalOne/diff-scan/fb93fa10-16b6-463d-98b5-4ecd9aace5bf/alert/QlzpKL-e6SclipJA2kKiS3Hc2nht0YycgksTG8-rSmHk).

View full report

@JGSweets i need to review these before approving

@JGSweets
Copy link
Copy Markdown
Contributor Author

@shania-m of course! ty for working with me on this!

@JGSweets
Copy link
Copy Markdown
Contributor Author

After this goes in, what would be the steps needed to make a release? I assume I cannot be part of that, but if so I'm happy to help achieve that as well!

@JGSweets
Copy link
Copy Markdown
Contributor Author

JGSweets commented May 28, 2026

I realized it might also be good to add py3.12 / py3.13 to the test list, especially since py3.10 reaches EOL this fall.

Will do that in a subsequent PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants