package pgdialect

import (
	"bytes"
	"database/sql"
	"encoding/hex"
	"fmt"
	"io"
	"time"

	"github.com/uptrace/bun/internal"
	"github.com/uptrace/bun/internal/parser"
	"github.com/uptrace/bun/schema"
)

type MultiRange[T any] []Range[T]

type Range[T any] struct {
	Lower, Upper           T
	LowerBound, UpperBound RangeBound
}

type RangeBound byte

const (
	RangeBoundInclusiveLeft  RangeBound = '['
	RangeBoundInclusiveRight RangeBound = ']'
	RangeBoundExclusiveLeft  RangeBound = '('
	RangeBoundExclusiveRight RangeBound = ')'
)

func NewRange[T any](lower, upper T) Range[T] {
	return Range[T]{
		Lower:      lower,
		Upper:      upper,
		LowerBound: RangeBoundInclusiveLeft,
		UpperBound: RangeBoundExclusiveRight,
	}
}

var _ sql.Scanner = (*Range[any])(nil)

func (r *Range[T]) Scan(anySrc any) (err error) {
	src := anySrc.([]byte)

	if len(src) == 0 {
		return io.ErrUnexpectedEOF
	}
	r.LowerBound = RangeBound(src[0])
	src = src[1:]

	src, err = scanElem(&r.Lower, src)
	if err != nil {
		return err
	}

	if len(src) == 0 {
		return io.ErrUnexpectedEOF
	}
	if ch := src[0]; ch != ',' {
		return fmt.Errorf("got %q, wanted %q", ch, ',')
	}
	src = src[1:]

	src, err = scanElem(&r.Upper, src)
	if err != nil {
		return err
	}

	if len(src) == 0 {
		return io.ErrUnexpectedEOF
	}
	r.UpperBound = RangeBound(src[0])
	src = src[1:]

	if len(src) > 0 {
		return fmt.Errorf("unread data: %q", src)
	}
	return nil
}

var _ schema.QueryAppender = (*Range[any])(nil)

func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) {
	buf = append(buf, byte(r.LowerBound))
	buf = appendElem(buf, r.Lower)
	buf = append(buf, ',')
	buf = appendElem(buf, r.Upper)
	buf = append(buf, byte(r.UpperBound))
	return buf, nil
}

func appendElem(buf []byte, val any) []byte {
	switch val := val.(type) {
	case time.Time:
		buf = append(buf, '"')
		buf = appendTime(buf, val)
		buf = append(buf, '"')
		return buf
	default:
		panic(fmt.Errorf("unsupported range type: %T", val))
	}
}

func scanElem(ptr any, src []byte) ([]byte, error) {
	switch ptr := ptr.(type) {
	case *time.Time:
		src, str, err := readStringLiteral(src)
		if err != nil {
			return nil, err
		}

		tm, err := internal.ParseTime(internal.String(str))
		if err != nil {
			return nil, err
		}
		*ptr = tm

		return src, nil
	default:
		panic(fmt.Errorf("unsupported range type: %T", ptr))
	}
}

func readStringLiteral(src []byte) ([]byte, []byte, error) {
	p := newParser(src)

	if err := p.Skip('"'); err != nil {
		return nil, nil, err
	}

	str, err := p.ReadSubstring('"')
	if err != nil {
		return nil, nil, err
	}

	src = p.Remaining()
	return src, str, nil
}

//------------------------------------------------------------------------------

type pgparser struct {
	parser.Parser
	buf []byte
}

func newParser(b []byte) *pgparser {
	p := new(pgparser)
	p.Reset(b)
	return p
}

func (p *pgparser) ReadLiteral(ch byte) []byte {
	p.Unread()
	lit, _ := p.ReadSep(',')
	return lit
}

func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) {
	return p.readSubstring(ch, false)
}

func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) {
	return p.readSubstring(ch, true)
}

func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) {
	ch, err := p.ReadByte()
	if err != nil {
		return nil, err
	}

	p.buf = p.buf[:0]
	for {
		if ch == '"' {
			break
		}

		next, err := p.ReadByte()
		if err != nil {
			return nil, err
		}

		if ch == '\\' {
			switch next {
			case '\\', '"':
				p.buf = append(p.buf, next)

				ch, err = p.ReadByte()
				if err != nil {
					return nil, err
				}
			default:
				p.buf = append(p.buf, '\\')
				ch = next
			}
			continue
		}

		if escaped && ch == '\'' && next == '\'' {
			p.buf = append(p.buf, next)
			ch, err = p.ReadByte()
			if err != nil {
				return nil, err
			}
			continue
		}

		p.buf = append(p.buf, ch)
		ch = next
	}

	if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 {
		data := p.buf[2:]
		buf := make([]byte, hex.DecodedLen(len(data)))
		n, err := hex.Decode(buf, data)
		if err != nil {
			return nil, err
		}
		return buf[:n], nil
	}

	return p.buf, nil
}

func (p *pgparser) ReadRange(ch byte) ([]byte, error) {
	p.buf = p.buf[:0]
	p.buf = append(p.buf, ch)

	for p.Valid() {
		ch = p.Read()
		p.buf = append(p.buf, ch)
		if ch == ']' || ch == ')' {
			break
		}
	}

	return p.buf, nil
}