# Copyright (c) 2024 Yuki Kimoto
# MIT License

class IO::Socket::SSL extends IO::Socket::IP {
  version "0.019";
  
  use IO::Socket::SSL::Callback::BeforeConnectSSL;
  use IO::Socket::SSL::Callback::BeforeAcceptSSL;
  
  use Net::SSLeay;
  use Net::SSLeay::SSL_CTX;
  use Net::SSLeay::SSL_METHOD;
  use Net::SSLeay::Constant as SSL;
  use Net::SSLeay::ERR;
  use Net::SSLeay::BIO;
  use Net::SSLeay::PEM;
  use Net::SSLeay::X509;
  use Net::SSLeay::X509_STORE;
  use Net::SSLeay::EVP;
  use Net::SSLeay::EVP_PKEY;
  use Net::SSLeay::Error::SSL_ERROR_WANT_READ;
  use Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
  use Net::SSLeay::Callback::Verify;
  use Net::SSLeay::Callback::PemPassword;
  
  use Fn;
  use StringBuffer;
  use Go;
  use List;
  use Re;
  
  # Enemerations
  enum {
    SOCKET_CATEGORY_CLIENT,
    SOCKET_CATEGORY_SERVER,
    SOCKET_CATEGORY_ACCEPTED,
  }
  
  # Fields
  has ssl_ctx : ro Net::SSLeay::SSL_CTX;
  
  has ssl : ro Net::SSLeay;
  
  has before_connect_SSL_cbs : ro IO::Socket::SSL::Callback::BeforeConnectSSL[];
  
  has before_accept_SSL_cbs : ro IO::Socket::SSL::Callback::BeforeAcceptSSL[];
  
  # Undocumented Fields
  has SocketCategory : int;
  
  has SSL_startHandshake : int;
  
  has SSL_verify_mode : int;
  
  has SSL_verify_callback : Net::SSLeay::Callback::Verify;
  
  has SSL_hostname : string;
  
  has SSL_passwd_cb : Net::SSLeay::Callback::PemPassword;
  
  has SSL_ca : Net::SSLeay::X509[];
  
  has SSL_ca_file : string;
  
  has SSL_ca_path : string;
  
  has SSL_cert : Net::SSLeay::X509[];
  
  has SSL_cert_file : string;
  
  has SSL_key : Net::SSLeay::EVP_PKEY;
  
  has SSL_key_file : string;
  
  has SSL_crl_file : string;
  
  has SSL_check_crl : int;
  
  has SSL_alpn_protocols : string[];
  
  # Class Methods
  static method new : IO::Socket::SSL ($options : object[] = undef) {
    
    my $self = new IO::Socket::SSL;
    
    $self->init($options);
    
    $self->configure;
    
    return $self;
  }
  
  # Instance Methods
  protected method option_names : string[] () {
    
    my $option_names = Array->merge_string(
      $self->SUPER::option_names,
      [
        "SSL_startHandshake",
        "SSL_verify_mode",
        "SSL_verify_callback",
        "SSL_hostname",
        "SSL_passwd_cb",
        "SSL_ca",
        "SSL_ca_file",
        "SSL_ca_path",
        "SSL_crl_file",
        "SSL_check_crl",
        "SSL_cert",
        "SSL_cert_file",
        "SSL_key",
        "SSL_key_file",
        "SSL_alpn_protocols",
      ],
    );
    
    return $option_names;
  }
  
