Payment Splitter

Split payment automatically with a smart contract in Vyper

Imagine you and your friends build a business and want to split the revenue automatically. Let's say there are 6 people who join this venture. You got 5%. Each individual of the rest of your partners got 19%. So 5% + 19% + 19% + 19% + 19% + 19% = 100%. Perfectly balanced, as it should be. So if a buyer sent 100 ETH to purchase one token from your smart contract, the spoils would be split to 5 ETH for you, 19 ETH for each of your friends.

Let's create a smart contract to do that.


(.venv) $ mkdir payment_splitter
(.venv) $ cd payment_splitter
(.venv) $ mamba init

Create contracts/PaymentSplitter.vy. Add the following code to the file:


"""
@title A Payment Splitter design pattern
@license MIT
@author Arjuna Sky Kok
@notice Translated from OpenZeppelin's PaymentSplitter code: https://github.com/OpenZeppelin/openzeppelin-contracts/blob/master/contracts/finance/PaymentSplitter.sol
"""

_totalShares: uint256
_totalReleased: uint256

_shares: HashMap[address, uint256]
_released: HashMap[address, uint256]
_payees: address[6]

event PayeeAdded:
    account: address
    shares: uint256

event PaymentReleased:
    to: address
    amount: uint256

event PaymentReceived:
    sender: address
    amount: uint256


@external
def __init__():
    self._payees[0] = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf
    self._payees[1] = 0x2B5AD5c4795c026514f8317c7a215E218DcCD6cF
    self._payees[2] = 0x6813Eb9362372EEF6200f3b1dbC3f819671cBA69
    self._payees[3] = 0x1efF47bc3a10a45D4B230B5d10E37751FE6AA718
    self._payees[4] = 0xe1AB8145F7E55DC933d51a18c793F901A3A0b276
    self._payees[5] = 0xE57bFE9F44b819898F47BF37E5AF72a0783e1141
    self._shares[self._payees[0]] = 19
    self._shares[self._payees[1]] = 19
    self._shares[self._payees[2]] = 19
    self._shares[self._payees[3]] = 19
    self._shares[self._payees[4]] = 19
    self._shares[self._payees[5]] = 5
    self._totalShares = 100

    log PayeeAdded(self._payees[0], self._shares[self._payees[0]])
    log PayeeAdded(self._payees[1], self._shares[self._payees[1]])
    log PayeeAdded(self._payees[2], self._shares[self._payees[2]])
    log PayeeAdded(self._payees[3], self._shares[self._payees[3]])
    log PayeeAdded(self._payees[4], self._shares[self._payees[4]])
    log PayeeAdded(self._payees[5], self._shares[self._payees[5]])


@external
@payable
def __default__():
    log PaymentReceived(msg.sender, msg.value)


@external
def totalShares() -> uint256:
    return self._totalShares


@external
def totalReleased() -> uint256:
    return self._totalReleased


@external
def shares(account: address) -> uint256:
    return self._shares[account]


@external
def released(account: address) -> uint256:
    return self._released[account]


@external
def payee(index: uint256) -> address:
    return self._payees[index]


@external
def release(account: address):
    assert self._shares[account] > 0, "Account has no shares"

    totalReceived: uint256 = self.balance + self._totalReleased
    payment: uint256 = totalReceived * self._shares[account] / self._totalShares - self._released[account]

    assert payment != 0, "Account is not due payment"

    self._released[account] = self._released[account] +  payment
    self._totalReleased = self._totalReleased + payment

    send(account, payment)
    log PaymentReleased(account, payment)

Vyper does not support dynamic array parameter yet. Also, there is a bug where you cannot execute a function inside an __init__ function. So you have to hard-code the agreement on how you want to split the payments. Inside the __init__ function, you initialize the payees' addresses with _payees. You set the shares with _shares. This variable is how you decide the splitting should be. But don't forget to set the total shares with _totalShares. You put 100 in _totalShares to make the number round and looks like percentage. So if your share if 5, it means your percentage of spoils is 5% (5 / 100). Make sure the total shares in _shares tally with _totalShares. Otherwise, some amount of spoils would be stuck in the smart contract and lost forever. If you changed _totalShares to 200, half of the payments would be stuck in the smart contract.

Let's look at how you withdraw the payment for your wallet. This is the function release.


@external
def release(account: address):
    assert self._shares[account] > 0, "Account has no shares"

    totalReceived: uint256 = self.balance + self._totalReleased
    payment: uint256 = totalReceived * self._shares[account] / self._totalShares - self._released[account]

    assert payment != 0, "Account is not due payment"

    self._released[account] = self._released[account] +  payment
    self._totalReleased = self._totalReleased + payment

    send(account, payment)
    log PaymentReleased(account, payment)

Say someone sends 100 ETH to the smart contract. The balance becomes 100 ETH. The totalReleased is still 0. So you calculate your spoils and put it in payment. Then you add it to _released for your account and _totalReleased. It means that if there is another payment and you want to withdraw the money, your withdrawn money in the past must be taken into account. If another buyer sent 50 ETH, you could only withdraw 2.5 ETH not 7.5 ETH (because you have withdrawn 5 ETH previously). Finally, the smart contract sends money to your account with send.

If you are wondering where the function that accepts payment from buyers is, it's __default__. Inside the function, you log the payment.

The rest of the functions are self-explanatory.

Compile the smart contract.


(.venv) $ mamba compile

Now, let's add the test. Create test/test_payment_splitter.py. Add the following code to it:


from black_mamba.testlib import contract, eth_tester, TestContract
import pytest
from eth_tester.exceptions import TransactionFailed
import web3


