Skip to content
Snippets Groups Projects
cls_cmd.py 7.67 KiB
Newer Older
#/*
# * Licensed to the OpenAirInterface (OAI) Software Alliance under one or more
# * contributor license agreements.  See the NOTICE file distributed with
# * this work for additional information regarding copyright ownership.
# * The OpenAirInterface Software Alliance licenses this file to You under
# * the OAI Public License, Version 1.1  (the "License"); you may not use this file
# * except in compliance with the License.
# * You may obtain a copy of the License at
# *
# *      http://www.openairinterface.org/?page_id=698
# *
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS,
# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * See the License for the specific language governing permissions and
# * limitations under the License.
# *-------------------------------------------------------------------------------
# * For more information about the OpenAirInterface (OAI) Software Alliance:
# *      contact@openairinterface.org
# */
#---------------------------------------------------------------------
#
#   Required Python Version
#     Python 3.x
#
#---------------------------------------------------------------------

import abc
import logging
import subprocess as sp
import os
import paramiko
import uuid
import sys 

SSHTIMEOUT=7

# helper that returns either LocalCmd or RemoteCmd based on passed host name
def getConnection(host, d=None):
	if host is None or host.lower() in ["", "none", "localhost"]:
		return LocalCmd(d=d)
	else:
		return RemoteCmd(host, d=d)

# provides a partial interface for the legacy SSHconnection class (getBefore(), command())
class Cmd(metaclass=abc.ABCMeta):
	def cd(self, d, silent=False):
		if d == None or d == '':
			self.cwd = None
		elif d[0] == '/':
			self.cwd = d
		else:
			if not self.cwd:
				# no cwd set: get current working directory
				self.cwd = self.run('pwd').stdout.strip()
			self.cwd += f"/{d}"
		if not silent:
			logging.debug(f'cd {self.cwd}')

	@abc.abstractmethod
	def run(self, line, timeout=300, silent=False):
		return

	def command(self, commandline, expectedline=None, timeout=300, silent=False, resync=False):
		splitted = commandline.split(' ')
		if splitted[0] == 'cd':
			self.cd(' '.join(splitted[1:]), silent)
		else:
			self.run(commandline, timeout, silent)
		return 0

	@abc.abstractmethod
	def close(self):
		return

	@abc.abstractmethod
	def getBefore(self):
		return

	@abc.abstractmethod
	def copyin(self, scpIp, scpUser, scpPw, src, tgt):
		return

	@abc.abstractmethod
	def copyout(self, scpIp, scpUser, scpPw, src, tgt):
		return

class LocalCmd(Cmd):
	def __init__(self, d = None):
		self.cwd = d
		if self.cwd is not None:
			logging.debug(f'Working dir is {self.cwd}')
		self.cp = sp.CompletedProcess(args='', returncode=0, stdout='')

	def run(self, line, timeout=300, silent=False, reportNonZero=True):
		if not silent:
			logging.info(line)
		try:
			if line.strip().endswith('&'):
				# if we wait for stdout, subprocess does not return before the end of the command
				# however, we don't want to wait for commands with &, so just return fake command
				ret = sp.run(line, shell=True, cwd=self.cwd, timeout=5)
			else:
				ret = sp.run(line, shell=True, cwd=self.cwd, stdout=sp.PIPE, stderr=sp.STDOUT, timeout=timeout)
		except Exception as e:
			ret = sp.CompletedProcess(args=line, returncode=255, stdout=f'Exception: {str(e)}'.encode('utf-8'))
		if ret.stdout is None:
			ret.stdout = b''
		ret.stdout = ret.stdout.decode('utf-8').strip()
		if reportNonZero and ret.returncode != 0:
			logging.warning(f'command "{ret.args}" returned non-zero returncode {ret.returncode}: output:\n{ret.stdout}')
		self.cp = ret
		return ret

	def close(self):
		pass

	def getBefore(self):
		return self.cp.stdout

	def copyin(self, src, tgt, recursive=False):
		if src[0] != '/' or tgt[0] != '/':
			raise Exception('support only absolute file paths!')
		opt = '-r' if recursive else ''
		self.run(f'cp {opt} {src} {tgt}')

	def copyout(self, src, tgt, recursive=False):
		self.copyin(src, tgt, recursive)

