#include "config.h"

#include <boost/algorithm/string/predicate.hpp>
#include <boost/property_tree/ptree.hpp>
#include <boost/property_tree/xml_parser.hpp>
#include <boost/timer/timer.hpp>

#include <algorithm>
#include <exception>
#include <iostream>

namespace fs = std::filesystem;
namespace pt = boost::property_tree;
using namespace std::string_literals;

namespace {

 void RemovePortFromHostname(std::string& hostname)
 {
  // besides hostnames and IPv4 addresses, consider IPv6 addresses: [xx::yy]:8080
  // so only remove ":N" if after "]"
  size_t pos = hostname.find_last_of(":]");
  if (pos != hostname.npos && hostname[pos] == ':') {
   hostname = hostname.substr(0, pos);
  }
 }

} // anonymous namespace

void Config::readConfigfile(const std::filesystem::path& filename)
{
 fs::path used_filename{filename};
 if (used_filename.empty()) {
  used_filename = default_filename;
 }

 pt::ptree tree;

 pt::read_xml(used_filename, tree, pt::xml_parser::no_comments | pt::xml_parser::trim_whitespace);

 // mandatory
 m_user = tree.get<std::string>("webserver.user");
 
 m_group = tree.get<std::string>("webserver.group");

 m_threads = tree.get<int>("webserver.threads");

 m_statistics_path = tree.get<std::string>("webserver.statisticspath", "/var/lib/webserver/stats.db");

 // optional entries
 auto elements = tree.get_child_optional("webserver");
 if (elements) {
  for (const auto& element: *elements) {
   if (element.first == "plugin-directory"s) {
    if (std::find(m_plugin_directories.begin(), m_plugin_directories.end(), element.second.data()) == m_plugin_directories.end())
     m_plugin_directories.push_back(element.second.data());
    else
     throw std::runtime_error("Found double plugin-directory element: "s + element.second.data());
   } else if (element.first == "sites"s) {
    for (const auto& site: element.second) {
     if (site.first != "site"s)
      throw std::runtime_error("<site> expected in <sites>");
     std::string site_name;
     Site site_struct;
     for (const auto& x: site.second) {
      if (x.first == "name"s) {
       if (site_name == "")
        site_name = x.second.data();
       else
        throw std::runtime_error("Found double site name: "s + x.second.data());
      } else if (x.first == "host"s) {
       if (site_struct.hosts.find(x.second.data()) == site_struct.hosts.end())
        site_struct.hosts.insert(x.second.data());
       else
        throw std::runtime_error("Found double site host element: "s + x.second.data());
      } else if (x.first == "path"s) {
       Path path;
       auto attrs = x.second.get_child("<xmlattr>");
       path.requested = attrs.get<std::string>("requested");
       for (const auto& param: x.second) { // get all sub-elements of <path>
        if (param.first.size() > 0 && param.first[0] != '<') { // exclude meta-elements like <xmlattr>
         if (param.first == "auth") {
          try {
           std::string login{param.second.get<std::string>("<xmlattr>.login")};
           if (path.auth.find(login) == path.auth.end()) {
            std::string password{param.second.get<std::string>("<xmlattr>.password")};
            path.auth[login] = password;
           } else
            throw std::runtime_error("Found double auth (login): "s + login);
          } catch (const std::exception& ex) {
           std::cerr << "Warning: Can't read auth data from config: " << ex.what() << std::endl;
          }
         } else {
          if (path.params.find(param.first) == path.params.end())
           path.params[param.first] = param.second.data();
          else
           throw std::runtime_error("Found double path param: "s + param.first + ": " + param.second.data());
         }
        }
       }
       if (std::find_if(site_struct.paths.begin(), site_struct.paths.end(), [&](const Path& p){ return p.requested == path.requested;}) == site_struct.paths.end())
        site_struct.paths.push_back(path);
       else
        throw std::runtime_error("Found double path spec: "s + path.requested);
      } else if (x.first == "certpath"s) {
       if (site_struct.cert_path == "")
        site_struct.cert_path = x.second.data();
       else
        throw std::runtime_error("Found double certpath: "s + x.second.data());
      } else if (x.first == "keypath"s) {
       if (site_struct.key_path == "")
        site_struct.key_path = x.second.data();
       else
        throw std::runtime_error("Found double keypath: "s + x.second.data());
      } else
       throw std::runtime_error("Unknown element: "s + x.first);
     }
     if (site_name.empty())
      throw std::runtime_error("Empty site name");
     if (m_sites.find(site_name) == m_sites.end())
      m_sites[site_name] = site_struct;
     else
      throw std::runtime_error("Found double site spec: "s + site_name);
    }
   } else if (element.first == "sockets"s) {
    for (const auto& socket: element.second) {
     if (socket.first != "socket"s)
      throw std::runtime_error("<socket> expected in <sockets>");
     Socket socket_struct;
     for (const auto& x: socket.second) {
      if (x.first == "address"s) {
       if (socket_struct.address == "")
        socket_struct.address = x.second.data();
       else
        throw std::runtime_error("Found double address spec: "s + x.second.data());
      } else if (x.first == "port"s) {
       if (socket_struct.port == "")
        socket_struct.port = x.second.data();
       else
        throw std::runtime_error("Found double port spec: "s + x.second.data());
      } else if (x.first == "protocol"s) {
       if (x.second.data() == "http"s)
        socket_struct.protocol = SocketProtocol::HTTP;
       else if (x.second.data() == "https"s)
        socket_struct.protocol = SocketProtocol::HTTPS;
       else
        throw std::runtime_error("Unknown protocol: "s + x.second.data());
      } else if (x.first == "site"s) {
       std::string site {x.second.data()};
       if (socket_struct.serve_sites.find(site) == socket_struct.serve_sites.end()) {
        socket_struct.serve_sites.insert(site);
       } else {
        throw std::runtime_error("Site "s + site + " already defined for "s + socket_struct.address + ", port " + socket_struct.port);
       }
      } else
       throw std::runtime_error("Unknown element: "s + x.first);
     }
     if (geteuid() != 0 && stoi(socket_struct.port) < 1024)
      std::cout << "Warning: Skipping privileged port " << socket_struct.port << std::endl;
     else
      m_sockets.push_back(socket_struct);
    }
   }
  }
 }

 expand_socket_sites();

 validate();

 create_look_up_table();
}

