# Copyright (c) 2024 Yuki Kimoto
# MIT License

class IO::Socket::SSL extends IO::Socket::IP {
  version "0.006";
  
  use IO::Socket::INET;
  use Net::SSLeay;
  use Net::SSLeay::SSL_CTX;
  use Net::SSLeay::SSL_METHOD;
  use Net::SSLeay::Constant as SSL;
  use Net::SSLeay::ERR;
  use Net::SSLeay::BIO;
  use Net::SSLeay::PEM;
  use StringBuffer;
  use Go;
  use Net::SSLeay::X509;
  use Net::SSLeay::EVP;
  use Fn;
  use Net::SSLeay::EVP_PKEY;
  use Net::SSLeay::Error::SSL_ERROR_WANT_READ;
  use Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
  use IO::Socket::SSL::Callback::BeforeConnectSSL;
  use IO::Socket::SSL::Callback::BeforeAcceptSSL;
  use Net::SSLeay::Callback::Verify;
  use List;
  use Regex;
  
  has ssl_ctx : ro Net::SSLeay::SSL_CTX;
  
  has ssl : ro Net::SSLeay;
  
  has before_connect_SSL_cbs_list : ro List of IO::Socket::SSL::Callback::BeforeConnectSSL;
  
  has before_accept_SSL_cbs_list : ro List of IO::Socket::SSL::Callback::BeforeAcceptSSL;
  
  has SSL_verify_mode : int;
  
  has SSL_verify_callback : Net::SSLeay::Callback::Verify;
  
  has SSL_verify_mode_specified : int;
  
  has SSL_hostname : string;
  
  has SSL_cipher_list : string;
  
  has SSL_ciphersuites : string;
  
  has SSL_check_crl : int;
  
  has SSL_crl_file : string;
  
  has SSL_server : int;
  
  has SSL_server_specified : int;
  
  has SSL_alpn_protocols : string[];
  
  has SSL_startHandshake : int;
  
  has SSL_honor_cipher_order : int;
  
  has SSL_honor_cipher_order_specified : int;
  
  has SSL_ca_file : string;
  
  has SSL_ca_path : string;
  
  has SSL_ca : Net::SSLeay::X509[];
  
  has SSL_cert_file : string;
  
  has SSL_cert : Net::SSLeay::X509[];
  
  has SSL_key_file : string;
  
  has SSL_key : Net::SSLeay::EVP_PKEY;
  
  # Class Methods
  static method new : IO::Socket::SSL ($options : object[] = undef) {
    
    my $self = new IO::Socket::SSL;
    
    $self->init($options);
    
    $self->configure;
    
    return $self;
  }
  
  # Instance Methods
  protected method option_names : string[] () {
    
    my $option_names = Array->merge_string(
      $self->SUPER::option_names,
      [
        "SSL_verify_mode",
        "SSL_verify_callback",
        "SSL_hostname",
        "SSL_cipher_list",
        "SSL_ciphersuites",
        "SSL_check_crl",
        "SSL_crl_file",
        "SSL_server",
        "SSL_alpn_protocols",
        "SSL_startHandshake",
        "SSL_honor_cipher_order",
        "SSL_ca_file",
        "SSL_ca_path",
        "SSL_ca",
        "SSL_cert_file",
        "SSL_cert",
        "SSL_key_file",
        "SSL_key"
      ],
    );
    
    return $option_names;
  }
  