class TestPaymentSplitter(TestContract):

    def test_init(self, eth_tester):
        accounts = eth_tester.get_accounts()
        ps_contract = contract("PaymentSplitter", [])

        assert ps_contract.functions.totalShares().call() == 100
        assert ps_contract.functions.totalReleased().call() == 0
        assert ps_contract.functions.shares(accounts[0]).call() == 19
        assert ps_contract.functions.shares(accounts[1]).call() == 19
        assert ps_contract.functions.shares(accounts[2]).call() == 19
        assert ps_contract.functions.shares(accounts[3]).call() == 19
        assert ps_contract.functions.shares(accounts[4]).call() == 19
        assert ps_contract.functions.shares(accounts[5]).call() == 5
        assert ps_contract.functions.payee(0).call() == accounts[0]
        assert ps_contract.functions.payee(1).call() == accounts[1]
        assert ps_contract.functions.payee(2).call() == accounts[2]
        assert ps_contract.functions.payee(3).call() == accounts[3]
        assert ps_contract.functions.payee(4).call() == accounts[4]
        assert ps_contract.functions.payee(5).call() == accounts[5]
        assert ps_contract.functions.released(accounts[0]).call() == 0
        assert ps_contract.functions.released(accounts[1]).call() == 0
        assert ps_contract.functions.released(accounts[2]).call() == 0
        assert ps_contract.functions.released(accounts[3]).call() == 0
        assert ps_contract.functions.released(accounts[4]).call() == 0
        assert ps_contract.functions.released(accounts[5]).call() == 0
        log = ps_contract.events.PayeeAdded.getLogs()
        assert log[0]["args"]["account"] == accounts[0]
        assert log[0]["args"]["shares"] == 19
        assert log[1]["args"]["account"] == accounts[1]
        assert log[1]["args"]["shares"] == 19
        assert log[2]["args"]["account"] == accounts[2]
        assert log[2]["args"]["shares"] == 19
        assert log[3]["args"]["account"] == accounts[3]
        assert log[3]["args"]["shares"] == 19
        assert log[4]["args"]["account"] == accounts[4]
        assert log[4]["args"]["shares"] == 19
        assert log[5]["args"]["account"] == accounts[5]
        assert log[5]["args"]["shares"] == 5


    def test_release(self, eth_tester):
        accounts = eth_tester.get_accounts()
        ps_contract = contract("PaymentSplitter", [])

        ether = web3.Web3.toWei(100, "ether")
        total_ether = web3.Web3.toWei(1000019, "ether")
        total_ether_lower_bound = web3.Web3.toWei(1000018, "ether")
        total_ether2 = web3.Web3.toWei(1000005, "ether")
        released_ether = web3.Web3.toWei(19, "ether")
        released_ether2 = web3.Web3.toWei(5, "ether")

        ps_contract.web3.eth.sendTransaction({"to": ps_contract.address, "from": accounts[6], "value": ether})

        assert ps_contract.functions.release(accounts[0]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[0]) < total_ether
        assert ps_contract.web3.eth.get_balance(accounts[0]) > total_ether_lower_bound
        assert ps_contract.functions.totalReleased().call() == released_ether

        assert ps_contract.functions.release(accounts[1]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[1]) == total_ether
        assert ps_contract.functions.totalReleased().call() == released_ether * 2

        assert ps_contract.functions.release(accounts[2]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[2]) == total_ether
        assert ps_contract.functions.totalReleased().call() == released_ether * 3

        assert ps_contract.functions.release(accounts[3]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[3]) == total_ether
        assert ps_contract.functions.totalReleased().call() == released_ether * 4

        assert ps_contract.functions.release(accounts[4]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[4]) == total_ether
        assert ps_contract.functions.totalReleased().call() == released_ether * 5

        assert ps_contract.functions.release(accounts[5]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[5]) == total_ether2
        assert ps_contract.functions.totalReleased().call() == released_ether * 5 + released_ether2

        assert ps_contract.functions.released(accounts[0]).call() == released_ether
        assert ps_contract.functions.released(accounts[1]).call() == released_ether
        assert ps_contract.functions.released(accounts[2]).call() == released_ether
        assert ps_contract.functions.released(accounts[3]).call() == released_ether
        assert ps_contract.functions.released(accounts[4]).call() == released_ether
        assert ps_contract.functions.released(accounts[5]).call() == released_ether2


    def test_release_second_times(self, eth_tester):
        accounts = eth_tester.get_accounts()
        ps_contract = contract("PaymentSplitter", [])

        ether = web3.Web3.toWei(100, "ether")
        total_ether = web3.Web3.toWei(1000019, "ether")
        total_ether2 = web3.Web3.toWei(1000038, "ether")
        released_ether = web3.Web3.toWei(19, "ether")

        ps_contract.web3.eth.sendTransaction({"to": ps_contract.address, "from": accounts[6], "value": ether})

        assert ps_contract.functions.release(accounts[1]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[1]) == total_ether
        assert ps_contract.functions.totalReleased().call() == released_ether

        ps_contract.web3.eth.sendTransaction({"to": ps_contract.address, "from": accounts[6], "value": ether})

        assert ps_contract.functions.release(accounts[1]).transact({ "from": accounts[7] })
        assert ps_contract.web3.eth.get_balance(accounts[1]) == total_ether2
        assert ps_contract.functions.totalReleased().call() == released_ether * 2

In this test file, you have three test methods. The first is to test __init__. The second test is to test release. The third test is to test release with a payment, a withdraw from an account, a payment again, and another withdraw from the same account. You can execute the test like this:


(.venv) $ py.test test/test_payment_splitter.py

If you want to modify this smart contract to your whim, you can start from __init__. You can create an ICO smart contract and split the capital with this design pattern.

Follow me at Twitter: @arjunaskykok.