Skip to content

Commit 23a98ca

Browse files
Merge pull request #10 from SamBriskman/sam
Optimize device plugins
2 parents c1966ec + 838b41e commit 23a98ca

File tree

40 files changed

+629
-560
lines changed

40 files changed

+629
-560
lines changed

config/agent/agent.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ session:
3030
# https://asyncssh.readthedocs.io/en/latest/api.html#supported-algorithms
3131
# enabled all algorithms by default, add extra algorithms if needed
3232
encryption_algs: []
33+
# Set terminal size, optional item
34+
term_size:
35+
width: 1000
36+
height: 100
3337
profiles:
3438
# global profile, used when no other profile matche
3539
global:
@@ -38,6 +42,9 @@ session:
3842
expiration_time: "16:00"
3943
# If the idle time exceeds the maximum, close the session. Unit: Seconds
4044
max_idle_time: 600.0
45+
# https://asyncssh.readthedocs.io/en/stable/api.html#publickeyalgs
46+
# A list of server host key algorithms to allow during the SSH handshake
47+
server_host_key_algs: []
4148
# profile based on vendor/type/version, priority is higher than default
4249
vendor:
4350
cisco:
@@ -46,13 +53,16 @@ session:
4653
read_timeout: 60.0
4754
expiration_time: "16:00"
4855
max_idle_time: 600.0
56+
server_host_key_algs: []
4957
# 9.8:
5058
# read_timeout: 12.0
5159
# expiration_time: "16:00"
5260
# max_idle_time: 600.0
61+
# server_host_key_algs: []
5362
# profile based on IP address, priority is higher than vendor/type/version
5463
# ip:
5564
# 192.168.60.198:
5665
# read_timeout: 15.0
5766
# expiration_time: "16:00"
5867
# max_idle_time: 600.0
68+
# server_host_key_algs: []

packages/agent/src/netdriver_agent/client/channel.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3+
34
from abc import abstractmethod
45
from dependency_injector.providers import Configuration
6+
from netdriver_core.utils.terminal import simulate_output, simulate_output_oct_to_chinese
57
from pydantic import IPvAnyAddress
68
from re import Match, Pattern
79
from typing import Optional, Tuple, List
810
import asyncssh
11+
import re
912

10-
from netdriver_core.exception.errors import ChannelError, ChannelReadTimeout
13+
from netdriver_core.exception.errors import ChannelError
1114
from netdriver_core.log import logman
1215
from netdriver_core.utils.asyncu import async_timeout
1316

@@ -95,19 +98,26 @@
9598
_DEFAULT_READ_BUFFER_SIZE = 8192
9699

97100

98-
def update_ssh_config(kwargs: dict, config: Configuration) -> dict:
101+
def update_ssh_config(kwargs: dict, profile: dict, config: Configuration) -> dict:
99102
""" Update SSH configuration with defaults and provided parameters """
100-
extra_kex_algs = set(config.session.ssh.kex_algs() or [])
101-
extra_encryption_algs = (config.session.ssh.encryption_algs() or [])
103+
ssh = config.session.ssh
104+
extra_kex_algs = set(ssh.kex_algs() or [])
105+
extra_encryption_algs = (ssh.encryption_algs() or [])
102106
ssh_config = _DEFAUTL_SSH_CONFIG.copy()
103107
ssh_config["kex_algs"] = list(ssh_config["kex_algs"].union(extra_kex_algs))
104108
ssh_config["encryption_algs"] = list(ssh_config["encryption_algs"].union(extra_encryption_algs))
105-
ssh_config["login_timeout"] = config.session.ssh.login_timeout() or ssh_config["login_timeout"]
106-
ssh_config["connect_timeout"] = config.session.ssh.connect_timeout() or ssh_config["connect_timeout"]
107-
ssh_config["keepalive_interval"] = config.session.ssh.keepalive_interval() or ssh_config["keepalive_interval"]
108-
ssh_config["keepalive_count_max"] = config.session.ssh.keepalive_count_max() or ssh_config["keepalive_count_max"]
109+
ssh_config["login_timeout"] = ssh.login_timeout() or ssh_config["login_timeout"]
110+
ssh_config["connect_timeout"] = ssh.connect_timeout() or ssh_config["connect_timeout"]
111+
ssh_config["keepalive_interval"] = ssh.keepalive_interval() or ssh_config["keepalive_interval"]
112+
ssh_config["keepalive_count_max"] = ssh.keepalive_count_max() or ssh_config["keepalive_count_max"]
113+
server_host_key_algs = profile.get("server_host_key_algs", [])
114+
if server_host_key_algs:
115+
ssh_config["server_host_key_algs"] = server_host_key_algs
116+
term_size = ()
109117
kwargs.update(ssh_config)
110-
return kwargs
118+
if ssh.term_size() and ssh.term_size.width() and ssh.term_size.height():
119+
term_size = (ssh.term_size.width(), ssh.term_size.height())
120+
return kwargs, term_size
111121

