1from sqlalchemy import and_
2from sqlalchemy import asc
3from sqlalchemy import bindparam
4from sqlalchemy import cast
5from sqlalchemy import desc
6from sqlalchemy import exc
7from sqlalchemy import except_
8from sqlalchemy import ForeignKey
9from sqlalchemy import func
10from sqlalchemy import INT
11from sqlalchemy import Integer
12from sqlalchemy import intersect
13from sqlalchemy import literal
14from sqlalchemy import literal_column
15from sqlalchemy import MetaData
16from sqlalchemy import not_
17from sqlalchemy import or_
18from sqlalchemy import select
19from sqlalchemy import sql
20from sqlalchemy import String
21from sqlalchemy import testing
22from sqlalchemy import text
23from sqlalchemy import tuple_
24from sqlalchemy import TypeDecorator
25from sqlalchemy import union
26from sqlalchemy import union_all
27from sqlalchemy import VARCHAR
28from sqlalchemy.engine import default
29from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
30from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
31from sqlalchemy.testing import assert_raises_message
32from sqlalchemy.testing import eq_
33from sqlalchemy.testing import fixtures
34from sqlalchemy.testing import is_
35from sqlalchemy.testing.schema import Column
36from sqlalchemy.testing.schema import Table
37from sqlalchemy.testing.util import resolve_lambda
38
39
40class QueryTest(fixtures.TablesTest):
41    __backend__ = True
42
43    @classmethod
44    def define_tables(cls, metadata):
45        Table(
46            "users",
47            metadata,
48            Column(
49                "user_id", INT, primary_key=True, test_needs_autoincrement=True
50            ),
51            Column("user_name", VARCHAR(20)),
52            test_needs_acid=True,
53        )
54        Table(
55            "addresses",
56            metadata,
57            Column(
58                "address_id",
59                Integer,
60                primary_key=True,
61                test_needs_autoincrement=True,
62            ),
63            Column("user_id", Integer, ForeignKey("users.user_id")),
64            Column("address", String(30)),
65            test_needs_acid=True,
66        )
67
68        Table(
69            "u2",
70            metadata,
71            Column("user_id", INT, primary_key=True),
72            Column("user_name", VARCHAR(20)),
73            test_needs_acid=True,
74        )
75
76    def test_order_by_label(self, connection):
77        """test that a label within an ORDER BY works on each backend.
78
79        This test should be modified to support [ticket:1068] when that ticket
80        is implemented.  For now, you need to put the actual string in the
81        ORDER BY.
82
83        """
84
85        users = self.tables.users
86
87        connection.execute(
88            users.insert(),
89            [
90                {"user_id": 7, "user_name": "jack"},
91                {"user_id": 8, "user_name": "ed"},
92                {"user_id": 9, "user_name": "fred"},
93            ],
94        )
95
96        concat = ("test: " + users.c.user_name).label("thedata")
97        eq_(
98            connection.execute(select(concat).order_by("thedata")).fetchall(),
99            [("test: ed",), ("test: fred",), ("test: jack",)],
100        )
101
102        eq_(
103            connection.execute(select(concat).order_by("thedata")).fetchall(),
104            [("test: ed",), ("test: fred",), ("test: jack",)],
105        )
106
107        concat = ("test: " + users.c.user_name).label("thedata")
108        eq_(
109            connection.execute(
110                select(concat).order_by(desc("thedata"))
111            ).fetchall(),
112            [("test: jack",), ("test: fred",), ("test: ed",)],
113        )
114
115    @testing.requires.order_by_label_with_expression
116    def test_order_by_label_compound(self, connection):
117        users = self.tables.users
118        connection.execute(
119            users.insert(),
120            [
121                {"user_id": 7, "user_name": "jack"},
122                {"user_id": 8, "user_name": "ed"},
123                {"user_id": 9, "user_name": "fred"},
124            ],
125        )
126
127        concat = ("test: " + users.c.user_name).label("thedata")
128        eq_(
129            connection.execute(
130                select(concat).order_by(literal_column("thedata") + "x")
131            ).fetchall(),
132            [("test: ed",), ("test: fred",), ("test: jack",)],
133        )
134
135    @testing.requires.boolean_col_expressions
136    def test_or_and_as_columns(self, connection):
137        true, false = literal(True), literal(False)
138
139        eq_(connection.execute(select(and_(true, false))).scalar(), False)
140        eq_(connection.execute(select(and_(true, true))).scalar(), True)
141        eq_(connection.execute(select(or_(true, false))).scalar(), True)
142        eq_(connection.execute(select(or_(false, false))).scalar(), False)
143        eq_(
144            connection.execute(select(not_(or_(false, false)))).scalar(),
145            True,
146        )
147
148        row = connection.execute(
149            select(or_(false, false).label("x"), and_(true, false).label("y"))
150        ).first()
151        assert row.x == False  # noqa
152        assert row.y == False  # noqa
153
154        row = connection.execute(
155            select(or_(true, false).label("x"), and_(true, false).label("y"))
156        ).first()
157        assert row.x == True  # noqa
158        assert row.y == False  # noqa
159
160    def test_select_tuple(self, connection):
161        users = self.tables.users
162        connection.execute(
163            users.insert(),
164            {"user_id": 1, "user_name": "apples"},
165        )
166
167        assert_raises_message(
168            exc.CompileError,
169            r"Most backends don't support SELECTing from a tuple\(\) object.",
170            connection.execute,
171            select(tuple_(users.c.user_id, users.c.user_name)),
172        )
173
174    @testing.combinations(
175        (
176            lambda users: select(users.c.user_id).where(
177                users.c.user_name.startswith("apple")
178            ),
179            [(1,)],
180        ),
181        (
182            lambda users: select(users.c.user_id).where(
183                users.c.user_name.contains("i % t")
184            ),
185            [(5,)],
186        ),
187        (
188            lambda users: select(users.c.user_id).where(
189                users.c.user_name.endswith("anas")
190            ),
191            [(3,)],
192        ),
193        (
194            lambda users: select(users.c.user_id).where(
195                users.c.user_name.contains("i % t", escape="&")
196            ),
197            [(5,)],
198        ),
199        argnames="expr,result",
200    )
201    def test_like_ops(self, connection, expr, result):
202        users = self.tables.users
203        connection.execute(
204            users.insert(),
205            [
206                {"user_id": 1, "user_name": "apples"},
207                {"user_id": 2, "user_name": "oranges"},
208                {"user_id": 3, "user_name": "bananas"},
209                {"user_id": 4, "user_name": "legumes"},
210                {"user_id": 5, "user_name": "hi % there"},
211            ],
212        )
213
214        expr = resolve_lambda(expr, users=users)
215        eq_(connection.execute(expr).fetchall(), result)
216
217    @testing.requires.mod_operator_as_percent_sign
218    @testing.emits_warning(".*now automatically escapes.*")
219    def test_percents_in_text(self, connection):
220        for expr, result in (
221            (text("select 6 % 10"), 6),
222            (text("select 17 % 10"), 7),
223            (text("select '%'"), "%"),
224            (text("select '%%'"), "%%"),
225            (text("select '%%%'"), "%%%"),
226            (text("select 'hello % world'"), "hello % world"),
227        ):
228            eq_(connection.scalar(expr), result)
229
230    def test_ilike(self, connection):
231        users = self.tables.users
232        connection.execute(
233            users.insert(),
234            [
235                {"user_id": 1, "user_name": "one"},
236                {"user_id": 2, "user_name": "TwO"},
237                {"user_id": 3, "user_name": "ONE"},
238                {"user_id": 4, "user_name": "OnE"},
239            ],
240        )
241
242        eq_(
243            connection.execute(
244                select(users.c.user_id).where(users.c.user_name.ilike("one"))
245            ).fetchall(),
246            [(1,), (3,), (4,)],
247        )
248
249        eq_(
250            connection.execute(
251                select(users.c.user_id).where(users.c.user_name.ilike("TWO"))
252            ).fetchall(),
253            [(2,)],
254        )
255
256        if testing.against("postgresql"):
257            eq_(
258                connection.execute(
259                    select(users.c.user_id).where(
260                        users.c.user_name.like("one")
261                    )
262                ).fetchall(),
263                [(1,)],
264            )
265            eq_(
266                connection.execute(
267                    select(users.c.user_id).where(
268                        users.c.user_name.like("TWO")
269                    )
270                ).fetchall(),
271                [],
272            )
273
274    def test_compiled_execute(self, connection):
275        users = self.tables.users
276        connection.execute(users.insert(), dict(user_id=7, user_name="jack"))
277        s = (
278            select(users)
279            .where(users.c.user_id == bindparam("id"))
280            .compile(connection)
281        )
282        eq_(connection.execute(s, dict(id=7)).first()._mapping["user_id"], 7)
283
284    def test_compiled_insert_execute(self, connection):
285        users = self.tables.users
286        connection.execute(
287            users.insert().compile(connection),
288            dict(user_id=7, user_name="jack"),
289        )
290        s = (
291            select(users)
292            .where(users.c.user_id == bindparam("id"))
293            .compile(connection)
294        )
295        eq_(connection.execute(s, dict(id=7)).first()._mapping["user_id"], 7)
296
297    def test_repeated_bindparams(self, connection):
298        """Tests that a BindParam can be used more than once.
299
300        This should be run for DB-APIs with both positional and named
301        paramstyles.
302        """
303        users = self.tables.users
304
305        connection.execute(users.insert(), dict(user_id=7, user_name="jack"))
306        connection.execute(users.insert(), dict(user_id=8, user_name="fred"))
307
308        u = bindparam("userid")
309        s = users.select().where(
310            and_(users.c.user_name == u, users.c.user_name == u)
311        )
312        r = connection.execute(s, dict(userid="fred")).fetchall()
313        assert len(r) == 1
314
315    def test_bindparam_detection(self):
316        dialect = default.DefaultDialect(paramstyle="qmark")
317
318        def prep(q):
319            return str(sql.text(q).compile(dialect=dialect))
320
321        def a_eq(got, wanted):
322            if got != wanted:
323                print("Wanted %s" % wanted)
324                print("Received %s" % got)
325            self.assert_(got == wanted, got)
326
327        a_eq(prep("select foo"), "select foo")
328        a_eq(prep("time='12:30:00'"), "time='12:30:00'")
329        a_eq(prep("time='12:30:00'"), "time='12:30:00'")
330        a_eq(prep(":this:that"), ":this:that")
331        a_eq(prep(":this :that"), "? ?")
332        a_eq(prep("(:this),(:that :other)"), "(?),(? ?)")
333        a_eq(prep("(:this),(:that:other)"), "(?),(:that:other)")
334        a_eq(prep("(:this),(:that,:other)"), "(?),(?,?)")
335        a_eq(prep("(:that_:other)"), "(:that_:other)")
336        a_eq(prep("(:that_ :other)"), "(? ?)")
337        a_eq(prep("(:that_other)"), "(?)")
338        a_eq(prep("(:that$other)"), "(?)")
339        a_eq(prep("(:that$:other)"), "(:that$:other)")
340        a_eq(prep(".:that$ :other."), ".? ?.")
341
342        a_eq(prep(r"select \foo"), r"select \foo")
343        a_eq(prep(r"time='12\:30:00'"), r"time='12\:30:00'")
344        a_eq(prep(r":this \:that"), "? :that")
345        a_eq(prep(r"(\:that$other)"), "(:that$other)")
346        a_eq(prep(r".\:that$ :other."), ".:that$ ?.")
347
348    @testing.requires.standalone_binds
349    def test_select_from_bindparam(self, connection):
350        """Test result row processing when selecting from a plain bind
351        param."""
352
353        class MyInteger(TypeDecorator):
354            impl = Integer
355            cache_ok = True
356
357            def process_bind_param(self, value, dialect):
358                return int(value[4:])
359
360            def process_result_value(self, value, dialect):
361                return "INT_%d" % value
362
363        eq_(
364            connection.scalar(select(cast("INT_5", type_=MyInteger))),
365            "INT_5",
366        )
367        eq_(
368            connection.scalar(
369                select(cast("INT_5", type_=MyInteger).label("foo"))
370            ),
371            "INT_5",
372        )
373
374    def test_order_by(self, connection):
375        """Exercises ORDER BY clause generation.
376
377        Tests simple, compound, aliased and DESC clauses.
378        """
379
380        users = self.tables.users
381
382        connection.execute(users.insert(), dict(user_id=1, user_name="c"))
383        connection.execute(users.insert(), dict(user_id=2, user_name="b"))
384        connection.execute(users.insert(), dict(user_id=3, user_name="a"))
385
386        def a_eq(executable, wanted):
387            got = list(connection.execute(executable))
388            eq_(got, wanted)
389
390        for labels in False, True:
391            label_style = (
392                LABEL_STYLE_NONE
393                if labels is False
394                else LABEL_STYLE_TABLENAME_PLUS_COL
395            )
396
397            def go(stmt):
398                if labels:
399                    stmt = stmt.set_label_style(label_style)
400                return stmt
401
402            a_eq(
403                users.select()
404                .order_by(users.c.user_id)
405                .set_label_style(label_style),
406                [(1, "c"), (2, "b"), (3, "a")],
407            )
408
409            a_eq(
410                users.select()
411                .order_by(users.c.user_name, users.c.user_id)
412                .set_label_style(label_style),
413                [(3, "a"), (2, "b"), (1, "c")],
414            )
415
416            a_eq(
417                go(
418                    select(users.c.user_id.label("foo")).order_by(
419                        users.c.user_id
420                    )
421                ),
422                [(1,), (2,), (3,)],
423            )
424
425            a_eq(
426                go(
427                    select(
428                        users.c.user_id.label("foo"), users.c.user_name
429                    ).order_by(users.c.user_name, users.c.user_id),
430                ),
431                [(3, "a"), (2, "b"), (1, "c")],
432            )
433
434            a_eq(
435                users.select()
436                .distinct()
437                .order_by(users.c.user_id)
438                .set_label_style(label_style),
439                [(1, "c"), (2, "b"), (3, "a")],
440            )
441
442            a_eq(
443                go(
444                    select(users.c.user_id.label("foo"))
445                    .distinct()
446                    .order_by(users.c.user_id),
447                ),
448                [(1,), (2,), (3,)],
449            )
450
451            a_eq(
452                go(
453                    select(
454                        users.c.user_id.label("a"),
455                        users.c.user_id.label("b"),
456                        users.c.user_name,
457                    ).order_by(users.c.user_id),
458                ),
459                [(1, 1, "c"), (2, 2, "b"), (3, 3, "a")],
460            )
461
462            a_eq(
463                users.select()
464                .distinct()
465                .order_by(desc(users.c.user_id))
466                .set_label_style(label_style),
467                [(3, "a"), (2, "b"), (1, "c")],
468            )
469
470            a_eq(
471                go(
472                    select(users.c.user_id.label("foo"))
473                    .distinct()
474                    .order_by(users.c.user_id.desc()),
475                ),
476                [(3,), (2,), (1,)],
477            )
478
479    @testing.requires.nullsordering
480    def test_order_by_nulls(self, connection):
481        """Exercises ORDER BY clause generation.
482
483        Tests simple, compound, aliased and DESC clauses.
484        """
485
486        users = self.tables.users
487
488        connection.execute(users.insert(), dict(user_id=1))
489        connection.execute(users.insert(), dict(user_id=2, user_name="b"))
490        connection.execute(users.insert(), dict(user_id=3, user_name="a"))
491
492        def a_eq(executable, wanted):
493            got = list(connection.execute(executable))
494            eq_(got, wanted)
495
496        for labels in False, True:
497            label_style = (
498                LABEL_STYLE_NONE
499                if labels is False
500                else LABEL_STYLE_TABLENAME_PLUS_COL
501            )
502            a_eq(
503                users.select()
504                .order_by(users.c.user_name.nulls_first())
505                .set_label_style(label_style),
506                [(1, None), (3, "a"), (2, "b")],
507            )
508
509            a_eq(
510                users.select()
511                .order_by(users.c.user_name.nulls_last())
512                .set_label_style(label_style),
513                [(3, "a"), (2, "b"), (1, None)],
514            )
515
516            a_eq(
517                users.select()
518                .order_by(asc(users.c.user_name).nulls_first())
519                .set_label_style(label_style),
520                [(1, None), (3, "a"), (2, "b")],
521            )
522
523            a_eq(
524                users.select()
525                .order_by(asc(users.c.user_name).nulls_last())
526                .set_label_style(label_style),
527                [(3, "a"), (2, "b"), (1, None)],
528            )
529
530            a_eq(
531                users.select()
532                .order_by(users.c.user_name.desc().nulls_first())
533                .set_label_style(label_style),
534                [(1, None), (2, "b"), (3, "a")],
535            )
536
537            a_eq(
538                users.select()
539                .order_by(users.c.user_name.desc().nulls_last())
540                .set_label_style(label_style),
541                [(2, "b"), (3, "a"), (1, None)],
542            )
543
544            a_eq(
545                users.select()
546                .order_by(desc(users.c.user_name).nulls_first())
547                .set_label_style(label_style),
548                [(1, None), (2, "b"), (3, "a")],
549            )
550
551            a_eq(
552                users.select()
553                .order_by(desc(users.c.user_name).nulls_last())
554                .set_label_style(label_style),
555                [(2, "b"), (3, "a"), (1, None)],
556            )
557
558            a_eq(
559                users.select()
560                .order_by(
561                    users.c.user_name.nulls_first(),
562                    users.c.user_id,
563                )
564                .set_label_style(label_style),
565                [(1, None), (3, "a"), (2, "b")],
566            )
567
568            a_eq(
569                users.select()
570                .order_by(users.c.user_name.nulls_last(), users.c.user_id)
571                .set_label_style(label_style),
572                [(3, "a"), (2, "b"), (1, None)],
573            )
574
575    def test_in_filtering(self, connection):
576        """test the behavior of the in_() function."""
577        users = self.tables.users
578
579        connection.execute(users.insert(), dict(user_id=7, user_name="jack"))
580        connection.execute(users.insert(), dict(user_id=8, user_name="fred"))
581        connection.execute(users.insert(), dict(user_id=9, user_name=None))
582
583        s = users.select().where(users.c.user_name.in_([]))
584        r = connection.execute(s).fetchall()
585        # No username is in empty set
586        assert len(r) == 0
587
588        s = users.select().where(not_(users.c.user_name.in_([])))
589        r = connection.execute(s).fetchall()
590        assert len(r) == 3
591
592        s = users.select().where(users.c.user_name.in_(["jack", "fred"]))
593        r = connection.execute(s).fetchall()
594        assert len(r) == 2
595
596        s = users.select().where(not_(users.c.user_name.in_(["jack", "fred"])))
597        r = connection.execute(s).fetchall()
598        # Null values are not outside any set
599        assert len(r) == 0
600
601    def test_expanding_in(self, connection):
602        users = self.tables.users
603        connection.execute(
604            users.insert(),
605            [
606                dict(user_id=7, user_name="jack"),
607                dict(user_id=8, user_name="fred"),
608                dict(user_id=9, user_name=None),
609            ],
610        )
611
612        stmt = (
613            select(users)
614            .where(users.c.user_name.in_(bindparam("uname", expanding=True)))
615            .order_by(users.c.user_id)
616        )
617
618        eq_(
619            connection.execute(stmt, {"uname": ["jack"]}).fetchall(),
620            [(7, "jack")],
621        )
622
623        eq_(
624            connection.execute(stmt, {"uname": ["jack", "fred"]}).fetchall(),
625            [(7, "jack"), (8, "fred")],
626        )
627
628        eq_(connection.execute(stmt, {"uname": []}).fetchall(), [])
629
630        assert_raises_message(
631            exc.StatementError,
632            "'expanding' parameters can't be used with executemany()",
633            connection.execute,
634            users.update().where(
635                users.c.user_name.in_(bindparam("uname", expanding=True))
636            ),
637            [{"uname": ["fred"]}, {"uname": ["ed"]}],
638        )
639
640    @testing.requires.no_quoting_special_bind_names
641    def test_expanding_in_special_chars(self, connection):
642        users = self.tables.users
643        connection.execute(
644            users.insert(),
645            [
646                dict(user_id=7, user_name="jack"),
647                dict(user_id=8, user_name="fred"),
648            ],
649        )
650
651        stmt = (
652            select(users)
653            .where(users.c.user_name.in_(bindparam("u35", expanding=True)))
654            .where(users.c.user_id == bindparam("u46"))
655            .order_by(users.c.user_id)
656        )
657
658        eq_(
659            connection.execute(
660                stmt, {"u35": ["jack", "fred"], "u46": 7}
661            ).fetchall(),
662            [(7, "jack")],
663        )
664
665        stmt = (
666            select(users)
667            .where(users.c.user_name.in_(bindparam("u.35", expanding=True)))
668            .where(users.c.user_id == bindparam("u.46"))
669            .order_by(users.c.user_id)
670        )
671
672        eq_(
673            connection.execute(
674                stmt, {"u.35": ["jack", "fred"], "u.46": 7}
675            ).fetchall(),
676            [(7, "jack")],
677        )
678
679    def test_expanding_in_multiple(self, connection):
680        users = self.tables.users
681
682        connection.execute(
683            users.insert(),
684            [
685                dict(user_id=7, user_name="jack"),
686                dict(user_id=8, user_name="fred"),
687                dict(user_id=9, user_name="ed"),
688            ],
689        )
690
691        stmt = (
692            select(users)
693            .where(users.c.user_name.in_(bindparam("uname", expanding=True)))
694            .where(users.c.user_id.in_(bindparam("userid", expanding=True)))
695            .order_by(users.c.user_id)
696        )
697
698        eq_(
699            connection.execute(
700                stmt, {"uname": ["jack", "fred", "ed"], "userid": [8, 9]}
701            ).fetchall(),
702            [(8, "fred"), (9, "ed")],
703        )
704
705    def test_expanding_in_repeated(self, connection):
706        users = self.tables.users
707
708        connection.execute(
709            users.insert(),
710            [
711                dict(user_id=7, user_name="jack"),
712                dict(user_id=8, user_name="fred"),
713                dict(user_id=9, user_name="ed"),
714            ],
715        )
716
717        stmt = (
718            select(users)
719            .where(
720                users.c.user_name.in_(bindparam("uname", expanding=True))
721                | users.c.user_name.in_(bindparam("uname2", expanding=True))
722            )
723            .where(users.c.user_id == 8)
724        )
725        stmt = stmt.union(
726            select(users)
727            .where(
728                users.c.user_name.in_(bindparam("uname", expanding=True))
729                | users.c.user_name.in_(bindparam("uname2", expanding=True))
730            )
731            .where(users.c.user_id == 9)
732        ).order_by("user_id")
733
734        eq_(
735            connection.execute(
736                stmt,
737                {
738                    "uname": ["jack", "fred"],
739                    "uname2": ["ed"],
740                    "userid": [8, 9],
741                },
742            ).fetchall(),
743            [(8, "fred"), (9, "ed")],
744        )
745
746    @testing.requires.tuple_in
747    def test_expanding_in_composite(self, connection):
748        users = self.tables.users
749
750        connection.execute(
751            users.insert(),
752            [
753                dict(user_id=7, user_name="jack"),
754                dict(user_id=8, user_name="fred"),
755                dict(user_id=9, user_name=None),
756            ],
757        )
758
759        stmt = (
760            select(users)
761            .where(
762                tuple_(users.c.user_id, users.c.user_name).in_(
763                    bindparam("uname", expanding=True)
764                )
765            )
766            .order_by(users.c.user_id)
767        )
768
769        eq_(
770            connection.execute(stmt, {"uname": [(7, "jack")]}).fetchall(),
771            [(7, "jack")],
772        )
773
774        eq_(
775            connection.execute(
776                stmt, {"uname": [(7, "jack"), (8, "fred")]}
777            ).fetchall(),
778            [(7, "jack"), (8, "fred")],
779        )
780
781    def test_expanding_in_dont_alter_compiled(self, connection):
782        """test for issue #5048"""
783
784        class NameWithProcess(TypeDecorator):
785            impl = String
786            cache_ok = True
787
788            def process_bind_param(self, value, dialect):
789                return value[3:]
790
791        users = Table(
792            "users",
793            MetaData(),
794            Column("user_id", Integer, primary_key=True),
795            Column("user_name", NameWithProcess()),
796        )
797
798        connection.execute(
799            users.insert(),
800            [
801                dict(user_id=7, user_name="AB jack"),
802                dict(user_id=8, user_name="BE fred"),
803                dict(user_id=9, user_name="GP ed"),
804            ],
805        )
806
807        stmt = (
808            select(users)
809            .where(users.c.user_name.in_(bindparam("uname", expanding=True)))
810            .order_by(users.c.user_id)
811        )
812
813        compiled = stmt.compile(testing.db)
814        eq_(len(compiled._bind_processors), 1)
815
816        eq_(
817            connection.execute(
818                compiled, {"uname": ["HJ jack", "RR fred"]}
819            ).fetchall(),
820            [(7, "jack"), (8, "fred")],
821        )
822
823        eq_(len(compiled._bind_processors), 1)
824
825    @testing.skip_if(["mssql"])
826    def test_bind_in(self, connection):
827        """test calling IN against a bind parameter.
828
829        this isn't allowed on several platforms since we
830        generate ? = ?.
831
832        """
833
834        users = self.tables.users
835
836        connection.execute(users.insert(), dict(user_id=7, user_name="jack"))
837        connection.execute(users.insert(), dict(user_id=8, user_name="fred"))
838        connection.execute(users.insert(), dict(user_id=9, user_name=None))
839
840        u = bindparam("search_key", type_=String)
841
842        s = users.select().where(not_(u.in_([])))
843        r = connection.execute(s, dict(search_key="john")).fetchall()
844        assert len(r) == 3
845        r = connection.execute(s, dict(search_key=None)).fetchall()
846        assert len(r) == 3
847
848    def test_literal_in(self, connection):
849        """similar to test_bind_in but use a bind with a value."""
850
851        users = self.tables.users
852
853        connection.execute(users.insert(), dict(user_id=7, user_name="jack"))
854        connection.execute(users.insert(), dict(user_id=8, user_name="fred"))
855        connection.execute(users.insert(), dict(user_id=9, user_name=None))
856
857        s = users.select().where(not_(literal("john").in_([])))
858        r = connection.execute(s).fetchall()
859        assert len(r) == 3
860
861    @testing.requires.boolean_col_expressions
862    def test_empty_in_filtering_static(self, connection):
863        """test the behavior of the in_() function when
864        comparing against an empty collection, specifically
865        that a proper boolean value is generated.
866
867        """
868        users = self.tables.users
869
870        connection.execute(
871            users.insert(),
872            [
873                {"user_id": 7, "user_name": "jack"},
874                {"user_id": 8, "user_name": "ed"},
875                {"user_id": 9, "user_name": None},
876            ],
877        )
878
879        s = users.select().where(users.c.user_name.in_([]) == True)  # noqa
880        r = connection.execute(s).fetchall()
881        assert len(r) == 0
882        s = users.select().where(users.c.user_name.in_([]) == False)  # noqa
883        r = connection.execute(s).fetchall()
884        assert len(r) == 3
885        s = users.select().where(users.c.user_name.in_([]) == None)  # noqa
886        r = connection.execute(s).fetchall()
887        assert len(r) == 0
888
889
890class RequiredBindTest(fixtures.TablesTest):
891    run_create_tables = None
892    run_deletes = None
893
894    @classmethod
895    def define_tables(cls, metadata):
896        Table(
897            "foo",
898            metadata,
899            Column("id", Integer, primary_key=True),
900            Column("data", String(50)),
901            Column("x", Integer),
902        )
903
904    def _assert_raises(self, stmt, params):
905        with testing.db.connect() as conn:
906            assert_raises_message(
907                exc.StatementError,
908                "A value is required for bind parameter 'x'",
909                conn.execute,
910                stmt,
911                params,
912            )
913
914    def test_insert(self):
915        stmt = self.tables.foo.insert().values(
916            x=bindparam("x"), data=bindparam("data")
917        )
918        self._assert_raises(stmt, {"data": "data"})
919
920    def test_select_where(self):
921        stmt = (
922            select(self.tables.foo)
923            .where(self.tables.foo.c.data == bindparam("data"))
924            .where(self.tables.foo.c.x == bindparam("x"))
925        )
926        self._assert_raises(stmt, {"data": "data"})
927
928    @testing.requires.standalone_binds
929    def test_select_columns(self):
930        stmt = select(bindparam("data"), bindparam("x"))
931        self._assert_raises(stmt, {"data": "data"})
932
933    def test_text(self):
934        stmt = text("select * from foo where x=:x and data=:data1")
935        self._assert_raises(stmt, {"data1": "data"})
936
937    def test_required_flag(self):
938        is_(bindparam("foo").required, True)
939        is_(bindparam("foo", required=False).required, False)
940        is_(bindparam("foo", "bar").required, False)
941        is_(bindparam("foo", "bar", required=True).required, True)
942
943        def c():
944            return None
945
946        is_(bindparam("foo", callable_=c, required=True).required, True)
947        is_(bindparam("foo", callable_=c).required, False)
948        is_(bindparam("foo", callable_=c, required=False).required, False)
949
950
951class LimitTest(fixtures.TablesTest):
952    __backend__ = True
953
954    @classmethod
955    def define_tables(cls, metadata):
956        Table(
957            "users",
958            metadata,
959            Column("user_id", INT, primary_key=True),
960            Column("user_name", VARCHAR(20)),
961        )
962        Table(
963            "addresses",
964            metadata,
965            Column("address_id", Integer, primary_key=True),
966            Column("user_id", Integer, ForeignKey("users.user_id")),
967            Column("address", String(30)),
968        )
969
970    @classmethod
971    def insert_data(cls, connection):
972        users, addresses = cls.tables("users", "addresses")
973        conn = connection
974        conn.execute(users.insert(), dict(user_id=1, user_name="john"))
975        conn.execute(
976            addresses.insert(), dict(address_id=1, user_id=1, address="addr1")
977        )
978        conn.execute(users.insert(), dict(user_id=2, user_name="jack"))
979        conn.execute(
980            addresses.insert(), dict(address_id=2, user_id=2, address="addr1")
981        )
982        conn.execute(users.insert(), dict(user_id=3, user_name="ed"))
983        conn.execute(
984            addresses.insert(), dict(address_id=3, user_id=3, address="addr2")
985        )
986        conn.execute(users.insert(), dict(user_id=4, user_name="wendy"))
987        conn.execute(
988            addresses.insert(), dict(address_id=4, user_id=4, address="addr3")
989        )
990        conn.execute(users.insert(), dict(user_id=5, user_name="laura"))
991        conn.execute(
992            addresses.insert(), dict(address_id=5, user_id=5, address="addr4")
993        )
994        conn.execute(users.insert(), dict(user_id=6, user_name="ralph"))
995        conn.execute(
996            addresses.insert(), dict(address_id=6, user_id=6, address="addr5")
997        )
998        conn.execute(users.insert(), dict(user_id=7, user_name="fido"))
999        conn.execute(
1000            addresses.insert(), dict(address_id=7, user_id=7, address="addr5")
1001        )
1002
1003    def test_select_limit(self, connection):
1004        users, addresses = self.tables("users", "addresses")
1005        r = connection.execute(
1006            users.select().limit(3).order_by(users.c.user_id)
1007        ).fetchall()
1008        self.assert_(r == [(1, "john"), (2, "jack"), (3, "ed")], repr(r))
1009
1010    @testing.requires.offset
1011    def test_select_limit_offset(self, connection):
1012        """Test the interaction between limit and offset"""
1013
1014        users, addresses = self.tables("users", "addresses")
1015
1016        r = connection.execute(
1017            users.select().limit(3).offset(2).order_by(users.c.user_id)
1018        ).fetchall()
1019        self.assert_(r == [(3, "ed"), (4, "wendy"), (5, "laura")])
1020        r = connection.execute(
1021            users.select().offset(5).order_by(users.c.user_id)
1022        ).fetchall()
1023        self.assert_(r == [(6, "ralph"), (7, "fido")])
1024
1025    def test_select_distinct_limit(self, connection):
1026        """Test the interaction between limit and distinct"""
1027
1028        users, addresses = self.tables("users", "addresses")
1029
1030        r = sorted(
1031            [
1032                x[0]
1033                for x in connection.execute(
1034                    select(addresses.c.address).distinct().limit(3)
1035                )
1036            ]
1037        )
1038        self.assert_(len(r) == 3, repr(r))
1039        self.assert_(r[0] != r[1] and r[1] != r[2], repr(r))
1040
1041    @testing.requires.offset
1042    def test_select_distinct_offset(self, connection):
1043        """Test the interaction between distinct and offset"""
1044
1045        users, addresses = self.tables("users", "addresses")
1046
1047        r = sorted(
1048            [
1049                x[0]
1050                for x in connection.execute(
1051                    select(addresses.c.address)
1052                    .distinct()
1053                    .offset(1)
1054                    .order_by(addresses.c.address)
1055                ).fetchall()
1056            ]
1057        )
1058        eq_(len(r), 4)
1059        self.assert_(r[0] != r[1] and r[1] != r[2] and r[2] != [3], repr(r))
1060
1061    @testing.requires.offset
1062    def test_select_distinct_limit_offset(self, connection):
1063        """Test the interaction between limit and limit/offset"""
1064
1065        users, addresses = self.tables("users", "addresses")
1066
1067        r = connection.execute(
1068            select(addresses.c.address)
1069            .order_by(addresses.c.address)
1070            .distinct()
1071            .offset(2)
1072            .limit(3)
1073        ).fetchall()
1074        self.assert_(len(r) == 3, repr(r))
1075        self.assert_(r[0] != r[1] and r[1] != r[2], repr(r))
1076
1077
1078class CompoundTest(fixtures.TablesTest):
1079
1080    """test compound statements like UNION, INTERSECT, particularly their
1081    ability to nest on different databases."""
1082
1083    __backend__ = True
1084
1085    run_inserts = "each"
1086
1087    @classmethod
1088    def define_tables(cls, metadata):
1089        Table(
1090            "t1",
1091            metadata,
1092            Column(
1093                "col1",
1094                Integer,
1095                test_needs_autoincrement=True,
1096                primary_key=True,
1097            ),
1098            Column("col2", String(30)),
1099            Column("col3", String(40)),
1100            Column("col4", String(30)),
1101        )
1102        Table(
1103            "t2",
1104            metadata,
1105            Column(
1106                "col1",
1107                Integer,
1108                test_needs_autoincrement=True,
1109                primary_key=True,
1110            ),
1111            Column("col2", String(30)),
1112            Column("col3", String(40)),
1113            Column("col4", String(30)),
1114        )
1115        Table(
1116            "t3",
1117            metadata,
1118            Column(
1119                "col1",
1120                Integer,
1121                test_needs_autoincrement=True,
1122                primary_key=True,
1123            ),
1124            Column("col2", String(30)),
1125            Column("col3", String(40)),
1126            Column("col4", String(30)),
1127        )
1128
1129    @classmethod
1130    def insert_data(cls, connection):
1131        t1, t2, t3 = cls.tables("t1", "t2", "t3")
1132        conn = connection
1133        conn.execute(
1134            t1.insert(),
1135            [
1136                dict(col2="t1col2r1", col3="aaa", col4="aaa"),
1137                dict(col2="t1col2r2", col3="bbb", col4="bbb"),
1138                dict(col2="t1col2r3", col3="ccc", col4="ccc"),
1139            ],
1140        )
1141        conn.execute(
1142            t2.insert(),
1143            [
1144                dict(col2="t2col2r1", col3="aaa", col4="bbb"),
1145                dict(col2="t2col2r2", col3="bbb", col4="ccc"),
1146                dict(col2="t2col2r3", col3="ccc", col4="aaa"),
1147            ],
1148        )
1149        conn.execute(
1150            t3.insert(),
1151            [
1152                dict(col2="t3col2r1", col3="aaa", col4="ccc"),
1153                dict(col2="t3col2r2", col3="bbb", col4="aaa"),
1154                dict(col2="t3col2r3", col3="ccc", col4="bbb"),
1155            ],
1156        )
1157
1158    def _fetchall_sorted(self, executed):
1159        return sorted([tuple(row) for row in executed.fetchall()])
1160
1161    @testing.requires.subqueries
1162    def test_union(self, connection):
1163        t1, t2, t3 = self.tables("t1", "t2", "t3")
1164        (s1, s2) = (
1165            select(t1.c.col3.label("col3"), t1.c.col4.label("col4")).where(
1166                t1.c.col2.in_(["t1col2r1", "t1col2r2"]),
1167            ),
1168            select(t2.c.col3.label("col3"), t2.c.col4.label("col4")).where(
1169                t2.c.col2.in_(["t2col2r2", "t2col2r3"]),
1170            ),
1171        )
1172        u = union(s1, s2)
1173
1174        wanted = [
1175            ("aaa", "aaa"),
1176            ("bbb", "bbb"),
1177            ("bbb", "ccc"),
1178            ("ccc", "aaa"),
1179        ]
1180        found1 = self._fetchall_sorted(connection.execute(u))
1181        eq_(found1, wanted)
1182
1183        found2 = self._fetchall_sorted(
1184            connection.execute(u.alias("bar").select())
1185        )
1186        eq_(found2, wanted)
1187
1188    def test_union_ordered(self, connection):
1189        t1, t2, t3 = self.tables("t1", "t2", "t3")
1190
1191        (s1, s2) = (
1192            select(t1.c.col3.label("col3"), t1.c.col4.label("col4")).where(
1193                t1.c.col2.in_(["t1col2r1", "t1col2r2"]),
1194            ),
1195            select(t2.c.col3.label("col3"), t2.c.col4.label("col4")).where(
1196                t2.c.col2.in_(["t2col2r2", "t2col2r3"]),
1197            ),
1198        )
1199        u = union(s1, s2).order_by("col3", "col4")
1200
1201        wanted = [
1202            ("aaa", "aaa"),
1203            ("bbb", "bbb"),
1204            ("bbb", "ccc"),
1205            ("ccc", "aaa"),
1206        ]
1207        eq_(connection.execute(u).fetchall(), wanted)
1208
1209    @testing.requires.subqueries
1210    def test_union_ordered_alias(self, connection):
1211        t1, t2, t3 = self.tables("t1", "t2", "t3")
1212
1213        (s1, s2) = (
1214            select(t1.c.col3.label("col3"), t1.c.col4.label("col4")).where(
1215                t1.c.col2.in_(["t1col2r1", "t1col2r2"]),
1216            ),
1217            select(t2.c.col3.label("col3"), t2.c.col4.label("col4")).where(
1218                t2.c.col2.in_(["t2col2r2", "t2col2r3"]),
1219            ),
1220        )
1221        u = union(s1, s2).order_by("col3", "col4")
1222
1223        wanted = [
1224            ("aaa", "aaa"),
1225            ("bbb", "bbb"),
1226            ("bbb", "ccc"),
1227            ("ccc", "aaa"),
1228        ]
1229        eq_(connection.execute(u.alias("bar").select()).fetchall(), wanted)
1230
1231    @testing.crashes("oracle", "FIXME: unknown, verify not fails_on")
1232    @testing.fails_on(
1233        testing.requires._mysql_not_mariadb_104, "FIXME: unknown"
1234    )
1235    @testing.fails_on("sqlite", "FIXME: unknown")
1236    def test_union_all(self, connection):
1237        t1, t2, t3 = self.tables("t1", "t2", "t3")
1238
1239        e = union_all(
1240            select(t1.c.col3),
1241            union(select(t1.c.col3), select(t1.c.col3)),
1242        )
1243
1244        wanted = [("aaa",), ("aaa",), ("bbb",), ("bbb",), ("ccc",), ("ccc",)]
1245        found1 = self._fetchall_sorted(connection.execute(e))
1246        eq_(found1, wanted)
1247
1248        found2 = self._fetchall_sorted(
1249            connection.execute(e.alias("foo").select())
1250        )
1251        eq_(found2, wanted)
1252
1253    def test_union_all_lightweight(self, connection):
1254        """like test_union_all, but breaks the sub-union into
1255        a subquery with an explicit column reference on the outside,
1256        more palatable to a wider variety of engines.
1257
1258        """
1259
1260        t1, t2, t3 = self.tables("t1", "t2", "t3")
1261
1262        u = union(select(t1.c.col3), select(t1.c.col3)).alias()
1263
1264        e = union_all(select(t1.c.col3), select(u.c.col3))
1265
1266        wanted = [("aaa",), ("aaa",), ("bbb",), ("bbb",), ("ccc",), ("ccc",)]
1267        found1 = self._fetchall_sorted(connection.execute(e))
1268        eq_(found1, wanted)
1269
1270        found2 = self._fetchall_sorted(
1271            connection.execute(e.alias("foo").select())
1272        )
1273        eq_(found2, wanted)
1274
1275    @testing.requires.intersect
1276    def test_intersect(self, connection):
1277        t1, t2, t3 = self.tables("t1", "t2", "t3")
1278
1279        i = intersect(
1280            select(t2.c.col3, t2.c.col4),
1281            select(t2.c.col3, t2.c.col4).where(t2.c.col4 == t3.c.col3),
1282        )
1283
1284        wanted = [("aaa", "bbb"), ("bbb", "ccc"), ("ccc", "aaa")]
1285
1286        found1 = self._fetchall_sorted(connection.execute(i))
1287        eq_(found1, wanted)
1288
1289        found2 = self._fetchall_sorted(
1290            connection.execute(i.alias("bar").select())
1291        )
1292        eq_(found2, wanted)
1293
1294    @testing.requires.except_
1295    @testing.fails_on("sqlite", "Can't handle this style of nesting")
1296    def test_except_style1(self, connection):
1297        t1, t2, t3 = self.tables("t1", "t2", "t3")
1298
1299        e = except_(
1300            union(
1301                select(t1.c.col3, t1.c.col4),
1302                select(t2.c.col3, t2.c.col4),
1303                select(t3.c.col3, t3.c.col4),
1304            ),
1305            select(t2.c.col3, t2.c.col4),
1306        )
1307
1308        wanted = [
1309            ("aaa", "aaa"),
1310            ("aaa", "ccc"),
1311            ("bbb", "aaa"),
1312            ("bbb", "bbb"),
1313            ("ccc", "bbb"),
1314            ("ccc", "ccc"),
1315        ]
1316
1317        found = self._fetchall_sorted(connection.execute(e.alias().select()))
1318        eq_(found, wanted)
1319
1320    @testing.requires.except_
1321    def test_except_style2(self, connection):
1322        # same as style1, but add alias().select() to the except_().
1323        # sqlite can handle it now.
1324
1325        t1, t2, t3 = self.tables("t1", "t2", "t3")
1326
1327        e = except_(
1328            union(
1329                select(t1.c.col3, t1.c.col4),
1330                select(t2.c.col3, t2.c.col4),
1331                select(t3.c.col3, t3.c.col4),
1332            )
1333            .alias()
1334            .select(),
1335            select(t2.c.col3, t2.c.col4),
1336        )
1337
1338        wanted = [
1339            ("aaa", "aaa"),
1340            ("aaa", "ccc"),
1341            ("bbb", "aaa"),
1342            ("bbb", "bbb"),
1343            ("ccc", "bbb"),
1344            ("ccc", "ccc"),
1345        ]
1346
1347        found1 = self._fetchall_sorted(connection.execute(e))
1348        eq_(found1, wanted)
1349
1350        found2 = self._fetchall_sorted(connection.execute(e.alias().select()))
1351        eq_(found2, wanted)
1352
1353    @testing.fails_on(
1354        ["sqlite", testing.requires._mysql_not_mariadb_104],
1355        "Can't handle this style of nesting",
1356    )
1357    @testing.requires.except_
1358    def test_except_style3(self, connection):
1359        # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
1360        t1, t2, t3 = self.tables("t1", "t2", "t3")
1361
1362        e = except_(
1363            select(t1.c.col3),  # aaa, bbb, ccc
1364            except_(
1365                select(t2.c.col3),  # aaa, bbb, ccc
1366                select(t3.c.col3).where(t3.c.col3 == "ccc"),  # ccc
1367            ),
1368        )
1369        eq_(connection.execute(e).fetchall(), [("ccc",)])
1370        eq_(connection.execute(e.alias("foo").select()).fetchall(), [("ccc",)])
1371
1372    @testing.requires.except_
1373    def test_except_style4(self, connection):
1374        # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
1375        t1, t2, t3 = self.tables("t1", "t2", "t3")
1376
1377        e = except_(
1378            select(t1.c.col3),  # aaa, bbb, ccc
1379            except_(
1380                select(t2.c.col3),  # aaa, bbb, ccc
1381                select(t3.c.col3).where(t3.c.col3 == "ccc"),  # ccc
1382            )
1383            .alias()
1384            .select(),
1385        )
1386
1387        eq_(connection.execute(e).fetchall(), [("ccc",)])
1388        eq_(connection.execute(e.alias().select()).fetchall(), [("ccc",)])
1389
1390    @testing.requires.intersect
1391    @testing.fails_on(
1392        ["sqlite", testing.requires._mysql_not_mariadb_104],
1393        "sqlite can't handle leading parenthesis",
1394    )
1395    def test_intersect_unions(self, connection):
1396        t1, t2, t3 = self.tables("t1", "t2", "t3")
1397
1398        u = intersect(
1399            union(select(t1.c.col3, t1.c.col4), select(t3.c.col3, t3.c.col4)),
1400            union(select(t2.c.col3, t2.c.col4), select(t3.c.col3, t3.c.col4))
1401            .alias()
1402            .select(),
1403        )
1404        wanted = [("aaa", "ccc"), ("bbb", "aaa"), ("ccc", "bbb")]
1405        found = self._fetchall_sorted(connection.execute(u))
1406
1407        eq_(found, wanted)
1408
1409    @testing.requires.intersect
1410    def test_intersect_unions_2(self, connection):
1411        t1, t2, t3 = self.tables("t1", "t2", "t3")
1412
1413        u = intersect(
1414            union(select(t1.c.col3, t1.c.col4), select(t3.c.col3, t3.c.col4))
1415            .alias()
1416            .select(),
1417            union(select(t2.c.col3, t2.c.col4), select(t3.c.col3, t3.c.col4))
1418            .alias()
1419            .select(),
1420        )
1421        wanted = [("aaa", "ccc"), ("bbb", "aaa"), ("ccc", "bbb")]
1422        found = self._fetchall_sorted(connection.execute(u))
1423
1424        eq_(found, wanted)
1425
1426    @testing.requires.intersect
1427    def test_intersect_unions_3(self, connection):
1428        t1, t2, t3 = self.tables("t1", "t2", "t3")
1429
1430        u = intersect(
1431            select(t2.c.col3, t2.c.col4),
1432            union(
1433                select(t1.c.col3, t1.c.col4),
1434                select(t2.c.col3, t2.c.col4),
1435                select(t3.c.col3, t3.c.col4),
1436            )
1437            .alias()
1438            .select(),
1439        )
1440        wanted = [("aaa", "bbb"), ("bbb", "ccc"), ("ccc", "aaa")]
1441        found = self._fetchall_sorted(connection.execute(u))
1442
1443        eq_(found, wanted)
1444
1445    @testing.requires.intersect
1446    def test_composite_alias(self, connection):
1447        t1, t2, t3 = self.tables("t1", "t2", "t3")
1448
1449        ua = intersect(
1450            select(t2.c.col3, t2.c.col4),
1451            union(
1452                select(t1.c.col3, t1.c.col4),
1453                select(t2.c.col3, t2.c.col4),
1454                select(t3.c.col3, t3.c.col4),
1455            )
1456            .alias()
1457            .select(),
1458        ).alias()
1459
1460        wanted = [("aaa", "bbb"), ("bbb", "ccc"), ("ccc", "aaa")]
1461        found = self._fetchall_sorted(connection.execute(ua.select()))
1462        eq_(found, wanted)
1463
1464
1465class JoinTest(fixtures.TablesTest):
1466
1467    """Tests join execution.
1468
1469    The compiled SQL emitted by the dialect might be ANSI joins or
1470    theta joins ('old oracle style', with (+) for OUTER).  This test
1471    tries to exercise join syntax and uncover any inconsistencies in
1472    `JOIN rhs ON lhs.col=rhs.col` vs `rhs.col=lhs.col`.  At least one
1473    database seems to be sensitive to this.
1474    """
1475
1476    __backend__ = True
1477
1478    @classmethod
1479    def define_tables(cls, metadata):
1480        Table(
1481            "t1",
1482            metadata,
1483            Column("t1_id", Integer, primary_key=True),
1484            Column("name", String(32)),
1485        )
1486        Table(
1487            "t2",
1488            metadata,
1489            Column("t2_id", Integer, primary_key=True),
1490            Column("t1_id", Integer, ForeignKey("t1.t1_id")),
1491            Column("name", String(32)),
1492        )
1493        Table(
1494            "t3",
1495            metadata,
1496            Column("t3_id", Integer, primary_key=True),
1497            Column("t2_id", Integer, ForeignKey("t2.t2_id")),
1498            Column("name", String(32)),
1499        )
1500
1501    @classmethod
1502    def insert_data(cls, connection):
1503        conn = connection
1504        # t1.10 -> t2.20 -> t3.30
1505        # t1.11 -> t2.21
1506        # t1.12
1507        t1, t2, t3 = cls.tables("t1", "t2", "t3")
1508
1509        conn.execute(
1510            t1.insert(),
1511            [
1512                {"t1_id": 10, "name": "t1 #10"},
1513                {"t1_id": 11, "name": "t1 #11"},
1514                {"t1_id": 12, "name": "t1 #12"},
1515            ],
1516        )
1517        conn.execute(
1518            t2.insert(),
1519            [
1520                {"t2_id": 20, "t1_id": 10, "name": "t2 #20"},
1521                {"t2_id": 21, "t1_id": 11, "name": "t2 #21"},
1522            ],
1523        )
1524        conn.execute(
1525            t3.insert(), [{"t3_id": 30, "t2_id": 20, "name": "t3 #30"}]
1526        )
1527
1528    def assertRows(self, statement, expected):
1529        """Execute a statement and assert that rows returned equal expected."""
1530        with testing.db.connect() as conn:
1531            found = sorted(
1532                [tuple(row) for row in conn.execute(statement).fetchall()]
1533            )
1534            eq_(found, sorted(expected))
1535
1536    def test_join_x1(self):
1537        """Joins t1->t2."""
1538        t1, t2, t3 = self.tables("t1", "t2", "t3")
1539
1540        for criteria in (t1.c.t1_id == t2.c.t1_id, t2.c.t1_id == t1.c.t1_id):
1541            expr = select(t1.c.t1_id, t2.c.t2_id).select_from(
1542                t1.join(t2, criteria)
1543            )
1544            self.assertRows(expr, [(10, 20), (11, 21)])
1545
1546    def test_join_x2(self):
1547        """Joins t1->t2->t3."""
1548        t1, t2, t3 = self.tables("t1", "t2", "t3")
1549
1550        for criteria in (t1.c.t1_id == t2.c.t1_id, t2.c.t1_id == t1.c.t1_id):
1551            expr = select(t1.c.t1_id, t2.c.t2_id).select_from(
1552                t1.join(t2, criteria)
1553            )
1554            self.assertRows(expr, [(10, 20), (11, 21)])
1555
1556    def test_outerjoin_x1(self):
1557        """Outer joins t1->t2."""
1558        t1, t2, t3 = self.tables("t1", "t2", "t3")
1559
1560        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1561            expr = select(t1.c.t1_id, t2.c.t2_id).select_from(
1562                t1.join(t2).join(t3, criteria)
1563            )
1564            self.assertRows(expr, [(10, 20)])
1565
1566    def test_outerjoin_x2(self):
1567        """Outer joins t1->t2,t3."""
1568        t1, t2, t3 = self.tables("t1", "t2", "t3")
1569
1570        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1571            expr = select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id).select_from(
1572                t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1573                    t3, criteria
1574                )
1575            )
1576            self.assertRows(
1577                expr, [(10, 20, 30), (11, 21, None), (12, None, None)]
1578            )
1579
1580    def test_outerjoin_where_x2_t1(self):
1581        """Outer joins t1->t2,t3, where on t1."""
1582        t1, t2, t3 = self.tables("t1", "t2", "t3")
1583
1584        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1585            expr = (
1586                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1587                .where(t1.c.name == "t1 #10")
1588                .select_from(
1589                    (
1590                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1591                            t3, criteria
1592                        )
1593                    )
1594                )
1595            )
1596            self.assertRows(expr, [(10, 20, 30)])
1597
1598            expr = (
1599                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1600                .where(t1.c.t1_id < 12)
1601                .select_from(
1602                    (
1603                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1604                            t3, criteria
1605                        )
1606                    )
1607                )
1608            )
1609            self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
1610
1611    def test_outerjoin_where_x2_t2(self):
1612        """Outer joins t1->t2,t3, where on t2."""
1613        t1, t2, t3 = self.tables("t1", "t2", "t3")
1614
1615        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1616            expr = (
1617                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1618                .where(t2.c.name == "t2 #20")
1619                .select_from(
1620                    (
1621                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1622                            t3, criteria
1623                        )
1624                    )
1625                )
1626            )
1627            self.assertRows(expr, [(10, 20, 30)])
1628
1629            expr = (
1630                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1631                .where(t2.c.t2_id < 29)
1632                .select_from(
1633                    (
1634                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1635                            t3, criteria
1636                        )
1637                    )
1638                )
1639            )
1640            self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
1641
1642    def test_outerjoin_where_x2_t3(self):
1643        """Outer joins t1->t2,t3, where on t3."""
1644        t1, t2, t3 = self.tables("t1", "t2", "t3")
1645
1646        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1647            expr = (
1648                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1649                .where(t3.c.name == "t3 #30")
1650                .select_from(
1651                    (
1652                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1653                            t3, criteria
1654                        )
1655                    )
1656                )
1657            )
1658            self.assertRows(expr, [(10, 20, 30)])
1659
1660            expr = (
1661                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1662                .where(t3.c.t3_id < 39)
1663                .select_from(
1664                    (
1665                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1666                            t3, criteria
1667                        )
1668                    )
1669                )
1670            )
1671            self.assertRows(expr, [(10, 20, 30)])
1672
1673    def test_outerjoin_where_x2_t1t3(self):
1674        """Outer joins t1->t2,t3, where on t1 and t3."""
1675
1676        t1, t2, t3 = self.tables("t1", "t2", "t3")
1677
1678        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1679            expr = (
1680                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1681                .where(and_(t1.c.name == "t1 #10", t3.c.name == "t3 #30"))
1682                .select_from(
1683                    t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1684                        t3, criteria
1685                    )
1686                )
1687            )
1688
1689            self.assertRows(expr, [(10, 20, 30)])
1690
1691            expr = (
1692                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1693                .where(and_(t1.c.t1_id < 19, t3.c.t3_id < 39))
1694                .select_from(
1695                    (
1696                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1697                            t3, criteria
1698                        )
1699                    )
1700                )
1701            )
1702            self.assertRows(expr, [(10, 20, 30)])
1703
1704    def test_outerjoin_where_x2_t1t2(self):
1705        """Outer joins t1->t2,t3, where on t1 and t2."""
1706
1707        t1, t2, t3 = self.tables("t1", "t2", "t3")
1708
1709        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1710            expr = (
1711                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1712                .where(and_(t1.c.name == "t1 #10", t2.c.name == "t2 #20"))
1713                .select_from(
1714                    (
1715                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1716                            t3, criteria
1717                        )
1718                    )
1719                )
1720            )
1721            self.assertRows(expr, [(10, 20, 30)])
1722
1723            expr = (
1724                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1725                .where(and_(t1.c.t1_id < 12, t2.c.t2_id < 39))
1726                .select_from(
1727                    (
1728                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1729                            t3, criteria
1730                        )
1731                    )
1732                )
1733            )
1734            self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
1735
1736    def test_outerjoin_where_x2_t1t2t3(self):
1737        """Outer joins t1->t2,t3, where on t1, t2 and t3."""
1738        t1, t2, t3 = self.tables("t1", "t2", "t3")
1739
1740        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1741            expr = (
1742                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1743                .where(
1744                    and_(
1745                        t1.c.name == "t1 #10",
1746                        t2.c.name == "t2 #20",
1747                        t3.c.name == "t3 #30",
1748                    )
1749                )
1750                .select_from(
1751                    (
1752                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1753                            t3, criteria
1754                        )
1755                    )
1756                )
1757            )
1758            self.assertRows(expr, [(10, 20, 30)])
1759
1760            expr = (
1761                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1762                .where(and_(t1.c.t1_id < 19, t2.c.t2_id < 29, t3.c.t3_id < 39))
1763                .select_from(
1764                    (
1765                        t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin(
1766                            t3, criteria
1767                        )
1768                    )
1769                )
1770            )
1771            self.assertRows(expr, [(10, 20, 30)])
1772
1773    def test_mixed(self):
1774        """Joins t1->t2, outer t2->t3."""
1775        t1, t2, t3 = self.tables("t1", "t2", "t3")
1776
1777        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1778            expr = select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id).select_from(
1779                (t1.join(t2).outerjoin(t3, criteria)),
1780            )
1781            print(expr)
1782            self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
1783
1784    def test_mixed_where(self):
1785        """Joins t1->t2, outer t2->t3, plus a where on each table in turn."""
1786        t1, t2, t3 = self.tables("t1", "t2", "t3")
1787
1788        for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id):
1789            expr = (
1790                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1791                .where(
1792                    t1.c.name == "t1 #10",
1793                )
1794                .select_from((t1.join(t2).outerjoin(t3, criteria)))
1795            )
1796            self.assertRows(expr, [(10, 20, 30)])
1797
1798            expr = (
1799                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1800                .where(
1801                    t2.c.name == "t2 #20",
1802                )
1803                .select_from((t1.join(t2).outerjoin(t3, criteria)))
1804            )
1805            self.assertRows(expr, [(10, 20, 30)])
1806
1807            expr = (
1808                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1809                .where(
1810                    t3.c.name == "t3 #30",
1811                )
1812                .select_from((t1.join(t2).outerjoin(t3, criteria)))
1813            )
1814            self.assertRows(expr, [(10, 20, 30)])
1815
1816            expr = (
1817                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1818                .where(
1819                    and_(t1.c.name == "t1 #10", t2.c.name == "t2 #20"),
1820                )
1821                .select_from((t1.join(t2).outerjoin(t3, criteria)))
1822            )
1823            self.assertRows(expr, [(10, 20, 30)])
1824
1825            expr = (
1826                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1827                .where(
1828                    and_(t2.c.name == "t2 #20", t3.c.name == "t3 #30"),
1829                )
1830                .select_from((t1.join(t2).outerjoin(t3, criteria)))
1831            )
1832            self.assertRows(expr, [(10, 20, 30)])
1833
1834            expr = (
1835                select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id)
1836                .where(
1837                    and_(
1838                        t1.c.name == "t1 #10",
1839                        t2.c.name == "t2 #20",
1840                        t3.c.name == "t3 #30",
1841                    ),
1842                )
1843                .select_from((t1.join(t2).outerjoin(t3, criteria)))
1844            )
1845            self.assertRows(expr, [(10, 20, 30)])
1846
1847
1848class OperatorTest(fixtures.TablesTest):
1849    __backend__ = True
1850
1851    @classmethod
1852    def define_tables(cls, metadata):
1853        Table(
1854            "flds",
1855            metadata,
1856            Column(
1857                "idcol",
1858                Integer,
1859                primary_key=True,
1860                test_needs_autoincrement=True,
1861            ),
1862            Column("intcol", Integer),
1863            Column("strcol", String(50)),
1864        )
1865
1866    @classmethod
1867    def insert_data(cls, connection):
1868        flds = cls.tables.flds
1869        connection.execute(
1870            flds.insert(),
1871            [dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")],
1872        )
1873
1874    # TODO: seems like more tests warranted for this setup.
1875    def test_modulo(self, connection):
1876        flds = self.tables.flds
1877
1878        eq_(
1879            connection.execute(
1880                select(flds.c.intcol % 3).order_by(flds.c.idcol)
1881            ).fetchall(),
1882            [(2,), (1,)],
1883        )
1884
1885    @testing.requires.window_functions
1886    def test_over(self, connection):
1887        flds = self.tables.flds
1888
1889        eq_(
1890            connection.execute(
1891                select(
1892                    flds.c.intcol,
1893                    func.row_number().over(order_by=flds.c.strcol),
1894                )
1895            ).fetchall(),
1896            [(13, 1), (5, 2)],
1897        )
1898