sftp/Test.cpp

#include "StdAfx.h"
#include "Test.h"
#include "UnixPath.h"

_Check_return_ HRESULT CTest::Initialize()
{
	const auto hr = m_Connection.CoCreateInstance(__uuidof(sfFTPLib::SSHConnection));
	if (SUCCEEDED(hr))
	{
		return S_OK;
	}
	Log(_T("Failed to create SFTPConnection instance. hr=0x%x.\nTry to register sfFTPLib.dll again."), hr);
	ATLASSERT(0);
	return hr;
}

_Check_return_ HRESULT CTest::Uninitialize()
{
	m_Connection.Release();
	return S_OK;
}

void CTest::ReportLastStatus()
{
	sfFTPLib::SFTPStatus statusCode;
	ATLENSURE_SUCCEEDED(m_SFTP->get_LastStatusCode(&statusCode));
	Log(_T("LastStatusCode = %d."), statusCode);
	ATL::CComBSTR statusMessage;
	ATLENSURE_SUCCEEDED(m_SFTP->get_LastStatusMessage(&statusMessage));
	Log(_T("LastStatusMessage = \"%s\"."), (LPCTSTR)statusMessage);
}

// Purpose: Converts FILETIME to ISO8601 string
_Check_return_ HRESULT CTest::FILETIMEToISO8601(const FILETIME& ft, ATL::CString &retval)
{
	SYSTEMTIME st;
	if (::FileTimeToSystemTime(&ft, &st))
	{
		retval.Format(_T("%04hu-%02hu-%02huT%02hu:%02hu:%02huZ"), st.wYear, st.wMonth, st.wDay, st.wHour, st.wMinute, st.wSecond);
		return S_OK;
	}
	return E_FAIL;
}