  protected method init : void ($options : object[] = undef) {
    
    $self->SUPER::init($options);
    
    my $options = Hash->new($options);
    
    $self->{before_connect_SSL_cbs} = new IO::Socket::SSL::Callback::BeforeConnectSSL[0];
    
    $self->{before_accept_SSL_cbs} = new IO::Socket::SSL::Callback::BeforeAcceptSSL[0];
    
    if ($options->exists("SSL_verify_mode")) {
      $self->{SSL_verify_mode} = $options->get("SSL_verify_mode")->(int);
    }
    
    $self->{SSL_verify_callback} = $options->get_or_default("SSL_verify_callback", undef)->(Net::SSLeay::Callback::Verify);
    
    $self->{SSL_hostname} = $options->get_or_default("SSL_hostname", undef)->(string);
    
    $self->{SSL_passwd_cb} = $options->get_or_default("SSL_passwd_cb", undef)->(Net::SSLeay::Callback::PemPassword);
    
    $self->{SSL_crl_file} = $options->get_or_default("SSL_crl_file", undef)->(string);
    
    $self->{SSL_check_crl} = $options->get_or_default("SSL_check_crl", 0)->(int);
    
    $self->{SSL_startHandshake} = $options->get_or_default("SSL_startHandshake", 1)->(int);
    
    $self->{SSL_ca} = $options->get_or_default("SSL_ca", undef)->(Net::SSLeay::X509[]);
    
    $self->{SSL_ca_file} = $options->get_or_default("SSL_ca_file", undef)->(string);
    
    $self->{SSL_ca_path} = $options->get_or_default("SSL_ca_path", undef)->(string);
    
    if ($self->{SSL_ca} && ($self->{SSL_ca_file} || $self->{SSL_ca_path})) {
      die "If SSL_ca option is specified, SSL_ca_file option or SSL_ca_path option cannot be specified.";
    }
    
    $self->{SSL_cert} = $options->get_or_default("SSL_cert", undef)->(Net::SSLeay::X509[]);
    
    $self->{SSL_cert_file} = $options->get_or_default("SSL_cert_file", undef)->(string);
    
    if ($self->{SSL_cert} && $self->{SSL_cert_file}) {
      die "If SSL_cert option is specified, SSL_cert_file option cannot be specified.";
    }
    
    $self->{SSL_key} = $options->get_or_default("SSL_key", undef)->(Net::SSLeay::EVP_PKEY);
    
    $self->{SSL_key_file} = $options->get_or_default("SSL_key_file", undef)->(string);
    
    if ($self->{SSL_key} && $self->{SSL_key_file}) {
      die "If SSL_key option is specified, SSL_key_file option cannot be specified.";
    }
    
    $self->{SSL_alpn_protocols} = $options->get_or_default("SSL_alpn_protocols", undef)->(string[]);
  }
  
  protected method configure : void () {
    
    $self->SUPER::configure;
    
    $self->configure_SSL;
    
    # Client
    my $SSL_startHandshake = $self->{SSL_startHandshake};
    if ($SSL_startHandshake) {
      my $SocketCategory = $self->{SocketCategory};
      if ($SocketCategory == &SOCKET_CATEGORY_CLIENT) {
        $self->connect_SSL;
      }
    }
    
  }
  