class RemoteCmd(Cmd):
	def __init__(self, hostname, d=None):
		cIdx = 0
		logging.getLogger('paramiko').setLevel(logging.ERROR) # prevent spamming through Paramiko
		self.client = paramiko.SSHClient()
		self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
		cfg = RemoteCmd._lookup_ssh_config(hostname)
		self.cwd = d
		self.cp = sp.CompletedProcess(args='', returncode=0, stdout='')
		while cIdx < 3:
			try:
				self.client.connect(**cfg)
				return
			except:
				logging.error(f'Could not connect to {hostname}, tried for {cIdx} time')
				cIdx +=1
		raise Exception ("Error: max retries, did not connect to host")

	def _lookup_ssh_config(hostname):
		ssh_config = paramiko.SSHConfig()
		user_config_file = os.path.expanduser("~/.ssh/config")
		if os.path.exists(user_config_file):
			with open(user_config_file) as f:
				ssh_config.parse(f)
		else:
			raise FileNotFoundError('class needs SSH config at ~/.ssh/config')
		ucfg = ssh_config.lookup(hostname)
		if 'identityfile' not in ucfg or 'user' not in ucfg:
			raise KeyError(f'no identityfile or user in SSH config for host {hostname}')
		cfg = {'hostname':hostname, 'username':ucfg['user'], 'key_filename':ucfg['identityfile'], 'timeout':SSHTIMEOUT}
		if 'hostname' in ucfg:
			cfg['hostname'] = ucfg['hostname'] # override user-given hostname with what is in config
		if 'port' in ucfg:
			cfg['port'] = int(ucfg['port'])
		if 'proxycommand' in ucfg:
			cfg['sock'] = paramiko.ProxyCommand(ucfg['proxycommand'])
		return cfg

	def run(self, line, timeout=300, silent=False, reportNonZero=True):
		if not silent:
			logging.info(line)
		if self.cwd:
			line = f"cd {self.cwd} && {line}"
		try:
			if line.strip().endswith('&'):
				# if we wait for stdout, Paramiko does not return before the end of the command
				# however, we don't want to wait for commands with &, so just return fake command
				self.client.exec_command(line, timeout = 5)
				ret = sp.CompletedProcess(args=line, returncode=0, stdout=b'')
			else:
				stdin, stdout, stderr = self.client.exec_command(line, timeout=timeout)
				ret = sp.CompletedProcess(args=line, returncode=stdout.channel.recv_exit_status(), stdout=stdout.read(size=None) + stderr.read(size=None))
		except Exception as e:
			ret = sp.CompletedProcess(args=line, returncode=255, stdout=f'Exception: {str(e)}'.encode('utf-8'))
		ret.stdout = ret.stdout.decode('utf-8').strip()
		if reportNonZero and ret.returncode != 0:
			logging.warning(f'command "{line}" returned non-zero returncode {ret.returncode}: output:\n{ret.stdout}')
		self.cp = ret
		return ret

	def close(self):
		self.client.close()

	def getBefore(self):
		return self.cp.stdout

	def copyout(self, src, tgt, recursive=False):
		logging.debug(f"copyout: local:{src} -> remote:{tgt}")
		if recursive:
			tmpfile = f"{uuid.uuid4()}.tar"
			abstmpfile = f"/tmp/{tmpfile}"
			cmd = LocalCmd()
			cmd.run(f"tar -cf {abstmpfile} {src}")
			sftp = self.client.open_sftp()
			sftp.put(abstmpfile, abstmpfile)
			sftp.close()
			cmd.run(f"rm {abstmpfile}")
			self.run(f"mv {abstmpfile} {tgt}; cd {tgt} && tar -xf {tmpfile} && rm {tmpfile}")
		else:
			sftp = self.client.open_sftp()
			sftp.put(src, tgt)
			sftp.close()

	def copyin(self, src, tgt, recursive=False):
		logging.debug(f"copyin: remote:{src} -> local:{tgt}")
		if recursive:
			tmpfile = f"{uuid.uuid4()}.tar"
			abstmpfile = f"/tmp/{tmpfile}"
			self.run(f"tar -cf {abstmpfile} {src}")
			sftp = self.client.open_sftp()
			sftp.get(abstmpfile, abstmpfile)
			sftp.close()
			self.run(f"rm {abstmpfile}")
			cmd = LocalCmd()
			cmd.run(f"mv {abstmpfile} {tgt}; cd {tgt} && tar -xf {tmpfile} && rm {tmpfile}")
		else:
			sftp = self.client.open_sftp()
			sftp.get(src, tgt)
			sftp.close()