void CTest::Run()
{
	const wchar_t host []{ L"localhost" };
	const long port{ 22 };
	const wchar_t username []{ L"user" };
	const wchar_t password []{ L"pass" };

	//m_Connection->Async = VARIANT_FALSE;
	ATLENSURE_SUCCEEDED(m_Connection->put_Host(ATL::CComBSTR(host)));
	ATLENSURE_SUCCEEDED(m_Connection->put_Port(port));
	ATLENSURE_SUCCEEDED(m_Connection->put_Username(ATL::CComBSTR(username)));
	ATLENSURE_SUCCEEDED(m_Connection->put_Password(ATL::CComBSTR(password)));

	ATL::CComPtr<sfFTPLib::IFileLogger> fileLogger;
	ATLENSURE_SUCCEEDED(fileLogger.CoCreateInstance(__uuidof(sfFTPLib::FileLogger)));
	ATLENSURE_SUCCEEDED(fileLogger->put_File(ATL::CComBSTR(L"ssh.log")));

	ATL::CComQIPtr<sfFTPLib::ILogger> logger(fileLogger);
	ATLENSURE_SUCCEEDED(m_Connection->put_Logger(logger));

	// Authentication
	// Notes:
	// - Titan FTP Server: If "publickey" authentication fails the server disconnects without accepting any further methods. e.g. password
	ATL::CComSafeArray<VARIANT> authentications(2);
	//authentications.SetAt(0, ATL::CComVariant(sfFTPLib::ftpSSHAuthenticationNone));
	authentications.SetAt(0, ATL::CComVariant(sfFTPLib::ftpSSHAuthenticationPassword));
	authentications.SetAt(1, ATL::CComVariant(sfFTPLib::ftpSSHAuthenticationPublicKey));
	//ATLENSURE_SUCCEEDED(m_Connection->put_Authentications(&ATL::CComVariant(authentications)));

	// Disable Compression
	// Uncomment to disable compression
	ATL::CComSafeArray<VARIANT> compressions(2);
	compressions.SetAt(0, ATL::CComVariant(sfFTPLib::ftpSSHCompressionzlibopenssh));
	compressions.SetAt(0, ATL::CComVariant(sfFTPLib::ftpSSHCompressionNone));
	//m_Connection->put_Compressions(&ATL::CComVariant(compressions));

	// Limit Encryptions
	ATL::CComSafeArray<VARIANT> encryptions(4);
	encryptions.SetAt(0, ATL::CComVariant(sfFTPLib::ftpEncryptionAES256CTR));
	encryptions.SetAt(1, ATL::CComVariant(sfFTPLib::ftpEncryptionAES192CTR));
	encryptions.SetAt(2, ATL::CComVariant(sfFTPLib::ftpEncryptionAES128CTR));
	encryptions.SetAt(3, ATL::CComVariant(sfFTPLib::ftpEncryption3DES));
	//m_Connection->put_Encryptions(&ATL::CComVariant(encryptions));

	// Limit KeyExchange Algorithms
	ATL::CComSafeArray<VARIANT> keyexchanges(2);
	keyexchanges.SetAt(0, ATL::CComVariant(sfFTPLib::ftpKeyExchangeDiffieHellmanGroup14SHA1));
	//m_Connection->put_KeyExchanges(&ATL::CComVariant(keyexchanges));

	ATL::CComPtr<sfFTPLib::IKeyManager> keyManager;
	ATLENSURE_SUCCEEDED(keyManager.CoCreateInstance(__uuidof(sfFTPLib::KeyManager)));
	
	// Uncomment to generate new key
	ATL::CComBSTR bstrFilePrivate(L"Identity");
	ATL::CComBSTR bstrFilePublic(L"Identity.pub");
	ATL::CComBSTR bstrPassword(L"");

#if 0
	// Uncomment to create private key
	// -> broken
	// For VShell copy public key (Identity.pub) to user's folder: C:\Program Files\VShell\PublicKey\<user>
	ATL::CComPtr<sfFTPLib::OpenSSLKey> pRSA;
	ATLENSURE_SUCCEEDED(pRSA.CoCreateInstance(__uuidof(sfFTPLib::OpenSSLKey)));

	// Generate 1024-bit RSA key
	pRSA->Generate(1024);
	// Save private key in PKCS12 format (.p12)
	ATLENSURE_SUCCEEDED(keyManager->SaveFile(sfFTPLib::ftpKeyFileFormatPKCS12, pRSA, sfFTPLib::ftpKeyTypePrivateKey, bstrFilePrivate, bstrPassword));
	// Save public key. Password is ignored.
	ATLENSURE_SUCCEEDED(keyManager->SaveFile(sfFTPLib::ftpKeyFileFormatSSH, pRSA, sfFTPLib::ftpKeyTypePublicKey, bstrFilePublic, bstrPassword));
	// Save public key (for OpenSSH only). Password is ignored.
	//ATLENSURE_SUCCEEDED(keyManager->SaveFile(sfFTPLib::ftpKeyFileFormatOpenSSH, pRSA, sfFTPLib::ftpKeyTypePublicKey, bstrFilePublic, bstrPassword);
#endif

	Log(_T("Loading private key \"%s\"."), (LPCTSTR) bstrFilePrivate);
	ATL::CComPtr<sfFTPLib::IKey> pKey;
	if (keyManager->LoadFile(bstrFilePrivate, bstrPassword, &pKey) == S_OK)
	{
		sfFTPLib::KeyType keyType;
		ATLENSURE_SUCCEEDED(pKey->get_Type(&keyType));
		if (keyType == sfFTPLib::ftpKeyTypePrivateKey)
		{
			m_Connection->put_PrivateKey(pKey);
			Log(_T("Private key successfully loaded from \"%s\"."), (LPCTSTR) bstrFilePrivate);
		}
	}
	else
	{
		Log(_T("Failed to load key."));
	}

	Log(_T("Connecting to %s Port: %u"), (LPCTSTR)host, port);
	ATLENSURE_SUCCEEDED(m_Connection->Connect());

	ATL::CComPtr<sfFTPLib::ISSHServerState> serverState;
	ATLENSURE_SUCCEEDED(m_Connection->get_ServerState(&serverState));
	ATL::CComBSTR remoteId;
	ATLENSURE_SUCCEEDED(serverState->get_RemoteId(&remoteId));
	Log(_T("%s"), (LPCTSTR) remoteId);

	SFTPTest();

	// Disconnect
	Log(_T("Disconnect"));
	ATLENSURE_SUCCEEDED(m_Connection->Disconnect());
}

