diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c index 778849e..fb20c6a 100644 --- a/src/backend/libpq/auth.c +++ b/src/backend/libpq/auth.c @@ -1216,8 +1216,7 @@ pg_SSPI_recvauth(Port *port) StringInfoData buf; SECURITY_STATUS r; CredHandle sspicred; - CtxtHandle *sspictx = NULL, - newctx; + CtxtHandle sspictx; TimeStamp expiry; ULONG contextattr; SecBufferDesc inbuf; @@ -1233,6 +1232,7 @@ pg_SSPI_recvauth(Port *port) DWORD domainnamesize = sizeof(domainname); SID_NAME_USE accountnameuse; HMODULE secur32; + bool isfirstcall = true; QUERY_SECURITY_CONTEXT_TOKEN_FN _QuerySecurityContextToken; /* @@ -1313,15 +1313,17 @@ pg_SSPI_recvauth(Port *port) (unsigned int) buf.len); r = AcceptSecurityContext(&sspicred, - sspictx, + (isfirstcall) ? NULL : &sspictx, &inbuf, ASC_REQ_ALLOCATE_MEMORY, SECURITY_NETWORK_DREP, - &newctx, + &sspictx, &outbuf, &contextattr, NULL); + isfirstcall = false; + /* input buffer no longer used */ pfree(buf.data); @@ -1343,26 +1345,12 @@ pg_SSPI_recvauth(Port *port) if (r != SEC_E_OK && r != SEC_I_CONTINUE_NEEDED) { - if (sspictx != NULL) - { - DeleteSecurityContext(sspictx); - free(sspictx); - } + DeleteSecurityContext(&sspictx); FreeCredentialsHandle(&sspicred); pg_SSPI_error(ERROR, _("could not accept SSPI security context"), r); } - if (sspictx == NULL) - { - sspictx = malloc(sizeof(CtxtHandle)); - if (sspictx == NULL) - ereport(ERROR, - (errmsg("out of memory"))); - - memcpy(sspictx, &newctx, sizeof(CtxtHandle)); - } - if (r == SEC_I_CONTINUE_NEEDED) elog(DEBUG4, "SSPI continue needed"); @@ -1401,7 +1389,7 @@ pg_SSPI_recvauth(Port *port) (int) GetLastError()))); } - r = (_QuerySecurityContextToken) (sspictx, &token); + r = (_QuerySecurityContextToken) (&sspictx, &token); if (r != SEC_E_OK) { FreeLibrary(secur32); @@ -1415,8 +1403,7 @@ pg_SSPI_recvauth(Port *port) * No longer need the security context, everything from here on uses the * token instead. */ - DeleteSecurityContext(sspictx); - free(sspictx); + DeleteSecurityContext(&sspictx); if (!GetTokenInformation(token, TokenUser, NULL, 0, &retlen) && GetLastError() != 122) ereport(ERROR,