/*
 *  base64util.cc
 *  Base64Util
 *
 *  Created by Andrew Choi on 10/11/04.
 *  Copyright 2004 Andrew Choi. Perl Artistic License.
 *
 */

#include <sstream>
#include <iomanip>

#include "base64util.h"

static std::string Encode3To4Base64(unsigned char c0, unsigned char c1, unsigned char c2)
{
  static const char base64Code[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
  
  std::string result;
  
  result += base64Code[c0 >> 2];
  result += base64Code[((c0 & 0x03) << 4) | ((c1 & 0xf0) >> 4)];
  result += base64Code[((c1 & 0x0f) << 2) | ((c2 & 0xc0) >> 6)];
  result += base64Code[c2 & 0x3f];
  
  return result;
}

std::string EncodeBase64(std::string s)
{
  std::string result;
  
  int i;
  for (i = 0; i < s.length() / 3 * 3; i += 3)
    result += Encode3To4Base64(s[i], s[i+1], s[i+2]);
  
  if (s.length() - i == 1) {
    std::string last = Encode3To4Base64(s[i], '\0', '\0');
    result += std::string(last, 0, 2) + "==";
  }
  else if (s.length() - i == 2) {
    std::string last = Encode3To4Base64(s[i], s[i+1], '\0');
    result += std::string(last, 0, 3) + "=";
  }
  
  return result;
}

static unsigned char Decode1Base64(unsigned char c)
{
  // Table of values for ASCII codes 0x20 to 0x7f.  Note that the pad character '=' has value 0 so it won't be decoded as invalid.  Its value is never used in the decoded string.
  static const unsigned char inverseBase64Code[] = {
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x3e, 0xff, 0xff, 0xff, 0x3f,
    0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0xff, 0xff, 0xff, 0x00, 0xff, 0xff,
    0xff, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
    0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
    0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff
  };

  if (c < 0x20 || c > 0x7f || inverseBase64Code[c - 0x20] == 0xff) {
    std::ostringstream oss;
    oss << std::hex << std::setw(2) << std::setfill('0') << (unsigned int)(c & 0xff);    
    throw std::invalid_argument("Invalid character in base64 string to decode (hex value = " + oss.str() + ")");
  }
  
  return inverseBase64Code[c - 0x20];
}

static std::string Decode4To3Base64(std::string s4)
{
  std::string result;
  
  unsigned char v0 = Decode1Base64(s4[0]);
  unsigned char v1 = Decode1Base64(s4[1]);
  unsigned char v2 = Decode1Base64(s4[2]);
  unsigned char v3 = Decode1Base64(s4[3]);
  
  result += (unsigned char)((v0 << 2) | ((v1 & 0x30) >> 4));
  result += (unsigned char)(((v1 & 0x0f) << 4) | ((v2 & 0x3c) >> 2));
  result += (unsigned char)(((v2 & 0x03) << 6) | v3);
  
  return result;  
}

static std::string DecodeLast4To3Base64(std::string s4)
{
  std::string s3 = Decode4To3Base64(s4);
  
  if (s4[3] != '=')
    return s3;
  else if (s4[2] != '=')
    return std::string(s3, 0, 2);
  else
    return std::string(s3, 0, 1);
}

std::string DecodeBase64(std::string s) throw (std::invalid_argument)
{
  std::string result;
  
  if (s.length() % 4 != 0)
    throw std::invalid_argument("Length of base64 string to decode (\"" + s + "\") is not a multiple of 4");
  
  for (int i = 0; i < s.length() - 4; i += 4) {
    std::string sub(s, i, 4);
    if (sub.find('=') != std::string::npos)
      throw std::invalid_argument("Invalid pad character in middle of base64 string (\"..." + sub + "...\")");      
    result += Decode4To3Base64(sub);
  }
  
  result += DecodeLast4To3Base64(std::string(s, s.length() - 4, 4));
  
  return result;
}

static std::string Encode5To8Base32(unsigned char c0, unsigned char c1, unsigned char c2, unsigned char c3, unsigned char c4)
{
  static const char base32Code[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
  
  std::string result;
  
  result += base32Code[(c0 & 0xf8) >> 3];
  result += base32Code[((c0 & 0x07) << 2) | ((c1 & 0xc0) >> 6)];
  result += base32Code[(c1 & 0x3e) >> 1];
  result += base32Code[((c1 & 0x01) << 4) | ((c2 & 0xf0) >> 4)];
  result += base32Code[((c2 & 0x0f) << 1) | ((c3 & 0x80) >> 7)];
  result += base32Code[(c3 & 0x7c) >> 2];
  result += base32Code[((c3 & 0x03) << 3) | ((c4 & 0xe0) >> 5)];
  result += base32Code[c4 & 0x1f];
  
  return result;
}

std::string EncodeBase32(std::string s)
{
  std::string result;
  
  int i;
  for (i = 0; i < s.length() / 5 * 5; i += 5)
    result += Encode5To8Base32(s[i], s[i+1], s[i+2], s[i+3], s[i+4]);
  
  if (s.length() - i == 1) {
    std::string last = Encode5To8Base32(s[i], '\0', '\0', '\0', '\0');
    result += std::string(last, 0, 2) + "======";
  }
  else if (s.length() - i == 2) {
    std::string last = Encode5To8Base32(s[i], s[i+1], '\0', '\0', '\0');
    result += std::string(last, 0, 4) + "====";
  }
  else if (s.length() - i == 3) {
    std::string last = Encode5To8Base32(s[i], s[i+1], s[i+2], '\0', '\0');
    result += std::string(last, 0, 5) + "===";
  }
  else if (s.length() - i == 4) {
    std::string last = Encode5To8Base32(s[i], s[i+1], s[i+2], s[i+3], '\0');
    result += std::string(last, 0, 7) + "=";
  }
  
  return result;
}

static unsigned char Decode1Base32(unsigned char c)
{
  // Table of values for ASCII codes 0x30 to 0x5f.
  static const unsigned char inverseBase32Code[] = {
    0xff, 0xff, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0xff, 0xff,
    0xff, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
    0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xff, 0xff, 0xff, 0xff, 0xff,
  };
  
  if (c < 0x30 || c > 0x5f || inverseBase32Code[c - 0x30] == 0xff) {
    std::ostringstream oss;
    oss << std::hex << std::setw(2) << std::setfill('0') << (unsigned int)(c & 0xff);    
    throw std::invalid_argument("Invalid character in base32 string to decode (hex value = " + oss.str() + ")");
  }
  
  return inverseBase32Code[c - 0x30];
}

static std::string Decode8To5Base32(std::string s8)
{
  std::string result;
  
  unsigned char v0 = Decode1Base32(s8[0]);
  unsigned char v1 = Decode1Base32(s8[1]);
  unsigned char v2 = Decode1Base32(s8[2]);
  unsigned char v3 = Decode1Base32(s8[3]);
  unsigned char v4 = Decode1Base32(s8[4]);
  unsigned char v5 = Decode1Base32(s8[5]);
  unsigned char v6 = Decode1Base32(s8[6]);
  unsigned char v7 = Decode1Base32(s8[7]);
  
  result += (unsigned char)((v0 << 3) | ((v1 & 0x1c) >> 2));
  result += (unsigned char)(((v1 & 0x03) << 6) | (v2 << 1) | ((v3 & 0x10) >> 4));
  result += (unsigned char)(((v3 & 0x0f) << 4) | ((v4 & 0x1e) >> 1));
  result += (unsigned char)(((v4 & 0x01) << 7) | (v5 << 2) | ((v6 & 0x18) >> 3));
  result += (unsigned char)(((v6 & 0x07) << 5) | v7);
  
  return result;
}

static std::string DecodeLast8To5Base32(std::string s8)
{
  std::string s5 = Decode8To5Base32(s8);
  
  if (s8[7] != '=')
    return s5;
  else if (s8[6] != '=')
    return std::string(s5, 0, 4);
  else if (s8[5] != '=')
    throw std::invalid_argument("Invalid number of pad characters (2) at the end of base32 string");    
  else if (s8[4] != '=')
    return std::string(s5, 0, 3);
  else if (s8[3] != '=')
    return std::string(s5, 0, 2);
  else if (s8[2] != '=')
    throw std::invalid_argument("Invalid number of pad characters (5) at the end of base32 string");    
  else
    return std::string(s5, 0, 1);
}

std::string DecodeBase32(std::string s) throw (std::invalid_argument)
{
  std::string result;
  
  if (s.length() % 8 != 0)
    throw std::invalid_argument("Length of base32 string to decode (\"" + s + "\") is not a multiple of 8");
  
  for (int i = 0; i < s.length() - 8; i += 8) {
    std::string sub(s, i, 8);
    if (sub.find('=') != std::string::npos)
      throw std::invalid_argument("Invalid pad character in middle of base32 string (\"..." + sub + "...\")");      
    result += Decode8To5Base32(sub);
  }
  
  result += DecodeLast8To5Base32(std::string(s, s.length() - 8, 8));
  
  return result;
}