  protected method init : void ($options : object[] = undef) {
    
    my $options_h = Hash->new($options);
    
    $self->SUPER::init($options);
    
    $self->{before_connect_SSL_cbs_list} = List->new(new IO::Socket::SSL::Callback::BeforeConnectSSL[0]);
    
    $self->{before_accept_SSL_cbs_list} = List->new(new IO::Socket::SSL::Callback::BeforeAcceptSSL[0]);
    
    if ($options_h->exists("SSL_verify_mode")) {
      $self->{SSL_verify_mode_specified} = 1;
      
      my $SSL_verify_mode = $options_h->get_int("SSL_verify_mode");
      $self->{SSL_verify_mode} = $SSL_verify_mode;
    }
    
    my $SSL_verify_callback = (Net::SSLeay::Callback::Verify)$options_h->get_or_default("SSL_verify_callback", undef);
    $self->{SSL_verify_callback} = $SSL_verify_callback;
    
    my $SSL_hostname = $options_h->get_or_default_string("SSL_hostname", undef);
    
    $self->{SSL_hostname} = $SSL_hostname;
    
    my $SSL_cipher_list = $options_h->get_or_default_string("SSL_cipher_list", "DEFAULT !EXP !MEDIUM !LOW !eNULL !aNULL !RC4 !DES !MD5 !PSK !SRP");
    
    $self->{SSL_cipher_list} = $SSL_cipher_list;
    
    my $SSL_ciphersuites = $options_h->get_or_default_string("SSL_ciphersuites", undef);
    
    $self->{SSL_ciphersuites} = $SSL_ciphersuites;
    
    my $SSL_check_crl = $options_h->get_or_default_int("SSL_check_crl", 0);
    
    $self->{SSL_check_crl} = $SSL_check_crl;
    
    my $SSL_crl_file = $options_h->get_or_default_string("SSL_crl_file", 0);
    
    $self->{SSL_crl_file} = $SSL_crl_file;
    
    my $SSL_server = $options_h->get_or_default_int("SSL_server", 0);
    
    if ($SSL_server) {
      $self->{SSL_server_specified} = 1;
      
      $self->{SSL_server} = $SSL_server;
    }
    
    my $SSL_alpn_protocols = (string[])$options_h->get_or_default("SSL_alpn_protocols", undef);
    $self->{SSL_alpn_protocols} = $SSL_alpn_protocols;
    
    my $SSL_startHandshake = $options_h->get_or_default_int("SSL_startHandshake", 1);
    $self->{SSL_startHandshake} = $SSL_startHandshake;
    
    if ($options_h->exists("SSL_honor_cipher_order")) {
      $self->{SSL_honor_cipher_order_specified} = 1;
      my $SSL_honor_cipher_order = $options_h->get_int("SSL_honor_cipher_order");
      $self->{SSL_honor_cipher_order} = $SSL_honor_cipher_order;
    }
    
    my $SSL_ca_file = $options_h->get_or_default_string("SSL_ca_file", undef);
    $self->{SSL_ca_file} = $SSL_ca_file;
    
    my $SSL_ca_path = $options_h->get_or_default_string("SSL_ca_path", undef);
    $self->{SSL_ca_path} = $SSL_ca_path;
    
    my $SSL_ca = (Net::SSLeay::X509[])$options_h->get_or_default("SSL_ca", undef);
    $self->{SSL_ca} = $SSL_ca;
    
    if ($SSL_ca && ($SSL_ca_file || $SSL_ca_path)) {
      die "If SSL_ca option is specified, SSL_ca_file option or SSL_ca_path option cannot be specified.";
    }
    
    my $SSL_cert_file = $options_h->get_or_default_string("SSL_cert_file", undef);
    $self->{SSL_cert_file} = $SSL_cert_file;
    
    my $SSL_cert = (Net::SSLeay::X509[])$options_h->get_or_default("SSL_cert", undef);
    $self->{SSL_cert} = $SSL_cert;
    
    if ($SSL_cert && $SSL_cert_file) {
      die "If SSL_cert option is specified, SSL_cert_file option cannot be specified.";
    }
    
    my $SSL_key_file = $options_h->get_or_default_string("SSL_key_file", undef);
    $self->{SSL_key_file} = $SSL_key_file;
    
    my $SSL_key = $options_h->get_or_default("SSL_key", undef);
    $self->{SSL_key} = (Net::SSLeay::EVP_PKEY)$SSL_key;
    
    if ($SSL_key && $SSL_key_file) {
      die "If SSL_key option is specified, SSL_key_file option cannot be specified.";
    }
    
  }
  
  protected method configure : void () {
    
    $self->SUPER::configure;
    
    $self->configure_SSL;
    
    # Client
    my $peer_address = $self->{PeerAddr};
    if ($peer_address) {
      
      my $SSL_startHandshake = $self->{SSL_startHandshake};
      
      if ($SSL_startHandshake) {
        $self->connect_SSL;
      }
    }
  }
  