void CTest::SFTPTest()
{
	ATLENSURE_SUCCEEDED(m_Connection->CreateSFTPConnection(&m_SFTP));

	ATL::CComPtr<sfFTPLib::IFileLogger> fileLogger;
	ATLENSURE_SUCCEEDED(fileLogger.CoCreateInstance(__uuidof(sfFTPLib::FileLogger)));
	fileLogger->put_File(ATL::CComBSTR(L"sftp.log"));

	ATL::CComQIPtr<sfFTPLib::ILogger> logger(fileLogger);
	ATLENSURE_SUCCEEDED(m_SFTP->put_Logger(logger));

	ATLENSURE_SUCCEEDED(m_SFTP->Connect());

	Log(_T("SFTP channel successfully opened."));

	// get current folder
	ATL::CComBSTR realPath = L".";
	Log(_T("RealPath \"%s\""), (LPCTSTR)realPath);

	ATL::CComBSTR currentFolder;
	ATLENSURE_SUCCEEDED(m_SFTP->RealPath(realPath, &currentFolder));
	Log(_T("Home Folder = %s"), (LPCTSTR)currentFolder);

	// overriding CurrentFolder for debug purpose
	//bstrCurrentFolder = L"/c/archive";
	//Log(_T("Overriding current folder. \"%s\""), (LPCTSTR)bstrCurrentFolder);

	Log(_T("Reading Directory \"%s\""), (LPCTSTR)currentFolder);

	ATL::CComPtr<sfFTPLib::IFTPItems> items;
	ATLENSURE_SUCCEEDED(m_SFTP->ReadDirectory(currentFolder, &items));

	long count;
	ATLENSURE_SUCCEEDED(items->get_Count(&count));
	Log(_T("Count = %d"), count);

	// Enum
	if (count > 0)
	{
		ATL::CComPtr<IUnknown> unkEnum;
		ATLENSURE_SUCCEEDED(items->get__NewEnum(&unkEnum));
		ATL::CComQIPtr<IEnumVARIANT> pEnum(unkEnum);

		ULONG CeltFetched;
		ATL::CComVariant variant;
		while (pEnum->Next(1, &variant, &CeltFetched) == S_OK)
		{
			if (variant.vt == VT_DISPATCH
				|| variant.vt == VT_UNKNOWN)
			{
				ATL::CComQIPtr<sfFTPLib::IFTPItem> pSFTPItem = variant.pdispVal;

				// TODO: Check for valid attributes (IsValidAttribute())
				sfFTPLib::ItemType itemType;
				ATLENSURE_SUCCEEDED(pSFTPItem->get_Type(&itemType));
				ATL::CComBSTR itemName;
				ATLENSURE_SUCCEEDED(pSFTPItem->get_Name(&itemName));
				ULONGLONG itemSize;
				ATLENSURE_SUCCEEDED(pSFTPItem->get_Size(&itemSize));

				ATL::CString str;
				str.Format(_T("Type=0x%x; Name=%s; Size=%I64u"), itemType, (LPCTSTR) itemName, itemSize);

				VARIANT_BOOL isValidAttribute;
				ATLENSURE_SUCCEEDED(pSFTPItem->IsValidAttribute(sfFTPLib::ftpFTPItemAttributeModifyTime, &isValidAttribute));
				if (isValidAttribute)
				{
					ATL::CString strTime;
					FILETIME modifyTime;
					ATLENSURE_SUCCEEDED(pSFTPItem->get_ModifyTime(&modifyTime));
					if (FILETIMEToISO8601(modifyTime, strTime) == S_OK)
					{
						str += _T("; ModifyTime=") + strTime;
					}
				}

				Log(str);
			}
			// need to manually clear variant
			variant.Clear();
		}
	}

	// MakeDirectory
	CUnixPath makeDirectory((LPCTSTR)currentFolder);
	makeDirectory.Append(_T("testfolder"));
	Log(_T("MakeDirectory \"%s\""), (LPCTSTR)makeDirectory);
	ATLENSURE_SUCCEEDED(m_SFTP->MakeDirectory(ATL::CComBSTR(makeDirectory)));
	Log(_T("Directory \"%s\" created."), (LPCTSTR)makeDirectory);

	// Rename
	CUnixPath renameFrom = makeDirectory;
	CUnixPath renameTo = (LPCTSTR)currentFolder;
	renameTo.Append(_T("testfolder2"));
	Log(_T("Rename \"%s\" to \"%s\""), (LPCTSTR)renameFrom, (LPCTSTR)renameTo);
	ATLENSURE_SUCCEEDED(m_SFTP->Rename(ATL::CComBSTR(renameFrom), ATL::CComBSTR(renameTo), 0));

	// RemoveDirectory
	CUnixPath RemoveDirectory = renameTo;
	Log(_T("RemoveDirectory \"%s\""), (LPCTSTR)RemoveDirectory);
	ATLENSURE_SUCCEEDED(m_SFTP->RemoveDirectory(ATL::CComBSTR(RemoveDirectory)));
	Log(_T("Directory \"%s\" removed."), (LPCTSTR)RemoveDirectory);

	// Creating temporary memory file
	ATL::CComPtr<IStream> pMemFile;
	static const DWORD dwSize = 1000 * 1024; // 1000 KiB
	if (CreateMemFile(dwSize, 0, &pMemFile) == S_OK)
	{
		// Upload File
		CUnixPath UploadFile((LPCTSTR)currentFolder);
		UploadFile.Append(_T("memfile"));
		ATL::CComBSTR bstrUploadFile = (LPCTSTR)UploadFile;

		Log(_T("UploadFile to \"%s\""), (LPCTSTR)bstrUploadFile);
		ATLENSURE_SUCCEEDED(m_SFTP->UploadFile(ATL::CComVariant(pMemFile.p), bstrUploadFile, sfFTPLib::ftpDataTransferTypeImage, 0,0));
		Log(_T("File successfully uploaded."));

		// Stat. Stat doesn't follow symbolic links.
		ATL::CComBSTR bstrStat = bstrUploadFile;
		Log(_T("Stat \"%s\""), (LPCTSTR)bstrStat);
		ATL::CComPtr<sfFTPLib::IFTPItem> pItem;
		ATLENSURE_SUCCEEDED(m_SFTP->Stat(bstrStat, sfFTPLib::ftpFTPItemAttributeSize, &pItem));
		VARIANT_BOOL isValidAttribute;
		ATLENSURE_SUCCEEDED(pItem->IsValidAttribute(sfFTPLib::ftpFTPItemAttributeSize, &isValidAttribute));
		if (isValidAttribute)
		{
			ULONGLONG size;
			ATLENSURE_SUCCEEDED(pItem->get_Size(&size));
			Log(_T("File Size = %I64u."), size);
		}
		// DownloadFile to memory file
		ATL::CComPtr<IStream> pDownloadMemFile;
		if (CreateMemFile(dwSize, 0, &pDownloadMemFile) == S_OK)
		{
			ATL::CComBSTR bstrDownloadFile = bstrUploadFile;
			Log(_T("DownloadFile \"%s\" to memfile"), (LPCTSTR)bstrDownloadFile);
			ATLENSURE_SUCCEEDED(m_SFTP->DownloadFileEx(bstrDownloadFile, ATL::CComVariant(pDownloadMemFile.p), sfFTPLib::ftpDataTransferTypeImage, 0, 0, sfFTPLib::ftpDownloadFlagReadBeyondEnd, nullptr));
			Log(_T("File successfully downloaded."));
		}

		// DownloadFile to physical file
		ATL::CComBSTR bstrDownloadFile = bstrUploadFile;
		TCHAR szCurrentDirectory[MAX_PATH] = {};
		::GetCurrentDirectory(ARRAYSIZE(szCurrentDirectory), szCurrentDirectory);
		::PathAppend(szCurrentDirectory, _T("Download"));
		::SHCreateDirectoryEx(nullptr, szCurrentDirectory, nullptr);
		::PathAppend(szCurrentDirectory, _T("memfile"));
		ATL::CComBSTR downloadLocalFile = szCurrentDirectory;
		Log(_T("DownloadFile \"%s\" to \"%s\""), (LPCTSTR)bstrDownloadFile, (LPCTSTR)downloadLocalFile);
		ATLENSURE_SUCCEEDED(m_SFTP->DownloadFile(bstrDownloadFile, ATL::CComVariant(szCurrentDirectory), sfFTPLib::ftpDataTransferTypeImage, 0, 0));
		Log(_T("File successfully downloaded."));
	}

	Log(_T("Closing channel."));
	ATLENSURE_SUCCEEDED(m_SFTP->Disconnect());
}