void Config::expand_socket_sites()
{
 // if no serving site is defined for a socket, serve all sites there
 for (auto& socket: m_sockets) {
  if (socket.serve_sites.empty()) {
   for (const auto& site: m_sites) {
    socket.serve_sites.insert(site.first);
   }
  }
 }
}

// just throws on inconsistency
void Config::validate()
{
 // make sure all m_sockets.serve_sites are configured in m_sites

 for (auto& socket: m_sockets) {
  for (auto& serve_site: socket.serve_sites) {
   if (m_sites.find(serve_site) == m_sites.end())
    throw std::runtime_error("Found serve_site "s + serve_site + " without configured site"s);
  }
 }
}

void Config::create_look_up_table()
{
 for (auto& socket: m_sockets) {
  for (const auto& site_name: socket.serve_sites) {
   Site& site {m_sites.at(site_name)}; // omit error check: validation previously made sure this exists
   for (const auto& host: site.hosts) {
    socket.host_lut[host] = &site;
   }
  }
 }
}

Config::Config(const std::filesystem::path& filename)
{
 readConfigfile(filename);
 dump();
}

std::string Config::User() const
{
 return m_user;
}

std::string Config::Group() const
{
 return m_group;
}

int Config::Threads() const
{
 return m_threads;
}

fs::path Config::statistics_path() const
{
 return m_statistics_path;
}

const std::vector<std::string>& Config::PluginDirectories() const
{
 return m_plugin_directories;
}

const std::unordered_map<std::string, Site>& Config::Sites() const
{
 return m_sites;
}

const std::vector<Socket>& Config::Sockets() const
{
 return m_sockets;
}

void Config::dump() const
{
 std::cout << "=== Configuration ===========================" << std::endl;
 std::cout << "User: " << m_user << std::endl;
 std::cout << "Group: " << m_group << std::endl;
 
 std::cout << "Threads: " << m_threads << std::endl;

 std::cout << "Statistics Path: " << statistics_path() << std::endl;

 std::cout << "Plugin Directories:";
 for (const auto& dir: m_plugin_directories)
  std::cout << " " << dir;
 std::cout << std::endl;

 for (const auto& site: m_sites) {
  std::cout << "Site: " << site.first << ":";
  for (const auto& host: site.second.hosts)
   std::cout << " " << host;
  std::cout << std::endl;
  if (site.second.paths.size() == 0)
   std::cout << "  Warning: No paths configured." << std::endl;
  for (const auto& path: site.second.paths) {
   std::cout << "  Path: " << path.requested << std::endl;
   for (const auto& param: path.params) {
    std::cout << "    " << param.first << ": " << param.second << std::endl;
   }
  }
  if (site.second.key_path != ""s) {
   std::cout << "  Key: " << site.second.key_path.generic_string() << std::endl;
   std::cout << "  Cert: " << site.second.cert_path.generic_string() << std::endl;
  }
 }

 for (const auto& socket: m_sockets) {
  std::cout << "Socket: " << socket.address << ":" << socket.port << " (" << (socket.protocol == SocketProtocol::HTTP ? "HTTP" : "HTTPS") << ")" << std::endl;
  std::cout << "  Serving:";
  for (const auto& site: socket.serve_sites) {
   std::cout << " " << site;
  }
  std::cout << std::endl;
 }
 std::cout << "=============================================" << std::endl;
}

// throws std::out_of_range if not found
const Path& Config::GetPath(const Socket& socket, const std::string& requested_host, const std::string& requested_path) const
{
 //boost::timer::auto_cpu_timer t;

 std::string host{requested_host};
 const Path* result{nullptr};
 size_t path_len{0}; // find longest matching prefix

 RemovePortFromHostname(host);

 const Site& site{*socket.host_lut.at(host)}; // can throw out_of_range

 for (const auto& path: site.paths) {
  if (boost::starts_with(requested_path, path.requested) &&
      ("/?"s.find(requested_path[path.requested.size()]) != std::string::npos ||
       requested_path[path.requested.size()] == 0 ||
       requested_path[path.requested.size() - 1] == '/'
      ) &&
      path.requested.size() > path_len)
  {
   path_len = path.requested.size();
   result = &path;
  }
 }

 if (result == nullptr)
  throw std::out_of_range("Path not found for "s + requested_host + " " + requested_path);

 return *result;
}

bool Config::PluginIsConfigured(const std::string& name) const
{
 for (const auto& site: m_sites) {
  for (const auto& path: site.second.paths) {
   auto it{path.params.find("plugin")};
   if (it != path.params.end() && it->second == name)
    return true;
  }
 }
 return false;
}