Files
ANSCORE/anslicensing/validator.h

225 lines
6.0 KiB
C++

#pragma once
#include <memory>
#include "bitstream.h"
#include "base64.h"
#include "base32.h"
#ifdef _WIN32
#include <tchar.h>
#endif
using namespace std;
class KeyValidatorImpl {
public:
KeyValidatorImpl()
{
}
KeyValidatorImpl(const LicenseTemplateImpl* templ, const char * key = NULL)
{
SetKeyTemplate(templ);
if (key)
SetKey(key);
}
~KeyValidatorImpl()
{
}
void SetKeyTemplate(const LicenseTemplateImpl* templ)
{
m_keyTemplate = templ;
if (templ->m_validationDataSize)
{
m_validationData.Create(templ->m_validationDataSize);
for (const auto& field : m_keyTemplate->m_validationFields)
m_validationData.AddField(field.first.c_str(), field.second.type, field.second.size, field.second.offset);
}
else
m_validationData.Create(0);
}
void SetValidationData(const char * fieldName, const void * buf, int len)
{
m_validationData.Set(fieldName, buf, len);
}
void SetValidationData(const char * fieldName, const char * data)
{
m_validationData.Set(fieldName, data);
}
void SetValidationData(const char * fieldName, int data)
{
m_validationData.Set(fieldName, data);
}
void SetValidationData(const char * fieldName, int year, int month, int day)
{
m_validationData.Set(fieldName, year, month, day);
}
void SetKey(const char * licKey)
{
if (!licKey)
throw new LicensingException(STATUS_INVALID_PARAM, "invalid license key");
string key(licKey);
if (key.length() < m_keyTemplate->GetCharactersPerGroup() * m_keyTemplate->GetNumberOfGroups() +
strlen(m_keyTemplate->GetHeader()) +
strlen(m_keyTemplate->GetFooter()) +
strlen(m_keyTemplate->GetGroupSeparator()) * (m_keyTemplate->GetNumberOfGroups() - 1))
{
throw new LicensingException(STATUS_INVALID_LICENSE_KEY, "invalid license key (too short)");
}
// remove header and footer (if any)
if (!m_keyTemplate->m_header.empty())
{
key.erase(0, m_keyTemplate->m_header.length() + 2);
}
if (!m_keyTemplate->m_footer.empty())
{
key.erase(key.length() - m_keyTemplate->m_footer.length() - 2, m_keyTemplate->m_footer.length() + 2);
}
// ungroup license key
for (int i = 0, erasePos = 0; i < m_keyTemplate->m_numGroups - 1; i++)
{
erasePos += m_keyTemplate->m_charsPerGroup;
key.erase(erasePos, m_keyTemplate->m_groupSeparator.length());
}
// decode license key
switch ( m_keyTemplate->m_keyEncoding )
{
case ENCODING_BASE32X:
{
BASE32 base32;
int padLen;
int len = base32.encode_pad_length(((int)key.length() * 5 + 7) >> 3, &padLen);
if (len > (int)key.length())
key.append(len - key.length(), 'A');
if (padLen)
key.append(padLen,'=');
auto keyBuf = base32.decode(key.c_str(), (int)key.length(), &len );
if (keyBuf.empty())
throw new LicensingException(STATUS_INVALID_LICENSE_KEY);
// reverse last byte
keyBuf[ len - 1 ] = (unsigned char)(((keyBuf[ len - 1 ] * 0x0802LU & 0x22110LU) | (keyBuf[ len - 1 ] * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16);
m_keyData.Attach(keyBuf, m_keyTemplate->m_keyEncoding * m_keyTemplate->m_charsPerGroup * m_keyTemplate->m_numGroups);
}
break;
case ENCODING_BASE64X:
{
BASE64 base64;
int padLen;
int len = base64.encode_pad_length(((int)key.length() * 6 + 7) >> 3, &padLen);
if (len > (int)key.length())
key.append(len - key.length(), 'A');
if (padLen)
key.append(padLen,'=');
auto keyBuf = base64.decode(key.c_str(), (int)key.length(), &len );
if (keyBuf.empty())
throw new LicensingException(STATUS_INVALID_LICENSE_KEY);
// reverse last byte
keyBuf[ len - 1 ] = (unsigned char)(((keyBuf[ len - 1 ] * 0x0802LU & 0x22110LU) | (keyBuf[ len - 1 ] * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16);
m_keyData.Attach(keyBuf, m_keyTemplate->m_keyEncoding * m_keyTemplate->m_charsPerGroup * m_keyTemplate->m_numGroups);
}
break;
default:
throw new LicensingException(STATUS_INVALID_KEY_ENCODING);
}
for (const auto& field : m_keyTemplate->m_dataFields)
m_keyData.AddField(field.first.c_str(), field.second.type, field.second.size, field.second.offset);
}
bool IsKeyValid()
{
ECC::Verifier verifier;
BitStream signedData,
signature;
signedData.Create(m_keyTemplate->m_dataSize + m_keyTemplate->m_validationDataSize);
signedData.Clear();
if (m_keyTemplate->m_dataSize)
signedData.Write(m_keyData.GetBuffer(), m_keyTemplate->m_dataSize);
if (m_keyTemplate->m_validationDataSize)
signedData.Write(m_validationData.GetBuffer(), m_keyTemplate->m_validationDataSize);
signature.Create(m_keyTemplate->m_signatureSize);
m_keyData.GetBitStream().Seek(m_keyTemplate->m_dataSize);
m_keyData.GetBitStream().Read(signature.GetBuffer(), m_keyTemplate->m_signatureSize);
signature.ReleaseBuffer(m_keyTemplate->m_signatureSize);
signature.Seek(m_keyTemplate->m_signatureSize);
signature.ZeroPadToNextByte();
signature.ReleaseBuffer(m_keyTemplate->m_signatureSize);
// we use a different algorithm than ECDSA when the signature size must be smaller than twice the key size
if (m_keyTemplate->m_signatureSize < (m_keyTemplate->m_signatureKeySize << 1))
verifier.SetHashSize(m_keyTemplate->m_signatureSize - m_keyTemplate->m_signatureKeySize);
else
verifier.SetHashSize(0);
verifier.SetPublicKey(m_keyTemplate->m_verificationKey.get());
return verifier.Verify(signedData.GetBuffer(), (signedData.GetSize() + 7) >> 3, signature.GetBuffer(), (signature.GetSize() + 7) >> 3 , m_keyTemplate->m_signatureSize);
}
void QueryKeyData(const char * fieldName, void * buf, int * len)
{
m_keyData.Get(fieldName, buf, len);
}
int QueryIntKeyData(const char * fieldName)
{
return m_keyData.GetInt(fieldName);
}
void QueryDateKeyData(const char * fieldName, int * year, int * month, int * day)
{
m_keyData.GetDate(fieldName, year, month, day);
}
void QueryValidationData(const char * fieldName, void * buf, int * len)
{
return m_validationData.Get(fieldName, buf, len);
}
public:
const LicenseTemplateImpl* m_keyTemplate;
BitStruct m_keyData;
BitStruct m_validationData;
};