void CTest::Log(_In_z_ PCTSTR pszFormat, ...)
{
	// max limit of log message set to 4096. Increase if message gets cut.
	const int LOG_EVENT_MSG_SIZE = 4096;

	TCHAR chMsg[LOG_EVENT_MSG_SIZE];
	va_list pArg;

	va_start(pArg, pszFormat);
	_vsntprintf_s(chMsg, LOG_EVENT_MSG_SIZE, LOG_EVENT_MSG_SIZE-1, pszFormat, pArg);

	ATL::CString strMsg(chMsg);
	strMsg += _T("\n");
	ATLTRACE(strMsg);
	_tprintf(strMsg);
}

// Purpose: Creates memory file with Global Memory (GlobalAlloc)
// nFillMethod: 0: zero data, 1: fill with 0-255
_Check_return_ HRESULT CTest::CreateMemFile(DWORD nSize, int fillMethod, _COM_Outptr_ IStream **retval)
{
	*retval = nullptr;

	auto hMem = ::GlobalAlloc(GMEM_MOVEABLE | GMEM_ZEROINIT, static_cast<SIZE_T>(nSize));
	if (!hMem)
	{
		return E_OUTOFMEMORY;
	}

	const auto pImage = reinterpret_cast<BYTE*>(::GlobalLock(hMem));
	if (pImage)
	{
		if (fillMethod == 1)
		{
			// fill with 0-255
			for (DWORD i = 0; i < nSize; i++)
			{
				pImage[i] = static_cast<BYTE>(i);
			}
		}
		::GlobalUnlock(hMem);

		// Create Stream from hMem. Automatically release hMem
		return ::CreateStreamOnHGlobal(hMem, TRUE, retval);
	}
	return E_FAIL;
}