  protected method configure_SSL : void () {
    
    my $PeerAddr = $self->{PeerAddr};
    
    my $Listen = $self->{Listen};
    
    my $FD = $self->{FD};
    
    # SocketCategory
    my $SocketCategory = -1;
    if ($Listen > 0) {
      $SocketCategory = &SOCKET_CATEGORY_SERVER;
    }
    elsif ($PeerAddr) {
      $SocketCategory = &SOCKET_CATEGORY_CLIENT;
    }
    elsif ($FD > 0) {
      $SocketCategory = &SOCKET_CATEGORY_ACCEPTED;
    }
    else {
      die "[Unexpected Error]A socket category cannot be detected.";
    }
    $self->{SocketCategory} = $SocketCategory;
    if ($SocketCategory == &SOCKET_CATEGORY_ACCEPTED) {
      return;
    }
    
    # ssl_method
    my $ssl_method = (Net::SSLeay::SSL_METHOD)undef;
    if ($SocketCategory == &SOCKET_CATEGORY_SERVER) {
      $ssl_method = Net::SSLeay::SSL_METHOD->TLS_server_method;
    }
    elsif ($SocketCategory == &SOCKET_CATEGORY_CLIENT) {
      $ssl_method = Net::SSLeay::SSL_METHOD->TLS_client_method;
    }
    else {
      $ssl_method = Net::SSLeay::SSL_METHOD->TLS_method;
    }
    
    # ssl_ctx
    my $ssl_ctx = Net::SSLeay::SSL_CTX->new($ssl_method);
    
    # SSL_verify_mode
    unless (exists $self->{SSL_verify_mode}) {
      if ($SocketCategory == &SOCKET_CATEGORY_SERVER) {
        $self->{SSL_verify_mode} = SSL->SSL_VERIFY_NONE;
      }
      elsif ($SocketCategory == &SOCKET_CATEGORY_CLIENT) {
        $self->{SSL_verify_mode} = SSL->SSL_VERIFY_PEER;
      }
    }
    
    # SSL_verify_callback
    my $SSL_verify_callback = $self->{SSL_verify_callback};
    $ssl_ctx->set_verify($self->{SSL_verify_mode}, $SSL_verify_callback);
    
    # verify_param
    my $verify_param = $ssl_ctx->get0_param;
    $verify_param->set_hostflags(SSL->X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
    my $PeerAddr_is_ip_address = 0;
    if ($PeerAddr) {
      # Maybe IPv6
      if (Fn->contains($PeerAddr, ":")) {
        $PeerAddr_is_ip_address = 1;
      }
      # IPv4
      elsif (Re->m($PeerAddr, "^(\d+)(?:\.(\d+)\.(\d+)\.(\d+)|[\d\.]*)$")) {
        $PeerAddr_is_ip_address = 1;
      }
      
      unless ($PeerAddr_is_ip_address) {
        $verify_param->set1_host($PeerAddr);
      }
    }
    
    # SSL_hostname
    my $SSL_hostname = $self->{SSL_hostname};
    if ($SocketCategory == &SOCKET_CATEGORY_CLIENT) {
      unless ($SSL_hostname) {
        if ($PeerAddr && !$PeerAddr_is_ip_address) {
          $SSL_hostname = $PeerAddr;
        }
      }
      
      if (length $SSL_hostname) {
        $self->add_before_connect_SSL_cb([$SSL_hostname : string] method : void ($socket : IO::Socket::SSL, $ssl : Net::SSLeay) {
          $ssl->set_tlsext_host_name($SSL_hostname);
        });
      }
    }
    
    # SSL_passwd_cb
    my $SSL_passwd_cb = $self->{SSL_passwd_cb};
    $ssl_ctx->set_default_passwd_cb($SSL_passwd_cb);
    
    # x509_store
    my $x509_store = $ssl_ctx->get_cert_store;
    my $x509_store_flags = 0;
    
    # SSL_ca, SSL_ca_file, SSL_ca_path
    my $SSL_ca_file = $self->{SSL_ca_file};
    my $SSL_ca_path = $self->{SSL_ca_path};
    
    my $SSL_ca = $self->{SSL_ca};
    if ($SSL_ca) {
      for (my $i = 0; $i < @$SSL_ca; $i++) {
        my $x509 = $SSL_ca->[$i];
        $x509_store->add_cert($x509);
      }
    }
    elsif ($SSL_ca_file || $SSL_ca_path) {
      $ssl_ctx->load_verify_locations($SSL_ca_file, $SSL_ca_path);
    }
    else {
      if (Sys::OS->is_windows) {
        $ssl_ctx->set_default_verify_paths_windows;
      }
      else {
        $ssl_ctx->set_default_verify_paths;
      }
    }
    
    # SSL_cert, SSL_cert_file
    my $SSL_cert = $self->{SSL_cert};
    my $SSL_cert_file = $self->{SSL_cert_file};
    if ($SSL_cert) {
      for (my $i = 0; $i < @$SSL_cert; $i++) {
        my $x509 = $SSL_cert->[$i];
        if ($i == 0) {
          $ssl_ctx->use_certificate($x509);
        }
        else {
          $ssl_ctx->add_extra_chain_cert($x509);
        }
      }
    }
    elsif ($SSL_cert_file) {
      $ssl_ctx->use_certificate_chain_file($SSL_cert_file);
    }
    
    # SSL_key, SSL_key_file
    my $SSL_key = $self->{SSL_key};
    my $SSL_key_file = $self->{SSL_key_file};
    if ($SSL_key) {
      $ssl_ctx->use_PrivateKey($SSL_key);
    }
    elsif ($SSL_key_file) {
      $ssl_ctx->use_PrivateKey_file($SSL_key_file, SSL->SSL_FILETYPE_PEM);
    }
    
    # SSL_crl_file
    my $SSL_crl_file = $self->{SSL_crl_file};
    if ($SSL_crl_file) {
      my $bio = Net::SSLeay::BIO->new_file($SSL_crl_file, "r");
      while (1) {
        my $x509_crl = (Net::SSLeay::X509_CRL)undef;
        
        eval { $x509_crl = Net::SSLeay::PEM->read_bio_X509_CRL($bio); }
        
        if ($@) {
          if ($@ isa Net::SSLeay::Error::PEM_R_NO_START_LINE) {
            last;
          }
          else {
            die $@;
          }
        }
        
        $x509_store->add_crl($x509_crl);
      }
    }
    
    # SSL_check_crl
    my $SSL_check_crl = $self->{SSL_check_crl};
    if ($SSL_check_crl) {
      $x509_store_flags |= SSL->X509_V_FLAG_CRL_CHECK;
    }
    
    # SSL_alpn_protocols
    my $SSL_alpn_protocols = $self->{SSL_alpn_protocols};
    if ($SSL_alpn_protocols) {
      if ($SocketCategory == &SOCKET_CATEGORY_CLIENT) {
        $ssl_ctx->set_alpn_protos_with_protocols($SSL_alpn_protocols);
      }
      elsif ($SocketCategory == &SOCKET_CATEGORY_SERVER) {
        $ssl_ctx->set_alpn_select_cb_with_protocols($SSL_alpn_protocols);
      }
    }
    
    $x509_store->set_flags($x509_store_flags);
    
    $self->{ssl_ctx} = $ssl_ctx;
    
  }
  
  method connect_SSL : void () {
    
    my $fd = $self->{FD};
    my $ssl_ctx = $self->{ssl_ctx};
    my $ssl = Net::SSLeay->new($ssl_ctx);
    $ssl->set_fd($fd);
    
    my $before_connect_SSL_cbs = (List of IO::Socket::SSL::Callback::BeforeConnectSSL)List->new_ref($self->{before_connect_SSL_cbs});
    for (my $i = 0; $i < @$before_connect_SSL_cbs; $i++) {
      $before_connect_SSL_cbs->[$i]->($self, $ssl);
    }
    
    while (1) {
      my $timeout_duration = $self->get_remaining_duration(1);
      
      eval { $ssl->connect; }
      
      if ($@) {
        my $again_read = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        my $again_write = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout_duration);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout_duration);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    $self->{ssl} = $ssl;
  }
  
  method accept_SSL : void () {
    my $fd = $self->{FD};
    my $ssl_ctx = $self->{ssl_ctx};
    my $ssl = Net::SSLeay->new($ssl_ctx);
    $ssl->set_fd($fd);
    
    my $before_accept_SSL_cbs = (List of IO::Socket::SSL::Callback::BeforeAcceptSSL)List->new_ref($self->{before_accept_SSL_cbs});
    for (my $i = 0; $i < @$before_accept_SSL_cbs; $i++) {
      $before_accept_SSL_cbs->[$i]->($self, $ssl);
    }
    
    while (1) {
      my $timeout_duration = $self->get_remaining_duration(0);
      
      eval { $ssl->accept; }
      
      if ($@) {
        my $again_read = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        my $again_write = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout_duration);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout_duration);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    $self->{ssl} = $ssl;
  }
  
  method accept : IO::Socket::SSL ($peer_ref : Sys::Socket::Sockaddr[] = undef) {
    
    my $accepted_socket = (IO::Socket::SSL)$self->SUPER::accept($peer_ref);
    
    $accepted_socket->{ssl_ctx} = $self->{ssl_ctx};
    
    my $SSL_startHandshake = $self->{SSL_startHandshake};
    
    if ($SSL_startHandshake) {
      $accepted_socket->accept_SSL;
    }
    
    return $accepted_socket;
  }
  
  method sysread : int ($buffer : mutable string, $length : int = -1, $offset : int = 0) {
    my $ssl = $self->{ssl};
    my $fd = $self->{FD};
    my $read_length = -1;
    
    while (1) {
      my $timeout_duration = $self->get_remaining_duration(0);
      
      eval { $read_length = $ssl->read($buffer, $length, $offset); }
      
      if ($@) {
        my $again_read = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        my $again_write = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout_duration);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout_duration);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    return $read_length;
  }
  
  method read : int ($buffer : mutable string, $length : int = -1, $offset : int = 0) {
    
    diag "IO::Socket::SSL#read method is depreacted. Use sysread method instead.";
    
    return $self->sysread($buffer, $length, $offset);
  }
  
  method syswrite : int ($buffer : string, $length : int = -1, $offset : int = 0) {
    my $ssl = $self->{ssl};
    my $fd = $self->{FD};
    my $write_length = -1;
    
    while (1) {
      my $timeout_duration = $self->get_remaining_duration(1);
      
      eval { $write_length = $ssl->write($buffer, $length, $offset); }
      
      if ($@) {
        my $again_read = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        my $again_write = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout_duration);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout_duration);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    return $write_length;
  }
  
  method write : int ($buffer : string, $length : int = -1, $offset : int = 0) {
    
    diag "IO::Socket::SSL#write method is depreacted. Use syswrite method instead.";
    
    return $self->syswrite($buffer, $length, $offset);
  }
  
  method shutdown_SSL : int () {
    
    my $ssl = $self->{ssl};
    my $fd = $self->{FD};
    
    my $status = -1;
    while (1) {
      my $timeout_duration = $self->get_remaining_duration(1);
      
      eval { $status = $ssl->shutdown; }
      
      if ($@) {
        my $again_read = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_READ;
        my $again_write = $@ isa Net::SSLeay::Error::SSL_ERROR_WANT_WRITE;
        
        if ($again_read) {
          Go->gosched_io_read($fd, $timeout_duration);
          next;
        }
        elsif ($again_write) {
          Go->gosched_io_write($fd, $timeout_duration);
          next;
        }
        else {
          die $@;
        }
      }
      else {
        last;
      }
    }
    
    return $status;
  }
  
  method alpn_selected : string () {
    
    my $ssl = $self->{ssl};
    
    my $protocol = $ssl->get0_alpn_selected_return_string;
    
    return $protocol;
  }
  
  method get_sslversion : string () {
    
    my $ssl = $self->{ssl};
    
    my $version_string = $ssl->get_version;
    
    return $version_string;
  }
  
  method get_sslversion_int : int () {
    
    my $ssl = $self->{ssl};
    
    my $version = $ssl->version;
    
    return $version;
  }
  
  method get_cipher : string ()  {
    
    my $ssl = $self->{ssl};
    
    my $cipher = $ssl->get_cipher;
    
    return $cipher;
  }  
  
  method get_servername : string ()  {
    
    my $ssl = $self->{ssl};
    
    my $servername = $ssl->get_servername(SSL->TLSEXT_NAMETYPE_host_name);
    
    return $servername;
  }  
  
  method peer_certificate : Net::SSLeay::X509 () {
    
    my $ssl = $self->{ssl};
    
    my $cert = $ssl->get_peer_certificate;
    
    return $cert;
  }
  
  method peer_certificates : Net::SSLeay::X509[] () {
    
    my $ssl = $self->{ssl};
    
    my $x509_peer = $self->peer_certificate;
    
    my $x509s = (Net::SSLeay::X509[])undef;
    if ($x509_peer) {
      my $x509s_original = $ssl->get_peer_cert_chain;
      
      my $SocketCategory = $self->{SocketCategory};
      if ($SocketCategory == &SOCKET_CATEGORY_ACCEPTED) {
        $x509s = (Net::SSLeay::X509[])Array->merge_object([$x509_peer], $x509s_original);
      }
      else {
        $x509s = $x509s_original;
      }
    }
    else {
      $x509s = new Net::SSLeay::X509[0];
    }
    
    return $x509s;
  }
  
  method sock_certificate : Net::SSLeay::X509 () {
    
    my $ssl = $self->{ssl};
    
    my $cert = $ssl->get_certificate;
    
    return $cert;
  }
  
  method dump_peer_certificate : string () {
    
    my $ssl = $self->{ssl};
    
    my $dump = $ssl->dump_peer_certificate;
    
    return $dump;
  }
  
  method add_before_connect_SSL_cb : void ($cb : IO::Socket::SSL::Callback::BeforeConnectSSL) {
    
    List->new_ref($self->{before_connect_SSL_cbs})->push($cb);
  }
  
  method add_before_accept_SSL_cb : void ($cb : IO::Socket::SSL::Callback::BeforeAcceptSSL) {
    
    List->new_ref($self->{before_accept_SSL_cbs})->push($cb);
  }
  
  method stat : Sys::IO::Stat () {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method send : int ($buffer : string, $flags : int = 0, $length : int = -1, $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method sendto : int ($buffer : string, $flags : int, $to : Sys::Socket::Sockaddr, $length : int = -1, $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method recv : int ($buffer : mutable string, $length : int = -1, $flags : int = 0, $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method recvfrom : int ($buffer : mutable string, $length : int, $flags : int, $from_ref : Sys::Socket::Sockaddr[], $offset : int = 0) {
    die "This method is not allowed in IO::Scoekt::SSL.";
  }
  
  method DESTROY : void () {
    
    if ($self->opened) {
      my $ssl = $self->{ssl};
      
      if ($ssl) {
        $self->shutdown_SSL;
      }
      
      $self->close;
    }
    
  }

}