11#!/usr/bin/env python3
22# -*- coding: utf-8 -*-
3+
34from abc import abstractmethod
45from dependency_injector .providers import Configuration
6+ from netdriver_core .utils .terminal import simulate_output , simulate_output_oct_to_chinese
57from pydantic import IPvAnyAddress
68from re import Match , Pattern
79from typing import Optional , Tuple , List
810import asyncssh
11+ import re
912
10- from netdriver_core .exception .errors import ChannelError , ChannelReadTimeout
13+ from netdriver_core .exception .errors import ChannelError
1114from netdriver_core .log import logman
1215from netdriver_core .utils .asyncu import async_timeout
1316
9598_DEFAULT_READ_BUFFER_SIZE = 8192
9699
97100
98- def update_ssh_config (kwargs : dict , config : Configuration ) -> dict :
101+ def update_ssh_config (kwargs : dict , profile : dict , config : Configuration ) -> dict :
99102 """ Update SSH configuration with defaults and provided parameters """
100- extra_kex_algs = set (config .session .ssh .kex_algs () or [])
101- extra_encryption_algs = (config .session .ssh .encryption_algs () or [])
103+ ssh = config .session .ssh
104+ extra_kex_algs = set (ssh .kex_algs () or [])
105+ extra_encryption_algs = (ssh .encryption_algs () or [])
102106 ssh_config = _DEFAUTL_SSH_CONFIG .copy ()
103107 ssh_config ["kex_algs" ] = list (ssh_config ["kex_algs" ].union (extra_kex_algs ))
104108 ssh_config ["encryption_algs" ] = list (ssh_config ["encryption_algs" ].union (extra_encryption_algs ))
105- ssh_config ["login_timeout" ] = config .session .ssh .login_timeout () or ssh_config ["login_timeout" ]
106- ssh_config ["connect_timeout" ] = config .session .ssh .connect_timeout () or ssh_config ["connect_timeout" ]
107- ssh_config ["keepalive_interval" ] = config .session .ssh .keepalive_interval () or ssh_config ["keepalive_interval" ]
108- ssh_config ["keepalive_count_max" ] = config .session .ssh .keepalive_count_max () or ssh_config ["keepalive_count_max" ]
109+ ssh_config ["login_timeout" ] = ssh .login_timeout () or ssh_config ["login_timeout" ]
110+ ssh_config ["connect_timeout" ] = ssh .connect_timeout () or ssh_config ["connect_timeout" ]
111+ ssh_config ["keepalive_interval" ] = ssh .keepalive_interval () or ssh_config ["keepalive_interval" ]
112+ ssh_config ["keepalive_count_max" ] = ssh .keepalive_count_max () or ssh_config ["keepalive_count_max" ]
113+ server_host_key_algs = profile .get ("server_host_key_algs" , [])
114+ if server_host_key_algs :
115+ ssh_config ["server_host_key_algs" ] = server_host_key_algs
116+ term_size = ()
109117 kwargs .update (ssh_config )
110- return kwargs
118+ if ssh .term_size () and ssh .term_size .width () and ssh .term_size .height ():
119+ term_size = (ssh .term_size .width (), ssh .term_size .height ())
120+ return kwargs , term_size
111121
112122
113123class Channel :
@@ -125,7 +135,6 @@ async def create(cls,
125135 username : Optional [str ] = None ,
126136 password : Optional [str ] = None ,
127137 encode : str = "utf-8" ,
128- term_size : Tuple = None ,
129138 logger : object = None ,
130139 profile : dict = {},
131140 config : Configuration = None ,
@@ -136,13 +145,13 @@ async def create(cls,
136145 cls ._read_channel_until_timeout = profile .get ("read_timeout" , DEFAULT_SESSION_PROFILE .get ("read_timeout" , 10 ))
137146
138147 if protocol == "ssh" :
139- kwargs = update_ssh_config (kwargs , config )
148+ kwargs , term_size = update_ssh_config (kwargs , profile , config )
140149 conn = await asyncssh .connect (
141150 host = str (ip ), port = port , username = username , password = password ,
142151 encoding = encode , ** kwargs )
143152 terminal = await conn .create_process (term_type = "ansi" , term_size = term_size )
144153 terminal .stdout .channel .set_encoding (encoding = encode , errors = 'replace' )
145- return SSHChannel (conn , terminal , logger = logger )
154+ return SSHChannel (conn , terminal , logger = logger , encode = encode )
146155 else :
147156 raise ValueError (f"protocol { protocol } not supported." )
148157
@@ -195,11 +204,13 @@ class SSHChannel(Channel):
195204
196205 def __init__ (self , conn : asyncssh .SSHClientConnection ,
197206 terminal : asyncssh .SSHClientProcess ,
198- logger : object = None ) -> None :
207+ logger : object = None ,
208+ encode : str = "utf-8" ) -> None :
199209 """ SSH Channel """
200210 self ._conn = conn
201211 self ._terminal = terminal
202212 self ._logger = logger
213+ self ._encode = encode
203214
204215 def _check_channel (self ):
205216 if not self ._conn :
@@ -235,7 +246,7 @@ async def read_channel_until(
235246 :return: str, the data read from the channel
236247 """
237248 self ._check_channel ()
238- output = ReadBuffer (cmd = cmd )
249+ output = ReadBuffer (cmd = cmd , encode = self . _encode )
239250 while not self .read_at_eof ():
240251 chunk = await self .read_channel (self ._read_buffer_size )
241252 output .append (chunk )
@@ -276,20 +287,74 @@ class ReadBuffer:
276287 _cmd : str
277288 _is_cmd_displayed : bool = False
278289
279- def __init__ (self , cmd : str = '' , line_break : str = '\n ' ) -> None :
290+ def __init__ (self , cmd : str = '' , line_break : str = '\n ' , encode : str = None ) -> None :
280291 """ Initialize read buffer """
281292 self ._buffer = []
282293 self ._last_line_pos = (0 , 0 )
283294 self ._line_break = line_break
284295 self ._cmd = cmd
285296 self ._is_cmd_displayed = False
297+ self ._encode = encode
286298
287- def _check_cmd_displayed (self , line : str = '' ) -> bool :
299+ def _check_cmd_displayed (self , pattern : Pattern , line : str = '' ) -> bool :
288300 if not self ._is_cmd_displayed and self ._cmd and line :
289301 # check if the command is displayed in the line
290- if self ._cmd in line :
302+ if self ._cmd_in_line ( pattern , line ) :
291303 self ._is_cmd_displayed = True
292304 log .trace (f"Command '{ self ._cmd } ' is displayed in the line: { line } " )
305+
306+ def _cmd_in_line (self , pattern : Pattern , line : str = '' ) -> bool :
307+ """check if the line contains cmd"""
308+
309+ log .trace (f"Line repr = { repr (line )} " )
310+ log .trace (f"Line escape = { line .encode ('unicode_escape' ).decode ('ascii' )} " )
311+
312+ # Topsec output extra ' \r' char
313+ # Fortinet output extra ' \x08' char
314+ if self ._cmd in re .sub (r'\s[\r\x08]' , '' , line ):
315+ return True
316+
317+ # Juniper input extra spaces, and the output will remove the extra spaces
318+ if self ._cmd .replace (' ' , '' ) in re .sub (r'[\x07\s]' , '' , line ):
319+ return True
320+
321+ # Fortinet input Chinese and output octal char
322+ chinese = simulate_output_oct_to_chinese (output = line , encoding = self ._encode )
323+ log .trace (f"Line oct to chinese: { repr (chinese )} " )
324+ if self ._cmd in chinese :
325+ return True
326+
327+ # Line Remove prompt
328+ for index in range (1 , len (line )):
329+ if re .match (pattern , line [:index ]):
330+ line = line [index :].lstrip ()
331+ break
332+ log .trace (f"Line remove prompt = { repr (line )} " )
333+
334+ # Topsec chinese escape failed, character ignored
335+ if '\ufffd ' in line :
336+ line_splits = re .sub (r"(\s\r|\r\n)" , '' , line ).split ('\ufffd ' )
337+ log .trace (f"Line split by \\ ufffd = { line_splits } " )
338+ for line_split in line_splits :
339+ if line_split not in self ._cmd :
340+ return False
341+ return True
342+
343+ line = simulate_output (line )
344+ log .trace (f"Line simulate output = { repr (line )} " )
345+
346+ # Line simulate output remove prompt
347+ for index in range (1 , len (line )):
348+ if re .match (pattern , line [:index ]):
349+ line = line [index :].lstrip ()
350+ break
351+ log .trace (f"Line simulate output remove prompt = { repr (line )} " )
352+
353+ # Array or Cisco display ultra wide processing
354+ if '$' in line and line .replace ('\x08 ' , '' ).split ('$' )[0 ] in self ._cmd :
355+ return True
356+
357+ return False
293358
294359 def _is_real_prompt (self ) -> bool :
295360 if self ._cmd :
@@ -332,7 +397,7 @@ def check_pattern(self, pattern: Pattern, is_update_checkpos: bool = True) -> Ma
332397 while lb_pos != - 1 :
333398 # found a line break, concat the line
334399 line = '' .join ([line , self ._buffer [i ][line_start_pos :lb_pos ], self ._line_break ])
335- self ._check_cmd_displayed (line )
400+ self ._check_cmd_displayed (pattern , line )
336401 line_start_pos = lb_pos + len (self ._line_break )
337402 log .trace (f"Checking buffer[{ i } ][:{ line_start_pos } ]: { line } " )
338403 matched = pattern .search (line )
@@ -350,7 +415,7 @@ def check_pattern(self, pattern: Pattern, is_update_checkpos: bool = True) -> Ma
350415
351416 # no line break found, check the rest of buffer item
352417 line = '' .join ([line , self ._buffer [i ][line_start_pos :]])
353- self ._check_cmd_displayed (line )
418+ self ._check_cmd_displayed (pattern , line )
354419 line_start_pos += len (line )
355420 if i == buffer_size - 1 :
356421 # if no line break found and no more buffer, check the last line
0 commit comments