1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "internal/godebug"
19 "io"
20 "net"
21 "sync"
22 "sync/atomic"
23 "time"
24 )
25
26
27
28 type Conn struct {
29
30 conn net.Conn
31 isClient bool
32 handshakeFn func(context.Context) error
33 quic *quicState
34
35
36
37
38 isHandshakeComplete atomic.Bool
39
40 handshakeMutex sync.Mutex
41 handshakeErr error
42 vers uint16
43 haveVers bool
44 config *Config
45
46
47
48 handshakes int
49 extMasterSecret bool
50 didResume bool
51 didHRR bool
52 cipherSuite uint16
53 curveID CurveID
54 peerSigAlg SignatureScheme
55 ocspResponse []byte
56 scts [][]byte
57 peerCertificates []*x509.Certificate
58
59
60 verifiedChains [][]*x509.Certificate
61
62 serverName string
63
64
65
66 secureRenegotiation bool
67
68 ekm func(label string, context []byte, length int) ([]byte, error)
69
70
71 resumptionSecret []byte
72 echAccepted bool
73
74
75
76
77 ticketKeys []ticketKey
78
79
80
81
82
83 clientFinishedIsFirst bool
84
85
86 closeNotifyErr error
87
88
89 closeNotifySent bool
90
91
92
93
94
95 clientFinished [12]byte
96 serverFinished [12]byte
97
98
99 clientProtocol string
100
101
102 in, out halfConn
103 rawInput bytes.Buffer
104 input bytes.Reader
105 hand bytes.Buffer
106 buffering bool
107 sendBuf []byte
108
109
110
111 bytesSent int64
112 packetsSent int64
113
114
115
116
117 retryCount int
118
119
120
121 activeCall atomic.Int32
122
123 tmp [16]byte
124 }
125
126
127
128
129
130
131 func (c *Conn) LocalAddr() net.Addr {
132 return c.conn.LocalAddr()
133 }
134
135
136 func (c *Conn) RemoteAddr() net.Addr {
137 return c.conn.RemoteAddr()
138 }
139
140
141
142
143 func (c *Conn) SetDeadline(t time.Time) error {
144 return c.conn.SetDeadline(t)
145 }
146
147
148
149 func (c *Conn) SetReadDeadline(t time.Time) error {
150 return c.conn.SetReadDeadline(t)
151 }
152
153
154
155
156 func (c *Conn) SetWriteDeadline(t time.Time) error {
157 return c.conn.SetWriteDeadline(t)
158 }
159
160
161
162
163 func (c *Conn) NetConn() net.Conn {
164 return c.conn
165 }
166
167
168
169 type halfConn struct {
170 sync.Mutex
171
172 err error
173 version uint16
174 cipher any
175 mac hash.Hash
176 seq [8]byte
177
178 scratchBuf [13]byte
179
180 nextCipher any
181 nextMac hash.Hash
182
183 level QUICEncryptionLevel
184 trafficSecret []byte
185 }
186
187 type permanentError struct {
188 err net.Error
189 }
190
191 func (e *permanentError) Error() string { return e.err.Error() }
192 func (e *permanentError) Unwrap() error { return e.err }
193 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
194 func (e *permanentError) Temporary() bool { return false }
195
196 func (hc *halfConn) setErrorLocked(err error) error {
197 if e, ok := err.(net.Error); ok {
198 hc.err = &permanentError{err: e}
199 } else {
200 hc.err = err
201 }
202 return hc.err
203 }
204
205
206
207 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
208 hc.version = version
209 hc.nextCipher = cipher
210 hc.nextMac = mac
211 }
212
213
214
215 func (hc *halfConn) changeCipherSpec() error {
216 if hc.nextCipher == nil || hc.version == VersionTLS13 {
217 return alertInternalError
218 }
219 hc.cipher = hc.nextCipher
220 hc.mac = hc.nextMac
221 hc.nextCipher = nil
222 hc.nextMac = nil
223 for i := range hc.seq {
224 hc.seq[i] = 0
225 }
226 return nil
227 }
228
229
230
231
232 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
233 hc.trafficSecret = secret
234 hc.level = level
235 key, iv := suite.trafficKey(secret)
236 hc.cipher = suite.aead(key, iv)
237 for i := range hc.seq {
238 hc.seq[i] = 0
239 }
240 }
241
242
243 func (hc *halfConn) incSeq() {
244 for i := 7; i >= 0; i-- {
245 hc.seq[i]++
246 if hc.seq[i] != 0 {
247 return
248 }
249 }
250
251
252
253
254 panic("TLS: sequence number wraparound")
255 }
256
257
258
259
260 func (hc *halfConn) explicitNonceLen() int {
261 if hc.cipher == nil {
262 return 0
263 }
264
265 switch c := hc.cipher.(type) {
266 case cipher.Stream:
267 return 0
268 case aead:
269 return c.explicitNonceLen()
270 case cbcMode:
271
272 if hc.version >= VersionTLS11 {
273 return c.BlockSize()
274 }
275 return 0
276 default:
277 panic("unknown cipher type")
278 }
279 }
280
281
282
283
284 func extractPadding(payload []byte) (toRemove int, good byte) {
285 if len(payload) < 1 {
286 return 0, 0
287 }
288
289 paddingLen := payload[len(payload)-1]
290 t := uint(len(payload)-1) - uint(paddingLen)
291
292 good = byte(int32(^t) >> 31)
293
294
295 toCheck := 256
296
297 if toCheck > len(payload) {
298 toCheck = len(payload)
299 }
300
301 for i := 0; i < toCheck; i++ {
302 t := uint(paddingLen) - uint(i)
303
304 mask := byte(int32(^t) >> 31)
305 b := payload[len(payload)-1-i]
306 good &^= mask&paddingLen ^ mask&b
307 }
308
309
310
311 good &= good << 4
312 good &= good << 2
313 good &= good << 1
314 good = uint8(int8(good) >> 7)
315
316
317
318
319
320
321
322
323
324
325 paddingLen &= good
326
327 toRemove = int(paddingLen) + 1
328 return
329 }
330
331 func roundUp(a, b int) int {
332 return a + (b-a%b)%b
333 }
334
335
336 type cbcMode interface {
337 cipher.BlockMode
338 SetIV([]byte)
339 }
340
341
342
343 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
344 var plaintext []byte
345 typ := recordType(record[0])
346 payload := record[recordHeaderLen:]
347
348
349
350 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
351 return payload, typ, nil
352 }
353
354 paddingGood := byte(255)
355 paddingLen := 0
356
357 explicitNonceLen := hc.explicitNonceLen()
358
359 if hc.cipher != nil {
360 switch c := hc.cipher.(type) {
361 case cipher.Stream:
362 c.XORKeyStream(payload, payload)
363 case aead:
364 if len(payload) < explicitNonceLen {
365 return nil, 0, alertBadRecordMAC
366 }
367 nonce := payload[:explicitNonceLen]
368 if len(nonce) == 0 {
369 nonce = hc.seq[:]
370 }
371 payload = payload[explicitNonceLen:]
372
373 var additionalData []byte
374 if hc.version == VersionTLS13 {
375 additionalData = record[:recordHeaderLen]
376 } else {
377 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
378 additionalData = append(additionalData, record[:3]...)
379 n := len(payload) - c.Overhead()
380 additionalData = append(additionalData, byte(n>>8), byte(n))
381 }
382
383 var err error
384 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
385 if err != nil {
386 return nil, 0, alertBadRecordMAC
387 }
388 case cbcMode:
389 blockSize := c.BlockSize()
390 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
391 if len(payload)%blockSize != 0 || len(payload) < minPayload {
392 return nil, 0, alertBadRecordMAC
393 }
394
395 if explicitNonceLen > 0 {
396 c.SetIV(payload[:explicitNonceLen])
397 payload = payload[explicitNonceLen:]
398 }
399 c.CryptBlocks(payload, payload)
400
401
402
403
404
405
406
407 paddingLen, paddingGood = extractPadding(payload)
408 default:
409 panic("unknown cipher type")
410 }
411
412 if hc.version == VersionTLS13 {
413 if typ != recordTypeApplicationData {
414 return nil, 0, alertUnexpectedMessage
415 }
416 if len(plaintext) > maxPlaintext+1 {
417 return nil, 0, alertRecordOverflow
418 }
419
420 for i := len(plaintext) - 1; i >= 0; i-- {
421 if plaintext[i] != 0 {
422 typ = recordType(plaintext[i])
423 plaintext = plaintext[:i]
424 break
425 }
426 if i == 0 {
427 return nil, 0, alertUnexpectedMessage
428 }
429 }
430 }
431 } else {
432 plaintext = payload
433 }
434
435 if hc.mac != nil {
436 macSize := hc.mac.Size()
437 if len(payload) < macSize {
438 return nil, 0, alertBadRecordMAC
439 }
440
441 n := len(payload) - macSize - paddingLen
442 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
443 record[3] = byte(n >> 8)
444 record[4] = byte(n)
445 remoteMAC := payload[n : n+macSize]
446 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
447
448
449
450
451
452
453
454
455 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
456 if macAndPaddingGood != 1 {
457 return nil, 0, alertBadRecordMAC
458 }
459
460 plaintext = payload[:n]
461 }
462
463 hc.incSeq()
464 return plaintext, typ, nil
465 }
466
467
468
469
470 func sliceForAppend(in []byte, n int) (head, tail []byte) {
471 if total := len(in) + n; cap(in) >= total {
472 head = in[:total]
473 } else {
474 head = make([]byte, total)
475 copy(head, in)
476 }
477 tail = head[len(in):]
478 return
479 }
480
481
482
483 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
484 if hc.cipher == nil {
485 return append(record, payload...), nil
486 }
487
488 var explicitNonce []byte
489 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
490 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
491 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
492
493
494
495
496
497
498
499
500
501 copy(explicitNonce, hc.seq[:])
502 } else {
503 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
504 return nil, err
505 }
506 }
507 }
508
509 var dst []byte
510 switch c := hc.cipher.(type) {
511 case cipher.Stream:
512 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
513 record, dst = sliceForAppend(record, len(payload)+len(mac))
514 c.XORKeyStream(dst[:len(payload)], payload)
515 c.XORKeyStream(dst[len(payload):], mac)
516 case aead:
517 nonce := explicitNonce
518 if len(nonce) == 0 {
519 nonce = hc.seq[:]
520 }
521
522 if hc.version == VersionTLS13 {
523 record = append(record, payload...)
524
525
526 record = append(record, record[0])
527 record[0] = byte(recordTypeApplicationData)
528
529 n := len(payload) + 1 + c.Overhead()
530 record[3] = byte(n >> 8)
531 record[4] = byte(n)
532
533 record = c.Seal(record[:recordHeaderLen],
534 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
535 } else {
536 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
537 additionalData = append(additionalData, record[:recordHeaderLen]...)
538 record = c.Seal(record, nonce, payload, additionalData)
539 }
540 case cbcMode:
541 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
542 blockSize := c.BlockSize()
543 plaintextLen := len(payload) + len(mac)
544 paddingLen := blockSize - plaintextLen%blockSize
545 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
546 copy(dst, payload)
547 copy(dst[len(payload):], mac)
548 for i := plaintextLen; i < len(dst); i++ {
549 dst[i] = byte(paddingLen - 1)
550 }
551 if len(explicitNonce) > 0 {
552 c.SetIV(explicitNonce)
553 }
554 c.CryptBlocks(dst, dst)
555 default:
556 panic("unknown cipher type")
557 }
558
559
560 n := len(record) - recordHeaderLen
561 record[3] = byte(n >> 8)
562 record[4] = byte(n)
563 hc.incSeq()
564
565 return record, nil
566 }
567
568
569 type RecordHeaderError struct {
570
571 Msg string
572
573
574 RecordHeader [5]byte
575
576
577
578
579 Conn net.Conn
580 }
581
582 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
583
584 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
585 err.Msg = msg
586 err.Conn = conn
587 copy(err.RecordHeader[:], c.rawInput.Bytes())
588 return err
589 }
590
591 func (c *Conn) readRecord() error {
592 return c.readRecordOrCCS(false)
593 }
594
595 func (c *Conn) readChangeCipherSpec() error {
596 return c.readRecordOrCCS(true)
597 }
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
614 if c.in.err != nil {
615 return c.in.err
616 }
617 handshakeComplete := c.isHandshakeComplete.Load()
618
619
620 if c.input.Len() != 0 {
621 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
622 }
623 c.input.Reset(nil)
624
625 if c.quic != nil {
626 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
627 }
628
629
630 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
631
632
633
634 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
635 err = io.EOF
636 }
637 if e, ok := err.(net.Error); !ok || !e.Temporary() {
638 c.in.setErrorLocked(err)
639 }
640 return err
641 }
642 hdr := c.rawInput.Bytes()[:recordHeaderLen]
643 typ := recordType(hdr[0])
644
645
646
647
648
649 if !handshakeComplete && typ == 0x80 {
650 c.sendAlert(alertProtocolVersion)
651 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
652 }
653
654 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
655 expectedVers := c.vers
656 if expectedVers == VersionTLS13 {
657
658
659 expectedVers = VersionTLS12
660 }
661 n := int(hdr[3])<<8 | int(hdr[4])
662 if c.haveVers && vers != expectedVers {
663 c.sendAlert(alertProtocolVersion)
664 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
665 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
666 }
667 if !c.haveVers {
668
669
670
671
672 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
673 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
674 }
675 }
676 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
677 c.sendAlert(alertRecordOverflow)
678 msg := fmt.Sprintf("oversized record received with length %d", n)
679 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
680 }
681 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
682 if e, ok := err.(net.Error); !ok || !e.Temporary() {
683 c.in.setErrorLocked(err)
684 }
685 return err
686 }
687
688
689 record := c.rawInput.Next(recordHeaderLen + n)
690 data, typ, err := c.in.decrypt(record)
691 if err != nil {
692 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
693 }
694 if len(data) > maxPlaintext {
695 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
696 }
697
698
699 if c.in.cipher == nil && typ == recordTypeApplicationData {
700 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
701 }
702
703 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
704
705 c.retryCount = 0
706 }
707
708
709 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
710 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
711 }
712
713 switch typ {
714 default:
715 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
716
717 case recordTypeAlert:
718 if c.quic != nil {
719 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
720 }
721 if len(data) != 2 {
722 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
723 }
724 if alert(data[1]) == alertCloseNotify {
725 return c.in.setErrorLocked(io.EOF)
726 }
727 if c.vers == VersionTLS13 {
728
729
730
731
732
733 if alert(data[1]) == alertUserCanceled {
734
735 return c.retryReadRecord(expectChangeCipherSpec)
736 }
737 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
738 }
739 switch data[0] {
740 case alertLevelWarning:
741
742 return c.retryReadRecord(expectChangeCipherSpec)
743 case alertLevelError:
744 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
745 default:
746 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
747 }
748
749 case recordTypeChangeCipherSpec:
750 if len(data) != 1 || data[0] != 1 {
751 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
752 }
753
754 if c.hand.Len() > 0 {
755 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
756 }
757
758
759
760
761
762 if c.vers == VersionTLS13 {
763 return c.retryReadRecord(expectChangeCipherSpec)
764 }
765 if !expectChangeCipherSpec {
766 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
767 }
768 if err := c.in.changeCipherSpec(); err != nil {
769 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
770 }
771
772 case recordTypeApplicationData:
773 if !handshakeComplete || expectChangeCipherSpec {
774 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
775 }
776
777
778 if len(data) == 0 {
779 return c.retryReadRecord(expectChangeCipherSpec)
780 }
781
782
783
784 c.input.Reset(data)
785
786 case recordTypeHandshake:
787 if len(data) == 0 || expectChangeCipherSpec {
788 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
789 }
790 c.hand.Write(data)
791 }
792
793 return nil
794 }
795
796
797
798 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
799 c.retryCount++
800 if c.retryCount > maxUselessRecords {
801 c.sendAlert(alertUnexpectedMessage)
802 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
803 }
804 return c.readRecordOrCCS(expectChangeCipherSpec)
805 }
806
807
808
809
810 type atLeastReader struct {
811 R io.Reader
812 N int64
813 }
814
815 func (r *atLeastReader) Read(p []byte) (int, error) {
816 if r.N <= 0 {
817 return 0, io.EOF
818 }
819 n, err := r.R.Read(p)
820 r.N -= int64(n)
821 if r.N > 0 && err == io.EOF {
822 return n, io.ErrUnexpectedEOF
823 }
824 if r.N <= 0 && err == nil {
825 return n, io.EOF
826 }
827 return n, err
828 }
829
830
831
832 func (c *Conn) readFromUntil(r io.Reader, n int) error {
833 if c.rawInput.Len() >= n {
834 return nil
835 }
836 needs := n - c.rawInput.Len()
837
838
839
840 c.rawInput.Grow(needs + bytes.MinRead)
841 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
842 return err
843 }
844
845
846 func (c *Conn) sendAlertLocked(err alert) error {
847 if c.quic != nil {
848 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
849 }
850
851 switch err {
852 case alertNoRenegotiation, alertCloseNotify:
853 c.tmp[0] = alertLevelWarning
854 default:
855 c.tmp[0] = alertLevelError
856 }
857 c.tmp[1] = byte(err)
858
859 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
860 if err == alertCloseNotify {
861
862 return writeErr
863 }
864
865 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
866 }
867
868
869 func (c *Conn) sendAlert(err alert) error {
870 c.out.Lock()
871 defer c.out.Unlock()
872 return c.sendAlertLocked(err)
873 }
874
875 const (
876
877
878
879
880
881 tcpMSSEstimate = 1208
882
883
884
885
886 recordSizeBoostThreshold = 128 * 1024
887 )
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
906 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
907 return maxPlaintext
908 }
909
910 if c.bytesSent >= recordSizeBoostThreshold {
911 return maxPlaintext
912 }
913
914
915 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
916 if c.out.cipher != nil {
917 switch ciph := c.out.cipher.(type) {
918 case cipher.Stream:
919 payloadBytes -= c.out.mac.Size()
920 case cipher.AEAD:
921 payloadBytes -= ciph.Overhead()
922 case cbcMode:
923 blockSize := ciph.BlockSize()
924
925
926 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
927
928
929 payloadBytes -= c.out.mac.Size()
930 default:
931 panic("unknown cipher type")
932 }
933 }
934 if c.vers == VersionTLS13 {
935 payloadBytes--
936 }
937
938
939 pkt := c.packetsSent
940 c.packetsSent++
941 if pkt > 1000 {
942 return maxPlaintext
943 }
944
945 n := payloadBytes * int(pkt+1)
946 if n > maxPlaintext {
947 n = maxPlaintext
948 }
949 return n
950 }
951
952 func (c *Conn) write(data []byte) (int, error) {
953 if c.buffering {
954 c.sendBuf = append(c.sendBuf, data...)
955 return len(data), nil
956 }
957
958 n, err := c.conn.Write(data)
959 c.bytesSent += int64(n)
960 return n, err
961 }
962
963 func (c *Conn) flush() (int, error) {
964 if len(c.sendBuf) == 0 {
965 return 0, nil
966 }
967
968 n, err := c.conn.Write(c.sendBuf)
969 c.bytesSent += int64(n)
970 c.sendBuf = nil
971 c.buffering = false
972 return n, err
973 }
974
975
976 var outBufPool = sync.Pool{
977 New: func() any {
978 return new([]byte)
979 },
980 }
981
982
983
984 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
985 if c.quic != nil {
986 if typ != recordTypeHandshake {
987 return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
988 }
989 c.quicWriteCryptoData(c.out.level, data)
990 if !c.buffering {
991 if _, err := c.flush(); err != nil {
992 return 0, err
993 }
994 }
995 return len(data), nil
996 }
997
998 outBufPtr := outBufPool.Get().(*[]byte)
999 outBuf := *outBufPtr
1000 defer func() {
1001
1002
1003
1004
1005
1006 *outBufPtr = outBuf
1007 outBufPool.Put(outBufPtr)
1008 }()
1009
1010 var n int
1011 for len(data) > 0 {
1012 m := len(data)
1013 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1014 m = maxPayload
1015 }
1016
1017 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1018 outBuf[0] = byte(typ)
1019 vers := c.vers
1020 if vers == 0 {
1021
1022
1023 vers = VersionTLS10
1024 } else if vers == VersionTLS13 {
1025
1026
1027 vers = VersionTLS12
1028 }
1029 outBuf[1] = byte(vers >> 8)
1030 outBuf[2] = byte(vers)
1031 outBuf[3] = byte(m >> 8)
1032 outBuf[4] = byte(m)
1033
1034 var err error
1035 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1036 if err != nil {
1037 return n, err
1038 }
1039 if _, err := c.write(outBuf); err != nil {
1040 return n, err
1041 }
1042 n += m
1043 data = data[m:]
1044 }
1045
1046 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1047 if err := c.out.changeCipherSpec(); err != nil {
1048 return n, c.sendAlertLocked(err.(alert))
1049 }
1050 }
1051
1052 return n, nil
1053 }
1054
1055
1056
1057
1058 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1059 c.out.Lock()
1060 defer c.out.Unlock()
1061
1062 data, err := msg.marshal()
1063 if err != nil {
1064 return 0, err
1065 }
1066 if transcript != nil {
1067 transcript.Write(data)
1068 }
1069
1070 return c.writeRecordLocked(recordTypeHandshake, data)
1071 }
1072
1073
1074
1075 func (c *Conn) writeChangeCipherRecord() error {
1076 c.out.Lock()
1077 defer c.out.Unlock()
1078 _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1079 return err
1080 }
1081
1082
1083 func (c *Conn) readHandshakeBytes(n int) error {
1084 if c.quic != nil {
1085 return c.quicReadHandshakeBytes(n)
1086 }
1087 for c.hand.Len() < n {
1088 if err := c.readRecord(); err != nil {
1089 return err
1090 }
1091 }
1092 return nil
1093 }
1094
1095
1096
1097
1098 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1099 if err := c.readHandshakeBytes(4); err != nil {
1100 return nil, err
1101 }
1102 data := c.hand.Bytes()
1103
1104 maxHandshakeSize := maxHandshake
1105
1106
1107
1108 if c.haveVers && data[0] == typeCertificate {
1109
1110
1111
1112 maxHandshakeSize = maxHandshakeCertificateMsg
1113 }
1114
1115 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1116 if n > maxHandshakeSize {
1117 c.sendAlertLocked(alertInternalError)
1118 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
1119 }
1120 if err := c.readHandshakeBytes(4 + n); err != nil {
1121 return nil, err
1122 }
1123 data = c.hand.Next(4 + n)
1124 return c.unmarshalHandshakeMessage(data, transcript)
1125 }
1126
1127 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1128 var m handshakeMessage
1129 switch data[0] {
1130 case typeHelloRequest:
1131 m = new(helloRequestMsg)
1132 case typeClientHello:
1133 m = new(clientHelloMsg)
1134 case typeServerHello:
1135 m = new(serverHelloMsg)
1136 case typeNewSessionTicket:
1137 if c.vers == VersionTLS13 {
1138 m = new(newSessionTicketMsgTLS13)
1139 } else {
1140 m = new(newSessionTicketMsg)
1141 }
1142 case typeCertificate:
1143 if c.vers == VersionTLS13 {
1144 m = new(certificateMsgTLS13)
1145 } else {
1146 m = new(certificateMsg)
1147 }
1148 case typeCertificateRequest:
1149 if c.vers == VersionTLS13 {
1150 m = new(certificateRequestMsgTLS13)
1151 } else {
1152 m = &certificateRequestMsg{
1153 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1154 }
1155 }
1156 case typeCertificateStatus:
1157 m = new(certificateStatusMsg)
1158 case typeServerKeyExchange:
1159 m = new(serverKeyExchangeMsg)
1160 case typeServerHelloDone:
1161 m = new(serverHelloDoneMsg)
1162 case typeClientKeyExchange:
1163 m = new(clientKeyExchangeMsg)
1164 case typeCertificateVerify:
1165 m = &certificateVerifyMsg{
1166 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1167 }
1168 case typeFinished:
1169 m = new(finishedMsg)
1170 case typeEncryptedExtensions:
1171 m = new(encryptedExtensionsMsg)
1172 case typeEndOfEarlyData:
1173 m = new(endOfEarlyDataMsg)
1174 case typeKeyUpdate:
1175 m = new(keyUpdateMsg)
1176 default:
1177 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1178 }
1179
1180
1181
1182
1183 data = append([]byte(nil), data...)
1184
1185 if !m.unmarshal(data) {
1186 return nil, c.in.setErrorLocked(c.sendAlert(alertDecodeError))
1187 }
1188
1189 if transcript != nil {
1190 transcript.Write(data)
1191 }
1192
1193 return m, nil
1194 }
1195
1196 var (
1197 errShutdown = errors.New("tls: protocol is shutdown")
1198 )
1199
1200
1201
1202
1203
1204
1205
1206 func (c *Conn) Write(b []byte) (int, error) {
1207
1208 for {
1209 x := c.activeCall.Load()
1210 if x&1 != 0 {
1211 return 0, net.ErrClosed
1212 }
1213 if c.activeCall.CompareAndSwap(x, x+2) {
1214 break
1215 }
1216 }
1217 defer c.activeCall.Add(-2)
1218
1219 if err := c.Handshake(); err != nil {
1220 return 0, err
1221 }
1222
1223 c.out.Lock()
1224 defer c.out.Unlock()
1225
1226 if err := c.out.err; err != nil {
1227 return 0, err
1228 }
1229
1230 if !c.isHandshakeComplete.Load() {
1231 return 0, alertInternalError
1232 }
1233
1234 if c.closeNotifySent {
1235 return 0, errShutdown
1236 }
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247 var m int
1248 if len(b) > 1 && c.vers == VersionTLS10 {
1249 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1250 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1251 if err != nil {
1252 return n, c.out.setErrorLocked(err)
1253 }
1254 m, b = 1, b[1:]
1255 }
1256 }
1257
1258 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1259 return n + m, c.out.setErrorLocked(err)
1260 }
1261
1262
1263 func (c *Conn) handleRenegotiation() error {
1264 if c.vers == VersionTLS13 {
1265 return errors.New("tls: internal error: unexpected renegotiation")
1266 }
1267
1268 msg, err := c.readHandshake(nil)
1269 if err != nil {
1270 return err
1271 }
1272
1273 helloReq, ok := msg.(*helloRequestMsg)
1274 if !ok {
1275 c.sendAlert(alertUnexpectedMessage)
1276 return unexpectedMessageError(helloReq, msg)
1277 }
1278
1279 if !c.isClient {
1280 return c.sendAlert(alertNoRenegotiation)
1281 }
1282
1283 switch c.config.Renegotiation {
1284 case RenegotiateNever:
1285 return c.sendAlert(alertNoRenegotiation)
1286 case RenegotiateOnceAsClient:
1287 if c.handshakes > 1 {
1288 return c.sendAlert(alertNoRenegotiation)
1289 }
1290 case RenegotiateFreelyAsClient:
1291
1292 default:
1293 c.sendAlert(alertInternalError)
1294 return errors.New("tls: unknown Renegotiation value")
1295 }
1296
1297 c.handshakeMutex.Lock()
1298 defer c.handshakeMutex.Unlock()
1299
1300 c.isHandshakeComplete.Store(false)
1301 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1302 c.handshakes++
1303 }
1304 return c.handshakeErr
1305 }
1306
1307
1308
1309 func (c *Conn) handlePostHandshakeMessage() error {
1310 if c.vers != VersionTLS13 {
1311 return c.handleRenegotiation()
1312 }
1313
1314 msg, err := c.readHandshake(nil)
1315 if err != nil {
1316 return err
1317 }
1318 c.retryCount++
1319 if c.retryCount > maxUselessRecords {
1320 c.sendAlert(alertUnexpectedMessage)
1321 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1322 }
1323
1324 switch msg := msg.(type) {
1325 case *newSessionTicketMsgTLS13:
1326 return c.handleNewSessionTicket(msg)
1327 case *keyUpdateMsg:
1328 return c.handleKeyUpdate(msg)
1329 }
1330
1331
1332
1333
1334 c.sendAlert(alertUnexpectedMessage)
1335 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1336 }
1337
1338 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1339 if c.quic != nil {
1340 c.sendAlert(alertUnexpectedMessage)
1341 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1342 }
1343
1344 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1345 if cipherSuite == nil {
1346 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1347 }
1348
1349 if keyUpdate.updateRequested {
1350 c.out.Lock()
1351 defer c.out.Unlock()
1352
1353 msg := &keyUpdateMsg{}
1354 msgBytes, err := msg.marshal()
1355 if err != nil {
1356 return err
1357 }
1358 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1359 if err != nil {
1360
1361 c.out.setErrorLocked(err)
1362 return nil
1363 }
1364
1365 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1366 c.setWriteTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1367 }
1368
1369 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1370 if err := c.setReadTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret); err != nil {
1371 return err
1372 }
1373
1374 return nil
1375 }
1376
1377
1378
1379
1380
1381
1382
1383 func (c *Conn) Read(b []byte) (int, error) {
1384 if err := c.Handshake(); err != nil {
1385 return 0, err
1386 }
1387 if len(b) == 0 {
1388
1389
1390 return 0, nil
1391 }
1392
1393 c.in.Lock()
1394 defer c.in.Unlock()
1395
1396 for c.input.Len() == 0 {
1397 if err := c.readRecord(); err != nil {
1398 return 0, err
1399 }
1400 for c.hand.Len() > 0 {
1401 if err := c.handlePostHandshakeMessage(); err != nil {
1402 return 0, err
1403 }
1404 }
1405 }
1406
1407 n, _ := c.input.Read(b)
1408
1409
1410
1411
1412
1413
1414
1415
1416 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1417 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1418 if err := c.readRecord(); err != nil {
1419 return n, err
1420 }
1421 }
1422
1423 return n, nil
1424 }
1425
1426
1427 func (c *Conn) Close() error {
1428
1429 var x int32
1430 for {
1431 x = c.activeCall.Load()
1432 if x&1 != 0 {
1433 return net.ErrClosed
1434 }
1435 if c.activeCall.CompareAndSwap(x, x|1) {
1436 break
1437 }
1438 }
1439 if x != 0 {
1440
1441
1442
1443
1444
1445
1446 return c.conn.Close()
1447 }
1448
1449 var alertErr error
1450 if c.isHandshakeComplete.Load() {
1451 if err := c.closeNotify(); err != nil {
1452 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1453 }
1454 }
1455
1456 if err := c.conn.Close(); err != nil {
1457 return err
1458 }
1459 return alertErr
1460 }
1461
1462 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1463
1464
1465
1466
1467 func (c *Conn) CloseWrite() error {
1468 if !c.isHandshakeComplete.Load() {
1469 return errEarlyCloseWrite
1470 }
1471
1472 return c.closeNotify()
1473 }
1474
1475 func (c *Conn) closeNotify() error {
1476 c.out.Lock()
1477 defer c.out.Unlock()
1478
1479 if !c.closeNotifySent {
1480
1481 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1482 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1483 c.closeNotifySent = true
1484
1485 c.SetWriteDeadline(time.Now())
1486 }
1487 return c.closeNotifyErr
1488 }
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503 func (c *Conn) Handshake() error {
1504 return c.HandshakeContext(context.Background())
1505 }
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517 func (c *Conn) HandshakeContext(ctx context.Context) error {
1518
1519
1520 return c.handshakeContext(ctx)
1521 }
1522
1523 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1524
1525
1526
1527 if c.isHandshakeComplete.Load() {
1528 return nil
1529 }
1530
1531 handshakeCtx, cancel := context.WithCancel(ctx)
1532
1533
1534
1535 defer cancel()
1536
1537 if c.quic != nil {
1538 c.quic.cancelc = handshakeCtx.Done()
1539 c.quic.cancel = cancel
1540 } else if ctx.Done() != nil {
1541
1542
1543
1544
1545
1546 done := make(chan struct{})
1547 interruptRes := make(chan error, 1)
1548 defer func() {
1549 close(done)
1550 if ctxErr := <-interruptRes; ctxErr != nil {
1551
1552 ret = ctxErr
1553 }
1554 }()
1555 go func() {
1556 select {
1557 case <-handshakeCtx.Done():
1558
1559 _ = c.conn.Close()
1560 interruptRes <- handshakeCtx.Err()
1561 case <-done:
1562 interruptRes <- nil
1563 }
1564 }()
1565 }
1566
1567 c.handshakeMutex.Lock()
1568 defer c.handshakeMutex.Unlock()
1569
1570 if err := c.handshakeErr; err != nil {
1571 return err
1572 }
1573 if c.isHandshakeComplete.Load() {
1574 return nil
1575 }
1576
1577 c.in.Lock()
1578 defer c.in.Unlock()
1579
1580 c.handshakeErr = c.handshakeFn(handshakeCtx)
1581 if c.handshakeErr == nil {
1582 c.handshakes++
1583 } else {
1584
1585
1586 c.flush()
1587 }
1588
1589 if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1590 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1591 }
1592 if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1593 panic("tls: internal error: handshake returned an error but is marked successful")
1594 }
1595
1596 if c.quic != nil {
1597 if c.handshakeErr == nil {
1598 c.quicHandshakeComplete()
1599
1600
1601
1602 if err := c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret); err != nil {
1603 return err
1604 }
1605 } else {
1606 var a alert
1607 c.out.Lock()
1608 if !errors.As(c.out.err, &a) {
1609 a = alertInternalError
1610 }
1611 c.out.Unlock()
1612
1613
1614
1615
1616 c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1617 }
1618 close(c.quic.blockedc)
1619 close(c.quic.signalc)
1620 }
1621
1622 return c.handshakeErr
1623 }
1624
1625
1626 func (c *Conn) ConnectionState() ConnectionState {
1627 c.handshakeMutex.Lock()
1628 defer c.handshakeMutex.Unlock()
1629 return c.connectionStateLocked()
1630 }
1631
1632 var tlsunsafeekm = godebug.New("tlsunsafeekm")
1633
1634 func (c *Conn) connectionStateLocked() ConnectionState {
1635 var state ConnectionState
1636 state.HandshakeComplete = c.isHandshakeComplete.Load()
1637 state.Version = c.vers
1638 state.NegotiatedProtocol = c.clientProtocol
1639 state.DidResume = c.didResume
1640 state.testingOnlyDidHRR = c.didHRR
1641 state.testingOnlyPeerSignatureAlgorithm = c.peerSigAlg
1642 state.CurveID = c.curveID
1643 state.NegotiatedProtocolIsMutual = true
1644 state.ServerName = c.serverName
1645 state.CipherSuite = c.cipherSuite
1646 state.PeerCertificates = c.peerCertificates
1647 state.VerifiedChains = c.verifiedChains
1648 state.SignedCertificateTimestamps = c.scts
1649 state.OCSPResponse = c.ocspResponse
1650 if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1651 if c.clientFinishedIsFirst {
1652 state.TLSUnique = c.clientFinished[:]
1653 } else {
1654 state.TLSUnique = c.serverFinished[:]
1655 }
1656 }
1657 if c.config.Renegotiation != RenegotiateNever {
1658 state.ekm = noEKMBecauseRenegotiation
1659 } else if c.vers != VersionTLS13 && !c.extMasterSecret {
1660 state.ekm = func(label string, context []byte, length int) ([]byte, error) {
1661 if tlsunsafeekm.Value() == "1" {
1662 tlsunsafeekm.IncNonDefault()
1663 return c.ekm(label, context, length)
1664 }
1665 return noEKMBecauseNoEMS(label, context, length)
1666 }
1667 } else {
1668 state.ekm = c.ekm
1669 }
1670 state.ECHAccepted = c.echAccepted
1671 return state
1672 }
1673
1674
1675
1676 func (c *Conn) OCSPResponse() []byte {
1677 c.handshakeMutex.Lock()
1678 defer c.handshakeMutex.Unlock()
1679
1680 return c.ocspResponse
1681 }
1682
1683
1684
1685
1686 func (c *Conn) VerifyHostname(host string) error {
1687 c.handshakeMutex.Lock()
1688 defer c.handshakeMutex.Unlock()
1689 if !c.isClient {
1690 return errors.New("tls: VerifyHostname called on TLS server connection")
1691 }
1692 if !c.isHandshakeComplete.Load() {
1693 return errors.New("tls: handshake has not yet been performed")
1694 }
1695 if len(c.verifiedChains) == 0 {
1696 return errors.New("tls: handshake did not verify certificate chain")
1697 }
1698 return c.peerCertificates[0].VerifyHostname(host)
1699 }
1700
1701
1702
1703
1704 func (c *Conn) setReadTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) error {
1705
1706
1707
1708 if c.hand.Len() != 0 {
1709 c.sendAlert(alertUnexpectedMessage)
1710 return errors.New("tls: handshake buffer not empty before setting read traffic secret")
1711 }
1712 c.in.setTrafficSecret(suite, level, secret)
1713 return nil
1714 }
1715
1716
1717
1718
1719 func (c *Conn) setWriteTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
1720 c.out.setTrafficSecret(suite, level, secret)
1721 }
1722
View as plain text