gotosocial/vendor/github.com/tdewolff/parse/v2/html/parse.go

404 lines
8.1 KiB
Go
Raw Normal View History

package html
import (
"bytes"
"fmt"
"io"
"strings"
"github.com/tdewolff/parse/v2"
"github.com/tdewolff/parse/v2/css"
)
type AST struct {
Children []*Tag
Text []byte
}
func (ast *AST) String() string {
sb := strings.Builder{}
for i, child := range ast.Children {
if i != 0 {
sb.WriteString("\n")
}
sb.WriteString(child.ASTString())
}
return sb.String()
}
type Attr struct {
Key, Val []byte
}
func (attr *Attr) String() string {
return fmt.Sprintf(`%s="%s"`, string(attr.Key), string(attr.Val))
}
type Tag struct {
Root *AST
Parent *Tag
Prev, Next *Tag
Children []*Tag
Index int
Name []byte
Attrs []Attr
textStart, textEnd int
}
func (tag *Tag) getAttr(key []byte) ([]byte, bool) {
for _, attr := range tag.Attrs {
if bytes.Equal(key, attr.Key) {
return attr.Val, true
}
}
return nil, false
}
func (tag *Tag) GetAttr(key string) (string, bool) {
val, ok := tag.getAttr([]byte(key))
return string(val), ok
}
func (tag *Tag) Text() string {
return string(tag.Root.Text[tag.textStart:tag.textEnd])
}
func (tag *Tag) String() string {
sb := strings.Builder{}
sb.WriteString("<")
sb.Write(tag.Name)
for _, attr := range tag.Attrs {
sb.WriteString(" ")
sb.WriteString(attr.String())
}
sb.WriteString(">")
return sb.String()
}
func (tag *Tag) ASTString() string {
sb := strings.Builder{}
sb.WriteString(tag.String())
for _, child := range tag.Children {
sb.WriteString("\n ")
s := child.ASTString()
s = strings.ReplaceAll(s, "\n", "\n ")
sb.WriteString(s)
}
return sb.String()
}
func Parse(r *parse.Input) (*AST, error) {
ast := &AST{}
root := &Tag{}
cur := root
l := NewLexer(r)
for {
tt, data := l.Next()
switch tt {
case ErrorToken:
if err := l.Err(); err != io.EOF {
return nil, err
}
ast.Children = root.Children
return ast, nil
case TextToken:
ast.Text = append(ast.Text, data...)
case StartTagToken:
child := &Tag{
Root: ast,
Parent: cur,
Index: len(cur.Children),
Name: l.Text(),
textStart: len(ast.Text),
}
if 0 < len(cur.Children) {
child.Prev = cur.Children[len(cur.Children)-1]
child.Prev.Next = child
}
cur.Children = append(cur.Children, child)
cur = child
case AttributeToken:
val := l.AttrVal()
if 0 < len(val) && (val[0] == '"' || val[0] == '\'') {
val = val[1 : len(val)-1]
}
cur.Attrs = append(cur.Attrs, Attr{l.AttrKey(), val})
case StartTagCloseToken:
if voidTags[string(cur.Name)] {
cur.textEnd = len(ast.Text)
cur = cur.Parent
}
case EndTagToken, StartTagVoidToken:
start := cur
for start != root && !bytes.Equal(l.Text(), start.Name) {
start = start.Parent
}
if start == root {
// ignore
} else {
parent := start.Parent
for cur != parent {
cur.textEnd = len(ast.Text)
cur = cur.Parent
}
}
}
}
}
func (ast *AST) Query(s string) (*Tag, error) {
sel, err := ParseSelector(s)
if err != nil {
return nil, err
}
for _, child := range ast.Children {
if match := child.query(sel); match != nil {
return match, nil
}
}
return nil, nil
}
func (tag *Tag) query(sel selector) *Tag {
if sel.AppliesTo(tag) {
return tag
}
for _, child := range tag.Children {
if match := child.query(sel); match != nil {
return match
}
}
return nil
}
func (ast *AST) QueryAll(s string) ([]*Tag, error) {
sel, err := ParseSelector(s)
if err != nil {
return nil, err
}
matches := []*Tag{}
for _, child := range ast.Children {
child.queryAll(&matches, sel)
}
return matches, nil
}
func (tag *Tag) queryAll(matches *[]*Tag, sel selector) {
if sel.AppliesTo(tag) {
*matches = append(*matches, tag)
}
for _, child := range tag.Children {
child.queryAll(matches, sel)
}
}
type attrSelector struct {
op byte // empty, =, ~, |
attr []byte
val []byte
}
func (sel attrSelector) AppliesTo(tag *Tag) bool {
val, ok := tag.getAttr(sel.attr)
if !ok {
return false
}
switch sel.op {
case 0:
return true
case '=':
return bytes.Equal(val, sel.val)
case '~':
if 0 < len(sel.val) {
vals := bytes.Split(val, []byte(" "))
for _, val := range vals {
if bytes.Equal(val, sel.val) {
return true
}
}
}
case '|':
return bytes.Equal(val, sel.val) || bytes.HasPrefix(val, append(sel.val, '-'))
}
return false
}
func (attr attrSelector) String() string {
sb := strings.Builder{}
sb.Write(attr.attr)
if attr.op != 0 {
sb.WriteByte(attr.op)
if attr.op != '=' {
sb.WriteByte('=')
}
sb.WriteByte('"')
sb.Write(attr.val)
sb.WriteByte('"')
}
return sb.String()
}
type selectorNode struct {
typ []byte // is * for universal
attrs []attrSelector
op byte // space or >, last is NULL
}
func (sel selectorNode) AppliesTo(tag *Tag) bool {
if 0 < len(sel.typ) && !bytes.Equal(sel.typ, []byte("*")) && !bytes.Equal(sel.typ, tag.Name) {
return false
}
for _, attr := range sel.attrs {
if !attr.AppliesTo(tag) {
return false
}
}
return true
}
func (sel selectorNode) String() string {
sb := strings.Builder{}
sb.Write(sel.typ)
for _, attr := range sel.attrs {
if bytes.Equal(attr.attr, []byte("id")) && attr.op == '=' {
sb.WriteByte('#')
sb.Write(attr.val)
} else if bytes.Equal(attr.attr, []byte("class")) && attr.op == '~' {
sb.WriteByte('.')
sb.Write(attr.val)
} else {
sb.WriteByte('[')
sb.WriteString(attr.String())
sb.WriteByte(']')
}
}
if sel.op != 0 {
sb.WriteByte(' ')
sb.WriteByte(sel.op)
sb.WriteByte(' ')
}
return sb.String()
}
type token struct {
tt css.TokenType
data []byte
}
type selector []selectorNode
func ParseSelector(s string) (selector, error) {
ts := []token{}
l := css.NewLexer(parse.NewInputString(s))
for {
tt, data := l.Next()
if tt == css.ErrorToken {
if err := l.Err(); err != io.EOF {
return selector{}, err
}
break
}
ts = append(ts, token{
tt: tt,
data: data,
})
}
sel := selector{}
node := selectorNode{}
for i := 0; i < len(ts); i++ {
t := ts[i]
if 0 < i && (t.tt == css.WhitespaceToken || t.tt == css.DelimToken && t.data[0] == '>') {
if t.tt == css.DelimToken {
node.op = '>'
} else {
node.op = ' '
}
sel = append(sel, node)
node = selectorNode{}
} else if t.tt == css.IdentToken || t.tt == css.DelimToken && t.data[0] == '*' {
node.typ = t.data
} else if t.tt == css.DelimToken && (t.data[0] == '.' || t.data[0] == '#') && i+1 < len(ts) && ts[i+1].tt == css.IdentToken {
if t.data[0] == '#' {
node.attrs = append(node.attrs, attrSelector{op: '=', attr: []byte("id"), val: ts[i+1].data})
} else {
node.attrs = append(node.attrs, attrSelector{op: '~', attr: []byte("class"), val: ts[i+1].data})
}
i++
} else if t.tt == css.DelimToken && t.data[0] == '[' && i+2 < len(ts) && ts[i+1].tt == css.IdentToken && ts[i+2].tt == css.DelimToken {
if ts[i+2].data[0] == ']' {
node.attrs = append(node.attrs, attrSelector{op: 0, attr: ts[i+1].data})
i += 2
} else if i+4 < len(ts) && ts[i+3].tt == css.IdentToken && ts[i+4].tt == css.DelimToken && ts[i+4].data[0] == ']' {
node.attrs = append(node.attrs, attrSelector{op: ts[i+2].data[0], attr: ts[i+1].data, val: ts[i+3].data})
i += 4
}
}
}
sel = append(sel, node)
return sel, nil
}
func (sels selector) AppliesTo(tag *Tag) bool {
if len(sels) == 0 {
return true
} else if !sels[len(sels)-1].AppliesTo(tag) {
return false
}
tag = tag.Parent
isel := len(sels) - 2
for 0 <= isel && tag != nil {
switch sels[isel].op {
case ' ':
for tag != nil {
if sels[isel].AppliesTo(tag) {
break
}
tag = tag.Parent
}
case '>':
if !sels[isel].AppliesTo(tag) {
return false
}
tag = tag.Parent
default:
return false
}
isel--
}
return len(sels) != 0 && isel == -1
}
func (sels selector) String() string {
if len(sels) == 0 {
return ""
}
sb := strings.Builder{}
for _, sel := range sels {
sb.WriteString(sel.String())
}
return sb.String()[1:]
}
var voidTags = map[string]bool{
"area": true,
"base": true,
"br": true,
"col": true,
"embed": true,
"hr": true,
"img": true,
"input": true,
"link": true,
"meta": true,
"source": true,
"track": true,
"wbr": true,
}