112122

113123
class Channel:
@@ -125,7 +135,6 @@ async def create(cls,
125135
username: Optional[str] = None,
126136
password: Optional[str] = None,
127137
encode: str = "utf-8",
128-
term_size: Tuple = None,
129138
logger: object = None,
130139
profile: dict = {},
131140
config: Configuration = None,
@@ -136,13 +145,13 @@ async def create(cls,
136145
cls._read_channel_until_timeout = profile.get("read_timeout", DEFAULT_SESSION_PROFILE.get("read_timeout", 10))
137146

138147
if protocol == "ssh":
139-
kwargs = update_ssh_config(kwargs, config)
148+
kwargs, term_size = update_ssh_config(kwargs, profile, config)
140149
conn = await asyncssh.connect(
141150
host=str(ip), port=port, username=username, password=password,
142151
encoding=encode, **kwargs)
143152
terminal = await conn.create_process(term_type="ansi", term_size=term_size)
144153
terminal.stdout.channel.set_encoding(encoding=encode, errors='replace')
145-
return SSHChannel(conn, terminal, logger=logger)
154+
return SSHChannel(conn, terminal, logger=logger, encode=encode)
146155
else:
147156
raise ValueError(f"protocol {protocol} not supported.")
148157

@@ -195,11 +204,13 @@ class SSHChannel(Channel):
195204

196205
def __init__(self, conn: asyncssh.SSHClientConnection,
197206
terminal: asyncssh.SSHClientProcess,
198-
logger: object = None) -> None:
207+
logger: object = None,
208+
encode: str = "utf-8") -> None:
199209
""" SSH Channel """
200210
self._conn = conn
201211
self._terminal = terminal
202212
self._logger = logger
213+
self._encode = encode
203214

204215
def _check_channel(self):
205216
if not self._conn:
@@ -235,7 +246,7 @@ async def read_channel_until(
235246
:return: str, the data read from the channel
236247
"""
237248
self._check_channel()
238-
output = ReadBuffer(cmd=cmd)
249+
output = ReadBuffer(cmd=cmd, encode=self._encode)
239250
while not self.read_at_eof():
240251
chunk = await self.read_channel(self._read_buffer_size)
241252
output.append(chunk)
@@ -276,20 +287,74 @@ class ReadBuffer:
276287
_cmd: str
277288
_is_cmd_displayed: bool = False
278289

279-
def __init__(self, cmd: str = '', line_break: str = '\n') -> None:
290+
def __init__(self, cmd: str = '', line_break: str = '\n', encode: str = None) -> None:
280291
""" Initialize read buffer """
281292
self._buffer = []
282293
self._last_line_pos = (0, 0)
283294
self._line_break = line_break
284295
self._cmd = cmd
285296
self._is_cmd_displayed = False
297+
self._encode = encode
286298

287-
def _check_cmd_displayed(self, line: str = '') -> bool:
299+
def _check_cmd_displayed(self, pattern: Pattern, line: str = '') -> bool:
288300
if not self._is_cmd_displayed and self._cmd and line:
289301
# check if the command is displayed in the line
290-
if self._cmd in line:
302+
if self._cmd_in_line(pattern, line):
291303
self._is_cmd_displayed = True
292304
log.trace(f"Command '{self._cmd}' is displayed in the line: {line}")
305+
306+
def _cmd_in_line(self, pattern: Pattern, line: str = '') -> bool:
307+
"""check if the line contains cmd"""
308+
309+
log.trace(f"Line repr = {repr(line)}")
310+
log.trace(f"Line escape = {line.encode('unicode_escape').decode('ascii')}")
311+
312+
# Topsec output extra ' \r' char
313+
# Fortinet output extra ' \x08' char
314+
if self._cmd in re.sub(r'\s[\r\x08]', '', line):
315+
return True
316+
317+
# Juniper input extra spaces, and the output will remove the extra spaces
318+
if self._cmd.replace(' ', '') in re.sub(r'[\x07\s]', '', line):
319+
return True
320+
321+
# Fortinet input Chinese and output octal char
322+
chinese = simulate_output_oct_to_chinese(output=line, encoding=self._encode)
323+
log.trace(f"Line oct to chinese: {repr(chinese)}")
324+
if self._cmd in chinese:
325+
return True
326+
327+
# Line Remove prompt
328+
for index in range(1, len(line)):
329+
if re.match(pattern, line[:index]):
330+
line = line[index:].lstrip()
331+
break
332+
log.trace(f"Line remove prompt = {repr(line)}")
333+
334+
# Topsec chinese escape failed, character ignored
335+
if '\ufffd' in line:
336+
line_splits = re.sub(r"(\s\r|\r\n)", '', line).split('\ufffd')
337+
log.trace(f"Line split by \\ufffd = {line_splits}")
338+
for line_split in line_splits:
339+
if line_split not in self._cmd:
340+
return False
341+
return True
342+
343+
line = simulate_output(line)
344+
log.trace(f"Line simulate output = {repr(line)}")
345+
346+
# Line simulate output remove prompt
347+
for index in range(1, len(line)):
348+
if re.match(pattern, line[:index]):
349+
line = line[index:].lstrip()
350+
break
351+
log.trace(f"Line simulate output remove prompt = {repr(line)}")
352+
353+
# Array or Cisco display ultra wide processing
354+
if '$' in line and line.replace('\x08', '').split('$')[0] in self._cmd:
355+
return True
356+
357+
return False
293358

294359
def _is_real_prompt(self) -> bool:
295360
if self._cmd:
@@ -332,7 +397,7 @@ def check_pattern(self, pattern: Pattern, is_update_checkpos: bool = True) -> Ma
332397
while lb_pos != -1:
333398
# found a line break, concat the line
334399
line = ''.join([line, self._buffer[i][line_start_pos:lb_pos], self._line_break])
335-
self._check_cmd_displayed(line)
400+
self._check_cmd_displayed(pattern, line)
336401
line_start_pos = lb_pos + len(self._line_break)
337402
log.trace(f"Checking buffer[{i}][:{line_start_pos}]: {line}")
338403
matched = pattern.search(line)
@@ -350,7 +415,7 @@ def check_pattern(self, pattern: Pattern, is_update_checkpos: bool = True) -> Ma
350415

351416
# no line break found, check the rest of buffer item
352417
line = ''.join([line, self._buffer[i][line_start_pos:]])
353-
self._check_cmd_displayed(line)
418+
self._check_cmd_displayed(pattern, line)
354419
line_start_pos += len(line)
355420
if i == buffer_size - 1:
356421
# if no line break found and no more buffer, check the last line

packages/agent/src/netdriver_agent/client/merger.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)