|
| 1 | +package dnscrypt |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "crypto/ed25519" |
| 7 | + "errors" |
| 8 | + "net" |
| 9 | + "strings" |
| 10 | + "time" |
| 11 | + |
| 12 | + "github.com/alphasoc/flightsim/simulator/encdns" |
| 13 | + dnsstamps "github.com/jedisct1/go-dnsstamps" |
| 14 | + "golang.org/x/net/dns/dnsmessage" |
| 15 | +) |
| 16 | + |
| 17 | +// Basic Client struct. Wrapped by providers. |
| 18 | +type Client struct { |
| 19 | + Net string |
| 20 | +} |
| 21 | + |
| 22 | +// ResolverInfo contains DNSCrypt resolver information necessary for decryption/encryption. |
| 23 | +type ResolverInfo struct { |
| 24 | + SecretKey [keySize]byte // Client short-term secret key |
| 25 | + PublicKey [keySize]byte // Client short-term public key |
| 26 | + |
| 27 | + ServerPublicKey ed25519.PublicKey // Resolver public key (this key is used to validate cert signature) |
| 28 | + ServerAddress string // Server IP address |
| 29 | + ProviderName string // Provider name |
| 30 | + |
| 31 | + ResolverCert *Cert // Certificate info (obtained with the first unencrypted DNS request) |
| 32 | + SharedKey [keySize]byte // Shared key that is to be used to encrypt/decrypt messages |
| 33 | +} |
| 34 | + |
| 35 | +// findCertMagic is a bit of a hack to find the beginning of the certificate. Returns |
| 36 | +// the start of the certificate magic, or -1 if not found. |
| 37 | +func findCertMagic(b []byte) int { |
| 38 | + return bytes.Index(b, certMagic[0:]) |
| 39 | +} |
| 40 | + |
| 41 | +// fetchCert loads DNSCrypt cert from the specified server. |
| 42 | +func (c *Client) fetchCert(ctx context.Context, stamp dnsstamps.ServerStamp) (*Cert, error) { |
| 43 | + providerName := stamp.ProviderName |
| 44 | + if !strings.HasSuffix(providerName, ".") { |
| 45 | + providerName = providerName + "." |
| 46 | + } |
| 47 | + |
| 48 | + dnsReq, err := encdns.NewUDPRequest(providerName, dnsmessage.TypeTXT) |
| 49 | + if err != nil { |
| 50 | + return nil, err |
| 51 | + } |
| 52 | + d := net.Dialer{} |
| 53 | + // ctx, cancelFn := context.WithTimeout(ctx, 500*time.Millisecond) |
| 54 | + // defer cancelFn() |
| 55 | + conn, err := d.DialContext(ctx, c.Net, stamp.ServerAddrStr) |
| 56 | + if err != nil { |
| 57 | + return nil, err |
| 58 | + } |
| 59 | + defer conn.Close() |
| 60 | + _, err = conn.Write(dnsReq) |
| 61 | + if err != nil { |
| 62 | + return nil, err |
| 63 | + } |
| 64 | + b := make([]byte, 2048) |
| 65 | + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) |
| 66 | + n, err := conn.Read(b) |
| 67 | + if err != nil { |
| 68 | + return nil, err |
| 69 | + } |
| 70 | + // Check certificate response rcode==0. |
| 71 | + certMsg := dnsmessage.Message{} |
| 72 | + if err := certMsg.Unpack(b[0:n]); err != nil || certMsg.RCode != dnsmessage.RCodeSuccess { |
| 73 | + return nil, errors.New(ErrInvalidDNSResponse) |
| 74 | + } |
| 75 | + certIdx := findCertMagic(b) |
| 76 | + if certIdx == -1 { |
| 77 | + return nil, errors.New(ErrCertMagic) |
| 78 | + } |
| 79 | + certStr := b[certIdx:] |
| 80 | + cert := &Cert{} |
| 81 | + err = cert.Deserialize(certStr) |
| 82 | + if err != nil { |
| 83 | + return nil, err |
| 84 | + } |
| 85 | + return cert, nil |
| 86 | +} |
| 87 | + |
| 88 | +// Dial dials the server specified by stampStr, returning a *ResolverInfo and an error. |
| 89 | +func (c *Client) Dial(ctx context.Context, stampStr string) (*ResolverInfo, error) { |
| 90 | + stamp, err := dnsstamps.NewServerStampFromString(stampStr) |
| 91 | + if err != nil { |
| 92 | + return nil, err |
| 93 | + } |
| 94 | + if stamp.Proto != dnsstamps.StampProtoTypeDNSCrypt { |
| 95 | + return nil, errors.New(ErrInvalidDNSStamp) |
| 96 | + } |
| 97 | + resolverInfo := &ResolverInfo{} |
| 98 | + // Generate the secret/public pair. |
| 99 | + resolverInfo.SecretKey, resolverInfo.PublicKey = generateRandomKeyPair() |
| 100 | + // Set the provider properties. |
| 101 | + resolverInfo.ServerPublicKey = stamp.ServerPk |
| 102 | + resolverInfo.ServerAddress = stamp.ServerAddrStr |
| 103 | + resolverInfo.ProviderName = stamp.ProviderName |
| 104 | + cert, err := c.fetchCert(ctx, stamp) |
| 105 | + if err != nil { |
| 106 | + return nil, err |
| 107 | + } |
| 108 | + resolverInfo.ResolverCert = cert |
| 109 | + // Compute shared key that we'll use to encrypt/decrypt messages. |
| 110 | + sharedKey, err := computeSharedKey(cert.EsVersion, &resolverInfo.SecretKey, &cert.ResolverPk) |
| 111 | + if err != nil { |
| 112 | + return nil, err |
| 113 | + } |
| 114 | + resolverInfo.SharedKey = sharedKey |
| 115 | + return resolverInfo, nil |
| 116 | +} |
| 117 | + |
| 118 | +// Encrypt encrypts a DNS message using shared key from the resolver info. It returns a |
| 119 | +// []byte and an error. |
| 120 | +func (c *Client) Encrypt(m []byte, resolverInfo *ResolverInfo) ([]byte, error) { |
| 121 | + q := EncryptedQuery{ |
| 122 | + EsVersion: resolverInfo.ResolverCert.EsVersion, |
| 123 | + ClientMagic: resolverInfo.ResolverCert.ClientMagic, |
| 124 | + ClientPk: resolverInfo.PublicKey, |
| 125 | + } |
| 126 | + // query, err := m.Pack() |
| 127 | + // if err != nil { |
| 128 | + // return nil, err |
| 129 | + // } |
| 130 | + b, err := q.Encrypt(m, resolverInfo.SharedKey) |
| 131 | + if len(b) > MinMsgSize { |
| 132 | + return nil, errors.New(ErrQueryTooLarge) |
| 133 | + } |
| 134 | + |
| 135 | + return b, err |
| 136 | +} |
| 137 | + |
| 138 | +// decrypts decrypts a DNS message using a shared key from the resolver info. It returns |
| 139 | +// a []byte and an error. |
| 140 | +func (c *Client) Decrypt(b []byte, resolverInfo *ResolverInfo) ([]byte, error) { |
| 141 | + dr := EncryptedResponse{ |
| 142 | + EsVersion: resolverInfo.ResolverCert.EsVersion, |
| 143 | + } |
| 144 | + msg, err := dr.Decrypt(b, resolverInfo.SharedKey) |
| 145 | + if err != nil { |
| 146 | + return nil, err |
| 147 | + } |
| 148 | + return msg, nil |
| 149 | +} |
0 commit comments