Author: andreas Date: 2010-07-18 23:12:14 -0700 (Sun, 18 Jul 2010) New Revision: 2242 Added: trunk/platforms/win32/plugins/SqueakSSL/ trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c Log: Add win32 implementation for SqueakSSL plugin. Added: trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c =================================================================== --- trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c (rev 0) +++ trunk/platforms/win32/plugins/SqueakSSL/sqWin32SSL.c 2010-07-19 06:12:14 UTC (rev 2242) @@ -0,0 +1,827 @@ +/* sqWin32SSL.c: SqueakSSL implementation for Windows */ + +#include <windows.h> +#include <errno.h> +#include <malloc.h> + +#include "sq.h" +#include "SqueakSSL.h" + +#define SECURITY_WIN32 +#include <security.h> +#include <schannel.h> +#include <wincrypt.h> + +typedef struct sqSSL { + int state; + int certFlags; + int loglevel; + + char *certName; + char *peerName; + + CredHandle sslCred; + CtxtHandle sslCtxt; + + SecBufferDesc sbdIn; + SecBufferDesc sbdOut; + SecBuffer inbuf[4]; + SecBuffer outbuf[4]; + + /* internal data buffer */ + char *dataBuf; + int dataLen; + int dataMax; + + SecPkgContext_StreamSizes sslSizes; +} sqSSL; + +static sqSSL **handleBuf = NULL; +static sqInt handleMax = 0; + + +/********************************************************************/ +/********************************************************************/ +/********************************************************************/ + +/* sslFromHandle: Maps a handle to an SSL */ +static sqSSL *sslFromHandle(sqInt handle) { + return handle < handleMax ? handleBuf[handle] : NULL; +} + + +/* sqPrintSBD: Prints a SecurityBuffer for debugging */ +static void sqPrintSBD(char *title, SecBufferDesc sbd) { + unsigned int i; + printf("%s\n", title); + for(i=0; i<sbd.cBuffers; i++) { + SecBuffer *buf = sbd.pBuffers + i; + printf("\tbuf[%d]: %d (%d bytes) ptr=%x\n", i,buf->BufferType, buf->cbBuffer, (int)buf->pvBuffer); + } +} + +/* sqCopyExtraData: Retains any SECBUFFER_EXTRA data. */ +static void sqCopyExtraData(sqSSL *ssl, SecBufferDesc sbd) { + unsigned int i; + if(sbd.pBuffers[0].BufferType == SECBUFFER_MISSING) { + if(ssl->loglevel) printf("sqCopyExtra: Encountered SECBUFFER_MISSING; retaining %d bytes\n", ssl->dataLen); + return; + } + ssl->dataLen = 0; + for(i=0; i<sbd.cBuffers; i++) { + SecBuffer *buf = sbd.pBuffers + i; + if(buf->BufferType == SECBUFFER_EXTRA) { + int count = buf->cbBuffer; + char *srcPtr = buf->pvBuffer; + char *dstPtr = ssl->dataBuf + ssl->dataLen; + if(ssl->loglevel) printf("sqCopyExtraData: Retaining %d bytes\n", count); + /* I *think* the extra buffers are always in input range. + Make sure that's the case or at least report it if not. */ + if(srcPtr < dstPtr || (srcPtr + count) > (ssl->dataBuf + ssl->dataMax)) { + if(ssl->loglevel) printf("sqCopyExtraDataSSL: Encountered out-of-range extra buffer\n"); + } + if(srcPtr != dstPtr) { + /* memmove() not memcpy() since the memory mayoverlap */ + memmove(dstPtr, srcPtr, count); + } + ssl->dataLen += count; + } + } +} + +/* Copies the data from a SecBufferDesc to dstBuf */ +static sqInt sqCopyDescToken(sqSSL *ssl, SecBufferDesc sbd, char *dstBuf, sqInt dstLen) { + unsigned int i; + int result = 0; + + if(ssl->loglevel) printf("sqCopyDescToken: \n"); + for(i = 0; i < sbd.cBuffers; i++) { + SecBuffer *buf = sbd.pBuffers + i; + if(ssl->loglevel) printf("\t type=%d, size=%d\n", buf->BufferType, buf->cbBuffer); + if(buf->BufferType == SECBUFFER_TOKEN) { + int count = buf->cbBuffer; + if(count > dstLen) return SQSSL_BUFFER_TOO_SMALL; + memcpy(dstBuf, buf->pvBuffer, count); + result += count; + dstBuf += count; + dstLen -= count; + FreeContextBuffer(buf->pvBuffer); + } + if(buf->BufferType == SECBUFFER_EXTRA) { + /* XXXX: Preserve contents for the next round */ + if(ssl->loglevel) printf("sqCopyDescToken: Unexpectedly encountered SECBUFFER_EXTRA\n"); + } + } + return result; +} + +/* Set up the local certificate for SSL */ +static sqInt sqSetupCert(sqSSL *ssl, char *certName, int server) { + SCHANNEL_CRED sc_cred = { 0 }; + SECURITY_STATUS ret; + HCERTSTORE hStore; + PCCERT_CONTEXT pContext = NULL; + + if(certName) { + hStore = CertOpenSystemStore(0, "MY"); + if(!hStore) { + if(ssl->loglevel) printf("sqSetupCert: CertOpenSystemStore failed\n"); + return 0; + } + pContext = CertFindCertificateInStore(hStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + 0, CERT_FIND_SUBJECT_STR_A, certName, NULL); + + /* XXXX: Fail? Or just not provide the cert? For now, fail. */ + if(!pContext) { + if(ssl->loglevel) printf("sqSetupCert: CertFindCertitficateInStore failed\n"); + CertCloseStore(hStore, 0); + return 0; + } + } + + sc_cred.dwVersion = SCHANNEL_CRED_VERSION; + sc_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION; + sc_cred.grbitEnabledProtocols = server ? SP_PROT_TLS1_SERVER | SP_PROT_SSL3_SERVER | SP_PROT_SSL2_SERVER : 0; + sc_cred.dwMinimumCipherStrength = 0; + sc_cred.dwMaximumCipherStrength = 0; + + if(pContext) { + sc_cred.cCreds = 1; + sc_cred.paCred = &pContext; + } else { + sc_cred.cCreds = 0; + } + + ret = AcquireCredentialsHandle(NULL, UNISP_NAME, + server ? SECPKG_CRED_INBOUND : SECPKG_CRED_OUTBOUND, + NULL, &sc_cred, NULL, NULL, &ssl->sslCred, NULL); + if(ssl->loglevel) printf("AquireCredentialsHandle returned: %x\n", ret); + + if(pContext) { + CertCloseStore(hStore, 0); + CertFreeCertificateContext(pContext); + } + + if (ret != SEC_E_OK) { + if(ssl->loglevel) printf("AquireCredentialsHandle error: %x\n", ret); + return 0; + } + return 1; +} + +/* sqExtractPeerName: Extract the name from the cert of the remote peer. */ +static int sqExtractPeerName(sqSSL *ssl) { + SECURITY_STATUS ret; + PCCERT_CONTEXT certHandle = NULL; + PCERT_NAME_INFO certInfo = NULL; + PCERT_RDN_ATTR certAttr = NULL; + DWORD dwSize = 0; + char tmpBuf[1024]; + + if(ssl->peerName) { + free(ssl->peerName); + ssl->peerName = NULL; + } + ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID)&certHandle); + /* No credentials were provided; can't extract peer name */ + if(ret == SEC_E_NO_CREDENTIALS) return 1; + + if(ret != SEC_E_OK) { + if(ssl->loglevel) printf("sqExtractPeerName: QueryContextAttributes failed (code = %x)\n", ret); + return 0; + } + + /* Figure out the size of the blob */ + if(!CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, X509_NAME, + certHandle->pCertInfo->Subject.pbData, + certHandle->pCertInfo->Subject.cbData, + 0, NULL, &dwSize)) { + if(ssl->loglevel) printf("sqExtractPeerName: CryptDecodeObject failed\n"); + return 0; + } + + /* Get the contents */ + certInfo = alloca(dwSize); + if(!CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, X509_NAME, + certHandle->pCertInfo->Subject.pbData, + certHandle->pCertInfo->Subject.cbData, + 0, certInfo, &dwSize)) { + if(ssl->loglevel) printf("sqExtractPeerName: CryptDecodeObject failed\n"); + return 0; + } + + /* Fetch the CN from the cert */ + certAttr = CertFindRDNAttr(szOID_COMMON_NAME, certInfo); + if(certAttr == NULL) return 0; + + /* Translate from RDN to string */ + if(CertRDNValueToStr(CERT_RDN_PRINTABLE_STRING, &certAttr->Value, tmpBuf, sizeof(tmpBuf)) == 0) return 0; + ssl->peerName = _strdup(tmpBuf); + if(ssl->loglevel) printf("sqExtractPeerName: Peer name is %s\n", ssl->peerName); + + return 1; +} + +/* sqVerifyCert: Verify the validity of the remote certificate */ +static int sqVerifyCert(sqSSL *ssl, int isServer) { + SECURITY_STATUS ret; + PCCERT_CONTEXT certHandle = NULL; + PCCERT_CHAIN_CONTEXT chainContext = NULL; + CERT_CHAIN_PARA chainPara; + SSL_EXTRA_CERT_CHAIN_POLICY_PARA epp; + CERT_CHAIN_POLICY_PARA policyPara; + CERT_CHAIN_POLICY_STATUS policyStatus; + + static LPSTR serverUsage[] = { + szOID_PKIX_KP_SERVER_AUTH, + szOID_SERVER_GATED_CRYPTO, + szOID_SGC_NETSCAPE + }; + static LPSTR clientUsage[] = { + szOID_PKIX_KP_CLIENT_AUTH + }; + + ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID)&certHandle); + /* No credentials were provided */ + if(ret == SEC_E_NO_CREDENTIALS) { + ssl->certFlags = SQSSL_NO_CERTIFICATE; + return 1; + } + + memset(&chainPara, 0, sizeof(chainPara)); + chainPara.cbSize = sizeof(chainPara); + chainPara.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR; + if(!isServer) { + chainPara.RequestedUsage.Usage.cUsageIdentifier = 3; + chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = serverUsage; + } else { + chainPara.RequestedUsage.Usage.cUsageIdentifier = 1; + chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = clientUsage; + } + if(!CertGetCertificateChain(NULL, certHandle, NULL, + certHandle->hCertStore, + &chainPara, 0, NULL, &chainContext)) { + /* XXXX: Does this mean the other end did not provide a cert? */ + ssl->certFlags = SQSSL_OTHER_ISSUE; + goto done; + } + + memset(&epp, 0, sizeof(epp)); + epp.cbSize = sizeof(epp); + epp.dwAuthType = AUTHTYPE_SERVER; + epp.fdwChecks = 0; + epp.pwszServerName = NULL; + + memset(&policyPara, 0, sizeof(policyPara)); + policyPara.cbSize = sizeof(policyPara); + policyPara.dwFlags = 0; + policyPara.pvExtraPolicyPara = &epp; + memset(&policyStatus, 0, sizeof(policyStatus)); + policyStatus.cbSize = sizeof(policyStatus); + + /* We loop here CertVerifyCertificateChainPolicy() returns only a + single error even if there is more than one issue with the cert. */ + ssl->certFlags = 0; + while(true) { + if (!CertVerifyCertificateChainPolicy( + CERT_CHAIN_POLICY_SSL, + chainContext, + &policyPara, + &policyStatus)) { + ssl->certFlags |= SQSSL_OTHER_ISSUE; + goto done; + } + switch(policyStatus.dwError) { + case SEC_E_OK: + goto done; + case CERT_E_UNTRUSTEDROOT: + if(ssl->certFlags & SQSSL_UNTRUSTED_ROOT) goto done; + ssl->certFlags |= SQSSL_UNTRUSTED_ROOT; + epp.fdwChecks |= 0x00000100; /* SECURITY_FLAG_IGNORE_UNKNOWN_CA */ + break; + case CERT_E_EXPIRED: + if(ssl->certFlags & SQSSL_CERT_EXPIRED) goto done; + ssl->certFlags |= SQSSL_CERT_EXPIRED; + epp.fdwChecks |= 0x00002000; /* SECURITY_FLAG_IGNORE_CERT_DATE_INVALID */ + break; + case CERT_E_WRONG_USAGE: + if(ssl->certFlags & SQSSL_WRONG_USAGE) goto done; + ssl->certFlags |= SQSSL_WRONG_USAGE; + epp.fdwChecks |= 0x00000200; /* SECURITY_FLAG_IGNORE_WRONG_USAGE */ + case CERT_E_REVOKED: + if(ssl->certFlags & SQSSL_CERT_REVOKED) goto done; + ssl->certFlags |= SQSSL_CERT_REVOKED; + epp.fdwChecks |= 0x00000080; /* SECURITY_FLAG_IGNORE_REVOCATION */ + break; + default: + ssl->certFlags |= SQSSL_OTHER_ISSUE; + goto done; + } + } +done: + return 1; +} + +/********************************************************************/ +/********************************************************************/ +/********************************************************************/ + +/* sqCreateSSL: Creates a new SSL instance. + Arguments: None. + Returns: SSL handle. +*/ +sqInt sqCreateSSL(void) { + sqInt handle; + sqSSL *ssl = NULL; + + ssl = calloc(1, sizeof(sqSSL)); + + ssl->sbdIn.ulVersion = SECBUFFER_VERSION; + ssl->sbdIn.cBuffers = 4; + ssl->sbdIn.pBuffers = ssl->inbuf; + + ssl->sbdOut.ulVersion = SECBUFFER_VERSION; + ssl->sbdOut.cBuffers = 4; + ssl->sbdOut.pBuffers = ssl->outbuf; + + /* Find a free handle */ + for(handle = 1; handle < handleMax; handle++) + if(handleBuf[handle] == NULL) break; + + if(handle >= handleMax) { + int i, delta = 100; + /* Resize the handle buffer */ + handleBuf = realloc(handleBuf, (handleMax+delta)*sizeof(void*)); + for(i = handleMax; i < handleMax+delta; i++) + handleBuf[i] = NULL; + handleMax += delta; + } + handleBuf[handle] = ssl; + return handle; +} + +/* sqDestroySSL: Destroys an SSL instance. + Arguments: + handle - the SSL handle + Returns: Non-zero if successful. +*/ +sqInt sqDestroySSL(sqInt handle) { + sqSSL *ssl = sslFromHandle(handle); + if(ssl == NULL) return 0; + + FreeCredentialsHandle(&ssl->sslCred); + DeleteSecurityContext(&ssl->sslCtxt); + + if(ssl->certName) free(ssl->certName); + if(ssl->peerName) free(ssl->peerName); + if(ssl->dataBuf) free(ssl->dataBuf); + + free(ssl); + handleBuf[handle] = NULL; + return 1; +} + +/* sqConnectSSL: Start/continue an SSL client handshake. + Arguments: + handle - the SSL handle + srcBuf - the input token sent by the remote peer + srcLen - the size of the input token + dstBuf - the output buffer for a new token + dstLen - the size of the output buffer + Returns: The size of the output token or an error code. +*/ +sqInt sqConnectSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) { + SecBufferDesc *sbdIn = NULL; + SECURITY_STATUS ret; + SCHANNEL_CRED sc_cred = { 0 }; + ULONG sslFlags, retFlags; + sqSSL *ssl = sslFromHandle(handle); + + /* Verify state of session */ + if(ssl == NULL || (ssl->state != SQSSL_UNUSED && ssl->state != SQSSL_CONNECTING)) { + return SQSSL_INVALID_STATE; + } + + if(ssl->dataLen + srcLen > ssl->dataMax) { + /* resize the data buffer */ + ssl->dataMax += (srcLen < 4096) ? (4096) : (srcLen+1024); + ssl->dataBuf = realloc(ssl->dataBuf, ssl->dataMax); + if(!ssl->dataBuf) return SQSSL_OUT_OF_MEMORY; + } + if(ssl->loglevel) printf("sqConnectSSL: input token %d bytes\n", srcLen); + memcpy(ssl->dataBuf + ssl->dataLen, srcBuf, srcLen); + ssl->dataLen += srcLen; + + /* Standard flags for SSL connection */ + sslFlags = + ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR | + ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_STREAM | + ISC_REQ_MANUAL_CRED_VALIDATION; + + /* Set up the input and output buffers */ + ssl->inbuf[0].BufferType = SECBUFFER_TOKEN; + ssl->inbuf[0].cbBuffer = ssl->dataLen; + ssl->inbuf[0].pvBuffer = ssl->dataBuf; + ssl->inbuf[1].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[1].cbBuffer = 0; + ssl->inbuf[1].pvBuffer = NULL; + ssl->inbuf[2].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[2].cbBuffer = 0; + ssl->inbuf[2].pvBuffer = NULL; + ssl->inbuf[3].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[3].cbBuffer = 0; + ssl->inbuf[3].pvBuffer = NULL; + ssl->sbdIn.cBuffers = 4; + + ssl->outbuf[0].BufferType = SECBUFFER_EMPTY; + ssl->outbuf[0].cbBuffer = 0; + ssl->outbuf[0].pvBuffer = NULL; + ssl->outbuf[1].BufferType = SECBUFFER_EMPTY; + ssl->outbuf[1].cbBuffer = 0; + ssl->outbuf[1].pvBuffer = NULL; + ssl->sbdOut.cBuffers = 2; + + if(ssl->loglevel) printf("sqConnectSSL: Input to InitSecCtxt is %d bytes\n", ssl->dataLen); + + if(ssl->state == SQSSL_UNUSED) { + ssl->state = SQSSL_CONNECTING; + + if(!sqSetupCert(ssl, ssl->certName, 0)) + /* FIXME. We need a different error code here. */ + return SQSSL_GENERIC_ERROR; + + ret = InitializeSecurityContext(&ssl->sslCred, NULL, NULL, + sslFlags, 0, 0, NULL, 0, &ssl->sslCtxt, + &ssl->sbdOut, &retFlags, NULL); + } else { + ret = InitializeSecurityContext(&ssl->sslCred, &ssl->sslCtxt, NULL, + sslFlags, 0, 0, &ssl->sbdIn, 0, NULL, + &ssl->sbdOut, &retFlags, NULL); + } + + if(ssl->loglevel) printf("InitializeSecurityContext returned: %x\n", ret); + + if(ssl->loglevel) sqPrintSBD("Input Buffers:", ssl->sbdIn); + if(ssl->loglevel) sqPrintSBD("Output Buffers:", ssl->sbdOut); + + if(ret != SEC_E_OK) { + int count; + /* Handle various failure conditions */ + switch(ret) { + case SEC_I_CONTINUE_NEEDED: + /* Send contents back to peer and come back with more data */ + count = sqCopyDescToken(ssl, ssl->sbdOut, dstBuf, dstLen); + /* Sanity checks for buffers */ + if(ssl->inbuf[0].BufferType != SECBUFFER_TOKEN) { + if(ssl->loglevel) printf("sqConnectSSL: Unexpected buffer[0].BufferType -- %d\n", ssl->inbuf[0].BufferType); + } + if(ssl->inbuf[2].BufferType != SECBUFFER_EMPTY) { + if(ssl->loglevel) printf("sqConnectSSL: Unexpected buffer[2].BufferType -- %d\n", ssl->inbuf[0].BufferType); + } + + /* If there is SECBUFFER_EXTRA in the input we need to retain it */ + if(ssl->inbuf[1].BufferType == SECBUFFER_EXTRA) { + int extra = ssl->inbuf[1].cbBuffer; + if(ssl->loglevel) printf("sqConnectSSL: Retaining %d token bytes\n", extra); + memmove(ssl->dataBuf, ssl->dataBuf + (ssl->dataLen - extra), extra); + ssl->dataLen = extra; + } else ssl->dataLen = 0; + + /* Don't return zero (SQSSL_OK) when more data is needed */ + return count ? count : SQSSL_NEED_MORE_DATA; + default: + if(ssl->loglevel) printf("Unexpected return code %d\n", ret); + return SQSSL_GENERIC_ERROR; + } + } + + /* TODO: Look at retFlags */ + ssl->state = SQSSL_CONNECTED; + sqCopyExtraData(ssl, ssl->sbdOut); + ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_STREAM_SIZES, &ssl->sslSizes); + if(ssl->loglevel) printf("sqConnectSSL: Maximum message size is %d bytes\n", ssl->sslSizes.cbMaximumMessage); + + /* Extract the peer name */ + sqExtractPeerName(ssl); + + /* Verify the certificate (sets certFlags) */ + sqVerifyCert(ssl, false); + + return SQSSL_OK; +} + +/* sqAcceptSSL: Start/continue an SSL server handshake. + Arguments: + handle - the SSL handle + srcBuf - the input token sent by the remote peer + srcLen - the size of the input token + dstBuf - the output buffer for a new token + dstLen - the size of the output buffer + Returns: The size of the output token or an error code. +*/ +sqInt sqAcceptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) { + SECURITY_STATUS ret; + SCHANNEL_CRED sc_cred = { 0 }; + ULONG sslFlags, retFlags; + sqSSL *ssl = sslFromHandle(handle); + + /* Verify state of session */ + if(ssl == NULL || (ssl->state != SQSSL_UNUSED && ssl->state != SQSSL_ACCEPTING)) { + return SQSSL_INVALID_STATE; + } + + /* Standard flags for SSL connection */ + sslFlags = + ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_CONFIDENTIALITY | ASC_REQ_EXTENDED_ERROR | + ASC_REQ_INTEGRITY | ASC_REQ_REPLAY_DETECT | ASC_REQ_STREAM; + + ssl->inbuf[0].BufferType = SECBUFFER_TOKEN; + ssl->inbuf[0].cbBuffer = srcLen; + ssl->inbuf[0].pvBuffer = srcBuf; + ssl->inbuf[1].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[1].cbBuffer = 0; + ssl->inbuf[1].pvBuffer = NULL; + + ssl->sbdIn.cBuffers = 2; + + ssl->outbuf[0].BufferType = SECBUFFER_EMPTY; + ssl->outbuf[0].cbBuffer = 0; + ssl->outbuf[0].pvBuffer = NULL; + ssl->outbuf[1].BufferType = SECBUFFER_EMPTY; + ssl->outbuf[1].cbBuffer = 0; + ssl->outbuf[1].pvBuffer = NULL; + + ssl->sbdOut.cBuffers = 2; + + if(ssl->state == SQSSL_UNUSED) { + ssl->state = SQSSL_ACCEPTING; + + if(!sqSetupCert(ssl, ssl->certName, 1)) + return SQSSL_GENERIC_ERROR; + + ret = AcceptSecurityContext(&ssl->sslCred, NULL, &ssl->sbdIn, sslFlags, + SECURITY_NATIVE_DREP, &ssl->sslCtxt, &ssl->sbdOut, + &retFlags, NULL); + } else { + ret = AcceptSecurityContext(&ssl->sslCred, &ssl->sslCtxt, &ssl->sbdIn, sslFlags, + SECURITY_NATIVE_DREP, &ssl->sslCtxt, &ssl->sbdOut, + &retFlags, NULL); + } + + if(ssl->loglevel) printf("AcceptSecurityContext returned: %x\n", ret); + + if(ret != SEC_E_OK) { + /* Handle various failure conditions */ + switch(ret) { + case SEC_I_CONTINUE_NEEDED: + /* Send contents back to peer and come back with more data */ + return sqCopyDescToken(ssl, ssl->sbdOut, dstBuf, dstLen); + default: + if(ssl->loglevel) printf("Unexpected return code %d\n", ret); + return SQSSL_GENERIC_ERROR; + } + } + + /* TODO: Look at retFlags */ + ssl->state = SQSSL_CONNECTED; + ret = QueryContextAttributes(&ssl->sslCtxt, SECPKG_ATTR_STREAM_SIZES, &ssl->sslSizes); + if(ssl->loglevel) printf("sqAcceptSSL: Maximum message size is %d bytes\n", ssl->sslSizes.cbMaximumMessage); + + /* Extract the peer name */ + sqExtractPeerName(ssl); + + /* Verify the certificate (sets certFlags) */ + sqVerifyCert(ssl, true); + + return sqCopyDescToken(ssl, ssl->sbdOut, dstBuf, dstLen); +} + +/* sqEncryptSSL: Encrypt data for SSL transmission. + Arguments: + handle - the SSL handle + srcBuf - the unencrypted input data + srcLen - the size of the input data + dstBuf - the output buffer for the encrypted contents + dstLen - the size of the output buffer + Returns: The size of the output generated or an error code. +*/ +sqInt sqEncryptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) { + SECURITY_STATUS ret; + sqInt total; + sqSSL *ssl = sslFromHandle(handle); + + if(ssl == NULL || ssl->state != SQSSL_CONNECTED) return SQSSL_INVALID_STATE; + + if(ssl->loglevel) printf("sqEncryptSSL: Encrypting %d bytes\n", srcLen); + + if(srcLen > (int)ssl->sslSizes.cbMaximumMessage) + return SQSSL_INPUT_TOO_LARGE; + + ssl->inbuf[0].BufferType = SECBUFFER_STREAM_HEADER; + ssl->inbuf[0].cbBuffer = ssl->sslSizes.cbHeader; + ssl->inbuf[0].pvBuffer = dstBuf; + + ssl->inbuf[1].BufferType = SECBUFFER_DATA; + ssl->inbuf[1].cbBuffer = srcLen; + ssl->inbuf[1].pvBuffer = dstBuf + ssl->inbuf[0].cbBuffer; + + ssl->inbuf[2].BufferType = SECBUFFER_STREAM_TRAILER; + ssl->inbuf[2].cbBuffer = ssl->sslSizes.cbTrailer; + ssl->inbuf[2].pvBuffer = dstBuf + ssl->inbuf[0].cbBuffer + ssl->inbuf[1].cbBuffer; + + ssl->inbuf[3].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[3].cbBuffer = 0; + ssl->inbuf[3].pvBuffer = NULL; + + ssl->sbdIn.cBuffers = 4; + + /* Check to ensure that encrypted contents fits dstBuf. + Fail with BUFFER_TOO_SMALL to allow caller to retry. */ + total = ssl->inbuf[0].cbBuffer + ssl->inbuf[1].cbBuffer + ssl->inbuf[2].cbBuffer; + if(dstLen < total) return SQSSL_BUFFER_TOO_SMALL; + + memcpy(ssl->inbuf[1].pvBuffer, srcBuf, srcLen); + + if(ssl->loglevel) printf("Header: %d; Data: %d; Trailer: %d\n", + ssl->inbuf[0].cbBuffer, ssl->inbuf[1].cbBuffer, ssl->inbuf[2].cbBuffer); + + ret = EncryptMessage(&ssl->sslCtxt, 0, &ssl->sbdIn, 0); + + if (ret != SEC_E_OK) { + if(ssl->loglevel) printf("EncryptMessage returned: %x\n", ret); + return SQSSL_GENERIC_ERROR; + } + + /* Return total amount of encrypted contents. + Must recompute total here since trailer may be overestimated */ + total = ssl->inbuf[0].cbBuffer + ssl->inbuf[1].cbBuffer + ssl->inbuf[2].cbBuffer; + return total; +} + +/* sqDecryptSSL: Decrypt data for SSL transmission. + Arguments: + handle - the SSL handle + srcBuf - the encrypted input data + srcLen - the size of the input data + dstBuf - the output buffer for the decrypted contents + dstLen - the size of the output buffer + Returns: The size of the output generated or an error code. +*/ +sqInt sqDecryptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt dstLen) { + int i, total; + SECURITY_STATUS ret; + sqSSL *ssl = sslFromHandle(handle); + + if(ssl == NULL || ssl->state != SQSSL_CONNECTED) return SQSSL_INVALID_STATE; + + if(ssl->dataLen + srcLen > ssl->dataMax) { + /* resize the read buffer */ + ssl->dataMax += (srcLen < 4096) ? (4096) : (srcLen+1024); + ssl->dataBuf = realloc(ssl->dataBuf, ssl->dataMax); + if(!ssl->dataBuf) return SQSSL_OUT_OF_MEMORY; + } + if(ssl->loglevel) printf("sqDecryptSSL: Input data %d bytes\n", srcLen); + memcpy(ssl->dataBuf + ssl->dataLen, srcBuf, srcLen); + ssl->dataLen += srcLen; + + if(ssl->loglevel) printf("sqDecryptSSL: Decrypting %d bytes\n", ssl->dataLen); + + ssl->inbuf[0].BufferType = SECBUFFER_DATA; + ssl->inbuf[0].cbBuffer = ssl->dataLen; + ssl->inbuf[0].pvBuffer = ssl->dataBuf; + + ssl->inbuf[1].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[1].cbBuffer = 0; + ssl->inbuf[1].pvBuffer = NULL; + + ssl->inbuf[2].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[2].cbBuffer = 0; + ssl->inbuf[2].pvBuffer = NULL; + + ssl->inbuf[3].BufferType = SECBUFFER_EMPTY; + ssl->inbuf[3].cbBuffer = 0; + ssl->inbuf[3].pvBuffer = NULL; + + ssl->sbdIn.cBuffers = 4; + ret = DecryptMessage(&ssl->sslCtxt, &ssl->sbdIn, 0, 0); + + /* Copy the result into destination buffer */ + total = 0; + for(i=0;i<4;i++) { + int buftype = ssl->inbuf[i].BufferType; + int count = ssl->inbuf[i].cbBuffer; + char *buffer = ssl->inbuf[i].pvBuffer; + if(ssl->loglevel) printf("buf[%d]: %d (%d bytes) ptr=%x\n", i,buftype, count, (int)buffer); + if(buftype == SECBUFFER_DATA) { + if(count > dstLen) return SQSSL_BUFFER_TOO_SMALL; + memcpy(dstBuf, buffer, count); + dstBuf += count; + total += count; + dstLen -= count; + } + } + + /* We ran out of steam. Hopefully this was because we've produced + a bunch'o bits from the decryption. */ + if(ret == SEC_E_OK || ret == SEC_E_INCOMPLETE_MESSAGE) { + /* Retain any remaining extra buffers and return output */ + sqCopyExtraData(ssl, ssl->sbdIn); + /* Return the total number of bytes decrypted */ + return total; + } + + if(ssl->loglevel) printf("DecryptMessage returned: %x\n", ret); + return SQSSL_GENERIC_ERROR; +} + +/* sqGetStringPropertySSL: Retrieve a string property from SSL. + Arguments: + handle - the ssl handle + propID - the property id to retrieve + Returns: The string value of the property. +*/ +char* sqGetStringPropertySSL(sqInt handle, int propID) { + sqSSL *ssl = sslFromHandle(handle); + + if(ssl == NULL) return NULL; + switch(propID) { + case SQSSL_PROP_PEERNAME: return ssl->peerName; + case SQSSL_PROP_CERTNAME: return ssl->certName; + default: + if(ssl->loglevel) printf("sqGetStringPropertySSL: Unknown property ID %d\n", propID); + return NULL; + } + return NULL; +} + +/* sqSetStringPropertySSL: Set a string property in SSL. + Arguments: + handle - the ssl handle + propID - the property id to retrieve + propName - the property string + propLen - the length of the property string + Returns: Non-zero if successful. +*/ +sqInt sqSetStringPropertySSL(sqInt handle, int propID, char *propName, sqInt propLen) { + sqSSL *ssl = sslFromHandle(handle); + char *property = NULL; + + if(ssl == NULL) return 0; + + if(propLen) { + property = calloc(1, propLen+1); + memcpy(property, propName, propLen); + }; + + if(ssl->loglevel) printf("sqSetStringPropertySSL(%d): %s\n", propID, property); + + switch(propID) { + case SQSSL_PROP_CERTNAME: ssl->certName = property; break; + default: + if(ssl->loglevel) printf("sqSetStringPropertySSL: Unknown property ID %d\n", propID); + return 0; + } + return 1; +} + +/* sqGetIntPropertySSL: Retrieve an integer property from SSL. + Arguments: + handle - the ssl handle + propID - the property id to retrieve + Returns: The integer value of the property. +*/ +int sqGetIntPropertySSL(sqInt handle, int propID) { + sqSSL *ssl = sslFromHandle(handle); + + if(ssl == NULL) return 0; + switch(propID) { + case SQSSL_PROP_SSLSTATE: return ssl->state; + case SQSSL_PROP_CERTSTATE: return ssl->certFlags; + case SQSSL_PROP_VERSION: return 1; + case SQSSL_PROP_LOGLEVEL: return ssl->loglevel; + default: + if(ssl->loglevel) printf("sqGetIntPropertySSL: Unknown property ID %d\n", propID); + return 0; + } + return 0; +} + +/* sqSetIntPropertySSL: Set an integer property in SSL. + Arguments: + handle - the ssl handle + propID - the property id to retrieve + propValue - the property value + Returns: Non-zero if successful. +*/ +sqInt sqSetIntPropertySSL(sqInt handle, sqInt propID, sqInt propValue) { + sqSSL *ssl = sslFromHandle(handle); + if(ssl == NULL) return 0; + + switch(propID) { + case SQSSL_PROP_LOGLEVEL: ssl->loglevel = propValue; break; + default: + if(ssl->loglevel) printf("sqSetIntPropertySSL: Unknown property ID %d\n", propID); + return 0; + } + return 0; +} |
Free forum by Nabble | Edit this page |