  protected method configure_SSL : void () {
    
    my $peer_addr = $self->{PeerAddr};
    
    my $listen = $self->{Listen};
    
    my $SSL_server_specified = $self->{SSL_server_specified};
    
    my $SSL_server = 0;
    if ($SSL_server_specified) {
      $SSL_server = $self->{SSL_server};
    }
    else {
      if ($listen > 0) {
        $SSL_server = 1;
      }
    }
    
    my $ssl_method = Net::SSLeay::SSL_METHOD->TLS_method;
    
    my $ssl_ctx = Net::SSLeay::SSL_CTX->new($ssl_method);
    
    my $SSL_cipher_list = $self->{SSL_cipher_list};
    
    $ssl_ctx->set_cipher_list($SSL_cipher_list);
    
    my $SSL_ciphersuites = $self->{SSL_ciphersuites};
    
    if ($SSL_ciphersuites) {
      $ssl_ctx->set_ciphersuites($SSL_ciphersuites);
    }
    
    my $SSL_verify_mode = -1;
    if ($self->{SSL_verify_mode_specified}) {
      $SSL_verify_mode = $self->{SSL_verify_mode};
    }
    else {
      if ($SSL_server) {
        $SSL_verify_mode = SSL->SSL_VERIFY_NONE;
      }
      else {
        $SSL_verify_mode = SSL->SSL_VERIFY_PEER;
      }
    }
    
    my $SSL_verify_callback = $self->{SSL_verify_callback};
    
    $ssl_ctx->set_verify($SSL_verify_mode, $SSL_verify_callback);
    
    my $default_ssl_options = SSL->SSL_OP_ALL | SSL->SSL_OP_SINGLE_DH_USE | SSL->SSL_OP_SINGLE_ECDH_USE;
    
    my $ssl_options = $default_ssl_options;
    
    my $ssl_op_no_options = 0;
    $ssl_options |= $ssl_op_no_options;
    
    my $x509_store = $ssl_ctx->get_cert_store;
    
    my $SSL_ca_file = $self->{SSL_ca_file};
    
    my $SSL_ca_path = $self->{SSL_ca_path};
    
    my $SSL_ca = $self->{SSL_ca};
    
    if ($SSL_ca) {
      for (my $i = 0; $i < @$SSL_ca; $i++) {
        my $x509 = $SSL_ca->[$i];
        $x509_store->add_cert($x509);
      }
    }
    elsif ($SSL_ca_file || $SSL_ca_path) {
      $ssl_ctx->load_verify_locations($SSL_ca_file, $SSL_ca_path);
    }
    else {
      if (Sys::OS->is_windows) {
        $ssl_ctx->set_default_verify_paths_windows;
      }
      else {
        $ssl_ctx->set_default_verify_paths;
      }
    }
    
    my $SSL_cert_file = $self->{SSL_cert_file};
    
    my $SSL_cert = $self->{SSL_cert};
    
    if ($SSL_cert_file) {
      $ssl_ctx->use_certificate_chain_file($SSL_cert_file);
    }
    elsif ($SSL_cert) {
      for (my $i = 0; $i < @$SSL_cert; $i++) {
        my $x509 = $SSL_cert->[$i];
        if ($i == 0) {
          $ssl_ctx->use_certificate($x509);
        }
        else {
          $ssl_ctx->add_extra_chain_cert($x509);
        }
      }
    }
    
    my $SSL_key_file = $self->{SSL_key_file};
    
    my $SSL_key = $self->{SSL_key};
    
    if ($SSL_key_file) {
      $ssl_ctx->use_PrivateKey_file($SSL_key_file, SSL->SSL_FILETYPE_PEM);
    }
    elsif ($SSL_key) {
      $ssl_ctx->use_PrivateKey($SSL_key);
    }
    
    my $SSL_check_crl = $self->{SSL_check_crl};
    
    my $SSL_crl_file = $self->{SSL_crl_file};
    
    my $default_x509_store_flags = SSL->X509_V_FLAG_TRUSTED_FIRST;
    
    my $x509_store_flags = $default_x509_store_flags;
    if ($SSL_check_crl) {
      $x509_store_flags |= SSL->X509_V_FLAG_CRL_CHECK;
      
      if ($SSL_crl_file) {
        my $bio = Net::SSLeay::BIO->new_file($SSL_crl_file, "r");
        my $crl = Net::SSLeay::PEM->read_bio_X509_CRL($bio);
        my $cert_store = $ssl_ctx->get_cert_store;
        $cert_store->add_crl($crl);
      }
    }
    
    $x509_store->set_flags($x509_store_flags);
    
    my $SSL_alpn_protocols = $self->{SSL_alpn_protocols};
    if ($SSL_alpn_protocols) {
      if ($SSL_server) {
        $ssl_ctx->set_alpn_select_cb_with_protocols($SSL_alpn_protocols);
      } else {
        $ssl_ctx->set_alpn_protos_with_protocols($SSL_alpn_protocols);
      }
    }
    
    my $verify_param = $ssl_ctx->get0_param;
    
    $verify_param->set_hostflags(SSL->X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
    
    my $peer_addr_is_ip_address = 0;
    if ($peer_addr) {
      # Maybe IPv6
      if (Fn->contains($peer_addr, ":")) {
        $peer_addr_is_ip_address = 1;
      }
      # IPv4
      elsif (Regex->new("^(\d+)(?:\.(\d+)\.(\d+)\.(\d+)|[\d\.]*)$")->match($peer_addr)) {
        $peer_addr_is_ip_address = 1;
      }
      
      if ($peer_addr_is_ip_address) {
        $verify_param->set1_ip_asc($peer_addr);
      }
      else {
        $verify_param->set1_host($peer_addr);
      }
    }
    
    my $SSL_hostname = $self->{SSL_hostname};
    
    unless ($SSL_hostname) {
      if ($peer_addr && !$peer_addr_is_ip_address) {
        $SSL_hostname = $peer_addr;
      }
    }
    
    if ($SSL_hostname && $SSL_hostname ne "") {
      $self->add_before_connect_SSL_cb([$SSL_hostname : string] method : void ($socket : IO::Socket::SSL, $ssl : Net::SSLeay) {
        $ssl->set_tlsext_host_name($SSL_hostname);
      });
    }
    
    my $SSL_honor_cipher_order = 0;
    if ($self->{SSL_honor_cipher_order_specified}) {
      $SSL_honor_cipher_order = $self->{SSL_honor_cipher_order};
    }
    else {
      if ($self->{SSL_server}) {
        $SSL_honor_cipher_order = 1;
      }
    }
    # if ($SSL_honor_cipher_order) {
      $ssl_options |= SSL->SSL_OP_CIPHER_SERVER_PREFERENCE;
    # }
    
    $ssl_ctx->set_options($ssl_options);
    
    $self->{ssl_ctx} = $ssl_ctx;
    
  }
  
  method connect_SSL : void () {
    
    my $timeout = $self->{Timeout};
    
    my $fd = $self->{FD};
    
    my $ssl_ctx = $self->{ssl_ctx};
    
    my $ssl = Net::SSLeay->new($ssl_ctx);
    
    $ssl->set_fd($fd);
    
    my $before_connect_SSL_cbs_list = $self->{before_connect_SSL_cbs_list};
    my $before_connect_SSL_cbs_length = $before_connect_SSL_cbs_list->length;
    for (my $i = 0; $i < $before_connect_SSL_cbs_length; $i++) {
      my $cb = (IO::Socket::SSL::Callback::BeforeConnectSSL)$before_connect_SSL_cbs_list->get($i);
      $cb->($self, $ssl);
    }
    
    while (1) {
      eval { $ssl->connect; }
      
      if ($@) {
        my $again_read = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        
        my $again_write = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    
    $self->{ssl} = $ssl;
    
  }
  
  method accept_SSL : void () {
    
    my $fd = $self->{FD};
    
    my $timeout = $self->{Timeout};
    
    my $ssl_ctx = $self->{ssl_ctx};
    
    my $ssl = Net::SSLeay->new($ssl_ctx);
    
    $ssl->set_fd($fd);
    
    my $before_accept_SSL_cbs_list = $self->{before_accept_SSL_cbs_list};
    my $before_accept_SSL_cbs_length = $before_accept_SSL_cbs_list->length;
    for (my $i = 0; $i < $before_accept_SSL_cbs_length; $i++) {
      my $cb = (IO::Socket::SSL::Callback::BeforeAcceptSSL)$before_accept_SSL_cbs_list->get($i);
      $cb->($self, $ssl);
    }
    
    while (1) {
      eval { $ssl->accept; }
      
      if ($@) {
        my $again_read = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        
        my $again_write = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    
    $self->{ssl} = $ssl;
  }
  
  method accept : IO::Socket::SSL ($peer_ref : Sys::Socket::Sockaddr[] = undef) {
    my $client = (IO::Socket::SSL)$self->SUPER::accept($peer_ref);
    
    my $SSL_startHandshake = $self->{SSL_startHandshake};
    
    if ($SSL_startHandshake) {
      $self->accept_SSL;
    }
    
    return $client;
  }
  
  method read : int ($buffer : mutable string, $length : int = -1, $offset : int = 0) {
    
    my $ssl = $self->{ssl};
    
    my $fd = $self->{FD};
    
    my $timeout = $self->{Timeout};
    
    my $read_length = -1;
    while (1) {
      eval { $read_length = $ssl->read($buffer, $length, $offset); }
      
      if ($@) {
        my $again_read = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        
        my $again_write = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    
    return $read_length;
  }
  
  method write : int ($buffer : string, $length : int = -1, $offset : int = 0) {
    
    my $ssl = $self->{ssl};
    
    my $fd = $self->{FD};
    
    my $timeout = $self->{Timeout};
    
    my $write_length = -1;
    while (1) {
      
      eval { $write_length = $ssl->write($buffer, $length, $offset); }
      
      if ($@) {
        my $again_read = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        
        my $again_write = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    
    return $write_length;
  }
  
  method shutdown_SSL : int () {
    
    my $ssl = $self->{ssl};
    
    my $fd = $self->{FD};
    
    my $timeout = $self->{Timeout};
    
    my $status = -1;
    while (1) {
      
      eval { $status = $ssl->shutdown; }
      
      if ($@) {
        my $again_read = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        
        my $again_write = eval_error_id isa_error Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    
    return $status;
  }
  
  method close : void () {
    
    my $ssl = $self->{ssl};
    
    if ($ssl) {
      my $listen = $self->{Listen};
      
      # Only client sockets and accpected sockets
      unless ($listen > 0) {
        $self->shutdown_SSL;
      }
    }
    
    $self->SUPER::close;
    
  }
  
  method dump_peer_certificate : string () {
    
    my $ssl = $self->{ssl};
    
    my $dump = $ssl->dump_peer_certificate;
    
    return $dump;
  }
  
  method alpn_selected : string () {
    
    my $ssl = $self->{ssl};
    
    my $protocol = $ssl->get0_alpn_selected_return_string;
    
    return $protocol;
  }
  
  method get_sslversion : string () {
    
    my $ssl = $self->{ssl};
    
    my $version_string = $ssl->get_version;
    
    return $version_string;
  }
  
  method get_sslversion_int : int () {
    
    my $ssl = $self->{ssl};
    
    my $version = $ssl->version;
    
    return $version;
  }
  
  method get_cipher : string ()  {
    
    my $ssl = $self->{ssl};
    
    my $cipher = $ssl->get_cipher;
    
    return $cipher;
  }  
  
  method get_servername : string ()  {
    
    my $ssl = $self->{ssl};
    
    my $servername = $ssl->get_servername(SSL->TLSEXT_NAMETYPE_host_name);
    
    return $servername;
  }  
  
  method peer_certificate : Net::SSLeay::X509 () {
    
    my $ssl = $self->{ssl};
    
    my $cert = $ssl->get_peer_certificate;
    
    return $cert;
  }
  
  method peer_certificates : Net::SSLeay::X509[] () {
    
    my $ssl = $self->{ssl};
    
    my $x509_peer = $self->peer_certificate;
    
    my $x509s = (Net::SSLeay::X509[])undef;
    if ($x509_peer) {
      my $x509s_original = $ssl->get_peer_cert_chain;
      
      if ($self->{SSL_server}) {
        $x509s = (Net::SSLeay::X509[])Array->merge_object([$x509_peer], $x509s_original);
      }
      else {
        $x509s = $x509s_original;
      }
    }
    else {
      $x509s = new Net::SSLeay::X509[0];
    }
    
    return $x509s;
  }
  
  method sock_certificate : Net::SSLeay::X509 () {
    
    my $ssl = $self->{ssl};
    
    my $cert = $ssl->get_certificate;
    
    return $cert;
  }
  
  method add_before_connect_SSL_cb : void ($cb : IO::Socket::SSL::Callback::BeforeConnectSSL) {
    
    my $before_connect_SSL_cbs_list = $self->{before_connect_SSL_cbs_list};
    
    $before_connect_SSL_cbs_list->push($cb);
    
  }
  
  method add_before_accept_SSL_cb : void ($cb : IO::Socket::SSL::Callback::BeforeAcceptSSL) {
    
    my $before_accept_SSL_cbs_list = $self->{before_accept_SSL_cbs_list};
    
    $before_accept_SSL_cbs_list->push($cb);
    
  }
  
  method stat : Sys::IO::Stat () {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method send : int ($buffer : string, $flags : int = 0, $length : int = -1, $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method sendto : int ($buffer : string, $flags : int, $to : Sys::Socket::Sockaddr, $length : int = -1, $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method recv : int ($buffer : mutable string, $length : int = -1, $flags : int = 0, $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method recvfrom : int ($buffer : mutable string, $length : int, $flags : int, $from_ref : Sys::Socket::Sockaddr[], $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
}